31
31
from ordered_set import OrderedSet
32
32
33
33
from datacompy .base import BaseCompare
34
- from datacompy .spark .sql import decimal_comparator
35
34
36
35
LOG = logging .getLogger (__name__ )
37
36
51
50
trim ,
52
51
when ,
53
52
)
53
+ from snowflake .snowpark .types import (
54
+ ByteType ,
55
+ DateType ,
56
+ DecimalType ,
57
+ DoubleType ,
58
+ FloatType ,
59
+ IntegerType ,
60
+ LongType ,
61
+ ShortType ,
62
+ StringType ,
63
+ TimestampType ,
64
+ )
65
+
66
+ NUMERIC_SNOWPARK_TYPES = [
67
+ ByteType ,
68
+ ShortType ,
69
+ IntegerType ,
70
+ LongType ,
71
+ FloatType ,
72
+ DoubleType ,
73
+ DecimalType ,
74
+ ]
75
+
54
76
55
77
except ImportError :
56
78
LOG .warning (
59
81
)
60
82
61
83
62
- NUMERIC_SNOWPARK_TYPES = [
63
- "tinyint" ,
64
- "smallint" ,
65
- "int" ,
66
- "bigint" ,
67
- "float" ,
68
- "double" ,
69
- decimal_comparator (),
70
- ]
71
-
72
-
73
84
class SnowflakeCompare (BaseCompare ):
74
85
"""Comparison class to be used to compare whether two Snowpark dataframes are equal.
75
86
@@ -490,17 +501,18 @@ def _calculate_column_compare_stats(self, column: str) -> None:
490
501
match_rate = 0
491
502
LOG .info (f"{ column } : { match_cnt } / { row_cnt } ({ match_rate :.2%} ) match" )
492
503
493
- col1_dtype , _ = _get_column_dtypes (self .df1 , column , column )
494
- col2_dtype , _ = _get_column_dtypes (self .df2 , column , column )
504
+ col1_dtype_instance , _ = _get_column_dtypes (self .df1 , column , column )
505
+ col2_dtype_instance , _ = _get_column_dtypes (self .df2 , column , column )
506
+ col1_dtype , col2_dtype = type (col1_dtype_instance ), type (col2_dtype_instance )
495
507
496
508
self .column_stats .append (
497
509
{
498
510
"column" : column ,
499
511
"match_column" : col_match ,
500
512
"match_cnt" : match_cnt ,
501
513
"unequal_cnt" : row_cnt - match_cnt ,
502
- "dtype1" : str ( col1_dtype ),
503
- "dtype2" : str ( col2_dtype ),
514
+ "dtype1" : col1_dtype_instance . simple_string ( ),
515
+ "dtype2" : col2_dtype_instance . simple_string ( ),
504
516
"all_match" : all (
505
517
(
506
518
col1_dtype == col2_dtype ,
@@ -995,7 +1007,10 @@ def columns_equal(
995
1007
A column of boolean values are added. True == the values match, False == the
996
1008
values don't match.
997
1009
"""
998
- base_dtype , compare_dtype = _get_column_dtypes (dataframe , col_1 , col_2 )
1010
+ base_dtype_instance , compare_dtype_instance = _get_column_dtypes (
1011
+ dataframe , col_1 , col_2
1012
+ )
1013
+ base_dtype , compare_dtype = type (base_dtype_instance ), type (compare_dtype_instance )
999
1014
if _is_comparable (base_dtype , compare_dtype ):
1000
1015
if (base_dtype in NUMERIC_SNOWPARK_TYPES ) and (
1001
1016
compare_dtype in NUMERIC_SNOWPARK_TYPES
@@ -1028,7 +1043,7 @@ def columns_equal(
1028
1043
)
1029
1044
else :
1030
1045
LOG .debug (
1031
- f"Skipping { col_1 } ({ base_dtype } ) and { col_2 } ({ compare_dtype } ), columns are not comparable"
1046
+ f"Skipping { col_1 } ({ base_dtype_instance . simple_string () } ) and { col_2 } ({ compare_dtype_instance . simple_string () } ), columns are not comparable"
1032
1047
)
1033
1048
dataframe = dataframe .withColumn (col_match , lit (False ))
1034
1049
return dataframe
@@ -1217,8 +1232,14 @@ def _get_column_dtypes(
1217
1232
Tuple(str, str)
1218
1233
Tuple of base and compare datatype
1219
1234
"""
1220
- base_dtype = next (d [1 ] for d in dataframe .dtypes if d [0 ] == col_1 )
1221
- compare_dtype = next (d [1 ] for d in dataframe .dtypes if d [0 ] == col_2 )
1235
+ df_raw_dtypes = [
1236
+ (name , field .datatype )
1237
+ for name , field in zip (
1238
+ dataframe .schema .names , dataframe .schema .fields , strict = False
1239
+ )
1240
+ ]
1241
+ base_dtype = next (d [1 ] for d in df_raw_dtypes if d [0 ] == col_1 )
1242
+ compare_dtype = next (d [1 ] for d in df_raw_dtypes if d [0 ] == col_2 )
1222
1243
return base_dtype , compare_dtype
1223
1244
1224
1245
@@ -1244,10 +1265,10 @@ def _is_comparable(type1: str, type2: str) -> bool:
1244
1265
return (
1245
1266
type1 == type2
1246
1267
or (type1 in NUMERIC_SNOWPARK_TYPES and type2 in NUMERIC_SNOWPARK_TYPES )
1247
- or ("string" in type1 and type2 == "date" )
1248
- or (type1 == "date" and "string" in type2 )
1249
- or ("string" in type1 and type2 == "timestamp" )
1250
- or (type1 == "timestamp" and "string" in type2 )
1268
+ or (type1 == StringType and type2 == DateType )
1269
+ or (type1 == DateType and type2 == StringType )
1270
+ or (type1 == StringType and type2 == TimestampType )
1271
+ or (type1 == TimestampType and type2 == StringType )
1251
1272
)
1252
1273
1253
1274
0 commit comments