gsarti commited on
Commit
2a644e6
Β·
1 Parent(s): 2b66ced

Load model after preset

Browse files
Files changed (1) hide show
  1. app.py +28 -12
app.py CHANGED
@@ -24,7 +24,7 @@ from presets import (
24
  from style import custom_css
25
  from utils import get_formatted_attribute_context_results
26
 
27
- from inseq import list_feature_attribution_methods, list_step_functions, load_model
28
  from inseq.commands.attribute_context.attribute_context import (
29
  AttributeContextArgs,
30
  attribute_context_with_model,
@@ -65,7 +65,7 @@ def pecore(
65
  )
66
  if loaded_model is None or model_name_or_path != loaded_model.model_name:
67
  gr.Info("Loading model...")
68
- loaded_model = load_model(
69
  model_name_or_path,
70
  attribution_method,
71
  model_kwargs=json.loads(model_kwargs),
@@ -130,7 +130,7 @@ def preload_model(
130
  global loaded_model
131
  if loaded_model is None or model_name_or_path != loaded_model.model_name:
132
  gr.Info("Loading model...")
133
- loaded_model = load_model(
134
  model_name_or_path,
135
  attribution_method,
136
  model_kwargs=json.loads(model_kwargs),
@@ -192,7 +192,9 @@ with gr.Blocks(css=custom_css) as demo:
192
  outputs=pecore_output_highlights,
193
  )
194
  with gr.Tab("βš™οΈ Parameters") as params_tab:
195
- gr.Markdown("## ✨ Presets")
 
 
196
  with gr.Row(equal_height=True):
197
  with gr.Column():
198
  default_preset = gr.Button("Default", variant="secondary")
@@ -218,7 +220,7 @@ with gr.Blocks(css=custom_css) as demo:
218
  "Present for multilingual MT models such as <a href='https://huggingface.co/facebook/nllb-200-distilled-600M' target='_blank'>NLLB</a> and <a href='https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt' target='_blank'>mBART</a> using language tags."
219
  )
220
  with gr.Column(scale=1):
221
- chatml_template = gr.Button("ChatML Template", variant="secondary")
222
  gr.Markdown(
223
  "Preset for models using the <a href='https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md' target='_blank'>ChatML conversational template</a>.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
224
  )
@@ -401,6 +403,15 @@ with gr.Blocks(css=custom_css) as demo:
401
  gr.Markdown(how_to_use)
402
  gr.Markdown(citation)
403
 
 
 
 
 
 
 
 
 
 
404
  attribute_input_button.click(
405
  pecore,
406
  inputs=[
@@ -435,7 +446,7 @@ with gr.Blocks(css=custom_css) as demo:
435
 
436
  load_model_button.click(
437
  preload_model,
438
- inputs=[model_name_or_path, attribution_method, model_kwargs, tokenizer_kwargs],
439
  outputs=[],
440
  )
441
 
@@ -461,11 +472,13 @@ with gr.Blocks(css=custom_css) as demo:
461
 
462
  # Presets
463
 
464
- default_preset.click(**reset_kwargs)
 
465
  cora_preset.click(**reset_kwargs).then(
466
  set_cora_preset,
467
  outputs=[model_name_or_path, input_template, contextless_input_current_text],
468
- )
 
469
  zephyr_preset.click(**reset_kwargs).then(
470
  set_zephyr_preset,
471
  outputs=[
@@ -474,11 +487,13 @@ with gr.Blocks(css=custom_css) as demo:
474
  contextless_input_current_text,
475
  decoder_input_output_separator,
476
  ],
477
- )
 
478
  multilingual_mt_template.click(**reset_kwargs).then(
479
  set_mmt_preset,
480
  outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs],
481
- )
 
482
  chatml_template.click(**reset_kwargs).then(
483
  set_chatml_preset,
484
  outputs=[
@@ -488,7 +503,8 @@ with gr.Blocks(css=custom_css) as demo:
488
  decoder_input_output_separator,
489
  special_tokens_to_keep,
490
  ],
491
- )
 
492
  towerinstruct_template.click(**reset_kwargs).then(
493
  set_towerinstruct_preset,
494
  outputs=[
@@ -497,6 +513,6 @@ with gr.Blocks(css=custom_css) as demo:
497
  contextless_input_current_text,
498
  decoder_input_output_separator,
499
  ],
500
- )
501
 
502
  demo.launch(allowed_paths=["outputs/"])
 
24
  from style import custom_css
25
  from utils import get_formatted_attribute_context_results
26
 
27
+ from inseq import list_feature_attribution_methods, list_step_functions
28
  from inseq.commands.attribute_context.attribute_context import (
29
  AttributeContextArgs,
30
  attribute_context_with_model,
 
65
  )
66
  if loaded_model is None or model_name_or_path != loaded_model.model_name:
67
  gr.Info("Loading model...")
68
+ loaded_model = HuggingfaceModel.load(
69
  model_name_or_path,
70
  attribution_method,
71
  model_kwargs=json.loads(model_kwargs),
 
130
  global loaded_model
131
  if loaded_model is None or model_name_or_path != loaded_model.model_name:
132
  gr.Info("Loading model...")
133
+ loaded_model = HuggingfaceModel.load(
134
  model_name_or_path,
135
  attribution_method,
136
  model_kwargs=json.loads(model_kwargs),
 
192
  outputs=pecore_output_highlights,
193
  )
194
  with gr.Tab("βš™οΈ Parameters") as params_tab:
195
+ gr.Markdown(
196
+ "## ✨ Presets\nSelect a preset to load default parameters into the fields below."
197
+ )
198
  with gr.Row(equal_height=True):
199
  with gr.Column():
200
  default_preset = gr.Button("Default", variant="secondary")
 
220
  "Present for multilingual MT models such as <a href='https://huggingface.co/facebook/nllb-200-distilled-600M' target='_blank'>NLLB</a> and <a href='https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt' target='_blank'>mBART</a> using language tags."
221
  )
222
  with gr.Column(scale=1):
223
+ chatml_template = gr.Button("Qwen ChatML", variant="secondary")
224
  gr.Markdown(
225
  "Preset for models using the <a href='https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md' target='_blank'>ChatML conversational template</a>.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
226
  )
 
403
  gr.Markdown(how_to_use)
404
  gr.Markdown(citation)
405
 
406
+ # Main logic
407
+
408
+ load_model_args = [
409
+ model_name_or_path,
410
+ attribution_method,
411
+ model_kwargs,
412
+ tokenizer_kwargs,
413
+ ]
414
+
415
  attribute_input_button.click(
416
  pecore,
417
  inputs=[
 
446
 
447
  load_model_button.click(
448
  preload_model,
449
+ inputs=load_model_args,
450
  outputs=[],
451
  )
452
 
 
472
 
473
  # Presets
474
 
475
+ default_preset.click(**reset_kwargs).success(preload_model, inputs=load_model_args)
476
+
477
  cora_preset.click(**reset_kwargs).then(
478
  set_cora_preset,
479
  outputs=[model_name_or_path, input_template, contextless_input_current_text],
480
+ ).success(preload_model, inputs=load_model_args)
481
+
482
  zephyr_preset.click(**reset_kwargs).then(
483
  set_zephyr_preset,
484
  outputs=[
 
487
  contextless_input_current_text,
488
  decoder_input_output_separator,
489
  ],
490
+ ).success(preload_model, inputs=load_model_args)
491
+
492
  multilingual_mt_template.click(**reset_kwargs).then(
493
  set_mmt_preset,
494
  outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs],
495
+ ).success(preload_model, inputs=load_model_args)
496
+
497
  chatml_template.click(**reset_kwargs).then(
498
  set_chatml_preset,
499
  outputs=[
 
503
  decoder_input_output_separator,
504
  special_tokens_to_keep,
505
  ],
506
+ ).success(preload_model, inputs=load_model_args)
507
+
508
  towerinstruct_template.click(**reset_kwargs).then(
509
  set_towerinstruct_preset,
510
  outputs=[
 
513
  contextless_input_current_text,
514
  decoder_input_output_separator,
515
  ],
516
+ ).success(preload_model, inputs=load_model_args)
517
 
518
  demo.launch(allowed_paths=["outputs/"])