Coverage for src/ipyvizzu/data/converters/spark/converter.py: 100%
48 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-12 08:13 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-12 08:13 +0000
1"""
2This module provides the `SparkDataFrameConverter` class,
3which allows converting a `pyspark` `DataFrame`
4into a list of dictionaries representing series.
5"""
7from types import ModuleType
8from typing import List, Tuple
10from ipyvizzu.data.converters.defaults import NAN_DIMENSION, NAN_MEASURE
11from ipyvizzu.data.converters.df.defaults import MAX_ROWS
12from ipyvizzu.data.converters.df.converter import DataFrameConverter
13from ipyvizzu.data.infer_type import InferType
14from ipyvizzu.data.type_alias import (
15 DimensionValue,
16 MeasureValue,
17 SeriesValues,
18)
21class SparkDataFrameConverter(DataFrameConverter):
22 """
23 Converts a `pyspark` `DataFrame` into a list of dictionaries representing series.
24 Each dictionary contains information about the series `name`, `values` and `type`.
26 Parameters:
27 df: The `pyspark` `DataFrame` to convert.
28 default_measure_value:
29 Default value to use for missing measure values. Defaults to 0.
30 default_dimension_value:
31 Default value to use for missing dimension values. Defaults to an empty string.
32 max_rows: The maximum number of rows to include in the converted series list.
33 If the `df` contains more rows,
34 a random sample of the given number of rows (approximately) will be taken.
36 Example:
37 Get series list from `DataFrame` columns:
39 converter = SparkDataFrameConverter(df)
40 series_list = converter.get_series_list()
41 """
43 # pylint: disable=too-few-public-methods
45 def __init__(
46 self,
47 df: "pyspark.sql.DataFrame", # type: ignore
48 default_measure_value: MeasureValue = NAN_MEASURE,
49 default_dimension_value: DimensionValue = NAN_DIMENSION,
50 max_rows: int = MAX_ROWS,
51 ) -> None:
52 super().__init__(default_measure_value, default_dimension_value, max_rows)
53 self._pyspark, self._pyspark_func = self._get_pyspark()
54 self._df = self._get_sampled_df(df)
56 def _get_pyspark(self) -> Tuple[ModuleType, ModuleType]:
57 try:
58 import pyspark # pylint: disable=import-outside-toplevel
59 from pyspark.sql import functions # pylint: disable=import-outside-toplevel
61 return pyspark, functions
62 except ImportError as error:
63 raise ImportError(
64 "pyspark is not available. Please install pyspark to use this feature."
65 ) from error
67 def _get_sampled_df(
68 self, df: "pyspark.sql.DataFrame" # type: ignore
69 ) -> "pyspark.sql.DataFrame": # type: ignore
70 row_number = df.count()
71 if self._is_max_rows_exceeded(row_number):
72 fraction = self._max_rows / row_number
73 sample_df = df.sample(withReplacement=False, fraction=fraction, seed=42)
74 return sample_df.limit(self._max_rows)
75 return df
77 def _get_columns(self) -> List[str]:
78 return self._df.columns
80 def _convert_to_series_values_and_type(
81 self, obj: str
82 ) -> Tuple[SeriesValues, InferType]:
83 column_name = obj
84 column = self._df.select(column_name)
85 integer_type = self._pyspark.sql.types.IntegerType
86 double_type = self._pyspark.sql.types.DoubleType
87 if isinstance(column.schema[column_name].dataType, (integer_type, double_type)):
88 return self._convert_to_measure_values(column_name), InferType.MEASURE
89 return self._convert_to_dimension_values(column_name), InferType.DIMENSION
91 def _convert_to_measure_values(self, obj: str) -> List[MeasureValue]:
92 column_name = obj
93 func = self._pyspark_func
94 df = self._df.withColumn(
95 column_name,
96 func.when(
97 func.col(column_name).isNull(), self._default_measure_value
98 ).otherwise(func.col(column_name)),
99 )
100 df_rdd = (
101 df.withColumn(column_name, func.col(column_name).cast("float"))
102 .select(column_name)
103 .rdd
104 )
105 return df_rdd.flatMap(list).collect()
107 def _convert_to_dimension_values(self, obj: str) -> List[DimensionValue]:
108 column_name = obj
109 func = self._pyspark_func
110 df = self._df.withColumn(
111 column_name,
112 func.when(
113 func.col(column_name).isNull(), self._default_dimension_value
114 ).otherwise(func.col(column_name)),
115 )
116 df_rdd = (
117 df.withColumn(column_name, func.col(column_name).cast("string"))
118 .select(column_name)
119 .rdd
120 )
121 return df_rdd.flatMap(list).collect()