Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/inference/SDK/inference-jumpstart-e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"outputs": [],
"source": [
"# Import the helper module\n",
"from jumpstart_public_hub_visualization_utils import get_all_public_hub_model_data\n",
"from sagemaker.hyperpod.inference.jumpstart_public_hub_visualization_utils import get_all_public_hub_model_data\n",
"\n",
"# Load and display SageMaker public hub models\n",
"get_all_public_hub_model_data(region=\"us-east-2\")"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import itables
import pandas
import logging
import json
from botocore.config import Config
from ipywidgets import Button, Output
from IPython.display import display
Expand Down Expand Up @@ -160,6 +161,7 @@ def _get_model_summary(self, full_summary):
"Model Type": model_type,
"Model Description": full_summary["HubContentDescription"],
"Search Keywords": keywords,
"Deployment Configs": self._create_config_link(full_summary["HubContentName"]),
}

def _determine_model_type(self, keywords, model_id):
Expand All @@ -180,6 +182,84 @@ def _get_hub_document(self, model_id):
HubContentType="Model",
HubContentName=model_id
)["HubContentDocument"]

def _get_supported_instance_types(self, model_id):
"""Extract supported instance types from hub document."""
try:
hub_doc = self._get_hub_document(model_id)
doc_data = json.loads(hub_doc)

supported_types = doc_data.get("SupportedInferenceInstanceTypes", [])
default_type = doc_data.get("DefaultInferenceInstanceType")

if default_type and default_type in supported_types:
supported_types = [default_type] + [t for t in supported_types if t != default_type]

return {"types": supported_types, "default": default_type, "error": None}
except Exception as e:
return {"types": [], "default": None, "error": str(e)}

def _create_config_link(self, model_id):
"""Create deployment config display using collapsible details for all environments."""
return f'<details><summary style="color: #007bff; cursor: pointer;">View SDK Config</summary><pre style="font-size: 10px; background: #f5f5f5; padding: 5px; margin: 5px 0;">{self._generate_deployment_config(model_id)}</pre></details>'

def _generate_deployment_config(self, model_id):
"""Generate deployment configuration code for a model."""
instance_data = self._get_supported_instance_types(model_id)
supported_types = instance_data["types"]
default_type = instance_data["default"]
error = instance_data["error"]

if error:
instance_type = '<ENTER-INSTANCE-TYPE>'
types_comment = ""
else:
instance_type = default_type if default_type else '\<ENTER-INSTANCE-TYPE\>'
types_comment = self._format_instance_types_comment(supported_types)

config_code = f'''# Deployment configuration for {model_id}
from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import (
Model, Server, SageMakerEndpoint
)
from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint

{types_comment}

# Create configs - REPLACE PLACEHOLDER VALUE BELOW
model = Model(
model_id='{model_id}',
)
server = Server(
instance_type='{instance_type}',
)
endpoint_name = SageMakerEndpoint(name='ENTER-YOUR-ENDPOINT-NAME')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can we have a variable and use that for the endpoint name?

Something along the lines of model_id+timestamp would work for endpoint name too


# Create endpoint spec
js_endpoint = HPJumpStartEndpoint(
model=model,
server=server,
sage_maker_endpoint=endpoint_name,
)

# Deploy the endpoint
js_endpoint.create()'''
return config_code

def _format_instance_types_comment(self, supported_types):
"""Format instance types comment with line breaks for better readability."""
if not supported_types:
return "# No supported instance types found"

if len(supported_types) <= 5:
return f"# Supported instance types: {', '.join(supported_types)}"

# For more than 5 instance types, format with newlines every 5 types
comment_lines = ["# Supported instance types:"]
for i in range(0, len(supported_types), 5):
batch = supported_types[i:i+5]
comment_lines.append(f"# {', '.join(batch)}")

return '\n'.join(comment_lines)


def get_all_public_hub_model_data(region: str):
Expand All @@ -198,14 +278,14 @@ def interactive_view(tabular_data: list):
styled_df = _style_dataframe(df)
layout = _get_table_layout(len(tabular_data))

itables.show(styled_df, layout=layout)
itables.show(styled_df, layout=layout, allow_html=True)


def _configure_itables():
"""Configure itables for notebook display."""
itables.init_notebook_mode(all_interactive=True)
itables.options.allow_html = True


def _style_dataframe(df):
"""Apply styling to dataframe."""
Expand All @@ -216,4 +296,4 @@ def _style_dataframe(df):

def _get_table_layout(data_length):
"""Get appropriate table layout based on data size."""
return {} if data_length > 10 else {"topStart": None, "topEnd": "search"}
return {} if data_length > 10 else {"topStart": None, "topEnd": "search"}
Loading