Skip to content

Commit 68d87de

Browse files
update print for inference CLI for list and describe, bug fix for since-hours flag to support float, minor update to notebook (#85)
1 parent 389d6ae commit 68d87de

File tree

4 files changed

+258
-16
lines changed

4 files changed

+258
-16
lines changed

examples/inference/CLI/inference-fsx-model-e2e-cli.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
28-
"!hyperpod list-cluster"
28+
"!hyperpod list-cluster --output table"
2929
]
3030
},
3131
{
@@ -114,7 +114,7 @@
114114
"metadata": {},
115115
"outputs": [],
116116
"source": [
117-
"!hyp get-operator-logs hyp-custom-endpoint --since-hours 4"
117+
"!hyp get-operator-logs hyp-custom-endpoint --since-hours 0.5"
118118
]
119119
},
120120
{

examples/inference/CLI/inference-jumpstart-e2e-cli.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
28-
"!hyperpod list-cluster"
28+
"!hyperpod list-cluster --output table"
2929
]
3030
},
3131
{
@@ -101,7 +101,7 @@
101101
"metadata": {},
102102
"outputs": [],
103103
"source": [
104-
"!hyp get-operator-logs hyp-jumpstart-endpoint --since-hours 4"
104+
"!hyp get-operator-logs hyp-jumpstart-endpoint --since-hours 0.5"
105105
]
106106
},
107107
{

examples/inference/CLI/inference-s3-model-e2e-cli.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
28-
"!hyperpod list-cluster"
28+
"!hyperpod list-cluster --output table"
2929
]
3030
},
3131
{
@@ -128,7 +128,7 @@
128128
"metadata": {},
129129
"outputs": [],
130130
"source": [
131-
"!hyp get-operator-logs hyp-custom-endpoint --since-hours 4"
131+
"!hyp get-operator-logs hyp-custom-endpoint --since-hours 0.5"
132132
]
133133
},
134134
{

src/sagemaker/hyperpod/cli/commands/inference.py

Lines changed: 252 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import boto3
44
from typing import Optional
5+
from tabulate import tabulate
56

67
from sagemaker.hyperpod.cli.inference_utils import generate_click_command
78
from jumpstart_inference_config_schemas.registry import SCHEMA_REGISTRY as JS_REG
@@ -104,8 +105,18 @@ def js_list(
104105
"""
105106

106107
endpoints = HPJumpStartEndpoint.model_construct().list(namespace)
107-
out = [ep.metadata.model_dump() for ep in endpoints]
108-
click.echo(json.dumps(out, indent=2))
108+
data = [ep.metadata.model_dump() for ep in endpoints]
109+
110+
if not data:
111+
click.echo("No endpoints found")
112+
return
113+
114+
headers = ["name", "namespace", "labels"]
115+
rows = [
116+
[item.get("name", ""), item.get("namespace", ""), item.get("labels", "")]
117+
for item in data
118+
]
119+
click.echo(tabulate(rows, headers=headers, tablefmt="github"))
109120

110121

111122
@click.command("hyp-custom-endpoint")
@@ -124,8 +135,18 @@ def custom_list(
124135
"""
125136

126137
endpoints = HPEndpoint.model_construct().list(namespace)
127-
out = [ep.metadata.model_dump() for ep in endpoints]
128-
click.echo(json.dumps(out, indent=2))
138+
data = [ep.metadata.model_dump() for ep in endpoints]
139+
140+
if not data:
141+
click.echo("No endpoints found")
142+
return
143+
144+
headers = ["name", "namespace", "labels"]
145+
rows = [
146+
[item.get("name", ""), item.get("namespace", ""), item.get("labels", "")]
147+
for item in data
148+
]
149+
click.echo(tabulate(rows, headers=headers, tablefmt="github"))
129150

130151

131152
@click.command("hyp-jumpstart-endpoint")
@@ -142,16 +163,86 @@ def custom_list(
142163
default="default",
143164
help="Optional. The namespace of the jumpstart model to describe. Default set to 'default'.",
144165
)
166+
@click.option(
167+
"--full",
168+
type=click.BOOL,
169+
is_flag=True,
170+
default=False,
171+
required=False,
172+
help="Optional. If set to `True`, the full json will be displayed",
173+
)
145174
def js_describe(
146175
name: str,
147176
namespace: Optional[str],
177+
full: bool
148178
):
149179
"""
150180
Describe a jumpstart model endpoint with provided name and namespace.
151181
"""
152182

153183
my_endpoint = HPJumpStartEndpoint.model_construct().get(name, namespace)
154-
click.echo(json.dumps(my_endpoint.model_dump(), indent=2))
184+
data = my_endpoint.model_dump()
185+
186+
if full:
187+
click.echo("\nFull JSON:")
188+
click.echo(json.dumps(data, indent=2))
189+
190+
else:
191+
summary = [
192+
("Deployment State:", data.get("status", {}).get("deploymentStatus", {}).get("deploymentObjectOverallState")),
193+
("Model ID:", data.get("model", {}).get("modelId")),
194+
("Instance Type:", data.get("server", {}).get("instanceType")),
195+
("Accept eula:", data.get("model", {}).get("acceptEula")),
196+
("Model Version:", data.get("model", {}).get("modelVersion")),
197+
("TLS Cert. Output S3 URI:",data.get("tlsConfig", {}).get("tlsCertificateOutputS3Uri")),
198+
]
199+
click.echo(tabulate(summary, tablefmt="plain"))
200+
201+
click.echo("\nSageMaker Endpoint:")
202+
ep_rows = [
203+
("State:", data.get("status", {}).get("endpoints", {}).get("sagemaker", {}).get("state")),
204+
("Name:", data.get("sageMakerEndpoint", {}).get("name")),
205+
("ARN:", data.get("status", {}).get("endpoints", {}).get("sagemaker", {}).get("endpointArn")),
206+
]
207+
click.echo(tabulate(ep_rows, tablefmt="plain"))
208+
209+
click.echo("\nConditions:")
210+
conds = data.get("status", {}).get("conditions", [])
211+
if conds:
212+
headers = ["TYPE", "STATUS", "LAST TRANSITION", "LAST UPDATE", "MESSAGE"]
213+
rows = [
214+
[
215+
c.get("type", ""),
216+
c.get("status", ""),
217+
c.get("lastTransitionTime", ""),
218+
c.get("lastUpdateTime", ""),
219+
c.get("message") or ""
220+
]
221+
for c in conds
222+
]
223+
click.echo(tabulate(rows, headers=headers, tablefmt="github"))
224+
else:
225+
click.echo(" <none>")
226+
227+
click.echo("\nDeploymentStatus Conditions:")
228+
dep_status = data.get("status", {}).get("deploymentStatus", {})
229+
dep_conds = dep_status.get("status", {}).get("conditions", [])
230+
if dep_conds:
231+
headers = ["TYPE", "STATUS", "LAST TRANSITION", "LAST UPDATE", "MESSAGE"]
232+
rows = [
233+
[
234+
c.get("type", ""),
235+
c.get("status", ""),
236+
c.get("lastTransitionTime", ""),
237+
c.get("lastUpdateTime", ""),
238+
c.get("message") or ""
239+
]
240+
for c in dep_conds
241+
]
242+
click.echo(tabulate(rows, headers=headers, tablefmt="github"))
243+
else:
244+
click.echo(" <none>")
245+
155246

156247

157248
@click.command("hyp-custom-endpoint")
@@ -168,16 +259,167 @@ def js_describe(
168259
default="default",
169260
help="Optional. The namespace of the custom model to describe. Default set to 'default'.",
170261
)
262+
@click.option(
263+
"--full",
264+
type=click.BOOL,
265+
is_flag=True,
266+
default=False,
267+
required=False,
268+
help="Optional. If set to `True`, the full json will be displayed",
269+
)
171270
def custom_describe(
172271
name: str,
173272
namespace: Optional[str],
273+
full: bool
174274
):
175275
"""
176276
Describe a custom model endpoint with provided name and namespace.
177277
"""
178278

179279
my_endpoint = HPEndpoint.model_construct().get(name, namespace)
180-
click.echo(json.dumps(my_endpoint.model_dump(), indent=2))
280+
data = my_endpoint.model_dump()
281+
282+
if full:
283+
click.echo("\nFull JSON:")
284+
click.echo(json.dumps(data, indent=2))
285+
286+
else:
287+
summary = [
288+
("Deployment State:", data.get("status", {}).get("deploymentStatus", {}).get("deploymentObjectOverallState")),
289+
("Invocation Endpoint", data.get("invocationEndpoint")),
290+
("Instance Type", data.get("instanceType")),
291+
("Metrics Enabled", data.get("metrics", {}).get("enabled")),
292+
("Model Name", data.get("modelName")),
293+
("Model Version", data.get("modelVersion")),
294+
("Model Source Type", data.get("modelSourceConfig", {}).get("modelSourceType")),
295+
("Model Location", data.get("modelSourceConfig", {}).get("modelLocation")),
296+
("Prefetch Enabled", data.get("modelSourceConfig", {}).get("prefetchEnabled")),
297+
("TLS Cert S3 URI", data.get("tlsConfig", {}).get("tlsCertificateOutputS3Uri")),
298+
("FSx DNS Name", data.get("modelSourceConfig", {}).get("fsxStorage", {}).get("dnsName")),
299+
("FSx File System ID", data.get("modelSourceConfig", {}).get("fsxStorage", {}).get("fileSystemId")),
300+
("FSx Mount Name", data.get("modelSourceConfig", {}).get("fsxStorage", {}).get("mountName")),
301+
("S3 Bucket Name", data.get("modelSourceConfig", {}).get("s3Storage", {}).get("bucketName")),
302+
("S3 Region", data.get("modelSourceConfig", {}).get("s3Storage", {}).get("region")),
303+
("Image URI", data.get("imageUri")
304+
or data.get("worker", {}).get("image")),
305+
("Container Port", data.get("containerPort")
306+
or data.get("worker", {})
307+
.get("modelInvocationPort", {})
308+
.get("containerPort")),
309+
("Model Volume Mount Path", data.get("modelVolumeMountPath")
310+
or data.get("worker", {})
311+
.get("modelVolumeMount", {})
312+
.get("mountPath")),
313+
("Model Volume Mount Name", data.get("modelVolumeMountName")
314+
or data.get("worker", {})
315+
.get("modelVolumeMount", {})
316+
.get("name")),
317+
("Resources Limits", data.get("resourcesLimits")
318+
or data.get("worker", {})
319+
.get("resources", {})
320+
.get("limits")),
321+
("Resources Requests", data.get("resourcesRequests")
322+
or data.get("worker", {})
323+
.get("resources", {})
324+
.get("requests")),
325+
("Dimensions", data.get("dimensions")
326+
or data.get("autoScalingSpec", {})
327+
.get("cloudWatchTrigger", {})
328+
.get("dimensions")),
329+
("Metric Collection Period", data.get("metricCollectionPeriod")
330+
or data.get("autoScalingSpec", {})
331+
.get("cloudWatchTrigger", {})
332+
.get("metricCollectionPeriod")),
333+
("Metric Collection Start Time",data.get("metricCollectionStartTime")
334+
or data.get("autoScalingSpec", {})
335+
.get("cloudWatchTrigger", {})
336+
.get("metricCollectionStartTime")),
337+
("Metric Name", data.get("metricName")
338+
or data.get("autoScalingSpec", {})
339+
.get("cloudWatchTrigger", {})
340+
.get("metricName")),
341+
("Metric Stat", data.get("metricStat")
342+
or data.get("autoScalingSpec", {})
343+
.get("cloudWatchTrigger", {})
344+
.get("metricStat")),
345+
("Metric Type", data.get("metricType")
346+
or data.get("autoScalingSpec", {})
347+
.get("cloudWatchTrigger", {})
348+
.get("metricType")),
349+
("Min Value", data.get("minValue")
350+
or data.get("autoScalingSpec", {})
351+
.get("cloudWatchTrigger", {})
352+
.get("minValue")),
353+
("CW Trigger Name", data.get("cloudWatchTriggerName")
354+
or data.get("autoScalingSpec", {})
355+
.get("cloudWatchTrigger", {})
356+
.get("name")),
357+
("CW Trigger Namespace", data.get("cloudWatchTriggerNamespace")
358+
or data.get("autoScalingSpec", {})
359+
.get("cloudWatchTrigger", {})
360+
.get("namespace")),
361+
("Target Value", data.get("targetValue")
362+
or data.get("autoScalingSpec", {})
363+
.get("cloudWatchTrigger", {})
364+
.get("targetValue")),
365+
("Use Cached Metrics", data.get("useCachedMetrics")
366+
or data.get("autoScalingSpec", {})
367+
.get("cloudWatchTrigger", {})
368+
.get("useCachedMetrics")),
369+
]
370+
371+
click.echo(tabulate(summary, tablefmt="plain"))
372+
373+
click.echo("\nSageMaker Endpoint:")
374+
status = data.get("status") or {}
375+
endpoints = status.get("endpoints") or {}
376+
sagemaker_info = endpoints.get("sagemaker")
377+
if not sagemaker_info:
378+
click.secho(" <no SageMaker endpoint information available>", fg="yellow")
379+
else:
380+
ep_rows = [
381+
("State:", data.get("status", {}).get("endpoints", {}).get("sagemaker", {}).get("state")),
382+
("Name:", data.get("sageMakerEndpoint", {}).get("name")),
383+
("ARN:", data.get("status", {}).get("endpoints", {}).get("sagemaker", {}).get("endpointArn")),
384+
]
385+
click.echo(tabulate(ep_rows, tablefmt="plain"))
386+
387+
click.echo("\nConditions:")
388+
conds = data.get("status", {}).get("conditions", [])
389+
if conds:
390+
headers = ["TYPE", "STATUS", "LAST TRANSITION", "LAST UPDATE", "MESSAGE"]
391+
rows = [
392+
[
393+
c.get("type", ""),
394+
c.get("status", ""),
395+
c.get("lastTransitionTime", ""),
396+
c.get("lastUpdateTime", ""),
397+
c.get("message") or ""
398+
]
399+
for c in conds
400+
]
401+
click.echo(tabulate(rows, headers=headers, tablefmt="github"))
402+
else:
403+
click.echo(" <none>")
404+
405+
click.echo("\nDeploymentStatus Conditions:")
406+
dep_status = data.get("status", {}).get("deploymentStatus", {})
407+
dep_conds = dep_status.get("status", {}).get("conditions", [])
408+
if dep_conds:
409+
headers = ["TYPE", "STATUS", "LAST TRANSITION", "LAST UPDATE", "MESSAGE"]
410+
rows = [
411+
[
412+
c.get("type", ""),
413+
c.get("status", ""),
414+
c.get("lastTransitionTime", ""),
415+
c.get("lastUpdateTime", ""),
416+
c.get("message") or ""
417+
]
418+
for c in dep_conds
419+
]
420+
click.echo(tabulate(rows, headers=headers, tablefmt="github"))
421+
else:
422+
click.echo(" <none>")
181423

182424

183425
@click.command("hyp-jumpstart-endpoint")
@@ -301,12 +543,12 @@ def custom_get_logs(
301543
@click.command("hyp-jumpstart-endpoint")
302544
@click.option(
303545
"--since-hours",
304-
type=click.INT,
546+
type=click.FLOAT,
305547
required=True,
306548
help="Required. The time frame to get logs for.",
307549
)
308550
def js_get_operator_logs(
309-
since_hours: int,
551+
since_hours: float,
310552
):
311553
"""
312554
Get specific pod log for jumpstart model endpoint.
@@ -319,12 +561,12 @@ def js_get_operator_logs(
319561
@click.command("hyp-custom-endpoint")
320562
@click.option(
321563
"--since-hours",
322-
type=click.INT,
564+
type=click.FLOAT,
323565
required=True,
324566
help="Required. The time frame get logs for.",
325567
)
326568
def custom_get_operator_logs(
327-
since_hours: int,
569+
since_hours: float,
328570
):
329571
"""
330572
Get specific pod log for custom model endpoint.

0 commit comments

Comments
 (0)