orionweller commited on
Commit
a907241
·
1 Parent(s): 5273fd3
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
- title: Scifact Prompting
3
  emoji: ⚡
4
- colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.41.0
 
1
  ---
2
+ title: Retrieval Prompting
3
  emoji: ⚡
4
+ colorFrom: yellow
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 4.41.0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pickle
3
+ import numpy as np
4
+ import glob
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from peft import PeftModel
10
+ from tevatron.retriever.searcher import FaissFlatSearcher
11
+ import logging
12
+ import os
13
+ import json
14
+ import spaces
15
+ import ir_datasets
16
+ import subprocess
17
+
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Global variables
23
+ CUR_MODEL = "orionweller/repllama-instruct-hard-positives-v2-joint"
24
+ base_model = "meta-llama/Llama-2-7b-hf"
25
+ tokenizer = None
26
+ model = None
27
+ retriever = None
28
+ corpus_lookup = None
29
+ queries = None
30
+ q_lookup = None
31
+
32
+ def load_model():
33
+ global tokenizer, model
34
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
35
+ tokenizer.pad_token_id = tokenizer.eos_token_id
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+ tokenizer.padding_side = "right"
38
+
39
+ base_model_instance = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
40
+ model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL)
41
+ model = model.merge_and_unload()
42
+ model.eval()
43
+ model.cuda()
44
+
45
+ def load_corpus_embeddings(dataset_name):
46
+ global retriever, corpus_lookup
47
+ corpus_path = f"{dataset_name}/corpus_emb*"
48
+ index_files = glob.glob(corpus_path)
49
+ logger.info(f'Pattern match found {len(index_files)} files; loading them into index.')
50
+
51
+ p_reps_0, p_lookup_0 = pickle_load(index_files[0])
52
+ retriever = FaissFlatSearcher(p_reps_0)
53
+
54
+ shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]]
55
+ corpus_lookup = []
56
+
57
+ for p_reps, p_lookup in tqdm(shards, desc='Loading shards into index', total=len(index_files)):
58
+ retriever.add(p_reps)
59
+ corpus_lookup += p_lookup
60
+
61
+ def pickle_load(path):
62
+ with open(path, 'rb') as f:
63
+ reps, lookup = pickle.load(f)
64
+ return np.array(reps), lookup
65
+
66
+ def load_queries(dataset_name):
67
+ global queries, q_lookup
68
+ dataset = ir_datasets.load(f"beir/{dataset_name.lower()}/test")
69
+
70
+ queries = []
71
+ q_lookup = {}
72
+ for query in dataset.queries_iter():
73
+ queries.append(query.text)
74
+ q_lookup[query.query_id] = query.text
75
+
76
+ def encode_queries(prefix, postfix):
77
+ global queries
78
+ input_texts = [f"{prefix}Query: {query} {postfix}".strip() for query in queries]
79
+
80
+ encoded_embeds = []
81
+ batch_size = 32 # Adjust as needed
82
+
83
+ for start_idx in range(0, len(input_texts), batch_size):
84
+ batch_input_texts = input_texts[start_idx: start_idx + batch_size]
85
+
86
+ inputs = tokenizer(batch_input_texts, padding=True, truncation=True, return_tensors="pt").to(model.device)
87
+
88
+ with torch.no_grad():
89
+ outputs = model(**inputs)
90
+ embeds = outputs.last_hidden_state[:, 0, :] # Use [CLS] token embedding
91
+ embeds = F.normalize(embeds, p=2, dim=-1)
92
+ encoded_embeds.append(embeds.cpu().numpy())
93
+
94
+ return np.concatenate(encoded_embeds, axis=0)
95
+
96
+ def search_queries(q_reps, depth=1000):
97
+ all_scores, all_indices = retriever.search(q_reps, depth)
98
+ psg_indices = [[str(corpus_lookup[x]) for x in q_dd] for q_dd in all_indices]
99
+ return all_scores, np.array(psg_indices)
100
+
101
+ def write_ranking(corpus_indices, corpus_scores, ranking_save_file):
102
+ with open(ranking_save_file, 'w') as f:
103
+ for qid, q_doc_scores, q_doc_indices in zip(q_lookup.keys(), corpus_scores, corpus_indices):
104
+ score_list = [(s, idx) for s, idx in zip(q_doc_scores, q_doc_indices)]
105
+ score_list = sorted(score_list, key=lambda x: x[0], reverse=True)
106
+ for rank, (s, idx) in enumerate(score_list, 1):
107
+ f.write(f'{qid} Q0 {idx} {rank} {s} pyserini\n')
108
+
109
+ def evaluate_with_subprocess(dataset, ranking_file):
110
+ # Convert to TREC format
111
+ trec_file = f"rank.{dataset}.trec"
112
+ convert_cmd = [
113
+ "python", "-m", "tevatron.utils.format.convert_result_to_trec",
114
+ "--input", ranking_file,
115
+ "--output", trec_file,
116
+ "--remove_query"
117
+ ]
118
+ subprocess.run(convert_cmd, check=True)
119
+
120
+ # Evaluate using trec_eval
121
+ eval_cmd = [
122
+ "python", "-m", "pyserini.eval.trec_eval",
123
+ "-c", "-mrecall.100", "-mndcg_cut.10",
124
+ f"beir-v1.0.0-{dataset}-test", trec_file
125
+ ]
126
+ result = subprocess.run(eval_cmd, capture_output=True, text=True, check=True)
127
+
128
+ # Parse the output
129
+ lines = result.stdout.strip().split('\n')
130
+ ndcg_10 = float(lines[0].split()[-1])
131
+ recall_100 = float(lines[1].split()[-1])
132
+
133
+ # Clean up temporary files
134
+ os.remove(ranking_file)
135
+ os.remove(trec_file)
136
+
137
+ return f"nDCG@10: {ndcg_10:.4f}, Recall@100: {recall_100:.4f}"
138
+
139
+ @spaces.GPU
140
+ def run_evaluation(dataset, prefix, postfix):
141
+ global queries, q_lookup
142
+
143
+ # Load corpus embeddings and queries if not already loaded
144
+ if retriever is None or queries is None:
145
+ load_corpus_embeddings(dataset)
146
+ load_queries(dataset)
147
+
148
+ # Encode queries
149
+ q_reps = encode_queries(prefix, postfix)
150
+
151
+ # Search
152
+ all_scores, psg_indices = search_queries(q_reps)
153
+
154
+ # Write ranking
155
+ ranking_file = f"temp_ranking_{dataset}.txt"
156
+ write_ranking(psg_indices, all_scores, ranking_file)
157
+
158
+ # Evaluate
159
+ results = evaluate_with_subprocess(dataset, ranking_file)
160
+
161
+ return results
162
+
163
+ def gradio_interface(dataset, prefix, postfix):
164
+ return run_evaluation(dataset, prefix, postfix)
165
+
166
+ # Load model
167
+ load_model()
168
+
169
+ # Create Gradio interface
170
+ iface = gr.Interface(
171
+ fn=gradio_interface,
172
+ inputs=[
173
+ gr.Dropdown(choices=["scifact", "arguana"], label="Dataset"),
174
+ gr.Textbox(label="Prefix prompt"),
175
+ gr.Textbox(label="Postfix prompt")
176
+ ],
177
+ outputs=gr.Textbox(label="Evaluation Results"),
178
+ title="Query Evaluation with Custom Prompts",
179
+ description="Select a dataset and enter prefix and postfix prompts to evaluate queries using Pyserini."
180
+ )
181
+
182
+ # Launch the interface
183
+ iface.launch()
arguana/corpus_emb.0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21742104cd3b5ff805fe1a7432c960b6933159c1092d49f5c4cad74922916a9b
3
+ size 35619068
arguana/corpus_emb.1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37b86d66f10459bc2c032741f70e9f1edba68b2b3fbf9c7b1c1bb6f6139fef02
3
+ size 35619074
arguana/corpus_emb.2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0ebf48b7b6ad18402a90da0ca92b71128b43d81a85fdb425edcaa3767f213d1
3
+ size 35602692
arguana/corpus_emb.3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82bb37e7f6d7e0728da781aaccaad49830708cd44abc455f8c3a81db1e4b4b0f
3
+ size 35602679
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ default-jre
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==4.39.0
2
+ pyserini==0.23.0
3
+ faiss-cpu==1.7.4
4
+ torch==2.1.0
5
+ ir_datasets
6
+ peft==0.12.0
7
+ ir_datasets==0.5.8
8
+ tevatron @ git+https://github.com/texttron/tevatron@7d298b4
scifact/corpus_emb.0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bb98e68350983519732b0b39e8f98ec0225abd2c68775e7317da9b17f0db1dd
3
+ size 21247618
scifact/corpus_emb.1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3dd3501342754aeb2ffb895480868e0976895bded3e5accbd8e5b6fa404e5484
3
+ size 21247619
scifact/corpus_emb.2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e1a98c698cbe367bc1abc789da76794a8e79e92743059b26faafbd34808aa15
3
+ size 21247619
scifact/corpus_emb.3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:911c8d6654bfb14a3d68363c96a70462348cfbbf35a591e020877ed28591339c
3
+ size 21231225