Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for variable batch size in `ThroughputMonitor` ([#20236](https://github.com/Lightning-AI/pytorch-lightning/pull/20236))


- Added `EMAWeightAveraging` callback that wraps Lightning's `WeightAveraging` class ([#21260](https://github.com/Lightning-AI/pytorch-lightning/pull/21260))


### Changed

- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from lightning.pytorch.callbacks.stochastic_weight_avg import StochasticWeightAveraging
from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.callbacks.weight_averaging import WeightAveraging
from lightning.pytorch.callbacks.weight_averaging import EMAWeightAveraging, WeightAveraging

__all__ = [
"BackboneFinetuning",
Expand All @@ -59,5 +59,6 @@
"ThroughputMonitor",
"Timer",
"TQDMProgressBar",
"EMAWeightAveraging",
"WeightAveraging",
]
54 changes: 53 additions & 1 deletion src/lightning/pytorch/callbacks/weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Optional, Union

import torch
from torch.optim.swa_utils import AveragedModel
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
from typing_extensions import override

import lightning.pytorch as pl
Expand Down Expand Up @@ -361,3 +361,55 @@ def _copy_average_to_current(self, pl_module: "pl.LightningModule") -> None:
current_params = itertools.chain(pl_module.parameters(), pl_module.buffers())
for average_param, current_param in zip(average_params, current_params):
current_param.data.copy_(average_param.data)


class EMAWeightAveraging(WeightAveraging):
"""Exponential Moving Average (EMA) Weight Averaging callback."""

def __init__(
self,
device: Optional[Union[torch.device, str, int]] = None,
use_buffers: bool = True,
decay: float = 0.999,
update_every_n_steps: int = 1,
update_starting_at_step: Optional[int] = None,
update_starting_at_epoch: Optional[int] = None,
**kwargs: Any,
):
super().__init__(
device=device,
use_buffers=use_buffers,
**kwargs,
avg_fn=get_ema_avg_fn(decay=decay),
)

self.update_every_n_steps = update_every_n_steps
self.update_starting_at_step = update_starting_at_step
self.update_starting_at_epoch = update_starting_at_epoch

def should_update(self, step_idx: Optional[int] = None, epoch_idx: Optional[int] = None) -> bool:
"""Decide when to update the model weights.

Args:
step_idx: The current step index.
epoch_idx: The current epoch index.
Returns:
bool: True if the model weights should be updated, False otherwise.

"""
if step_idx is not None:
# Check step-based conditions only if we have a valid step_idx
meets_step_requirement = self.update_starting_at_step is None or step_idx >= self.update_starting_at_step
meets_step_frequency = self.update_every_n_steps > 0 and step_idx % self.update_every_n_steps == 0
if meets_step_requirement and meets_step_frequency:
return True

if epoch_idx is not None:
# Check epoch-based condition only if we specify one
meets_epoch_requirement = (
self.update_starting_at_epoch is not None and epoch_idx >= self.update_starting_at_epoch
)
if meets_epoch_requirement:
return True

return False
115 changes: 114 additions & 1 deletion tests/tests_pytorch/callbacks/test_weight_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import WeightAveraging
from lightning.pytorch.callbacks import EMAWeightAveraging, WeightAveraging
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -329,3 +329,116 @@ def _train_and_resume(model: TestModel, dataset: Dataset, tmp_path: str, devices
callback = EMATestCallback(devices=devices)
_train(model, dataset, tmp_path, callback, devices=devices, checkpoint_path=checkpoint_path, **kwargs)
return model


@pytest.mark.parametrize(
("strategy", "accelerator", "devices"),
[
("auto", "cpu", 1),
pytest.param("auto", "gpu", 1, marks=RunIf(min_cuda_gpus=1)),
],
)
def test_ema_weight_averaging(tmp_path, strategy, accelerator, devices):
"""Test EMAWeightAveraging callback with various update configurations."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Test with default settings (update every step)
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1)
_train(model, dataset, tmp_path, callback, strategy=strategy, accelerator=accelerator, devices=devices)

# Verify the average model was created and updated
assert callback._average_model is not None
assert callback._average_model.n_averaged > 0


def test_ema_weight_averaging_step_frequency(tmp_path):
"""Test EMAWeightAveraging with custom step update frequency."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Update every 5 steps
callback = EMAWeightAveraging(decay=0.95, update_every_n_steps=5)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None


def test_ema_weight_averaging_starting_step(tmp_path):
"""Test EMAWeightAveraging with delayed start based on steps."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Start updating after step 10
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_step=10)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None


def test_ema_weight_averaging_starting_epoch(tmp_path):
"""Test EMAWeightAveraging with delayed start based on epochs."""
model = TestModel()
dataset = RandomDataset(32, 32)

# Start updating after epoch 3
callback = EMAWeightAveraging(decay=0.999, update_every_n_steps=1, update_starting_at_epoch=3)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None


def test_ema_weight_averaging_should_update(tmp_path):
"""Test the should_update logic of EMAWeightAveraging."""
# Test with step-based updates
callback = EMAWeightAveraging(update_every_n_steps=5, update_starting_at_step=10)

# Before starting step
assert not callback.should_update(step_idx=5)
assert not callback.should_update(step_idx=9)

# At and after starting step, but not on update frequency
assert callback.should_update(step_idx=10) # First update
assert not callback.should_update(step_idx=11)
assert not callback.should_update(step_idx=14)
assert callback.should_update(step_idx=15) # Second update

# Test with epoch-based updates
callback = EMAWeightAveraging(update_starting_at_epoch=2)

assert not callback.should_update(epoch_idx=0)
assert not callback.should_update(epoch_idx=1)
assert callback.should_update(epoch_idx=2)
assert callback.should_update(epoch_idx=3)


def test_ema_weight_averaging_checkpoint_save_load(tmp_path):
"""Test that EMAWeightAveraging correctly saves and loads checkpoints."""
model = TestModel()
dataset = RandomDataset(32, 32)

callback = EMAWeightAveraging(decay=0.99, update_every_n_steps=2)

# Train and create checkpoint
_train(model, dataset, tmp_path, callback, will_crash=True)

# Resume from checkpoint
model2 = TestModel()
callback2 = EMAWeightAveraging(decay=0.99, update_every_n_steps=2)
checkpoint_path = str(tmp_path / "lightning_logs" / "version_0" / "checkpoints" / "*.ckpt")

_train(model2, dataset, tmp_path, callback2, checkpoint_path=checkpoint_path)

assert callback2._average_model is not None


@pytest.mark.parametrize("decay", [0.9, 0.99, 0.999, 0.9999])
def test_ema_weight_averaging_decay_values(tmp_path, decay):
"""Test EMAWeightAveraging with different decay values."""
model = TestModel()
dataset = RandomDataset(32, 32)

callback = EMAWeightAveraging(decay=decay, update_every_n_steps=1)
_train(model, dataset, tmp_path, callback)

assert callback._average_model is not None
Loading