Skip to content

Commit 389d6ae

Browse files
authored
Add unit test and fix HyperPod Manager (#84)
* Add unit test and fix HyperPod Manager 1. Default namespace can be set by HyperpodManager.set_context() 2. Added unit tests for inference * Remove debug print
1 parent 7c30bbc commit 389d6ae

13 files changed

+575
-936
lines changed

src/sagemaker/hyperpod/common/config/metadata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ class Metadata(BaseModel):
99
description="Name must match the name of one entry in pod.spec.resourceClaims of the Pod where this field is used. It makes that resource available inside a container."
1010
)
1111
namespace: Optional[str] = Field(
12-
default="default",
12+
default=None,
1313
description="Name must match the name of one entry in pod.spec.resourceClaims of the Pod where this field is used. It makes that resource available inside a container.",
1414
)
1515
labels: Optional[Dict[str, str]] = Field(
1616
default=None,
1717
description="Labels are key value pairs that are attached to objects, such as Pod. Labels are intended to be used to specify identifying attributes of objects. The system ignores labels that are not in the service's selector. Labels can only be added to objects during creation. More info: XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX",
18-
)
18+
)

src/sagemaker/hyperpod/common/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from kubernetes import config as k8s_config
33
from pydantic import ValidationError
44
from kubernetes.client.exceptions import ApiException
5+
from kubernetes import config
56

67

78
def validate_cluster_connection():
@@ -13,6 +14,23 @@ def validate_cluster_connection():
1314
return False
1415

1516

17+
def get_default_namespace():
18+
_, active_context = config.list_kube_config_contexts()
19+
20+
if active_context and "context" in active_context:
21+
if (
22+
"namespace" in active_context["context"]
23+
and active_context["context"]["namespace"]
24+
):
25+
return active_context["context"]["namespace"]
26+
else:
27+
return "default"
28+
else:
29+
raise Exception(
30+
"No active context. Please use set_context() method to set current context."
31+
)
32+
33+
1634
def handle_exception(e: Exception, name: str, namespace: str):
1735
if isinstance(e, ApiException):
1836
if e.status == 401:

src/sagemaker/hyperpod/hyperpod_manager.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import boto3
2-
from tabulate import tabulate
32
from kubernetes import config
43
import yaml
54
from typing import Optional
@@ -9,12 +8,17 @@
98
import os
109
import subprocess
1110
import re
11+
import logging
1212

1313
KUBE_CONFIG_PATH = os.path.expanduser(KUBE_CONFIG_DEFAULT_LOCATION)
1414
TEMP_KUBE_CONFIG_FILE = "/tmp/kubeconfig"
1515

1616

1717
class HyperPodManager:
18+
@classmethod
19+
def get_logger(self):
20+
return logging.getLogger(__name__)
21+
1822
def _get_eks_name_from_arn(self, arn: str) -> str:
1923

2024
pattern = r"arn:aws:eks:[\w-]+:\d+:cluster/([\w-]+)"
@@ -25,7 +29,8 @@ def _get_eks_name_from_arn(self, arn: str) -> str:
2529
else:
2630
raise RuntimeError("cannot get EKS cluster name")
2731

28-
def _is_eks_orchestrator(self, sagemaker_client, cluster_name: str):
32+
@classmethod
33+
def _is_eks_orchestrator(cls, sagemaker_client, cluster_name: str):
2934
response = sagemaker_client.describe_cluster(ClusterName=cluster_name)
3035
return "Eks" in response["Orchestrator"]
3136

@@ -86,7 +91,7 @@ def _set_current_context(
8691
if namespace is not None:
8792
context["context"]["namespace"] = namespace
8893
else:
89-
context["context"].pop("namespace", None)
94+
context["context"]["namespace"] = "default"
9095
exist = True
9196

9297
if not exist:
@@ -107,8 +112,6 @@ def list_clusters(
107112
cls,
108113
region: Optional[str] = None,
109114
):
110-
instance = cls()
111-
112115
client = boto3.client("sagemaker", region_name=region)
113116
clusters = client.list_clusters()
114117

@@ -118,15 +121,16 @@ def list_clusters(
118121
for cluster in clusters["ClusterSummaries"]:
119122
cluster_name = cluster["ClusterName"]
120123

121-
if instance._is_eks_orchestrator(client, cluster_name):
122-
eks_clusters.append(("EKS", cluster_name))
124+
if cls._is_eks_orchestrator(client, cluster_name):
125+
eks_clusters.append(cluster_name)
123126
else:
124-
slurm_clusters.append((cluster_name, "Slurm"))
125-
126-
table_data = eks_clusters + slurm_clusters
127-
headers = ["Orchestrator", "Cluster Name"]
128-
129-
print(tabulate(table_data, headers=headers))
127+
slurm_clusters.append(cluster_name)
128+
129+
return {
130+
"Eks": eks_clusters,
131+
"Slurm": slurm_clusters
132+
}
133+
130134

131135
@classmethod
132136
def set_context(
@@ -145,7 +149,10 @@ def set_context(
145149
instance._update_kube_config(eks_name, region, TEMP_KUBE_CONFIG_FILE)
146150
instance._set_current_context(eks_cluster_arn, namespace)
147151

148-
print(f"Successfully set current cluster as: {cluster_name}")
152+
if namespace:
153+
cls.get_logger().info(f"Successfully set current context as: {cluster_name}, namespace: {namespace}")
154+
else:
155+
cls.get_logger().info(f"Successfully set current context as: {cluster_name}")
149156

150157
@classmethod
151158
def get_context(cls):
@@ -155,6 +162,6 @@ def get_context(cls):
155162
]
156163
return current_context
157164
except Exception as e:
158-
print(
165+
raise Exception(
159166
f"Failed to get current context: {e}. Check your config file at {TEMP_KUBE_CONFIG_FILE}"
160-
)
167+
)

src/sagemaker/hyperpod/inference/hp_endpoint.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
_HPEndpoint,
66
)
77
from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase
8-
from typing import Dict, List, Optional
9-
from typing_extensions import Self
8+
from sagemaker.hyperpod.common.utils import get_default_namespace
9+
from typing import Dict, List, Optional, Self
1010
from sagemaker_core.main.resources import Endpoint
1111
from pydantic import Field, ValidationError
1212
import logging
@@ -19,7 +19,7 @@ class HPEndpoint(_HPEndpoint, HPEndpointBase):
1919
def create(
2020
self,
2121
name=None,
22-
namespace="default",
22+
namespace=None,
2323
debug=False,
2424
) -> None:
2525
logging.basicConfig()
@@ -32,6 +32,9 @@ def create(
3232
if not name:
3333
name = spec.modelName
3434

35+
if not namespace:
36+
namespace = get_default_namespace()
37+
3538
self.call_create_api(
3639
name=name, # use model name as metadata name
3740
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
@@ -51,13 +54,17 @@ def create(
5154
def create_from_dict(
5255
self,
5356
input: Dict,
54-
namespace: str = "default",
57+
name: str = None,
58+
namespace: str = None,
5559
) -> None:
5660
spec = _HPEndpoint.model_validate(input, by_name=True)
5761

5862
if not name:
5963
name = spec.modelName
6064

65+
if not namespace:
66+
namespace = get_default_namespace()
67+
6168
self.call_create_api(
6269
name=name, # use model name as metadata name
6370
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
@@ -95,8 +102,11 @@ def refresh(self) -> Self:
95102
@classmethod
96103
def list(
97104
cls,
98-
namespace: str = "default",
105+
namespace: str = None,
99106
) -> List[Endpoint]:
107+
if not namespace:
108+
namespace = get_default_namespace()
109+
100110
response = cls.call_list_api(
101111
kind=INFERENCE_ENDPOINT_CONFIG_KIND,
102112
namespace=namespace,
@@ -112,7 +122,10 @@ def list(
112122
return endpoints
113123

114124
@classmethod
115-
def get(cls, name: str, namespace: str = "default") -> Endpoint:
125+
def get(cls, name: str, namespace: str = None) -> Endpoint:
126+
if not namespace:
127+
namespace = get_default_namespace()
128+
116129
response = cls.call_get_api(
117130
name=name,
118131
kind=INFERENCE_ENDPOINT_CONFIG_KIND,

src/sagemaker/hyperpod/inference/hp_endpoint_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from sagemaker.hyperpod.common.utils import (
1616
validate_cluster_connection,
1717
handle_exception,
18+
get_default_namespace,
1819
)
1920
import logging
2021
import yaml
22+
from kubernetes import config
2123

2224

2325
class HPEndpointBase:
@@ -44,7 +46,7 @@ def call_create_api(
4446
name: str,
4547
kind: str,
4648
namespace: str,
47-
spec: Union[_HPJumpStartEndpoint , _HPEndpoint],
49+
spec: Union[_HPJumpStartEndpoint, _HPEndpoint],
4850
):
4951
if not validate_cluster_connection():
5052
raise Exception(

src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
_HPJumpStartEndpoint,
66
JumpStartModelStatus,
77
)
8-
from typing import Dict, List, Optional
9-
from typing_extensions import Self
8+
from sagemaker.hyperpod.common.utils import get_default_namespace
9+
from typing import Dict, List, Optional, Self
1010
from sagemaker_core.main.resources import Endpoint
1111
from pydantic import Field, ValidationError
1212
import logging
@@ -19,7 +19,7 @@ class HPJumpStartEndpoint(_HPJumpStartEndpoint, HPEndpointBase):
1919
def create(
2020
self,
2121
name=None,
22-
namespace="default",
22+
namespace=None,
2323
debug=False,
2424
) -> None:
2525
logging.basicConfig()
@@ -33,6 +33,9 @@ def create(
3333
if not name:
3434
name = spec.model.modelId
3535

36+
if not namespace:
37+
namespace = get_default_namespace()
38+
3639
self.call_create_api(
3740
name=name, # use model name as metadata name
3841
kind=JUMPSTART_MODEL_KIND,
@@ -53,13 +56,16 @@ def create_from_dict(
5356
self,
5457
input: Dict,
5558
name: str = None,
56-
namespace: str = "default",
59+
namespace: str = None,
5760
) -> None:
5861
spec = _HPJumpStartEndpoint.model_validate(input, by_name=True)
5962

6063
if not name:
6164
name = spec.model.modelId
6265

66+
if not namespace:
67+
namespace = get_default_namespace()
68+
6369
self.call_create_api(
6470
name=name, # use model name as metadata name
6571
kind=JUMPSTART_MODEL_KIND,
@@ -97,8 +103,11 @@ def refresh(self) -> Self:
97103
@classmethod
98104
def list(
99105
cls,
100-
namespace: str = "default",
106+
namespace: str = None,
101107
) -> List[Endpoint]:
108+
if not namespace:
109+
namespace = get_default_namespace()
110+
102111
response = cls.call_list_api(
103112
kind=JUMPSTART_MODEL_KIND,
104113
namespace=namespace,
@@ -114,7 +123,10 @@ def list(
114123
return endpoints
115124

116125
@classmethod
117-
def get(cls, name: str, namespace: str = "default") -> Self:
126+
def get(cls, name: str, namespace: str = None) -> Self:
127+
if not namespace:
128+
namespace = get_default_namespace()
129+
118130
response = cls.call_get_api(
119131
name=name,
120132
kind=JUMPSTART_MODEL_KIND,

src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sagemaker.hyperpod.common.utils import (
1212
validate_cluster_connection,
1313
handle_exception,
14+
get_default_namespace,
1415
)
1516
import yaml
1617
import logging
@@ -52,18 +53,20 @@ def create(self, debug=False):
5253

5354
spec = _HyperPodPytorchJob(**self.model_dump(by_alias=True, exclude_none=True))
5455

56+
if not self.metadata.namespace:
57+
self.metadata.namespace = get_default_namespace()
58+
5559
config = {
5660
"apiVersion": f"{TRAINING_GROUP}/{API_VERSION}",
5761
"kind": KIND,
58-
"metadata": self.metadata.model_dump(),
59-
"spec": spec.model_dump(),
62+
"metadata": self.metadata.model_dump(exclude_none=True),
63+
"spec": spec.model_dump(exclude_none=True),
6064
}
6165

6266
custom_api = client.CustomObjectsApi()
6367
logger.debug(
64-
"Deploying HyperPodPytorchJob with config:\n",
68+
"Deploying HyperPodPytorchJob with config:\n%s",
6569
yaml.dump(config),
66-
sep="",
6770
)
6871

6972
try:
@@ -80,7 +83,10 @@ def create(self, debug=False):
8083
handle_exception(e, self.metadata.name, self.metadata.namespace)
8184

8285
@classmethod
83-
def list(cls, namespace="default") -> List["HyperPodPytorchJob"]:
86+
def list(cls, namespace=None) -> List["HyperPodPytorchJob"]:
87+
if namespace is None:
88+
namespace = get_default_namespace()
89+
8490
if not validate_cluster_connection():
8591
raise Exception(
8692
"Failed to connect to the Kubernetes cluster. Please check your kubeconfig."
@@ -124,7 +130,10 @@ def delete(self):
124130
handle_exception(e, self.metadata.name, self.metadata.namespace)
125131

126132
@classmethod
127-
def get(cls, name, namespace="default") -> "HyperPodPytorchJob":
133+
def get(cls, name, namespace=None) -> "HyperPodPytorchJob":
134+
if namespace is None:
135+
namespace = get_default_namespace()
136+
128137
if not validate_cluster_connection():
129138
raise Exception(
130139
"Failed to connect to the Kubernetes cluster. Please check your kubeconfig."

0 commit comments

Comments
 (0)