moritz648 commited on
Commit
8a30d86
·
1 Parent(s): c4e65da
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import TypedDict
3
+ from typing import Dict, List
4
+ from datasets import load_dataset
5
+ import joblib
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Dict, List
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+
13
+ @dataclass
14
+ class Document:
15
+ collection_id: str
16
+ text: str
17
+
18
+
19
+ @dataclass
20
+ class Query:
21
+ query_id: str
22
+ text: str
23
+
24
+
25
+ @dataclass
26
+ class QRel:
27
+ query_id: str
28
+ collection_id: str
29
+ relevance: int
30
+ answer: Optional[str] = None
31
+
32
+
33
+
34
+ class Split(str, Enum):
35
+ train = "train"
36
+ dev = "dev"
37
+ test = "test"
38
+
39
+
40
+ @dataclass
41
+ class IRDataset:
42
+ corpus: List[Document]
43
+ queries: List[Query]
44
+ split2qrels: Dict[Split, List[QRel]]
45
+
46
+ def get_stats(self) -> Dict[str, int]:
47
+ stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
48
+ for split, qrels in self.split2qrels.items():
49
+ stats[f"|qrels-{split}|"] = len(qrels)
50
+ return stats
51
+
52
+ def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
53
+ qrels_dict = {}
54
+ for qrel in self.split2qrels[split]:
55
+ qrels_dict.setdefault(qrel.query_id, {})
56
+ qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
57
+ return qrels_dict
58
+
59
+ def get_split_queries(self, split: Split) -> List[Query]:
60
+ qrels = self.split2qrels[split]
61
+ qids = {qrel.query_id for qrel in qrels}
62
+ return list(filter(lambda query: query.query_id in qids, self.queries))
63
+
64
+
65
+ @(joblib.Memory(".cache").cache)
66
+ def load_sciq(verbose: bool = False) -> IRDataset:
67
+ train = load_dataset("allenai/sciq", split="train")
68
+ validation = load_dataset("allenai/sciq", split="validation")
69
+ test = load_dataset("allenai/sciq", split="test")
70
+ data = {Split.train: train, Split.dev: validation, Split.test: test}
71
+
72
+ # Each duplicated record is the same to each other:
73
+ df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
74
+ for question, group in df.groupby("question"):
75
+ assert len(set(group["support"].tolist())) == len(group)
76
+ assert len(set(group["correct_answer"].tolist())) == len(group)
77
+
78
+ # Build:
79
+ corpus = []
80
+ queries = []
81
+ split2qrels: Dict[str, List[dict]] = {}
82
+ question2id = {}
83
+ support2id = {}
84
+ for split, rows in data.items():
85
+ if verbose:
86
+ print(f"|raw_{split}|", len(rows))
87
+ split2qrels[split] = []
88
+ for i, row in enumerate(rows):
89
+ example_id = f"{split}-{i}"
90
+ support: str = row["support"]
91
+ if len(support.strip()) == 0:
92
+ continue
93
+ question = row["question"]
94
+ if len(support.strip()) == 0:
95
+ continue
96
+ if support in support2id:
97
+ continue
98
+ else:
99
+ support2id[support] = example_id
100
+ if question in question2id:
101
+ continue
102
+ else:
103
+ question2id[question] = example_id
104
+ doc = {"collection_id": example_id, "text": support}
105
+ query = {"query_id": example_id, "text": row["question"]}
106
+ qrel = {
107
+ "query_id": example_id,
108
+ "collection_id": example_id,
109
+ "relevance": 1,
110
+ "answer": row["correct_answer"],
111
+ }
112
+ corpus.append(Document(**doc))
113
+ queries.append(Query(**query))
114
+ split2qrels[split].append(QRel(**qrel))
115
+
116
+ # Assembly and return:
117
+ return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
118
+
119
+
120
+ if __name__ == "__main__":
121
+ # python -m nlp4web_codebase.ir.data_loaders.sciq
122
+ import ujson
123
+ import time
124
+
125
+ start = time.time()
126
+ dataset = load_sciq(verbose=True)
127
+ print(f"Loading costs: {time.time() - start}s")
128
+ print(ujson.dumps(dataset.get_stats(), indent=4))
129
+ # ________________________________________________________________________________
130
+ # [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq...
131
+ # load_sciq(verbose=True)
132
+ # |raw_train| 11679
133
+ # |raw_dev| 1000
134
+ # |raw_test| 1000
135
+ # ________________________________________________________load_sciq - 7.3s, 0.1min
136
+ # Loading costs: 7.260092735290527s
137
+ # {
138
+ # "|corpus|": 12160,
139
+ # "|queries|": 12160,
140
+ # "|qrels-train|": 10409,
141
+ # "|qrels-dev|": 875,
142
+ # "|qrels-test|": 876
143
+ # }
144
+
145
+
146
+ class Hit(TypedDict):
147
+ cid: str
148
+ score: float
149
+ text: str
150
+
151
+ ## YOUR_CODE_STARTS_HERE
152
+ def search(query: str) -> List[Hit]:
153
+
154
+ sciq = load_sciq()
155
+ counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
156
+
157
+ bm25_index = BM25Index.build_from_documents(
158
+ documents=iter(sciq.corpus),
159
+ ndocs=12160,
160
+ show_progress_bar=True
161
+ )
162
+ bm25_index.save("output/bm25_index")
163
+
164
+ bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
165
+
166
+ results = bm25_retriever.retrieve(query=query)
167
+
168
+ hits: List[Hit] = []
169
+ for cid, score in results.items():
170
+ docid = bm25_retriever.index.cid2docid[cid]
171
+ text = bm25_retriever.index.doc_texts[docid]
172
+ hits.append({"cid": cid, "score": score, "text": text})
173
+
174
+ return hits
175
+ ## YOUR_CODE_ENDS_HERE
176
+
177
+ demo: Optional[gr.Interface] = gr.Interface(
178
+ fn=search,
179
+ inputs=gr.Textbox(label="Query"),
180
+ outputs=gr.JSON(label="Results")
181
+ ) # Assign your gradio demo to this variable
182
+ return_type = List[Hit]
183
+ demo.launch()