Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -136,11 +136,14 @@ def generate(input_text, number_steps, number_beams, number_beam_groups, diversi
|
|
136 |
|
137 |
#--------ON SELECTING MODEL------------------------
|
138 |
|
139 |
-
def
|
|
|
|
|
140 |
|
141 |
if model_selected == "GPT2":
|
142 |
-
|
143 |
-
|
|
|
144 |
#print (model_selected + " loaded")
|
145 |
|
146 |
#if model_selected == "Gemma 2":
|
@@ -148,10 +151,19 @@ def load_model(model_selected):
|
|
148 |
#model = model_gemma
|
149 |
|
150 |
if model_selected == "Qwen2":
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
153 |
|
|
|
|
|
|
|
|
|
|
|
154 |
|
|
|
155 |
|
156 |
#--------ON SELECT NO. OF RETURN SEQUENCES----------
|
157 |
|
@@ -313,7 +325,7 @@ with gr.Blocks() as demo:
|
|
313 |
|
314 |
model_selected = gr.Radio (["GPT2", "Qwen2"], label="ML Model", value="GPT2")
|
315 |
strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
|
316 |
-
|
317 |
|
318 |
with gr.Column(scale=1):
|
319 |
|
@@ -390,10 +402,10 @@ with gr.Blocks() as demo:
|
|
390 |
|
391 |
|
392 |
|
393 |
-
#----------ON SELECTING/CHANGING: RETURN SEEQUENCES/NO OF BEAMS/BEAM GROUPS/TEMPERATURE--------
|
394 |
|
395 |
model_selected.change(
|
396 |
-
fn=
|
397 |
)
|
398 |
|
399 |
#num_return_sequences.change(
|
@@ -421,6 +433,12 @@ with gr.Blocks() as demo:
|
|
421 |
outputs=[out_markdown]
|
422 |
)
|
423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
with gr.Row():
|
426 |
|
|
|
136 |
|
137 |
#--------ON SELECTING MODEL------------------------
|
138 |
|
139 |
+
def select_model(model_selected):
|
140 |
+
|
141 |
+
global model_name
|
142 |
|
143 |
if model_selected == "GPT2":
|
144 |
+
model_name = "openai-community/gpt2"
|
145 |
+
#tokenizer = tokenizer_gpt2
|
146 |
+
#model = model_gpt2
|
147 |
#print (model_selected + " loaded")
|
148 |
|
149 |
#if model_selected == "Gemma 2":
|
|
|
151 |
#model = model_gemma
|
152 |
|
153 |
if model_selected == "Qwen2":
|
154 |
+
model_name = "Qwen/Qwen2-0.5B"
|
155 |
+
#tokenizer = tokenizer_qwen
|
156 |
+
#model = model_qwen
|
157 |
+
|
158 |
+
# On clicking load button
|
159 |
|
160 |
+
def load_model ():
|
161 |
+
|
162 |
+
global model_name
|
163 |
+
tokenizer_gpt2 = AutoTokenizer.from_pretrained(model_name)
|
164 |
+
model_gpt2 = AutoModelForCausalLM.from_pretrained(model_name)
|
165 |
|
166 |
+
|
167 |
|
168 |
#--------ON SELECT NO. OF RETURN SEQUENCES----------
|
169 |
|
|
|
325 |
|
326 |
model_selected = gr.Radio (["GPT2", "Qwen2"], label="ML Model", value="GPT2")
|
327 |
strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
|
328 |
+
load_model_button = gr.Button("Load")
|
329 |
|
330 |
with gr.Column(scale=1):
|
331 |
|
|
|
402 |
|
403 |
|
404 |
|
405 |
+
#----------ON SELECTING/CHANGING: RETURN SEEQUENCES/NO OF BEAMS/BEAM GROUPS/TEMPERATURE--------
|
406 |
|
407 |
model_selected.change(
|
408 |
+
fn=select_model, inputs=[model_selected], outputs=[]
|
409 |
)
|
410 |
|
411 |
#num_return_sequences.change(
|
|
|
433 |
outputs=[out_markdown]
|
434 |
)
|
435 |
|
436 |
+
load_model_button.click(
|
437 |
+
fn=load_model,
|
438 |
+
inputs=[model_selected],
|
439 |
+
outputs=[]
|
440 |
+
)
|
441 |
+
|
442 |
|
443 |
with gr.Row():
|
444 |
|