Skip to content

Commit 29fe7b1

Browse files
committed
Remove LoRA tab, move it into the Parameters menu
1 parent 214dc68 commit 29fe7b1

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

modules/LoRA.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ def add_lora_to_model(lora_name):
1010

1111
# Is there a more efficient way of returning to the base model?
1212
if lora_name == "None":
13+
print(f"Reloading the model to remove the LoRA...")
1314
shared.model, shared.tokenizer = load_model(shared.model_name)
1415
else:
16+
print(f"Adding the LoRA {lora_name} to the model...")
1517
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"))

modules/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
},
5757
'lora_prompts': {
5858
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
59-
'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a Python script that generates text using the transformers library.\n### Response:\n"
59+
'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
6060
}
6161
}
6262

server.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ def load_model_wrapper(selected_model):
6464
return selected_model
6565

6666
def load_lora_wrapper(selected_lora):
67+
shared.lora_name = selected_lora
68+
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
69+
6770
if not shared.args.cpu:
6871
gc.collect()
6972
torch.cuda.empty_cache()
7073
add_lora_to_model(selected_lora)
71-
return selected_lora
74+
75+
return selected_lora, default_text
7276

7377
def load_preset_values(preset_menu, return_dict=False):
7478
generate_params = {
@@ -156,6 +160,10 @@ def create_settings_menus(default_preset):
156160
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
157161
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
158162

163+
with gr.Row():
164+
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
165+
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
166+
159167
with gr.Accordion('Soft prompt', open=False):
160168
with gr.Row():
161169
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
@@ -167,6 +175,7 @@ def create_settings_menus(default_preset):
167175

168176
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
169177
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
178+
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
170179
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
171180
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
172181

@@ -226,8 +235,8 @@ def set_interface_arguments(interface_mode, extensions, cmd_active):
226235
shared.model_name = available_models[i]
227236
shared.model, shared.tokenizer = load_model(shared.model_name)
228237
if shared.args.lora:
238+
print(shared.args.lora)
229239
shared.lora_name = shared.args.lora
230-
print(f"Adding the LoRA {shared.lora_name} to the model...")
231240
add_lora_to_model(shared.lora_name)
232241

233242
# Default UI settings
@@ -419,19 +428,6 @@ def create_interface():
419428
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
420429
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
421430

422-
with gr.Tab("LoRA", elem_id="lora"):
423-
with gr.Row():
424-
with gr.Column():
425-
gr.Markdown("Load")
426-
with gr.Row():
427-
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
428-
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
429-
with gr.Column():
430-
gr.Markdown("Train (TODO)")
431-
gr.Button("Practice your button clicking skills")
432-
433-
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
434-
435431
with gr.Tab("Interface mode", elem_id="interface-mode"):
436432
modes = ["default", "notebook", "chat", "cai_chat"]
437433
current_mode = "default"

0 commit comments

Comments
 (0)