@@ -206,7 +206,8 @@ class RichProgressBar(ProgressBarBase):
206
206
trainer = Trainer(callbacks=RichProgressBar())
207
207
208
208
Args:
209
- refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled.
209
+ refresh_rate: Determines at which rate (in number of batches) the progress bars get updated.
210
+ Set it to ``0`` to disable the display.
210
211
leave: Leaves the finished progress bar in the terminal at the end of the epoch. Default: False
211
212
theme: Contains styles used to stylize the progress bar.
212
213
@@ -222,7 +223,7 @@ class RichProgressBar(ProgressBarBase):
222
223
223
224
def __init__ (
224
225
self ,
225
- refresh_rate_per_second : int = 10 ,
226
+ refresh_rate : int = 1 ,
226
227
leave : bool = False ,
227
228
theme : RichProgressBarTheme = RichProgressBarTheme (),
228
229
) -> None :
@@ -231,7 +232,7 @@ def __init__(
231
232
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`."
232
233
)
233
234
super ().__init__ ()
234
- self ._refresh_rate_per_second : int = refresh_rate_per_second
235
+ self ._refresh_rate : int = refresh_rate
235
236
self ._leave : bool = leave
236
237
self ._enabled : bool = True
237
238
self .progress : Optional [Progress ] = None
@@ -242,17 +243,12 @@ def __init__(
242
243
self .theme = theme
243
244
244
245
@property
245
- def refresh_rate_per_second (self ) -> float :
246
- """Refresh rate for Rich Progress.
247
-
248
- Returns: Refresh rate for Progress Bar.
249
- Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress).
250
- """
251
- return self ._refresh_rate_per_second if self ._refresh_rate_per_second > 0 else 1
246
+ def refresh_rate (self ) -> float :
247
+ return self ._refresh_rate
252
248
253
249
@property
254
250
def is_enabled (self ) -> bool :
255
- return self ._enabled and self ._refresh_rate_per_second > 0
251
+ return self ._enabled and self .refresh_rate > 0
256
252
257
253
@property
258
254
def is_disabled (self ) -> bool :
@@ -289,14 +285,18 @@ def _init_progress(self, trainer):
289
285
self .progress = CustomProgress (
290
286
* self .configure_columns (trainer ),
291
287
self ._metric_component ,
292
- refresh_per_second = self . refresh_rate_per_second ,
288
+ auto_refresh = False ,
293
289
disable = self .is_disabled ,
294
290
console = self ._console ,
295
291
)
296
292
self .progress .start ()
297
293
# progress has started
298
294
self ._progress_stopped = False
299
295
296
+ def refresh (self ) -> None :
297
+ if self .progress :
298
+ self .progress .refresh ()
299
+
300
300
def on_train_start (self , trainer , pl_module ):
301
301
super ().on_train_start (trainer , pl_module )
302
302
self ._init_progress (trainer )
@@ -328,10 +328,12 @@ def on_sanity_check_start(self, trainer, pl_module):
328
328
super ().on_sanity_check_start (trainer , pl_module )
329
329
self ._init_progress (trainer )
330
330
self .val_sanity_progress_bar_id = self ._add_task (trainer .num_sanity_val_steps , self .sanity_check_description )
331
+ self .refresh ()
331
332
332
333
def on_sanity_check_end (self , trainer , pl_module ):
333
334
super ().on_sanity_check_end (trainer , pl_module )
334
335
self ._update (self .val_sanity_progress_bar_id , visible = False )
336
+ self .refresh ()
335
337
336
338
def on_train_epoch_start (self , trainer , pl_module ):
337
339
super ().on_train_epoch_start (trainer , pl_module )
@@ -354,6 +356,7 @@ def on_train_epoch_start(self, trainer, pl_module):
354
356
self .progress .reset (
355
357
self .main_progress_bar_id , total = total_batches , description = train_description , visible = True
356
358
)
359
+ self .refresh ()
357
360
358
361
def on_validation_epoch_start (self , trainer , pl_module ):
359
362
super ().on_validation_epoch_start (trainer , pl_module )
@@ -364,52 +367,62 @@ def on_validation_epoch_start(self, trainer, pl_module):
364
367
val_checks_per_epoch = self .total_train_batches // trainer .val_check_batch
365
368
total_val_batches = self .total_val_batches * val_checks_per_epoch
366
369
self .val_progress_bar_id = self ._add_task (total_val_batches , self .validation_description , visible = False )
370
+ self .refresh ()
367
371
368
372
def _add_task (self , total_batches : int , description : str , visible : bool = True ) -> Optional [int ]:
369
373
if self .progress is not None :
370
374
return self .progress .add_task (
371
375
f"[{ self .theme .description } ]{ description } " , total = total_batches , visible = visible
372
376
)
373
377
374
- def _update (self , progress_bar_id : int , visible : bool = True ) -> None :
375
- if self .progress is not None :
376
- self .progress .update (progress_bar_id , advance = 1.0 , visible = visible )
378
+ def _update (self , progress_bar_id : int , current : int , total : int , visible : bool = True ) -> None :
379
+ if self .progress is not None and self ._should_update (current , total ):
380
+ self .progress .update (progress_bar_id , advance = self .refresh_rate , visible = visible )
381
+ self .refresh ()
382
+
383
+ def _should_update (self , current : int , total : int ) -> bool :
384
+ return self .is_enabled and (current % self .refresh_rate == 0 or current == total )
377
385
378
386
def on_validation_epoch_end (self , trainer , pl_module ):
379
387
super ().on_validation_epoch_end (trainer , pl_module )
380
388
if self .val_progress_bar_id is not None :
381
- self ._update (self .val_progress_bar_id , visible = False )
389
+ self ._update (self .val_progress_bar_id , self . val_batch_idx , self . total_val_batches , visible = False )
382
390
383
391
def on_test_epoch_start (self , trainer , pl_module ):
384
- super ().on_train_epoch_start (trainer , pl_module )
385
392
self .test_progress_bar_id = self ._add_task (self .total_test_batches , self .test_description )
393
+ self .refresh ()
386
394
387
395
def on_predict_epoch_start (self , trainer , pl_module ):
388
396
super ().on_predict_epoch_start (trainer , pl_module )
389
397
self .predict_progress_bar_id = self ._add_task (self .total_predict_batches , self .predict_description )
398
+ self .refresh ()
390
399
391
400
def on_train_batch_end (self , trainer , pl_module , outputs , batch , batch_idx ):
392
401
super ().on_train_batch_end (trainer , pl_module , outputs , batch , batch_idx )
393
- self ._update (self .main_progress_bar_id )
402
+ self ._update (self .main_progress_bar_id , self . train_batch_idx , self . total_train_batches )
394
403
self ._update_metrics (trainer , pl_module )
404
+ self .refresh ()
395
405
396
406
def on_validation_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
397
407
super ().on_validation_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
398
408
if trainer .sanity_checking :
399
- self ._update (self .val_sanity_progress_bar_id )
409
+ self ._update (self .val_sanity_progress_bar_id , self . val_batch_idx , self . total_val_batches )
400
410
elif self .val_progress_bar_id is not None :
401
411
# check to see if we should update the main training progress bar
402
412
if self .main_progress_bar_id is not None :
403
- self ._update (self .main_progress_bar_id )
404
- self ._update (self .val_progress_bar_id )
413
+ self ._update (self .main_progress_bar_id , self .val_batch_idx , self .total_val_batches )
414
+ self ._update (self .val_progress_bar_id , self .val_batch_idx , self .total_val_batches )
415
+ self .refresh ()
405
416
406
417
def on_test_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
407
418
super ().on_test_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
408
- self ._update (self .test_progress_bar_id )
419
+ self ._update (self .test_progress_bar_id , self .test_batch_idx , self .total_test_batches )
420
+ self .refresh ()
409
421
410
422
def on_predict_batch_end (self , trainer , pl_module , outputs , batch , batch_idx , dataloader_idx ):
411
423
super ().on_predict_batch_end (trainer , pl_module , outputs , batch , batch_idx , dataloader_idx )
412
- self ._update (self .predict_progress_bar_id )
424
+ self ._update (self .predict_progress_bar_id , self .predict_batch_idx , self .total_predict_batches )
425
+ self .refresh ()
413
426
414
427
def _get_train_description (self , current_epoch : int ) -> str :
415
428
train_description = f"Epoch { current_epoch } "
0 commit comments