pablo-rf commited on
Commit
589e0e8
1 Parent(s): 29e8e30

[ADD] Temperature parameter

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -36,14 +36,15 @@ def remove_empty_lines(text):
36
  non_empty_lines = [line for line in lines if line.strip()]
37
  return "\n".join(non_empty_lines)
38
 
39
- def predict(prompt, max_length, repetition_penalty=1.3):
40
  print("Dentro da xeración...")
41
  prompt_length = len(generator_model.tokenizer.encode(prompt))
42
  generated_text = generator_model(
43
  prompt,
44
  max_length=prompt_length + max_length,
45
  pad_token_id=generator_model.tokenizer.eos_token_id,
46
- repetition_penalty=repetition_penalty)
 
47
 
48
  generated_sequence = generated_text[0]['generated_text']
49
  if generated_sequence is None:
@@ -60,7 +61,8 @@ def clear():
60
  None,
61
  None,
62
  gr.update(value=20),
63
- gr.update(value=1.3)
 
64
  )
65
  def pass_to_input(generated_gl):
66
  return (
@@ -71,13 +73,15 @@ def pass_to_input(generated_gl):
71
  def parameters_default(text):
72
  return (
73
  gr.update(value=30), # max_length
74
- gr.update(value=1.3) # repetition_penalty
 
75
  )
76
 
77
  def parameters_fewshot_prompt(text):
78
  return (
79
  gr.update(value=15), # max_length
80
- gr.update(value=1) # repetition_penalty
 
81
  )
82
 
83
  def gradio_app():
@@ -108,6 +112,12 @@ def gradio_app():
108
  value=1.3,
109
  label="Repetition penalty"
110
  )
 
 
 
 
 
 
111
  generator_btn = gr.Button(value="Generate",variant='primary')
112
  with gr.Column():
113
  generated_gl = gr.Textbox(label="Output",
@@ -118,8 +128,8 @@ def gradio_app():
118
  pass_btn = gr.Button(value="Pass text to input")
119
  clean_btn = gr.Button(value="Clean")
120
 
121
- generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty], outputs=generated_gl, api_name="generate-flor-gl")
122
- clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty], queue=False, api_name=False)
123
  pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl], queue=False, api_name=False)
124
 
125
  with gr.Row():
@@ -128,7 +138,7 @@ def gradio_app():
128
  label = "Short prompts",
129
  examples = short_prompts_examples,
130
  inputs = [text_gl],
131
- outputs = [max_length, repetition_penalty],
132
  fn = parameters_default,
133
  run_on_click = True
134
  )
@@ -136,7 +146,7 @@ def gradio_app():
136
  label = "Few-shot prompts",
137
  examples = few_shot_prompts_examples,
138
  inputs = [text_gl],
139
- outputs = [max_length, repetition_penalty],
140
  fn = parameters_fewshot_prompt,
141
  run_on_click = True
142
  )
 
36
  non_empty_lines = [line for line in lines if line.strip()]
37
  return "\n".join(non_empty_lines)
38
 
39
+ def predict(prompt, max_length, repetition_penalty, temperature):
40
  print("Dentro da xeración...")
41
  prompt_length = len(generator_model.tokenizer.encode(prompt))
42
  generated_text = generator_model(
43
  prompt,
44
  max_length=prompt_length + max_length,
45
  pad_token_id=generator_model.tokenizer.eos_token_id,
46
+ repetition_penalty=repetition_penalty,
47
+ temperature=temperature)
48
 
49
  generated_sequence = generated_text[0]['generated_text']
50
  if generated_sequence is None:
 
61
  None,
62
  None,
63
  gr.update(value=20),
64
+ gr.update(value=1.3),
65
+ gr.update(value=0.5)
66
  )
67
  def pass_to_input(generated_gl):
68
  return (
 
73
  def parameters_default(text):
74
  return (
75
  gr.update(value=30), # max_length
76
+ gr.update(value=1.3), # repetition_penalty
77
+ gr.update(value=0.5) # temperature
78
  )
79
 
80
  def parameters_fewshot_prompt(text):
81
  return (
82
  gr.update(value=15), # max_length
83
+ gr.update(value=1), # repetition_penalty
84
+ gr.update(value=0.5) # temperature
85
  )
86
 
87
  def gradio_app():
 
112
  value=1.3,
113
  label="Repetition penalty"
114
  )
115
+ temperature = Slider(
116
+ minimum=0,
117
+ maximum=1,
118
+ value=0.5,
119
+ label="Temperaturr"
120
+ )
121
  generator_btn = gr.Button(value="Generate",variant='primary')
122
  with gr.Column():
123
  generated_gl = gr.Textbox(label="Output",
 
128
  pass_btn = gr.Button(value="Pass text to input")
129
  clean_btn = gr.Button(value="Clean")
130
 
131
+ generator_btn.click(predict, inputs=[text_gl,max_length, repetition_penalty, temperature], outputs=generated_gl, api_name="generate-flor-gl")
132
+ clean_btn.click(fn=clear, inputs=[], outputs=[text_gl, generated_gl, max_length, repetition_penalty, temperature], queue=False, api_name=False)
133
  pass_btn.click(fn=pass_to_input, inputs=[generated_gl], outputs=[text_gl,generated_gl], queue=False, api_name=False)
134
 
135
  with gr.Row():
 
138
  label = "Short prompts",
139
  examples = short_prompts_examples,
140
  inputs = [text_gl],
141
+ outputs = [max_length, repetition_penalty, temperature],
142
  fn = parameters_default,
143
  run_on_click = True
144
  )
 
146
  label = "Few-shot prompts",
147
  examples = few_shot_prompts_examples,
148
  inputs = [text_gl],
149
+ outputs = [max_length, repetition_penalty, temperature],
150
  fn = parameters_fewshot_prompt,
151
  run_on_click = True
152
  )