@@ -183,7 +183,21 @@ def _get_hub_document(self, model_id):
183
183
HubContentName = model_id
184
184
)["HubContentDocument" ]
185
185
186
-
186
+ def _get_supported_instance_types (self , model_id ):
187
+ """Extract supported instance types from hub document."""
188
+ try :
189
+ hub_doc = self ._get_hub_document (model_id )
190
+ doc_data = json .loads (hub_doc )
191
+
192
+ supported_types = doc_data .get ("SupportedInferenceInstanceTypes" , [])
193
+ default_type = doc_data .get ("DefaultInferenceInstanceType" )
194
+
195
+ if default_type and default_type in supported_types :
196
+ supported_types = [default_type ] + [t for t in supported_types if t != default_type ]
197
+
198
+ return {"types" : supported_types , "default" : default_type , "error" : None }
199
+ except Exception as e :
200
+ return {"types" : [], "default" : None , "error" : str (e )}
187
201
188
202
def _create_config_link (self , model_id ):
189
203
"""Create an HTML link that generates deployment config."""
@@ -192,16 +206,30 @@ def _create_config_link(self, model_id):
192
206
193
207
def _generate_deployment_config (self , model_id ):
194
208
"""Generate deployment configuration code for a model."""
209
+ instance_data = self ._get_supported_instance_types (model_id )
210
+ supported_types = instance_data ["types" ]
211
+ default_type = instance_data ["default" ]
212
+ error = instance_data ["error" ]
213
+
214
+ if error :
215
+ instance_type = '<ENTER-INSTANCE-TYPE>'
216
+ types_comment = ""
217
+ else :
218
+ instance_type = default_type if default_type else '<ENTER-INSTANCE-TYPE>'
219
+ types_comment = f"# Supported instance types: { ', ' .join (supported_types )} " if supported_types else "# No supported instance types found"
220
+
195
221
config_code = f'''# Deployment configuration for { model_id }
196
222
from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import Model, Server, SageMakerEndpoint
197
223
from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint
198
224
199
- # Create configs - REPLACE PLACEHOLDER VALUES BELOW
225
+ { types_comment }
226
+
227
+ # Create configs - REPLACE PLACEHOLDER VALUE BELOW
200
228
model = Model(
201
229
model_id='{ model_id } ',
202
230
)
203
231
server = Server(
204
- instance_type='<ENTER-INSTANCE-TYPE> ',
232
+ instance_type='{ instance_type } ',
205
233
)
206
234
endpoint_name = SageMakerEndpoint(name='<ENTER-YOUR-ENDPOINT-NAME>')
207
235
0 commit comments