Skip to content

Commit 137b62d

Browse files
kaushikb11ananthsubawaelchli
authored
Add refresh_rate to RichProgressBar (#10497)
Co-authored-by: ananthsub <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 7d3ad5b commit 137b62d

File tree

3 files changed

+63
-26
lines changed

3 files changed

+63
-26
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
3838

3939

40+
- Renamed `refresh_rate_per_second` parameter to `referesh_rate` for `RichProgressBar` signature ([#10497](https://github.com/PyTorchLightning/pytorch-lightning/pull/10497))
41+
42+
4043
- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
4144

4245

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase):
206206
trainer = Trainer(callbacks=RichProgressBar())
207207
208208
Args:
209-
refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
209+
refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
210+
Set it to ``0`` to disable the display.
210211
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
211212
theme: Contains styles used to stylize the progress bar.
212213
@@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase):
222223

223224
def __init__(
224225
self,
225-
refresh_rate_per_second: int = 10,
226+
refresh_rate: int = 1,
226227
leave: bool = False,
227228
theme: RichProgressBarTheme = RichProgressBarTheme(),
228229
) -> None:
@@ -231,7 +232,7 @@ def __init__(
231232
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`."
232233
)
233234
super().__init__()
234-
self._refresh_rate_per_second: int = refresh_rate_per_second
235+
self._refresh_rate: int = refresh_rate
235236
self._leave: bool = leave
236237
self._enabled: bool = True
237238
self.progress: Optional[Progress] = None
@@ -242,17 +243,12 @@ def __init__(
242243
self.theme = theme
243244

244245
@property
245-
def refresh_rate_per_second(self) -> float:
246-
"""Refresh rate for Rich Progress.
247-
248-
Returns: Refresh rate for Progress Bar.
249-
Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
250-
"""
251-
return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1
246+
def refresh_rate(self) -> float:
247+
return self._refresh_rate
252248

253249
@property
254250
def is_enabled(self) -> bool:
255-
return self._enabled and self._refresh_rate_per_second > 0
251+
return self._enabled and self.refresh_rate > 0
256252

257253
@property
258254
def is_disabled(self) -> bool:
@@ -289,14 +285,18 @@ def _init_progress(self, trainer):
289285
self.progress = CustomProgress(
290286
*self.configure_columns(trainer),
291287
self._metric_component,
292-
refresh_per_second=self.refresh_rate_per_second,
288+
auto_refresh=False,
293289
disable=self.is_disabled,
294290
console=self._console,
295291
)
296292
self.progress.start()
297293
# progress has started
298294
self._progress_stopped = False
299295

296+
def refresh(self) -> None:
297+
if self.progress:
298+
self.progress.refresh()
299+
300300
def on_train_start(self, trainer, pl_module):
301301
super().on_train_start(trainer, pl_module)
302302
self._init_progress(trainer)
@@ -328,10 +328,12 @@ def on_sanity_check_start(self, trainer, pl_module):
328328
super().on_sanity_check_start(trainer, pl_module)
329329
self._init_progress(trainer)
330330
self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description)
331+
self.refresh()
331332

332333
def on_sanity_check_end(self, trainer, pl_module):
333334
super().on_sanity_check_end(trainer, pl_module)
334335
self._update(self.val_sanity_progress_bar_id, visible=False)
336+
self.refresh()
335337

336338
def on_train_epoch_start(self, trainer, pl_module):
337339
super().on_train_epoch_start(trainer, pl_module)
@@ -354,6 +356,7 @@ def on_train_epoch_start(self, trainer, pl_module):
354356
self.progress.reset(
355357
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
356358
)
359+
self.refresh()
357360

358361
def on_validation_epoch_start(self, trainer, pl_module):
359362
super().on_validation_epoch_start(trainer, pl_module)
@@ -364,52 +367,62 @@ def on_validation_epoch_start(self, trainer, pl_module):
364367
val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch
365368
total_val_batches = self.total_val_batches * val_checks_per_epoch
366369
self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False)
370+
self.refresh()
367371

368372
def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]:
369373
if self.progress is not None:
370374
return self.progress.add_task(
371375
f"[{self.theme.description}]{description}", total=total_batches, visible=visible
372376
)
373377

374-
def _update(self, progress_bar_id: int, visible: bool = True) -> None:
375-
if self.progress is not None:
376-
self.progress.update(progress_bar_id, advance=1.0, visible=visible)
378+
def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None:
379+
if self.progress is not None and self._should_update(current, total):
380+
self.progress.update(progress_bar_id, advance=self.refresh_rate, visible=visible)
381+
self.refresh()
382+
383+
def _should_update(self, current: int, total: int) -> bool:
384+
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
377385

378386
def on_validation_epoch_end(self, trainer, pl_module):
379387
super().on_validation_epoch_end(trainer, pl_module)
380388
if self.val_progress_bar_id is not None:
381-
self._update(self.val_progress_bar_id, visible=False)
389+
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False)
382390

383391
def on_test_epoch_start(self, trainer, pl_module):
384-
super().on_train_epoch_start(trainer, pl_module)
385392
self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description)
393+
self.refresh()
386394

387395
def on_predict_epoch_start(self, trainer, pl_module):
388396
super().on_predict_epoch_start(trainer, pl_module)
389397
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
398+
self.refresh()
390399

391400
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
392401
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
393-
self._update(self.main_progress_bar_id)
402+
self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches)
394403
self._update_metrics(trainer, pl_module)
404+
self.refresh()
395405

396406
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
397407
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
398408
if trainer.sanity_checking:
399-
self._update(self.val_sanity_progress_bar_id)
409+
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches)
400410
elif self.val_progress_bar_id is not None:
401411
# check to see if we should update the main training progress bar
402412
if self.main_progress_bar_id is not None:
403-
self._update(self.main_progress_bar_id)
404-
self._update(self.val_progress_bar_id)
413+
self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches)
414+
self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches)
415+
self.refresh()
405416

406417
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
407418
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
408-
self._update(self.test_progress_bar_id)
419+
self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches)
420+
self.refresh()
409421

410422
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
411423
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
412-
self._update(self.predict_progress_bar_id)
424+
self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches)
425+
self.refresh()
413426

414427
def _get_train_description(self, current_epoch: int) -> str:
415428
train_description = f"Epoch {current_epoch}"

tests/callbacks/test_rich_progress_bar.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def test_rich_progress_bar_callback():
3636

3737

3838
@RunIf(rich=True)
39-
def test_rich_progress_bar_refresh_rate():
40-
progress_bar = RichProgressBar(refresh_rate_per_second=1)
39+
def test_rich_progress_bar_refresh_rate_enabled():
40+
progress_bar = RichProgressBar(refresh_rate=1)
4141
assert progress_bar.is_enabled
4242
assert not progress_bar.is_disabled
43-
progress_bar = RichProgressBar(refresh_rate_per_second=0)
43+
progress_bar = RichProgressBar(refresh_rate=0)
4444
assert not progress_bar.is_enabled
4545
assert progress_bar.is_disabled
4646

@@ -180,3 +180,24 @@ def test_rich_progress_bar_leave(tmpdir, leave, reset_call_count):
180180
)
181181
trainer.fit(model)
182182
assert mock_progress_reset.call_count == reset_call_count
183+
184+
185+
@RunIf(rich=True)
186+
@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update")
187+
@pytest.mark.parametrize(("refresh_rate", "expected_call_count"), ([(0, 0), (3, 7)]))
188+
def test_rich_progress_bar_refresh_rate(progress_update, tmpdir, refresh_rate, expected_call_count):
189+
190+
model = BoringModel()
191+
192+
trainer = Trainer(
193+
default_root_dir=tmpdir,
194+
num_sanity_val_steps=0,
195+
limit_train_batches=6,
196+
limit_val_batches=6,
197+
max_epochs=1,
198+
callbacks=RichProgressBar(refresh_rate=refresh_rate),
199+
)
200+
201+
trainer.fit(model)
202+
203+
assert progress_update.call_count == expected_call_count

0 commit comments

Comments
 (0)