pablo-rf commited on
Commit
b52e67d
1 Parent(s): a289a30

Undo refactor pass_to_input

Browse files
Files changed (1) hide show
  1. app.py +13 -23
app.py CHANGED
@@ -54,6 +54,18 @@ def predict(prompt, max_length, repetition_penalty=1.3):
54
  return generated_sequence
55
 
56
  # Gradio app ---------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def parameters_default(text):
59
  return (
@@ -67,28 +79,6 @@ def parameters_fewshot_prompt(text):
67
  gr.update(value=1) # repetition_penalty
68
  )
69
 
70
- def clear():
71
- return (
72
- None,
73
- None,
74
- gr.update(value=20),
75
- gr.update(value=1.3)
76
- )
77
- def pass_to_input(generated_gl):
78
- few_shot_tasks = [example[0].splitlines()[0] for example in few_shot_prompts_examples]
79
- if generated_gl.splitlines()[0] in few_shot_tasks:
80
- return (
81
- gr.update(value=generated_gl),
82
- None,
83
- parameters_fewshot_prompt(generated_gl)
84
- )
85
- else:
86
- return (
87
- gr.update(value=generated_gl),
88
- None,
89
- parameters_fewshot_prompt(generated_gl)
90
- )
91
-
92
  def gradio_app():
93
  with gr.Blocks(theme=fronted_theme) as demo:
94
  with gr.Row():
@@ -129,7 +119,7 @@ def gradio_app():
129
 
130
  generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty], outputs=generated_gl, api_name="generate-flor-gl")
131
  clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
132
- pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
133
 
134
  with gr.Row():
135
  with gr.Column(scale=0.5):
 
54
  return generated_sequence
55
 
56
  # Gradio app ---------------------------------------------------------
57
+ def clear():
58
+ return (
59
+ None,
60
+ None,
61
+ gr.update(value=20),
62
+ gr.update(value=1.3)
63
+ )
64
+ def pass_to_input(generated_gl):
65
+ return (
66
+ gr.update(value=generated_gl),
67
+ None
68
+ )
69
 
70
  def parameters_default(text):
71
  return (
 
79
  gr.update(value=1) # repetition_penalty
80
  )
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def gradio_app():
83
  with gr.Blocks(theme=fronted_theme) as demo:
84
  with gr.Row():
 
119
 
120
  generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty], outputs=generated_gl, api_name="generate-flor-gl")
121
  clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
122
+ pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl], queue=False, api_name=False)
123
 
124
  with gr.Row():
125
  with gr.Column(scale=0.5):