Skip to content

Commit a90505c

Browse files
committed
Updated to handle YAML arrays in config file (#190)
1 parent 4f0e26a commit a90505c

File tree

4 files changed

+206
-23
lines changed

4 files changed

+206
-23
lines changed

src/sagemaker/hyperpod/cli/init_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _parse_json_flag(ctx, param, value):
226226
try:
227227
# Remove brackets and split by comma
228228
inner = value.strip()[1:-1]
229-
items = [item.strip() for item in inner.split(',')]
229+
items = [item.strip().strip('"').strip("'") for item in inner.split(',')]
230230
return items
231231
except:
232232
pass
@@ -458,6 +458,15 @@ def save_config_yaml(prefill: dict, comment_map: dict, directory: str):
458458
if vol.get('read_only') is not None:
459459
f.write(f" read_only: {vol.get('read_only')}\n")
460460
f.write("\n")
461+
elif isinstance(val, list):
462+
# Handle arrays in YAML format
463+
if val:
464+
f.write(f"{key}:\n")
465+
for item in val:
466+
f.write(f" - {item}\n")
467+
else:
468+
f.write(f"{key}: []\n")
469+
f.write("\n")
461470
else:
462471
# Handle simple values
463472
val = "" if val is None else val
@@ -780,7 +789,7 @@ def build_config_from_schema(template: str, version: str, model_config=None, exi
780789
if val_stripped.startswith('[') and val_stripped.endswith(']'):
781790
try:
782791
inner = val_stripped[1:-1]
783-
val = [item.strip() for item in inner.split(',')]
792+
val = [item.strip().strip('"').strip("'") for item in inner.split(',')]
784793
except:
785794
pass
786795

src/sagemaker/hyperpod/cluster_management/hp_cluster_stack.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212

1313
from sagemaker.hyperpod import create_boto3_client
1414

15-
CLUSTER_CREATION_TEMPLATE_FILE_NAME = "v1_0/main-stack-eks-based-cfn-template.yaml"
16-
CLUSTER_STACK_TEMPLATE_PACKAGE_NAME = "hyperpod_cluster_stack_template"
17-
1815
CAPABILITIES_FOR_STACK_CREATION = [
1916
'CAPABILITY_IAM',
2017
'CAPABILITY_NAMED_IAM'
@@ -31,17 +28,25 @@ class HpClusterStack(ClusterStackBase):
3128
None,
3229
description="CloudFormation stack name set after stack creation"
3330
)
31+
32+
def __init__(self, **data):
33+
# Convert array values to JSON strings
34+
for key, value in data.items():
35+
if isinstance(value, list):
36+
data[key] = json.dumps(value)
37+
super().__init__(**data)
3438

3539
@staticmethod
3640
def get_template() -> str:
37-
s3 = create_boto3_client('s3')
38-
response = s3.get_object(
39-
Bucket='sagemaker-hyperpod-cluster-stack-bucket',
40-
Key='1.0/main-stack-eks-based-cfn-template.yaml'
41-
)
42-
yaml_content = response['Body'].read().decode('utf-8')
43-
yaml_data = yaml.safe_load(yaml_content)
44-
return json.dumps(yaml_data, indent=2, ensure_ascii=False)
41+
try:
42+
template_content = importlib.resources.read_text(
43+
'hyperpod_cluster_stack_template',
44+
'creation_template.yaml'
45+
)
46+
yaml_data = yaml.safe_load(template_content)
47+
return json.dumps(yaml_data, indent=2, ensure_ascii=False)
48+
except Exception as e:
49+
raise RuntimeError(f"Failed to load template from package: {e}")
4550

4651
def create(self,
4752
region: Optional[str] = None) -> str:
@@ -56,7 +61,7 @@ def create(self,
5661
# Get account ID and create bucket name
5762
bucket_name = f"sagemaker-hyperpod-cluster-stack-bucket"
5863
template_key = f"1.0/main-stack-eks-based-cfn-template.yaml"
59-
64+
6065
try:
6166
# Use TemplateURL for large templates (>51KB)
6267
template_url = f"https://{bucket_name}.s3.amazonaws.com/{template_key}"
@@ -65,10 +70,7 @@ def create(self,
6570
StackName=stack_name,
6671
TemplateURL=template_url,
6772
Parameters=parameters,
68-
Tags=self.tags or [{
69-
'Key': 'Environment',
70-
'Value': 'Development'
71-
}],
73+
Tags=self._parse_tags(),
7274
Capabilities=CAPABILITIES_FOR_STACK_CREATION
7375
)
7476

@@ -137,6 +139,13 @@ def _create_parameters(self) -> List[Dict[str, str]]:
137139
})
138140
return parameters
139141

142+
def _parse_tags(self) -> List[Dict[str, str]]:
143+
"""Parse tags field and return proper CloudFormation tags format."""
144+
try:
145+
return json.loads(self.tags) if self.tags else []
146+
except (json.JSONDecodeError, TypeError):
147+
return []
148+
140149
def _convert_nested_keys(self, obj: Any) -> Any:
141150
"""Convert nested JSON keys from snake_case to PascalCase."""
142151
if isinstance(obj, dict):

test/unit_tests/cli/test_init_utils.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -928,4 +928,45 @@ def test_process_cfn_template_content_preserves_template_structure(self, mock_cl
928928
assert "CloudFormation Template" in result
929929
assert original_content in result
930930
assert "OtherParam" in result
931-
assert mock_template in result
931+
assert mock_template in result
932+
933+
934+
class TestProcessCfnTemplateContentUpdated:
935+
"""Test updated _process_cfn_template_content function"""
936+
937+
@patch('sagemaker.hyperpod.cli.init_utils.HpClusterStack.get_template')
938+
def test_process_cfn_template_content_uses_package_template(self, mock_get_template):
939+
"""Test that _process_cfn_template_content uses HpClusterStack.get_template from package"""
940+
original_content = "Original template content"
941+
mock_template = '{"Parameters": {"TestParam": {"Type": "String"}}}'
942+
mock_get_template.return_value = mock_template
943+
944+
result = _process_cfn_template_content(original_content)
945+
946+
# Verify get_template was called (not S3)
947+
mock_get_template.assert_called_once()
948+
949+
# Verify content structure
950+
assert "CloudFormation Template:" in result
951+
assert mock_template in result
952+
assert original_content in result
953+
954+
@patch('sagemaker.hyperpod.cli.init_utils.HpClusterStack.get_template')
955+
@patch('sagemaker.hyperpod.cli.init_utils.click.secho')
956+
def test_process_cfn_template_content_handles_package_error(self, mock_secho, mock_get_template):
957+
"""Test error handling when package template fails to load"""
958+
original_content = "Original template content"
959+
mock_get_template.side_effect = RuntimeError("Failed to load template from package")
960+
961+
result = _process_cfn_template_content(original_content)
962+
963+
# Verify error was logged
964+
mock_secho.assert_called_once_with(
965+
"⚠️ Failed to generate CloudFormation template: Failed to load template from package",
966+
fg="red"
967+
)
968+
969+
# Verify fallback behavior
970+
assert "CloudFormation Template:" in result
971+
assert original_content in result
972+
assert result.count("CloudFormation Template:") == 1

test/unit_tests/cluster_management/test_hp_cluster_stack.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from botocore.exceptions import ClientError
55
import boto3
66

7-
from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack, CLUSTER_STACK_TEMPLATE_PACKAGE_NAME, CLUSTER_CREATION_TEMPLATE_FILE_NAME
8-
7+
from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack
98

109
class TestHpClusterStack(unittest.TestCase):
1110
@patch('uuid.uuid4')
@@ -392,9 +391,134 @@ def test_create_parameters_preserves_other_fields(self):
392391

393392
# Should have the other fields
394393
param_keys = [p['ParameterKey'] for p in other_params]
395-
self.assertIn('HyperpodClusterName', param_keys)
394+
self.assertIn('HyperPodClusterName', param_keys)
396395
self.assertIn('CreateVPCStack', param_keys)
397396

398397
# Verify boolean conversion
399398
vpc_param = next(p for p in other_params if p['ParameterKey'] == 'CreateVPCStack')
400-
self.assertEqual(vpc_param['ParameterValue'], 'true')
399+
self.assertEqual(vpc_param['ParameterValue'], 'true')
400+
401+
class TestHpClusterStackInit(unittest.TestCase):
402+
"""Test HpClusterStack __init__ method array conversion"""
403+
404+
def test_init_converts_arrays_to_json_strings(self):
405+
"""Test that __init__ converts array values to JSON strings"""
406+
data = {
407+
'tags': [{'Key': 'Environment', 'Value': 'Test'}],
408+
'availability_zone_ids': ['us-east-1a', 'us-east-1b'],
409+
'hyperpod_cluster_name': 'test-cluster',
410+
'storage_capacity': 1200
411+
}
412+
413+
stack = HpClusterStack(**data)
414+
415+
# Arrays should be converted to JSON strings
416+
self.assertEqual(stack.tags, '[{"Key": "Environment", "Value": "Test"}]')
417+
self.assertEqual(stack.availability_zone_ids, '["us-east-1a", "us-east-1b"]')
418+
419+
# Other types should remain unchanged
420+
self.assertEqual(stack.hyperpod_cluster_name, 'test-cluster')
421+
self.assertEqual(stack.storage_capacity, 1200)
422+
423+
def test_init_handles_empty_arrays(self):
424+
"""Test that empty arrays are converted to empty JSON arrays"""
425+
data = {'tags': []}
426+
427+
stack = HpClusterStack(**data)
428+
429+
self.assertEqual(stack.tags, '[]')
430+
431+
def test_init_handles_no_arrays(self):
432+
"""Test that __init__ works normally when no arrays are present"""
433+
data = {
434+
'hyperpod_cluster_name': 'test-cluster',
435+
'stage': 'gamma'
436+
}
437+
438+
stack = HpClusterStack(**data)
439+
440+
self.assertEqual(stack.hyperpod_cluster_name, 'test-cluster')
441+
self.assertEqual(stack.stage, 'gamma')
442+
443+
444+
class TestHpClusterStackParseTags(unittest.TestCase):
445+
"""Test HpClusterStack _parse_tags method"""
446+
447+
def test_parse_tags_valid_json_array(self):
448+
"""Test parsing valid JSON array of tags"""
449+
tags_json = '[{"Key": "Environment", "Value": "Test"}, {"Key": "Project", "Value": "HyperPod"}]'
450+
stack = HpClusterStack(tags=tags_json)
451+
452+
result = stack._parse_tags()
453+
454+
expected = [
455+
{"Key": "Environment", "Value": "Test"},
456+
{"Key": "Project", "Value": "HyperPod"}
457+
]
458+
self.assertEqual(result, expected)
459+
460+
def test_parse_tags_empty_string(self):
461+
"""Test parsing empty tags string returns empty list"""
462+
stack = HpClusterStack(tags="")
463+
464+
result = stack._parse_tags()
465+
466+
self.assertEqual(result, [])
467+
468+
def test_parse_tags_none_value(self):
469+
"""Test parsing None tags returns empty list"""
470+
stack = HpClusterStack(tags=None)
471+
472+
result = stack._parse_tags()
473+
474+
self.assertEqual(result, [])
475+
476+
def test_parse_tags_invalid_json(self):
477+
"""Test parsing invalid JSON returns empty list"""
478+
stack = HpClusterStack(tags="invalid json")
479+
480+
result = stack._parse_tags()
481+
482+
self.assertEqual(result, [])
483+
484+
def test_parse_tags_empty_json_array(self):
485+
"""Test parsing empty JSON array returns empty list"""
486+
stack = HpClusterStack(tags="[]")
487+
488+
result = stack._parse_tags()
489+
490+
self.assertEqual(result, [])
491+
492+
493+
class TestHpClusterStackGetTemplate(unittest.TestCase):
494+
"""Test HpClusterStack get_template method using package instead of S3"""
495+
496+
@patch('sagemaker.hyperpod.cluster_management.hp_cluster_stack.importlib.resources.read_text')
497+
@patch('sagemaker.hyperpod.cluster_management.hp_cluster_stack.yaml.safe_load')
498+
def test_get_template_from_package(self, mock_yaml_load, mock_read_text):
499+
"""Test get_template reads from package instead of S3"""
500+
mock_yaml_content = "Parameters:\n TestParam:\n Type: String"
501+
mock_read_text.return_value = mock_yaml_content
502+
503+
mock_yaml_data = {"Parameters": {"TestParam": {"Type": "String"}}}
504+
mock_yaml_load.return_value = mock_yaml_data
505+
506+
result = HpClusterStack.get_template()
507+
508+
# Verify package resource was read
509+
mock_read_text.assert_called_once_with('hyperpod_cluster_stack_template', 'creation_template.yaml')
510+
mock_yaml_load.assert_called_once_with(mock_yaml_content)
511+
512+
# Verify JSON output
513+
expected_json = json.dumps(mock_yaml_data, indent=2, ensure_ascii=False)
514+
self.assertEqual(result, expected_json)
515+
516+
@patch('sagemaker.hyperpod.cluster_management.hp_cluster_stack.importlib.resources.read_text')
517+
def test_get_template_handles_package_error(self, mock_read_text):
518+
"""Test get_template handles package read errors"""
519+
mock_read_text.side_effect = FileNotFoundError("Template not found")
520+
521+
with self.assertRaises(RuntimeError) as context:
522+
HpClusterStack.get_template()
523+
524+
self.assertIn("Failed to load template from package", str(context.exception))

0 commit comments

Comments
 (0)