versae commited on
Commit
7e65dc3
1 Parent(s): 95792cc

Add API support

Browse files
Files changed (1) hide show
  1. gradio_app.py +34 -15
gradio_app.py CHANGED
@@ -7,6 +7,31 @@ import torch
7
  from transformers import pipeline, set_seed
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  logger = logging.getLogger()
11
  logger.addHandler(logging.StreamHandler())
12
 
@@ -149,27 +174,18 @@ class TextGeneration:
149
  generated_text = generated_text.replace(text, "", 1).strip()
150
  if generated_text:
151
  return (
152
- text,
153
  text + " " + generated_text,
154
  [(text, None), (generated_text, "BERTIN")]
155
  )
156
  if not generated_text:
157
  return (
158
- "",
159
  "",
160
  [("Tras 10 intentos BERTIN no gener贸 nada. Pruebe cambiando las opciones.", "ERROR")]
161
  )
162
  return (
163
- "",
164
  "",
165
  [("Debe escribir algo primero.", "ERROR")]
166
  )
167
- # return (text + " " + generated_text,
168
- # f'<p class="ltr ltr-box">'
169
- # f'<span class="result-text">{text} <span>'
170
- # f'<span class="result-text generated-text">{generated_text}</span>'
171
- # f'</p>'
172
- # )
173
 
174
 
175
  #@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
@@ -234,7 +250,7 @@ def chat_with_gpt(user, agent, context, user_message, history, max_length, top_k
234
  break
235
  context += history_context
236
  for _ in range(5):
237
- response = generator.generate(f"{context}\n\n{user}: {message}.\n", generation_kwargs)[1]
238
  if DEBUG:
239
  print("\n-----" + response + "-----\n")
240
  response = response.split("\n")[-1]
@@ -325,9 +341,9 @@ with gr.Blocks() as demo:
325
  output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"})
326
  with gr.Row():
327
  generate_btn = gr.Button("Generar")
328
- generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output])
329
  expand_btn = gr.Button("A帽adir")
330
- expand_btn.click(expand_with_gpt, inputs=[hidden, textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[textbox, hidden, output])
331
 
332
  edit_btn = gr.Button("Editar", variant="secondary")
333
  edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output])
@@ -346,10 +362,13 @@ with gr.Blocks() as demo:
346
  with gr.Row():
347
  message = gr.Textbox(placeholder="Escriba aqu铆 su mensaje y pulse 'Enviar'", show_label=False)
348
  chat_btn = gr.Button("Enviar")
349
- chat_btn.click(chat_with_gpt, inputs=[agent, user, context, message, history, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[chatbot, history, message])
350
  gr.Markdown(FOOTER)
351
 
352
-
 
 
 
 
353
 
354
  demo.launch()
355
- # gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch()
 
7
  from transformers import pipeline, set_seed
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  import logging
10
+
11
+
12
+ # Monkey patch
13
+ import inspect
14
+ from gradio import routes
15
+ from typing import List, Type
16
+
17
+ def get_types(cls_set: List[Type], component: str):
18
+ docset = []
19
+ types = []
20
+ if component == "input":
21
+ for cls in cls_set:
22
+ doc = inspect.getdoc(cls)
23
+ doc_lines = doc.split("\n")
24
+ docset.append(doc_lines[1].split(":")[-1])
25
+ types.append(doc_lines[1].split(")")[0].split("(")[-1])
26
+ else:
27
+ for cls in cls_set:
28
+ doc = inspect.getdoc(cls)
29
+ doc_lines = doc.split("\n")
30
+ docset.append(doc_lines[-1].split(":")[-1])
31
+ types.append(doc_lines[-1].split(")")[0].split("(")[-1])
32
+ return docset, types
33
+ routes.get_types = get_types
34
+
35
  logger = logging.getLogger()
36
  logger.addHandler(logging.StreamHandler())
37
 
 
174
  generated_text = generated_text.replace(text, "", 1).strip()
175
  if generated_text:
176
  return (
 
177
  text + " " + generated_text,
178
  [(text, None), (generated_text, "BERTIN")]
179
  )
180
  if not generated_text:
181
  return (
 
182
  "",
183
  [("Tras 10 intentos BERTIN no gener贸 nada. Pruebe cambiando las opciones.", "ERROR")]
184
  )
185
  return (
 
186
  "",
187
  [("Debe escribir algo primero.", "ERROR")]
188
  )
 
 
 
 
 
 
189
 
190
 
191
  #@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
 
250
  break
251
  context += history_context
252
  for _ in range(5):
253
+ response = generator.generate(f"{context}\n\n{user}: {message}.\n", generation_kwargs)[0]
254
  if DEBUG:
255
  print("\n-----" + response + "-----\n")
256
  response = response.split("\n")[-1]
 
341
  output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"})
342
  with gr.Row():
343
  generate_btn = gr.Button("Generar")
344
+ generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output])
345
  expand_btn = gr.Button("A帽adir")
346
+ expand_btn.click(expand_with_gpt, inputs=[hidden, textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output])
347
 
348
  edit_btn = gr.Button("Editar", variant="secondary")
349
  edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output])
 
362
  with gr.Row():
363
  message = gr.Textbox(placeholder="Escriba aqu铆 su mensaje y pulse 'Enviar'", show_label=False)
364
  chat_btn = gr.Button("Enviar")
365
+ chat_btn.click(chat_with_gpt, inputs=[agent, user, context, message, history, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[chatbot, history, message]))
366
  gr.Markdown(FOOTER)
367
 
368
+ with gr.Interface(lambda: None, inputs=["text", max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]) as iface:
369
+ demo.examples = None
370
+ demo.predict_durations = []
371
+ demo.input_components = iface.input_components
372
+ demo.output_components = iface.output_components
373
 
374
  demo.launch()