muchad commited on
Commit
3afe23e
·
1 Parent(s): 2b9a9f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import itertools
3
+ from typing import Dict, Union
4
+
5
+ from nltk import sent_tokenize
6
+ import nltk
7
+ nltk.download('punkt')
8
+ import torch
9
+ from transformers import(
10
+ AutoModelForSeq2SeqLM,
11
+ AutoTokenizer
12
+ )
13
+
14
+ class QGPipeline:
15
+
16
+ def __init__(
17
+ self
18
+ ):
19
+
20
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("muchad/idt5-qa-qg")
21
+ self.tokenizer = AutoTokenizer.from_pretrained("muchad/idt5-qa-qg")
22
+ self.qg_format = "highlight"
23
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ self.model.to(self.device)
25
+ self.ans_model = self.model
26
+ self.ans_tokenizer = self.tokenizer
27
+ assert self.model.__class__.__name__ in ["T5ForConditionalGeneration"]
28
+ self.model_type = "t5"
29
+
30
+
31
+ def __call__(self, inputs: str):
32
+ inputs = " ".join(inputs.split())
33
+ sents, answers = self._extract_answers(inputs)
34
+ flat_answers = list(itertools.chain(*answers))
35
+
36
+ if len(flat_answers) == 0:
37
+ return []
38
+
39
+ qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
40
+ qg_inputs = [example['source_text'] for example in qg_examples]
41
+ questions = self._generate_questions(qg_inputs)
42
+ output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
43
+ return output
44
+
45
+ def _generate_questions(self, inputs):
46
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
47
+
48
+ outs = self.model.generate(
49
+ input_ids=inputs['input_ids'].to(self.device),
50
+ attention_mask=inputs['attention_mask'].to(self.device),
51
+ max_length=80,
52
+ num_beams=4,
53
+ )
54
+
55
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
56
+ return questions
57
+
58
+ def _extract_answers(self, context):
59
+ sents, inputs = self._prepare_inputs_for_ans_extraction(context)
60
+
61
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
62
+
63
+ outs = self.ans_model.generate(
64
+ input_ids=inputs['input_ids'].to(self.device),
65
+ attention_mask=inputs['attention_mask'].to(self.device),
66
+ max_length=80,
67
+ )
68
+
69
+ dec = [self.ans_tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
70
+ answers = [item.split('<sep>') for item in dec]
71
+ answers = [i[:-1] for i in answers]
72
+ return sents, answers
73
+
74
+ def _tokenize(self,
75
+ inputs,
76
+ padding=True,
77
+ truncation=True,
78
+ add_special_tokens=True,
79
+ max_length=512
80
+ ):
81
+ inputs = self.tokenizer.batch_encode_plus(
82
+ inputs,
83
+ max_length=max_length,
84
+ add_special_tokens=add_special_tokens,
85
+ truncation=truncation,
86
+ padding="max_length" if padding else False,
87
+ pad_to_max_length=padding,
88
+ return_tensors="pt"
89
+ )
90
+ return inputs
91
+
92
+ def _prepare_inputs_for_ans_extraction(self, text):
93
+ sents = sent_tokenize(text)
94
+
95
+ inputs = []
96
+ for i in range(len(sents)):
97
+ source_text = "extract answers:"
98
+ for j, sent in enumerate(sents):
99
+ if i == j:
100
+ sent = "<hl> %s <hl>" % sent
101
+ source_text = "%s %s" % (source_text, sent)
102
+ source_text = source_text.strip()
103
+
104
+ source_text = source_text + " </s>"
105
+ inputs.append(source_text)
106
+ return sents, inputs
107
+
108
+ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
109
+ inputs = []
110
+ for i, answer in enumerate(answers):
111
+ if len(answer) == 0: continue
112
+ for answer_text in answer:
113
+ sent = sents[i]
114
+ sents_copy = sents[:]
115
+
116
+ answer_text = answer_text.strip()
117
+ try:
118
+ ans_start_idx = sent.index(answer_text)
119
+
120
+ sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
121
+ sents_copy[i] = sent
122
+
123
+ source_text = " ".join(sents_copy)
124
+ source_text = f"generate question: {source_text}"
125
+ if self.model_type == "t5":
126
+ source_text = source_text + " </s>"
127
+ except:
128
+ continue
129
+
130
+ inputs.append({"answer": answer_text, "source_text": source_text})
131
+
132
+ return inputs
133
+
134
+ class TaskPipeline(QGPipeline):
135
+ def __init__(self, **kwargs):
136
+ super().__init__(**kwargs)
137
+
138
+ def __call__(self, inputs: Union[Dict, str]):
139
+ return super().__call__(inputs)
140
+
141
+ def pipeline():
142
+ task = TaskPipeline
143
+ return task()
144
+
145
+ @st.cache(ttl=24*3600,allow_output_mutation=True)
146
+ def pipeline():
147
+ task = TaskPipeline
148
+ return task()
149
+
150
+ st.title("Indonesian Question Generation")
151
+ st.write("Indonesian Question Generation System using [idT5](https://huggingface.co/muchad/idt5-base)")
152
+ qg = pipeline()
153
+ default_context = "Kapitan Pattimura adalah pahlawan dari Maluku. Beliau lahir pada tanggal 8 Juni 1783 dan meninggal pada tanggal 16 Desember 1817."
154
+ context_in = st.text_area('Context:', default_context, height=200)
155
+ if st.button('Generate Question'):
156
+ if context_in:
157
+ questions = qg(context_in)
158
+ re = ""
159
+ for i, q in enumerate(questions):
160
+ re += (str(i+1) + "\tAnswer: %s".expandtabs(1) % q['answer'] + " \n" + "\tQuestion: %s".expandtabs(2) % q['question'] + " \n")
161
+ st.write(re)
162
+ else:
163
+ st.write("Please check your context")