Jingxiang Mo commited on
Commit
bde6562
1 Parent(s): 782aff2

Added deliverable 3

Browse files
deliverables/MAIS 202 - Project Deliverable 3.pdf ADDED
Binary file (511 kB). View file
 
test.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import wikipediaapi as wk
5
+ from transformers import (
6
+ TokenClassificationPipeline,
7
+ AutoModelForTokenClassification,
8
+ AutoTokenizer,
9
+ )
10
+ import torch
11
+ from transformers.pipelines import AggregationStrategy
12
+ from transformers import BertForQuestionAnswering
13
+ from transformers import BertTokenizer
14
+
15
+ # =====[ DEFINE PIPELINE ]===== #
16
+ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
17
+ def __init__(self, model, *args, **kwargs):
18
+ super().__init__(
19
+ model=AutoModelForTokenClassification.from_pretrained(model),
20
+ tokenizer=AutoTokenizer.from_pretrained(model),
21
+ *args,
22
+ **kwargs
23
+ )
24
+
25
+ def postprocess(self, model_outputs):
26
+ results = super().postprocess(
27
+ model_outputs=model_outputs,
28
+ aggregation_strategy=AggregationStrategy.SIMPLE,
29
+ )
30
+ return np.unique([result.get("word").strip() for result in results])
31
+
32
+ # =====[ LOAD PIPELINE ]===== #
33
+ keyPhraseExtractionModel = "ml6team/keyphrase-extraction-kbir-inspec"
34
+ extractor = KeyphraseExtractionPipeline(model=keyPhraseExtractionModel)
35
+ model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
36
+ tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
37
+
38
+ #TODO: add further preprocessing
39
+ def keyphrases_extraction(text: str) -> str:
40
+ keyphrases = extractor(text)
41
+ return keyphrases
42
+
43
+ def wikipedia_search(input: str) -> str:
44
+ input = input.replace("\n", " ")
45
+ keyphrases = keyphrases_extraction(input)
46
+ wiki = wk.Wikipedia('en')
47
+
48
+ try :
49
+ #TODO: add better extraction and search
50
+ keyphrase_index = 0
51
+ page = wiki.page(keyphrases[keyphrase_index])
52
+
53
+ while not ('.' in page.summary) or not page.exists():
54
+ keyphrase_index += 1
55
+ if keyphrase_index == len(keyphrases):
56
+ raise Exception
57
+ page = wiki.page(keyphrases[keyphrase_index])
58
+ return page.summary
59
+ except:
60
+ return "I cannot answer this question"
61
+
62
+ def answer_question(question):
63
+
64
+ context = wikipedia_search(question)
65
+ if context == "I cannot answer this question":
66
+ return context
67
+
68
+ # ======== Tokenize ========
69
+ # Apply the tokenizer to the input text, treating them as a text-pair.
70
+ input_ids = tokenizer.encode(question, context)
71
+
72
+ # Report how long the input sequence is. if longer than 512 tokens, make it shorter
73
+ while(len(input_ids) > 512):
74
+ input_ids.pop()
75
+
76
+ print('Query has {:,} tokens.\n'.format(len(input_ids)))
77
+
78
+ # ======== Set Segment IDs ========
79
+ # Search the input_ids for the first instance of the `[SEP]` token.
80
+ sep_index = input_ids.index(tokenizer.sep_token_id)
81
+
82
+ # The number of segment A tokens includes the [SEP] token istelf.
83
+ num_seg_a = sep_index + 1
84
+
85
+ # The remainder are segment B.
86
+ num_seg_b = len(input_ids) - num_seg_a
87
+
88
+ # Construct the list of 0s and 1s.
89
+ segment_ids = [0]*num_seg_a + [1]*num_seg_b
90
+
91
+ # There should be a segment_id for every input token.
92
+ assert len(segment_ids) == len(input_ids)
93
+
94
+ # ======== Evaluate ========
95
+ # Run our example through the model.
96
+ outputs = model(torch.tensor([input_ids]), # The tokens representing our input text.
97
+ token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
98
+ return_dict=True)
99
+
100
+ start_scores = outputs.start_logits
101
+ end_scores = outputs.end_logits
102
+
103
+ # ======== Reconstruct Answer ========
104
+ # Find the tokens with the highest `start` and `end` scores.
105
+ answer_start = torch.argmax(start_scores)
106
+ answer_end = torch.argmax(end_scores)
107
+
108
+ # Get the string versions of the input tokens.
109
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
110
+
111
+ # Start with the first token.
112
+ answer = tokens[answer_start]
113
+
114
+ # Select the remaining answer tokens and join them with whitespace.
115
+ for i in range(answer_start + 1, answer_end + 1):
116
+
117
+ # If it's a subword token, then recombine it with the previous token.
118
+ if tokens[i][0:2] == '##':
119
+ answer += tokens[i][2:]
120
+
121
+ # Otherwise, add a space then the token.
122
+ else:
123
+ answer += ' ' + tokens[i]
124
+
125
+ return 'Answer: "' + answer + '"'
126
+
127
+ # =====[ DEFINE INTERFACE ]===== #'
128
+ title = "Azza Chatbot"
129
+ examples = [
130
+ ["Where is the Eiffel Tower?"],
131
+ ["What is the population of France?"]
132
+ ]
133
+
134
+
135
+
136
+ demo = gr.Interface(
137
+ title = title,
138
+
139
+ fn=answer_question,
140
+ inputs = "text",
141
+ outputs = "text",
142
+
143
+ examples=examples,
144
+ )
145
+
146
+ if __name__ == "__main__":
147
+ demo.launch(share=True)