Skip to content

Commit 3b03301

Browse files
mollyheamazonnargokul
authored andcommitted
add inference template submit backend logic, fix namespace default across template (#160)
* add inference template submit backend logic, fix namespace default across template * add namespace to jumpstart and custom endpoint template to simplify logic, no special handling for namespace for any templates, add unit tests for init experience
1 parent 84f3526 commit 3b03301

File tree

5 files changed

+1879
-238
lines changed

5 files changed

+1879
-238
lines changed

src/sagemaker/hyperpod/cli/commands/init.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
from sagemaker.hyperpod.cli.constants.init_constants import (
1111
USAGE_GUIDE_TEXT,
12-
CFN
12+
CFN,
13+
CRD
1314
)
1415
from sagemaker.hyperpod.cluster_management.hp_cluster_stack import HpClusterStack
1516
import json
@@ -22,21 +23,18 @@
2223
validate_config_against_model,
2324
filter_validation_errors_for_user_input,
2425
display_validation_results,
25-
extract_user_provided_args_from_cli,
2626
build_config_from_schema,
2727
save_template
2828
)
2929

3030
@click.command("init")
3131
@click.argument("template", type=click.Choice(list(TEMPLATES.keys())))
3232
@click.argument("directory", type=click.Path(file_okay=False), default=".")
33-
@click.option("--namespace", "-n", default="default", help="Namespace, default to default")
3433
@click.option("--version", "-v", default="1.0", help="Schema version")
3534
@generate_click_command(require_schema_fields=False)
3635
def init(
3736
template: str,
3837
directory: str,
39-
namespace: str,
4038
version: str,
4139
model_config, # Pydantic model from decorator
4240
):
@@ -102,7 +100,7 @@ def init(
102100
# 3) Build config dict + comment map, then write config.yaml
103101
try:
104102
# Use the common function to build config from schema
105-
full_cfg, comment_map = build_config_from_schema(template, namespace, version, model_config)
103+
full_cfg, comment_map = build_config_from_schema(template, version, model_config)
106104

107105
save_config_yaml(
108106
prefill=full_cfg,
@@ -144,16 +142,15 @@ def init(
144142
def reset():
145143
"""
146144
Reset the current directory's config.yaml to an "empty" scaffold:
147-
all schema keys set to default values (but keeping the template and namespace).
145+
all schema keys set to default values (but keeping the template and version).
148146
"""
149147
dir_path = Path(".").resolve()
150148

151149
# 1) Load and validate config
152150
data, template, version = load_config(dir_path)
153-
namespace = data.get("namespace", "default")
154151

155152
# 2) Build config with default values from schema
156-
full_cfg, comment_map = build_config_from_schema(template, namespace, version)
153+
full_cfg, comment_map = build_config_from_schema(template, version)
157154

158155
# 3) Overwrite config.yaml
159156
try:
@@ -175,67 +172,65 @@ def reset():
175172
@click.command("configure")
176173
@generate_click_command(
177174
require_schema_fields=False, # flags are all optional
178-
auto_load_config=True, # load template/namespace/version from config.yaml
175+
auto_load_config=True, # load template/version from config.yaml
179176
)
180177
@click.pass_context
181178
def configure(ctx, model_config):
182179
"""
183-
Update any subset of fields in ./config.yaml by passing
184-
--<field> flags. E.g.
185-
186-
hyp configure --model-name my-model --instance-type ml.m5.large
187-
hyp configure --namespace production
188-
hyp configure --namespace test --stage gamma
189-
"""
190-
# Extract namespace from command line arguments manually
191-
import sys
192-
namespace = None
193-
args = sys.argv
194-
for i, arg in enumerate(args):
195-
if arg in ['--namespace', '-n'] and i + 1 < len(args):
196-
namespace = args[i + 1]
197-
break
180+
Update any subset of fields in ./config.yaml by passing --<field> flags.
181+
182+
This command allows you to modify specific configuration fields without having
183+
to regenerate the entire config or fix unrelated validation issues. Only the
184+
fields you explicitly provide will be validated, making it easy to update
185+
configurations incrementally.
186+
187+
Examples:
188+
189+
# Update a single field
190+
hyp configure --hyperpod-cluster-name my-new-cluster
191+
192+
# Update multiple fields at once
193+
hyp configure --instance-type ml.g5.xlarge --endpoint-name my-endpoint
194+
195+
# Update complex fields with JSON
196+
hyp configure --tags '{"Environment": "prod", "Team": "ml"}'
198197
198+
"""
199199
# 1) Load existing config without validation
200200
dir_path = Path(".").resolve()
201201
data, template, version = load_config(dir_path)
202202

203-
# Use provided namespace or fall back to existing config namespace
204-
config_namespace = namespace if namespace is not None else data.get("namespace", "default")
205-
206-
# 2) Extract ONLY the user's input arguments by checking what was actually provided
207-
provided_args = extract_user_provided_args_from_cli()
203+
# 2) Determine which fields the user actually provided
204+
# Use Click's parameter source tracking to identify command-line provided parameters
205+
user_input_fields = set()
208206

209-
# Filter model_config to only include user-provided fields
210-
all_model_data = model_config.model_dump(exclude_none=True) if model_config else {}
211-
user_input = {k: v for k, v in all_model_data.items() if k in provided_args}
207+
if ctx and hasattr(ctx, 'params') and model_config:
208+
# Check which parameters were provided via command line (not defaults)
209+
for param_name, param_value in ctx.params.items():
210+
# Skip if the parameter source indicates it came from default
211+
param_source = ctx.get_parameter_source(param_name)
212+
if param_source and param_source.name == 'COMMANDLINE':
213+
user_input_fields.add(param_name)
212214

213-
if not user_input and namespace is None:
215+
if not user_input_fields:
214216
click.secho("⚠️ No arguments provided to configure.", fg="yellow")
215217
return
216218

217219
# 3) Build merged config with user input
218220
full_cfg, comment_map = build_config_from_schema(
219221
template=template,
220-
namespace=config_namespace,
221222
version=version,
222223
model_config=model_config,
223224
existing_config=data
224225
)
225226

226-
# 4) Validate the merged config and filter errors for user input fields only
227+
# 4) Validate the merged config, but only check user-provided fields
227228
all_validation_errors = validate_config_against_model(full_cfg, template, version)
228-
229-
# Include namespace in user input fields if it was provided
230-
user_input_fields = set(user_input.keys())
231-
if namespace is not None:
232-
user_input_fields.add("namespace")
233-
234229
user_input_errors = filter_validation_errors_for_user_input(all_validation_errors, user_input_fields)
235230

236231
is_valid = display_validation_results(
237232
user_input_errors,
238-
success_message="User input is valid!" if user_input_errors else "Merged configuration is valid!",
233+
success_message="User input is valid!" if user_input_errors else "Configuration updated successfully!",
239234
error_prefix="Invalid input arguments:"
240235
)
241236

@@ -376,6 +371,22 @@ def submit(region):
376371
from sagemaker.hyperpod.cli.commands.cluster_stack import create_cluster_stack_helper
377372
create_cluster_stack_helper(config_file=f"{out_dir}/config.yaml",
378373
region=region)
374+
else:
375+
dir_path = Path(".").resolve()
376+
data, template, version = load_config(dir_path)
377+
namespace = data.get("namespace", "default")
378+
registry = TEMPLATES[template]["registry"]
379+
model = registry.get(version)
380+
if model:
381+
filtered_config = {
382+
k: v for k, v in data.items()
383+
if k not in ('template', 'version') and v is not None
384+
}
385+
flat = model(**filtered_config)
386+
domain = flat.to_domain()
387+
domain.create(namespace=namespace)
388+
389+
379390
except Exception as e:
380391
click.secho(f"❌ Failed to sumbit the command: {e}", fg="red")
381392
sys.exit(1)

0 commit comments

Comments
 (0)