Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import sys | |
# We're going to explicitly use a local installation of Pyserini (as opposed to a pip-installed one). | |
# Comment these lines out to use a pip-installed one instead. | |
sys.path.insert(0, './') | |
import argparse | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
from collections import defaultdict | |
from transformers import AutoTokenizer | |
from pyserini.search.lucene.ltr._search_msmarco import MsmarcoLtrSearcher | |
from pyserini.search.lucene.ltr import * | |
from pyserini.search.lucene import LuceneSearcher | |
from pyserini.analysis import Analyzer, get_lucene_analyzer | |
""" | |
Running prediction on candidates | |
""" | |
def dev_data_loader(file, format, topic, rerank, prebuilt, qrel, granularity, top=1000): | |
if rerank: | |
if format == 'tsv': | |
dev = pd.read_csv(file, sep="\t", | |
names=['qid', 'pid', 'rank'], | |
dtype={'qid': 'S','pid': 'S', 'rank':'i',}) | |
elif format == 'trec': | |
dev = pd.read_csv(file, sep="\s+", | |
names=['qid', 'q0', 'pid', 'rank', 'score', 'tag'], | |
usecols=['qid', 'pid', 'rank'], | |
dtype={'qid': 'S','pid': 'S', 'rank':'i',}) | |
else: | |
raise Exception('unknown parameters') | |
assert dev['qid'].dtype == object | |
assert dev['pid'].dtype == object | |
assert dev['rank'].dtype == np.int32 | |
dev = dev[dev['rank']<=top] | |
else: | |
if prebuilt: | |
bm25search = LuceneSearcher.from_prebuilt_index(args.index) | |
else: | |
bm25search = LuceneSearcher(args.index) | |
bm25search.set_bm25(0.82, 0.68) | |
dev_dic = {"qid":[], "pid":[], "rank":[]} | |
for topic in tqdm(queries.keys()): | |
query_text = queries[topic]['raw'] | |
bm25_dev = bm25search.search(query_text, args.hits) | |
doc_ids = [bm25_result.docid for bm25_result in bm25_dev] | |
qid = [topic for _ in range(len(doc_ids))] | |
rank = [i for i in range(1, len(doc_ids)+1)] | |
dev_dic['qid'].extend(qid) | |
dev_dic['pid'].extend(doc_ids) | |
dev_dic['rank'].extend(rank) | |
dev = pd.DataFrame(dev_dic) | |
dev['rank'].astype(np.int32) | |
if granularity == 'document': | |
seperation = "\t" | |
else: | |
seperation = " " | |
dev_qrel = pd.read_csv(qrel, sep=seperation, | |
names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], | |
dtype={'qid': 'S','pid': 'S', 'rel':'i'}) | |
dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left') | |
dev['rel'] = dev['rel'].fillna(0).astype(np.int32) | |
dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid']) | |
print(dev.shape) | |
print(dev.index.get_level_values('qid').drop_duplicates().shape) | |
print(dev.groupby('qid').count().mean()) | |
print(dev.head(10)) | |
print(dev.info()) | |
dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] | |
recall_point = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000] | |
recall_curve = {k: [] for k in recall_point} | |
for qid, group in tqdm(dev.groupby('qid')): | |
group = group.reset_index() | |
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | |
total_rel = dev_rel_num.loc[qid] | |
query_recall = [0 for k in recall_point] | |
for t in group.sort_values('rank').itertuples(): | |
if t.rel > 0: | |
for i, p in enumerate(recall_point): | |
if t.rank <= p: | |
query_recall[i] += 1 | |
for i, p in enumerate(recall_point): | |
if total_rel > 0: | |
recall_curve[p].append(query_recall[i] / total_rel) | |
else: | |
recall_curve[p].append(0.) | |
for k, v in recall_curve.items(): | |
avg = np.mean(v) | |
print(f'recall@{k}:{avg}') | |
return dev, dev_qrel | |
def query_loader(topic): | |
queries = {} | |
nlp = SpacyTextParser('en_core_web_sm', keep_only_alpha_num=True, lower_case=True) | |
analyzer = Analyzer(get_lucene_analyzer()) | |
bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
inp_file = open(topic) | |
ln = 0 | |
for line in tqdm(inp_file): | |
ln += 1 | |
line = line.strip() | |
if not line: | |
continue | |
fields = line.split('\t') | |
if len(fields) != 2: | |
print('Misformated line %d ignoring:' % ln) | |
print(line.replace('\t', '<field delimiter>')) | |
continue | |
did, query = fields | |
query_lemmas, query_unlemm = nlp.proc_text(query) | |
analyzed = analyzer.analyze(query) | |
for token in analyzed: | |
if ' ' in token: | |
print(analyzed) | |
query_toks = query_lemmas.split() | |
if len(query_toks) >= 0: | |
query = {"raw" : query, | |
"text": query_lemmas.split(' '), | |
"text_unlemm": query_unlemm.split(' '), | |
"analyzed": analyzed, | |
"text_bert_tok": bert_tokenizer.tokenize(query.lower())} | |
queries[did] = query | |
if ln % 10000 == 0: | |
print('Processed %d queries' % ln) | |
print('Processed %d queries' % ln) | |
return queries | |
def eval_mrr(dev_data): | |
score_tie_counter = 0 | |
score_tie_query = set() | |
MRR = [] | |
for qid, group in tqdm(dev_data.groupby('qid')): | |
group = group.reset_index() | |
rank = 0 | |
prev_score = None | |
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | |
# stable sort is also used in LightGBM | |
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | |
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | |
score_tie_counter += 1 | |
score_tie_query.add(qid) | |
prev_score = t.score | |
rank += 1 | |
if t.rel > 0: | |
MRR.append(1.0 / rank) | |
break | |
elif rank == 10 or rank == len(group): | |
MRR.append(0.) | |
break | |
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | |
print(score_tie) | |
mrr_10 = np.mean(MRR).item() | |
print(f'MRR@10:{mrr_10} with {len(MRR)} queries') | |
return {'score_tie': score_tie, 'mrr_10': mrr_10} | |
def eval_recall(dev_qrel, dev_data): | |
dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] | |
score_tie_counter = 0 | |
score_tie_query = set() | |
recall_point = [10,20,50,100,200,250,300,333,400,500,1000] | |
recall_curve = {k: [] for k in recall_point} | |
for qid, group in tqdm(dev_data.groupby('qid')): | |
group = group.reset_index() | |
rank = 0 | |
prev_score = None | |
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | |
# stable sort is also used in LightGBM | |
total_rel = dev_rel_num.loc[qid] | |
query_recall = [0 for k in recall_point] | |
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | |
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | |
score_tie_counter += 1 | |
score_tie_query.add(qid) | |
prev_score = t.score | |
rank += 1 | |
if t.rel > 0: | |
for i, p in enumerate(recall_point): | |
if rank <= p: | |
query_recall[i] += 1 | |
for i, p in enumerate(recall_point): | |
if total_rel > 0: | |
recall_curve[p].append(query_recall[i] / total_rel) | |
else: | |
recall_curve[p].append(0.) | |
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | |
print(score_tie) | |
res = {'score_tie': score_tie} | |
for k, v in recall_curve.items(): | |
avg = np.mean(v) | |
print(f'recall@{k}:{avg}') | |
res[f'recall@{k}'] = avg | |
return res | |
def output(file, dev_data, format, maxp): | |
score_tie_counter = 0 | |
score_tie_query = set() | |
output_file = open(file,'w') | |
results = defaultdict(dict) | |
idx = 0 | |
for qid, group in tqdm(dev_data.groupby('qid')): | |
group = group.reset_index() | |
rank = 0 | |
prev_score = None | |
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | |
# stable sort is also used in LightGBM | |
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | |
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | |
score_tie_counter += 1 | |
score_tie_query.add(qid) | |
prev_score = t.score | |
if maxp: | |
docid = t.pid.split('#')[0] | |
if qid not in results or docid not in results[qid] or t.score > results[qid][docid]: | |
results[qid][docid] = t.score | |
else: | |
results[qid][t.pid] = t.score | |
for qid in tqdm(results.keys()): | |
rank = 1 | |
docid_score = results[qid] | |
docid_score = sorted(docid_score.items(),key=lambda kv: kv[1], reverse=True) | |
for docid, score in docid_score: | |
if format=='trec': | |
output_file.write(f"{qid}\tQ0\t{docid}\t{rank}\t{score}\tltr\n") | |
else: | |
output_file.write(f"{qid}\t{docid}\t{rank}\n") | |
rank += 1 | |
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | |
print(score_tie) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='Learning to rank reranking') | |
parser.add_argument('--input', default='') | |
parser.add_argument('--hits', type=int, default=1000) | |
parser.add_argument('--input-format', default = 'trec') | |
parser.add_argument('--model', required=True) | |
parser.add_argument('--index', required=True) | |
parser.add_argument('--output', required=True) | |
parser.add_argument('--ibm-model', required=True) | |
parser.add_argument('--topic', required=True) | |
parser.add_argument('--output-format', default='tsv') | |
parser.add_argument('--max-passage', action='store_true') | |
parser.add_argument('--rerank', action='store_true') | |
parser.add_argument('--qrel', required=True) | |
parser.add_argument('--granularity', default='passage') | |
args = parser.parse_args() | |
queries = query_loader(args.topic) | |
print("---------------------loading dev----------------------------------------") | |
prebuilt = args.index == 'msmarco-passage-ltr' or args.index == 'msmarco-doc-per-passage-ltr' | |
dev, dev_qrel = dev_data_loader(args.input, args.input_format, args.topic, args.rerank, prebuilt, args.qrel, args.granularity, args.hits) | |
searcher = MsmarcoLtrSearcher(args.model, args.ibm_model, args.index, args.granularity, prebuilt, args.topic) | |
searcher.add_fe() | |
batch_info = searcher.search(dev, queries) | |
del dev, queries | |
eval_res = eval_mrr(batch_info) | |
eval_recall(dev_qrel, batch_info) | |
output(args.output, batch_info,args.output_format, args.max_passage) | |
print('Done!') |