Skip to content

Commit ee838fa

Browse files
authored
Merge pull request #413 from capitalone/develop
2 parents 8208f20 + 49d01a3 commit ee838fa

File tree

5 files changed

+81
-26
lines changed

5 files changed

+81
-26
lines changed

datacompy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Then extended to carry that functionality over to Spark Dataframes.
1919
"""
2020

21-
__version__ = "0.16.6"
21+
__version__ = "0.16.7"
2222

2323
import platform
2424
from warnings import warn

datacompy/polars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def generate_id_within_group(
10141014
dataframe[join_columns]
10151015
.cast(pl.String)
10161016
.fill_null(default_value)
1017-
.select(rn=pl.col(dataframe.columns[0]).cum_count().over(join_columns))
1017+
.select(rn=pl.col(join_columns[0]).cum_count().over(join_columns))
10181018
.to_series()
10191019
)
10201020
else:

datacompy/snowflake.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from ordered_set import OrderedSet
3232

3333
from datacompy.base import BaseCompare
34-
from datacompy.spark.sql import decimal_comparator
3534

3635
LOG = logging.getLogger(__name__)
3736

@@ -51,6 +50,29 @@
5150
trim,
5251
when,
5352
)
53+
from snowflake.snowpark.types import (
54+
ByteType,
55+
DateType,
56+
DecimalType,
57+
DoubleType,
58+
FloatType,
59+
IntegerType,
60+
LongType,
61+
ShortType,
62+
StringType,
63+
TimestampType,
64+
)
65+
66+
NUMERIC_SNOWPARK_TYPES = [
67+
ByteType,
68+
ShortType,
69+
IntegerType,
70+
LongType,
71+
FloatType,
72+
DoubleType,
73+
DecimalType,
74+
]
75+
5476

5577
except ImportError:
5678
LOG.warning(
@@ -59,17 +81,6 @@
5981
)
6082

6183

62-
NUMERIC_SNOWPARK_TYPES = [
63-
"tinyint",
64-
"smallint",
65-
"int",
66-
"bigint",
67-
"float",
68-
"double",
69-
decimal_comparator(),
70-
]
71-
72-
7384
class SnowflakeCompare(BaseCompare):
7485
"""Comparison class to be used to compare whether two Snowpark dataframes are equal.
7586
@@ -490,17 +501,18 @@ def _calculate_column_compare_stats(self, column: str) -> None:
490501
match_rate = 0
491502
LOG.info(f"{column}: {match_cnt} / {row_cnt} ({match_rate:.2%}) match")
492503

493-
col1_dtype, _ = _get_column_dtypes(self.df1, column, column)
494-
col2_dtype, _ = _get_column_dtypes(self.df2, column, column)
504+
col1_dtype_instance, _ = _get_column_dtypes(self.df1, column, column)
505+
col2_dtype_instance, _ = _get_column_dtypes(self.df2, column, column)
506+
col1_dtype, col2_dtype = type(col1_dtype_instance), type(col2_dtype_instance)
495507

496508
self.column_stats.append(
497509
{
498510
"column": column,
499511
"match_column": col_match,
500512
"match_cnt": match_cnt,
501513
"unequal_cnt": row_cnt - match_cnt,
502-
"dtype1": str(col1_dtype),
503-
"dtype2": str(col2_dtype),
514+
"dtype1": col1_dtype_instance.simple_string(),
515+
"dtype2": col2_dtype_instance.simple_string(),
504516
"all_match": all(
505517
(
506518
col1_dtype == col2_dtype,
@@ -995,7 +1007,10 @@ def columns_equal(
9951007
A column of boolean values are added. True == the values match, False == the
9961008
values don't match.
9971009
"""
998-
base_dtype, compare_dtype = _get_column_dtypes(dataframe, col_1, col_2)
1010+
base_dtype_instance, compare_dtype_instance = _get_column_dtypes(
1011+
dataframe, col_1, col_2
1012+
)
1013+
base_dtype, compare_dtype = type(base_dtype_instance), type(compare_dtype_instance)
9991014
if _is_comparable(base_dtype, compare_dtype):
10001015
if (base_dtype in NUMERIC_SNOWPARK_TYPES) and (
10011016
compare_dtype in NUMERIC_SNOWPARK_TYPES
@@ -1028,7 +1043,7 @@ def columns_equal(
10281043
)
10291044
else:
10301045
LOG.debug(
1031-
f"Skipping {col_1}({base_dtype}) and {col_2}({compare_dtype}), columns are not comparable"
1046+
f"Skipping {col_1}({base_dtype_instance.simple_string()}) and {col_2}({compare_dtype_instance.simple_string()}), columns are not comparable"
10321047
)
10331048
dataframe = dataframe.withColumn(col_match, lit(False))
10341049
return dataframe
@@ -1217,8 +1232,14 @@ def _get_column_dtypes(
12171232
Tuple(str, str)
12181233
Tuple of base and compare datatype
12191234
"""
1220-
base_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_1)
1221-
compare_dtype = next(d[1] for d in dataframe.dtypes if d[0] == col_2)
1235+
df_raw_dtypes = [
1236+
(name, field.datatype)
1237+
for name, field in zip(
1238+
dataframe.schema.names, dataframe.schema.fields, strict=False
1239+
)
1240+
]
1241+
base_dtype = next(d[1] for d in df_raw_dtypes if d[0] == col_1)
1242+
compare_dtype = next(d[1] for d in df_raw_dtypes if d[0] == col_2)
12221243
return base_dtype, compare_dtype
12231244

12241245

@@ -1244,10 +1265,10 @@ def _is_comparable(type1: str, type2: str) -> bool:
12441265
return (
12451266
type1 == type2
12461267
or (type1 in NUMERIC_SNOWPARK_TYPES and type2 in NUMERIC_SNOWPARK_TYPES)
1247-
or ("string" in type1 and type2 == "date")
1248-
or (type1 == "date" and "string" in type2)
1249-
or ("string" in type1 and type2 == "timestamp")
1250-
or (type1 == "timestamp" and "string" in type2)
1268+
or (type1 == StringType and type2 == DateType)
1269+
or (type1 == DateType and type2 == StringType)
1270+
or (type1 == StringType and type2 == TimestampType)
1271+
or (type1 == TimestampType and type2 == StringType)
12511272
)
12521273

12531274

docs/source/img/benchmarks.png

3.57 KB
Loading

tests/test_snowflake.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@
4343
)
4444
from pandas.testing import assert_frame_equal, assert_series_equal
4545
from snowflake.snowpark.exceptions import SnowparkSQLException
46+
from snowflake.snowpark.types import (
47+
DecimalType,
48+
StringType,
49+
StructField,
50+
StructType,
51+
)
4652

4753
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
4854

@@ -189,6 +195,34 @@ def test_date_columns_equal_with_ignore_spaces(snowpark_session):
189195
assert_series_equal(expect_out, actual_out_rev, check_names=False)
190196

191197

198+
def test_columns_equal_same_type_dif_length(snowpark_session):
199+
schema = StructType(
200+
[
201+
StructField("NAME", StringType(length=20)),
202+
StructField("DECIMAL_VAL", DecimalType(precision=7, scale=5)),
203+
StructField("NAME_COPY", StringType(is_max_size=True)),
204+
StructField("DECIMAL_VAL_COPY", DecimalType(precision=20, scale=10)),
205+
]
206+
)
207+
data = [
208+
["Alice", 10.44556, "Alice", 10.44556],
209+
["Bob", 2.33445, "Bob", 2.33445],
210+
["Charlie", 5.2234, "Charlie", 5.2234],
211+
]
212+
213+
df = snowpark_session.create_dataframe(data, schema=schema)
214+
assert (
215+
columns_equal(df, "NAME", "NAME_COPY", "NAME_ACTUAL")
216+
.toPandas()["NAME_ACTUAL"]
217+
.all()
218+
)
219+
assert (
220+
columns_equal(df, "DECIMAL_VAL", "DECIMAL_VAL_COPY", "DECIMAL_VAL_ACTUAL")
221+
.toPandas()["DECIMAL_VAL_ACTUAL"]
222+
.all()
223+
)
224+
225+
192226
def test_date_columns_unequal(snowpark_session):
193227
"""I want datetime fields to match with dates stored as strings"""
194228
data = [{"A": "2017-01-01", "B": "2017-01-02"}, {"A": "2017-01-01"}]

0 commit comments

Comments
 (0)