Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
warnings.filterwarnings("ignore", message="numpy.dtype size changed", append=True)
warnings.filterwarnings("ignore", module="tqdm", append=True)
warnings.filterwarnings("once", append=True)
warnings.filterwarnings("ignore", message="This overload of add_ is deprecated", append=True)
warnings.filterwarnings(
"ignore", message="This overload of add_ is deprecated", append=True
)

from catalyst.__version__ import __version__
from catalyst.settings import SETTINGS
2 changes: 1 addition & 1 deletion catalyst/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "21.12rc1"
__version__ = "22.01rc0"
7 changes: 4 additions & 3 deletions catalyst/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
from catalyst.callbacks.optimizer import IOptimizerCallback, OptimizerCallback
from catalyst.core.callback import (
Callback,
CallbackList,
CallbackNode,
CallbackOrder,
CallbackScope,
CallbackWrapper,
ICallback,
IBackwardCallback,
ICriterionCallback,
IOptimizerCallback,
ISchedulerCallback,
)
from catalyst.settings import SETTINGS

Expand Down
7 changes: 5 additions & 2 deletions catalyst/callbacks/batch_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def __init__(self, **kwargs):
for loader, num_batches in kwargs.items():
if not isinstance(num_batches, (int, float)):
raise TypeError(
"Expected loader num_batches type is int/float " f"but got {type(num_batches)}"
"Expected loader num_batches type is int/float "
f"but got {type(num_batches)}"
)
self.loader_batches[loader] = num_batches

Expand All @@ -110,7 +111,9 @@ def on_epoch_start(self, runner: "IRunner") -> None:
num_batches = self.loader_batches.get(name, 1)
if isinstance(num_batches, float):
num_batches = int(len(loader) * num_batches)
epoch_loaders[name] = BatchLimitLoaderWrapper(loader=loader, num_batches=num_batches)
epoch_loaders[name] = BatchLimitLoaderWrapper(
loader=loader, num_batches=num_batches
)

runner.loaders = epoch_loaders

Expand Down
12 changes: 9 additions & 3 deletions catalyst/callbacks/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def __init__(
output_key = output_key or input_key
if output_key is not None:
if input_key is None:
raise TypeError("You should define input_key in " "case if output_key is not None")
raise TypeError(
"You should define input_key in " "case if output_key is not None"
)
if not isinstance(output_key, (list, str)):
raise TypeError("output key should be str or a list of str.")
if isinstance(output_key, str):
Expand All @@ -220,15 +222,19 @@ def __init__(
if isinstance(scope, str) and scope in ["on_batch_end", "on_batch_start"]:
self.scope = scope
else:
raise TypeError('Expected scope to be on of the ["on_batch_end", "on_batch_start"]')
raise TypeError(
'Expected scope to be on of the ["on_batch_end", "on_batch_start"]'
)
self.input_key = input_key
self.output_key = output_key
self.transform = transform

def _handle_value(self, runner):
batch_in = [runner.batch[key] for key in self.input_key]
batch_out = self.transform(*batch_in)
runner.batch.update(**{key: value for key, value in zip(self.output_key, batch_out)})
runner.batch.update(
**{key: value for key, value in zip(self.output_key, batch_out)}
)

def _handle_key_value(self, runner):
runner.batch = self.transform(runner.batch)
Expand Down
33 changes: 22 additions & 11 deletions catalyst/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import shutil

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.callback import Callback, CallbackOrder
from catalyst.core.runner import IRunner
from catalyst.extras.metric_handler import MetricHandler
from catalyst.utils.config import save_config
Expand Down Expand Up @@ -423,7 +423,7 @@ def __init__(
use_runner_logdir: bool = False,
):
"""Init."""
super().__init__(order=CallbackOrder.external, node=CallbackNode.all)
super().__init__(order=CallbackOrder.external)
possible_states = _default_states.union([None])
assert save_n_best >= 0
if save_n_best == 0:
Expand All @@ -439,7 +439,9 @@ def __init__(
"requires both `loader_key` and `metric_key` specified."
)
self._use_model_selection = True
self.minimize = minimize if minimize is not None else True # loss-oriented selection
self.minimize = (
minimize if minimize is not None else True
) # loss-oriented selection
else:
self._use_model_selection = False
self.minimize = False # epoch-num-oriented selection
Expand Down Expand Up @@ -556,11 +558,15 @@ def _truncate_checkpoints(self) -> None:
if len(self.top_best_metrics) > self.save_n_best:
last_item = self.top_best_metrics.pop(-1)
last_filepath = Path(last_item[1])
last_filepaths = last_filepath.parent.glob(last_filepath.name.replace(".pth", "*"))
last_filepaths = last_filepath.parent.glob(
last_filepath.name.replace(".pth", "*")
)
for filepath in last_filepaths:
os.remove(filepath)

def _prepare_metrics_log(self, last_epoch_score: float, last_epoch_metrics: Dict) -> Dict:
def _prepare_metrics_log(
self, last_epoch_score: float, last_epoch_metrics: Dict
) -> Dict:
top_best_checkpoints = [
(Path(filepath).stem, {**epoch_metrics, **{"_score_": score}})
for (score, filepath, _, _, epoch_metrics) in self.top_best_metrics
Expand Down Expand Up @@ -616,9 +622,7 @@ def on_stage_start(self, runner: "IRunner") -> None:
self.resume = None
elif self.load_on_stage_start is not None:
_load_runner(
logdir=self.logdir,
runner=runner,
mapping=self.load_on_stage_start,
logdir=self.logdir, runner=runner, mapping=self.load_on_stage_start,
)

def on_epoch_end(self, runner: "IRunner") -> None:
Expand Down Expand Up @@ -664,7 +668,9 @@ def on_epoch_end(self, runner: "IRunner") -> None:
# truncate checkpoints
self._truncate_checkpoints()
# save checkpoint metrics
metrics_log = self._prepare_metrics_log(float(score), dict(runner.epoch_metrics))
metrics_log = self._prepare_metrics_log(
float(score), dict(runner.epoch_metrics)
)
save_config(metrics_log, f"{self.logdir}/{self.metrics_filename}")

def on_stage_end(self, runner: "IRunner") -> None:
Expand Down Expand Up @@ -698,12 +704,17 @@ def on_stage_end(self, runner: "IRunner") -> None:
)
# add metrics to records
# save checkpoint metrics
metrics_log = self._prepare_metrics_log(float(score), dict(runner.epoch_metrics))
metrics_log = self._prepare_metrics_log(
float(score), dict(runner.epoch_metrics)
)
save_config(metrics_log, f"{self.logdir}/{self.metrics_filename}")
log_message += f"{checkpoint_path}\t{score:3.4f}"
else:
log_message += "\n".join(
[f"{filepath}\t{score:3.4f}" for score, filepath, _, _, _ in self.top_best_metrics]
[
f"{filepath}\t{score:3.4f}"
for score, filepath, _, _, _ in self.top_best_metrics
]
)
print(log_message)

Expand Down
13 changes: 10 additions & 3 deletions catalyst/callbacks/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@


class _EpochFilterFn:
def __init__(self, epochs: Union[int, float, Sequence[int]], reverse_condition: bool):
def __init__(
self, epochs: Union[int, float, Sequence[int]], reverse_condition: bool
):
if not isinstance(epochs, (int, float, list, tuple)):
raise ValueError(
"'epochs' should be int/float/Sequence[int]! " f"(got {type(epochs)})"
Expand Down Expand Up @@ -106,7 +108,8 @@ def __init__(self, filter_fn: Union[str, FILTER_FN]):
raise ValueError("'filter_fn' should be a callable!")
if filter_fn.__code__.co_argcount != 3:
raise ValueError(
"Filter function should have three arguments - " "'stage', 'epoch' and 'loader'!"
"Filter function should have three arguments - "
"'stage', 'epoch' and 'loader'!"
)
self.filter_fn = filter_fn

Expand Down Expand Up @@ -342,7 +345,11 @@ def on_loader_start(self, runner: "IRunner") -> None:
"""
stage = runner.stage_key
loader = runner.loader_key
epoch = runner.global_epoch_step if self.use_global_epochs else runner.stage_epoch_step
epoch = (
runner.global_epoch_step
if self.use_global_epochs
else runner.stage_epoch_step
)

if self.filter_fn is not None:
self._is_enabled = self.filter_fn(stage, epoch, loader)
Expand Down
23 changes: 17 additions & 6 deletions catalyst/callbacks/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.callback import Callback, CallbackOrder
from catalyst.core.runner import IRunner
from catalyst.metrics._functional_metric import FunctionalBatchMetric
from catalyst.metrics._metric import ICallbackBatchMetric, ICallbackLoaderMetric, IMetric
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
target_key: Union[str, Iterable[str], Dict[str, str]],
):
"""Init MetricCallback"""
super().__init__(order=CallbackOrder.metric, node=CallbackNode.all)
super().__init__(order=CallbackOrder.metric)
self.metric = metric
assert isinstance(metric, IMetric)
self._metric_update_method = self.metric.update
Expand Down Expand Up @@ -89,7 +89,9 @@ def __init__(
}

@staticmethod
def _convert_keys_to_kv(keys: Union[str, Iterable[str], Dict[str, str]]) -> Dict[str, str]:
def _convert_keys_to_kv(
keys: Union[str, Iterable[str], Dict[str, str]]
) -> Dict[str, str]:
"""
Convert keys to key-value format

Expand Down Expand Up @@ -246,10 +248,15 @@ def __init__(
"""Init."""
assert isinstance(metric, FunctionalBatchMetric)
super().__init__(
metric=metric, input_key=input_key, target_key=target_key, log_on_batch=log_on_batch
metric=metric,
input_key=input_key,
target_key=target_key,
log_on_batch=log_on_batch,
)

def _get_value_inputs(self, runner: "IRunner") -> Tuple[float, torch.Tensor, torch.Tensor]:
def _get_value_inputs(
self, runner: "IRunner"
) -> Tuple[float, torch.Tensor, torch.Tensor]:
"""Get data from batch in value input case

Args:
Expand All @@ -258,7 +265,11 @@ def _get_value_inputs(self, runner: "IRunner") -> Tuple[float, torch.Tensor, tor
Returns:
tuple of tensor of inputs and tensor of targets
"""
return runner.batch_size, runner.batch[self.input_key], runner.batch[self.target_key]
return (
runner.batch_size,
runner.batch[self.input_key],
runner.batch[self.target_key],
)

def _get_key_value_inputs(self, runner: "IRunner") -> Dict[str, torch.Tensor]:
"""Get data from batch in key-value input case
Expand Down
15 changes: 10 additions & 5 deletions catalyst/callbacks/metric_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.callback import Callback, CallbackOrder

if TYPE_CHECKING:
from catalyst.core.runner import IRunner
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
multiplier: float = 1.0,
) -> None:
"""Init."""
super().__init__(order=CallbackOrder.metric_aggregation, node=CallbackNode.all)
super().__init__(order=CallbackOrder.metric_aggregation)

if metric_key is None or not isinstance(metric_key, str):
raise ValueError("prefix must be str")
Expand All @@ -141,7 +141,8 @@ def __init__(
)
elif not callable(mode):
raise NotImplementedError(
"mode must be `sum`, `mean` " "or `weighted_sum` or `weighted_mean` or be Callable"
"mode must be `sum`, `mean` "
"or `weighted_sum` or `weighted_mean` or be Callable"
)

assert scope in ("batch", "loader")
Expand All @@ -159,7 +160,9 @@ def __init__(
self.aggregation_fn = _sum_aggregation
if mode == "weighted_mean":
weights_sum = sum(metrics.items())
self.metrics = {key: weight / weights_sum for key, weight in metrics.items()}
self.metrics = {
key: weight / weights_sum for key, weight in metrics.items()
}
elif mode == "mean":
self.aggregation_fn = _mean_aggregation
elif callable(mode):
Expand All @@ -169,7 +172,9 @@ def _get_metrics_list(self, metrics: Dict) -> List[float]:
if self.metrics is not None:
try:
if self.mode == "weighted_sum":
result = [metrics[key] * value for key, value in self.metrics.items()]
result = [
metrics[key] * value for key, value in self.metrics.items()
]
else:
result = [metrics[key] for key in self.metrics]
except KeyError:
Expand Down
10 changes: 8 additions & 2 deletions catalyst/callbacks/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

from catalyst.settings import SETTINGS

from catalyst.callbacks.metrics.accuracy import AccuracyCallback, MultilabelAccuracyCallback
from catalyst.callbacks.metrics.accuracy import (
AccuracyCallback,
MultilabelAccuracyCallback,
)
from catalyst.callbacks.metrics.auc import AUCCallback

from catalyst.callbacks.metrics.classification import (
Expand Down Expand Up @@ -33,4 +36,7 @@
)

if SETTINGS.ml_required:
from catalyst.callbacks.metrics.scikit_learn import SklearnBatchCallback, SklearnLoaderCallback
from catalyst.callbacks.metrics.scikit_learn import (
SklearnBatchCallback,
SklearnLoaderCallback,
)
9 changes: 7 additions & 2 deletions catalyst/callbacks/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def __init__(
"""Init."""
super().__init__(
metric=AccuracyMetric(
topk_args=topk_args, num_classes=num_classes, prefix=prefix, suffix=suffix
topk_args=topk_args,
num_classes=num_classes,
prefix=prefix,
suffix=suffix,
),
input_key=input_key,
target_key=target_key,
Expand Down Expand Up @@ -185,7 +188,9 @@ def __init__(
):
"""Init."""
super().__init__(
metric=MultilabelAccuracyMetric(threshold=threshold, prefix=prefix, suffix=suffix),
metric=MultilabelAccuracyMetric(
threshold=threshold, prefix=prefix, suffix=suffix
),
input_key=input_key,
target_key=target_key,
log_on_batch=log_on_batch,
Expand Down
4 changes: 3 additions & 1 deletion catalyst/callbacks/metrics/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def __init__(
"""Init."""
super().__init__(
metric=AUCMetric(
compute_per_class_metrics=compute_per_class_metrics, prefix=prefix, suffix=suffix
compute_per_class_metrics=compute_per_class_metrics,
prefix=prefix,
suffix=suffix,
),
input_key=input_key,
target_key=target_key,
Expand Down
9 changes: 6 additions & 3 deletions catalyst/callbacks/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Dict, List, TYPE_CHECKING

from catalyst.core.callback import Callback, CallbackNode, CallbackOrder
from catalyst.core.callback import Callback, CallbackOrder
from catalyst.metrics._confusion_matrix import ConfusionMatrixMetric
from catalyst.settings import SETTINGS

if SETTINGS.ml_required:
from catalyst.contrib.utils.visualization import plot_confusion_matrix, render_figure_to_array
from catalyst.contrib.utils.visualization import (
plot_confusion_matrix,
render_figure_to_array,
)

if TYPE_CHECKING:
from catalyst.core.runner import IRunner
Expand Down Expand Up @@ -97,7 +100,7 @@ def __init__(
plot_params: Dict = None,
):
"""Callback initialisation."""
super().__init__(CallbackOrder.metric, CallbackNode.all)
super().__init__(CallbackOrder.metric)
assert num_classes is not None or class_names is not None
self.prefix = prefix or "confusion_matrix"
self.input_key = input_key
Expand Down
Loading