Spaces:
Runtime error
Runtime error
Commit
·
9f23e0b
1
Parent(s):
0b52499
uploaded code files
Browse files- Article.py +46 -0
- QueryProcessor.py +98 -0
- QuestionAnswer.py +129 -0
- app.py +84 -0
Article.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wikipediaapi
|
2 |
+
|
3 |
+
class Article:
|
4 |
+
|
5 |
+
def __init__(self, article_name):
|
6 |
+
self.article_data = {}
|
7 |
+
self.article = wikipediaapi.Wikipedia('en').page(article_name)
|
8 |
+
|
9 |
+
def article_exists(self):
|
10 |
+
try:
|
11 |
+
if self.article.exists():
|
12 |
+
return True
|
13 |
+
except:
|
14 |
+
return False
|
15 |
+
|
16 |
+
def get_sections_and_texts(self, sections):
|
17 |
+
if 'Summary' not in self.article_data:
|
18 |
+
self.article_data['Summary'] = ''
|
19 |
+
if self.article.summary:
|
20 |
+
self.article_data['Summary'] = self.article.summary.lower().split('\n')
|
21 |
+
|
22 |
+
for section in sections:
|
23 |
+
if section.text:
|
24 |
+
self.article_data[section.title] = section.text.lower().split('\n')
|
25 |
+
if len(section.sections) > 0:
|
26 |
+
self.get_sections_and_texts(section.sections)
|
27 |
+
|
28 |
+
def remove_empty_sections(self):
|
29 |
+
for _, docs in self.article_data.items():
|
30 |
+
for d in docs:
|
31 |
+
if len(d) <= 0:
|
32 |
+
docs.remove(d)
|
33 |
+
|
34 |
+
|
35 |
+
def get_article_data(self):
|
36 |
+
self.get_sections_and_texts(self.article.sections)
|
37 |
+
self.remove_empty_sections()
|
38 |
+
|
39 |
+
num_docs = sum(len(docs) for _, docs in self.article_data.items())
|
40 |
+
avg_doc_len = sum(len(doc.split()) for _, docs in self.article_data.items() for doc in docs) / num_docs
|
41 |
+
|
42 |
+
return {
|
43 |
+
'article_data': self.article_data,
|
44 |
+
'num_docs': num_docs,
|
45 |
+
'avg_doc_len': avg_doc_len
|
46 |
+
}
|
QueryProcessor.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from nltk.corpus import stopwords
|
4 |
+
from nltk.tokenize import RegexpTokenizer
|
5 |
+
|
6 |
+
class QueryProcessor:
|
7 |
+
|
8 |
+
def __init__(self, question, section_texts, N, avg_doc_len):
|
9 |
+
self.section_texts = section_texts
|
10 |
+
self.N = N
|
11 |
+
self.avg_doc_len = avg_doc_len
|
12 |
+
# self.bm25_scores = {}
|
13 |
+
|
14 |
+
self.query_items = self.set_query(question)
|
15 |
+
self.section_document_idx = None
|
16 |
+
|
17 |
+
def set_query(self, question):
|
18 |
+
punct_regex = RegexpTokenizer(r'\w+')
|
19 |
+
|
20 |
+
return [q for q in punct_regex.tokenize(question.lower()) if q not in stopwords.words('english')]
|
21 |
+
|
22 |
+
def get_query(self):
|
23 |
+
return self.query_items
|
24 |
+
|
25 |
+
def bm25(self, word, paragraph, k=1.2, b=0.75):
|
26 |
+
# frequency of word (word) in doc (paragraph)
|
27 |
+
freq = paragraph.split().count(word)
|
28 |
+
|
29 |
+
# term frequency
|
30 |
+
tf = (freq * (k+1)) / (freq + k * (1 - b + b * len(paragraph.split()) / self.avg_doc_len))
|
31 |
+
|
32 |
+
# number of docs that contain the word
|
33 |
+
N_q = sum([1 for _, docs in self.section_texts.items() for doc in docs if word in doc.split()])
|
34 |
+
|
35 |
+
# inverse document frequency
|
36 |
+
idf = np.log(((self.N - N_q + 0.5) / (N_q + 0.5)) + 1)
|
37 |
+
|
38 |
+
return round(tf*idf, 4)
|
39 |
+
|
40 |
+
def get_bm25_scores(self):
|
41 |
+
bm25_scores = {}
|
42 |
+
|
43 |
+
for query in self.query_items:
|
44 |
+
bm25_scores[query] = {}
|
45 |
+
for section, docs in self.section_texts.items():
|
46 |
+
bm25_scores[query][section] = {}
|
47 |
+
for doc_index in range(len(docs)):
|
48 |
+
score = self.bm25(query, docs[doc_index])
|
49 |
+
if score > 0.0:
|
50 |
+
bm25_scores[query][section][doc_index] = score
|
51 |
+
|
52 |
+
if len(bm25_scores[query][section]) <= 0:
|
53 |
+
del bm25_scores[query][section]
|
54 |
+
|
55 |
+
return bm25_scores
|
56 |
+
|
57 |
+
def filter_bad_documents(self, bm25_scores):
|
58 |
+
section_document_idx = {}
|
59 |
+
|
60 |
+
for sec_docs in bm25_scores.values():
|
61 |
+
for sec, doc_scores in sec_docs.items():
|
62 |
+
if sec not in section_document_idx:
|
63 |
+
section_document_idx[sec] = []
|
64 |
+
for doc_idx, score in doc_scores.items():
|
65 |
+
if score > 0.5 and doc_idx not in section_document_idx[sec]:
|
66 |
+
section_document_idx[sec].append(doc_idx)
|
67 |
+
|
68 |
+
if len(section_document_idx[sec]) <= 0:
|
69 |
+
del section_document_idx[sec]
|
70 |
+
|
71 |
+
return section_document_idx
|
72 |
+
|
73 |
+
|
74 |
+
def get_context(self):
|
75 |
+
bm25_scores = self.get_bm25_scores()
|
76 |
+
self.section_document_idx = self.filter_bad_documents(bm25_scores)
|
77 |
+
|
78 |
+
# print(bm25_scores)
|
79 |
+
context = ' '.join([self.section_texts[section][d_id] for section, doc_ids in self.section_document_idx.items() for d_id in doc_ids])
|
80 |
+
|
81 |
+
# print(section_document_idx)
|
82 |
+
|
83 |
+
return context
|
84 |
+
|
85 |
+
def match_section_with_answer_text(self, text):
|
86 |
+
# print(text)
|
87 |
+
sections = []
|
88 |
+
for sec, doc_ids in self.section_document_idx.items():
|
89 |
+
for d_id in doc_ids:
|
90 |
+
if self.section_texts[sec][d_id].find(text) > -1:
|
91 |
+
sections.append(sec)
|
92 |
+
|
93 |
+
return sections
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
QuestionAnswer.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
# # from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
4 |
+
|
5 |
+
|
6 |
+
class QuestionAnswer:
|
7 |
+
|
8 |
+
def __init__(self, data, model, tokenizer, torch_device):
|
9 |
+
|
10 |
+
self.max_length = 384
|
11 |
+
self.doc_stride = 128
|
12 |
+
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.model = model
|
15 |
+
self.data = data
|
16 |
+
self.torch_device = torch_device
|
17 |
+
|
18 |
+
self.output = None
|
19 |
+
self.features = None
|
20 |
+
self.results = None
|
21 |
+
|
22 |
+
def get_output_from_model(self):
|
23 |
+
# data = {'question': question, 'context': context}
|
24 |
+
|
25 |
+
with torch.no_grad():
|
26 |
+
tokenized_data = self.tokenizer(
|
27 |
+
self.data['question'],
|
28 |
+
self.data['context'],
|
29 |
+
truncation='only_second',
|
30 |
+
max_length=self.max_length,
|
31 |
+
stride=self.doc_stride,
|
32 |
+
return_overflowing_tokens=True,
|
33 |
+
return_offsets_mapping=True,
|
34 |
+
padding='max_length',
|
35 |
+
return_tensors='pt'
|
36 |
+
).to(self.torch_device)
|
37 |
+
|
38 |
+
output = self.model(tokenized_data['input_ids'], tokenized_data['attention_mask'])
|
39 |
+
|
40 |
+
return output
|
41 |
+
|
42 |
+
# print(output.keys())
|
43 |
+
# print(output['start_logits'].shape)
|
44 |
+
# print(output['end_logits'].shape)
|
45 |
+
# print(tokenized_data.keys())
|
46 |
+
|
47 |
+
def prepare_features(self, example):
|
48 |
+
tokenized_example = self.tokenizer(
|
49 |
+
example['question'],
|
50 |
+
example['context'],
|
51 |
+
truncation='only_second',
|
52 |
+
max_length=self.max_length,
|
53 |
+
stride=self.doc_stride,
|
54 |
+
return_overflowing_tokens=True,
|
55 |
+
return_offsets_mapping=True,
|
56 |
+
padding='max_length',
|
57 |
+
)
|
58 |
+
|
59 |
+
# sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
|
60 |
+
|
61 |
+
for i in range(len(tokenized_example['input_ids'])):
|
62 |
+
sequence_ids = tokenized_example.sequence_ids(i)
|
63 |
+
# print(sequence_ids)
|
64 |
+
context_index = 1
|
65 |
+
|
66 |
+
# sample_index = sample_mapping[i]
|
67 |
+
|
68 |
+
tokenized_example["offset_mapping"][i] = [
|
69 |
+
(o if sequence_ids[k] == context_index else None)
|
70 |
+
for k, o in enumerate(tokenized_example["offset_mapping"][i])
|
71 |
+
]
|
72 |
+
|
73 |
+
return tokenized_example
|
74 |
+
|
75 |
+
def postprocess_qa_predictions(self, data, features, raw_predictions, top_n_answers=5, max_answer_length=30):
|
76 |
+
all_start_logits, all_end_logits = raw_predictions.start_logits, raw_predictions.end_logits
|
77 |
+
|
78 |
+
# print(all_start_logits)
|
79 |
+
|
80 |
+
results = []
|
81 |
+
context = data['context']
|
82 |
+
|
83 |
+
# print(len(features['input_ids']))
|
84 |
+
for i in range(len(features['input_ids'])):
|
85 |
+
start_logits = all_start_logits[i].cpu().numpy()
|
86 |
+
end_logits = all_end_logits[i].cpu().numpy()
|
87 |
+
|
88 |
+
# print(start_logits)
|
89 |
+
|
90 |
+
offset_mapping = features['offset_mapping'][i]
|
91 |
+
|
92 |
+
start_indices = np.argsort(start_logits)[-1: -top_n_answers - 1: -1].tolist()
|
93 |
+
end_indices = np.argsort(end_logits)[-1: -top_n_answers - 1: -1].tolist()
|
94 |
+
|
95 |
+
for start_index in start_indices:
|
96 |
+
for end_index in end_indices:
|
97 |
+
if (
|
98 |
+
start_index >= len(offset_mapping)
|
99 |
+
or end_index >= len(offset_mapping)
|
100 |
+
or offset_mapping[start_index] is None
|
101 |
+
or offset_mapping[end_index] is None
|
102 |
+
or end_index < start_index
|
103 |
+
or end_index - start_index + 1 > max_answer_length
|
104 |
+
):
|
105 |
+
continue
|
106 |
+
|
107 |
+
start_char = offset_mapping[start_index][0]
|
108 |
+
end_char = offset_mapping[end_index][1]
|
109 |
+
|
110 |
+
# print(start_logits[start_index])
|
111 |
+
# print(end_logits[end_index])
|
112 |
+
score = start_logits[start_index] + end_logits[end_index]
|
113 |
+
results.append(
|
114 |
+
{
|
115 |
+
'score': float('%.*g' % (3, score)),
|
116 |
+
'text': context[start_char: end_char]
|
117 |
+
}
|
118 |
+
)
|
119 |
+
|
120 |
+
results = sorted(results, key=lambda x: x["score"], reverse=True)[:top_n_answers]
|
121 |
+
return results
|
122 |
+
|
123 |
+
|
124 |
+
def get_results(self):
|
125 |
+
self.output = self.get_output_from_model()
|
126 |
+
self.features = self.prepare_features(self.data)
|
127 |
+
self.results = self.postprocess_qa_predictions(self.data, self.features, self.output)
|
128 |
+
|
129 |
+
return self.results
|
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import wikipediaapi
|
3 |
+
from Article import Article
|
4 |
+
from QueryProcessor import QueryProcessor
|
5 |
+
from QuestionAnswer import QuestionAnswer
|
6 |
+
|
7 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
8 |
+
|
9 |
+
model = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
|
11 |
+
|
12 |
+
st.write("""
|
13 |
+
# Wiki Q & A
|
14 |
+
""")
|
15 |
+
|
16 |
+
placeholder = st.empty()
|
17 |
+
wiki_wiki = wikipediaapi.Wikipedia('en')
|
18 |
+
|
19 |
+
if "found_article" not in st.session_state:
|
20 |
+
st.session_state.page = 0
|
21 |
+
st.session_state.found_article = False
|
22 |
+
st.session_state.article = ''
|
23 |
+
st.session_state.conversation = []
|
24 |
+
st.session_state.article_data = {}
|
25 |
+
|
26 |
+
|
27 |
+
def get_article():
|
28 |
+
article_name = placeholder.text_input('Enter the name of a Wikipedia article', '')
|
29 |
+
|
30 |
+
if article_name:
|
31 |
+
page = wiki_wiki.page(article_name)
|
32 |
+
if page.exists():
|
33 |
+
st.session_state.found_article = True
|
34 |
+
st.session_state.article = article_name
|
35 |
+
|
36 |
+
article = Article(article_name=article_name)
|
37 |
+
st.session_state.article_data = article.get_article_data()
|
38 |
+
|
39 |
+
ask_questions()
|
40 |
+
else:
|
41 |
+
st.write(f'Sorry, could not find Wikipedia article: {article}')
|
42 |
+
|
43 |
+
def ask_questions():
|
44 |
+
question = placeholder.text_input(f"Ask questions about {st.session_state.article}", '')
|
45 |
+
st.header("Questions and Answers:")
|
46 |
+
|
47 |
+
if question:
|
48 |
+
query_processor = QueryProcessor(
|
49 |
+
question=question,
|
50 |
+
section_texts=st.session_state.article_data['article_data'],
|
51 |
+
N=st.session_state.article_data['num_docs'],
|
52 |
+
avg_doc_len=st.session_state.article_data['avg_doc_len']
|
53 |
+
)
|
54 |
+
|
55 |
+
context = query_processor.get_context()
|
56 |
+
|
57 |
+
data = {
|
58 |
+
'question': question,
|
59 |
+
'context': context
|
60 |
+
}
|
61 |
+
|
62 |
+
qa = QuestionAnswer(data, model, tokenizer, 'cpu')
|
63 |
+
results = qa.get_results()
|
64 |
+
|
65 |
+
answer = ''
|
66 |
+
for r in results:
|
67 |
+
answer += r['text']+", "
|
68 |
+
|
69 |
+
answer = answer[:len(answer)-2]
|
70 |
+
st.session_state.conversation.append({'question' : question, 'answer': answer})
|
71 |
+
st.session_state.conversation.reverse()
|
72 |
+
# print(results)
|
73 |
+
|
74 |
+
if len(st.session_state.conversation) > 0:
|
75 |
+
|
76 |
+
for data in st.session_state.conversation:
|
77 |
+
st.text("Question: " + data['question'] + "\n" + "Answer: " + data['answer'] )
|
78 |
+
|
79 |
+
|
80 |
+
if st.session_state.found_article == False:
|
81 |
+
get_article()
|
82 |
+
|
83 |
+
else:
|
84 |
+
ask_questions()
|