1
1
from pydantic import ConfigDict , Field
2
+
3
+ from sagemaker .hyperpod .cli .constants .command_constants import INSTANCE_TYPE_LABEL , NVIDIA_GPU_RESOURCE_LIMIT_KEY , \
4
+ NEURON_RESOURCE_LIMIT_KEY
2
5
from sagemaker .hyperpod .training .config .hyperpod_pytorch_job_unified_config import (
3
6
_HyperPodPytorchJob , HyperPodPytorchJobStatus
4
7
)
18
21
import yaml
19
22
import logging
20
23
24
+ from hyperpod_pytorch_job_template .quota_allocation_util import _is_valid , _get_resources_from_compute_quotas , _get_resources_from_instance , _get_limits
25
+
26
+
21
27
22
28
TRAINING_GROUP = "sagemaker.amazonaws.com"
23
29
API_VERSION = "v1"
@@ -52,6 +58,109 @@ def verify_kube_config(cls):
52
58
53
59
# Verify Kubernetes version compatibility
54
60
verify_kubernetes_version_compatibility (cls .get_logger ())
61
+ @classmethod
62
+ def _extract_numeric_value (cls , value ):
63
+ """Extract numeric value from strings like '1.5Gi' -> 1.5"""
64
+ if not value :
65
+ return None
66
+ import re
67
+ match = re .match (r'^([0-9]*\.?[0-9]+)' , str (value ))
68
+ return float (match .group (1 )) if match else None
69
+
70
+ @classmethod
71
+ def sanitize_memory (cls , resource ):
72
+ try :
73
+ if 'memory' in resource :
74
+ memory = resource ['memory' ]
75
+ # Case when quotas have been already initialized in CLI layer
76
+ # ToDo : Cleanup quota initialization in CLI layer and directly use SDK layer for init.
77
+ memory .replace ('GiGi' , 'Gi' )
78
+ resource ['memory' ] = memory
79
+ return resource
80
+ except Exception as e :
81
+ return resource
82
+
83
+
84
+ @classmethod
85
+ def _process_replica_resources (cls , data ):
86
+ """Process and validate replica resource configuration."""
87
+ try :
88
+ node_count = data .get ('replicas' , None )
89
+
90
+ # Extract nested configuration with validation
91
+ template = data .get ('template' , {})
92
+ spec = template .get ('spec' , {})
93
+ node_selector = spec .get ('nodeSelector' , {})
94
+ instance_type = node_selector .get (INSTANCE_TYPE_LABEL ) if node_selector else None
95
+
96
+ if not instance_type :
97
+ return None
98
+
99
+ containers = spec .get ('containers' , [])
100
+
101
+ if not containers :
102
+ raise ValueError ("No containers found in template spec" )
103
+
104
+ container = containers [0 ]
105
+ resources = container .get ('resources' , {})
106
+ requests = resources .get ('requests' , {})
107
+ limits = resources .get ('limits' , {})
108
+
109
+ # Extract resource values
110
+ vcpu = float (requests .get ('cpu' )) if requests .get ('cpu' ) else None
111
+ memory = cls ._extract_numeric_value (requests .get ('memory' ))
112
+ accelerators = int (requests .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY )) or int (requests .get (NEURON_RESOURCE_LIMIT_KEY )) or None
113
+ memory_limit = cls ._extract_numeric_value (limits .get ('memory' ))
114
+ vcpu_limit = float (limits .get ('cpu' )) if limits .get ('cpu' ) else None
115
+ accelerators_limit = int (limits .get (NVIDIA_GPU_RESOURCE_LIMIT_KEY )) or int (limits .get (NEURON_RESOURCE_LIMIT_KEY )) or None
116
+
117
+ # Validate configuration
118
+ valid , error = _is_valid (vcpu , memory , accelerators , node_count , instance_type )
119
+ if not valid :
120
+ raise ValueError (error )
121
+
122
+ # Calculate resource values
123
+ requests_value = (_get_resources_from_compute_quotas (instance_type , vcpu , memory , accelerators )
124
+ or _get_resources_from_instance (instance_type , node_count ))
125
+ limits_value = _get_limits (instance_type , vcpu_limit , memory_limit , accelerators_limit )
126
+
127
+ requests_value = cls .sanitize_memory (requests_value )
128
+ limits_value = cls .sanitize_memory (limits_value )
129
+
130
+ # Update data with calculated values
131
+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['requests' ] = requests_value
132
+ data ['template' ]['spec' ]['containers' ][0 ]['resources' ]['limits' ] = limits_value
133
+ return data
134
+ except KeyError as e :
135
+ raise ValueError (f"Missing required configuration key: { str (e )} " )
136
+
137
+ @classmethod
138
+ def _get_container_resources (cls , replica_spec ):
139
+ """Extract container resources from replica spec."""
140
+ container_resources = replica_spec ['template' ]['spec' ]['containers' ][0 ]['resources' ]
141
+ return container_resources ['requests' ], container_resources ['limits' ]
142
+
143
+ @classmethod
144
+ def allocate_quotas_if_applicable (cls , spec ):
145
+ logger = cls .get_logger ()
146
+ logger = setup_logging (logger )
147
+ try :
148
+ spec_dict = spec .model_dump ()
149
+ replica_spec = spec_dict ['replicaSpecs' ][0 ]
150
+ cls ._process_replica_resources (replica_spec )
151
+
152
+ # Update the original spec object directly
153
+ requests , limits = cls ._get_container_resources (replica_spec )
154
+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .requests = requests
155
+ spec .replicaSpecs [0 ].template .spec .containers [0 ].resources .limits = limits
156
+
157
+ return spec
158
+ except ValueError as e :
159
+ logger .error (f"Error: in quota allocation:{ e } " )
160
+ raise ValueError (e )
161
+ except Exception as e :
162
+ logger .info (f"Warning: in quota allocation: { e } . using defaults." )
163
+ return spec
55
164
56
165
@_hyperpod_telemetry_emitter (Feature .HYPERPOD , "create_pytorchjob" )
57
166
def create (self , debug = False ):
@@ -65,6 +174,10 @@ def create(self, debug=False):
65
174
if not self .metadata .namespace :
66
175
self .metadata .namespace = get_default_namespace ()
67
176
177
+ spec = self .allocate_quotas_if_applicable (spec )
178
+ if spec .replicaSpecs [0 ].replicas == 0 :
179
+ spec .replicaSpecs [0 ].replicas = 1 # default value
180
+
68
181
config = {
69
182
"apiVersion" : f"{ TRAINING_GROUP } /{ API_VERSION } " ,
70
183
"kind" : KIND ,
@@ -91,6 +204,8 @@ def create(self, debug=False):
91
204
logger .error (f"Failed to create HyperPodPytorchJob { self .metadata .name } !" )
92
205
handle_exception (e , self .metadata .name , self .metadata .namespace )
93
206
207
+
208
+
94
209
@classmethod
95
210
@_hyperpod_telemetry_emitter (Feature .HYPERPOD , "list_pytorchjobs" )
96
211
def list (cls , namespace = None ) -> List ["HyperPodPytorchJob" ]:
0 commit comments