Skip to content

Commit 31cf556

Browse files
committed
Implementing Task Gov. feature for SDK flow
1 parent f571859 commit 31cf556

File tree

8 files changed

+575
-3
lines changed

8 files changed

+575
-3
lines changed

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
'topology.k8s.aws/network-node-layer-2',
2121
'topology.k8s.aws/network-node-layer-3'
2222
}
23-
from .quota_allocation_util import _is_valid, _get_resources_from_compute_quotas, _get_resources_from_instance, _get_limits
23+
from hyperpod_pytorch_job_template.quota_allocation_util import _is_valid, _get_resources_from_compute_quotas, _get_resources_from_instance, _get_limits
2424

2525
class VolumeConfig(BaseModel):
2626
model_config = ConfigDict(extra="forbid")

src/sagemaker/hyperpod/cli/constants/command_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
SAGEMAKER_MANAGED_CLUSTER_QUEUE_SUFFIX = "-clusterqueue"
4545
SAGEMAKER_TRAINING_LAUNCHER_DIR = str(Path(__file__).parent.parent / "sagemaker_hyperpod_recipes")
4646
NVIDIA_GPU_RESOURCE_LIMIT_KEY = "nvidia.com/gpu"
47+
NEURON_RESOURCE_LIMIT_KEY = "aws.amazon.com/neurondevice"
4748
AVAILABLE_ACCELERATOR_DEVICES_KEY = "AvailableAcceleratorDevices"
4849
TOTAL_ACCELERATOR_DEVICES_KEY = "TotalAcceleratorDevices"
4950
USER_NAME_LABEL_KEY = "sagemaker.user/created-by"

src/sagemaker/hyperpod/training/config/hyperpod_pytorch_job_unified_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2979,7 +2979,7 @@ class ReplicaSpec(BaseModel):
29792979

29802980
name: str = Field(description="The name for the replica set")
29812981
replicas: Optional[int] = Field(
2982-
default=1,
2982+
default=0,
29832983
description="Replicas is the desired number of replicas of the given template.",
29842984
)
29852985
spares: Optional[int] = Field(

src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from pydantic import ConfigDict, Field
2+
3+
from sagemaker.hyperpod.cli.constants.command_constants import INSTANCE_TYPE_LABEL, NVIDIA_GPU_RESOURCE_LIMIT_KEY, \
4+
NEURON_RESOURCE_LIMIT_KEY
25
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_unified_config import (
36
_HyperPodPytorchJob, HyperPodPytorchJobStatus
47
)
@@ -18,6 +21,9 @@
1821
import yaml
1922
import logging
2023

24+
from hyperpod_pytorch_job_template.quota_allocation_util import _is_valid, _get_resources_from_compute_quotas, _get_resources_from_instance, _get_limits
25+
26+
2127

2228
TRAINING_GROUP = "sagemaker.amazonaws.com"
2329
API_VERSION = "v1"
@@ -52,6 +58,88 @@ def verify_kube_config(cls):
5258

5359
# Verify Kubernetes version compatibility
5460
verify_kubernetes_version_compatibility(cls.get_logger())
61+
@classmethod
62+
def sanitize_memory(cls, resource):
63+
if 'memory' in resource:
64+
memory = resource['memory']
65+
# Case when quotas have been already initialized in CLI layer
66+
# ToDo : Cleanup quota initialization in CLI layer and directly use SDK layer for init.
67+
memory.replace('GiGi', 'Gi')
68+
resource['memory'] = memory
69+
70+
@classmethod
71+
def _process_replica_resources(cls, data):
72+
"""Process and validate replica resource configuration."""
73+
try:
74+
node_count = data['replicas']
75+
76+
# Extract nested configuration with validation
77+
template = data.get('template', {})
78+
spec = template.get('spec', {})
79+
node_selector = spec.get('nodeSelector', {})
80+
containers = spec.get('containers', [])
81+
82+
if not containers:
83+
raise ValueError("No containers found in template spec")
84+
85+
instance_type = node_selector.get(INSTANCE_TYPE_LABEL, None)
86+
if not instance_type:
87+
raise ValueError("Instance type not found in node selector")
88+
89+
container = containers[0]
90+
resources = container.get('resources', {})
91+
requests = resources.get('requests', {})
92+
limits = resources.get('limits', {})
93+
94+
# Extract resource values
95+
vcpu = requests.get('vcpu', None)
96+
memory = requests.get('memory', None)
97+
accelerators = requests.get(NVIDIA_GPU_RESOURCE_LIMIT_KEY) or requests.get(NEURON_RESOURCE_LIMIT_KEY) or None
98+
memory_limit = limits.get('memory', None)
99+
vcpu_limit = limits.get('vcpu', None)
100+
accelerators_limit = limits.get(NVIDIA_GPU_RESOURCE_LIMIT_KEY) or requests.get(NEURON_RESOURCE_LIMIT_KEY) or None
101+
102+
# Validate configuration
103+
valid, error = _is_valid(vcpu, memory, accelerators, node_count, instance_type)
104+
if not valid:
105+
raise ValueError(error)
106+
107+
# Calculate resource values
108+
requests_value = (_get_resources_from_compute_quotas(instance_type, vcpu, memory, accelerators)
109+
or _get_resources_from_instance(instance_type, node_count))
110+
limits_value = _get_limits(instance_type, vcpu_limit, memory_limit, accelerators_limit)
111+
requests_value = cls.sanitize_memory(requests_value)
112+
limits_value = cls.sanitze_memory(limits_value)
113+
114+
# Update data with calculated values
115+
data['template']['spec']['containers'][0]['resources']['requests'] = requests_value
116+
data['template']['spec']['containers'][0]['resources']['limits'] = limits_value
117+
return data
118+
except KeyError as e:
119+
raise ValueError(f"Missing required configuration key: {str(e)}")
120+
121+
@classmethod
122+
def _get_container_resources(cls, replica_spec):
123+
"""Extract container resources from replica spec."""
124+
container_resources = replica_spec['template']['spec']['containers'][0]['resources']
125+
return container_resources['requests'], container_resources['limits']
126+
127+
@classmethod
128+
def allocate_quotas_if_applicable(cls, spec):
129+
try:
130+
spec_dict = spec.model_dump()
131+
replica_spec = spec_dict['replicaSpecs'][0]
132+
cls._process_replica_resources(replica_spec)
133+
134+
# Update the original spec object directly
135+
requests, limits = cls._get_container_resources(replica_spec)
136+
spec.replicaSpecs[0].template.spec.containers[0].resources.requests = requests
137+
spec.replicaSpecs[0].template.spec.containers[0].resources.limits = limits
138+
139+
return spec
140+
except Exception as e:
141+
print(f"Warning: in quota allocation: {e}. using defaults.")
142+
return spec
55143

56144
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_pytorchjob")
57145
def create(self, debug=False):
@@ -65,6 +153,10 @@ def create(self, debug=False):
65153
if not self.metadata.namespace:
66154
self.metadata.namespace = get_default_namespace()
67155

156+
spec = self.allocate_quotas_if_applicable(spec)
157+
if spec.replicaSpecs[0].replicas == 0 :
158+
spec.replicaSpecs[0].replicas = 1 # default value
159+
68160
config = {
69161
"apiVersion": f"{TRAINING_GROUP}/{API_VERSION}",
70162
"kind": KIND,
@@ -91,6 +183,8 @@ def create(self, debug=False):
91183
logger.error(f"Failed to create HyperPodPytorchJob {self.metadata.name}!")
92184
handle_exception(e, self.metadata.name, self.metadata.namespace)
93185

186+
187+
94188
@classmethod
95189
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pytorchjobs")
96190
def list(cls, namespace=None) -> List["HyperPodPytorchJob"]:

0 commit comments

Comments
 (0)