Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion catalyst/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "22.02"
__version__ = "22.02.1"
2 changes: 1 addition & 1 deletion catalyst/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def log_image(
"""Logs image to MLflow for current scope on current step."""
if scope == "batch" or scope == "loader":
log_path = "_".join(
[tag, f"epoch-{runner.epoch_step:04d}", f"loader-{runner.loader}"]
[tag, f"epoch-{runner.epoch_step:04d}", f"loader-{runner.loader_key}"]
)
elif scope == "epoch":
log_path = "_".join([tag, f"epoch-{runner.epoch_step:04d}"])
Expand Down
6 changes: 3 additions & 3 deletions catalyst/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def log_artifact(
self.base_namespace,
"_artifacts",
f"epoch-{runner.epoch_step:04d}",
f"loader-{runner.loader}",
f"loader-{runner.loader_key}",
f"batch-{runner.batch_step:04d}",
tag,
]
Expand All @@ -185,7 +185,7 @@ def log_artifact(
self.base_namespace,
"_artifacts",
f"epoch-{runner.epoch_step:04d}",
f"loader-{runner.loader}",
f"loader-{runner.loader_key}",
tag,
]
)
Expand Down Expand Up @@ -216,7 +216,7 @@ def log_image(
self.base_namespace,
"_images",
f"epoch-{runner.epoch_step:04d}",
f"loader-{runner.loader}",
f"loader-{runner.loader_key}",
tag,
]
)
Expand Down
17 changes: 13 additions & 4 deletions catalyst/loggers/wandb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Optional, TYPE_CHECKING
import os
import pickle
import warnings

import numpy as np

Expand Down Expand Up @@ -72,6 +73,12 @@ def __init__(
super().__init__(
log_batch_metrics=log_batch_metrics, log_epoch_metrics=log_epoch_metrics
)
if self.log_batch_metrics:
warnings.warn(
"Wandb does NOT support several x-axes for logging."
"For this reason, everything has to be logged in the batch-based regime."
)

self.project = project
self.name = name
self.entity = entity
Expand Down Expand Up @@ -142,7 +149,8 @@ def log_image(
elif scope == "experiment" or scope is None:
log_path = tag

self.run.log({f"{log_path}.png": wandb.Image(image)}, step=runner.sample_step)
step = runner.sample_step if self.log_batch_metrics else runner.epoch_step
self.run.log({f"{log_path}.png": wandb.Image(image)}, step=step)

def log_hparams(self, hparams: Dict, runner: "IRunner" = None) -> None:
"""Logs hyperparameters to the logger."""
Expand All @@ -155,18 +163,19 @@ def log_metrics(
runner: "IRunner",
) -> None:
"""Logs batch and epoch metrics to wandb."""
step = runner.sample_step if self.log_batch_metrics else runner.epoch_step
if scope == "batch" and self.log_batch_metrics:
metrics = {k: float(v) for k, v in metrics.items()}
self._log_metrics(
metrics=metrics,
step=runner.sample_step,
step=step,
loader_key=runner.loader_key,
prefix="batch",
)
elif scope == "loader" and self.log_epoch_metrics:
self._log_metrics(
metrics=metrics,
step=runner.sample_step,
step=step,
loader_key=runner.loader_key,
prefix="epoch",
)
Expand All @@ -175,7 +184,7 @@ def log_metrics(
per_loader_metrics = metrics[loader_key]
self._log_metrics(
metrics=per_loader_metrics,
step=runner.sample_step,
step=step,
loader_key=loader_key,
prefix="epoch",
)
Expand Down