Coverage for src/ipyvizzu/data/converters/spark/converter.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-12 08:13 +0000

1""" 

2This module provides the `SparkDataFrameConverter` class, 

3which allows converting a `pyspark` `DataFrame` 

4into a list of dictionaries representing series. 

5""" 

6 

7from types import ModuleType 

8from typing import List, Tuple 

9 

10from ipyvizzu.data.converters.defaults import NAN_DIMENSION, NAN_MEASURE 

11from ipyvizzu.data.converters.df.defaults import MAX_ROWS 

12from ipyvizzu.data.converters.df.converter import DataFrameConverter 

13from ipyvizzu.data.infer_type import InferType 

14from ipyvizzu.data.type_alias import ( 

15 DimensionValue, 

16 MeasureValue, 

17 SeriesValues, 

18) 

19 

20 

21class SparkDataFrameConverter(DataFrameConverter): 

22 """ 

23 Converts a `pyspark` `DataFrame` into a list of dictionaries representing series. 

24 Each dictionary contains information about the series `name`, `values` and `type`. 

25 

26 Parameters: 

27 df: The `pyspark` `DataFrame` to convert. 

28 default_measure_value: 

29 Default value to use for missing measure values. Defaults to 0. 

30 default_dimension_value: 

31 Default value to use for missing dimension values. Defaults to an empty string. 

32 max_rows: The maximum number of rows to include in the converted series list. 

33 If the `df` contains more rows, 

34 a random sample of the given number of rows (approximately) will be taken. 

35 

36 Example: 

37 Get series list from `DataFrame` columns: 

38 

39 converter = SparkDataFrameConverter(df) 

40 series_list = converter.get_series_list() 

41 """ 

42 

43 # pylint: disable=too-few-public-methods 

44 

45 def __init__( 

46 self, 

47 df: "pyspark.sql.DataFrame", # type: ignore 

48 default_measure_value: MeasureValue = NAN_MEASURE, 

49 default_dimension_value: DimensionValue = NAN_DIMENSION, 

50 max_rows: int = MAX_ROWS, 

51 ) -> None: 

52 super().__init__(default_measure_value, default_dimension_value, max_rows) 

53 self._pyspark, self._pyspark_func = self._get_pyspark() 

54 self._df = self._get_sampled_df(df) 

55 

56 def _get_pyspark(self) -> Tuple[ModuleType, ModuleType]: 

57 try: 

58 import pyspark # pylint: disable=import-outside-toplevel 

59 from pyspark.sql import functions # pylint: disable=import-outside-toplevel 

60 

61 return pyspark, functions 

62 except ImportError as error: 

63 raise ImportError( 

64 "pyspark is not available. Please install pyspark to use this feature." 

65 ) from error 

66 

67 def _get_sampled_df( 

68 self, df: "pyspark.sql.DataFrame" # type: ignore 

69 ) -> "pyspark.sql.DataFrame": # type: ignore 

70 row_number = df.count() 

71 if self._is_max_rows_exceeded(row_number): 

72 fraction = self._max_rows / row_number 

73 sample_df = df.sample(withReplacement=False, fraction=fraction, seed=42) 

74 return sample_df.limit(self._max_rows) 

75 return df 

76 

77 def _get_columns(self) -> List[str]: 

78 return self._df.columns 

79 

80 def _convert_to_series_values_and_type( 

81 self, obj: str 

82 ) -> Tuple[SeriesValues, InferType]: 

83 column_name = obj 

84 column = self._df.select(column_name) 

85 integer_type = self._pyspark.sql.types.IntegerType 

86 double_type = self._pyspark.sql.types.DoubleType 

87 if isinstance(column.schema[column_name].dataType, (integer_type, double_type)): 

88 return self._convert_to_measure_values(column_name), InferType.MEASURE 

89 return self._convert_to_dimension_values(column_name), InferType.DIMENSION 

90 

91 def _convert_to_measure_values(self, obj: str) -> List[MeasureValue]: 

92 column_name = obj 

93 func = self._pyspark_func 

94 df = self._df.withColumn( 

95 column_name, 

96 func.when( 

97 func.col(column_name).isNull(), self._default_measure_value 

98 ).otherwise(func.col(column_name)), 

99 ) 

100 df_rdd = ( 

101 df.withColumn(column_name, func.col(column_name).cast("float")) 

102 .select(column_name) 

103 .rdd 

104 ) 

105 return df_rdd.flatMap(list).collect() 

106 

107 def _convert_to_dimension_values(self, obj: str) -> List[DimensionValue]: 

108 column_name = obj 

109 func = self._pyspark_func 

110 df = self._df.withColumn( 

111 column_name, 

112 func.when( 

113 func.col(column_name).isNull(), self._default_dimension_value 

114 ).otherwise(func.col(column_name)), 

115 ) 

116 df_rdd = ( 

117 df.withColumn(column_name, func.col(column_name).cast("string")) 

118 .select(column_name) 

119 .rdd 

120 ) 

121 return df_rdd.flatMap(list).collect()