Coverage for src/ipyvizzu/data/converters/spark/converter.py: 100%
48 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-26 10:12 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-02-26 10:12 +0000
1"""
2This module provides the `SparkDataFrameConverter` class,
3which allows converting a `pyspark` `DataFrame`
4into a list of dictionaries representing series.
5"""
7from types import ModuleType
8from typing import Dict, List, Optional, Tuple
10from ipyvizzu.data.converters.defaults import NAN_DIMENSION, NAN_MEASURE
11from ipyvizzu.data.converters.df.defaults import MAX_ROWS
12from ipyvizzu.data.converters.df.converter import DataFrameConverter
13from ipyvizzu.data.infer_type import InferType
14from ipyvizzu.data.type_alias import (
15 DimensionValue,
16 MeasureValue,
17 SeriesValues,
18)
21class SparkDataFrameConverter(DataFrameConverter):
22 """
23 Converts a `pyspark` `DataFrame` into a list of dictionaries representing series.
24 Each dictionary contains information about the series `name`, `values` and `type`.
26 Parameters:
27 df: The `pyspark` `DataFrame` to convert.
28 default_measure_value:
29 Default value to use for missing measure values. Defaults to 0.
30 default_dimension_value:
31 Default value to use for missing dimension values. Defaults to an empty string.
32 max_rows: The maximum number of rows to include in the converted series list.
33 If the `df` contains more rows,
34 a random sample of the given number of rows (approximately) will be taken.
36 Example:
37 Get series list from `DataFrame` columns:
39 converter = SparkDataFrameConverter(df)
40 series_list = converter.get_series_list()
41 """
43 # pylint: disable=too-few-public-methods
45 def __init__(
46 self,
47 df: "pyspark.sql.DataFrame", # type: ignore
48 default_measure_value: MeasureValue = NAN_MEASURE,
49 default_dimension_value: DimensionValue = NAN_DIMENSION,
50 max_rows: int = MAX_ROWS,
51 units: Optional[Dict[str, str]] = None,
52 ) -> None:
53 # pylint: disable=too-many-arguments
55 super().__init__(
56 default_measure_value, default_dimension_value, max_rows, units
57 )
58 self._pyspark, self._pyspark_func = self._get_pyspark()
59 self._df = self._get_sampled_df(df)
61 def _get_pyspark(self) -> Tuple[ModuleType, ModuleType]:
62 try:
63 import pyspark # pylint: disable=import-outside-toplevel
64 from pyspark.sql import functions # pylint: disable=import-outside-toplevel
66 return pyspark, functions
67 except ImportError as error:
68 raise ImportError(
69 "pyspark is not available. Please install pyspark to use this feature."
70 ) from error
72 def _get_sampled_df(
73 self, df: "pyspark.sql.DataFrame" # type: ignore
74 ) -> "pyspark.sql.DataFrame": # type: ignore
75 row_number = df.count()
76 if self._is_max_rows_exceeded(row_number):
77 fraction = self._max_rows / row_number
78 sample_df = df.sample(withReplacement=False, fraction=fraction, seed=42)
79 return sample_df.limit(self._max_rows)
80 return df
82 def _get_columns(self) -> List[str]:
83 return self._df.columns
85 def _convert_to_series_values_and_type(
86 self, obj: str
87 ) -> Tuple[SeriesValues, InferType]:
88 column_name = obj
89 column = self._df.select(column_name)
90 integer_type = self._pyspark.sql.types.IntegerType
91 double_type = self._pyspark.sql.types.DoubleType
92 if isinstance(column.schema[column_name].dataType, (integer_type, double_type)):
93 return self._convert_to_measure_values(column_name), InferType.MEASURE
94 return self._convert_to_dimension_values(column_name), InferType.DIMENSION
96 def _convert_to_measure_values(self, obj: str) -> List[MeasureValue]:
97 column_name = obj
98 func = self._pyspark_func
99 df = self._df.withColumn(
100 column_name,
101 func.when(
102 func.col(column_name).isNull(), self._default_measure_value
103 ).otherwise(func.col(column_name)),
104 )
105 df_rdd = (
106 df.withColumn(column_name, func.col(column_name).cast("float"))
107 .select(column_name)
108 .rdd
109 )
110 return df_rdd.flatMap(list).collect()
112 def _convert_to_dimension_values(self, obj: str) -> List[DimensionValue]:
113 column_name = obj
114 func = self._pyspark_func
115 df = self._df.withColumn(
116 column_name,
117 func.when(
118 func.col(column_name).isNull(), self._default_dimension_value
119 ).otherwise(func.col(column_name)),
120 )
121 df_rdd = (
122 df.withColumn(column_name, func.col(column_name).cast("string"))
123 .select(column_name)
124 .rdd
125 )
126 return df_rdd.flatMap(list).collect()