Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/mimiciv/discharge_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1182,9 +1182,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
"# remove the group size from the fairness results and add it to the slice name\n",
"# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
" group_size = slice_results.pop(\"Group Size\")\n",
" group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/mimiciv/icu_mortality_prediction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1159,9 +1159,9 @@
"# Reformatting the fairness metrics\n",
"fairness_results = copy.deepcopy(results[\"fairness\"])\n",
"fairness_metrics = {}\n",
"# remove the group size from the fairness results and add it to the slice name\n",
"# remove the sample_size from the fairness results and add it to the slice name\n",
"for slice_name, slice_results in fairness_results.items():\n",
" group_size = slice_results.pop(\"Group Size\")\n",
" group_size = slice_results.pop(\"sample_size\")\n",
" fairness_metrics[f\"{slice_name} (Size={group_size})\"] = slice_results"
]
},
Expand Down
1 change: 1 addition & 0 deletions cyclops/evaluate/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def _compute_metrics(
model_name: str = "model_for_%s" % prediction_column
results.setdefault(model_name, {})
results[model_name][slice_name] = metric_output
results[model_name][slice_name]["sample_size"] = len(sliced_dataset)

set_decode(dataset, True) # restore decoding features

Expand Down
4 changes: 2 additions & 2 deletions cyclops/evaluate/fairness/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def evaluate_fairness( # noqa: PLR0912
for prediction_column in fmt_prediction_columns:
results.setdefault(prediction_column, {})
results[prediction_column].setdefault(slice_name, {}).update(
{"Group Size": len(sliced_dataset)},
{"sample_size": len(sliced_dataset)},
)

pred_result = _get_metric_results_for_prediction_and_slice(
Expand Down Expand Up @@ -966,7 +966,7 @@ def _compute_parity_metrics(
parity_results[key] = {}
for slice_name, slice_result in prediction_result.items():
for metric_name, metric_value in slice_result.items():
if metric_name == "Group Size":
if metric_name == "sample_size":
continue

# add 'Parity' to the metric name before @threshold, if specified
Expand Down
10 changes: 10 additions & 0 deletions cyclops/report/model_card/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ class PerformanceMetric(
default_factory=list,
)

sample_size: Optional[StrictInt] = Field(
None,
description="The sample size used to compute this metric.",
)


class User(
BaseModelCardField,
Expand Down Expand Up @@ -599,6 +604,11 @@ class MetricCard(
description="Timestamps for each point in the history.",
)

sample_sizes: Optional[List[int]] = Field(
None,
description="Sample sizes for each point in the history.",
)


class MetricCardCollection(BaseModelCardField, composable_with="Overview"):
"""A collection of metric cards to be displayed in the model card."""
Expand Down
92 changes: 62 additions & 30 deletions cyclops/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
get_histories,
get_names,
get_passed,
get_sample_sizes,
get_slices,
get_thresholds,
get_timestamps,
Expand Down Expand Up @@ -855,6 +856,7 @@ def log_quantitative_analysis(
pass_fail_threshold_fns: Optional[
Union[Callable[[Any, float], bool], List[Callable[[Any, float], bool]]]
] = None,
sample_size: Optional[int] = None,
**extra: Any,
) -> None:
"""Add a quantitative analysis to the report.
Expand Down Expand Up @@ -921,6 +923,7 @@ def log_quantitative_analysis(
"slice": metric_slice,
"decision_threshold": decision_threshold,
"description": description,
"sample_size": sample_size,
**extra,
}

Expand Down Expand Up @@ -958,42 +961,70 @@ def log_quantitative_analysis(
field_type=field_type,
)

def log_performance_metrics(self, metrics: Dict[str, Any]) -> None:
"""Add a performance metric to the `Quantitative Analysis` section.
def log_performance_metrics(
self,
results: Dict[str, Any],
metric_descriptions: Dict[str, str],
pass_fail_thresholds: Union[float, Dict[str, float]] = 0.7,
pass_fail_threshold_fn: Callable[[float, float], bool] = lambda x,
threshold: bool(x >= threshold),
) -> None:
"""
Log all performance metrics to the model card report.

Parameters
----------
metrics : Dict[str, Any]
A dictionary of performance metrics. The keys should be the name of the
metric, and the values should be the value of the metric. If the metric
is a slice metric, the key should be the slice name followed by a slash
and then the metric name (e.g. "slice_name/metric_name"). If no slice
name is provided, the slice name will be "overall".

Raises
------
TypeError
If the given metrics are not a dictionary with string keys.
results : Dict[str, Any]
Dictionary containing the results,
with keys in the format "split/metric_name".
metric_descriptions : Dict[str, str]
Dictionary mapping metric names to their descriptions.
pass_fail_thresholds : Union[float, Dict[str, float]], optional
The threshold(s) for pass/fail tests.
Can be a single float applied to all metrics,
or a dictionary mapping "split/metric_name" to individual thresholds.
Default is 0.7.
pass_fail_threshold_fn : Callable[[float, float], bool], optional
Function to determine if a metric passes or fails.
Default is lambda x, threshold: bool(x >= threshold).

Returns
-------
None
"""
_raise_if_not_dict_with_str_keys(metrics)
for metric_name, metric_value in metrics.items():
name_split = metric_name.split("/")
if len(name_split) == 1:
slice_name = "overall"
metric_name = name_split[0] # noqa: PLW2901
else: # everything before the last slash is the slice name
slice_name = "/".join(name_split[:-1])
metric_name = name_split[-1] # noqa: PLW2901

# TODO: create plot
# Extract sample sizes
sample_sizes = {
key.split("/")[0]: value
for key, value in results.items()
if "sample_size" in key.split("/")[1]
}

self._log_field(
data={"type": metric_name, "value": metric_value, "slice": slice_name},
section_name="quantitative_analysis",
field_name="performance_metrics",
field_type=PerformanceMetric,
)
# Log metrics
for name, metric in results.items():
split, metric_name = name.split("/")
if metric_name != "sample_size":
metric_value = metric.tolist() if hasattr(metric, "tolist") else metric

# Determine the threshold for this specific metric
if isinstance(pass_fail_thresholds, dict):
threshold = pass_fail_thresholds.get(
name, 0.7
) # Default to 0.7 if not specified
else:
threshold = pass_fail_thresholds

self.log_quantitative_analysis(
"performance",
name=metric_name,
value=metric_value,
description=metric_descriptions.get(
metric_name, "No description provided."
),
metric_slice=split,
pass_fail_thresholds=threshold,
pass_fail_threshold_fns=pass_fail_threshold_fn,
sample_size=sample_sizes.get(split),
)

# TODO: MERGE/COMPARE MODEL CARDS

Expand Down Expand Up @@ -1162,6 +1193,7 @@ def export(
"get_names": get_names,
"get_histories": get_histories,
"get_timestamps": get_timestamps,
"get_sample_sizes": get_sample_sizes,
}
template.globals.update(func_dict)

Expand Down
Loading