1616
1717import logging
1818import warnings
19+ import os
20+
1921from collections import OrderedDict
2022
2123from tensorflow_tts .configs import (
4042 SavableTFFastSpeech2 ,
4143 SavableTFTacotron2
4244)
45+ from tensorflow_tts .utils import CACHE_DIRECTORY , MODEL_FILE_NAME , LIBRARY_NAME
46+ from tensorflow_tts import __version__ as VERSION
47+ from huggingface_hub import hf_hub_url , cached_download
4348
4449
4550TF_MODEL_MAPPING = OrderedDict (
@@ -62,8 +67,35 @@ def __init__(self):
6267 raise EnvironmentError ("Cannot be instantiated using `__init__()`" )
6368
6469 @classmethod
65- def from_pretrained (cls , config , pretrained_path = None , ** kwargs ):
70+ def from_pretrained (cls , config = None , pretrained_path = None , ** kwargs ):
6671 is_build = kwargs .pop ("is_build" , True )
72+
73+ # load weights from hf hub
74+ if pretrained_path is not None :
75+ if not os .path .isfile (pretrained_path ):
76+ # retrieve correct hub url
77+ download_url = hf_hub_url (repo_id = pretrained_path , filename = MODEL_FILE_NAME )
78+
79+ downloaded_file = str (
80+ cached_download (
81+ url = download_url ,
82+ library_name = LIBRARY_NAME ,
83+ library_version = VERSION ,
84+ cache_dir = CACHE_DIRECTORY ,
85+ )
86+ )
87+
88+ # load config from repo as well
89+ if config is None :
90+ from tensorflow_tts .inference import AutoConfig
91+
92+ config = AutoConfig .from_pretrained (pretrained_path )
93+
94+ pretraine_path = downloaded_file
95+
96+
97+ assert config is not None , "Please make sure to pass a config along to load a model from a local file"
98+
6799 for config_class , model_class in TF_MODEL_MAPPING .items ():
68100 if isinstance (config , config_class ) and str (config_class .__name__ ) in str (
69101 config
@@ -79,6 +111,7 @@ def from_pretrained(cls, config, pretrained_path=None, **kwargs):
79111 pretrained_path , by_name = True , skip_mismatch = True
80112 )
81113 return model
114+
82115 raise ValueError (
83116 "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n "
84117 "Model type should be one of {}." .format (
0 commit comments