Coverage for src/ipyvizzu/data/converters/spark/converter.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-02-26 10:12 +0000

1""" 

2This module provides the `SparkDataFrameConverter` class, 

3which allows converting a `pyspark` `DataFrame` 

4into a list of dictionaries representing series. 

5""" 

6 

7from types import ModuleType 

8from typing import Dict, List, Optional, Tuple 

9 

10from ipyvizzu.data.converters.defaults import NAN_DIMENSION, NAN_MEASURE 

11from ipyvizzu.data.converters.df.defaults import MAX_ROWS 

12from ipyvizzu.data.converters.df.converter import DataFrameConverter 

13from ipyvizzu.data.infer_type import InferType 

14from ipyvizzu.data.type_alias import ( 

15 DimensionValue, 

16 MeasureValue, 

17 SeriesValues, 

18) 

19 

20 

21class SparkDataFrameConverter(DataFrameConverter): 

22 """ 

23 Converts a `pyspark` `DataFrame` into a list of dictionaries representing series. 

24 Each dictionary contains information about the series `name`, `values` and `type`. 

25 

26 Parameters: 

27 df: The `pyspark` `DataFrame` to convert. 

28 default_measure_value: 

29 Default value to use for missing measure values. Defaults to 0. 

30 default_dimension_value: 

31 Default value to use for missing dimension values. Defaults to an empty string. 

32 max_rows: The maximum number of rows to include in the converted series list. 

33 If the `df` contains more rows, 

34 a random sample of the given number of rows (approximately) will be taken. 

35 

36 Example: 

37 Get series list from `DataFrame` columns: 

38 

39 converter = SparkDataFrameConverter(df) 

40 series_list = converter.get_series_list() 

41 """ 

42 

43 # pylint: disable=too-few-public-methods 

44 

45 def __init__( 

46 self, 

47 df: "pyspark.sql.DataFrame", # type: ignore 

48 default_measure_value: MeasureValue = NAN_MEASURE, 

49 default_dimension_value: DimensionValue = NAN_DIMENSION, 

50 max_rows: int = MAX_ROWS, 

51 units: Optional[Dict[str, str]] = None, 

52 ) -> None: 

53 # pylint: disable=too-many-arguments 

54 

55 super().__init__( 

56 default_measure_value, default_dimension_value, max_rows, units 

57 ) 

58 self._pyspark, self._pyspark_func = self._get_pyspark() 

59 self._df = self._get_sampled_df(df) 

60 

61 def _get_pyspark(self) -> Tuple[ModuleType, ModuleType]: 

62 try: 

63 import pyspark # pylint: disable=import-outside-toplevel 

64 from pyspark.sql import functions # pylint: disable=import-outside-toplevel 

65 

66 return pyspark, functions 

67 except ImportError as error: 

68 raise ImportError( 

69 "pyspark is not available. Please install pyspark to use this feature." 

70 ) from error 

71 

72 def _get_sampled_df( 

73 self, df: "pyspark.sql.DataFrame" # type: ignore 

74 ) -> "pyspark.sql.DataFrame": # type: ignore 

75 row_number = df.count() 

76 if self._is_max_rows_exceeded(row_number): 

77 fraction = self._max_rows / row_number 

78 sample_df = df.sample(withReplacement=False, fraction=fraction, seed=42) 

79 return sample_df.limit(self._max_rows) 

80 return df 

81 

82 def _get_columns(self) -> List[str]: 

83 return self._df.columns 

84 

85 def _convert_to_series_values_and_type( 

86 self, obj: str 

87 ) -> Tuple[SeriesValues, InferType]: 

88 column_name = obj 

89 column = self._df.select(column_name) 

90 integer_type = self._pyspark.sql.types.IntegerType 

91 double_type = self._pyspark.sql.types.DoubleType 

92 if isinstance(column.schema[column_name].dataType, (integer_type, double_type)): 

93 return self._convert_to_measure_values(column_name), InferType.MEASURE 

94 return self._convert_to_dimension_values(column_name), InferType.DIMENSION 

95 

96 def _convert_to_measure_values(self, obj: str) -> List[MeasureValue]: 

97 column_name = obj 

98 func = self._pyspark_func 

99 df = self._df.withColumn( 

100 column_name, 

101 func.when( 

102 func.col(column_name).isNull(), self._default_measure_value 

103 ).otherwise(func.col(column_name)), 

104 ) 

105 df_rdd = ( 

106 df.withColumn(column_name, func.col(column_name).cast("float")) 

107 .select(column_name) 

108 .rdd 

109 ) 

110 return df_rdd.flatMap(list).collect() 

111 

112 def _convert_to_dimension_values(self, obj: str) -> List[DimensionValue]: 

113 column_name = obj 

114 func = self._pyspark_func 

115 df = self._df.withColumn( 

116 column_name, 

117 func.when( 

118 func.col(column_name).isNull(), self._default_dimension_value 

119 ).otherwise(func.col(column_name)), 

120 ) 

121 df_rdd = ( 

122 df.withColumn(column_name, func.col(column_name).cast("string")) 

123 .select(column_name) 

124 .rdd 

125 ) 

126 return df_rdd.flatMap(list).collect()