Skip to content

Commit 853dfa8

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
feat: add get_operator_logs to pytorch job (#218)
* feat: add get_operator_logs to pytorch job * feat: add get_operator_logs to pytorch job * feat: add get_operator_logs to pytorch job * feat: add get_operator_logs to pytorch job --------- Co-authored-by: Roja Reddy Sareddy <[email protected]>
1 parent 99121e7 commit 853dfa8

File tree

7 files changed

+103
-2
lines changed

7 files changed

+103
-2
lines changed

src/sagemaker/hyperpod/cli/commands/training.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,21 @@ def pytorch_get_logs(job_name: str, pod_name: str, namespace: str):
354354

355355
except Exception as e:
356356
raise click.UsageError(f"Failed to list jobs: {str(e)}")
357+
358+
359+
@click.command("hyp-pytorch-job")
360+
@click.option(
361+
"--since-hours",
362+
type=click.FLOAT,
363+
required=True,
364+
help="Required. The time frame to get logs for.",
365+
)
366+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "get_pytorch_operator_logs")
367+
def pytorch_get_operator_logs(
368+
since_hours: float,
369+
):
370+
"""
371+
Get operator logs for pytorch training jobs.
372+
"""
373+
logs = HyperPodPytorchJob.get_operator_logs(since_hours=since_hours)
374+
click.echo(logs)

src/sagemaker/hyperpod/cli/hyp_cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
pytorch_delete,
1717
pytorch_list_pods,
1818
pytorch_get_logs,
19+
pytorch_get_operator_logs,
1920
)
2021
from sagemaker.hyperpod.cli.commands.inference import (
2122
js_create,
@@ -139,6 +140,7 @@ def get_operator_logs():
139140
get_logs.add_command(js_get_logs)
140141
get_logs.add_command(custom_get_logs)
141142

143+
get_operator_logs.add_command(pytorch_get_operator_logs)
142144
get_operator_logs.add_command(js_get_operator_logs)
143145
get_operator_logs.add_command(custom_get_operator_logs)
144146

src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
API_VERSION = "v1"
2424
PLURAL = "hyperpodpytorchjobs"
2525
KIND = "HyperPodPyTorchJob"
26+
TRAINING_OPERATOR_NAMESPACE = "aws-hyperpod"
27+
TRAINING_OPERATOR_LABEL = "hp-training-control-plane"
2628

2729

2830
class HyperPodPytorchJob(_HyperPodPytorchJob):
@@ -233,6 +235,40 @@ def get_logs_from_pod(self, pod_name: str, container: Optional[str] = None) -> s
233235
logger.error(f"Failed to get logs from pod {pod_name}!")
234236
handle_exception(e, self.metadata.name, self.metadata.namespace)
235237

238+
@classmethod
239+
@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_operator_logs_pytorchjob")
240+
def get_operator_logs(cls, since_hours: float):
241+
cls.verify_kube_config()
242+
243+
v1 = client.CoreV1Api()
244+
245+
# Get pods with the training operator label directly
246+
pods = v1.list_namespaced_pod(
247+
namespace=TRAINING_OPERATOR_NAMESPACE,
248+
label_selector=TRAINING_OPERATOR_LABEL
249+
)
250+
251+
if not pods.items:
252+
raise Exception(
253+
f"No training operator pod found with label {TRAINING_OPERATOR_LABEL}"
254+
)
255+
256+
# Use the first pod found
257+
operator_pod = pods.items[0]
258+
pod_name = operator_pod.metadata.name
259+
260+
try:
261+
logs = v1.read_namespaced_pod_log(
262+
name=pod_name,
263+
namespace=TRAINING_OPERATOR_NAMESPACE,
264+
timestamps=True,
265+
since_seconds=int(3600 * since_hours),
266+
)
267+
except Exception as e:
268+
handle_exception(e, pod_name, TRAINING_OPERATOR_NAMESPACE)
269+
270+
return logs
271+
236272

237273
def _load_hp_job(response: dict) -> HyperPodPytorchJob:
238274

test/integration_tests/training/cli/test_cli_training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,4 +239,9 @@ def test_delete_job(self, test_job_name):
239239
assert list_result.returncode == 0
240240

241241
# The job name should no longer be in the output
242-
assert test_job_name not in list_result.stdout
242+
assert test_job_name not in list_result.stdout
243+
244+
def test_pytorch_get_operator_logs():
245+
"""Test getting operator logs via CLI"""
246+
result = execute_command(["hyp", "get-operator-logs", "hyp-pytorch-job", "--since-hours", "1"])
247+
assert result.returncode == 0

test/integration_tests/training/sdk/test_sdk_training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,8 @@ def test_delete_job(self, pytorch_job):
112112
jobs = HyperPodPytorchJob.list()
113113
job_names = [job.metadata.name for job in jobs]
114114
assert pytorch_job.metadata.name not in job_names
115+
116+
def test_get_operator_logs():
117+
"""Test getting operator logs"""
118+
logs = HyperPodPytorchJob.get_operator_logs(since_hours=1)
119+
assert logs

test/unit_tests/cli/test_training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
pytorch_create,
77
list_jobs,
88
pytorch_describe,
9+
pytorch_get_operator_logs,
910
)
1011
from hyperpod_pytorch_job_template.v1_1.model import ALLOWED_TOPOLOGY_LABELS
1112
import sys
@@ -827,3 +828,12 @@ def test_none_topology_labels(self):
827828
)
828829
self.assertIsNone(config.preferred_topology)
829830
self.assertIsNone(config.required_topology)
831+
832+
@patch('sagemaker.hyperpod.cli.commands.training.HyperPodPytorchJob')
833+
def test_pytorch_get_operator_logs(mock_hp):
834+
mock_hp.get_operator_logs.return_value = "operator logs"
835+
runner = CliRunner()
836+
result = runner.invoke(pytorch_get_operator_logs, ['--since-hours', '2'])
837+
assert result.exit_code == 0
838+
assert 'operator logs' in result.output
839+
mock_hp.get_operator_logs.assert_called_once_with(since_hours=2.0)

test/unit_tests/training/test_hyperpod_pytorch_job.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,31 @@ def test_get_logs_from_pod_with_container_name(
283283
)
284284
self.assertEqual(result, "test logs")
285285

286+
@patch("kubernetes.client.CoreV1Api")
287+
@patch.object(HyperPodPytorchJob, "verify_kube_config")
288+
def test_get_operator_logs(self, mock_verify_config, mock_core_api):
289+
# Mock only the training operator pod (since we're using label selector)
290+
mock_operator_pod = MagicMock()
291+
mock_operator_pod.metadata.name = "training-operator-pod-abc123"
292+
293+
mock_core_api.return_value.list_namespaced_pod.return_value.items = [mock_operator_pod]
294+
mock_core_api.return_value.read_namespaced_pod_log.return_value = "training operator logs"
295+
296+
result = HyperPodPytorchJob.get_operator_logs(2.5)
297+
298+
self.assertEqual(result, "training operator logs")
299+
# Verify label selector is used
300+
mock_core_api.return_value.list_namespaced_pod.assert_called_once_with(
301+
namespace="aws-hyperpod",
302+
label_selector="hp-training-control-plane"
303+
)
304+
mock_core_api.return_value.read_namespaced_pod_log.assert_called_once_with(
305+
name="training-operator-pod-abc123",
306+
namespace="aws-hyperpod",
307+
timestamps=True,
308+
since_seconds=9000,
309+
)
310+
286311

287312
class TestLoadHpJob(unittest.TestCase):
288313
"""Test the _load_hp_job function"""
@@ -350,4 +375,4 @@ def test_load_hp_job_list_empty(self):
350375
result = _load_hp_job_list(response)
351376

352377
self.assertEqual(len(result), 0)
353-
self.assertEqual(result, [])
378+
self.assertEqual(result, [])

0 commit comments

Comments
 (0)