4242
4343has_compile = hasattr (torch , 'compile' )
4444
45+ try :
46+ from sklearn .metrics import precision_score , recall_score , f1_score
47+ has_sklearn = True
48+ except ImportError :
49+ has_sklearn = False
50+
4551_logger = logging .getLogger ('validate' )
4652
4753
158164parser .add_argument ('--retry' , default = False , action = 'store_true' ,
159165 help = 'Enable batch size decay & retry for single model validation' )
160166
167+ parser .add_argument ('--metrics-avg' , type = str , default = None ,
168+ choices = ['micro' , 'macro' , 'weighted' ],
169+ help = 'Enable precision, recall, F1-score calculation and specify the averaging method. '
170+ 'Requires scikit-learn. (default: None)' )
171+
161172# NaFlex loader arguments
162173parser .add_argument ('--naflex-loader' , action = 'store_true' , default = False ,
163174 help = 'Use NaFlex loader (Requires NaFlex compatible model)' )
@@ -176,6 +187,11 @@ def validate(args):
176187
177188 device = torch .device (args .device )
178189
190+ if args .metrics_avg and not has_sklearn :
191+ _logger .warning (
192+ f"scikit-learn not installed, disabling metrics calculation. Please install with 'pip install scikit-learn'." )
193+ args .metrics_avg = None
194+
179195 model_dtype = None
180196 if args .model_dtype :
181197 assert args .model_dtype in ('float32' , 'float16' , 'bfloat16' )
@@ -346,6 +362,10 @@ def validate(args):
346362 top1 = AverageMeter ()
347363 top5 = AverageMeter ()
348364
365+ if args .metrics_avg :
366+ all_preds = []
367+ all_targets = []
368+
349369 model .eval ()
350370 with torch .inference_mode ():
351371 # warmup, reduce variability of first batch time, especially for comparing torchscript vs non
@@ -382,6 +402,11 @@ def validate(args):
382402 top1 .update (acc1 .item (), batch_size )
383403 top5 .update (acc5 .item (), batch_size )
384404
405+ if args .metrics_avg :
406+ predictions = torch .argmax (output , dim = 1 )
407+ all_preds .append (predictions .cpu ())
408+ all_targets .append (target .cpu ())
409+
385410 # measure elapsed time
386411 batch_time .update (time .time () - end )
387412 end = time .time ()
@@ -408,18 +433,41 @@ def validate(args):
408433 top1a , top5a = real_labels .get_accuracy (k = 1 ), real_labels .get_accuracy (k = 5 )
409434 else :
410435 top1a , top5a = top1 .avg , top5 .avg
436+
437+ metric_results = {}
438+ if args .metrics_avg :
439+ all_preds = torch .cat (all_preds ).numpy ()
440+ all_targets = torch .cat (all_targets ).numpy ()
441+ precision = precision_score (all_targets , all_preds , average = args .metrics_avg , zero_division = 0 )
442+ recall = recall_score (all_targets , all_preds , average = args .metrics_avg , zero_division = 0 )
443+ f1 = f1_score (all_targets , all_preds , average = args .metrics_avg , zero_division = 0 )
444+ metric_results = {
445+ f'{ args .metrics_avg } _precision' : round (100 * precision , 4 ),
446+ f'{ args .metrics_avg } _recall' : round (100 * recall , 4 ),
447+ f'{ args .metrics_avg } _f1_score' : round (100 * f1 , 4 ),
448+ }
449+
411450 results = OrderedDict (
412451 model = args .model ,
413452 top1 = round (top1a , 4 ), top1_err = round (100 - top1a , 4 ),
414453 top5 = round (top5a , 4 ), top5_err = round (100 - top5a , 4 ),
454+ ** metric_results ,
415455 param_count = round (param_count / 1e6 , 2 ),
416456 img_size = data_config ['input_size' ][- 1 ],
417457 crop_pct = crop_pct ,
418458 interpolation = data_config ['interpolation' ],
419459 )
420460
421- _logger .info (' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})' .format (
422- results ['top1' ], results ['top1_err' ], results ['top5' ], results ['top5_err' ]))
461+ log_string = ' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})' .format (
462+ results ['top1' ], results ['top1_err' ], results ['top5' ], results ['top5_err' ])
463+ if metric_results :
464+ log_string += ' | Precision({avg}) {prec:.3f} | Recall({avg}) {rec:.3f} | F1-score({avg}) {f1:.3f}' .format (
465+ avg = args .metrics_avg ,
466+ prec = metric_results [f'{ args .metrics_avg } _precision' ],
467+ rec = metric_results [f'{ args .metrics_avg } _recall' ],
468+ f1 = metric_results [f'{ args .metrics_avg } _f1_score' ],
469+ )
470+ _logger .info (log_string )
423471
424472 return results
425473
@@ -534,4 +582,4 @@ def write_results(results_file, results, format='csv'):
534582
535583
536584if __name__ == '__main__' :
537- main ()
585+ main ()
0 commit comments