diff --git a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py index 2e346a91..08e9cfc8 100644 --- a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py +++ b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/model.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator, ConfigDict from typing import Optional, List, Dict, Union, Literal from sagemaker.hyperpod.inference.config.hp_endpoint_config import ( @@ -31,9 +31,19 @@ from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint class FlatHPEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid") + + metadata_name: Optional[str] = Field( + None, + alias="metadata_name", + description="Name of the jumpstart endpoint object", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + # endpoint_name endpoint_name: Optional[str] = Field( - "", + None, alias="endpoint_name", description="Name of SageMaker endpoint; empty string means no creation", max_length=63, @@ -130,7 +140,7 @@ class FlatHPEndpoint(BaseModel): description="FSX File System DNS Name", ) fsx_file_system_id: Optional[str] = Field( - ..., + None, alias="fsx_file_system_id", description="FSX File System ID", ) @@ -142,12 +152,12 @@ class FlatHPEndpoint(BaseModel): # S3Storage s3_bucket_name: Optional[str] = Field( - ..., + None, alias="s3_bucket_name", description="S3 bucket location", ) s3_region: Optional[str] = Field( - ..., + None, alias="s3_region", description="S3 bucket region", ) @@ -229,12 +239,22 @@ class FlatHPEndpoint(BaseModel): invocation_endpoint: Optional[str] = Field( default="invocations", description=( - "The invocation endpoint of the model server. " - "http://:/ would be pre-populated based on the other fields. " + "The invocation endpoint of the model server. http://:/ would be pre-populated based on the other fields. " "Please fill in the path after http://:/ specific to your model server.", ) ) - + + @model_validator(mode='after') + def validate_model_source_config(self): + """Validate that required fields are provided based on model_source_type""" + if self.model_source_type == "s3": + if not self.s3_bucket_name or not self.s3_region: + raise ValueError("s3_bucket_name and s3_region are required when model_source_type is 's3'") + elif self.model_source_type == "fsx": + if not self.fsx_file_system_id: + raise ValueError("fsx_file_system_id is required when model_source_type is 'fsx'") + return self + def to_domain(self) -> HPEndpoint: env_vars = None if self.env: diff --git a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/schema.json b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/schema.json index 389df921..8474449b 100644 --- a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/schema.json +++ b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_0/schema.json @@ -1,184 +1,457 @@ { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "FlatHPEndpoint", - "type": "object", "additionalProperties": false, - "required": [ - "instance_type", - "model_name", - "model_source_type", - "image_uri", - "container_port", - "model_volume_mount_name" - ], "properties": { + "metadata_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the jumpstart endpoint object", + "title": "Metadata Name" + }, "endpoint_name": { - "type": ["string", "null"], - "description": "Name used for SageMaker endpoint; empty string means no creation", - "default": "", - "maxLength": 63, - "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$" + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of SageMaker endpoint; empty string means no creation", + "title": "Endpoint Name" }, "env": { - "type": ["object", "null"], + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, "description": "Map of environment variable names to their values", - "additionalProperties": { "type": "string" } + "title": "Env" }, "instance_type": { - "type": "string", "description": "EC2 instance type for the inference server", - "pattern": "^ml\\..*" + "pattern": "^ml\\..*", + "title": "Instance Type", + "type": "string" }, "metrics_enabled": { - "type": "boolean", + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": false, "description": "Enable metrics collection", - "default": false + "title": "Metrics Enabled" }, "model_name": { - "type": "string", "description": "Name of model to create on SageMaker", - "minLength": 1, "maxLength": 63, - "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$" + "minLength": 1, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "title": "Model Name", + "type": "string" }, "model_version": { - "type": ["string", "null"], + "anyOf": [ + { + "maxLength": 14, + "minLength": 5, + "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, "description": "Version of the model for the endpoint", - "minLength": 5, - "maxLength": 14, - "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$" + "title": "Model Version" }, "model_source_type": { - "type": "string", "description": "Source type: fsx or s3", - "enum": ["fsx", "s3"] + "enum": [ + "fsx", + "s3" + ], + "title": "Model Source Type", + "type": "string" }, "model_location": { - "type": ["string", "null"], - "description": "Specific model data location" + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Specific model data location", + "title": "Model Location" }, "prefetch_enabled": { - "type": "boolean", + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": false, "description": "Whether to pre-fetch model data", - "default": false + "title": "Prefetch Enabled" }, "tls_certificate_output_s3_uri": { - "type": ["string", "null"], + "anyOf": [ + { + "pattern": "^s3://([^/]+)/?(.*)$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, "description": "S3 URI for TLS certificate output", - "pattern": "^s3://([^/]+)/?(.*)$" - }, - "fsx_dns_name": { - "type": ["string", "null"], - "description": "FSX File System DNS Name" - }, - "fsx_file_system_id": { - "type": ["string", "null"], - "description": "FSX File System ID" - }, - "fsx_mount_name": { - "type": ["string", "null"], - "description": "FSX File System Mount Name" - }, - "s3_bucket_name": { - "type": ["string", "null"], - "description": "S3 bucket location" - }, - "s3_region": { - "type": ["string", "null"], - "description": "S3 bucket region" + "title": "Tls Certificate Output S3 Uri" }, "image_uri": { - "type": "string", - "description": "Inference server image name" + "description": "Inference server image name", + "title": "Image Uri", + "type": "string" }, "container_port": { - "type": "integer", - "format": "int32", "description": "Port on which the model server listens", + "maximum": 65535, "minimum": 1, - "maximum": 65535 + "title": "Container Port", + "type": "integer" }, "model_volume_mount_path": { - "type": "string", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "/opt/ml/model", "description": "Path inside container for model volume", - "default": "/opt/ml/model" + "title": "Model Volume Mount Path" }, "model_volume_mount_name": { - "type": "string", - "description": "Name of the model volume mount" + "description": "Name of the model volume mount", + "title": "Model Volume Mount Name", + "type": "string" + }, + "fsx_dns_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "FSX File System DNS Name", + "title": "Fsx Dns Name" + }, + "fsx_file_system_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "FSX File System ID", + "title": "Fsx File System Id" + }, + "fsx_mount_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "FSX File System Mount Name", + "title": "Fsx Mount Name" + }, + "s3_bucket_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 bucket location", + "title": "S3 Bucket Name" + }, + "s3_region": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 bucket region", + "title": "S3 Region" }, "resources_limits": { - "type": ["object", "null"], + "anyOf": [ + { + "additionalProperties": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ] + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, "description": "Resource limits for the worker", - "additionalProperties": { - "type": ["integer", "string"] - } + "title": "Resources Limits" }, "resources_requests": { - "type": ["object", "null"], + "anyOf": [ + { + "additionalProperties": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ] + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, "description": "Resource requests for the worker", - "additionalProperties": { - "type": ["integer", "string"] - } + "title": "Resources Requests" }, "dimensions": { - "type": ["object", "null"], - "description": "CloudWatch Metric dimensions as key–value pairs", - "additionalProperties": { - "type": "string" - } + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "CloudWatch Metric dimensions as key\u2013value pairs", + "title": "Dimensions" }, "metric_collection_period": { - "type": "integer", + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 300, "description": "Defines the Period for CloudWatch query", - "default": 300 + "title": "Metric Collection Period" }, "metric_collection_start_time": { - "type": "integer", + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 300, "description": "Defines the StartTime for CloudWatch query", - "default": 300 + "title": "Metric Collection Start Time" }, "metric_name": { - "type": ["string", "null"], - "description": "Metric name to query for CloudWatch trigger" + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Metric name to query for CloudWatch trigger", + "title": "Metric Name" }, "metric_stat": { - "type": "string", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "Average", "description": "Statistics metric to be used by Trigger. Defines the Stat for the CloudWatch query. Default is Average.", - "default": "Average" + "title": "Metric Stat" }, "metric_type": { - "type": "string", - "description": "The type of metric to be used by HPA. `Average` – Uses average value per pod; `Value` – Uses absolute metric value.", - "enum": ["Value", "Average"], - "default": "Average" + "anyOf": [ + { + "enum": [ + "Value", + "Average" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": "Average", + "description": "The type of metric to be used by HPA. `Average` \u2013 Uses average value per pod; `Value` \u2013 Uses absolute metric value.", + "title": "Metric Type" }, "min_value": { - "type": "number", + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "default": 0, "description": "Minimum metric value used in case of empty response from CloudWatch. Default is 0.", - "default": 0 + "title": "Min Value" }, "cloud_watch_trigger_name": { - "type": ["string", "null"], - "description": "Name for the CloudWatch trigger" + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name for the CloudWatch trigger", + "title": "Cloud Watch Trigger Name" }, "cloud_watch_trigger_namespace": { - "type": ["string", "null"], - "description": "AWS CloudWatch namespace for the metric" + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "AWS CloudWatch namespace for the metric", + "title": "Cloud Watch Trigger Namespace" }, "target_value": { - "type": ["number", "null"], - "description": "Target value for the CloudWatch metric" + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Target value for the CloudWatch metric", + "title": "Target Value" }, "use_cached_metrics": { - "type": "boolean", + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": true, "description": "Enable caching of metric values during polling interval. Default is true.", - "default": true + "title": "Use Cached Metrics" }, "invocation_endpoint": { - "type": "string", + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "invocations", "description": "The invocation endpoint of the model server. http://:/ would be pre-populated based on the other fields. Please fill in the path after http://:/ specific to your model server.", - "default": "invocations" + "title": "Invocation Endpoint" } - } -} + }, + "required": [ + "instance_type", + "model_name", + "model_source_type", + "image_uri", + "container_port", + "model_volume_mount_name" + ], + "title": "FlatHPEndpoint", + "type": "object" +} \ No newline at end of file diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/model.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/model.py index 44ad2d63..2dd257ed 100644 --- a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/model.py +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/model.py @@ -10,7 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from pydantic import BaseModel, Field, constr +from pydantic import BaseModel, Field, model_validator, ConfigDict from typing import Optional # reuse the nested types @@ -23,10 +23,20 @@ from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint class FlatHPJumpStartEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid") + accept_eula: bool = Field( False, alias="accept_eula", description="Whether model terms of use have been accepted" ) + metadata_name: Optional[str] = Field( + None, + alias="metadata_name", + description="Name of the jumpstart endpoint object", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + model_id: str = Field( ..., alias="model_id", @@ -53,7 +63,7 @@ class FlatHPJumpStartEndpoint(BaseModel): ) endpoint_name: Optional[str] = Field( - "", + None, alias="endpoint_name", description="Name of SageMaker endpoint; empty string means no creation", max_length=63, diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/schema.json b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/schema.json index efe6f340..307ffdd2 100644 --- a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/schema.json +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_0/schema.json @@ -1,49 +1,91 @@ { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "FlatHPJumpStartEndpointV1", - "type": "object", "additionalProperties": false, - "required": [ - "model_id", - "instance_type" - ], "properties": { "accept_eula": { - "type": "boolean", + "default": false, "description": "Whether model terms of use have been accepted", - "default": false + "title": "Accept Eula", + "type": "boolean" + }, + "metadata_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the jumpstart endpoint object", + "title": "Metadata Name" }, "model_id": { - "type": "string", "description": "Unique identifier of the model within the hub", - "minLength": 1, "maxLength": 63, - "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$" + "minLength": 1, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "title": "Model Id", + "type": "string" }, "model_version": { - "type": ["string", "null"], + "anyOf": [ + { + "maxLength": 14, + "minLength": 5, + "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, "description": "Semantic version of the model to deploy (e.g. 1.0.0)", - "minLength": 5, - "maxLength": 14, - "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$", - "default": null + "title": "Model Version" }, "instance_type": { - "type": "string", "description": "EC2 instance type for the inference server", - "pattern": "^ml\\..*" + "pattern": "^ml\\..*", + "title": "Instance Type", + "type": "string" }, "endpoint_name": { - "type": "string", + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, "description": "Name of SageMaker endpoint; empty string means no creation", - "default": "", - "maxLength": 63, - "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$" + "title": "Endpoint Name" }, "tls_certificate_output_s3_uri": { - "type": ["string", "null"], + "anyOf": [ + { + "pattern": "^s3://([^/]+)/?(.*)$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, "description": "S3 URI to write the TLS certificate (optional)", - "pattern": "^s3://([^/]+)/?(.*)$" + "title": "Tls Certificate Output S3 Uri" } - } -} + }, + "required": [ + "model_id", + "instance_type" + ], + "title": "FlatHPJumpStartEndpoint", + "type": "object" +} \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/commands/inference.py b/src/sagemaker/hyperpod/cli/commands/inference.py index 7314432e..71e8cdd1 100644 --- a/src/sagemaker/hyperpod/cli/commands/inference.py +++ b/src/sagemaker/hyperpod/cli/commands/inference.py @@ -31,12 +31,12 @@ registry=JS_REG, ) @_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_js_endpoint_cli") -def js_create(namespace, version, js_endpoint): +def js_create(name, namespace, version, js_endpoint): """ Create a jumpstart model endpoint. """ - js_endpoint.create(namespace=namespace) + js_endpoint.create(name=name, namespace=namespace) @click.command("hyp-custom-endpoint") @@ -53,12 +53,12 @@ def js_create(namespace, version, js_endpoint): registry=C_REG, ) @_hyperpod_telemetry_emitter(Feature.HYPERPOD_CLI, "create_custom_endpoint_cli") -def custom_create(namespace, version, custom_endpoint): +def custom_create(name, namespace, version, custom_endpoint): """ Create a custom model endpoint. """ - custom_endpoint.create(namespace=namespace) + custom_endpoint.create(name=name, namespace=namespace) # INVOKE diff --git a/src/sagemaker/hyperpod/cli/inference_utils.py b/src/sagemaker/hyperpod/cli/inference_utils.py index e402eb71..db44c77a 100644 --- a/src/sagemaker/hyperpod/cli/inference_utils.py +++ b/src/sagemaker/hyperpod/cli/inference_utils.py @@ -30,6 +30,7 @@ def _parse_json_flag(ctx, param, value): # 1) the wrapper click actually invokes def wrapped_func(*args, **kwargs): namespace = kwargs.pop("namespace", None) + name = kwargs.pop("metadata_name", None) pop_version = kwargs.pop("version", "1.0") Model = registry.get(version) @@ -38,7 +39,7 @@ def wrapped_func(*args, **kwargs): flat = Model(**kwargs) domain = flat.to_domain() - return func(namespace, version, domain) + return func(name, namespace, version, domain) # 2) inject JSON flags only if they exist in the schema schema = load_schema_for_version(version, schema_pkg) diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint.py b/src/sagemaker/hyperpod/inference/hp_endpoint.py index 8a7907a1..f4bc2b22 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint.py @@ -38,7 +38,7 @@ def create( spec = _HPEndpoint(**self.model_dump(by_alias=True, exclude_none=True)) if not spec.endpointName and not name: - raise Exception('Input "name" is required if endpoint name is not provided') + raise Exception('Either metadata name or endpoint name must be provided') if not namespace: namespace = get_default_namespace() diff --git a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py index 6110f20c..c3a45711 100644 --- a/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_jumpstart_endpoint.py @@ -43,7 +43,7 @@ def create( endpoint_name = spec.sageMakerEndpoint.name if not endpoint_name and not name: - raise Exception('Input "name" is required if endpoint name is not provided') + raise Exception('Either metadata name or endpoint name must be provided') if not name: name = endpoint_name diff --git a/test/integration_tests/inference/cli/test_cli_custom_fsx_inference.py b/test/integration_tests/inference/cli/test_cli_custom_fsx_inference.py index 7caba854..1dc20f4e 100644 --- a/test/integration_tests/inference/cli/test_cli_custom_fsx_inference.py +++ b/test/integration_tests/inference/cli/test_cli_custom_fsx_inference.py @@ -51,7 +51,6 @@ def test_custom_create(runner, custom_endpoint_name): "--model-source-type", "fsx", "--model-location", "hf-eqa", "--fsx-file-system-id", FSX_LOCATION, - "--s3-region", REGION, "--image-uri", "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:2.3.0-transformers4.48.0-cpu-py311-ubuntu22.04", "--container-port", "8080", "--model-volume-mount-name", "model-weights", diff --git a/test/unit_tests/cli/test_inference.py b/test/unit_tests/cli/test_inference.py index 3a884c54..0957cc19 100644 --- a/test/unit_tests/cli/test_inference.py +++ b/test/unit_tests/cli/test_inference.py @@ -63,7 +63,7 @@ def test_js_create_with_required_args(): ]) assert result.exit_code == 0, result.output - domain_obj.create.assert_called_once_with(namespace='test-ns') + domain_obj.create.assert_called_once_with(name=None, namespace='test-ns') def test_js_create_missing_required_args(): @@ -180,7 +180,7 @@ def test_custom_create_with_required_args(): ]) assert result.exit_code == 0, result.output - domain_obj.create.assert_called_once_with(namespace='test-ns') + domain_obj.create.assert_called_once_with(name=None, namespace='test-ns') def test_custom_create_missing_required_args(): diff --git a/test/unit_tests/cli/test_inference_utils.py b/test/unit_tests/cli/test_inference_utils.py index 657bf14f..1e6d3ad8 100644 --- a/test/unit_tests/cli/test_inference_utils.py +++ b/test/unit_tests/cli/test_inference_utils.py @@ -76,7 +76,7 @@ def to_domain(self): return self @click.command() @generate_click_command(registry=registry) - def cmd(namespace, version, domain): + def cmd(name, namespace, version, domain): click.echo(json.dumps({ 'env': domain.env, 'dimensions': domain.dimensions, 'limits': domain.resources_limits, 'reqs': domain.resources_requests @@ -118,7 +118,7 @@ def to_domain(self): return self @click.command() @generate_click_command(registry=registry) - def cmd(namespace, version, domain): + def cmd(name, namespace, version, domain): click.echo(f"{domain.s},{domain.i},{domain.n},{domain.b},{domain.e},{domain.d}") res = self.runner.invoke(cmd, [ @@ -148,7 +148,7 @@ def to_domain(self): # Create test command @click.command() @generate_click_command(schema_pkg='mypkg', registry=registry) - def cmd(namespace, version, domain): + def cmd(name, namespace, version, domain): click.echo(f"version: {version}") # Test command execution