CarlosMalaga commited on
Commit
8e4a1aa
·
verified ·
1 Parent(s): a22a927

Delete examples

Browse files
examples/explore_faiss.md DELETED
@@ -1,8 +0,0 @@
1
- # table to store results
2
-
3
- | Index | nprobe | Recall | Time |
4
- |----------------|--------|--------|-------|
5
- | Flat | 1 | 98.7 | 38.64 |
6
- | IVFx,Flat | 1 | 42.5 | 23.46 |
7
- | IVFx,Flat | 14 | 88.5 | 133 |
8
- | IVFx_HNSW,Flat | 1 | 88.5 | 133 |
 
 
 
 
 
 
 
 
 
examples/explore_faiss.py DELETED
@@ -1,163 +0,0 @@
1
- import argparse
2
- import json
3
- import logging
4
- import os
5
- from pathlib import Path
6
- import time
7
- from typing import Union
8
-
9
- import torch
10
- import tqdm
11
-
12
- from relik.retriever import GoldenRetriever
13
- from relik.common.log import get_logger
14
- from relik.retriever.common.model_inputs import ModelInputs
15
- from relik.retriever.data.base.datasets import BaseDataset
16
- from relik.retriever.indexers.base import BaseDocumentIndex
17
- from relik.retriever.indexers.faiss import FaissDocumentIndex
18
-
19
- logger = get_logger(level=logging.INFO)
20
-
21
-
22
- def compute_retriever_stats(dataset) -> None:
23
- correct, total = 0, 0
24
- for sample in dataset:
25
- window_candidates = sample["window_candidates"]
26
- window_candidates = [c.replace("_", " ").lower() for c in window_candidates]
27
-
28
- for ss, se, label in sample["window_labels"]:
29
- if label == "--NME--":
30
- continue
31
- if label.replace("_", " ").lower() in window_candidates:
32
- correct += 1
33
- total += 1
34
-
35
- recall = correct / total
36
- print("Recall:", recall)
37
-
38
-
39
- @torch.no_grad()
40
- def add_candidates(
41
- retriever_name_or_path: Union[str, os.PathLike],
42
- document_index_name_or_path: Union[str, os.PathLike],
43
- input_path: Union[str, os.PathLike],
44
- batch_size: int = 128,
45
- num_workers: int = 4,
46
- index_type: str = "Flat",
47
- nprobe: int = 1,
48
- device: str = "cpu",
49
- precision: str = "fp32",
50
- topics: bool = False,
51
- ):
52
- document_index = BaseDocumentIndex.from_pretrained(
53
- document_index_name_or_path,
54
- # config_kwargs={
55
- # "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex",
56
- # "index_type": index_type,
57
- # "nprobe": nprobe,
58
- # },
59
- device=device,
60
- precision=precision,
61
- )
62
-
63
- retriever = GoldenRetriever(
64
- question_encoder=retriever_name_or_path,
65
- document_index=document_index,
66
- device=device,
67
- precision=precision,
68
- index_device=device,
69
- index_precision=precision,
70
- )
71
- retriever.eval()
72
-
73
- logger.info(f"Loading from {input_path}")
74
- with open(input_path) as f:
75
- samples = [json.loads(line) for line in f.readlines()]
76
-
77
- topics = topics and "doc_topic" in samples[0]
78
-
79
- # get tokenizer
80
- tokenizer = retriever.question_tokenizer
81
- collate_fn = lambda batch: ModelInputs(
82
- tokenizer(
83
- [b["text"] for b in batch],
84
- text_pair=[b["doc_topic"] for b in batch] if topics else None,
85
- padding=True,
86
- return_tensors="pt",
87
- truncation=True,
88
- )
89
- )
90
- logger.info(f"Creating dataloader with batch size {batch_size}")
91
- dataloader = torch.utils.data.DataLoader(
92
- BaseDataset(name="passage", data=samples),
93
- batch_size=batch_size,
94
- shuffle=False,
95
- num_workers=num_workers,
96
- pin_memory=False,
97
- collate_fn=collate_fn,
98
- )
99
-
100
- # we also dump the candidates to a file after a while
101
- retrieved_accumulator = []
102
- with torch.inference_mode():
103
- start = time.time()
104
- num_completed_docs = 0
105
-
106
- for documents_batch in tqdm.tqdm(dataloader):
107
- retrieve_kwargs = {
108
- **documents_batch,
109
- "k": 100,
110
- "precision": precision,
111
- }
112
- batch_out = retriever.retrieve(**retrieve_kwargs)
113
- retrieved_accumulator.extend(batch_out)
114
-
115
- end = time.time()
116
-
117
- output_data = []
118
- # get the correct document from the original dataset
119
- # the dataloader is not shuffled, so we can just count the number of
120
- # documents we have seen so far
121
- for sample, retrieved in zip(
122
- samples[
123
- num_completed_docs : num_completed_docs + len(retrieved_accumulator)
124
- ],
125
- retrieved_accumulator,
126
- ):
127
- candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved]
128
- sample["window_candidates"] = candidate_titles
129
- sample["window_candidates_scores"] = [c.score for c in retrieved]
130
- output_data.append(sample)
131
-
132
- # for sample in output_data:
133
- # f_out.write(json.dumps(sample) + "\n")
134
-
135
- num_completed_docs += len(retrieved_accumulator)
136
- retrieved_accumulator = []
137
-
138
- compute_retriever_stats(output_data)
139
- print(f"Retrieval took {end - start:.2f} seconds")
140
-
141
-
142
- if __name__ == "__main__":
143
- # arg_parser = argparse.ArgumentParser()
144
- # arg_parser.add_argument("--retriever_name_or_path", type=str, required=True)
145
- # arg_parser.add_argument("--document_index_name_or_path", type=str, required=True)
146
- # arg_parser.add_argument("--input_path", type=str, required=True)
147
- # arg_parser.add_argument("--output_path", type=str, required=True)
148
- # arg_parser.add_argument("--batch_size", type=int, default=128)
149
- # arg_parser.add_argument("--device", type=str, default="cuda")
150
- # arg_parser.add_argument("--index_device", type=str, default="cpu")
151
- # arg_parser.add_argument("--precision", type=str, default="fp32")
152
-
153
- # add_candidates(**vars(arg_parser.parse_args()))
154
- add_candidates(
155
- "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
156
- "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
157
- "/root/relik-spaces/data/reader/aida/testa_windowed.jsonl",
158
- # index_type="HNSW32",
159
- # index_type="IVF1024,PQ8",
160
- # nprobe=1,
161
- topics=True,
162
- device="cuda",
163
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/train_retriever.py DELETED
@@ -1,45 +0,0 @@
1
- from relik.retriever.trainer import RetrieverTrainer
2
- from relik import GoldenRetriever
3
- from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
4
- from relik.retriever.data.datasets import AidaInBatchNegativesDataset
5
-
6
- if __name__ == "__main__":
7
- # instantiate retriever
8
- document_index = InMemoryDocumentIndex(
9
- documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
10
- device="cuda",
11
- precision="16",
12
- )
13
- retriever = GoldenRetriever(
14
- question_encoder="intfloat/e5-small-v2", document_index=document_index
15
- )
16
-
17
- train_dataset = AidaInBatchNegativesDataset(
18
- name="aida_train",
19
- path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
20
- tokenizer=retriever.question_tokenizer,
21
- question_batch_size=64,
22
- passage_batch_size=400,
23
- max_passage_length=64,
24
- use_topics=True,
25
- shuffle=True,
26
- )
27
- val_dataset = AidaInBatchNegativesDataset(
28
- name="aida_val",
29
- path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
30
- tokenizer=retriever.question_tokenizer,
31
- question_batch_size=64,
32
- passage_batch_size=400,
33
- max_passage_length=64,
34
- use_topics=True,
35
- )
36
-
37
- trainer = RetrieverTrainer(
38
- retriever=retriever,
39
- train_dataset=train_dataset,
40
- val_dataset=val_dataset,
41
- max_steps=25_000,
42
- wandb_offline_mode=True,
43
- )
44
-
45
- trainer.train()