asimmetti commited on
Commit
3274425
·
verified ·
1 Parent(s): 9abdfdc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -0
app.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from _future_ import annotations
2
+ from dataclasses import dataclass
3
+ import pickle
4
+ import os
5
+ from typing import Iterable, Callable, List, Dict, Optional, Type, TypeVar
6
+ from nlp4web_codebase.ir.data_loaders.dm import Document
7
+ from collections import Counter
8
+ import tqdm
9
+ import re
10
+ import nltk
11
+ nltk.download("stopwords", quiet=True)
12
+ from nltk.corpus import stopwords as nltk_stopwords
13
+
14
+ LANGUAGE = "english"
15
+ word_splitter = re.compile(r"(?u)\b\w\w+\b").findall
16
+ stopwords = set(nltk_stopwords.words(LANGUAGE))
17
+
18
+
19
+ def word_splitting(text: str) -> List[str]:
20
+ return word_splitter(text.lower())
21
+
22
+ def lemmatization(words: List[str]) -> List[str]:
23
+ return words # We ignore lemmatization here for simplicity
24
+
25
+ def simple_tokenize(text: str) -> List[str]:
26
+ words = word_splitting(text)
27
+ tokenized = list(filter(lambda w: w not in stopwords, words))
28
+ tokenized = lemmatization(tokenized)
29
+ return tokenized
30
+
31
+ T = TypeVar("T", bound="InvertedIndex")
32
+
33
+ @dataclass
34
+ class PostingList:
35
+ term: str # The term
36
+ docid_postings: List[int] # docid_postings[i] means the docid (int) of the i-th associated posting
37
+ tweight_postings: List[float] # tweight_postings[i] means the term weight (float) of the i-th associated posting
38
+
39
+
40
+ @dataclass
41
+ class InvertedIndex:
42
+ posting_lists: List[PostingList] # docid -> posting_list
43
+ vocab: Dict[str, int]
44
+ cid2docid: Dict[str, int] # collection_id -> docid
45
+ collection_ids: List[str] # docid -> collection_id
46
+ doc_texts: Optional[List[str]] = None # docid -> document text
47
+
48
+ def save(self, output_dir: str) -> None:
49
+ os.makedirs(output_dir, exist_ok=True)
50
+ with open(os.path.join(output_dir, "index.pkl"), "wb") as f:
51
+ pickle.dump(self, f)
52
+
53
+ @classmethod
54
+ def from_saved(cls: Type[T], saved_dir: str) -> T:
55
+ index = cls(
56
+ posting_lists=[], vocab={}, cid2docid={}, collection_ids=[], doc_texts=None
57
+ )
58
+ with open(os.path.join(saved_dir, "index.pkl"), "rb") as f:
59
+ index = pickle.load(f)
60
+ return index
61
+
62
+
63
+ # The output of the counting function:
64
+ @dataclass
65
+ class Counting:
66
+ posting_lists: List[PostingList]
67
+ vocab: Dict[str, int]
68
+ cid2docid: Dict[str, int]
69
+ collection_ids: List[str]
70
+ dfs: List[int] # tid -> df
71
+ dls: List[int] # docid -> doc length
72
+ avgdl: float
73
+ nterms: int
74
+ doc_texts: Optional[List[str]] = None
75
+
76
+ def run_counting(
77
+ documents: Iterable[Document],
78
+ tokenize_fn: Callable[[str], List[str]] = simple_tokenize,
79
+ store_raw: bool = True, # store the document text in doc_texts
80
+ ndocs: Optional[int] = None,
81
+ show_progress_bar: bool = True,
82
+ ) -> Counting:
83
+ """Counting TFs, DFs, doc_lengths, etc."""
84
+ posting_lists: List[PostingList] = []
85
+ vocab: Dict[str, int] = {}
86
+ cid2docid: Dict[str, int] = {}
87
+ collection_ids: List[str] = []
88
+ dfs: List[int] = [] # tid -> df
89
+ dls: List[int] = [] # docid -> doc length
90
+ nterms: int = 0
91
+ doc_texts: Optional[List[str]] = []
92
+ for doc in tqdm.tqdm(
93
+ documents,
94
+ desc="Counting",
95
+ total=ndocs,
96
+ disable=not show_progress_bar,
97
+ ):
98
+ if doc.collection_id in cid2docid:
99
+ continue
100
+ collection_ids.append(doc.collection_id)
101
+ docid = cid2docid.setdefault(doc.collection_id, len(cid2docid))
102
+ toks = tokenize_fn(doc.text)
103
+ tok2tf = Counter(toks)
104
+ dls.append(sum(tok2tf.values()))
105
+ for tok, tf in tok2tf.items():
106
+ nterms += tf
107
+ tid = vocab.get(tok, None)
108
+ if tid is None:
109
+ posting_lists.append(
110
+ PostingList(term=tok, docid_postings=[], tweight_postings=[])
111
+ )
112
+ tid = vocab.setdefault(tok, len(vocab))
113
+ posting_lists[tid].docid_postings.append(docid)
114
+ posting_lists[tid].tweight_postings.append(tf)
115
+ if tid < len(dfs):
116
+ dfs[tid] += 1
117
+ else:
118
+ dfs.append(0)
119
+ if store_raw:
120
+ doc_texts.append(doc.text)
121
+ else:
122
+ doc_texts = None
123
+ return Counting(
124
+ posting_lists=posting_lists,
125
+ vocab=vocab,
126
+ cid2docid=cid2docid,
127
+ collection_ids=collection_ids,
128
+ dfs=dfs,
129
+ dls=dls,
130
+ avgdl=sum(dls) / len(dls),
131
+ nterms=nterms,
132
+ doc_texts=doc_texts,
133
+ )
134
+
135
+ from nlp4web_codebase.ir.data_loaders.sciq import load_sciq
136
+ sciq = load_sciq()
137
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
138
+
139
+
140
+ from dataclasses import asdict, dataclass
141
+ import math
142
+ import os
143
+ from typing import Iterable, List, Optional, Type
144
+ import tqdm
145
+ from nlp4web_codebase.ir.data_loaders.dm import Document
146
+
147
+
148
+ @dataclass
149
+ class BM25Index(InvertedIndex):
150
+
151
+ @staticmethod
152
+ def tokenize(text: str) -> List[str]:
153
+ return simple_tokenize(text)
154
+
155
+ @staticmethod
156
+ def cache_term_weights(
157
+ posting_lists: List[PostingList],
158
+ total_docs: int,
159
+ avgdl: float,
160
+ dfs: List[int],
161
+ dls: List[int],
162
+ k1: float,
163
+ b: float,
164
+ ) -> None:
165
+ """Compute term weights and caching"""
166
+
167
+ N = total_docs
168
+ for tid, posting_list in enumerate(
169
+ tqdm.tqdm(posting_lists, desc="Regularizing TFs")
170
+ ):
171
+ idf = BM25Index.calc_idf(df=dfs[tid], N=N)
172
+ for i in range(len(posting_list.docid_postings)):
173
+ docid = posting_list.docid_postings[i]
174
+ tf = posting_list.tweight_postings[i]
175
+ dl = dls[docid]
176
+ regularized_tf = BM25Index.calc_regularized_tf(
177
+ tf=tf, dl=dl, avgdl=avgdl, k1=k1, b=b
178
+ )
179
+ posting_list.tweight_postings[i] = regularized_tf * idf
180
+
181
+ @staticmethod
182
+ def calc_regularized_tf(
183
+ tf: int, dl: float, avgdl: float, k1: float, b: float
184
+ ) -> float:
185
+ return tf / (tf + k1 * (1 - b + b * dl / avgdl))
186
+
187
+ @staticmethod
188
+ def calc_idf(df: int, N: int):
189
+ return math.log(1 + (N - df + 0.5) / (df + 0.5))
190
+
191
+ @classmethod
192
+ def build_from_documents(
193
+ cls: Type[BM25Index],
194
+ documents: Iterable[Document],
195
+ store_raw: bool = True,
196
+ output_dir: Optional[str] = None,
197
+ ndocs: Optional[int] = None,
198
+ show_progress_bar: bool = True,
199
+ k1: float = 0.9,
200
+ b: float = 0.4,
201
+ ) -> BM25Index:
202
+ # Counting TFs, DFs, doc_lengths, etc.:
203
+ counting = run_counting(
204
+ documents=documents,
205
+ tokenize_fn=BM25Index.tokenize,
206
+ store_raw=store_raw,
207
+ ndocs=ndocs,
208
+ show_progress_bar=show_progress_bar,
209
+ )
210
+
211
+ # Compute term weights and caching:
212
+ posting_lists = counting.posting_lists
213
+ total_docs = len(counting.cid2docid)
214
+ BM25Index.cache_term_weights(
215
+ posting_lists=posting_lists,
216
+ total_docs=total_docs,
217
+ avgdl=counting.avgdl,
218
+ dfs=counting.dfs,
219
+ dls=counting.dls,
220
+ k1=k1,
221
+ b=b,
222
+ )
223
+
224
+ # Assembly and save:
225
+ index = BM25Index(
226
+ posting_lists=posting_lists,
227
+ vocab=counting.vocab,
228
+ cid2docid=counting.cid2docid,
229
+ collection_ids=counting.collection_ids,
230
+ doc_texts=counting.doc_texts,
231
+ )
232
+ return index
233
+
234
+ from nlp4web_codebase.ir.models import BaseRetriever
235
+ from typing import Type
236
+ from abc import abstractmethod
237
+
238
+
239
+ class BaseInvertedIndexRetriever(BaseRetriever):
240
+
241
+ @property
242
+ @abstractmethod
243
+ def index_class(self) -> Type[InvertedIndex]:
244
+ pass
245
+
246
+ def _init_(self, index_dir: str) -> None:
247
+ self.index = self.index_class.from_saved(index_dir)
248
+
249
+ def get_term_weights(self, query: str, cid: str) -> Dict[str, float]:
250
+ toks = self.index.tokenize(query)
251
+ target_docid = self.index.cid2docid[cid]
252
+ term_weights = {}
253
+ for tok in toks:
254
+ if tok not in self.index.vocab:
255
+ continue
256
+ tid = self.index.vocab[tok]
257
+ posting_list = self.index.posting_lists[tid]
258
+ for docid, tweight in zip(
259
+ posting_list.docid_postings, posting_list.tweight_postings
260
+ ):
261
+ if docid == target_docid:
262
+ term_weights[tok] = tweight
263
+ break
264
+ return term_weights
265
+
266
+ def score(self, query: str, cid: str) -> float:
267
+ return sum(self.get_term_weights(query=query, cid=cid).values())
268
+
269
+ def retrieve(self, query: str, topk: int = 10) -> Dict[str, float]:
270
+ toks = self.index.tokenize(query)
271
+ docid2score: Dict[int, float] = {}
272
+ for tok in toks:
273
+ if tok not in self.index.vocab:
274
+ continue
275
+ tid = self.index.vocab[tok]
276
+ posting_list = self.index.posting_lists[tid]
277
+ for docid, tweight in zip(
278
+ posting_list.docid_postings, posting_list.tweight_postings
279
+ ):
280
+ docid2score.setdefault(docid, 0)
281
+ docid2score[docid] += tweight
282
+ docid2score = dict(
283
+ sorted(docid2score.items(), key=lambda pair: pair[1], reverse=True)[:topk]
284
+ )
285
+ return {
286
+ self.index.collection_ids[docid]: score
287
+ for docid, score in docid2score.items()
288
+ }
289
+
290
+
291
+ class BM25Retriever(BaseInvertedIndexRetriever):
292
+
293
+ @property
294
+ def index_class(self) -> Type[BM25Index]:
295
+ return BM25Index
296
+
297
+ import gradio as gr
298
+ from typing import TypedDict
299
+ from typing import List, TypedDict, Optional
300
+
301
+ class Hit(TypedDict):
302
+ cid: str
303
+ score: float
304
+ text: str
305
+
306
+ demo: Optional[gr.Interface] = None # Assign your gradio demo to this variable
307
+ return_type = List[Hit]
308
+
309
+ ## YOUR_CODE_STARTS_HERE
310
+ bm25_index = BM25Index.build_from_documents(
311
+ documents=iter(sciq.corpus),
312
+ ndocs=len(sciq.corpus),
313
+ k1=0.9,
314
+ b=0.4
315
+ )
316
+ bm25_index.save("output/bm25_index_b") # Save index to directory
317
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index_b")
318
+
319
+ corpus_dict = {doc.collection_id: doc.text for doc in sciq.corpus}
320
+
321
+ def get_query(query):
322
+ results = bm25_retriever.retrieve(query)
323
+ hits = [
324
+ {
325
+ "cid": cid,
326
+ "score": score,
327
+ "text": corpus_dict[cid]
328
+ }
329
+ for cid, score in results.items()
330
+ ]
331
+ return hits
332
+
333
+
334
+
335
+ demo = gr.Interface(
336
+ fn=get_query,
337
+ inputs=gr.Textbox(label="Enter your query"),
338
+ outputs=gr.Textbox(label="Results", lines=20, interactive=False),
339
+ title="BM25 Query Engine"
340
+ )
341
+ ## YOUR_CODE_ENDS_HERE
342
+ demo.launch()