Skip to content

Commit 13a2049

Browse files
authored
Fix bugs in Optuna integration with the prompt2model demo script (#374)
* Fix bugs in Optuna integration * Lint
1 parent cc4995a commit 13a2049

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

prompt2model/utils/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
MAX_SUPPORTED_BATCH_SIZE = 4
33

44
DEFAULT_HYPERPARAMETERS_SPACE = {
5-
"min_num_train_epochs": 10,
6-
"max_num_train_epochs": 20,
5+
"min_num_train_epochs": 5,
6+
"max_num_train_epochs": 15,
77
"save_strategy": ["no"],
88
"evaluation_strategy": ["no"],
99
"per_device_train_batch_size": MAX_SUPPORTED_BATCH_SIZE,

prompt2model_demo.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ def parse_model_size_limit(line: str, default_size=3e9) -> float:
9191
return default_size
9292
model_units = {"B": 1e0, "KB": 1e3, "MB": 1e6, "GB": 1e9, "TB": 1e12, "PB": 1e15}
9393
unit_disambiguations = {
94-
"B": ["b", "bytes"],
9594
"KB": ["Kb", "kb", "kilobytes"],
9695
"MB": ["Mb", "mb", "megabytes"],
9796
"GB": ["Gb", "gb", "gigabytes"],
9897
"TB": ["Tb", "tb", "terabytes"],
9998
"PB": ["Pb", "pb", "petabytes"],
99+
"B": ["b", "bytes"],
100100
}
101101
unit_matched = False
102102
for unit, disambiguations in unit_disambiguations.items():
@@ -110,6 +110,7 @@ def parse_model_size_limit(line: str, default_size=3e9) -> float:
110110
numerical_part = line.strip()[: -len(unit_name)].strip()
111111
else:
112112
numerical_part = line.strip()
113+
113114
if not str.isdecimal(numerical_part):
114115
raise ValueError(
115116
"Invalid input. Please enter a number (integer " + "or number with units)."
@@ -368,28 +369,40 @@ def main():
368369

369370
if line == "y":
370371
line_print("Starting training with hyperparameter selection.")
372+
default_min_num_epochs = DEFAULT_HYPERPARAMETERS_SPACE[
373+
"min_num_train_epochs"
374+
]
375+
min_num_epochs = input(
376+
f"Enter min number of epochs. Press enter to use default value ({default_min_num_epochs}): " # noqa E501
377+
)
378+
default_max_num_epochs = DEFAULT_HYPERPARAMETERS_SPACE[
379+
"max_num_train_epochs"
380+
]
371381
max_num_epochs = input(
372-
"Enter max number of epochs: Press enter to use default value: "
382+
f"Enter max number of epochs. Press enter to use default value ({default_max_num_epochs}): " # noqa E501
373383
)
384+
default_num_trials = 10
374385
num_trials = input(
375-
"Enter the number of trials to conduct hypeparamter search. Press enter to use default value: " # noqa E501
386+
f"Enter the number of trials (maximum number of hyperparameter configurations to consider) for hyperparameter search. Press enter to use default value ({default_num_trials}): " # noqa E501
376387
)
377388
default_batch_size = DEFAULT_HYPERPARAMETERS_SPACE[
378389
"per_device_train_batch_size"
379390
] # noqa E501
380391
max_batch_size = input(
381392
"Enter the max batch size. "
382-
+ f"Press enter to use default: {default_batch_size}"
393+
+ f"Press enter to use default ({default_batch_size}): "
383394
)
384395

396+
min_num_epochs = (
397+
default_min_num_epochs if min_num_epochs == "" else eval(min_num_epochs)
398+
)
385399
max_num_epochs = (
386-
DEFAULT_HYPERPARAMETERS_SPACE["max_num_train_epochs"]
387-
if max_num_epochs == ""
388-
else eval(max_num_epochs)
400+
default_max_num_epochs if max_num_epochs == "" else eval(max_num_epochs)
389401
)
390402
num_trials = 1 if num_trials == "" else eval(num_trials)
403+
391404
max_batch_size = (
392-
DEFAULT_HYPERPARAMETERS_SPACE["per_device_train_batch_size"][0]
405+
DEFAULT_HYPERPARAMETERS_SPACE["per_device_train_batch_size"]
393406
if max_batch_size == ""
394407
else eval(max_batch_size)
395408
)
@@ -406,12 +419,13 @@ def main():
406419
line_print("Starting training.")
407420

408421
trained_model, trained_tokenizer = OptunaParamSelector(
409-
n_trial=num_trials,
422+
n_trials=num_trials,
410423
trainer=trainer,
411424
).select_from_hyperparameters(
412425
training_datasets=training_datasets,
413426
validation=validation_datasets,
414427
hyperparameters={
428+
"min_num_train_epochs": min_num_epochs,
415429
"max_num_train_epochs": max_num_epochs,
416430
"per_device_train_batch_size": [max_batch_size],
417431
},

0 commit comments

Comments
 (0)