salsarra commited on
Commit
70ca632
verified
1 Parent(s): 6cf991d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +275 -0
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tensorflow as tf
3
+ from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering, AutoModelForCausalLM
4
+ import gradio as gr
5
+ import re
6
+
7
+ # Check if GPU is available and use it if possible
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ # Load Spanish models and tokenizers
11
+ confli_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA'
12
+ confli_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_model_spanish)
13
+ confli_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_model_spanish)
14
+
15
+ beto_model_spanish = 'salsarra/Beto-Spanish-Cased-NewsQA'
16
+ beto_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_model_spanish)
17
+ beto_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_model_spanish)
18
+
19
+ confli_sqac_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC'
20
+ confli_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_sqac_model_spanish)
21
+ confli_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_sqac_model_spanish)
22
+
23
+ beto_sqac_model_spanish = 'salsarra/Beto-Spanish-Cased-SQAC'
24
+ beto_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_sqac_model_spanish)
25
+ beto_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_sqac_model_spanish)
26
+
27
+ # Load Spanish GPT-2 model and tokenizer
28
+ gpt2_spanish_model_name = 'datificate/gpt2-small-spanish'
29
+ gpt2_spanish_tokenizer = AutoTokenizer.from_pretrained(gpt2_spanish_model_name)
30
+ gpt2_spanish_model = AutoModelForCausalLM.from_pretrained(gpt2_spanish_model_name).to(device)
31
+
32
+ # Load BLOOM-1.7B model and tokenizer for Spanish
33
+ bloom_model_name = 'bigscience/bloom-1b7'
34
+ bloom_tokenizer = AutoTokenizer.from_pretrained(bloom_model_name)
35
+ bloom_model = AutoModelForCausalLM.from_pretrained(bloom_model_name).to(device)
36
+
37
+ # Preload models with a dummy pass to improve first-time loading
38
+ def preload_models():
39
+ dummy_context = "Este es un contexto de prueba."
40
+ dummy_question = "驴Cu谩l es el prop贸sito de este contexto?"
41
+
42
+ # Run each model with a dummy input to initialize them
43
+ inputs = confli_tokenizer_spanish(dummy_question, dummy_context, return_tensors='tf')
44
+ _ = confli_model_spanish_qa(inputs)
45
+
46
+ inputs = beto_tokenizer_spanish(dummy_question, dummy_context, return_tensors='tf')
47
+ _ = beto_model_spanish_qa(inputs)
48
+
49
+ inputs = confli_sqac_tokenizer_spanish(dummy_question, dummy_context, return_tensors='tf')
50
+ _ = confli_sqac_model_spanish_qa(inputs)
51
+
52
+ inputs = beto_sqac_tokenizer_spanish(dummy_question, dummy_context, return_tensors='tf')
53
+ _ = beto_sqac_model_spanish_qa(inputs)
54
+
55
+ preload_models() # Initialize models
56
+
57
+ # Error handling function
58
+ def handle_error_message(e, default_limit=512):
59
+ error_message = str(e)
60
+ pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
61
+ match = pattern.search(error_message)
62
+ if match:
63
+ number_1, number_2 = match.groups()
64
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
65
+ return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
66
+
67
+ # Spanish QA functions
68
+ def question_answering_spanish(context, question):
69
+ try:
70
+ inputs = confli_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
71
+ outputs = confli_model_spanish_qa(inputs)
72
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
73
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
74
+ answer = confli_tokenizer_spanish.convert_tokens_to_string(confli_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
75
+ return f"<span style='color: green; font-weight: bold;'>{answer}</span>"
76
+ except Exception as e:
77
+ return handle_error_message(e)
78
+
79
+ def beto_question_answering_spanish(context, question):
80
+ try:
81
+ inputs = beto_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
82
+ outputs = beto_model_spanish_qa(inputs)
83
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
84
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
85
+ answer = beto_tokenizer_spanish.convert_tokens_to_string(beto_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
86
+ return f"<span style='color: blue; font-weight: bold;'>{answer}</span>"
87
+ except Exception as e:
88
+ return handle_error_message(e)
89
+
90
+ def confli_sqac_question_answering_spanish(context, question):
91
+ try:
92
+ inputs = confli_sqac_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
93
+ outputs = confli_sqac_model_spanish_qa(inputs)
94
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
95
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
96
+ answer = confli_sqac_tokenizer_spanish.convert_tokens_to_string(confli_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
97
+ return f"<span style='color: teal; font-weight: bold;'>{answer}</span>"
98
+ except Exception as e:
99
+ return handle_error_message(e)
100
+
101
+ def beto_sqac_question_answering_spanish(context, question):
102
+ try:
103
+ inputs = beto_sqac_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
104
+ outputs = beto_sqac_model_spanish_qa(inputs)
105
+ answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
106
+ answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
107
+ answer = beto_sqac_tokenizer_spanish.convert_tokens_to_string(beto_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
108
+ return f"<span style='color: brown; font-weight: bold;'>{answer}</span>"
109
+ except Exception as e:
110
+ return handle_error_message(e)
111
+
112
+ def gpt2_spanish_question_answering(context, question):
113
+ try:
114
+ prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:"
115
+ inputs = gpt2_spanish_tokenizer(prompt, return_tensors='pt').to(device)
116
+ outputs = gpt2_spanish_model.generate(
117
+ inputs['input_ids'],
118
+ max_length=inputs['input_ids'].shape[1] + 50,
119
+ num_return_sequences=1,
120
+ pad_token_id=gpt2_spanish_tokenizer.eos_token_id,
121
+ do_sample=True,
122
+ top_k=40,
123
+ temperature=0.8
124
+ )
125
+ answer = gpt2_spanish_tokenizer.decode(outputs[0], skip_special_tokens=True)
126
+ answer = answer.split("Respuesta:")[-1].strip()
127
+ return f"<span style='color: orange; font-weight: bold;'>{answer}</span>"
128
+ except Exception as e:
129
+ return handle_error_message(e)
130
+
131
+ def bloom_question_answering(context, question):
132
+ try:
133
+ prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:"
134
+ inputs = bloom_tokenizer(prompt, return_tensors='pt').to(device)
135
+ outputs = bloom_model.generate(
136
+ inputs['input_ids'],
137
+ max_length=inputs['input_ids'].shape[1] + 50,
138
+ num_return_sequences=1,
139
+ pad_token_id=bloom_tokenizer.eos_token_id,
140
+ do_sample=True,
141
+ top_k=40,
142
+ temperature=0.8
143
+ )
144
+ answer = bloom_tokenizer.decode(outputs[0], skip_special_tokens=True)
145
+ answer = answer.split("Respuesta:")[-1].strip()
146
+ return f"<span style='color: purple; font-weight: bold;'>{answer}</span>"
147
+ except Exception as e:
148
+ return handle_error_message(e)
149
+
150
+ # Main function for Spanish QA
151
+ def compare_question_answering_spanish(context, question):
152
+ confli_answer_spanish = question_answering_spanish(context, question)
153
+ beto_answer_spanish = beto_question_answering_spanish(context, question)
154
+ confli_sqac_answer_spanish = confli_sqac_question_answering_spanish(context, question)
155
+ beto_sqac_answer_spanish = beto_sqac_question_answering_spanish(context, question)
156
+ gpt2_answer_spanish = gpt2_spanish_question_answering(context, question)
157
+ bloom_answer = bloom_question_answering(context, question)
158
+ return f"""
159
+ <div>
160
+ <h2 style='color: #2e8b57; font-weight: bold;'>Respuestas:</h2>
161
+ </div><br>
162
+ <div>
163
+ <strong>ConfliBERT-Spanish-Beto-Cased-NewsQA:</strong><br>{confli_answer_spanish}</div><br>
164
+ <div>
165
+ <strong>Beto-Spanish-Cased-NewsQA:</strong><br>{beto_answer_spanish}
166
+ </div><br>
167
+ <div>
168
+ <strong>ConfliBERT-Spanish-Beto-Cased-SQAC:</strong><br>{confli_sqac_answer_spanish}
169
+ </div><br>
170
+ <div>
171
+ <strong>Beto-Spanish-Cased-SQAC:</strong><br>{beto_sqac_answer_spanish}
172
+ </div><br>
173
+ <div>
174
+ <strong>GPT-2-Small-Spanish:</strong><br>{gpt2_answer_spanish}
175
+ </div><br>
176
+ <div>
177
+ <strong>BLOOM-1.7B:</strong><br>{bloom_answer}
178
+ </div><br>
179
+ <div>
180
+ <strong>Informaci贸n del modelo:</strong><br>
181
+ ConfliBERT-Spanish-Beto-Cased-NewsQA: <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' target='_blank'>salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA</a><br>
182
+ Beto-Spanish-Cased-NewsQA: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-NewsQA' target='_blank'>salsarra/Beto-Spanish-Cased-NewsQA</a><br>
183
+ ConfliBERT-Spanish-Beto-Cased-SQAC: <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC' target='_blank'>salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC</a><br>
184
+ Beto-Spanish-Cased-SQAC: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-SQAC' target='_blank'>salsarra/Beto-Spanish-Cased-SQAC</a><br>
185
+ GPT-2-Small-Spanish: <a href='https://huggingface.co/datificate/gpt2-small-spanish' target='_blank'>datificate GPT-2 Small Spanish</a><br>
186
+ BLOOM-1.7B: <a href='https://huggingface.co/bigscience/bloom-1b7' target='_blank'>bigscience BLOOM-1.7B</a><br>
187
+ </div>
188
+ """
189
+
190
+ # CSS for Gradio interface
191
+ css_styles = """
192
+ body {
193
+ background-color: #f0f8ff;
194
+ font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
195
+ }
196
+ h1 a {
197
+ color: #2e8b57;
198
+ text-align: center;
199
+ font-size: 2em;
200
+ text-decoration: none;
201
+ }
202
+ h1 a:hover {
203
+ color: #ff8c00;
204
+ }
205
+ h2 {
206
+ color: #ff8c00;
207
+ text-align: center;
208
+ font-size: 1.5em;
209
+ }
210
+ .description-light {
211
+ color: black;
212
+ display: block;
213
+ font-size: 1em;
214
+ text-align: center;
215
+ }
216
+ .description-dark {
217
+ color: white;
218
+ display: none;
219
+ font-size: 1em;
220
+ text-align: center;
221
+ }
222
+ @media (prefers-color-scheme: dark) {
223
+ .description-light {
224
+ display: none;
225
+ }
226
+ .description-dark {
227
+ display: block;
228
+ }
229
+ }
230
+ .footer {
231
+ text-align: center;
232
+ margin-top: 10px;
233
+ font-size: 0.9em;
234
+ color: #666;
235
+ width: 100%;
236
+ }
237
+ .footer a {
238
+ color: #2e8b57;
239
+ font-weight: bold;
240
+ text-decoration: none;
241
+ }
242
+ .footer a:hover {
243
+ text-decoration: underline;
244
+ }
245
+ """
246
+
247
+ # Define the Gradio interface with footer directly in the layout
248
+ demo = gr.Interface(
249
+ fn=compare_question_answering_spanish,
250
+ inputs=[
251
+ gr.Textbox(lines=5, placeholder="Ingrese el contexto aqu铆...", label="Contexto"),
252
+ gr.Textbox(lines=2, placeholder="Ingrese su pregunta aqu铆...", label="Pregunta")
253
+ ],
254
+ outputs=gr.HTML(label="Salida"),
255
+ title="<a href='https://eventdata.utdallas.edu/conflibert/' target='_blank'>ConfliBERT-Spanish-QA</a>",
256
+ description="""
257
+ <span class="description-light">Compare respuestas entre los modelos ConfliBERT, BETO, ConfliBERT SQAC, Beto SQAC, GPT-2 Small Spanish y BLOOM-1.7B para preguntas en espa帽ol.</span>
258
+ <span class="description-dark">Compare respuestas entre los modelos ConfliBERT, BETO, ConfliBERT SQAC, Beto SQAC, GPT-2 Small Spanish y BLOOM-1.7B para preguntas en espa帽ol.</span>
259
+ """,
260
+ css=css_styles,
261
+ allow_flagging="never",
262
+ # Footer HTML with centered, green links
263
+ article="""
264
+ <div class='footer'>
265
+ <a href='https://eventdata.utdallas.edu/' style='color: #2e8b57; font-weight: bold;'>UTD Event Data</a> |
266
+ <a href='https://www.utdallas.edu/' style='color: #2e8b57; font-weight: bold;'>University of Texas at Dallas</a>
267
+ </div>
268
+ <div class='footer'>
269
+ Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank' style='color: #2e8b57; font-weight: bold;'>Sultan Alsarra</a>
270
+ </div>
271
+ """
272
+ )
273
+
274
+ # Launch the Gradio demo
275
+ demo.launch(share=True)