Skip to content

Commit 8d07ffd

Browse files
authored
wandb step fix (#1405)
* wandb step fix * Update __version__.py * +
1 parent f390239 commit 8d07ffd

File tree

4 files changed

+18
-9
lines changed

4 files changed

+18
-9
lines changed

catalyst/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "22.02"
1+
__version__ = "22.02.1"

catalyst/loggers/mlflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def log_image(
172172
"""Logs image to MLflow for current scope on current step."""
173173
if scope == "batch" or scope == "loader":
174174
log_path = "_".join(
175-
[tag, f"epoch-{runner.epoch_step:04d}", f"loader-{runner.loader}"]
175+
[tag, f"epoch-{runner.epoch_step:04d}", f"loader-{runner.loader_key}"]
176176
)
177177
elif scope == "epoch":
178178
log_path = "_".join([tag, f"epoch-{runner.epoch_step:04d}"])

catalyst/loggers/neptune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def log_artifact(
174174
self.base_namespace,
175175
"_artifacts",
176176
f"epoch-{runner.epoch_step:04d}",
177-
f"loader-{runner.loader}",
177+
f"loader-{runner.loader_key}",
178178
f"batch-{runner.batch_step:04d}",
179179
tag,
180180
]
@@ -185,7 +185,7 @@ def log_artifact(
185185
self.base_namespace,
186186
"_artifacts",
187187
f"epoch-{runner.epoch_step:04d}",
188-
f"loader-{runner.loader}",
188+
f"loader-{runner.loader_key}",
189189
tag,
190190
]
191191
)
@@ -216,7 +216,7 @@ def log_image(
216216
self.base_namespace,
217217
"_images",
218218
f"epoch-{runner.epoch_step:04d}",
219-
f"loader-{runner.loader}",
219+
f"loader-{runner.loader_key}",
220220
tag,
221221
]
222222
)

catalyst/loggers/wandb.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, Optional, TYPE_CHECKING
22
import os
33
import pickle
4+
import warnings
45

56
import numpy as np
67

@@ -72,6 +73,12 @@ def __init__(
7273
super().__init__(
7374
log_batch_metrics=log_batch_metrics, log_epoch_metrics=log_epoch_metrics
7475
)
76+
if self.log_batch_metrics:
77+
warnings.warn(
78+
"Wandb does NOT support several x-axes for logging."
79+
"For this reason, everything has to be logged in the batch-based regime."
80+
)
81+
7582
self.project = project
7683
self.name = name
7784
self.entity = entity
@@ -142,7 +149,8 @@ def log_image(
142149
elif scope == "experiment" or scope is None:
143150
log_path = tag
144151

145-
self.run.log({f"{log_path}.png": wandb.Image(image)}, step=runner.sample_step)
152+
step = runner.sample_step if self.log_batch_metrics else runner.epoch_step
153+
self.run.log({f"{log_path}.png": wandb.Image(image)}, step=step)
146154

147155
def log_hparams(self, hparams: Dict, runner: "IRunner" = None) -> None:
148156
"""Logs hyperparameters to the logger."""
@@ -155,18 +163,19 @@ def log_metrics(
155163
runner: "IRunner",
156164
) -> None:
157165
"""Logs batch and epoch metrics to wandb."""
166+
step = runner.sample_step if self.log_batch_metrics else runner.epoch_step
158167
if scope == "batch" and self.log_batch_metrics:
159168
metrics = {k: float(v) for k, v in metrics.items()}
160169
self._log_metrics(
161170
metrics=metrics,
162-
step=runner.sample_step,
171+
step=step,
163172
loader_key=runner.loader_key,
164173
prefix="batch",
165174
)
166175
elif scope == "loader" and self.log_epoch_metrics:
167176
self._log_metrics(
168177
metrics=metrics,
169-
step=runner.sample_step,
178+
step=step,
170179
loader_key=runner.loader_key,
171180
prefix="epoch",
172181
)
@@ -175,7 +184,7 @@ def log_metrics(
175184
per_loader_metrics = metrics[loader_key]
176185
self._log_metrics(
177186
metrics=per_loader_metrics,
178-
step=runner.sample_step,
187+
step=step,
179188
loader_key=loader_key,
180189
prefix="epoch",
181190
)

0 commit comments

Comments
 (0)