Skip to content

Commit 395b88d

Browse files
committed
Enable Telemetry for Cluster creation (#230)
* Enable Telemetry for Cluster creation * Telemetry for CLI and updates * Fix
1 parent a1c1094 commit 395b88d

File tree

2 files changed

+34
-24
lines changed

2 files changed

+34
-24
lines changed

src/sagemaker/hyperpod/cli/commands/cluster_stack.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from tabulate import tabulate
1616
from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack
17+
from sagemaker.hyperpod.common.telemetry import _hyperpod_telemetry_emitter
18+
from sagemaker.hyperpod.common.telemetry.constants import Feature
1719
from sagemaker.hyperpod.common.utils import setup_logging
1820
from sagemaker.hyperpod.cli.utils import convert_datetimes
1921

@@ -135,6 +137,7 @@ def create_cluster_stack_helper(config_file: str, region: Optional[str] = None,
135137
@click.argument("stack-name", required=True)
136138
@click.option("--region", help="AWS region")
137139
@click.option("--debug", is_flag=True, help="Enable debug logging")
140+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "describe_cluster_stack_cli")
138141
def describe_cluster_stack(stack_name: str, debug: bool, region: str) -> None:
139142
"""Describe the status of a HyperPod cluster stack.
140143
@@ -212,6 +215,7 @@ def describe_cluster_stack(stack_name: str, debug: bool, region: str) -> None:
212215
@click.option("--status",
213216
callback=parse_status_list,
214217
help="Filter by stack status. Format: \"['CREATE_COMPLETE', 'UPDATE_COMPLETE']\"")
218+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "list_cluster_stack_cli")
215219
def list_cluster_stacks(region, debug, status):
216220
"""List all HyperPod cluster stacks.
217221
@@ -305,6 +309,7 @@ def delete(stack_name: str, debug: bool) -> None:
305309
@click.option("--region", help="Region")
306310
@click.option("--node-recovery", help="Node Recovery (Automatic or None)")
307311
@click.option("--debug", is_flag=True, help="Enable debug logging")
312+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "update_cluster_cli")
308313
def update_cluster(
309314
cluster_name: str,
310315
instance_groups: Optional[str],

src/sagemaker/hyperpod/cluster_management/hp_cluster_stack.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from hyperpod_cluster_stack_template.v1_0.model import ClusterStackBase
1212

1313
from sagemaker.hyperpod import create_boto3_client
14+
from sagemaker.hyperpod.common.telemetry import _hyperpod_telemetry_emitter
15+
from sagemaker.hyperpod.common.telemetry.constants import Feature
1416

1517
CAPABILITIES_FOR_STACK_CREATION = [
1618
'CAPABILITY_IAM',
@@ -33,7 +35,7 @@ class HpClusterStack(ClusterStackBase):
3335
>>> # Create a cluster stack instance
3436
>>> stack = HpClusterStack()
3537
>>> response = stack.create(region="us-west-2")
36-
>>>
38+
>>>
3739
>>> # Check stack status
3840
>>> status = stack.get_status()
3941
>>> print(status)
@@ -46,17 +48,17 @@ class HpClusterStack(ClusterStackBase):
4648
None,
4749
description="CloudFormation stack name set after stack creation"
4850
)
49-
51+
5052
def __init__(self, **data):
5153
super().__init__(**data)
52-
54+
5355
@field_validator('kubernetes_version', mode='before')
5456
@classmethod
5557
def validate_kubernetes_version(cls, v):
5658
if v is not None:
5759
return str(v)
5860
return v
59-
61+
6062
@field_validator('availability_zone_ids', 'nat_gateway_ids', 'eks_private_subnet_ids', 'security_group_ids', 'private_route_table_ids', 'private_subnet_ids', 'instance_group_settings', 'rig_settings', 'tags', mode='before')
6163
@classmethod
6264
def validate_list_fields(cls, v):
@@ -71,7 +73,7 @@ def validate_list_fields(cls, v):
7173
v = ast.literal_eval(v)
7274
except:
7375
pass # Keep original value if parsing fails
74-
76+
7577
if isinstance(v, list) and len(v) == 0:
7678
raise ValueError('Empty lists [] are not allowed. Use proper YAML array format or leave field empty.')
7779
return v
@@ -80,14 +82,15 @@ def validate_list_fields(cls, v):
8082
def get_template() -> str:
8183
try:
8284
template_content = importlib.resources.read_text(
83-
'hyperpod_cluster_stack_template',
85+
'hyperpod_cluster_stack_template',
8486
'creation_template.yaml'
8587
)
8688
yaml_data = yaml.safe_load(template_content)
8789
return json.dumps(yaml_data, indent=2, ensure_ascii=False)
8890
except Exception as e:
8991
raise RuntimeError(f"Failed to load template from package: {e}")
9092

93+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_cluster_stack")
9194
def create(self,
9295
region: Optional[str] = None) -> str:
9396
"""Creates a new HyperPod cluster CloudFormation stack.
@@ -121,7 +124,7 @@ def create(self,
121124
>>> # Create stack in default region
122125
>>> stack = HpClusterStack()
123126
>>> response = stack.create()
124-
>>>
127+
>>>
125128
>>> # Create stack in specific region
126129
>>> response = stack.create(region="us-east-1")
127130
"""
@@ -178,12 +181,12 @@ def _create_parameters(self) -> List[Dict[str, str]]:
178181
settings_list = json.loads(str(value))
179182
except (json.JSONDecodeError, TypeError):
180183
settings_list = []
181-
184+
182185
for i, setting in enumerate(settings_list, 1):
183186
formatted_setting = self._convert_nested_keys(setting)
184187
parameters.append({
185188
'ParameterKey': f'InstanceGroupSettings{i}',
186-
'ParameterValue': "[" + json.dumps(formatted_setting) + "]" if isinstance(formatted_setting, (dict, list)) else str(formatted_setting)
189+
'ParameterValue': "[" + json.dumps(formatted_setting) + "]" if isinstance(formatted_setting, (dict, list)) else str(formatted_setting)
187190
})
188191
elif field_name == 'rig_settings':
189192
# Handle both list and JSON string formats
@@ -195,7 +198,7 @@ def _create_parameters(self) -> List[Dict[str, str]]:
195198
settings_list = json.loads(str(value))
196199
except (json.JSONDecodeError, TypeError):
197200
settings_list = []
198-
201+
199202
for i, setting in enumerate(settings_list, 1):
200203
formatted_setting = self._convert_nested_keys(setting)
201204
parameters.append({
@@ -204,7 +207,7 @@ def _create_parameters(self) -> List[Dict[str, str]]:
204207
})
205208
else:
206209
# Convert array fields to comma-separated strings
207-
if field_name in ['availability_zone_ids', 'nat_gateway_ids', 'eks_private_subnet_ids',
210+
if field_name in ['availability_zone_ids', 'nat_gateway_ids', 'eks_private_subnet_ids',
208211
'security_group_ids', 'private_route_table_ids', 'private_subnet_ids']:
209212
if isinstance(value, list):
210213
value = ','.join(str(item) for item in value)
@@ -236,22 +239,22 @@ def _parse_tags(self) -> List[Dict[str, str]]:
236239
"""Parse tags field and return proper CloudFormation tags format."""
237240
if not self.tags:
238241
return []
239-
242+
240243
tags_list = self.tags
241244
if isinstance(self.tags, str):
242245
try:
243246
tags_list = json.loads(self.tags)
244247
except (json.JSONDecodeError, TypeError):
245248
return []
246-
249+
247250
# Convert array of strings to Key-Value format
248251
if isinstance(tags_list, list) and tags_list:
249252
# Check if already in Key-Value format
250253
if isinstance(tags_list[0], dict) and 'Key' in tags_list[0]:
251254
return tags_list
252255
# Convert string array to Key-Value format
253256
return [{'Key': tag, 'Value': ''} for tag in tags_list if isinstance(tag, str)]
254-
257+
255258
return []
256259

257260
def _convert_nested_keys(self, obj: Any) -> Any:
@@ -267,7 +270,7 @@ def _snake_to_pascal(snake_str: str) -> str:
267270
"""Convert snake_case string to PascalCase."""
268271
if not snake_str:
269272
return snake_str
270-
273+
271274
# Handle specific cases
272275
mappings = {
273276
"eks_cluster_name": "EKSClusterName",
@@ -289,14 +292,14 @@ def _snake_to_pascal(snake_str: str) -> str:
289292
"EbsVolumeConfig": "EbsVolumeConfig",
290293
"VolumeSizeInGB": "VolumeSizeInGB"
291294
}
292-
295+
293296
if snake_str in mappings:
294297
return mappings[snake_str]
295298

296299

297300
# Default case: capitalize each word
298301
return ''.join(word.capitalize() for word in snake_str.split('_'))
299-
302+
300303
def _snake_to_camel(self, snake_str: str) -> str:
301304
"""Convert snake_case string to camelCase for nested JSON keys."""
302305
if not snake_str:
@@ -305,6 +308,7 @@ def _snake_to_camel(self, snake_str: str) -> str:
305308
return words[0] + ''.join(word.capitalize() for word in words[1:])
306309

307310
@staticmethod
311+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "describe_cluster_stack")
308312
def describe(stack_name, region: Optional[str] = None):
309313
"""Describes a CloudFormation stack by name.
310314
@@ -343,7 +347,7 @@ def describe(stack_name, region: Optional[str] = None):
343347
344348
>>> # Describe a stack by name
345349
>>> response = HpClusterStack.describe("my-stack-name")
346-
>>>
350+
>>>
347351
>>> # Describe stack in specific region
348352
>>> response = HpClusterStack.describe("my-stack", region="us-west-2")
349353
"""
@@ -368,6 +372,7 @@ def describe(stack_name, region: Optional[str] = None):
368372
raise RuntimeError("Stack operation failed")
369373

370374
@staticmethod
375+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_cluster_stack")
371376
def list(region: Optional[str] = None, stack_status_filter: Optional[List[str]] = None):
372377
"""Lists all CloudFormation stacks in the specified region.
373378
@@ -403,7 +408,7 @@ def list(region: Optional[str] = None, stack_status_filter: Optional[List[str]]
403408
404409
>>> # List stacks in current region
405410
>>> stacks = HpClusterStack.list()
406-
>>>
411+
>>>
407412
>>> # List stacks in specific region
408413
>>> stacks = HpClusterStack.list(region="us-east-1")
409414
"""
@@ -412,19 +417,19 @@ def list(region: Optional[str] = None, stack_status_filter: Optional[List[str]]
412417
try:
413418
# Prepare API call parameters
414419
list_params = {}
415-
420+
416421
if stack_status_filter is not None:
417422
list_params['StackStatusFilter'] = stack_status_filter
418-
423+
419424
response = cf.list_stacks(**list_params)
420-
425+
421426
# Only filter DELETE_COMPLETE when no explicit filter is provided
422427
if stack_status_filter is None and 'StackSummaries' in response:
423428
response['StackSummaries'] = [
424-
stack for stack in response['StackSummaries']
429+
stack for stack in response['StackSummaries']
425430
if stack.get('StackStatus') != 'DELETE_COMPLETE'
426431
]
427-
432+
428433
return response
429434
except cf.exceptions.ClientError as e:
430435
error_code = e.response['Error']['Code']

0 commit comments

Comments
 (0)