Spaces:
Sleeping
Sleeping
import os | |
import openai | |
import torch | |
import tensorflow as tf | |
from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering | |
import gradio as gr | |
import re | |
# Set your OpenAI API key here temporarily for testing | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# Check if GPU is available and use it if possible | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Load the English models and tokenizers | |
qa_model_name_v1 = 'salsarra/ConfliBERT-QA' | |
qa_model_v1 = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name_v1) | |
qa_tokenizer_v1 = AutoTokenizer.from_pretrained(qa_model_name_v1) | |
bert_model_name_v1 = 'salsarra/BERT-base-cased-SQuAD-v1' | |
bert_qa_model_v1 = TFAutoModelForQuestionAnswering.from_pretrained(bert_model_name_v1) | |
bert_qa_tokenizer_v1 = AutoTokenizer.from_pretrained(bert_model_name_v1) | |
# Load Spanish models and tokenizers | |
confli_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' | |
confli_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_model_spanish) | |
confli_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_model_spanish) | |
beto_model_spanish = 'salsarra/Beto-Spanish-Cased-NewsQA' | |
beto_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_model_spanish) | |
beto_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_model_spanish) | |
# Load the newly added models for Spanish (Beto and ConfliBERT SQAC) | |
confli_sqac_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC' | |
confli_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_sqac_model_spanish) | |
confli_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_sqac_model_spanish) | |
beto_sqac_model_spanish = 'salsarra/Beto-Spanish-Cased-SQAC' | |
beto_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_sqac_model_spanish) | |
beto_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_sqac_model_spanish) | |
# Define error handling to separate input size errors from other issues | |
def handle_error_message(e, default_limit=512): | |
error_message = str(e) | |
pattern = re.compile(r"The size of tensor a \\((\\d+)\\) must match the size of tensor b \\((\\d+)\\)") | |
match = pattern.search(error_message) | |
if match: | |
number_1, number_2 = match.groups() | |
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" | |
pattern_qa = re.compile(r"indices\\[0,(\\d+)\\] = \\d+ is not in \\[0, (\\d+)\\)") | |
match_qa = pattern_qa.search(error_message) | |
if match_qa: | |
number_1, number_2 = match_qa.groups() | |
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" | |
return f"<span style='color: red; font-weight: bold;'>Error: {error_message}</span>" | |
# Main comparison function with language selection | |
def compare_question_answering(language, context, question): | |
if language == "English": | |
confli_answer_v1 = question_answering_v1(context, question) | |
bert_answer_v1 = bert_question_answering_v1(context, question) | |
chatgpt_answer = chatgpt_question_answering(context, question) | |
return f""" | |
<div> | |
<h2 style='color: #2e8b57; font-weight: bold;'>Answers:</h2> | |
</div><br> | |
<div> | |
<strong>ConfliBERT-cont-cased-SQuAD-v1:</strong><br>{confli_answer_v1}</div><br> | |
<div> | |
<strong>BERT-base-cased-SQuAD-v1:</strong><br>{bert_answer_v1} | |
</div><br> | |
<div> | |
<strong>ChatGPT:</strong><br>{chatgpt_answer} | |
</div><br> | |
<div> | |
<strong>Model Information:</strong><br> | |
ConfliBERT-cont-cased-SQuAD-v1: <a href='https://huggingface.co/salsarra/ConfliBERT-QA' target='_blank'>salsarra/ConfliBERT-QA</a><br> | |
BERT-base-cased-SQuAD-v1: <a href='https://huggingface.co/salsarra/BERT-base-cased-SQuAD-v1' target='_blank'>salsarra/BERT-base-cased-SQuAD-v1</a><br> | |
ChatGPT (GPT-3.5 Turbo): <a href='https://platform.openai.com/docs/models/gpt-3-5' target='_blank'>OpenAI API</a><br> | |
</div> | |
""" | |
elif language == "Spanish": | |
confli_answer_spanish = question_answering_spanish(context, question) | |
beto_answer_spanish = beto_question_answering_spanish(context, question) | |
confli_sqac_answer_spanish = confli_sqac_question_answering_spanish(context, question) | |
beto_sqac_answer_spanish = beto_sqac_question_answering_spanish(context, question) | |
chatgpt_answer_spanish = chatgpt_question_answering_spanish(context, question) | |
return f""" | |
<div> | |
<h2 style='color: #2e8b57; font-weight: bold;'>Answers:</h2> | |
</div><br> | |
<div> | |
<strong>ConfliBERT-Spanish-Beto-Cased-NewsQA:</strong><br>{confli_answer_spanish}</div><br> | |
<div> | |
<strong>Beto-Spanish-Cased-NewsQA:</strong><br>{beto_answer_spanish} | |
</div><br> | |
<div> | |
<strong>ConfliBERT-Spanish-Beto-Cased-SQAC:</strong><br>{confli_sqac_answer_spanish} | |
</div><br> | |
<div> | |
<strong>Beto-Spanish-Cased-SQAC:</strong><br>{beto_sqac_answer_spanish} | |
</div><br> | |
<div> | |
<strong>ChatGPT:</strong><br>{chatgpt_answer_spanish} | |
</div><br> | |
<div> | |
<strong>Model Information:</strong><br> | |
ConfliBERT-Spanish-Beto-Cased-NewsQA: <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' target='_blank'>salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA</a><br> | |
Beto-Spanish-Cased-NewsQA: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-NewsQA' target='_blank'>salsarra/Beto-Spanish-Cased-NewsQA</a><br> | |
ConfliBERT-Spanish-Beto-Cased-SQAC: <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC' target='_blank'>salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC</a><br> | |
Beto-Spanish-Cased-SQAC: <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-SQAC' target='_blank'>salsarra/Beto-Spanish-Cased-SQAC</a><br> | |
ChatGPT (GPT-3.5 Turbo): <a href='https://platform.openai.com/docs/models/gpt-3-5' target='_blank'>OpenAI API</a><br> | |
</div> | |
""" | |
# Setting up Gradio Blocks interface with footer | |
with gr.Blocks(css=""" | |
body { | |
background-color: #f0f8ff; | |
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; | |
} | |
h1, h1 a { | |
color: #2e8b57; | |
text-align: center; | |
font-size: 2em; | |
text-decoration: none; | |
} | |
h1 a:hover { | |
color: #ff8c00; | |
} | |
h2 { | |
color: #ff8c00; | |
text-align: center; | |
font-size: 1.5em; | |
} | |
.gradio-container { | |
max-width: 100%; | |
margin: 10px auto; | |
padding: 10px; | |
background-color: #ffffff; | |
border-radius: 10px; | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
} | |
.button-row { | |
display: flex; | |
justify-content: center; | |
gap: 10px; | |
} | |
""") as demo: | |
gr.Markdown("# [ConfliBERT-QA](https://eventdata.utdallas.edu/conflibert/)", elem_id="title") | |
gr.Markdown("Compare answers between ConfliBERT, BERT, and ChatGPT for English, and ConfliBERT, BETO, ConfliBERT-SQAC, Beto-SQAC, and ChatGPT for Spanish.") | |
language = gr.Dropdown(choices=["English", "Spanish"], label="Select Language") | |
context = gr.Textbox(lines=5, placeholder="Enter the context here...", label="Context") | |
question = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question") | |
output = gr.HTML(label="Output") | |
with gr.Row(elem_id="button-row"): | |
clear_btn = gr.Button("Clear") | |
submit_btn = gr.Button("Submit") | |
submit_btn.click(fn=compare_question_answering, inputs=[language, context, question], outputs=output) | |
clear_btn.click(fn=lambda: ("", "", "", ""), inputs=[], outputs=[language, context, question, output]) | |
gr.Markdown(""" | |
<div style="text-align: center; margin-top: 20px;"> | |
Built by: <a href="https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/" target="_blank">Sultan Alsarra</a> | |
</div> | |
""") | |
demo.launch(share=True) | |