Skip to content

Commit fc3cb62

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
Training CLI for Launch - Changes per SDK HyperPodPytorchJob constructor (#64)
* Training CLI for Launch * Training CLI for Launch --------- Co-authored-by: Roja Reddy Sareddy <[email protected]>
1 parent 60e56f4 commit fc3cb62

File tree

2 files changed

+28
-23
lines changed
  • hyperpod-pytorchjob-config-schemas/hyperpod_pytorchjob_config_schemas/v1_0
  • sagemaker-hyperpod/src/sagemaker/hyperpod/cli/commands

2 files changed

+28
-23
lines changed

hyperpod-pytorchjob-config-schemas/hyperpod_pytorchjob_config_schemas/v1_0/model.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from pydantic import BaseModel, ConfigDict, Field
22
from typing import Optional, List, Dict, Union
3-
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import _HyperPodPytorchJob, ReplicaSpec, RunPolicy, Template, Metadata, Spec
3+
from sagemaker.hyperpod.training.config.hyperpod_pytorch_job_config import ReplicaSpec, RunPolicy, Template, Metadata, Spec
44

55

66
class PyTorchJobConfig(BaseModel):
@@ -28,7 +28,7 @@ class PyTorchJobConfig(BaseModel):
2828

2929

3030

31-
def to_domain(self) -> _HyperPodPytorchJob:
31+
def to_domain(self) -> Dict:
3232
"""
3333
Convert flat config to domain model (HyperPodPytorchJobSpec)
3434
"""
@@ -38,13 +38,6 @@ def to_domain(self) -> _HyperPodPytorchJob:
3838
"image": self.image,
3939
}
4040

41-
# Add resources if needed (could be moved to SDK default)
42-
container["resources"] = {
43-
"limits": {
44-
"nvidia.com/gpu": 8
45-
}
46-
}
47-
4841
# Add optional container fields only if they're not None
4942
optional_container_fields = [
5043
("command", "command", self.command),
@@ -157,5 +150,5 @@ def to_domain(self) -> _HyperPodPytorchJob:
157150
# Create and return the domain model
158151
return { "name":self.job_name ,
159152
"namespace":self.namespace,
160-
"spec":_HyperPodPytorchJob(**job_kwargs)
153+
"spec":job_kwargs
161154
}

sagemaker-hyperpod/src/sagemaker/hyperpod/cli/commands/training.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,28 @@ def pytorch_create(version, config):
3636
job_name = config.get("name")
3737
namespace = config.get("namespace")
3838
spec = config.get("spec")
39-
# Create job with or without namespace
40-
if namespace is None:
41-
job = HyperPodPytorchJob(metadata=Metadata(name=job_name), spec=spec)
42-
else:
43-
job = HyperPodPytorchJob(
44-
metadata=Metadata(name=job_name, namespace=namespace), spec=spec
45-
)
4639

40+
# Prepare metadata
41+
metadata_kwargs = {"name": job_name}
42+
if namespace:
43+
metadata_kwargs["namespace"] = namespace
44+
45+
# Prepare job kwargs
46+
job_kwargs = {
47+
"metadata": Metadata(**metadata_kwargs),
48+
"replica_specs": spec.get("replica_specs", [])
49+
}
50+
51+
# Add nproc_per_node if present
52+
if "nproc_per_node" in spec:
53+
job_kwargs["nproc_per_node"] = spec["nproc_per_node"]
54+
55+
# Add run_policy if present
56+
if "run_policy" in spec:
57+
job_kwargs["run_policy"] = spec["run_policy"]
58+
59+
# Create job
60+
job = HyperPodPytorchJob(**job_kwargs)
4761
job.create()
4862

4963
except Exception as e:
@@ -138,16 +152,14 @@ def pytorch_describe(job_name: str, namespace: str):
138152
click.echo("=" * 80)
139153
click.echo(f"Name: {job.metadata.name}")
140154
click.echo(f"Namespace: {job.metadata.namespace}")
141-
click.echo(f"API Version: {job.apiVersion}")
142-
click.echo(f"Kind: {job.kind}")
143155

144156
# Print Spec details
145157
click.echo("\nSpec:")
146158
click.echo("-" * 80)
147-
click.echo(f"Processes per Node: {job.spec.nprocPerNode}")
159+
click.echo(f"Processes per Node: {job.nprocPerNode}")
148160

149161
# Print Replica Specs
150-
for replica in job.spec.replicaSpecs:
162+
for replica in job.replicaSpecs:
151163
click.echo(f"\nReplica Spec:")
152164
click.echo(f" Name: {replica.name}")
153165
click.echo(f" Replicas: {replica.replicas}")
@@ -169,9 +181,9 @@ def pytorch_describe(job_name: str, namespace: str):
169181
# Print Run Policy
170182
click.echo("\nRun Policy:")
171183
click.echo("-" * 80)
172-
click.echo(f"Clean Pod Policy: {job.spec.runPolicy.cleanPodPolicy}")
184+
click.echo(f"Clean Pod Policy: {job.runPolicy.cleanPodPolicy}")
173185
click.echo(
174-
f"TTL Seconds After Finished: {job.spec.runPolicy.ttlSecondsAfterFinished}"
186+
f"TTL Seconds After Finished: {job.runPolicy.ttlSecondsAfterFinished}"
175187
)
176188

177189
# Print Status

0 commit comments

Comments
 (0)