import numpy as np import nltk nltk.download('stopwords') from nltk.corpus import stopwords from nltk.tokenize import RegexpTokenizer class QueryProcessor: def __init__(self, question, section_texts, N, avg_doc_len): self.section_texts = section_texts self.N = N self.avg_doc_len = avg_doc_len # self.bm25_scores = {} self.query_items = self.set_query(question) self.section_document_idx = None def set_query(self, question): punct_regex = RegexpTokenizer(r'\w+') return [q for q in punct_regex.tokenize(question.lower()) if q not in stopwords.words('english')] def get_query(self): return self.query_items def bm25(self, word, paragraph, k=1.2, b=0.75): # frequency of word (word) in doc (paragraph) freq = paragraph.split().count(word) # term frequency tf = (freq * (k+1)) / (freq + k * (1 - b + b * len(paragraph.split()) / self.avg_doc_len)) # number of docs that contain the word N_q = sum([1 for _, docs in self.section_texts.items() for doc in docs if word in doc.split()]) # inverse document frequency idf = np.log(((self.N - N_q + 0.5) / (N_q + 0.5)) + 1) return round(tf*idf, 4) def get_bm25_scores(self): bm25_scores = {} for query in self.query_items: bm25_scores[query] = {} for section, docs in self.section_texts.items(): bm25_scores[query][section] = {} for doc_index in range(len(docs)): score = self.bm25(query, docs[doc_index]) if score > 0.0: bm25_scores[query][section][doc_index] = score if len(bm25_scores[query][section]) <= 0: del bm25_scores[query][section] return bm25_scores def filter_bad_documents(self, bm25_scores): section_document_idx = {} for sec_docs in bm25_scores.values(): for sec, doc_scores in sec_docs.items(): if sec not in section_document_idx: section_document_idx[sec] = [] for doc_idx, score in doc_scores.items(): if score > 0.5 and doc_idx not in section_document_idx[sec]: section_document_idx[sec].append(doc_idx) if len(section_document_idx[sec]) <= 0: del section_document_idx[sec] return section_document_idx def get_context(self): bm25_scores = self.get_bm25_scores() self.section_document_idx = self.filter_bad_documents(bm25_scores) # print(bm25_scores) context = ' '.join([self.section_texts[section][d_id] for section, doc_ids in self.section_document_idx.items() for d_id in doc_ids]) # print(section_document_idx) return context def match_section_with_answer_text(self, text): # print(text) sections = [] for sec, doc_ids in self.section_document_idx.items(): for d_id in doc_ids: if self.section_texts[sec][d_id].find(text) > -1: sections.append(sec) return sections