Pennywise881 commited on
Commit
8644233
·
1 Parent(s): a70737f

Upload 4 files

Browse files
Files changed (4) hide show
  1. Article.py +55 -0
  2. QuestionAnswer.py +129 -0
  3. VectorDB.py +34 -0
  4. app.py +97 -0
Article.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipediaapi
2
+
3
+ class Article:
4
+ def __init__(self):
5
+ self.article = None
6
+ self.article_data = []
7
+ self.id_counter = 0
8
+
9
+ def set_summary(self):
10
+ if self.article.summary:
11
+ for text in self.article.summary.split('\n'):
12
+ self.id_counter += 1
13
+ self.article_data.append(
14
+ {
15
+ 'id': self.id_counter,
16
+ 'section': 'Summary',
17
+ 'text': text.lower()
18
+ }
19
+ )
20
+
21
+ def set_sections_and_texts(self, sections):
22
+ for section in sections:
23
+ if section.text:
24
+ for text in section.text.split('\n'):
25
+ self.id_counter += 1
26
+ self.article_data.append(
27
+ {
28
+ 'id': self.id_counter,
29
+ 'section': section.title,
30
+ 'text': text.lower()
31
+ }
32
+ )
33
+ if len(section.sections) > 0:
34
+ self.set_sections_and_texts(section.sections)
35
+
36
+ def clean_data(self):
37
+ unwanted_sections = ['See also', 'External links']
38
+ cleaned_data = []
39
+ for data in self.article_data:
40
+ if len(data['text']) > 1 and data['section'] not in unwanted_sections:
41
+ cleaned_data.append(data)
42
+
43
+ self.article_data = cleaned_data
44
+
45
+ def get_article_data(self, article_name):
46
+ self.article = wikipediaapi.Wikipedia('en').page(article_name)
47
+
48
+ if not self.article.exists():
49
+ return []
50
+ else:
51
+ self.set_summary()
52
+ self.set_sections_and_texts(self.article.sections)
53
+ self.clean_data()
54
+
55
+ return self.article_data
QuestionAnswer.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ # # from transformers import AutoTokenizer, AutoModelForQuestionAnswering
4
+
5
+
6
+ class QuestionAnswer:
7
+
8
+ def __init__(self, data, model, tokenizer, torch_device):
9
+
10
+ self.max_length = 384
11
+ self.doc_stride = 128
12
+
13
+ self.tokenizer = tokenizer
14
+ self.model = model
15
+ self.data = data
16
+ self.torch_device = torch_device
17
+
18
+ self.output = None
19
+ self.features = None
20
+ self.results = None
21
+
22
+ def get_output_from_model(self):
23
+ # data = {'question': question, 'context': context}
24
+
25
+ with torch.no_grad():
26
+ tokenized_data = self.tokenizer(
27
+ self.data['question'],
28
+ self.data['context'],
29
+ truncation='only_second',
30
+ max_length=self.max_length,
31
+ stride=self.doc_stride,
32
+ return_overflowing_tokens=True,
33
+ return_offsets_mapping=True,
34
+ padding='max_length',
35
+ return_tensors='pt'
36
+ ).to(self.torch_device)
37
+
38
+ output = self.model(tokenized_data['input_ids'], tokenized_data['attention_mask'])
39
+
40
+ return output
41
+
42
+ # print(output.keys())
43
+ # print(output['start_logits'].shape)
44
+ # print(output['end_logits'].shape)
45
+ # print(tokenized_data.keys())
46
+
47
+ def prepare_features(self, example):
48
+ tokenized_example = self.tokenizer(
49
+ example['question'],
50
+ example['context'],
51
+ truncation='only_second',
52
+ max_length=self.max_length,
53
+ stride=self.doc_stride,
54
+ return_overflowing_tokens=True,
55
+ return_offsets_mapping=True,
56
+ padding='max_length',
57
+ )
58
+
59
+ # sample_mapping = tokenized_example.pop("overflow_to_sample_mapping")
60
+
61
+ for i in range(len(tokenized_example['input_ids'])):
62
+ sequence_ids = tokenized_example.sequence_ids(i)
63
+ # print(sequence_ids)
64
+ context_index = 1
65
+
66
+ # sample_index = sample_mapping[i]
67
+
68
+ tokenized_example["offset_mapping"][i] = [
69
+ (o if sequence_ids[k] == context_index else None)
70
+ for k, o in enumerate(tokenized_example["offset_mapping"][i])
71
+ ]
72
+
73
+ return tokenized_example
74
+
75
+ def postprocess_qa_predictions(self, data, features, raw_predictions, top_n_answers=5, max_answer_length=30):
76
+ all_start_logits, all_end_logits = raw_predictions.start_logits, raw_predictions.end_logits
77
+
78
+ # print(all_start_logits)
79
+
80
+ results = []
81
+ context = data['context']
82
+
83
+ # print(len(features['input_ids']))
84
+ for i in range(len(features['input_ids'])):
85
+ start_logits = all_start_logits[i].cpu().numpy()
86
+ end_logits = all_end_logits[i].cpu().numpy()
87
+
88
+ # print(start_logits)
89
+
90
+ offset_mapping = features['offset_mapping'][i]
91
+
92
+ start_indices = np.argsort(start_logits)[-1: -top_n_answers - 1: -1].tolist()
93
+ end_indices = np.argsort(end_logits)[-1: -top_n_answers - 1: -1].tolist()
94
+
95
+ for start_index in start_indices:
96
+ for end_index in end_indices:
97
+ if (
98
+ start_index >= len(offset_mapping)
99
+ or end_index >= len(offset_mapping)
100
+ or offset_mapping[start_index] is None
101
+ or offset_mapping[end_index] is None
102
+ or end_index < start_index
103
+ or end_index - start_index + 1 > max_answer_length
104
+ ):
105
+ continue
106
+
107
+ start_char = offset_mapping[start_index][0]
108
+ end_char = offset_mapping[end_index][1]
109
+
110
+ # print(start_logits[start_index])
111
+ # print(end_logits[end_index])
112
+ score = start_logits[start_index] + end_logits[end_index]
113
+ results.append(
114
+ {
115
+ 'score': float('%.*g' % (3, score)),
116
+ 'text': context[start_char: end_char]
117
+ }
118
+ )
119
+
120
+ results = sorted(results, key=lambda x: x["score"], reverse=True)[:top_n_answers]
121
+ return results
122
+
123
+
124
+ def get_results(self):
125
+ self.output = self.get_output_from_model()
126
+ self.features = self.prepare_features(self.data)
127
+ self.results = self.postprocess_qa_predictions(self.data, self.features, self.output)
128
+
129
+ return self.results
VectorDB.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pinecone
2
+
3
+ class VectorDB:
4
+
5
+ def __init__(self, retreiver, API_KEY):
6
+ pinecone.init(api_key=API_KEY, environment='us-east1-gcp')
7
+ self.retreiver = retreiver
8
+
9
+ if 'wikiqav2-index' not in pinecone.list_indexes():
10
+ pinecone.create_index(
11
+ name='wikiqav2-index', dimension=self.retreiver.get_sentence_embedding_dimension(), metric='cosine'
12
+ )
13
+
14
+ self.index = pinecone.Index('wikiqav2-index')
15
+
16
+ def upsert_data(self, article_data):
17
+ for i in range(len(article_data)):
18
+ article_data[i]['encoding'] = self.retreiver.encode(article_data[i]['text']).tolist()
19
+
20
+
21
+ upserts = [(str(v['id']), v['encoding'], {'text': v['text'], 'section': v['section']}) for v in article_data]
22
+
23
+ # index.upsert(vectors=upserts[0])
24
+
25
+ for i in range(0, len(upserts), 10):
26
+ i_end = i + 10
27
+ if i_end > len(upserts):
28
+ i_end = len(upserts)
29
+ self.index.upsert(vectors=upserts[i:i_end])
30
+
31
+ def get_contexts(self, question):
32
+ xq = self.retreiver.encode([question]).tolist()
33
+ contexts = self.index.query(xq, top_k=1, include_metadata=True)
34
+ return contexts
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+
4
+ from Article import Article
5
+ from VectorDB import VectorDB
6
+ from QuestionAnswer import QuestionAnswer
7
+
8
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
9
+ from sentence_transformers import models, SentenceTransformer
10
+
11
+
12
+ reader = AutoModelForQuestionAnswering.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
13
+ tokenizer = AutoTokenizer.from_pretrained('Pennywise881/distilbert-base-uncased-finetuned-squad-v2')
14
+
15
+ distilbert = models.Transformer("Pennywise881/distilbert-base-uncased-mnr-squadv2")
16
+ pooler = models.Pooling(
17
+ distilbert.get_word_embedding_dimension(),
18
+ pooling_mode_mean_tokens=True
19
+ )
20
+
21
+ retreiver = SentenceTransformer(modules=[distilbert, pooler])
22
+
23
+ if 'found_article' not in st.session_state:
24
+ st.session_state.found_article = False
25
+ st.session_state.article_name = ''
26
+ st.session_state.db = None
27
+ st.session_state.qas = []
28
+
29
+ st.write("""
30
+ # Wiki Q&A V2
31
+ """)
32
+ placeholder = st.empty()
33
+
34
+ def get_article(retreiver):
35
+ article_name = placeholder.text_input("Enter the name of a Wikipedia article")
36
+
37
+ if article_name:
38
+ article = Article()
39
+ article_data = article.get_article_data(article_name=article_name)
40
+
41
+ if len(article_data) > 0:
42
+ API_KEY = os.environ['API_KEY']
43
+ db = VectorDB(retreiver=retreiver, API_KEY=API_KEY)
44
+ db.upsert_data(article_data=article_data)
45
+ ask_questions(article_name=article_name, db=db)
46
+
47
+ st.session_state.found_article = True
48
+ st.session_state.article_name = article_name
49
+ st.session_state.db = db
50
+ else:
51
+ st.write(f'Sorry, could not find Wikipedia article: {article_name}')
52
+
53
+ def ask_questions(article_name, db : VectorDB):
54
+ question = placeholder.text_input(f"Ask questions about '{article_name}'", '')
55
+ st.header("Questions and Answers:")
56
+
57
+ if question:
58
+ contexts = db.get_contexts(question.lower())
59
+ # print(contexts)
60
+
61
+ data = {
62
+ 'question': question.lower(),
63
+ 'context': contexts['matches'][0]['metadata']['text']
64
+ }
65
+ qa = QuestionAnswer(data, reader, tokenizer, 'cpu')
66
+ results = qa.get_results()
67
+
68
+ paragraph_index = contexts['matches'][0]['id']
69
+ section = contexts['matches'][0]['metadata']['section']
70
+ answer = ''
71
+ for r in results:
72
+ answer += r['text'] + ", "
73
+
74
+ answer = answer[:len(answer) - 2]
75
+ st.session_state.qas.append(
76
+ {
77
+ 'question': question,
78
+ 'answer': answer,
79
+ 'section': section,
80
+ 'para': paragraph_index
81
+ }
82
+ )
83
+
84
+ if len(st.session_state.qas) > 0:
85
+ for data in st.session_state.qas:
86
+ st.text(
87
+ "Question: " + data['question'] + '\n' +
88
+ "Answer: " + data['answer'] + '\n' +
89
+ "Section: " + data['section'] + '\n' +
90
+ "Paragraph #: " + data['para']
91
+ )
92
+
93
+ if st.session_state.found_article == False:
94
+ get_article(retreiver)
95
+
96
+ else:
97
+ ask_questions(st.session_state.article_name, st.session_state.db)