Spaces:
Runtime error
Runtime error
import numpy as np | |
import nltk | |
nltk.download('stopwords') | |
from nltk.corpus import stopwords | |
from nltk.tokenize import RegexpTokenizer | |
class QueryProcessor: | |
def __init__(self, question, section_texts, N, avg_doc_len): | |
self.section_texts = section_texts | |
self.N = N | |
self.avg_doc_len = avg_doc_len | |
# self.bm25_scores = {} | |
self.query_items = self.set_query(question) | |
self.section_document_idx = None | |
def set_query(self, question): | |
punct_regex = RegexpTokenizer(r'\w+') | |
return [q for q in punct_regex.tokenize(question.lower()) if q not in stopwords.words('english')] | |
def get_query(self): | |
return self.query_items | |
def bm25(self, word, paragraph, k=1.2, b=0.75): | |
# frequency of word (word) in doc (paragraph) | |
freq = paragraph.split().count(word) | |
# term frequency | |
tf = (freq * (k+1)) / (freq + k * (1 - b + b * len(paragraph.split()) / self.avg_doc_len)) | |
# number of docs that contain the word | |
N_q = sum([1 for _, docs in self.section_texts.items() for doc in docs if word in doc.split()]) | |
# inverse document frequency | |
idf = np.log(((self.N - N_q + 0.5) / (N_q + 0.5)) + 1) | |
return round(tf*idf, 4) | |
def get_bm25_scores(self): | |
bm25_scores = {} | |
for query in self.query_items: | |
bm25_scores[query] = {} | |
for section, docs in self.section_texts.items(): | |
bm25_scores[query][section] = {} | |
for doc_index in range(len(docs)): | |
score = self.bm25(query, docs[doc_index]) | |
if score > 0.0: | |
bm25_scores[query][section][doc_index] = score | |
if len(bm25_scores[query][section]) <= 0: | |
del bm25_scores[query][section] | |
return bm25_scores | |
def filter_bad_documents(self, bm25_scores): | |
section_document_idx = {} | |
for sec_docs in bm25_scores.values(): | |
for sec, doc_scores in sec_docs.items(): | |
if sec not in section_document_idx: | |
section_document_idx[sec] = [] | |
for doc_idx, score in doc_scores.items(): | |
if score > 0.5 and doc_idx not in section_document_idx[sec]: | |
section_document_idx[sec].append(doc_idx) | |
if len(section_document_idx[sec]) <= 0: | |
del section_document_idx[sec] | |
return section_document_idx | |
def get_context(self): | |
bm25_scores = self.get_bm25_scores() | |
self.section_document_idx = self.filter_bad_documents(bm25_scores) | |
# print(bm25_scores) | |
context = ' '.join([self.section_texts[section][d_id] for section, doc_ids in self.section_document_idx.items() for d_id in doc_ids]) | |
# print(section_document_idx) | |
return context | |
def match_section_with_answer_text(self, text): | |
# print(text) | |
sections = [] | |
for sec, doc_ids in self.section_document_idx.items(): | |
for d_id in doc_ids: | |
if self.section_texts[sec][d_id].find(text) > -1: | |
sections.append(sec) | |
return sections | |