Spaces:
Runtime error
Runtime error
Commit
·
8644233
1
Parent(s):
a70737f
Upload 4 files
Browse files- Article.py +55 -0
- QuestionAnswer.py +129 -0
- VectorDB.py +34 -0
- app.py +97 -0
Article.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import wikipediaapi
|
2 |
+
|
3 |
+
class Article:
|
4 |
+
def __init__(self):
|
5 |
+
self.article = None
|
6 |
+
self.article_data = []
|
7 |
+
self.id_counter = 0
|
8 |
+
|
9 |
+
def set_summary(self):
|
10 |
+
if self.article.summary:
|
11 |
+
for text in self.article.summary.split('\n'):
|
12 |
+
self.id_counter += 1
|
13 |
+
self.article_data.append(
|
14 |
+
{
|
15 |
+
'id': self.id_counter,
|
16 |
+
'section': 'Summary',
|
17 |
+
'text': text.lower()
|
18 |
+
}
|
19 |
+
)
|
20 |
+
|
21 |
+
def set_sections_and_texts(self, sections):
|
22 |
+
for section in sections:
|
23 |
+
if section.text:
|
24 |
+
for text in section.text.split('\n'):
|
25 |
+
self.id_counter += 1
|
26 |
+
self.article_data.append(
|
27 |
+
{
|
28 |
+
'id': self.id_counter,
|
29 |
+
'section': section.title,
|
30 |
+
'text': text.lower()
|
31 |
+
}
|
32 |
+
)
|
33 |
+
if len(section.sections) > 0:
|
34 |
+
self.set_sections_and_texts(section.sections)
|
35 |
+
|
36 |
+
def clean_data(self):
|
37 |
+
unwanted_sections = ['See also', 'External links']
|
38 |
+
cleaned_data = []
|
39 |
+
for data in self.article_data:
|
40 |
+
if len(data['text']) > 1 and data['section'] not in unwanted_sections:
|
41 |
+
cleaned_data.append(data)
|
42 |
+
|
43 |
+
self.article_data = cleaned_data
|
44 |
+
|
45 |
+
def get_article_data(self, article_name):
|
46 |
+
self.article = wikipediaapi.Wikipedia('en').page(article_name)
|
47 |
+
|
48 |
+
if not self.article.exists():
|
49 |
+
return []
|
50 |
+
else:
|
51 |
+
self.set_summary()
|
52 |
+
self.set_sections_and_texts(self.article.sections)
|
53 |
+
self.clean_data()
|
54 |
+
|
55 |
+
return self.article_data
|
QuestionAnswer.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
# # from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
4 |
+
|
5 |
+
|
6 |
+
class QuestionAnswer:
|
7 |
+
|
8 |
+
def __init__(self, data, model, tokenizer, torch_device):
|
9 |
+
|
10 |
+
self.max_length = 384
|
11 |
+
self.doc_stride = 128
|
12 |
+
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.model = model
|
15 |
+
self.data = data
|
16 |
+
self.torch_device = torch_device
|
17 |
+
|
18 |
+
self.output = None
|
19 |
+
self.features = None
|
20 |
+
self.results = None
|
21 |
+
|
22 |
+
def get_output_from_model(self):
|
23 |
+
# data = {'question': question, 'context': context}
|
24 |
+
|
25 |
+
with torch.no_grad():
|
26 |
+
tokenized_data = self.tokenizer(
|
27 |
+
self.data['question'],
|
28 |
+
self.data['context'],
|
29 |
+
truncation='only_second',
|
30 |
+
max_length=self.max_length,
|
31 |
+
stride=self.doc_stride,
|
32 |
+
return_overflowing_tokens=True,
|
33 |
+
return_offsets_mapping=True,
|
34 |
+
padding='max_length',
|
35 |
+
return_tensors='pt'
|
36 |
+
).to(self.torch_device)
|
37 |
+
|
38 |
+
output = self.model(tokenized_data['input_ids'], tokenized_data['attention_mask'])
|
39 |
+
|
40 |
+
return output
|
41 |
+
|
42 |
+
# print(output.keys())
|
43 |
+
# print(output['start_logits'].shape)
|
44 |
+
# print(output['end_logits'].shape)
|
45 |
+
# print(tokenized_data.keys())
|
46 |
+
|
47 |
+
def prepare_features(self, example):
|
48 |
+
tokenized_example = self.tokenizer(
|
49 |
+
example['question'],
|
50 |
+
example['context'],
|
51 |
+
truncation='only_second',
|
52 |
+
max_length=self.max_length,
|
53 |
+
stride=self.doc_stride,
|
54 |
+
return_overflowing_tokens=True,
|
55 |
+
return_offsets_mapping=True,
|
56 |
+
padding='max_length',
|
57 |
+
)
|
58 |
+
|
59 |
+
# sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
|
60 |
+
|
61 |
+
for i in range(len(tokenized_example['input_ids'])):
|
62 |
+
sequence_ids = tokenized_example.sequence_ids(i)
|
63 |
+
# print(sequence_ids)
|
64 |
+
context_index = 1
|
65 |
+
|
66 |
+
# sample_index = sample_mapping[i]
|
67 |
+
|
68 |
+
tokenized_example["offset_mapping"][i] = [
|
69 |
+
(o if sequence_ids[k] == context_index else None)
|
70 |
+
for k, o in enumerate(tokenized_example["offset_mapping"][i])
|
71 |
+
]
|
72 |
+
|
73 |
+
return tokenized_example
|
74 |
+
|
75 |
+
def postprocess_qa_predictions(self, data, features, raw_predictions, top_n_answers=5, max_answer_length=30):
|
76 |
+
all_start_logits, all_end_logits = raw_predictions.start_logits, raw_predictions.end_logits
|
77 |
+
|
78 |
+
# print(all_start_logits)
|
79 |
+
|
80 |
+
results = []
|
81 |
+
context = data['context']
|
82 |
+
|
83 |
+
# print(len(features['input_ids']))
|
84 |
+
for i in range(len(features['input_ids'])):
|
85 |
+
start_logits = all_start_logits[i].cpu().numpy()
|
86 |
+
end_logits = all_end_logits[i].cpu().numpy()
|
87 |
+
|
88 |
+
# print(start_logits)
|
89 |
+
|
90 |
+
offset_mapping = features['offset_mapping'][i]
|
91 |
+
|
92 |
+
start_indices = np.argsort(start_logits)[-1: -top_n_answers - 1: -1].tolist()
|
93 |
+
end_indices = np.argsort(end_logits)[-1: -top_n_answers - 1: -1].tolist()
|
94 |
+
|
95 |
+
for start_index in start_indices:
|
96 |
+
for end_index in end_indices:
|
97 |
+
if (
|
98 |
+
start_index >= len(offset_mapping)
|
99 |
+
or end_index >= len(offset_mapping)
|
100 |
+
or offset_mapping[start_index] is None
|
101 |
+
or offset_mapping[end_index] is None
|
102 |
+
or end_index < start_index
|
103 |
+
or end_index - start_index + 1 > max_answer_length
|
104 |
+
):
|
105 |
+
continue
|
106 |
+
|
107 |
+
start_char = offset_mapping[start_index][0]
|
108 |
+
end_char = offset_mapping[end_index][1]
|
109 |
+
|
110 |
+
# print(start_logits[start_index])
|
111 |
+
# print(end_logits[end_index])
|
112 |
+
score = start_logits[start_index] + end_logits[end_index]
|
113 |
+
results.append(
|
114 |
+
{
|
115 |
+
'score': float('%.*g' % (3, score)),
|
116 |
+
'text': context[start_char: end_char]
|
117 |
+
}
|
118 |
+
)
|
119 |
+
|
120 |
+
results = sorted(results, key=lambda x: x["score"], reverse=True)[:top_n_answers]
|
121 |
+
return results
|
122 |
+
|
123 |
+
|
124 |
+
def get_results(self):
|
125 |
+
self.output = self.get_output_from_model()
|
126 |
+
self.features = self.prepare_features(self.data)
|
127 |
+
self.results = self.postprocess_qa_predictions(self.data, self.features, self.output)
|
128 |
+
|
129 |
+
return self.results
|
VectorDB.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pinecone
|
2 |
+
|
3 |
+
class VectorDB:
|
4 |
+
|
5 |
+
def __init__(self, retreiver, API_KEY):
|
6 |
+
pinecone.init(api_key=API_KEY, environment='us-east1-gcp')
|
7 |
+
self.retreiver = retreiver
|
8 |
+
|
9 |
+
if 'wikiqav2-index' not in pinecone.list_indexes():
|
10 |
+
pinecone.create_index(
|
11 |
+
name='wikiqav2-index', dimension=self.retreiver.get_sentence_embedding_dimension(), metric='cosine'
|
12 |
+
)
|
13 |
+
|
14 |
+
self.index = pinecone.Index('wikiqav2-index')
|
15 |
+
|
16 |
+
def upsert_data(self, article_data):
|
17 |
+
for i in range(len(article_data)):
|
18 |
+
article_data[i]['encoding'] = self.retreiver.encode(article_data[i]['text']).tolist()
|
19 |
+
|
20 |
+
|
21 |
+
upserts = [(str(v['id']), v['encoding'], {'text': v['text'], 'section': v['section']}) for v in article_data]
|
22 |
+
|
23 |
+
# index.upsert(vectors=upserts[0])
|
24 |
+
|
25 |
+
for i in range(0, len(upserts), 10):
|
26 |
+
i_end = i + 10
|
27 |
+
if i_end > len(upserts):
|
28 |
+
i_end = len(upserts)
|
29 |
+
self.index.upsert(vectors=upserts[i:i_end])
|
30 |
+
|
31 |
+
def get_contexts(self, question):
|
32 |
+
xq = self.retreiver.encode([question]).tolist()
|
33 |
+
contexts = self.index.query(xq, top_k=1, include_metadata=True)
|
34 |
+
return contexts
|
app.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import os
|
3 |
+
|
4 |
+
from Article import Article
|
5 |
+
from VectorDB import VectorDB
|
6 |
+
from QuestionAnswer import QuestionAnswer
|
7 |
+
|
8 |
+
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
|
9 |
+
from sentence_transformers import models, SentenceTransformer
|
10 |
+
|
11 |
+
|
12 |
+
reader = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
|
14 |
+
|
15 |
+
distilbert = models.Transformer("Pennywise881/distilbert-base-uncased-mnr-squadv2")
|
16 |
+
pooler = models.Pooling(
|
17 |
+
distilbert.get_word_embedding_dimension(),
|
18 |
+
pooling_mode_mean_tokens=True
|
19 |
+
)
|
20 |
+
|
21 |
+
retreiver = SentenceTransformer(modules=[distilbert, pooler])
|
22 |
+
|
23 |
+
if 'found_article' not in st.session_state:
|
24 |
+
st.session_state.found_article = False
|
25 |
+
st.session_state.article_name = ''
|
26 |
+
st.session_state.db = None
|
27 |
+
st.session_state.qas = []
|
28 |
+
|
29 |
+
st.write("""
|
30 |
+
# Wiki Q&A V2
|
31 |
+
""")
|
32 |
+
placeholder = st.empty()
|
33 |
+
|
34 |
+
def get_article(retreiver):
|
35 |
+
article_name = placeholder.text_input("Enter the name of a Wikipedia article")
|
36 |
+
|
37 |
+
if article_name:
|
38 |
+
article = Article()
|
39 |
+
article_data = article.get_article_data(article_name=article_name)
|
40 |
+
|
41 |
+
if len(article_data) > 0:
|
42 |
+
API_KEY = os.environ['API_KEY']
|
43 |
+
db = VectorDB(retreiver=retreiver, API_KEY=API_KEY)
|
44 |
+
db.upsert_data(article_data=article_data)
|
45 |
+
ask_questions(article_name=article_name, db=db)
|
46 |
+
|
47 |
+
st.session_state.found_article = True
|
48 |
+
st.session_state.article_name = article_name
|
49 |
+
st.session_state.db = db
|
50 |
+
else:
|
51 |
+
st.write(f'Sorry, could not find Wikipedia article: {article_name}')
|
52 |
+
|
53 |
+
def ask_questions(article_name, db : VectorDB):
|
54 |
+
question = placeholder.text_input(f"Ask questions about '{article_name}'", '')
|
55 |
+
st.header("Questions and Answers:")
|
56 |
+
|
57 |
+
if question:
|
58 |
+
contexts = db.get_contexts(question.lower())
|
59 |
+
# print(contexts)
|
60 |
+
|
61 |
+
data = {
|
62 |
+
'question': question.lower(),
|
63 |
+
'context': contexts['matches'][0]['metadata']['text']
|
64 |
+
}
|
65 |
+
qa = QuestionAnswer(data, reader, tokenizer, 'cpu')
|
66 |
+
results = qa.get_results()
|
67 |
+
|
68 |
+
paragraph_index = contexts['matches'][0]['id']
|
69 |
+
section = contexts['matches'][0]['metadata']['section']
|
70 |
+
answer = ''
|
71 |
+
for r in results:
|
72 |
+
answer += r['text'] + ", "
|
73 |
+
|
74 |
+
answer = answer[:len(answer) - 2]
|
75 |
+
st.session_state.qas.append(
|
76 |
+
{
|
77 |
+
'question': question,
|
78 |
+
'answer': answer,
|
79 |
+
'section': section,
|
80 |
+
'para': paragraph_index
|
81 |
+
}
|
82 |
+
)
|
83 |
+
|
84 |
+
if len(st.session_state.qas) > 0:
|
85 |
+
for data in st.session_state.qas:
|
86 |
+
st.text(
|
87 |
+
"Question: " + data['question'] + '\n' +
|
88 |
+
"Answer: " + data['answer'] + '\n' +
|
89 |
+
"Section: " + data['section'] + '\n' +
|
90 |
+
"Paragraph #: " + data['para']
|
91 |
+
)
|
92 |
+
|
93 |
+
if st.session_state.found_article == False:
|
94 |
+
get_article(retreiver)
|
95 |
+
|
96 |
+
else:
|
97 |
+
ask_questions(st.session_state.article_name, st.session_state.db)
|