Spaces:
Runtime error
Runtime error
Add log for preiously generated texts
Browse files- gradio_app.py +18 -11
gradio_app.py
CHANGED
@@ -159,24 +159,31 @@ class TextGeneration:
|
|
159 |
# return generated
|
160 |
|
161 |
|
162 |
-
def generate(self, text, generation_kwargs):
|
163 |
-
|
|
|
164 |
generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
|
165 |
generated_text = None
|
166 |
-
if
|
167 |
for _ in range(10):
|
168 |
generated_text = self.generator(
|
169 |
-
|
170 |
**generation_kwargs,
|
171 |
)[0]["generated_text"]
|
172 |
if generation_kwargs["do_clean"]:
|
173 |
generated_text = cleaner.clean_txt(generated_text)
|
174 |
-
if generated_text.strip().startswith(
|
175 |
-
generated_text = generated_text.replace(
|
176 |
if generated_text:
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
return (
|
178 |
-
|
179 |
-
|
180 |
)
|
181 |
if not generated_text:
|
182 |
return (
|
@@ -221,7 +228,7 @@ def expand_with_gpt(hidden, text, max_length, top_k, top_p, temperature, do_samp
|
|
221 |
"do_sample": do_sample,
|
222 |
"do_clean": do_clean,
|
223 |
}
|
224 |
-
return generator.generate(
|
225 |
|
226 |
def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean):
|
227 |
# agent = AGENT
|
@@ -339,7 +346,7 @@ with gr.Blocks() as demo:
|
|
339 |
hidden = gr.Textbox(visible=False, show_label=False)
|
340 |
with gr.Box():
|
341 |
# output = gr.Markdown()
|
342 |
-
output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={
|
343 |
with gr.Row():
|
344 |
generate_btn = gr.Button("Generar")
|
345 |
generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output])
|
@@ -358,7 +365,7 @@ with gr.Blocks() as demo:
|
|
358 |
with gr.Row():
|
359 |
agent = gr.Textbox(label="Agente", value=AGENT)
|
360 |
user = gr.Textbox(label="Usuario", value=USER)
|
361 |
-
history = gr.Variable(
|
362 |
chatbot = gr.Chatbot(color_map=("green", "gray"))
|
363 |
with gr.Row():
|
364 |
message = gr.Textbox(placeholder="Escriba aquí su mensaje y pulse 'Enviar'", show_label=False)
|
|
|
159 |
# return generated
|
160 |
|
161 |
|
162 |
+
def generate(self, text, generation_kwargs, previous_text=None):
|
163 |
+
input_text = previous_text or text
|
164 |
+
max_length = len(self.tokenizer(input_text)["input_ids"]) + generation_kwargs["max_length"]
|
165 |
generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions)
|
166 |
generated_text = None
|
167 |
+
if input_text:
|
168 |
for _ in range(10):
|
169 |
generated_text = self.generator(
|
170 |
+
input_text,
|
171 |
**generation_kwargs,
|
172 |
)[0]["generated_text"]
|
173 |
if generation_kwargs["do_clean"]:
|
174 |
generated_text = cleaner.clean_txt(generated_text)
|
175 |
+
if generated_text.strip().startswith(input_text):
|
176 |
+
generated_text = generated_text.replace(input_text, "", 1).strip()
|
177 |
if generated_text:
|
178 |
+
if previous_text and previous_text != text:
|
179 |
+
diff = [
|
180 |
+
(text, None), (previous_text.replace(text, " ", 1).strip(), " "), (generated_text, AGENT)
|
181 |
+
]
|
182 |
+
else:
|
183 |
+
diff = [(text, None), (generated_text, AGENT)]
|
184 |
return (
|
185 |
+
input_text + " " + generated_text,
|
186 |
+
diff
|
187 |
)
|
188 |
if not generated_text:
|
189 |
return (
|
|
|
228 |
"do_sample": do_sample,
|
229 |
"do_clean": do_clean,
|
230 |
}
|
231 |
+
return generator.generate(text, generation_kwargs, previous_text=hidden)
|
232 |
|
233 |
def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k, top_p, temperature, do_sample, do_clean):
|
234 |
# agent = AGENT
|
|
|
346 |
hidden = gr.Textbox(visible=False, show_label=False)
|
347 |
with gr.Box():
|
348 |
# output = gr.Markdown()
|
349 |
+
output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={AGENT: "green", "ERROR": "red", " ": "blue"})
|
350 |
with gr.Row():
|
351 |
generate_btn = gr.Button("Generar")
|
352 |
generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output])
|
|
|
365 |
with gr.Row():
|
366 |
agent = gr.Textbox(label="Agente", value=AGENT)
|
367 |
user = gr.Textbox(label="Usuario", value=USER)
|
368 |
+
history = gr.Variable(value=[])
|
369 |
chatbot = gr.Chatbot(color_map=("green", "gray"))
|
370 |
with gr.Row():
|
371 |
message = gr.Textbox(placeholder="Escriba aquí su mensaje y pulse 'Enviar'", show_label=False)
|