Amitontheweb commited on
Commit
2f5d998
·
verified ·
1 Parent(s): e0c4432

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -8
app.py CHANGED
@@ -136,11 +136,14 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
136
 
137
  #--------ON SELECTING MODEL------------------------
138
 
139
- def load_model(model_selected):
 
 
140
 
141
  if model_selected == "GPT2":
142
- tokenizer = tokenizer_gpt2
143
- model = model_gpt2
 
144
  #print (model_selected + " loaded")
145
 
146
  #if model_selected == "Gemma 2":
@@ -148,10 +151,19 @@ def load_model(model_selected):
148
  #model = model_gemma
149
 
150
  if model_selected == "Qwen2":
151
- tokenizer = tokenizer_qwen
152
- model = model_qwen
 
 
 
153
 
 
 
 
 
 
154
 
 
155
 
156
  #--------ON SELECT NO. OF RETURN SEQUENCES----------
157
 
@@ -313,7 +325,7 @@ with gr.Blocks() as demo:
313
 
314
  model_selected = gr.Radio (["GPT2", "Qwen2"], label="ML Model", value="GPT2")
315
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
316
-
317
 
318
  with gr.Column(scale=1):
319
 
@@ -390,10 +402,10 @@ with gr.Blocks() as demo:
390
 
391
 
392
 
393
- #----------ON SELECTING/CHANGING: RETURN SEEQUENCES/NO OF BEAMS/BEAM GROUPS/TEMPERATURE--------
394
 
395
  model_selected.change(
396
- fn=load_model, inputs=[model_selected], outputs=[]
397
  )
398
 
399
  #num_return_sequences.change(
@@ -421,6 +433,12 @@ with gr.Blocks() as demo:
421
  outputs=[out_markdown]
422
  )
423
 
 
 
 
 
 
 
424
 
425
  with gr.Row():
426
 
 
136
 
137
  #--------ON SELECTING MODEL------------------------
138
 
139
+ def select_model(model_selected):
140
+
141
+ global model_name
142
 
143
  if model_selected == "GPT2":
144
+ model_name = "openai-community/gpt2"
145
+ #tokenizer = tokenizer_gpt2
146
+ #model = model_gpt2
147
  #print (model_selected + " loaded")
148
 
149
  #if model_selected == "Gemma 2":
 
151
  #model = model_gemma
152
 
153
  if model_selected == "Qwen2":
154
+ model_name = "Qwen/Qwen2-0.5B"
155
+ #tokenizer = tokenizer_qwen
156
+ #model = model_qwen
157
+
158
+ # On clicking load button
159
 
160
+ def load_model ():
161
+
162
+ global model_name
163
+ tokenizer_gpt2 = AutoTokenizer.from_pretrained(model_name)
164
+ model_gpt2 = AutoModelForCausalLM.from_pretrained(model_name)
165
 
166
+
167
 
168
  #--------ON SELECT NO. OF RETURN SEQUENCES----------
169
 
 
325
 
326
  model_selected = gr.Radio (["GPT2", "Qwen2"], label="ML Model", value="GPT2")
327
  strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
328
+ load_model_button = gr.Button("Load")
329
 
330
  with gr.Column(scale=1):
331
 
 
402
 
403
 
404
 
405
+ #----------ON SELECTING/CHANGING: RETURN SEEQUENCES/NO OF BEAMS/BEAM GROUPS/TEMPERATURE--------
406
 
407
  model_selected.change(
408
+ fn=select_model, inputs=[model_selected], outputs=[]
409
  )
410
 
411
  #num_return_sequences.change(
 
433
  outputs=[out_markdown]
434
  )
435
 
436
+ load_model_button.click(
437
+ fn=load_model,
438
+ inputs=[model_selected],
439
+ outputs=[]
440
+ )
441
+
442
 
443
  with gr.Row():
444