Skip to content

Commit 0417c3f

Browse files
committed
Implementing Task Gov. feature for SDK flow
1 parent f571859 commit 0417c3f

File tree

9 files changed

+604
-6
lines changed

9 files changed

+604
-6
lines changed

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
'topology.k8s.aws/network-node-layer-2',
2121
'topology.k8s.aws/network-node-layer-3'
2222
}
23-
from .quota_allocation_util import _is_valid, _get_resources_from_compute_quotas, _get_resources_from_instance, _get_limits
23+
from hyperpod_pytorch_job_template.quota_allocation_util import _is_valid, _get_resources_from_compute_quotas, _get_resources_from_instance, _get_limits
2424

2525
class VolumeConfig(BaseModel):
2626
model_config = ConfigDict(extra="forbid")
@@ -111,7 +111,7 @@ class PyTorchJobConfig(BaseModel):
111111
min_length=1
112112
)
113113
node_count: Optional[int] = Field(
114-
default=1,
114+
default=None,
115115
alias="node_count",
116116
description="Number of nodes",
117117
ge=1
@@ -290,7 +290,7 @@ def to_domain(self) -> Dict:
290290
valid, error = _is_valid(
291291
self.vcpu, self.memory, self.accelerators, self.node_count, self.instance_type
292292
)
293-
293+
294294
if not valid:
295295
raise ValueError(error)
296296

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@
202202
"type": "null"
203203
}
204204
],
205-
"default": 1,
205+
"default": null,
206206
"description": "Number of nodes",
207207
"title": "Node Count"
208208
},

src/sagemaker/hyperpod/cli/constants/command_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SAGEMAKER_MANAGED_CLUSTER_QUEUE_SUFFIX = "-clusterqueue"
4545
SAGEMAKER_TRAINING_LAUNCHER_DIR = str(Path(__file__).parent.parent / "sagemaker_hyperpod_recipes")
4646
NVIDIA_GPU_RESOURCE_LIMIT_KEY = "nvidia.com/gpu"
47+
NEURON_RESOURCE_LIMIT_KEY = "aws.amazon.com/neurondevice"
4748
AVAILABLE_ACCELERATOR_DEVICES_KEY = "AvailableAcceleratorDevices"
4849
TOTAL_ACCELERATOR_DEVICES_KEY = "TotalAcceleratorDevices"
4950
USER_NAME_LABEL_KEY = "sagemaker.user/created-by"

src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2979,7 +2979,7 @@ class ReplicaSpec(BaseModel):
29792979

29802980
name: str = Field(description="The name for the replica set")
29812981
replicas: Optional[int] = Field(
2982-
default=1,
2982+
default=0,
29832983
description="Replicas is the desired number of replicas of the given template.",
29842984
)
29852985
spares: Optional[int] = Field(

src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from pydantic import ConfigDict, Field
2+
3+
from sagemaker.hyperpod.cli.constants.command_constants import INSTANCE_TYPE_LABEL, NVIDIA_GPU_RESOURCE_LIMIT_KEY, \
4+
NEURON_RESOURCE_LIMIT_KEY
25
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
36
_HyperPodPytorchJob, HyperPodPytorchJobStatus
47
)
@@ -18,6 +21,9 @@
1821
import yaml
1922
import logging
2023

24+
from hyperpod_pytorch_job_template.quota_allocation_util import _is_valid, _get_resources_from_compute_quotas, _get_resources_from_instance, _get_limits
25+
26+
2127

2228
TRAINING_GROUP = "sagemaker.amazonaws.com"
2329
API_VERSION = "v1"
@@ -52,6 +58,109 @@ def verify_kube_config(cls):
5258

5359
# Verify Kubernetes version compatibility
5460
verify_kubernetes_version_compatibility(cls.get_logger())
61+
@classmethod
62+
def _extract_numeric_value(cls, value):
63+
"""Extract numeric value from strings like '1.5Gi' -> 1.5"""
64+
if not value:
65+
return None
66+
import re
67+
match = re.match(r'^([0-9]*\.?[0-9]+)', str(value))
68+
return float(match.group(1)) if match else None
69+
70+
@classmethod
71+
def sanitize_memory(cls, resource):
72+
try :
73+
if 'memory' in resource:
74+
memory = resource['memory']
75+
# Case when quotas have been already initialized in CLI layer
76+
# ToDo : Cleanup quota initialization in CLI layer and directly use SDK layer for init.
77+
memory.replace('GiGi', 'Gi')
78+
resource['memory'] = memory
79+
return resource
80+
except Exception as e:
81+
return resource
82+
83+
84+
@classmethod
85+
def _process_replica_resources(cls, data):
86+
"""Process and validate replica resource configuration."""
87+
try:
88+
node_count = data.get('replicas', None)
89+
90+
# Extract nested configuration with validation
91+
template = data.get('template', {})
92+
spec = template.get('spec', {})
93+
node_selector = spec.get('nodeSelector', {})
94+
instance_type = node_selector.get(INSTANCE_TYPE_LABEL) if node_selector else None
95+
96+
if not instance_type:
97+
return None
98+
99+
containers = spec.get('containers', [])
100+
101+
if not containers:
102+
raise ValueError("No containers found in template spec")
103+
104+
container = containers[0]
105+
resources = container.get('resources', {})
106+
requests = resources.get('requests', {})
107+
limits = resources.get('limits', {})
108+
109+
# Extract resource values
110+
vcpu = float(requests.get('cpu')) if requests.get('cpu') else None
111+
memory = cls._extract_numeric_value(requests.get('memory'))
112+
accelerators = int(requests.get(NVIDIA_GPU_RESOURCE_LIMIT_KEY)) or int(requests.get(NEURON_RESOURCE_LIMIT_KEY)) or None
113+
memory_limit = cls._extract_numeric_value(limits.get('memory'))
114+
vcpu_limit = float(limits.get('cpu')) if limits.get('cpu') else None
115+
accelerators_limit = int(limits.get(NVIDIA_GPU_RESOURCE_LIMIT_KEY)) or int(limits.get(NEURON_RESOURCE_LIMIT_KEY)) or None
116+
117+
# Validate configuration
118+
valid, error = _is_valid(vcpu, memory, accelerators, node_count, instance_type)
119+
if not valid:
120+
raise ValueError(error)
121+
122+
# Calculate resource values
123+
requests_value = (_get_resources_from_compute_quotas(instance_type, vcpu, memory, accelerators)
124+
or _get_resources_from_instance(instance_type, node_count))
125+
limits_value = _get_limits(instance_type, vcpu_limit, memory_limit, accelerators_limit)
126+
127+
requests_value = cls.sanitize_memory(requests_value)
128+
limits_value = cls.sanitize_memory(limits_value)
129+
130+
# Update data with calculated values
131+
data['template']['spec']['containers'][0]['resources']['requests'] = requests_value
132+
data['template']['spec']['containers'][0]['resources']['limits'] = limits_value
133+
return data
134+
except KeyError as e:
135+
raise ValueError(f"Missing required configuration key: {str(e)}")
136+
137+
@classmethod
138+
def _get_container_resources(cls, replica_spec):
139+
"""Extract container resources from replica spec."""
140+
container_resources = replica_spec['template']['spec']['containers'][0]['resources']
141+
return container_resources['requests'], container_resources['limits']
142+
143+
@classmethod
144+
def allocate_quotas_if_applicable(cls, spec):
145+
logger = cls.get_logger()
146+
logger = setup_logging(logger)
147+
try:
148+
spec_dict = spec.model_dump()
149+
replica_spec = spec_dict['replicaSpecs'][0]
150+
cls._process_replica_resources(replica_spec)
151+
152+
# Update the original spec object directly
153+
requests, limits = cls._get_container_resources(replica_spec)
154+
spec.replicaSpecs[0].template.spec.containers[0].resources.requests = requests
155+
spec.replicaSpecs[0].template.spec.containers[0].resources.limits = limits
156+
157+
return spec
158+
except ValueError as e:
159+
logger.error(f"Error: in quota allocation:{e}")
160+
raise ValueError(e)
161+
except Exception as e:
162+
logger.info(f"Warning: in quota allocation: {e}. using defaults.")
163+
return spec
55164

56165
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_pytorchjob")
57166
def create(self, debug=False):
@@ -65,6 +174,10 @@ def create(self, debug=False):
65174
if not self.metadata.namespace:
66175
self.metadata.namespace = get_default_namespace()
67176

177+
spec = self.allocate_quotas_if_applicable(spec)
178+
if spec.replicaSpecs[0].replicas == 0 :
179+
spec.replicaSpecs[0].replicas = 1 # default value
180+
68181
config = {
69182
"apiVersion": f"{TRAINING_GROUP}/{API_VERSION}",
70183
"kind": KIND,
@@ -91,6 +204,8 @@ def create(self, debug=False):
91204
logger.error(f"Failed to create HyperPodPytorchJob {self.metadata.name}!")
92205
handle_exception(e, self.metadata.name, self.metadata.namespace)
93206

207+
208+
94209
@classmethod
95210
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pytorchjobs")
96211
def list(cls, namespace=None) -> List["HyperPodPytorchJob"]:

0 commit comments

Comments
 (0)