Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from googletrans import Translator | |
from transformers import T5Tokenizer | |
from transformers import T5ForConditionalGeneration | |
from transformers import BartForConditionalGeneration | |
from transformers import BartTokenizer | |
from transformers import PreTrainedModel | |
from transformers import PreTrainedTokenizer | |
from transformers import AutoTokenizer | |
# Question launcher | |
class E2EQGPipeline: | |
def __init__( | |
self, | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer | |
): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = model | |
self.tokenizer = tokenizer | |
self.model_type = "t5" | |
self.kwargs = { | |
"max_length": 256, | |
"num_beams": 4, | |
"length_penalty": 1.5, | |
"no_repeat_ngram_size": 3, | |
"early_stopping": True, | |
} | |
def generate_questions(self, context: str): | |
inputs = self._prepare_inputs_for_e2e_qg(context) | |
outs = self.model.generate( | |
input_ids=inputs['input_ids'].to(self.device), | |
attention_mask=inputs['attention_mask'].to(self.device), | |
**self.kwargs | |
) | |
prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True) | |
questions = prediction.split("<sep>") | |
questions = [question.strip() for question in questions[:-1]] | |
return questions | |
def _prepare_inputs_for_e2e_qg(self, context): | |
source_text = f"generate questions: {context}" | |
inputs = self._tokenize([source_text], padding=False) | |
return inputs | |
def _tokenize( | |
self, | |
inputs, | |
padding=True, | |
truncation=True, | |
add_special_tokens=True, | |
max_length=512 | |
): | |
inputs = self.tokenizer.batch_encode_plus( | |
inputs, | |
max_length=max_length, | |
add_special_tokens=add_special_tokens, | |
truncation=truncation, | |
padding="max_length" if padding else False, | |
pad_to_max_length=padding, | |
return_tensors="pt" | |
) | |
return inputs | |
qg_model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-e2e-qg') | |
qg_tokenizer = T5Tokenizer.from_pretrained('valhalla/t5-base-e2e-qg') | |
def generate_questions(text): | |
qg_final_model = E2EQGPipeline(qg_model, qg_tokenizer) | |
questions = qg_final_model.generate_questions(text) | |
translator = Translator() | |
translated_questions = [translator.translate(question, dest='es').text for question in questions] | |
return translated_questions | |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | |
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | |
def generate_summary(text): | |
inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True) | |
summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
# QA | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
ckpt = 'mrm8488/spanish-t5-small-sqac-for-qa' | |
qa_tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
qa_model = T5ForConditionalGeneration.from_pretrained(ckpt).to(device) | |
def generate_question_response(question, context): | |
input_text = 'question: %s context: %s' % (question, context) | |
print(input_text) | |
features = qa_tokenizer([input_text], padding='max_length', truncation=True, max_length=512, return_tensors='pt') | |
output = qa_model.generate( | |
input_ids=features['input_ids'].to(device), | |
attention_mask=features['attention_mask'].to(device), | |
temperature=1.0 | |
) | |
return qa_tokenizer.decode(output[0], skip_special_tokens=True) | |
class SummarizerAndQA: | |
def __init__(self): | |
self.input_text = '' | |
self.question = '' | |
self.summary = '' | |
self.study_generated_questions = '' | |
self.question_response = '' | |
def is_text_loaded(self): | |
return self.input_text != '' | |
def process_summarizer(self, text): | |
self.input_text = text | |
return generate_summary(text) | |
def process_questions(self): | |
if not self.is_text_loaded(): | |
return "Primero ingresa el texto en la seccion de Resumen" | |
return generate_questions(self.input_text) | |
def process_question_response(self, question, history): | |
if not self.is_text_loaded(): | |
return "Primero ingresa el texto en la seccion de Resumen" | |
return generate_question_response(question, self.input_text) | |
summarizer_and_qa = SummarizerAndQA() | |
textbox_input = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15) | |
summary_output = gr.Textbox(label="Resumen", lines=15) | |
questions_output = gr.Textbox(label="Preguntas de guia generadas", lines=5) | |
summarizer_interface = gr.Interface(fn=summarizer_and_qa.process_summarizer, inputs=[textbox_input], outputs=[summary_output], allow_flagging="never") | |
questions_interface = gr.Interface(fn=summarizer_and_qa.process_questions, inputs=[], outputs=[questions_output], allow_flagging="never") | |
chatbot_interface = gr.ChatInterface(fn=summarizer_and_qa.process_question_response, type="messages", examples=[], title="Preguntas sobre el texto") | |
gr.TabbedInterface([summarizer_interface, questions_interface, chatbot_interface], ["Resumidor", "Preguntas de guia", "Chatbot"]).launch() | |