File size: 5,672 Bytes
f34392a
dc63bd9
f34392a
 
 
ca190b4
 
dc63bd9
 
f555fb0
2c6b1df
ca190b4
f34392a
 
 
 
 
 
 
 
dc63bd9
f34392a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2e73e4
 
 
 
f34392a
 
 
 
 
 
 
 
c2e73e4
 
 
 
77f3032
ca190b4
 
 
 
 
e9e44ae
f555fb0
e9e44ae
 
 
 
ca190b4
f555fb0
 
e9e44ae
f61923d
e396b2c
e9e44ae
 
 
e396b2c
e9e44ae
 
ca190b4
 
23ea224
 
 
 
 
 
 
 
 
ed5c58f
 
23ea224
3ef0d51
 
 
 
 
c3cccdc
 
ba197d5
 
f12740c
c3cccdc
 
ba197d5
23ea224
 
 
b9ce672
 
3fb83c0
ed5c58f
3ef0d51
 
5f20db4
f12740c
ba197d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import gradio as gr
from googletrans import Translator
from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration
from transformers import BartForConditionalGeneration
from transformers import BartTokenizer
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizer
from transformers import AutoTokenizer


# Question launcher
class E2EQGPipeline:
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer
    ):

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.model = model
        self.tokenizer = tokenizer

        self.model_type = "t5"

        self.kwargs = {
            "max_length": 256,
            "num_beams": 4,
            "length_penalty": 1.5,
            "no_repeat_ngram_size": 3,
            "early_stopping": True,
        }

    def generate_questions(self, context: str):
        inputs = self._prepare_inputs_for_e2e_qg(context)

        outs = self.model.generate(
            input_ids=inputs['input_ids'].to(self.device),
            attention_mask=inputs['attention_mask'].to(self.device),
            **self.kwargs
        )

        prediction = self.tokenizer.decode(outs[0], skip_special_tokens=True)

        questions = prediction.split("<sep>")
        questions = [question.strip() for question in questions[:-1]]
        return questions

    def _prepare_inputs_for_e2e_qg(self, context):
        source_text = f"generate questions: {context}"

        inputs = self._tokenize([source_text], padding=False)

        return inputs

    def _tokenize(
        self,
        inputs,
        padding=True,
        truncation=True,
        add_special_tokens=True,
        max_length=512
    ):
        inputs = self.tokenizer.batch_encode_plus(
            inputs,
            max_length=max_length,
            add_special_tokens=add_special_tokens,
            truncation=truncation,
            padding="max_length" if padding else False,
            pad_to_max_length=padding,
            return_tensors="pt"
        )

        return inputs


qg_model = T5ForConditionalGeneration.from_pretrained('valhalla/t5-base-e2e-qg')
qg_tokenizer = T5Tokenizer.from_pretrained('valhalla/t5-base-e2e-qg')


def generate_questions(text):
    qg_final_model = E2EQGPipeline(qg_model, qg_tokenizer)
    questions = qg_final_model.generate_questions(text)
    translator = Translator()
    translated_questions = [translator.translate(question, dest='es').text for question in questions]
    return translated_questions


tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')


def generate_summary(text):
    inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=1024, truncation=True)
    summary_ids = model.generate(inputs, max_length=150, min_length=50, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary


# QA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ckpt = 'mrm8488/spanish-t5-small-sqac-for-qa'
qa_tokenizer = AutoTokenizer.from_pretrained(ckpt)
qa_model = T5ForConditionalGeneration.from_pretrained(ckpt).to(device)


def generate_question_response(question, context):
    input_text = 'question: %s  context: %s' % (question, context)
    print(input_text)
    features = qa_tokenizer([input_text], padding='max_length', truncation=True, max_length=512, return_tensors='pt')
    output = qa_model.generate(
        input_ids=features['input_ids'].to(device),
        attention_mask=features['attention_mask'].to(device),
        temperature=1.0
    )
    return qa_tokenizer.decode(output[0], skip_special_tokens=True)


class SummarizerAndQA:
    def __init__(self):
        self.input_text = ''
        self.question = ''

        self.summary = ''
        self.study_generated_questions = ''
        self.question_response = ''

    def is_text_loaded(self):
        return self.input_text != ''

    def process_summarizer(self, text):
        self.input_text = text
        return generate_summary(text)

    def process_questions(self):
        if not self.is_text_loaded:
            return "Primero ingresa el texto en la seccion de Resumen"
        return generate_questions(self.input_text)

    def process_question_response(self, question, history):
        if not self.is_text_loaded:
            return "Primero ingresa el texto en la seccion de Resumen"
        return generate_question_response(question, self.input_text)

summarizer_and_qa = SummarizerAndQA()

textbox_input = gr.Textbox(label="Pega el text aca:", placeholder="Texto...", lines=15)
summary_output = gr.Textbox(label="Resumen", lines=15)
questions_output = gr.Textbox(label="Preguntas de guia generadas", lines=5)
questions_generate_button = gr.Button("Generate", variant="primary", interactive=summarizer_and_qa.is_text_loaded())

summarizer_interface = gr.Interface(fn=summarizer_and_qa.process_summarizer, inputs=[textbox_input], outputs=[summary_output], allow_flagging="never")
questions_interface = gr.Interface(fn=summarizer_and_qa.process_questions, inputs=[], outputs=[questions_output], allow_flagging="never", submit_btn=questions_generate_button, live=True)
chatbot_interface = gr.ChatInterface(fn=summarizer_and_qa.process_question_response, type="messages", examples=[], title="Preguntas sobre el texto")
gr.TabbedInterface([summarizer_interface, questions_interface, chatbot_interface], ["Resumidor", "Preguntas de guia", "Chatbot"]).launch()