Skip to content

Commit e8b5b27

Browse files
Add pre and post scripts args to run before and after job execution (#64)
* feat: add pre-scripts and post-scripts args for start-job command * fix: corrected delimiter in documentation and updatetd config file instead of using override-parameters * test: added unit test cases * fix: updated config field used to training_cfg * chore: found bug in cli adding extra parm in test to unblock PR
1 parent 7d4ec29 commit e8b5b27

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ hyperpod connect-cluster --cluster-name <cluster-name> [--region <region>] [--na
134134
This command submits a new training job to the connected SageMaker HyperPod cluster.
135135
136136
```
137-
hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <kubeflow/PyTorchJob>] [--image <image>] [--command <command>] [--entry-script <script>] [--script-args <arg1 arg2>] [--environment <key=value>] [--pull-policy <Always|IfNotPresent|Never>] [--instance-type <instance-type>] [--node-count <count>] [--tasks-per-node <count>] [--label-selector <key=value>] [--deep-health-check-passed-nodes-only] [--scheduler-type <Kueue SageMaker None>] [--queue-name <queue-name>] [--priority <priority>] [--auto-resume] [--max-retry <count>] [--restart-policy <Always|OnFailure|Never|ExitCode>] [--volumes <volume1,volume2>] [--persistent-volume-claims <claim1:/mount/path,claim2:/mount/path>] [--results-dir <dir>] [--service-account-name <account>]
137+
hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <kubeflow/PyTorchJob>] [--image <image>] [--command <command>] [--entry-script <script>] [--script-args <arg1 arg2>] [--environment <key=value>] [--pull-policy <Always|IfNotPresent|Never>] [--instance-type <instance-type>] [--node-count <count>] [--tasks-per-node <count>] [--label-selector <key=value>] [--deep-health-check-passed-nodes-only] [--scheduler-type <Kueue SageMaker None>] [--queue-name <queue-name>] [--priority <priority>] [--auto-resume] [--max-retry <count>] [--restart-policy <Always|OnFailure|Never|ExitCode>] [--volumes <volume1,volume2>] [--persistent-volume-claims <claim1:/mount/path,claim2:/mount/path>] [--results-dir <dir>] [--service-account-name <account>] [--pre-script <cmd1 cmd2>] [--post-script <cmd1 cmd2>]
138138
```
139139
140140
* `job-name` (string) - Required. The base name of the job. A unique identifier (UUID) will automatically be appended to the name like `<job-name>-<UUID>`.
@@ -147,6 +147,9 @@ hyperpod start-job --job-name <job-name> [--namespace <namespace>] [--job-kind <
147147
* `script-args` (list[string]) - Optional. The list of arguments for entry scripts.
148148
* `environment` (dict[string, string]) - Optional. The environment variables (key-value pairs) to set in the containers.
149149
* `node-count` (int) - Required. The number of nodes (instances) to launch the jobs on.
150+
* `instance-type` (string) - Required. The instance type to launch the job on. Note that the instance types you can use are the available instances within your SageMaker quotas for instances prefixed with `ml`.
151+
* `pre-script` (list[string]) - Optional. Commands to run before the job starts. Multiple commands should be separated by comma.
152+
* `post-script` (list[string]) - Optional. Commands to run after the job completes. Multiple commands should be separated by comma.
150153
* `instance-type` (string) - Required. The instance type to launch the job on. Note that the instance types you can use are the available instances within your SageMaker quotas for instances prefixed with `ml`. If `node.kubernetes.io/instance-type` is provided via the `label-selector` it will take precedence for node selection.
151154
* `tasks-per-node` (int) - Optional. The number of devices to use per instance.
152155
* `label-selector` (dict[string, list[string]]) - Optional. A dictionary of labels and their values that will override the predefined node selection rules based on the SageMaker HyperPod `node-health-status` label and values. If users provide this field, the CLI will launch the job with this customized label selection.

src/hyperpod_cli/commands/job.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,18 @@ def cancel_job(
431431
help="Optional. Add a temp directory for containers to store data in the hosts."
432432
" <volume_name>:</host/mount/path>:</container/mount/path>,<volume_name>:</host/mount/path1>:</container/mount/path1>",
433433
)
434+
@click.option(
435+
"--pre-script",
436+
type=click.STRING,
437+
required=False,
438+
help="Optional. Commands to run before the job starts. Multiple commands should be separated by semicolons.",
439+
)
440+
@click.option(
441+
"--post-script",
442+
type=click.STRING,
443+
required=False,
444+
help="Optional. Commands to run after the job completes. Multiple commands should be separated by semicolons.",
445+
)
434446
@click.option(
435447
"--recipe",
436448
type=click.STRING,
@@ -549,6 +561,8 @@ def start_job(
549561
service_account_name: Optional[str],
550562
persistent_volume_claims: Optional[str],
551563
volumes: Optional[str],
564+
pre_script: Optional[str],
565+
post_script: Optional[str],
552566
recipe: Optional[str],
553567
override_parameters: Optional[str],
554568
debug: bool,
@@ -721,6 +735,23 @@ def start_job(
721735
custom_labels[KUEUE_WORKLOAD_PRIORITY_CLASS_LABEL_KEY] = priority
722736
priority = None
723737

738+
# Handle pre_script
739+
if pre_script:
740+
_override_or_remove(
741+
config["training_cfg"],
742+
"pre_script",
743+
pre_script.split(',')
744+
)
745+
746+
# Handle post_script
747+
if post_script:
748+
_override_or_remove(
749+
config["training_cfg"],
750+
"post_script",
751+
post_script.split(',')
752+
)
753+
754+
724755
_override_or_remove(
725756
config["cluster"]["cluster_config"],
726757
"custom_labels",
@@ -807,7 +838,7 @@ def start_job(
807838
auto_resume=auto_resume,
808839
label_selector=label_selector,
809840
max_retry=max_retry,
810-
deep_health_check_passed_nodes_only=deep_health_check_passed_nodes_only,
841+
deep_health_check_passed_nodes_only=deep_health_check_passed_nodes_only
811842
)
812843
# TODO: Unblock this after fixing customer using EKS cluster.
813844
console_link = utils.get_cluster_console_url()

test/unit_tests/test_job.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,53 @@ def test_start_job_with_cli_args_label_selection_not_json_str(
886886
)
887887
self.assertEqual(result.exit_code, 1)
888888

889+
@mock.patch("yaml.dump")
890+
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
891+
@mock.patch("hyperpod_cli.commands.job.JobValidator")
892+
@mock.patch("boto3.Session")
893+
def test_start_job_with_cli_args_pre_script_and_post_script(
894+
self,
895+
mock_boto3,
896+
mock_validator_cls,
897+
mock_kubernetes_client,
898+
mock_yaml_dump,
899+
):
900+
mock_validator = mock_validator_cls.return_value
901+
mock_validator.validate_aws_credential.return_value = True
902+
mock_kubernetes_client.get_current_context_namespace.return_value = "kubeflow"
903+
mock_yaml_dump.return_value = None
904+
result = self.runner.invoke(
905+
start_job,
906+
[
907+
"--job-name",
908+
"test-job",
909+
"--instance-type",
910+
"ml.c5.xlarge",
911+
"--image",
912+
"pytorch:1.9.0-cuda11.1-cudnn8-runtime",
913+
"--node-count",
914+
"2",
915+
"--label-selector",
916+
"{NonJsonStr",
917+
"--entry-script",
918+
"/opt/train/src/train.py",
919+
"--pre-script",
920+
"echo 'test', echo 'test 1'",
921+
"--post-script",
922+
"echo 'test 1', echo 'test 2'",
923+
"--label-selector",
924+
'{"preferred": {"node.kubernetes.io/instance-type": ["ml.c5.xlarge"]}}'
925+
],
926+
)
927+
928+
# Assert that yaml.dump was called with the correct configuration
929+
mock_yaml_dump.assert_called_once()
930+
call_args = mock_yaml_dump.call_args[0]
931+
self.assertEqual(call_args[0]['training_cfg']['pre_script'], ["echo 'test'", " echo 'test 1'"])
932+
self.assertEqual(call_args[0]['training_cfg']['post_script'], ["echo 'test 1'", " echo 'test 2'"])
933+
934+
self.assertEqual(result.exit_code, 1)
935+
889936
@mock.patch("yaml.dump")
890937
@mock.patch("hyperpod_cli.clients.kubernetes_client.KubernetesClient.__new__")
891938
@mock.patch("hyperpod_cli.commands.job.JobValidator")

0 commit comments

Comments
 (0)