@@ -271,38 +271,49 @@ class CheckpointSaver(Callback):
271
271
272
272
Args:
273
273
save_dir (str): path to folder where to save the model
274
- save_name (str): name of the saved model. can additionally
274
+ save_name (str): name of the saved model. can additionally
275
275
add epoch and metric to model save name
276
+ monitor (str): quantity to monitor. Implicitly prefers validation metrics over train. One of:
277
+ `loss` or name of any metric passed to the runner.
276
278
mode (str): one of "min" of "max". Whether to decide to save based
277
279
on minimizing or maximizing loss
278
- include_optimizer (bool): if True would also save `optimizers` state_dict.
280
+ include_optimizer (bool): if True would also save `optimizers` state_dict.
279
281
This increases checkpoint size 2x times.
282
+ verbose (bool): If `True` reports each time new best is found
280
283
"""
281
284
282
285
def __init__ (
283
- self , save_dir , save_name = "model_{ep}_{metric:.2f}.chpn" , mode = "min" , include_optimizer = False
286
+ self ,
287
+ save_dir ,
288
+ save_name = "model_{ep}_{metric:.2f}.chpn" ,
289
+ monitor = "loss" ,
290
+ mode = "min" ,
291
+ include_optimizer = False ,
292
+ verbose = True ,
284
293
):
285
294
super ().__init__ ()
286
295
self .save_dir = save_dir
287
296
self .save_name = save_name
288
- self .mode = ReduceMode (mode )
289
- self .best = float ("inf" ) if self .mode == ReduceMode .MIN else - float ("inf" )
297
+ self .monitor = monitor
298
+ mode = ReduceMode (mode )
299
+ if mode == ReduceMode .MIN :
300
+ self .best = np .inf
301
+ self .monitor_op = np .less
302
+ elif mode == ReduceMode .MAX :
303
+ self .best = - np .inf
304
+ self .monitor_op = np .greater
290
305
self .include_optimizer = include_optimizer
306
+ self .verbose = verbose
291
307
292
308
def on_begin (self ):
293
309
os .makedirs (self .save_dir , exist_ok = True )
294
310
295
311
def on_epoch_end (self ):
296
- # TODO zakirov(1.11.19) Add support for saving based on metric
297
- if self .state .val_loss is not None :
298
- current = self .state .val_loss .avg
299
- else :
300
- current = self .state .train_loss .avg
301
- if (self .mode == ReduceMode .MIN and current < self .best ) or (
302
- self .mode == ReduceMode .MAX and current > self .best
303
- ):
304
- ep = self .state .epoch
305
- # print(f"Epoch {ep}: best loss improved from {self.best:.4f} to {current:.4f}")
312
+ current = self .get_monitor_value ()
313
+ if self .monitor_op (current , self .best ):
314
+ ep = self .state .epoch_log
315
+ if self .verbose :
316
+ print (f"Epoch { ep :2d} : best { self .monitor } improved from { self .best :.4f} to { current :.4f} " )
306
317
self .best = current
307
318
save_name = os .path .join (self .save_dir , self .save_name .format (ep = ep , metric = current ))
308
319
self ._save_checkpoint (save_name )
@@ -317,6 +328,18 @@ def _save_checkpoint(self, path):
317
328
save_dict ["optimizer" ] = self .state .optimizer .state_dict ()
318
329
torch .save (save_dict , path )
319
330
331
+ def get_monitor_value (self ):
332
+ value = None
333
+ if self .monitor == "loss" :
334
+ value = self .state .loss_meter .avg
335
+ else :
336
+ for metric_meter in self .state .metric_meters :
337
+ if metric_meter .name == self .monitor :
338
+ value = metric_meter .avg
339
+ if value is None :
340
+ raise ValueError (f"CheckpointSaver can't find { self .monitor } value to monitor" )
341
+ return value
342
+
320
343
321
344
class TensorBoard (Callback ):
322
345
"""
@@ -407,7 +430,7 @@ def on_batch_end(self):
407
430
408
431
def on_loader_end (self ):
409
432
super ().on_loader_end ()
410
- f = plot_confusion_matrix (self .cmap , self .class_names , show = False )
433
+ f = plot_confusion_matrix (self .cmap , self .class_names , normalize = True , show = False )
411
434
cm_img = render_figure_to_tensor (f )
412
435
if self .state .is_train :
413
436
self .train_cm_img = cm_img
@@ -527,10 +550,11 @@ def mixup(self, data, target):
527
550
if not self .state .is_train or np .random .rand () > self .prob :
528
551
return data , target_one_hot
529
552
prev_data , prev_target = (data , target_one_hot ) if self .prev_input is None else self .prev_input
530
- self .prev_input = data , target_one_hot
553
+ self .prev_input = data .clone (), target_one_hot .clone ()
554
+ perm = torch .randperm (data .size (0 )).cuda ()
531
555
c = self .tb .sample ()
532
- md = c * data + (1 - c ) * prev_data
533
- mt = c * target_one_hot + (1 - c ) * prev_target
556
+ md = c * data + (1 - c ) * prev_data [ perm ]
557
+ mt = c * target_one_hot + (1 - c ) * prev_target [ perm ]
534
558
return md , mt
535
559
536
560
@@ -570,16 +594,17 @@ def cutmix(self, data, target):
570
594
if not self .state .is_train or np .random .rand () > self .prob :
571
595
return data , target_one_hot
572
596
prev_data , prev_target = (data , target_one_hot ) if self .prev_input is None else self .prev_input
573
- self .prev_input = data , target_one_hot
597
+ self .prev_input = data . clone () , target_one_hot . clone ()
574
598
# prev_data shape can be different from current. so need to take min
575
599
H , W = min (data .size (2 ), prev_data .size (2 )), min (data .size (3 ), prev_data .size (3 ))
600
+ perm = torch .randperm (data .size (0 )).cuda ()
576
601
lam = self .tb .sample ()
577
602
lam = min ([lam , 1 - lam ])
578
603
bbh1 , bbw1 , bbh2 , bbw2 = self .rand_bbox (H , W , lam )
579
604
# real lambda may be diffrent from sampled. adjust for it
580
605
lam = (bbh2 - bbh1 ) * (bbw2 - bbw1 ) / (H * W )
581
- data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [: , :, bbh1 :bbh2 , bbw1 :bbw2 ]
582
- mixed_target = (1 - lam ) * target_one_hot + lam * prev_target
606
+ data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [perm , :, bbh1 :bbh2 , bbw1 :bbw2 ]
607
+ mixed_target = (1 - lam ) * target_one_hot + lam * prev_target [ perm ]
583
608
return data , mixed_target
584
609
585
610
@staticmethod
@@ -609,11 +634,32 @@ def cutmix(self, data, target):
609
634
if not self .state .is_train or np .random .rand () > self .prob :
610
635
return data , target
611
636
prev_data , prev_target = (data , target ) if self .prev_input is None else self .prev_input
612
- self .prev_input = data , target
637
+ self .prev_input = data . clone () , target . clone ()
613
638
H , W = min (data .size (2 ), prev_data .size (2 )), min (data .size (3 ), prev_data .size (3 ))
639
+ perm = torch .randperm (data .size (0 )).cuda ()
614
640
lam = self .tb .sample ()
615
641
lam = min ([lam , 1 - lam ])
616
642
bbh1 , bbw1 , bbh2 , bbw2 = self .rand_bbox (H , W , lam )
617
- data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [: , :, bbh1 :bbh2 , bbw1 :bbw2 ]
618
- target [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_target [: , :, bbh1 :bbh2 , bbw1 :bbw2 ]
643
+ data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [perm , :, bbh1 :bbh2 , bbw1 :bbw2 ]
644
+ target [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_target [perm , :, bbh1 :bbh2 , bbw1 :bbw2 ]
619
645
return data , target
646
+
647
+
648
+ class ScheduledDropout (Callback ):
649
+ def __init__ (self , drop_rate = 0.1 , epochs = 30 , attr_name = "dropout.p" ):
650
+ """
651
+ Slowly changes dropout value for `attr_name` each epoch.
652
+ Ref: https://arxiv.org/abs/1703.06229
653
+ Args:
654
+ drop_rate (float): max dropout rate
655
+ epochs (int): num epochs to max dropout to fully take effect
656
+ attr_name (str): name of dropout block in model
657
+ """
658
+ super ().__init__ ()
659
+ self .drop_rate = drop_rate
660
+ self .epochs = epochs
661
+ self .attr_name = attr_name
662
+
663
+ def on_epoch_end (self ):
664
+ current_rate = self .drop_rate * min (1 , self .state .epoch / self .epochs )
665
+ setattr (self .state .model , self .attr_name , current_rate )
0 commit comments