Spaces:
Sleeping
Sleeping
import torch | |
import tensorflow as tf | |
from tf_keras import models, layers | |
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering, AutoModelForCausalLM | |
import gradio as gr | |
import re | |
import os | |
# Check if GPU is available and use it if possible | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Version Information: | |
confli_version_spanish = 'ConfliBERT-Spanish-Beto-Cased-NewsQA' | |
beto_version_spanish = 'Beto-Spanish-Cased-NewsQA' | |
gpt2_spanish_version = 'GPT-2-Small-Spanish' | |
bloom_spanish_version = 'BLOOM-1.7B' | |
beto_sqac_version_spanish = 'Beto-Spanish-Cased-SQAC' | |
# Load Spanish models and tokenizers | |
confli_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' | |
confli_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_model_spanish) | |
confli_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_model_spanish) | |
beto_model_spanish = 'salsarra/Beto-Spanish-Cased-NewsQA' | |
beto_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_model_spanish) | |
beto_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_model_spanish) | |
beto_sqac_model_spanish = 'salsarra/Beto-Spanish-Cased-SQAC' | |
beto_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_sqac_model_spanish) | |
beto_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_sqac_model_spanish) | |
# Load Spanish GPT-2 model and tokenizer | |
gpt2_spanish_model_name = 'datificate/gpt2-small-spanish' | |
gpt2_spanish_tokenizer = AutoTokenizer.from_pretrained(gpt2_spanish_model_name) | |
gpt2_spanish_model = AutoModelForCausalLM.from_pretrained(gpt2_spanish_model_name).to(device) | |
# Load BLOOM-1.7B model and tokenizer for Spanish | |
bloom_model_name = 'bigscience/bloom-1b7' | |
bloom_tokenizer = AutoTokenizer.from_pretrained(bloom_model_name) | |
bloom_model = AutoModelForCausalLM.from_pretrained(bloom_model_name).to(device) | |
def handle_error_message(e, default_limit=512): | |
error_message = str(e) | |
pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)") | |
match = pattern.search(error_message) | |
if match: | |
number_1, number_2 = match.groups() | |
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" | |
pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)") | |
match_qa = pattern_qa.search(error_message) | |
if match_qa: | |
number_1, number_2 = match_qa.groups() | |
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" | |
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>" | |
# Spanish QA functions | |
def question_answering_spanish(context, question): | |
try: | |
inputs = confli_tokenizer_spanish(question, context, return_tensors='tf', truncation=True) | |
outputs = confli_model_spanish_qa(inputs) | |
answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
answer = confli_tokenizer_spanish.convert_tokens_to_string(confli_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) | |
return f"<span style='color: green; font-weight: bold;'>{answer}</span>" | |
except Exception as e: | |
return handle_error_message(e) | |
def beto_question_answering_spanish(context, question): | |
try: | |
inputs = beto_tokenizer_spanish(question, context, return_tensors='tf', truncation=True) | |
outputs = beto_model_spanish_qa(inputs) | |
answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
answer = beto_tokenizer_spanish.convert_tokens_to_string(beto_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) | |
return f"<span style='color: blue; font-weight: bold;'>{answer}</span>" | |
except Exception as e: | |
return handle_error_message(e) | |
def beto_sqac_question_answering_spanish(context, question): | |
try: | |
inputs = beto_sqac_tokenizer_spanish(question, context, return_tensors='tf', truncation=True) | |
outputs = beto_sqac_model_spanish_qa(inputs) | |
answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
answer = beto_sqac_tokenizer_spanish.convert_tokens_to_string(beto_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) | |
return f"<span style='color: brown; font-weight: bold;'>{answer}</span>" | |
except Exception as e: | |
return handle_error_message(e) | |
# Functions for Spanish GPT-2 and BLOOM-1.7B models | |
def gpt2_spanish_question_answering(context, question): | |
try: | |
prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:" | |
inputs = gpt2_spanish_tokenizer(prompt, return_tensors='pt').to(device) | |
outputs = gpt2_spanish_model.generate( | |
inputs['input_ids'], | |
max_length=inputs['input_ids'].shape[1] + 50, | |
num_return_sequences=1, | |
pad_token_id=gpt2_spanish_tokenizer.eos_token_id, | |
do_sample=True, | |
top_k=40, | |
temperature=0.8 | |
) | |
answer = gpt2_spanish_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
answer = answer.split("Respuesta:")[-1].strip() | |
return f"<span style='color: orange; font-weight: bold;'>{answer}</span>" | |
except Exception as e: | |
return handle_error_message(e) | |
def bloom_question_answering(context, question): | |
try: | |
prompt = f"Contexto:\n{context}\n\nPregunta:\n{question}\n\nRespuesta:" | |
inputs = bloom_tokenizer(prompt, return_tensors='pt').to(device) | |
outputs = bloom_model.generate( | |
inputs['input_ids'], | |
max_length=inputs['input_ids'].shape[1] + 50, | |
num_return_sequences=1, | |
pad_token_id=bloom_tokenizer.eos_token_id, | |
do_sample=True, | |
top_k=40, | |
temperature=0.8 | |
) | |
answer = bloom_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
answer = answer.split("Respuesta:")[-1].strip() | |
return f"<span style='color: purple; font-weight: bold;'>{answer}</span>" | |
except Exception as e: | |
return handle_error_message(e) | |
# Main function for Spanish QA | |
def compare_question_answering_spanish(context, question): | |
confli_answer_spanish = question_answering_spanish(context, question) | |
beto_answer_spanish = beto_question_answering_spanish(context, question) | |
beto_sqac_answer_spanish = beto_sqac_question_answering_spanish(context, question) | |
gpt2_answer_spanish = gpt2_spanish_question_answering(context, question) | |
bloom_answer = bloom_question_answering(context, question) | |
return f""" | |
<div> | |
<h2 style='color: #2e8b57; font-weight: bold;'>Respuestas:</h2> | |
</div><br> | |
<div> | |
<strong>ConfliBERT-Spanish-Beto-Cased-NewsQA:</strong><br>{confli_answer_spanish}</div><br> | |
<div> | |
<strong>Beto-Spanish-Cased-NewsQA:</strong><br>{beto_answer_spanish} | |
</div><br> | |
<div> | |
<strong>Beto-Spanish-Cased-SQAC:</strong><br>{beto_sqac_answer_spanish} | |
</div><br> | |
<div> | |
<strong>GPT-2-Small-Spanish:</strong><br>{gpt2_answer_spanish} | |
</div><br> | |
<div> | |
<strong>BLOOM-1.7B:</strong><br>{bloom_answer} | |
</div><br> | |
<div> | |
<strong>Información del modelo:</strong><br> | |
ConfliBERT-Spanish-Beto-Cased-NewsQA: <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' target='_blank'>salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA</a><br> | |
Beto-Spanish-Cased-NewsQA: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-NewsQA' target='_blank'>salsarra/Beto-Spanish-Cased-NewsQA</a><br> | |
Beto-Spanish-Cased-SQAC: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-SQAC' target='_blank'>salsarra/Beto-Spanish-Cased-SQAC</a><br> | |
GPT-2-Small-Spanish: <a href='https://huggingface.co/datificate/gpt2-small-spanish' target='_blank'>datificate GPT-2 Small Spanish</a><br> | |
BLOOM-1.7B: <a href='https://huggingface.co/bigscience/bloom-1b7' target='_blank'>bigscience BLOOM-1.7B</a><br> | |
</div> | |
""" | |
# Define the CSS for Gradio interface | |
css_styles = """ | |
body { | |
background-color: #f0f8ff; | |
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; | |
} | |
h1 a { | |
color: #2e8b57; | |
text-align: center; | |
font-size: 2em; | |
text-decoration: none; | |
} | |
h1 a:hover { | |
color: #ff8c00; | |
} | |
h2 { | |
color: #ff8c00; | |
text-align: center; | |
font-size: 1.5em; | |
} | |
.gradio-container { | |
max-width: 100%; | |
margin: 10px auto; | |
padding: 10px; | |
background-color: #ffffff; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
.gr-input, .gr-output { | |
background-color: #ffffff; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 10px; | |
font-size: 1em; | |
} | |
.gr-title { | |
font-size: 1.5em; | |
font-weight: bold; | |
color: #2e8b57; | |
margin-bottom: 10px; | |
text-align: center; | |
} | |
.gr-description { | |
font-size: 1.2em; | |
color: #ff8c00; | |
margin-bottom: 10px; | |
text-align: center. | |
} | |
.header-title-center a { | |
font-size: 4em; | |
font-weight: bold; | |
color: darkorange; | |
text-align: center; | |
display: block. | |
} | |
.gr-button { | |
background-color: #ff8c00; | |
color: white; | |
border: none; | |
padding: 10px 20px; | |
font-size: 1em. | |
border-radius: 5px; | |
cursor: pointer. | |
} | |
.gr-button:hover { | |
background-color: #ff4500. | |
} | |
.footer { | |
text-align: center. | |
margin-top: 10px. | |
font-size: 0.9em. | |
color: #666. | |
width: 100%. | |
} | |
.footer a { | |
color: #2e8b57. | |
font-weight: bold. | |
text-decoration: none. | |
} | |
.footer a:hover { | |
text-decoration: underline. | |
} | |
""" | |
# Define the Gradio interface | |
demo = gr.Interface( | |
fn=compare_question_answering_spanish, | |
inputs=[ | |
gr.Textbox(lines=5, placeholder="Ingrese el contexto aquí...", label="Contexto"), | |
gr.Textbox(lines=2, placeholder="Ingrese su pregunta aquí...", label="Pregunta") | |
], | |
outputs=gr.HTML(label="Salida"), | |
title="<a href='https://eventdata.utdallas.edu/conflibert/' target='_blank'>ConfliBERT-Spanish-QA</a>", | |
description="Compare respuestas entre los modelos ConfliBERT, BETO, Beto SQAC, GPT-2 Small Spanish y BLOOM-1.7B para preguntas en español.", | |
css=css_styles, | |
allow_flagging="never" | |
) | |
# Launch the Gradio demo | |
demo.launch(share=True) | |