Spaces:
Running
Running
RAdapt pass_to_input to tasks
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
from gradio.components import Slider
|
3 |
import torch
|
4 |
-
from transformers import pipeline
|
5 |
|
6 |
# Model, information and examples ----------------------------------------------
|
7 |
model_id = "proxectonos/FLOR-1.3B-GL"
|
@@ -54,18 +54,7 @@ def predict(prompt, max_length, repetition_penalty=1.3):
|
|
54 |
return generated_sequence
|
55 |
|
56 |
# Gradio app ---------------------------------------------------------
|
57 |
-
|
58 |
-
return (
|
59 |
-
None,
|
60 |
-
None,
|
61 |
-
gr.update(value=20),
|
62 |
-
gr.update(value=1.3)
|
63 |
-
)
|
64 |
-
def pass_to_input(generated_gl):
|
65 |
-
return (
|
66 |
-
gr.update(value=generated_gl),
|
67 |
-
None,
|
68 |
-
)
|
69 |
def parameters_default(text):
|
70 |
return (
|
71 |
gr.update(value=30), # max_length
|
@@ -78,6 +67,28 @@ def parameters_fewshot_prompt(text):
|
|
78 |
gr.update(value=1) # repetition_penalty
|
79 |
)
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
def gradio_app():
|
82 |
with gr.Blocks(theme=fronted_theme) as demo:
|
83 |
with gr.Row():
|
@@ -118,7 +129,7 @@ def gradio_app():
|
|
118 |
|
119 |
generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty], outputs=generated_gl, api_name="generate-flor-gl")
|
120 |
clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
|
121 |
-
pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl], queue=False, api_name=False)
|
122 |
|
123 |
with gr.Row():
|
124 |
with gr.Column(scale=0.5):
|
|
|
1 |
import gradio as gr
|
2 |
from gradio.components import Slider
|
3 |
import torch
|
4 |
+
from transformers import pipeline
|
5 |
|
6 |
# Model, information and examples ----------------------------------------------
|
7 |
model_id = "proxectonos/FLOR-1.3B-GL"
|
|
|
54 |
return generated_sequence
|
55 |
|
56 |
# Gradio app ---------------------------------------------------------
|
57 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def parameters_default(text):
|
59 |
return (
|
60 |
gr.update(value=30), # max_length
|
|
|
67 |
gr.update(value=1) # repetition_penalty
|
68 |
)
|
69 |
|
70 |
+
def clear():
|
71 |
+
return (
|
72 |
+
None,
|
73 |
+
None,
|
74 |
+
gr.update(value=20),
|
75 |
+
gr.update(value=1.3)
|
76 |
+
)
|
77 |
+
def pass_to_input(generated_gl):
|
78 |
+
few_shot_tasks = [example.splitlines()[0] for example in few_shot_prompts_examples]
|
79 |
+
if generated_gl.splitlines()[0] in few_shot_tasks:
|
80 |
+
return (
|
81 |
+
gr.update(value=generated_gl),
|
82 |
+
None,
|
83 |
+
parameters_fewshot_prompt(generated_gl)
|
84 |
+
)
|
85 |
+
else:
|
86 |
+
return (
|
87 |
+
gr.update(value=generated_gl),
|
88 |
+
None,
|
89 |
+
parameters_fewshot_prompt(generated_gl)
|
90 |
+
)
|
91 |
+
|
92 |
def gradio_app():
|
93 |
with gr.Blocks(theme=fronted_theme) as demo:
|
94 |
with gr.Row():
|
|
|
129 |
|
130 |
generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty], outputs=generated_gl, api_name="generate-flor-gl")
|
131 |
clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
|
132 |
+
pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
|
133 |
|
134 |
with gr.Row():
|
135 |
with gr.Column(scale=0.5):
|