Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from googletrans import Translator | |
from transformers import T5Tokenizer | |
from transformers import T5ForConditionalGeneration | |
from transformers import BartForConditionalGeneration | |
from transformers import BartTokenizer | |
from transformers import PreTrainedModel | |
from transformers import PreTrainedTokenizer | |
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') | |
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') | |
# Question launcher | |
class E2EQGPipeline: | |
def __init__( | |
self, | |
model: PreTrainedModel, | |
tokenizer: PreTrainedTokenizer | |
): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = model | |
self.tokenizer = tokenizer | |
self.model_type = "t5" | |
self.kwargs = { | |
"max_length": 256, | |
"num_beams": 4, | |
"length_penalty": 1.5, | |
"no_repeat_ngram_size": 3, | |
"early_stopping": True, | |
} | |
def generate_questions(self, context: str): | |
inputs = self._prepare_inputs_for_e2e_qg(context) | |
outs = self.model.generate( | |
input_ids=inputs['input_ids'].to(self.device), | |
attention_mask=inputs['attention_mask'].to(self.device), | |
**self.kwargs | |
) | |
prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True) | |
questions = prediction.split("<sep>") | |
questions = [question.strip() for question in questions[:-1]] | |
return questions | |
def _prepare_inputs_for_e2e_qg(self, context): | |
source_text = f"generate questions: {context}" | |
inputs = self._tokenize([source_text], padding=False) | |
return inputs | |
def _tokenize( | |
self, | |
inputs, | |
padding=True, | |
truncation=True, | |
add_special_tokens=True, | |
max_length=512 | |
): | |
inputs = self.tokenizer.batch_encode_plus( | |
inputs, | |
max_length=max_length, | |
add_special_tokens=add_special_tokens, | |
truncation=truncation, | |
padding="max_length" if padding else False, | |
pad_to_max_length=padding, | |
return_tensors="pt" | |
) | |
return inputs | |
def generate_questions(text): | |
qg_model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-e2e-qg') | |
qg_tokenizer = T5Tokenizer.from_pretrained('valhalla/t5-base-e2e-qg') | |
qg_final_model = E2EQGPipeline(qg_model, qg_tokenizer) | |
questions = qg_final_model.generate_questions(text) | |
translator = Translator() | |
translated_questions = [translator.translate(question, dest='es').text for question in questions] | |
return translated_questions | |
def generate_summary(text): | |
inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True) | |
summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
return summary | |
def process(text): | |
return generate_summary(text), generate_questions(text) | |
textbox_input = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15) | |
summary_output = gr.Textbox(label="Resumen", lines=15) | |
questions_output = gr.Textbox(label="Preguntas", lines=5) | |
demo = gr.Interface(fn=process, inputs=textbox_input, outputs=[summary_output, questions_output]) | |
demo.launch() | |