salsarra commited on
Commit
8740f66
·
verified ·
1 Parent(s): 70ca632

Delete appOLD.py

Browse files
Files changed (1) hide show
  1. appOLD.py +0 -263
appOLD.py DELETED
@@ -1,263 +0,0 @@
1
- import torch
2
- import tensorflow as tf
3
- from tf_keras import models, layers
4
- from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering, AutoModelForCausalLM
5
- import gradio as gr
6
- import re
7
- import os
8
-
9
- # Check if GPU is available and use it if possible
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
- # Version Information:
13
- confli_version_spanish = 'ConfliBERT-Spanish-Beto-Cased-NewsQA'
14
- beto_version_spanish = 'Beto-Spanish-Cased-NewsQA'
15
- gpt2_spanish_version = 'GPT-2-Small-Spanish'
16
- bloom_spanish_version = 'BLOOM-1.7B'
17
- beto_sqac_version_spanish = 'Beto-Spanish-Cased-SQAC'
18
-
19
- # Load Spanish models and tokenizers
20
- confli_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA'
21
- confli_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_model_spanish)
22
- confli_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_model_spanish)
23
-
24
- beto_model_spanish = 'salsarra/Beto-Spanish-Cased-NewsQA'
25
- beto_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_model_spanish)
26
- beto_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_model_spanish)
27
-
28
- beto_sqac_model_spanish = 'salsarra/Beto-Spanish-Cased-SQAC'
29
- beto_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_sqac_model_spanish)
30
- beto_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_sqac_model_spanish)
31
-
32
- # Load Spanish GPT-2 model and tokenizer
33
- gpt2_spanish_model_name = 'datificate/gpt2-small-spanish'
34
- gpt2_spanish_tokenizer = AutoTokenizer.from_pretrained(gpt2_spanish_model_name)
35
- gpt2_spanish_model = AutoModelForCausalLM.from_pretrained(gpt2_spanish_model_name).to(device)
36
-
37
- # Load BLOOM-1.7B model and tokenizer for Spanish
38
- bloom_model_name = 'bigscience/bloom-1b7'
39
- bloom_tokenizer = AutoTokenizer.from_pretrained(bloom_model_name)
40
- bloom_model = AutoModelForCausalLM.from_pretrained(bloom_model_name).to(device)
41
-
42
- def handle_error_message(e, default_limit=512):
43
- error_message = str(e)
44
- pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)")
45
- match = pattern.search(error_message)
46
- if match:
47
- number_1, number_2 = match.groups()
48
- return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
49
- pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)")
50
- match_qa = pattern_qa.search(error_message)
51
- if match_qa:
52
- number_1, number_2 = match_qa.groups()
53
- return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>"
54
- return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>"
55
-
56
- # Spanish QA functions
57
- def question_answering_spanish(context, question):
58
- try:
59
- inputs = confli_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
60
- outputs = confli_model_spanish_qa(inputs)
61
- answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
62
- answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
63
- answer = confli_tokenizer_spanish.convert_tokens_to_string(confli_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
64
- return f"<span style='color: green; font-weight: bold;'>{answer}</span>"
65
- except Exception as e:
66
- return handle_error_message(e)
67
-
68
- def beto_question_answering_spanish(context, question):
69
- try:
70
- inputs = beto_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
71
- outputs = beto_model_spanish_qa(inputs)
72
- answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
73
- answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
74
- answer = beto_tokenizer_spanish.convert_tokens_to_string(beto_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
75
- return f"<span style='color: blue; font-weight: bold;'>{answer}</span>"
76
- except Exception as e:
77
- return handle_error_message(e)
78
-
79
- def beto_sqac_question_answering_spanish(context, question):
80
- try:
81
- inputs = beto_sqac_tokenizer_spanish(question, context, return_tensors='tf', truncation=True)
82
- outputs = beto_sqac_model_spanish_qa(inputs)
83
- answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0]
84
- answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1
85
- answer = beto_sqac_tokenizer_spanish.convert_tokens_to_string(beto_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]))
86
- return f"<span style='color: brown; font-weight: bold;'>{answer}</span>"
87
- except Exception as e:
88
- return handle_error_message(e)
89
-
90
- # Functions for Spanish GPT-2 and BLOOM-1.7B models
91
- def gpt2_spanish_question_answering(context, question):
92
- try:
93
- prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:"
94
- inputs = gpt2_spanish_tokenizer(prompt, return_tensors='pt').to(device)
95
- outputs = gpt2_spanish_model.generate(
96
- inputs['input_ids'],
97
- max_length=inputs['input_ids'].shape[1] + 50,
98
- num_return_sequences=1,
99
- pad_token_id=gpt2_spanish_tokenizer.eos_token_id,
100
- do_sample=True,
101
- top_k=40,
102
- temperature=0.8
103
- )
104
- answer = gpt2_spanish_tokenizer.decode(outputs[0], skip_special_tokens=True)
105
- answer = answer.split("Respuesta:")[-1].strip()
106
- return f"<span style='color: orange; font-weight: bold;'>{answer}</span>"
107
- except Exception as e:
108
- return handle_error_message(e)
109
-
110
- def bloom_question_answering(context, question):
111
- try:
112
- prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:"
113
- inputs = bloom_tokenizer(prompt, return_tensors='pt').to(device)
114
- outputs = bloom_model.generate(
115
- inputs['input_ids'],
116
- max_length=inputs['input_ids'].shape[1] + 50,
117
- num_return_sequences=1,
118
- pad_token_id=bloom_tokenizer.eos_token_id,
119
- do_sample=True,
120
- top_k=40,
121
- temperature=0.8
122
- )
123
- answer = bloom_tokenizer.decode(outputs[0], skip_special_tokens=True)
124
- answer = answer.split("Respuesta:")[-1].strip()
125
- return f"<span style='color: purple; font-weight: bold;'>{answer}</span>"
126
- except Exception as e:
127
- return handle_error_message(e)
128
-
129
- # Main function for Spanish QA
130
- def compare_question_answering_spanish(context, question):
131
- confli_answer_spanish = question_answering_spanish(context, question)
132
- beto_answer_spanish = beto_question_answering_spanish(context, question)
133
- beto_sqac_answer_spanish = beto_sqac_question_answering_spanish(context, question)
134
- gpt2_answer_spanish = gpt2_spanish_question_answering(context, question)
135
- bloom_answer = bloom_question_answering(context, question)
136
- return f"""
137
- <div>
138
- <h2 style='color: #2e8b57; font-weight: bold;'>Respuestas:</h2>
139
- </div><br>
140
- <div>
141
- <strong>ConfliBERT-Spanish-Beto-Cased-NewsQA:</strong><br>{confli_answer_spanish}</div><br>
142
- <div>
143
- <strong>Beto-Spanish-Cased-NewsQA:</strong><br>{beto_answer_spanish}
144
- </div><br>
145
- <div>
146
- <strong>Beto-Spanish-Cased-SQAC:</strong><br>{beto_sqac_answer_spanish}
147
- </div><br>
148
- <div>
149
- <strong>GPT-2-Small-Spanish:</strong><br>{gpt2_answer_spanish}
150
- </div><br>
151
- <div>
152
- <strong>BLOOM-1.7B:</strong><br>{bloom_answer}
153
- </div><br>
154
- <div>
155
- <strong>Información del modelo:</strong><br>
156
- ConfliBERT-Spanish-Beto-Cased-NewsQA: <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' target='_blank'>salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA</a><br>
157
- Beto-Spanish-Cased-NewsQA: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-NewsQA' target='_blank'>salsarra/Beto-Spanish-Cased-NewsQA</a><br>
158
- Beto-Spanish-Cased-SQAC: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-SQAC' target='_blank'>salsarra/Beto-Spanish-Cased-SQAC</a><br>
159
- GPT-2-Small-Spanish: <a href='https://huggingface.co/datificate/gpt2-small-spanish' target='_blank'>datificate GPT-2 Small Spanish</a><br>
160
- BLOOM-1.7B: <a href='https://huggingface.co/bigscience/bloom-1b7' target='_blank'>bigscience BLOOM-1.7B</a><br>
161
- </div>
162
- """
163
-
164
- # Define the CSS for Gradio interface
165
- css_styles = """
166
- body {
167
- background-color: #f0f8ff;
168
- font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
169
- }
170
- h1 a {
171
- color: #2e8b57;
172
- text-align: center;
173
- font-size: 2em;
174
- text-decoration: none;
175
- }
176
- h1 a:hover {
177
- color: #ff8c00;
178
- }
179
- h2 {
180
- color: #ff8c00;
181
- text-align: center;
182
- font-size: 1.5em;
183
- }
184
- .gradio-container {
185
- max-width: 100%;
186
- margin: 10px auto;
187
- padding: 10px;
188
- background-color: #ffffff;
189
- border-radius: 10px;
190
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
191
- }
192
- .gr-input, .gr-output {
193
- background-color: #ffffff;
194
- border: 1px solid #ddd;
195
- border-radius: 5px;
196
- padding: 10px;
197
- font-size: 1em;
198
- }
199
- .gr-title {
200
- font-size: 1.5em;
201
- font-weight: bold;
202
- color: #2e8b57;
203
- margin-bottom: 10px;
204
- text-align: center;
205
- }
206
- .gr-description {
207
- font-size: 1.2em;
208
- color: #ff8c00;
209
- margin-bottom: 10px;
210
- text-align: center.
211
- }
212
- .header-title-center a {
213
- font-size: 4em;
214
- font-weight: bold;
215
- color: darkorange;
216
- text-align: center;
217
- display: block.
218
- }
219
- .gr-button {
220
- background-color: #ff8c00;
221
- color: white;
222
- border: none;
223
- padding: 10px 20px;
224
- font-size: 1em.
225
- border-radius: 5px;
226
- cursor: pointer.
227
- }
228
- .gr-button:hover {
229
- background-color: #ff4500.
230
- }
231
- .footer {
232
- text-align: center.
233
- margin-top: 10px.
234
- font-size: 0.9em.
235
- color: #666.
236
- width: 100%.
237
- }
238
- .footer a {
239
- color: #2e8b57.
240
- font-weight: bold.
241
- text-decoration: none.
242
- }
243
- .footer a:hover {
244
- text-decoration: underline.
245
- }
246
- """
247
-
248
- # Define the Gradio interface
249
- demo = gr.Interface(
250
- fn=compare_question_answering_spanish,
251
- inputs=[
252
- gr.Textbox(lines=5, placeholder="Ingrese el contexto aquí...", label="Contexto"),
253
- gr.Textbox(lines=2, placeholder="Ingrese su pregunta aquí...", label="Pregunta")
254
- ],
255
- outputs=gr.HTML(label="Salida"),
256
- title="<a href='https://eventdata.utdallas.edu/conflibert/' target='_blank'>ConfliBERT-Spanish-QA</a>",
257
- description="Compare respuestas entre los modelos ConfliBERT, BETO, Beto SQAC, GPT-2 Small Spanish y BLOOM-1.7B para preguntas en español.",
258
- css=css_styles,
259
- allow_flagging="never"
260
- )
261
-
262
- # Launch the Gradio demo
263
- demo.launch(share=True)