From ecc528cdad4a87bf3cdae91e9fd2d828063e1364 Mon Sep 17 00:00:00 2001 From: Franklin Ogidi <41602287+fcogidi@users.noreply.github.com> Date: Wed, 15 May 2024 16:34:03 -0400 Subject: [PATCH 1/2] fixes for slicer and evaluators --- cyclops/data/slicer.py | 14 ++++++++++++++ cyclops/evaluate/evaluator.py | 6 ++++-- cyclops/evaluate/fairness/evaluator.py | 3 +++ cyclops/evaluate/metrics/experimental/metric.py | 1 + docs/source/tutorials/synthea/los_prediction.ipynb | 10 ---------- .../evaluate/metrics/experimental/test_metric.py | 4 ++-- 6 files changed, 24 insertions(+), 14 deletions(-) diff --git a/cyclops/data/slicer.py b/cyclops/data/slicer.py index a17047db1..cc4487405 100644 --- a/cyclops/data/slicer.py +++ b/cyclops/data/slicer.py @@ -3,6 +3,7 @@ import copy import datetime import itertools +import json from dataclasses import dataclass, field from functools import partial from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union @@ -248,6 +249,19 @@ def _create_intersections(self) -> None: ) self.spec_list.extend(intersect_list) + # remove duplicates + seen = set() + result = [] + + for spec in self.spec_list: + spec_str = json.dumps(spec, sort_keys=True) + if spec_str not in seen: + seen.add(spec_str) + result.append(spec) + + seen.clear() + self.spec_list = result + def _parse_and_register_slice_specs( self, slice_spec: Dict[str, Dict[str, Any]], diff --git a/cyclops/evaluate/evaluator.py b/cyclops/evaluate/evaluator.py index d883ed133..bfeb54bb7 100644 --- a/cyclops/evaluate/evaluator.py +++ b/cyclops/evaluate/evaluator.py @@ -152,7 +152,9 @@ def evaluate( fairness_config.batch_size = batch_size fairness_config.remove_columns = ignore_columns - fairness_results = evaluate_fairness(**asdict(fairness_config)) + fairness_results = evaluate_fairness( + **asdict(fairness_config), array_lib=array_lib + ) results["fairness"] = fairness_results return results @@ -304,7 +306,7 @@ def _compute_metrics( metrics.update(targets, predictions) metric_output = metrics.compute() - metrics.reset() + metrics.reset() model_name: str = "model_for_%s" % prediction_column results.setdefault(model_name, {}) diff --git a/cyclops/evaluate/fairness/evaluator.py b/cyclops/evaluate/fairness/evaluator.py index 6c41f7454..9b25773bf 100644 --- a/cyclops/evaluate/fairness/evaluator.py +++ b/cyclops/evaluate/fairness/evaluator.py @@ -728,6 +728,9 @@ def _compute_metrics( # noqa: C901, PLR0912 The batch size to use for the computation. metric_name : Optional[str] The name of the metric to compute. + array_lib : {"torch", "numpy, "cupy"}, default="numpy" + The array library to use for the metric computation. The metric results + will be returned in the format of `array_lib`. Returns ------- diff --git a/cyclops/evaluate/metrics/experimental/metric.py b/cyclops/evaluate/metrics/experimental/metric.py index 07279e989..ea1ef398b 100644 --- a/cyclops/evaluate/metrics/experimental/metric.py +++ b/cyclops/evaluate/metrics/experimental/metric.py @@ -431,6 +431,7 @@ def reset(self) -> None: "object or a list of array API objects. But got " f"`{type(default_value)} instead.", ) + self._defaults = {} self._update_count = 0 self._computed = None diff --git a/docs/source/tutorials/synthea/los_prediction.ipynb b/docs/source/tutorials/synthea/los_prediction.ipynb index 4dc0bf923..53d4e4304 100644 --- a/docs/source/tutorials/synthea/los_prediction.ipynb +++ b/docs/source/tutorials/synthea/los_prediction.ipynb @@ -1068,16 +1068,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "172a1654", - "metadata": {}, - "outputs": [], - "source": [ - "results" - ] - }, { "cell_type": "markdown", "id": "7d2d1d75-f7d8-44d3-a782-2aba9a4fbac0", diff --git a/tests/cyclops/evaluate/metrics/experimental/test_metric.py b/tests/cyclops/evaluate/metrics/experimental/test_metric.py index 205fed5a0..ff94db0d2 100644 --- a/tests/cyclops/evaluate/metrics/experimental/test_metric.py +++ b/tests/cyclops/evaluate/metrics/experimental/test_metric.py @@ -331,7 +331,7 @@ def test_reset_compute(): anp.asarray(42, dtype=anp.float32), ) metric.reset() - assert metric.state_vars == {"x": anp.asarray(0, dtype=anp.float32)} + assert metric.state_vars == {} def test_error_on_compute_before_update(): @@ -397,7 +397,7 @@ def test_call(): assert metric._computed is None metric.reset() - assert metric.state_vars == {"x": anp.asarray(0, dtype=anp.float32)} + assert metric.state_vars == {} assert metric._computed is None From c7e9d13c9d92af72af782701a4fe3345b509324a Mon Sep 17 00:00:00 2001 From: Franklin Ogidi <41602287+fcogidi@users.noreply.github.com> Date: Wed, 15 May 2024 22:49:15 -0400 Subject: [PATCH 2/2] fix state reference creation in `MetricDict` --- cyclops/evaluate/metrics/experimental/metric_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cyclops/evaluate/metrics/experimental/metric_dict.py b/cyclops/evaluate/metrics/experimental/metric_dict.py index 341e28e99..8c21eb446 100644 --- a/cyclops/evaluate/metrics/experimental/metric_dict.py +++ b/cyclops/evaluate/metrics/experimental/metric_dict.py @@ -361,7 +361,7 @@ def deepcopy_state(obj: Any) -> Any: for metric_names in self._metric_groups.values(): base_metric = self.data[metric_names[0]] for metric_name in metric_names[1:]: - for state in self.data[metric_name]._defaults: + for state in base_metric._defaults: base_metric_state = getattr(base_metric, state) setattr( self.data[metric_name],