|
import os
|
|
import ujson
|
|
|
|
from functools import partial
|
|
from colbert.utils.utils import print_message
|
|
from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples
|
|
|
|
from colbert.utils.runs import Run
|
|
|
|
|
|
class LazyBatcher():
|
|
def __init__(self, args, rank=0, nranks=1):
|
|
self.bsize, self.accumsteps = args.bsize, args.accumsteps
|
|
|
|
self.query_tokenizer = QueryTokenizer(args.query_maxlen)
|
|
self.doc_tokenizer = DocTokenizer(args.doc_maxlen)
|
|
self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
|
|
self.position = 0
|
|
|
|
self.triples = self._load_triples(args.triples, rank, nranks)
|
|
self.queries = self._load_queries(args.queries)
|
|
self.collection = self._load_collection(args.collection)
|
|
|
|
def _load_triples(self, path, rank, nranks):
|
|
"""
|
|
NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
|
|
In particular, each subset is perfectly represented in every batch! However, since we never
|
|
repeat passes over the data, we never repeat any particular triple, and the split across
|
|
nodes is random (since the underlying file is pre-shuffled), there's no concern here.
|
|
"""
|
|
print_message("#> Loading triples...")
|
|
|
|
triples = []
|
|
|
|
with open(path) as f:
|
|
for line_idx, line in enumerate(f):
|
|
if line_idx % nranks == rank:
|
|
qid, pos, neg = ujson.loads(line)
|
|
triples.append((qid, pos, neg))
|
|
|
|
return triples
|
|
|
|
def _load_queries(self, path):
|
|
print_message("#> Loading queries...")
|
|
|
|
queries = {}
|
|
|
|
with open(path) as f:
|
|
for line in f:
|
|
qid, query = line.strip().split('\t')
|
|
qid = int(qid)
|
|
queries[qid] = query
|
|
|
|
return queries
|
|
|
|
def _load_collection(self, path):
|
|
print_message("#> Loading collection...")
|
|
|
|
collection = []
|
|
|
|
with open(path) as f:
|
|
for line_idx, line in enumerate(f):
|
|
pid, passage, title, *_ = line.strip().split('\t')
|
|
assert pid == 'id' or int(pid) == line_idx
|
|
|
|
passage = title + ' | ' + passage
|
|
collection.append(passage)
|
|
|
|
return collection
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __len__(self):
|
|
return len(self.triples)
|
|
|
|
def __next__(self):
|
|
offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
|
|
self.position = endpos
|
|
|
|
if offset + self.bsize > len(self.triples):
|
|
raise StopIteration
|
|
|
|
queries, positives, negatives = [], [], []
|
|
|
|
for position in range(offset, endpos):
|
|
query, pos, neg = self.triples[position]
|
|
query, pos, neg = self.queries[query], self.collection[pos], self.collection[neg]
|
|
|
|
queries.append(query)
|
|
positives.append(pos)
|
|
negatives.append(neg)
|
|
|
|
return self.collate(queries, positives, negatives)
|
|
|
|
def collate(self, queries, positives, negatives):
|
|
assert len(queries) == len(positives) == len(negatives) == self.bsize
|
|
|
|
return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)
|
|
|
|
def skip_to_batch(self, batch_idx, intended_batch_size):
|
|
Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
|
|
self.position = intended_batch_size * batch_idx
|
|
|