@@ -478,6 +478,11 @@ def func(modelcls: AnyModel) -> AnyModel:
478478 return modelcls
479479 return func
480480
481+ @classmethod
482+ def print_registered_models (cls ):
483+ for name in cls ._model_classes .keys ():
484+ logger .error (f"- { name } " )
485+
481486 @classmethod
482487 def from_model_architecture (cls , arch : str ) -> type [Model ]:
483488 try :
@@ -4929,6 +4934,7 @@ def parse_args() -> argparse.Namespace:
49294934 parser .add_argument (
49304935 "model" , type = Path ,
49314936 help = "directory containing model file" ,
4937+ nargs = "?" ,
49324938 )
49334939 parser .add_argument (
49344940 "--use-temp-file" , action = "store_true" ,
@@ -4966,8 +4972,15 @@ def parse_args() -> argparse.Namespace:
49664972 "--metadata" , type = Path ,
49674973 help = "Specify the path for an authorship metadata override file"
49684974 )
4975+ parser .add_argument (
4976+ "--print-supported-models" , action = "store_true" ,
4977+ help = "Print the supported models"
4978+ )
49694979
4970- return parser .parse_args ()
4980+ args = parser .parse_args ()
4981+ if not args .print_supported_models and args .model is None :
4982+ parser .error ("the following arguments are required: model" )
4983+ return args
49714984
49724985
49734986def split_str_to_n_bytes (split_str : str ) -> int :
@@ -4991,6 +5004,11 @@ def split_str_to_n_bytes(split_str: str) -> int:
49915004def main () -> None :
49925005 args = parse_args ()
49935006
5007+ if args .print_supported_models :
5008+ logger .error ("Supported models:" )
5009+ Model .print_registered_models ()
5010+ sys .exit (0 )
5011+
49945012 if args .verbose :
49955013 logging .basicConfig (level = logging .DEBUG )
49965014 else :
0 commit comments