Skip to content

Commit 9dc9ee8

Browse files
fdosanimparaz
andauthored
feat: add per-column tolerance support for comparison operations (#426)
* feat: add per-column tolerance validation and comparison functionality for pandas * Snowflake per-column tolerances (#417) * Add Snowflake per-column tolerances. * Use the original abs_tol and rel_tol parameters but take both the float and the dict options. * Remove docstring for removed parameters. * Add tests for _all_mismatch, fix ruff lint. * Change report format to include per-column tolerances and change the row summary to default. * Ruff formatting. * Remove breakpoint() . * Use orig_col_name without _MATCH as the key for the tolerances. --------- Co-authored-by: Faisal <[email protected]> * feat: enhance report template with default tolerances and remove unused snowflake row summary template * refactor tolerance validation and enhance column tolerance retrieval functionality * implement per-column tolerance validation and comparison in PolarsCompare * enhance SparkSQLCompare to support per-column absolute and relative tolerances * make relative and absolute tolerances available in fugue report output * small cleanup and improvements * add tests for get_column_tolerance --------- Co-authored-by: Miguel Paraz <[email protected]>
1 parent a0a145e commit 9dc9ee8

File tree

14 files changed

+928
-95
lines changed

14 files changed

+928
-95
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ pip install datacompy[snowflake]
4646

4747
### LegacySparkCompare and SparkPandasCompare removal
4848

49-
With version ``v0.17.0`` the ``LegacySparkCompare`` and ``SparkPandasCompare`` have been removed.
49+
Starting with v0.17.0, both `LegacySparkCompare` and `SparkPandasCompare` have been removed.
5050

5151

52-
#### Supported versions and dependncies
52+
#### Supported versions and dependencies
5353

5454
Different versions of Spark, Pandas, and Python interact differently. Below is a matrix of what we test with.
5555
With the move to Pandas on Spark API and compatability issues with Pandas 2+ we will for the mean time note support Pandas 2

datacompy/base.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import logging
2525
from abc import ABC, abstractmethod
2626
from pathlib import Path
27-
from typing import Any
27+
from typing import Any, Dict
2828

2929
from jinja2 import Environment, FileSystemLoader, select_autoescape
3030
from ordered_set import OrderedSet
@@ -348,3 +348,98 @@ def df_to_str(df: Any, sample_count: int | None = None, on_index: bool = False)
348348

349349
# Fallback to str() if we can't determine the type
350350
return str(df)
351+
352+
353+
def get_column_tolerance(column: str, tol_dict: Dict[str, float]) -> float:
354+
"""
355+
Return the tolerance value for a given column from a dictionary of tolerances.
356+
357+
Parameters
358+
----------
359+
column : str
360+
The name of the column for which to retrieve the tolerance.
361+
tol_dict : dict of str to float
362+
Dictionary mapping column names to their tolerance values.
363+
May contain a "default" key for columns not explicitly listed.
364+
365+
Returns
366+
-------
367+
float
368+
The tolerance value for the specified column, or the "default" tolerance if the column is not found.
369+
Returns 0 if neither the column nor "default" is present in the dictionary.
370+
"""
371+
return tol_dict.get(column, tol_dict.get("default", 0.0))
372+
373+
374+
def _validate_tolerance_parameter(
375+
param_value: float | Dict[str, float],
376+
param_name: str,
377+
case_mode: str = "lower",
378+
) -> Dict[str, float]:
379+
"""Validate and normalize tolerance parameter input.
380+
381+
Parameters
382+
----------
383+
param_value : float or dict
384+
The tolerance value to validate. Can be either a float or a dictionary mapping
385+
column names to float values.
386+
param_name : str
387+
Name of the parameter being validated ('abs_tol' or 'rel_tol')
388+
case_mode : str
389+
How to handle column name case. Options are:
390+
- "lower": convert to lowercase
391+
- "upper": convert to uppercase
392+
- "preserve": keep original case
393+
394+
Returns
395+
-------
396+
dict
397+
Normalized dictionary of tolerance values
398+
399+
Raises
400+
------
401+
TypeError
402+
If param_value is not a float or dict
403+
ValueError
404+
If any tolerance values are not numeric or negative or if case_mode is invalid
405+
"""
406+
if case_mode not in ["lower", "upper", "preserve"]:
407+
raise ValueError("case_mode must be 'lower', 'upper', or 'preserve'")
408+
409+
# If float, convert to dict with default value
410+
if isinstance(param_value, int | float):
411+
if param_value < 0:
412+
raise ValueError(f"{param_name} cannot be negative")
413+
return {"default": float(param_value)}
414+
415+
# If dict, validate values and format
416+
if isinstance(param_value, dict):
417+
result = {}
418+
419+
# Convert all values to float and validate
420+
for col, value in param_value.items():
421+
if not isinstance(value, int | float):
422+
raise ValueError(
423+
f"Value for column '{col}' in {param_name} must be numeric"
424+
)
425+
if value < 0:
426+
raise ValueError(
427+
f"Value for column '{col}' in {param_name} cannot be negative"
428+
)
429+
430+
# Handle column name case according to case_mode
431+
col_key = str(col)
432+
if case_mode == "lower":
433+
col_key = col_key.lower()
434+
elif case_mode == "upper":
435+
col_key = col_key.upper()
436+
437+
result[col_key] = float(value)
438+
439+
# If no default provided, add 0.0
440+
if "default" not in result:
441+
result["default"] = 0.0
442+
443+
return result
444+
445+
raise TypeError(f"{param_name} must be a float or dictionary")

datacompy/core.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030

3131
from datacompy.base import (
3232
BaseCompare,
33+
_validate_tolerance_parameter,
3334
df_to_str,
35+
get_column_tolerance,
3436
render,
3537
save_html_report,
3638
temp_column_name,
@@ -59,10 +61,14 @@ class Compare(BaseCompare):
5961
If True, the index will be used to join the two dataframes. If both
6062
``join_columns`` and ``on_index`` are provided, an exception will be
6163
raised.
62-
abs_tol : float, optional
63-
Absolute tolerance between two values.
64-
rel_tol : float, optional
65-
Relative tolerance between two values.
64+
abs_tol : float or dict, optional
65+
Absolute tolerance between two values. Can be either a float value applied to all columns,
66+
or a dictionary mapping column names to specific tolerance values. The special key "default"
67+
in the dictionary specifies the tolerance for columns not explicitly listed.
68+
rel_tol : float or dict, optional
69+
Relative tolerance between two values. Can be either a float value applied to all columns,
70+
or a dictionary mapping column names to specific tolerance values. The special key "default"
71+
in the dictionary specifies the tolerance for columns not explicitly listed.
6672
df1_name : str, optional
6773
A string name for the first dataframe. This allows the reporting to
6874
print out an actual name instead of "df1", and allows human users to
@@ -91,15 +97,24 @@ def __init__(
9197
df2: pd.DataFrame,
9298
join_columns: List[str] | str | None = None,
9399
on_index: bool = False,
94-
abs_tol: float = 0,
95-
rel_tol: float = 0,
100+
abs_tol: float | Dict[str, float] = 0,
101+
rel_tol: float | Dict[str, float] = 0,
96102
df1_name: str = "df1",
97103
df2_name: str = "df2",
98104
ignore_spaces: bool = False,
99105
ignore_case: bool = False,
100106
cast_column_names_lower: bool = True,
101107
) -> None:
102108
self.cast_column_names_lower = cast_column_names_lower
109+
110+
# Validate tolerance parameters first
111+
self._abs_tol_dict = _validate_tolerance_parameter(
112+
abs_tol, "abs_tol", "lower" if cast_column_names_lower else "preserve"
113+
)
114+
self._rel_tol_dict = _validate_tolerance_parameter(
115+
rel_tol, "rel_tol", "lower" if cast_column_names_lower else "preserve"
116+
)
117+
103118
if on_index and join_columns is not None:
104119
raise Exception("Only provide on_index or join_columns")
105120
elif on_index:
@@ -369,12 +384,12 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
369384
[
370385
self.intersect_rows,
371386
columns_equal(
372-
self.intersect_rows[col_1],
373-
self.intersect_rows[col_2],
374-
self.rel_tol,
375-
self.abs_tol,
376-
ignore_spaces,
377-
ignore_case,
387+
col_1=self.intersect_rows[col_1],
388+
col_2=self.intersect_rows[col_2],
389+
rel_tol=get_column_tolerance(column, self._rel_tol_dict),
390+
abs_tol=get_column_tolerance(column, self._abs_tol_dict),
391+
ignore_spaces=ignore_spaces,
392+
ignore_case=ignore_case,
378393
).to_frame(name=col_match),
379394
],
380395
axis=1,
@@ -414,6 +429,8 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
414429
),
415430
"max_diff": max_diff,
416431
"null_diff": null_diff,
432+
"rel_tol": get_column_tolerance(column, self._rel_tol_dict),
433+
"abs_tol": get_column_tolerance(column, self._abs_tol_dict),
417434
}
418435
)
419436

@@ -589,12 +606,12 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pd.DataFrame:
589606
orig_col_name = col[:-6]
590607

591608
col_comparison = columns_equal(
592-
self.intersect_rows[orig_col_name + "_" + self.df1_name],
593-
self.intersect_rows[orig_col_name + "_" + self.df2_name],
594-
self.rel_tol,
595-
self.abs_tol,
596-
self.ignore_spaces,
597-
self.ignore_case,
609+
col_1=self.intersect_rows[orig_col_name + "_" + self.df1_name],
610+
col_2=self.intersect_rows[orig_col_name + "_" + self.df2_name],
611+
rel_tol=get_column_tolerance(orig_col_name, self._rel_tol_dict),
612+
abs_tol=get_column_tolerance(orig_col_name, self._abs_tol_dict),
613+
ignore_spaces=self.ignore_spaces,
614+
ignore_case=self.ignore_case,
598615
)
599616

600617
if not ignore_matching_cols or (
@@ -717,6 +734,8 @@ def _get_mismatch_stats(self, sample_count: int) -> dict:
717734
"unequal_cnt": column["unequal_cnt"],
718735
"max_diff": column["max_diff"],
719736
"null_diff": column["null_diff"],
737+
"rel_tol": column["rel_tol"],
738+
"abs_tol": column["abs_tol"],
720739
}
721740
)
722741
if column["unequal_cnt"] > 0:
@@ -961,6 +980,7 @@ def columns_equal(
961980
col_2, ignore_spaces=ignore_spaces, ignore_case=ignore_case
962981
)
963982

983+
# Rest of comparison logic using rel_tol and abs_tol
964984
# short circuit if comparing mixed type columns. Check list/arrrays or just return false for everything else.
965985
if pd.api.types.infer_dtype(col_1).startswith("mixed") or pd.api.types.infer_dtype(
966986
col_2

datacompy/fugue.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,8 @@ def _any(col: str) -> int:
625625
"unequal_cnt": col["unequal_cnt"],
626626
"max_diff": col["max_diff"],
627627
"null_diff": col["null_diff"],
628+
"rel_tol": rel_tol,
629+
"abs_tol": abs_tol,
628630
}
629631
for col in column_stats
630632
if not col["all_match"]

datacompy/polars.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131

3232
from datacompy.base import (
3333
BaseCompare,
34+
_validate_tolerance_parameter,
3435
df_to_str,
36+
get_column_tolerance,
3537
render,
3638
save_html_report,
3739
temp_column_name,
@@ -59,10 +61,13 @@ class PolarsCompare(BaseCompare):
5961
join_columns : list or str
6062
Column(s) to join dataframes on. If a string is passed in, that one
6163
column will be used.
62-
abs_tol : float, optional
63-
Absolute tolerance between two values.
64-
rel_tol : float, optional
65-
Relative tolerance between two values.
64+
abs_tol : float or dict, optional
65+
Absolute tolerance between two values. Can be either a float value applied to all columns,
66+
or a dictionary mapping column names to specific tolerance values. The special key "default"
67+
in the dictionary specifies the tolerance for columns not explicitly listed.
68+
rel_tol : float or dict, optional
69+
Relative tolerance between two values. Can be either a float value applied to all columns,
70+
or a dictionary mapping column names to specific tolerance values. The special key "default"
6671
df1_name : str, optional
6772
A string name for the first dataframe. This allows the reporting to
6873
print out an actual name instead of "df1", and allows human users to
@@ -90,8 +95,8 @@ def __init__(
9095
df1: pl.DataFrame,
9196
df2: pl.DataFrame,
9297
join_columns: List[str] | str,
93-
abs_tol: float = 0,
94-
rel_tol: float = 0,
98+
abs_tol: float | Dict[str, float] = 0,
99+
rel_tol: float | Dict[str, float] = 0,
95100
df1_name: str = "df1",
96101
df2_name: str = "df2",
97102
ignore_spaces: bool = False,
@@ -100,6 +105,14 @@ def __init__(
100105
) -> None:
101106
self.cast_column_names_lower = cast_column_names_lower
102107

108+
# Validate tolerance parameters first
109+
self._abs_tol_dict = _validate_tolerance_parameter(
110+
abs_tol, "abs_tol", "lower" if cast_column_names_lower else "preserve"
111+
)
112+
self._rel_tol_dict = _validate_tolerance_parameter(
113+
rel_tol, "rel_tol", "lower" if cast_column_names_lower else "preserve"
114+
)
115+
103116
if isinstance(join_columns, str):
104117
self.join_columns = [
105118
str(join_columns).lower()
@@ -371,12 +384,12 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
371384
col_match = column + "_match"
372385
self.intersect_rows = self.intersect_rows.with_columns(
373386
columns_equal(
374-
self.intersect_rows[col_1],
375-
self.intersect_rows[col_2],
376-
self.rel_tol,
377-
self.abs_tol,
378-
ignore_spaces,
379-
ignore_case,
387+
col_1=self.intersect_rows[col_1],
388+
col_2=self.intersect_rows[col_2],
389+
rel_tol=get_column_tolerance(column, self._rel_tol_dict),
390+
abs_tol=get_column_tolerance(column, self._abs_tol_dict),
391+
ignore_spaces=ignore_spaces,
392+
ignore_case=ignore_case,
380393
).alias(col_match)
381394
)
382395
match_cnt = self.intersect_rows[col_match].sum()
@@ -409,6 +422,8 @@ def _intersect_compare(self, ignore_spaces: bool, ignore_case: bool) -> None:
409422
),
410423
"max_diff": max_diff,
411424
"null_diff": null_diff,
425+
"rel_tol": get_column_tolerance(column, self._rel_tol_dict),
426+
"abs_tol": get_column_tolerance(column, self._abs_tol_dict),
412427
}
413428
)
414429

@@ -588,12 +603,12 @@ def all_mismatch(self, ignore_matching_cols: bool = False) -> pl.DataFrame:
588603
orig_col_name = col[:-6]
589604

590605
col_comparison = columns_equal(
591-
self.intersect_rows[orig_col_name + "_" + self.df1_name],
592-
self.intersect_rows[orig_col_name + "_" + self.df2_name],
593-
self.rel_tol,
594-
self.abs_tol,
595-
self.ignore_spaces,
596-
self.ignore_case,
606+
col_1=self.intersect_rows[orig_col_name + "_" + self.df1_name],
607+
col_2=self.intersect_rows[orig_col_name + "_" + self.df2_name],
608+
rel_tol=get_column_tolerance(orig_col_name, self._rel_tol_dict),
609+
abs_tol=get_column_tolerance(orig_col_name, self._abs_tol_dict),
610+
ignore_spaces=self.ignore_spaces,
611+
ignore_case=self.ignore_case,
597612
)
598613

599614
if not ignore_matching_cols or (
@@ -717,6 +732,8 @@ def _get_mismatch_stats(self, sample_count: int) -> dict:
717732
"unequal_cnt": column["unequal_cnt"],
718733
"max_diff": column["max_diff"],
719734
"null_diff": column["null_diff"],
735+
"rel_tol": column["rel_tol"],
736+
"abs_tol": column["abs_tol"],
720737
}
721738
)
722739
if column["unequal_cnt"] > 0:
@@ -768,7 +785,9 @@ def _get_unique_rows_data(self, sample_count: int, column_count: int) -> dict:
768785
"df1_unique_rows": {
769786
"has_rows": min_sample_count_df1 > 0,
770787
"rows": df_to_str(
771-
self.df1_unq_rows[:, :min_column_count_df1],
788+
self.df1_unq_rows.select(
789+
self.df1_unq_rows.columns[:min_column_count_df1]
790+
),
772791
sample_count=min_sample_count_df1,
773792
)
774793
if self.df1_unq_rows.shape[0] > 0
@@ -780,7 +799,9 @@ def _get_unique_rows_data(self, sample_count: int, column_count: int) -> dict:
780799
"df2_unique_rows": {
781800
"has_rows": min_sample_count_df2 > 0,
782801
"rows": df_to_str(
783-
self.df2_unq_rows[:, :min_column_count_df2],
802+
self.df2_unq_rows.select(
803+
self.df2_unq_rows.columns[:min_column_count_df2]
804+
),
784805
sample_count=min_sample_count_df2,
785806
)
786807
if self.df2_unq_rows.shape[0] > 0

0 commit comments

Comments
 (0)