moritz648
commited on
Commit
·
8a30d86
1
Parent(s):
c4e65da
app.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import TypedDict
|
3 |
+
from typing import Dict, List
|
4 |
+
from datasets import load_dataset
|
5 |
+
import joblib
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from enum import Enum
|
8 |
+
from typing import Dict, List
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Optional
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class Document:
|
15 |
+
collection_id: str
|
16 |
+
text: str
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class Query:
|
21 |
+
query_id: str
|
22 |
+
text: str
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class QRel:
|
27 |
+
query_id: str
|
28 |
+
collection_id: str
|
29 |
+
relevance: int
|
30 |
+
answer: Optional[str] = None
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
class Split(str, Enum):
|
35 |
+
train = "train"
|
36 |
+
dev = "dev"
|
37 |
+
test = "test"
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class IRDataset:
|
42 |
+
corpus: List[Document]
|
43 |
+
queries: List[Query]
|
44 |
+
split2qrels: Dict[Split, List[QRel]]
|
45 |
+
|
46 |
+
def get_stats(self) -> Dict[str, int]:
|
47 |
+
stats = {"|corpus|": len(self.corpus), "|queries|": len(self.queries)}
|
48 |
+
for split, qrels in self.split2qrels.items():
|
49 |
+
stats[f"|qrels-{split}|"] = len(qrels)
|
50 |
+
return stats
|
51 |
+
|
52 |
+
def get_qrels_dict(self, split: Split) -> Dict[str, Dict[str, int]]:
|
53 |
+
qrels_dict = {}
|
54 |
+
for qrel in self.split2qrels[split]:
|
55 |
+
qrels_dict.setdefault(qrel.query_id, {})
|
56 |
+
qrels_dict[qrel.query_id][qrel.collection_id] = qrel.relevance
|
57 |
+
return qrels_dict
|
58 |
+
|
59 |
+
def get_split_queries(self, split: Split) -> List[Query]:
|
60 |
+
qrels = self.split2qrels[split]
|
61 |
+
qids = {qrel.query_id for qrel in qrels}
|
62 |
+
return list(filter(lambda query: query.query_id in qids, self.queries))
|
63 |
+
|
64 |
+
|
65 |
+
@(joblib.Memory(".cache").cache)
|
66 |
+
def load_sciq(verbose: bool = False) -> IRDataset:
|
67 |
+
train = load_dataset("allenai/sciq", split="train")
|
68 |
+
validation = load_dataset("allenai/sciq", split="validation")
|
69 |
+
test = load_dataset("allenai/sciq", split="test")
|
70 |
+
data = {Split.train: train, Split.dev: validation, Split.test: test}
|
71 |
+
|
72 |
+
# Each duplicated record is the same to each other:
|
73 |
+
df = train.to_pandas() + validation.to_pandas() + test.to_pandas()
|
74 |
+
for question, group in df.groupby("question"):
|
75 |
+
assert len(set(group["support"].tolist())) == len(group)
|
76 |
+
assert len(set(group["correct_answer"].tolist())) == len(group)
|
77 |
+
|
78 |
+
# Build:
|
79 |
+
corpus = []
|
80 |
+
queries = []
|
81 |
+
split2qrels: Dict[str, List[dict]] = {}
|
82 |
+
question2id = {}
|
83 |
+
support2id = {}
|
84 |
+
for split, rows in data.items():
|
85 |
+
if verbose:
|
86 |
+
print(f"|raw_{split}|", len(rows))
|
87 |
+
split2qrels[split] = []
|
88 |
+
for i, row in enumerate(rows):
|
89 |
+
example_id = f"{split}-{i}"
|
90 |
+
support: str = row["support"]
|
91 |
+
if len(support.strip()) == 0:
|
92 |
+
continue
|
93 |
+
question = row["question"]
|
94 |
+
if len(support.strip()) == 0:
|
95 |
+
continue
|
96 |
+
if support in support2id:
|
97 |
+
continue
|
98 |
+
else:
|
99 |
+
support2id[support] = example_id
|
100 |
+
if question in question2id:
|
101 |
+
continue
|
102 |
+
else:
|
103 |
+
question2id[question] = example_id
|
104 |
+
doc = {"collection_id": example_id, "text": support}
|
105 |
+
query = {"query_id": example_id, "text": row["question"]}
|
106 |
+
qrel = {
|
107 |
+
"query_id": example_id,
|
108 |
+
"collection_id": example_id,
|
109 |
+
"relevance": 1,
|
110 |
+
"answer": row["correct_answer"],
|
111 |
+
}
|
112 |
+
corpus.append(Document(**doc))
|
113 |
+
queries.append(Query(**query))
|
114 |
+
split2qrels[split].append(QRel(**qrel))
|
115 |
+
|
116 |
+
# Assembly and return:
|
117 |
+
return IRDataset(corpus=corpus, queries=queries, split2qrels=split2qrels)
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
# python -m nlp4web_codebase.ir.data_loaders.sciq
|
122 |
+
import ujson
|
123 |
+
import time
|
124 |
+
|
125 |
+
start = time.time()
|
126 |
+
dataset = load_sciq(verbose=True)
|
127 |
+
print(f"Loading costs: {time.time() - start}s")
|
128 |
+
print(ujson.dumps(dataset.get_stats(), indent=4))
|
129 |
+
# ________________________________________________________________________________
|
130 |
+
# [Memory] Calling __main__--home-kwang-research-nlp4web-ir-exercise-nlp4web-nlp4web-ir-data_loaders-sciq.load_sciq...
|
131 |
+
# load_sciq(verbose=True)
|
132 |
+
# |raw_train| 11679
|
133 |
+
# |raw_dev| 1000
|
134 |
+
# |raw_test| 1000
|
135 |
+
# ________________________________________________________load_sciq - 7.3s, 0.1min
|
136 |
+
# Loading costs: 7.260092735290527s
|
137 |
+
# {
|
138 |
+
# "|corpus|": 12160,
|
139 |
+
# "|queries|": 12160,
|
140 |
+
# "|qrels-train|": 10409,
|
141 |
+
# "|qrels-dev|": 875,
|
142 |
+
# "|qrels-test|": 876
|
143 |
+
# }
|
144 |
+
|
145 |
+
|
146 |
+
class Hit(TypedDict):
|
147 |
+
cid: str
|
148 |
+
score: float
|
149 |
+
text: str
|
150 |
+
|
151 |
+
## YOUR_CODE_STARTS_HERE
|
152 |
+
def search(query: str) -> List[Hit]:
|
153 |
+
|
154 |
+
sciq = load_sciq()
|
155 |
+
counting = run_counting(documents=iter(sciq.corpus), ndocs=len(sciq.corpus))
|
156 |
+
|
157 |
+
bm25_index = BM25Index.build_from_documents(
|
158 |
+
documents=iter(sciq.corpus),
|
159 |
+
ndocs=12160,
|
160 |
+
show_progress_bar=True
|
161 |
+
)
|
162 |
+
bm25_index.save("output/bm25_index")
|
163 |
+
|
164 |
+
bm25_retriever = BM25Retriever(index_dir="output/bm25_index")
|
165 |
+
|
166 |
+
results = bm25_retriever.retrieve(query=query)
|
167 |
+
|
168 |
+
hits: List[Hit] = []
|
169 |
+
for cid, score in results.items():
|
170 |
+
docid = bm25_retriever.index.cid2docid[cid]
|
171 |
+
text = bm25_retriever.index.doc_texts[docid]
|
172 |
+
hits.append({"cid": cid, "score": score, "text": text})
|
173 |
+
|
174 |
+
return hits
|
175 |
+
## YOUR_CODE_ENDS_HERE
|
176 |
+
|
177 |
+
demo: Optional[gr.Interface] = gr.Interface(
|
178 |
+
fn=search,
|
179 |
+
inputs=gr.Textbox(label="Query"),
|
180 |
+
outputs=gr.JSON(label="Results")
|
181 |
+
) # Assign your gradio demo to this variable
|
182 |
+
return_type = List[Hit]
|
183 |
+
demo.launch()
|