pablo-rf commited on
Commit
9f8226f
1 Parent(s): 0aa6a75

RAdapt pass_to_input to tasks

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from gradio.components import Slider
3
  import torch
4
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
5
 
6
  # Model, information and examples ----------------------------------------------
7
  model_id = "proxectonos/FLOR-1.3B-GL"
@@ -54,18 +54,7 @@ def predict(prompt, max_length, repetition_penalty=1.3):
54
  return generated_sequence
55
 
56
  # Gradio app ---------------------------------------------------------
57
- def clear():
58
- return (
59
- None,
60
- None,
61
- gr.update(value=20),
62
- gr.update(value=1.3)
63
- )
64
- def pass_to_input(generated_gl):
65
- return (
66
- gr.update(value=generated_gl),
67
- None,
68
- )
69
  def parameters_default(text):
70
  return (
71
  gr.update(value=30), # max_length
@@ -78,6 +67,28 @@ def parameters_fewshot_prompt(text):
78
  gr.update(value=1) # repetition_penalty
79
  )
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def gradio_app():
82
  with gr.Blocks(theme=fronted_theme) as demo:
83
  with gr.Row():
@@ -118,7 +129,7 @@ def gradio_app():
118
 
119
  generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty], outputs=generated_gl, api_name="generate-flor-gl")
120
  clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
121
- pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl], queue=False, api_name=False)
122
 
123
  with gr.Row():
124
  with gr.Column(scale=0.5):
 
1
  import gradio as gr
2
  from gradio.components import Slider
3
  import torch
4
+ from transformers import pipeline
5
 
6
  # Model, information and examples ----------------------------------------------
7
  model_id = "proxectonos/FLOR-1.3B-GL"
 
54
  return generated_sequence
55
 
56
  # Gradio app ---------------------------------------------------------
57
+
 
 
 
 
 
 
 
 
 
 
 
58
  def parameters_default(text):
59
  return (
60
  gr.update(value=30), # max_length
 
67
  gr.update(value=1) # repetition_penalty
68
  )
69
 
70
+ def clear():
71
+ return (
72
+ None,
73
+ None,
74
+ gr.update(value=20),
75
+ gr.update(value=1.3)
76
+ )
77
+ def pass_to_input(generated_gl):
78
+ few_shot_tasks = [example.splitlines()[0] for example in few_shot_prompts_examples]
79
+ if generated_gl.splitlines()[0] in few_shot_tasks:
80
+ return (
81
+ gr.update(value=generated_gl),
82
+ None,
83
+ parameters_fewshot_prompt(generated_gl)
84
+ )
85
+ else:
86
+ return (
87
+ gr.update(value=generated_gl),
88
+ None,
89
+ parameters_fewshot_prompt(generated_gl)
90
+ )
91
+
92
  def gradio_app():
93
  with gr.Blocks(theme=fronted_theme) as demo:
94
  with gr.Row():
 
129
 
130
  generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty], outputs=generated_gl, api_name="generate-flor-gl")
131
  clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
132
+ pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
133
 
134
  with gr.Row():
135
  with gr.Column(scale=0.5):