1
1
from pydantic import ConfigDict , Field
2
+
3
+ from sagemaker .hyperpod .cli .constants .command_constants import INSTANCE_TYPE_LABEL , NVIDIA_GPU_RESOURCE_LIMIT_KEY , \
4
+ NEURON_RESOURCE_LIMIT_KEY
2
5
from sagemaker .hyperpod .training .config .hyperpod_pytorch_job_unified_config import (
3
6
_HyperPodPytorchJob , HyperPodPytorchJobStatus
4
7
)
18
21
import yaml
19
22
import logging
20
23
24
+ from hyperpod_pytorch_job_template .quota_allocation_util import _is_valid , _get_resources_from_compute_quotas , _get_resources_from_instance , _get_limits
25
+
26
+
21
27
22
28
TRAINING_GROUP = "sagemaker.amazonaws.com"
23
29
API_VERSION = "v1"
@@ -52,6 +58,88 @@ def verify_kube_config(cls):
52
58
53
59
# Verify Kubernetes version compatibility
54
60
verify_kubernetes_version_compatibility (cls .get_logger ())
61
+ @classmethod
62
+ def sanitize_memory (cls , resource ):
63
+ if 'memory' in resource :
64
+ memory = resource ['memory' ]
65
+ # Case when quotas have been already initialized in CLI layer
66
+ # ToDo : Cleanup quota initialization in CLI layer and directly use SDK layer for init.
67
+ memory .replace ('GiGi' , 'Gi' )
68
+ resource ['memory' ] = memory
69
+
70
+ @classmethod
71
+ def _process_replica_resources (cls , data ):
72
+ """Process and validate replica resource configuration."""
73
+ try :
74
+ node_count = data ['replicas' ]
75
+
76
+ # Extract nested configuration with validation
77
+ template = data .get ('template' , {})
78
+ spec = template .get ('spec' , {})
79
+ node_selector = spec .get ('nodeSelector' , {})
80
+ containers = spec .get ('containers' , [])
81
+
82
+ if not containers :
83
+ raise ValueError ("No containers found in template spec" )
84
+
85
+ instance_type = node_selector .get (INSTANCE_TYPE_LABEL , None )
86
+ if not instance_type :
87
+ raise ValueError ("Instance type not found in node selector" )
88
+
89
+ container = containers [0 ]
90
+ resources = container .get ('resources' , {})
91
+ requests = resources .get ('requests' , {})
92
+ limits = resources .get ('limits' , {})
93
+
94
+ # Extract resource values
95
+ vcpu = requests .get ('vcpu' , None )
96
+ memory = requests .get ('memory' , None )
97
+ accelerators = requests .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY ) or requests .get (NEURON_RESOURCE_LIMIT_KEY ) or None
98
+ memory_limit = limits .get ('memory' , None )
99
+ vcpu_limit = limits .get ('vcpu' , None )
100
+ accelerators_limit = limits .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY ) or requests .get (NEURON_RESOURCE_LIMIT_KEY ) or None
101
+
102
+ # Validate configuration
103
+ valid , error = _is_valid (vcpu , memory , accelerators , node_count , instance_type )
104
+ if not valid :
105
+ raise ValueError (error )
106
+
107
+ # Calculate resource values
108
+ requests_value = (_get_resources_from_compute_quotas (instance_type , vcpu , memory , accelerators )
109
+ or _get_resources_from_instance (instance_type , node_count ))
110
+ limits_value = _get_limits (instance_type , vcpu_limit , memory_limit , accelerators_limit )
111
+ requests_value = cls .sanitize_memory (requests_value )
112
+ limits_value = cls .sanitze_memory (limits_value )
113
+
114
+ # Update data with calculated values
115
+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['requests' ] = requests_value
116
+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['limits' ] = limits_value
117
+ return data
118
+ except KeyError as e :
119
+ raise ValueError (f"Missing required configuration key: { str (e )} " )
120
+
121
+ @classmethod
122
+ def _get_container_resources (cls , replica_spec ):
123
+ """Extract container resources from replica spec."""
124
+ container_resources = replica_spec ['template' ]['spec' ]['containers' ][0 ]['resources' ]
125
+ return container_resources ['requests' ], container_resources ['limits' ]
126
+
127
+ @classmethod
128
+ def allocate_quotas_if_applicable (cls , spec ):
129
+ try :
130
+ spec_dict = spec .model_dump ()
131
+ replica_spec = spec_dict ['replicaSpecs' ][0 ]
132
+ cls ._process_replica_resources (replica_spec )
133
+
134
+ # Update the original spec object directly
135
+ requests , limits = cls ._get_container_resources (replica_spec )
136
+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .requests = requests
137
+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .limits = limits
138
+
139
+ return spec
140
+ except Exception as e :
141
+ print (f"Warning: in quota allocation: { e } . using defaults." )
142
+ return spec
55
143
56
144
@_hyperpod_telemetry_emitter (Feature .HYPERPOD , "create_pytorchjob" )
57
145
def create (self , debug = False ):
@@ -65,6 +153,10 @@ def create(self, debug=False):
65
153
if not self .metadata .namespace :
66
154
self .metadata .namespace = get_default_namespace ()
67
155
156
+ spec = self .allocate_quotas_if_applicable (spec )
157
+ if spec .replicaSpecs [0 ].replicas == 0 :
158
+ spec .replicaSpecs [0 ].replicas = 1 # default value
159
+
68
160
config = {
69
161
"apiVersion" : f"{ TRAINING_GROUP } /{ API_VERSION } " ,
70
162
"kind" : KIND ,
@@ -91,6 +183,8 @@ def create(self, debug=False):
91
183
logger .error (f"Failed to create HyperPodPytorchJob { self .metadata .name } !" )
92
184
handle_exception (e , self .metadata .name , self .metadata .namespace )
93
185
186
+
187
+
94
188
@classmethod
95
189
@_hyperpod_telemetry_emitter (Feature .HYPERPOD , "list_pytorchjobs" )
96
190
def list (cls , namespace = None ) -> List ["HyperPodPytorchJob" ]:
0 commit comments