Spaces:
Runtime error
Runtime error
File size: 4,393 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from colbert.infra.run import Run
import os
import ujson
from colbert.evaluation.loaders import load_queries
# TODO: Look up path in some global [per-thread or thread-safe] list.
# TODO: path could be a list of paths...? But then how can we tell it's not a list of queries..
class Queries:
def __init__(self, path=None, data=None):
self.path = path
if data:
assert isinstance(data, dict), type(data)
self._load_data(data) or self._load_file(path)
def __len__(self):
return len(self.data)
def __iter__(self):
return iter(self.data.items())
def provenance(self):
return self.path
def toDict(self):
return {'provenance': self.provenance()}
def _load_data(self, data):
if data is None:
return None
self.data = {}
self._qas = {}
for qid, content in data.items():
if isinstance(content, dict):
self.data[qid] = content['question']
self._qas[qid] = content
else:
self.data[qid] = content
if len(self._qas) == 0:
del self._qas
return True
def _load_file(self, path):
if not path.endswith('.json'):
self.data = load_queries(path)
return True
# Load QAs
self.data = {}
self._qas = {}
with open(path) as f:
for line in f:
qa = ujson.loads(line)
assert qa['qid'] not in self.data
self.data[qa['qid']] = qa['question']
self._qas[qa['qid']] = qa
return self.data
def qas(self):
return dict(self._qas)
def __getitem__(self, key):
return self.data[key]
def keys(self):
return self.data.keys()
def values(self):
return self.data.values()
def items(self):
return self.data.items()
def save(self, new_path):
assert new_path.endswith('.tsv')
assert not os.path.exists(new_path), new_path
with Run().open(new_path, 'w') as f:
for qid, content in self.data.items():
content = f'{qid}\t{content}\n'
f.write(content)
return f.name
def save_qas(self, new_path):
assert new_path.endswith('.json')
assert not os.path.exists(new_path), new_path
with open(new_path, 'w') as f:
for qid, qa in self._qas.items():
qa['qid'] = qid
f.write(ujson.dumps(qa) + '\n')
def _load_tsv(self, path):
raise NotImplementedError
def _load_jsonl(self, path):
raise NotImplementedError
@classmethod
def cast(cls, obj):
if type(obj) is str:
return cls(path=obj)
if isinstance(obj, dict) or isinstance(obj, list):
return cls(data=obj)
if type(obj) is cls:
return obj
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
# class QuerySet:
# def __init__(self, *paths, renumber=False):
# self.paths = paths
# self.original_queries = [load_queries(path) for path in paths]
# if renumber:
# self.queries = flatten([q.values() for q in self.original_queries])
# self.queries = {idx: text for idx, text in enumerate(self.queries)}
# else:
# self.queries = {}
# for queries in self.original_queries:
# assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \
# "renumber=False requires non-overlapping query IDs"
# self.queries.update(queries)
# assert len(self.queries) == sum(map(len, self.original_queries))
# def todict(self):
# return dict(self.queries)
# def tolist(self):
# return list(self.queries.values())
# def query_sets(self):
# return self.original_queries
# def split_rankings(self, rankings):
# assert type(rankings) is list
# assert len(rankings) == len(self.queries)
# sub_rankings = []
# offset = 0
# for source in self.original_queries:
# sub_rankings.append(rankings[offset:offset+len(source)])
# offset += len(source)
# return sub_rankings
|