Skip to content

Commit f571859

Browse files
update v1.1 pytorch job template to match parity with v1.0 change in staging repo (#228)
1 parent cc9eec6 commit f571859

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from .quota_allocation_util import _is_valid, _get_resources_from_compute_quotas, _get_resources_from_instance, _get_limits
2424

2525
class VolumeConfig(BaseModel):
26+
model_config = ConfigDict(extra="forbid")
27+
2628
name: str = Field(
2729
...,
2830
description="Volume name",
@@ -109,16 +111,15 @@ class PyTorchJobConfig(BaseModel):
109111
min_length=1
110112
)
111113
node_count: Optional[int] = Field(
112-
default=None,
114+
default=1,
113115
alias="node_count",
114116
description="Number of nodes",
115117
ge=1
116118
)
117-
tasks_per_node: Optional[int] = Field(
118-
default=None,
119+
tasks_per_node: Optional[str] = Field(
120+
default="auto",
119121
alias="tasks_per_node",
120-
description="Number of tasks per node",
121-
ge=1
122+
description="Number of workers per node; supported values: [auto,cpu, gpu, int]",
122123
)
123124
label_selector: Optional[Dict[str, str]] = Field(
124125
default=None,

hyperpod-pytorch-job-template/hyperpod_pytorch_job_template/v1_1/schema.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,22 +202,22 @@
202202
"type": "null"
203203
}
204204
],
205-
"default": null,
205+
"default": 1,
206206
"description": "Number of nodes",
207207
"title": "Node Count"
208208
},
209209
"tasks_per_node": {
210210
"anyOf": [
211211
{
212212
"minimum": 1,
213-
"type": "integer"
213+
"type": "string"
214214
},
215215
{
216216
"type": "null"
217217
}
218218
],
219-
"default": null,
220-
"description": "Number of tasks per node",
219+
"default": "auto",
220+
"description": "Number of workers per node; supported values: [auto,cpu, gpu, int]",
221221
"title": "Tasks Per Node"
222222
},
223223
"label_selector": {

test/unit_tests/cli/test_training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,9 @@ def test_integer_field_validation_success(self):
441441
config = PyTorchJobConfig(
442442
job_name="test-job",
443443
image="pytorch:latest",
444-
tasks_per_node=8
444+
tasks_per_node="auto"
445445
)
446-
self.assertEqual(config.tasks_per_node, 8)
446+
self.assertEqual(config.tasks_per_node, "auto")
447447

448448
# Test max_retry
449449
config = PyTorchJobConfig(
@@ -755,7 +755,7 @@ def test_comprehensive_valid_config(self):
755755
pull_policy="Always",
756756
instance_type="ml.p4d.24xlarge",
757757
node_count=2,
758-
tasks_per_node=8,
758+
tasks_per_node="auto",
759759
label_selector={"accelerator": "nvidia"},
760760
queue_name="training-queue",
761761
priority="high",
@@ -774,7 +774,7 @@ def test_comprehensive_valid_config(self):
774774
self.assertEqual(config.pull_policy, "Always")
775775
self.assertEqual(config.instance_type, "ml.p4d.24xlarge")
776776
self.assertEqual(config.node_count, 2)
777-
self.assertEqual(config.tasks_per_node, 8)
777+
self.assertEqual(config.tasks_per_node, "auto")
778778
self.assertEqual(config.label_selector, {"accelerator": "nvidia"})
779779
self.assertEqual(config.queue_name, "training-queue")
780780
self.assertEqual(config.priority, "high")

0 commit comments

Comments
 (0)