wiki-chat / QueryProcessor.py
Pennywise881's picture
Update QueryProcessor.py
4e8a231
import numpy as np
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
from nltk.tokenize import RegexpTokenizer
class QueryProcessor:
def __init__(self, question, section_texts, N, avg_doc_len):
self.section_texts = section_texts
self.N = N
self.avg_doc_len = avg_doc_len
# self.bm25_scores = {}
self.query_items = self.set_query(question)
self.section_document_idx = None
def set_query(self, question):
punct_regex = RegexpTokenizer(r'\w+')
return [q for q in punct_regex.tokenize(question.lower()) if q not in stopwords.words('english')]
def get_query(self):
return self.query_items
def bm25(self, word, paragraph, k=1.2, b=0.75):
# frequency of word (word) in doc (paragraph)
freq = paragraph.split().count(word)
# term frequency
tf = (freq * (k+1)) / (freq + k * (1 - b + b * len(paragraph.split()) / self.avg_doc_len))
# number of docs that contain the word
N_q = sum([1 for _, docs in self.section_texts.items() for doc in docs if word in doc.split()])
# inverse document frequency
idf = np.log(((self.N - N_q + 0.5) / (N_q + 0.5)) + 1)
return round(tf*idf, 4)
def get_bm25_scores(self):
bm25_scores = {}
for query in self.query_items:
bm25_scores[query] = {}
for section, docs in self.section_texts.items():
bm25_scores[query][section] = {}
for doc_index in range(len(docs)):
score = self.bm25(query, docs[doc_index])
if score > 0.0:
bm25_scores[query][section][doc_index] = score
if len(bm25_scores[query][section]) <= 0:
del bm25_scores[query][section]
return bm25_scores
def filter_bad_documents(self, bm25_scores):
section_document_idx = {}
for sec_docs in bm25_scores.values():
for sec, doc_scores in sec_docs.items():
if sec not in section_document_idx:
section_document_idx[sec] = []
for doc_idx, score in doc_scores.items():
if score > 0.5 and doc_idx not in section_document_idx[sec]:
section_document_idx[sec].append(doc_idx)
if len(section_document_idx[sec]) <= 0:
del section_document_idx[sec]
return section_document_idx
def get_context(self):
bm25_scores = self.get_bm25_scores()
self.section_document_idx = self.filter_bad_documents(bm25_scores)
# print(bm25_scores)
context = ' '.join([self.section_texts[section][d_id] for section, doc_ids in self.section_document_idx.items() for d_id in doc_ids])
# print(section_document_idx)
return context
def match_section_with_answer_text(self, text):
# print(text)
sections = []
for sec, doc_ids in self.section_document_idx.items():
for d_id in doc_ids:
if self.section_texts[sec][d_id].find(text) > -1:
sections.append(sec)
return sections