diff --git a/src/sagemaker/hyperpod/cli/commands/cluster.py b/src/sagemaker/hyperpod/cli/commands/cluster.py index 3e4aacd2..4f47dd3c 100644 --- a/src/sagemaker/hyperpod/cli/commands/cluster.py +++ b/src/sagemaker/hyperpod/cli/commands/cluster.py @@ -42,7 +42,7 @@ TEMP_KUBE_CONFIG_FILE, OutputFormat, ) -from sagemaker.hyperpod.cli.telemetry.user_agent import ( +from sagemaker.hyperpod.common.telemetry.user_agent import ( get_user_agent_extra_suffix, ) from sagemaker.hyperpod.cli.service.list_pods import ( @@ -61,8 +61,17 @@ from sagemaker.hyperpod.cli.utils import ( get_eks_cluster_name, ) -from sagemaker.hyperpod.common.utils import get_cluster_context as get_cluster_context_util -from sagemaker.hyperpod.observability.utils import get_monitoring_config, is_observability_addon_enabled +from sagemaker.hyperpod.common.utils import ( + get_cluster_context as get_cluster_context_util, +) +from sagemaker.hyperpod.observability.utils import ( + get_monitoring_config, + is_observability_addon_enabled, +) +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature RATE_LIMIT = 4 RATE_LIMIT_PERIOD = 1 # 1 second @@ -103,12 +112,13 @@ multiple=True, help="Optional. The namespace that you want to check the capacity for. Only SageMaker managed namespaces are supported.", ) +@_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_cluster") def list_cluster( region: Optional[str], output: Optional[str], clusters: Optional[str], debug: bool, - namespace: Optional[List] + namespace: Optional[List], ): """List SageMaker Hyperpod Clusters with cluster metadata. @@ -261,8 +271,14 @@ def rate_limited_operation( for ns in namespace: sm_managed_namespace = k8s_client.get_sagemaker_managed_namespace(ns) if sm_managed_namespace: - quota_allocation_id = sm_managed_namespace.metadata.labels[SAGEMAKER_QUOTA_ALLOCATION_LABEL] - cluster_queue_name = HYPERPOD_NAMESPACE_PREFIX + quota_allocation_id + SAGEMAKER_MANAGED_CLUSTER_QUEUE_SUFFIX + quota_allocation_id = sm_managed_namespace.metadata.labels[ + SAGEMAKER_QUOTA_ALLOCATION_LABEL + ] + cluster_queue_name = ( + HYPERPOD_NAMESPACE_PREFIX + + quota_allocation_id + + SAGEMAKER_MANAGED_CLUSTER_QUEUE_SUFFIX + ) cluster_queue = k8s_client.get_cluster_queue(cluster_queue_name) nominal_quota = _get_cluster_queue_nominal_quota(cluster_queue) quota_usage = _get_cluster_queue_quota_usage(cluster_queue) @@ -282,8 +298,19 @@ def rate_limited_operation( nodes_summary["deep_health_check_passed"], ] for ns in namespace: - capacities.append(ns_nominal_quota.get(ns).get(instance_type, {}).get(NVIDIA_GPU_RESOURCE_LIMIT_KEY, "N/A")) - capacities.append(_get_available_quota(ns_nominal_quota.get(ns), ns_quota_usage.get(ns), instance_type, NVIDIA_GPU_RESOURCE_LIMIT_KEY)) + capacities.append( + ns_nominal_quota.get(ns) + .get(instance_type, {}) + .get(NVIDIA_GPU_RESOURCE_LIMIT_KEY, "N/A") + ) + capacities.append( + _get_available_quota( + ns_nominal_quota.get(ns), + ns_quota_usage.get(ns), + instance_type, + NVIDIA_GPU_RESOURCE_LIMIT_KEY, + ) + ) cluster_capacities.append(capacities) except Exception as e: logger.error(f"Error processing cluster {cluster_name}: {e}, continue...") @@ -305,7 +332,7 @@ def _get_cluster_queue_nominal_quota(cluster_queue): if resource_name == NVIDIA_GPU_RESOURCE_LIMIT_KEY: quota = int(quota) nominal_quota[flavor_name][resource_name] = quota - + return nominal_quota @@ -336,7 +363,7 @@ def _get_available_quota(nominal, usage, flavor, resource_name): # Some resources need to be further processed by parsing unit like memory, e.g 10Gi if nominal_quota is not None and usage_quota is not None: return int(nominal_quota) - int(usage_quota) - + return "N/A" @@ -358,7 +385,9 @@ def _restructure_output(summary_list, namespaces): for node_summary in summary_list: node_summary["Namespaces"] = {} for ns in namespaces: - available_accelerators = node_summary[ns + AVAILABLE_ACCELERATOR_DEVICES_KEY] + available_accelerators = node_summary[ + ns + AVAILABLE_ACCELERATOR_DEVICES_KEY + ] total_accelerators = node_summary[ns + TOTAL_ACCELERATOR_DEVICES_KEY] quota_accelerator_info = { AVAILABLE_ACCELERATOR_DEVICES_KEY: available_accelerators, @@ -425,9 +454,9 @@ def _aggregate_nodes_info( # Accelerator Devices available = Allocatable devices - Allocated devices if node_name in nodes_resource_allocated_dict: - nodes_summary[instance_type]["accelerator_devices_available"] -= ( - nodes_resource_allocated_dict[node_name] - ) + nodes_summary[instance_type][ + "accelerator_devices_available" + ] -= nodes_resource_allocated_dict[node_name] logger.debug(f"nodes_summary: {nodes_summary}") return nodes_summary @@ -550,7 +579,6 @@ def get_cluster_context( sys.exit(1) - @click.command() @click.option("--grafana", is_flag=True, help="Returns Grafana Dashboard URL") @click.option("--prometheus", is_flag=True, help="Returns Prometheus Workspace URL") @@ -572,14 +600,21 @@ def get_monitoring(grafana: bool, prometheus: bool, list: bool) -> None: print(f"Grafana dashboard URL: {monitor_config.grafanaURL}") if list: metrics_data = monitor_config.availableMetrics - print(tabulate([[k, v.get('level', v.get('enabled'))] for k, v in metrics_data.items()], - headers=['Metric', 'Level/Enabled'], tablefmt='presto')) + print( + tabulate( + [ + [k, v.get("level", v.get("enabled"))] + for k, v in metrics_data.items() + ], + headers=["Metric", "Level/Enabled"], + tablefmt="presto", + ) + ) except Exception as e: logger.error(f"Failed to get metrics: {e}") sys.exit(1) - def _update_kube_config( eks_name: str, region: Optional[str], diff --git a/src/sagemaker/hyperpod/cli/telemetry/__init__.py b/src/sagemaker/hyperpod/common/telemetry/__init__.py similarity index 85% rename from src/sagemaker/hyperpod/cli/telemetry/__init__.py rename to src/sagemaker/hyperpod/common/telemetry/__init__.py index 65490521..3bb710cc 100644 --- a/src/sagemaker/hyperpod/cli/telemetry/__init__.py +++ b/src/sagemaker/hyperpod/common/telemetry/__init__.py @@ -10,3 +10,5 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from __future__ import absolute_import +from .telemetry_logging import _hyperpod_telemetry_emitter diff --git a/src/sagemaker/hyperpod/common/telemetry/constants.py b/src/sagemaker/hyperpod/common/telemetry/constants.py new file mode 100644 index 00000000..fc7a7579 --- /dev/null +++ b/src/sagemaker/hyperpod/common/telemetry/constants.py @@ -0,0 +1,60 @@ +from __future__ import absolute_import +from enum import Enum + + +class Feature(Enum): + """Enumeration of feature names used in telemetry.""" + + HYPERPOD = 6 # Added to support telemetry in sagemaker-hyperpod-cli + + def __str__(self): # pylint: disable=E0307 + """Return the feature name.""" + return self.name + + +class Status(Enum): + """Enumeration of status values used in telemetry.""" + + SUCCESS = 1 + FAILURE = 0 + + def __str__(self): # pylint: disable=E0307 + """Return the status name.""" + return self.name + + +class Region(str, Enum): + """Telemetry: List of all supported AWS regions.""" + + # Classic + US_EAST_1 = "us-east-1" # IAD + US_EAST_2 = "us-east-2" # CMH + US_WEST_1 = "us-west-1" # SFO + US_WEST_2 = "us-west-2" # PDX + AP_NORTHEAST_1 = "ap-northeast-1" # NRT + AP_NORTHEAST_2 = "ap-northeast-2" # ICN + AP_NORTHEAST_3 = "ap-northeast-3" # KIX + AP_SOUTH_1 = "ap-south-1" # BOM + AP_SOUTHEAST_1 = "ap-southeast-1" # SIN + AP_SOUTHEAST_2 = "ap-southeast-2" # SYD + CA_CENTRAL_1 = "ca-central-1" # YUL + EU_CENTRAL_1 = "eu-central-1" # FRA + EU_NORTH_1 = "eu-north-1" # ARN + EU_WEST_1 = "eu-west-1" # DUB + EU_WEST_2 = "eu-west-2" # LHR + EU_WEST_3 = "eu-west-3" # CDG + SA_EAST_1 = "sa-east-1" # GRU + # Opt-in + AP_EAST_1 = "ap-east-1" # HKG + AP_SOUTHEAST_3 = "ap-southeast-3" # CGK + AF_SOUTH_1 = "af-south-1" # CPT + EU_SOUTH_1 = "eu-south-1" # MXP + ME_SOUTH_1 = "me-south-1" # BAH + MX_CENTRAL_1 = "mx-central-1" # QRO + AP_SOUTHEAST_7 = "ap-southeast-7" # BKK + AP_SOUTH_2 = "ap-south-2" # HYD + AP_SOUTHEAST_4 = "ap-southeast-4" # MEL + EU_CENTRAL_2 = "eu-central-2" # ZRH + EU_SOUTH_2 = "eu-south-2" # ZAZ + IL_CENTRAL_1 = "il-central-1" # TLV + ME_CENTRAL_1 = "me-central-1" # DXB diff --git a/src/sagemaker/hyperpod/common/telemetry/telemetry_logging.py b/src/sagemaker/hyperpod/common/telemetry/telemetry_logging.py new file mode 100644 index 00000000..79eb2d29 --- /dev/null +++ b/src/sagemaker/hyperpod/common/telemetry/telemetry_logging.py @@ -0,0 +1,186 @@ +from __future__ import absolute_import +import logging +import platform +import sys +from time import perf_counter +from typing import List, Tuple +import functools +import requests +import subprocess +import re + +import boto3 +from sagemaker.hyperpod.common.telemetry.constants import Feature, Status, Region +import importlib.metadata + +SDK_VERSION = importlib.metadata.version("sagemaker-hyperpod") +DEFAULT_AWS_REGION = "us-east-2" +OS_NAME = platform.system() or "UnresolvedOS" +OS_VERSION = platform.release() or "UnresolvedOSVersion" +OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) +PYTHON_VERSION = "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro +) + +FEATURE_TO_CODE = { + str(Feature.HYPERPOD): 6, # Added to support telemetry in sagemaker-hyperpod-cli +} + +STATUS_TO_CODE = { + str(Status.SUCCESS): 1, + str(Status.FAILURE): 0, +} + +logger = logging.getLogger(__name__) + + +def get_region_and_account_from_current_context() -> Tuple[str, str]: + """ + Get region and account ID from current kubernetes context + Returns: (region, account_id) + """ + try: + # Get current context + result = subprocess.run( + ["kubectl", "config", "current-context"], capture_output=True, text=True + ) + + if result.returncode == 0: + context = result.stdout.strip() + + # Extract region + region_pattern = r"([a-z]{2}-[a-z]+-\d{1})" + region = DEFAULT_AWS_REGION + if match := re.search(region_pattern, context): + region = match.group(1) + + # Extract account ID (12 digits) + account_pattern = r"(\d{12})" + account = "unknown" + if match := re.search(account_pattern, context): + account = match.group(1) + + return region, account + + except Exception as e: + logger.debug(f"Failed to get context info from kubectl: {e}") + + return DEFAULT_AWS_REGION, "unknown" + + +def _requests_helper(url, timeout): + """Make a GET request to the given URL""" + + response = None + try: + response = requests.get(url, timeout) + except requests.exceptions.RequestException as e: + logger.exception("Request exception: %s", str(e)) + return response + + +def _construct_url( + accountId: str, + region: str, + status: str, + feature: str, + failure_reason: str, + failure_type: str, + extra_info: str, +) -> str: + """Construct the URL for the telemetry request""" + + base_url = ( + f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" + f"x-accountId={accountId}" + f"&x-status={status}" + f"&x-feature={feature}" + ) + logger.debug("Failure reason: %s", failure_reason) + if failure_reason: + base_url += f"&x-failureReason={failure_reason}" + base_url += f"&x-failureType={failure_type}" + if extra_info: + base_url += f"&x-extra={extra_info}" + return base_url + + +def _send_telemetry_request( + status: int, + feature_list: List[int], + session, + failure_reason: str = None, + failure_type: str = None, + extra_info: str = None, +) -> None: + """Make GET request to an empty object in S3 bucket""" + try: + region, accountId = get_region_and_account_from_current_context() + + try: + Region(region) # Validate the region + except ValueError: + logger.warning( + "Region not found in supported regions. Telemetry request will not be emitted." + ) + return + + url = _construct_url( + accountId, + region, + str(status), + str( + ",".join(map(str, feature_list)) + ), # Remove brackets and quotes to cut down on length + failure_reason, + failure_type, + extra_info, + ) + # Send the telemetry request + logger.info("Sending telemetry request to [%s]", url) + _requests_helper(url, 2) + logger.info("SageMaker Python SDK telemetry successfully emitted.") + except Exception: # pylint: disable=W0703 + logger.warning("SageMaker Python SDK telemetry not emitted!") + + +def _hyperpod_telemetry_emitter(feature: str, func_name: str): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + extra = ( + f"{func_name}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + ) + start = perf_counter() + try: + result = func(*args, **kwargs) + duration = round(perf_counter() - start, 2) + extra += f"&x-latency={duration}" + _send_telemetry_request( + Status.SUCCESS, + [FEATURE_TO_CODE[str(feature)]], + None, + None, + None, + extra, + ) + return result + except Exception as e: + duration = round(perf_counter() - start, 2) + extra += f"&x-latency={duration}" + _send_telemetry_request( + Status.FAILURE, + [FEATURE_TO_CODE[str(feature)]], + None, + str(e), + type(e).__name__, + extra, + ) + raise + + return wrapper + + return decorator diff --git a/src/sagemaker/hyperpod/cli/telemetry/user_agent.py b/src/sagemaker/hyperpod/common/telemetry/user_agent.py similarity index 100% rename from src/sagemaker/hyperpod/cli/telemetry/user_agent.py rename to src/sagemaker/hyperpod/common/telemetry/user_agent.py diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint.py b/src/sagemaker/hyperpod/inference/hp_endpoint.py index e952baba..8a7907a1 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint.py @@ -11,6 +11,10 @@ InferenceEndpointConfigStatus, _HPEndpoint, ) +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature from sagemaker.hyperpod.inference.hp_endpoint_base import HPEndpointBase from typing import Dict, List, Optional from sagemaker_core.main.resources import Endpoint @@ -21,6 +25,7 @@ class HPEndpoint(_HPEndpoint, HPEndpointBase): metadata: Optional[Metadata] = Field(default=None) status: Optional[InferenceEndpointConfigStatus] = Field(default=None) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_endpoint") def create( self, name=None, @@ -59,6 +64,7 @@ def create( f"Creating sagemaker model and endpoint. Endpoint name: {spec.endpointName}.\n The process may take a few minutes..." ) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_endpoint_from_dict") def create_from_dict( self, input: Dict, @@ -116,6 +122,7 @@ def refresh(self): return self @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_endpoints") def list( cls, namespace: str = None, @@ -138,6 +145,7 @@ def list( return endpoints @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_endpoint") def get(cls, name: str, namespace: str = None) -> Endpoint: if not namespace: namespace = get_default_namespace() @@ -163,6 +171,7 @@ def get(cls, name: str, namespace: str = None) -> Endpoint: return endpoint + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_endpoint") def delete(self) -> None: logger = self.get_logger() logger = setup_logging(logger) @@ -174,6 +183,7 @@ def delete(self) -> None: ) logger.info(f"Deleting HPEndpoint: {self.metadata.name}...") + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "invoke_endpoint") def invoke(self, body, content_type="application/json"): if not self.endpointName: raise Exception("SageMaker endpoint name not found in this object!") diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint_base.py b/src/sagemaker/hyperpod/inference/hp_endpoint_base.py index 1cb0432a..f80308ad 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint_base.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint_base.py @@ -15,6 +15,10 @@ setup_logging, get_default_namespace, ) +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature class HPEndpointBase: @@ -130,6 +134,7 @@ def call_delete_api( handle_exception(e, name, namespace) @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_operator_logs") def get_operator_logs(cls, since_hours: float): cls.verify_kube_config() @@ -159,6 +164,7 @@ def get_operator_logs(cls, since_hours: float): return logs @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_logs_endpoint") def get_logs( cls, pod: str, @@ -194,6 +200,7 @@ def get_logs( return logs @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_endpoint") def list_pods(cls, namespace=None): cls.verify_kube_config() @@ -210,6 +217,7 @@ def list_pods(cls, namespace=None): return pods @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_namespaces") def list_namespaces(cls): cls.verify_kube_config() diff --git a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py index 71b6635b..6110f20c 100644 --- a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py @@ -16,12 +16,17 @@ _HPJumpStartEndpoint, JumpStartModelStatus, ) +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature class HPJumpStartEndpoint(_HPJumpStartEndpoint, HPEndpointBase): metadata: Optional[Metadata] = Field(default=None) status: Optional[JumpStartModelStatus] = Field(default=None) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint") def create( self, name=None, @@ -64,6 +69,7 @@ def create( f"Creating JumpStart model and sagemaker endpoint. Endpoint name: {endpoint_name}.\n The process may take a few minutes..." ) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_js_endpoint_from_dict") def create_from_dict( self, input: Dict, @@ -125,6 +131,7 @@ def refresh(self): return self @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_js_endpoints") def list( cls, namespace: str = None, @@ -147,6 +154,7 @@ def list( return endpoints @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_js_endpoint") def get(cls, name: str, namespace: str = None): if not namespace: namespace = get_default_namespace() @@ -172,6 +180,7 @@ def get(cls, name: str, namespace: str = None): return endpoint + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_js_endpoint") def delete(self) -> None: logger = self.get_logger() logger = setup_logging(logger) @@ -185,6 +194,7 @@ def delete(self) -> None: f"Deleting JumpStart model and sagemaker endpoint: {self.metadata.name}. This may take a few minutes..." ) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "invoke_js_endpoint") def invoke(self, body, content_type="application/json"): if not self.sageMakerEndpoint or not self.sageMakerEndpoint.name: raise Exception("SageMaker endpoint name not found in this object!") diff --git a/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py b/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py index 78e8f86a..ac131c05 100644 --- a/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py +++ b/src/sagemaker/hyperpod/training/hyperpod_pytorch_job.py @@ -13,6 +13,10 @@ get_default_namespace, setup_logging, ) +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + _hyperpod_telemetry_emitter, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature import yaml import logging @@ -45,6 +49,7 @@ def verify_kube_config(cls): def get_logger(cls): return logging.getLogger(__name__) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "create_pytorchjob") def create(self, debug=False): self.verify_kube_config() @@ -83,6 +88,7 @@ def create(self, debug=False): handle_exception(e, self.metadata.name, self.metadata.namespace) @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pytorchjobs") def list(cls, namespace=None) -> List["HyperPodPytorchJob"]: cls.verify_kube_config() @@ -106,6 +112,7 @@ def list(cls, namespace=None) -> List["HyperPodPytorchJob"]: logger.error(f"Failed to list HyperpodPytorchJobs!") handle_exception(e, "", namespace) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "delete_pytorchjob") def delete(self): self.verify_kube_config() @@ -128,6 +135,7 @@ def delete(self): handle_exception(e, self.metadata.name, self.metadata.namespace) @classmethod + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_pytorchjob") def get(cls, name, namespace=None) -> "HyperPodPytorchJob": cls.verify_kube_config() @@ -175,6 +183,7 @@ def refresh(self) -> "HyperPodPytorchJob": logger.error(f"Failed to refresh HyperPodPytorchJob {self.metadata.name}!") handle_exception(e, self.metadata.name, self.metadata.namespace) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "list_pods_pytorchjob") def list_pods(self) -> List[str]: self.verify_kube_config() @@ -196,6 +205,7 @@ def list_pods(self) -> List[str]: logger.error(f"Failed to list pod in namespace {self.metadata.namespace}!") handle_exception(e, self.metadata.name, self.metadata.namespace) + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "get_pytorchjob_logs_from_pod") def get_logs_from_pod(self, pod_name: str, container: Optional[str] = None) -> str: self.verify_kube_config() diff --git a/test/unit_tests/common/telemetry/__init__.py b/test/unit_tests/common/telemetry/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/unit_tests/common/telemetry/test_telemetry_logging.py b/test/unit_tests/common/telemetry/test_telemetry_logging.py new file mode 100644 index 00000000..12939bdc --- /dev/null +++ b/test/unit_tests/common/telemetry/test_telemetry_logging.py @@ -0,0 +1,319 @@ +import pytest +from unittest.mock import patch, MagicMock, Mock +import subprocess +from typing import Tuple + +# Import your module +from sagemaker.hyperpod.common.telemetry.telemetry_logging import ( + get_region_and_account_from_current_context, + _send_telemetry_request, + _hyperpod_telemetry_emitter, + _requests_helper, + _construct_url, + DEFAULT_AWS_REGION, + FEATURE_TO_CODE, +) +from sagemaker.hyperpod.common.telemetry.constants import Feature, Status +import requests +import logging + +# Test data +MOCK_CONTEXTS = { + "eks_arn": "arn:aws:eks:us-west-2:123456789012:cluster/my-cluster", + "simple": "cluster-123456789012-us-east-1", + "invalid": "invalid-context", + "partial": "cluster-us-west-2-invalid", +} + + +@pytest.fixture +def mock_subprocess(): + with patch("subprocess.run") as mock_run: + yield mock_run + + +@pytest.fixture +def mock_requests(): + with patch("requests.get") as mock_get: + yield mock_get + + +@pytest.mark.parametrize( + "context,expected", + [ + (MOCK_CONTEXTS["eks_arn"], ("us-west-2", "123456789012")), + (MOCK_CONTEXTS["simple"], ("us-east-1", "123456789012")), + (MOCK_CONTEXTS["invalid"], (DEFAULT_AWS_REGION, "unknown")), + (MOCK_CONTEXTS["partial"], ("us-west-2", "unknown")), + ], +) +def test_get_region_and_account_from_current_context( + mock_subprocess, context, expected +): + # Setup mock + mock_subprocess.return_value = MagicMock(returncode=0, stdout=context) + + # Test + result = get_region_and_account_from_current_context() + assert result == expected + + +def test_get_region_and_account_subprocess_failure(mock_subprocess): + # Setup mock to simulate failure + mock_subprocess.return_value = MagicMock(returncode=1) + + # Test + result = get_region_and_account_from_current_context() + assert result == (DEFAULT_AWS_REGION, "unknown") + + +def test_get_region_and_account_exception(mock_subprocess): + # Setup mock to raise exception + mock_subprocess.side_effect = Exception("Command failed") + + # Test + result = get_region_and_account_from_current_context() + assert result == (DEFAULT_AWS_REGION, "unknown") + + +@pytest.fixture +def mock_get_region_account(): + with patch( + "sagemaker.hyperpod.common.telemetry.telemetry_logging.get_region_and_account_from_current_context" + ) as mock: + mock.return_value = ("us-west-2", "123456789012") + yield mock + + +def test_send_telemetry_request(mock_get_region_account, mock_requests): + # Test successful telemetry request + _send_telemetry_request(status=1, feature_list=[1], session=None, extra_info="test") + + # Verify request was made + assert mock_requests.called + + +def test_send_telemetry_request_failure(mock_get_region_account, mock_requests): + # Setup mock to simulate failure + mock_requests.side_effect = Exception("Request failed") + + # Test + _send_telemetry_request(status=1, feature_list=[1], session=None, extra_info="test") + # Should not raise exception + + +# Test the decorator +def test_hyperpod_telemetry_emitter(): + # Create a mock function + @_hyperpod_telemetry_emitter(feature="HYPERPOD", func_name="test_func") + def test_function(): + return "success" + + # Mock the telemetry request + with patch( + "sagemaker.hyperpod.common.telemetry.telemetry_logging._send_telemetry_request" + ) as mock_telemetry: + # Test successful execution + result = test_function() + assert result == "success" + assert mock_telemetry.called + + +def test_hyperpod_telemetry_emitter_failure(): + # Create a mock function that raises an exception + @_hyperpod_telemetry_emitter(feature="HYPERPOD", func_name="test_func") + def failing_function(): + raise ValueError("Test error") + + # Mock the telemetry request + with patch( + "sagemaker.hyperpod.common.telemetry.telemetry_logging._send_telemetry_request" + ) as mock_telemetry: + # Test exception handling + with pytest.raises(ValueError): + failing_function() + assert mock_telemetry.called + + +# Test invalid region handling +def test_send_telemetry_request_invalid_region(mock_get_region_account, mock_requests): + # Setup mock to return invalid region + mock_get_region_account.return_value = ("invalid-region", "123456789012") + + # Test + _send_telemetry_request(status=1, feature_list=[1], session=None, extra_info="test") + + # Verify no request was made due to invalid region + assert not mock_requests.called + + +def test_telemetry_decorator_details(): + with patch( + "sagemaker.hyperpod.common.telemetry.telemetry_logging._send_telemetry_request" + ) as mock_telemetry: + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "test_func") + def sample_function(): + return "success" + + result = sample_function() + + # Verify telemetry call details + mock_telemetry.assert_called_once() + args = mock_telemetry.call_args[0] + + # Check status + assert args[0] == Status.SUCCESS + + # Check feature code + assert args[1] == [FEATURE_TO_CODE[str(Feature.HYPERPOD)]] + + # Check extra info contains required fields + extra_info = args[5] + assert "test_func" in extra_info + assert "x-sdkVersion" in extra_info + assert "x-latency" in extra_info + + +def test_multiple_telemetry_calls(): + with patch( + "sagemaker.hyperpod.common.telemetry.telemetry_logging._send_telemetry_request" + ) as mock_telemetry: + + @_hyperpod_telemetry_emitter(Feature.HYPERPOD, "test_func") + def sample_function(succeed: bool): + if not succeed: + raise ValueError("Failed") + return "success" + + # Success case + sample_function(True) + + # Failure case + with pytest.raises(ValueError): + sample_function(False) + + # Verify both calls + assert mock_telemetry.call_count == 2 + + # Check success call + success_call = mock_telemetry.call_args_list[0] + assert success_call[0][0] == Status.SUCCESS + + # Check failure call + failure_call = mock_telemetry.call_args_list[1] + assert failure_call[0][0] == Status.FAILURE + + +# Test _requests_helper +def test_requests_helper_success(): + """Test successful request""" + with patch("requests.get") as mock_get: + # Setup mock response + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + # Make request + response = _requests_helper("https://test.com", 2) + + # Verify + assert response == mock_response + mock_get.assert_called_once_with("https://test.com", 2) + + +def test_requests_helper_with_invalid_url(caplog): + """Test requests helper with invalid URL""" + with patch("requests.get") as mock_get: + # Set up the mock to raise InvalidURL + mock_get.side_effect = requests.exceptions.InvalidURL("Invalid URL") + + # Capture logs at DEBUG level + with caplog.at_level(logging.DEBUG): + response = _requests_helper("invalid://url", 2) + + # Verify response is None + assert response is None + + # Verify log message + assert "Request exception: Invalid URL" in caplog.text + + +def test_construct_url_basic(): + """Test basic URL construction""" + url = _construct_url( + accountId="123456789012", + region="us-west-2", + status="SUCCESS", + feature="TEST", + failure_reason=None, + failure_type=None, + extra_info=None, + ) + + expected = ( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/telemetry?" + "x-accountId=123456789012&x-status=SUCCESS&x-feature=TEST" + ) + assert url == expected + + +def test_construct_url_with_failure(): + """Test URL construction with failure information""" + url = _construct_url( + accountId="123456789012", + region="us-west-2", + status="FAILURE", + feature="TEST", + failure_reason="Test failed", + failure_type="TestError", + extra_info=None, + ) + + expected = ( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/telemetry?" + "x-accountId=123456789012&x-status=FAILURE&x-feature=TEST" + "&x-failureReason=Test failed&x-failureType=TestError" + ) + assert url == expected + + +def test_construct_url_with_extra_info(): + """Test URL construction with extra information""" + url = _construct_url( + accountId="123456789012", + region="us-west-2", + status="SUCCESS", + feature="TEST", + failure_reason=None, + failure_type=None, + extra_info="additional=info", + ) + + expected = ( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/telemetry?" + "x-accountId=123456789012&x-status=SUCCESS&x-feature=TEST" + "&x-extra=additional=info" + ) + assert url == expected + + +def test_construct_url_all_parameters(): + """Test URL construction with all parameters""" + url = _construct_url( + accountId="123456789012", + region="us-west-2", + status="FAILURE", + feature="TEST", + failure_reason="Test failed", + failure_type="TestError", + extra_info="additional=info", + ) + + expected = ( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/telemetry?" + "x-accountId=123456789012&x-status=FAILURE&x-feature=TEST" + "&x-failureReason=Test failed&x-failureType=TestError" + "&x-extra=additional=info" + ) + assert url == expected