File size: 1,098 Bytes
828992f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import os
import torch
import ujson
from math import ceil
from itertools import accumulate
from colbert.utils.utils import print_message
def get_parts(directory):
extension = '.pt'
parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
if filename.endswith(extension)])
assert list(range(len(parts))) == parts, parts
# Integer-sortedness matters.
parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]
return parts, parts_paths, samples_paths
def load_doclens(directory, flatten=True):
parts, _, _ = get_parts(directory)
doclens_filenames = [os.path.join(directory, 'doclens.{}.json'.format(filename)) for filename in parts]
all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]
if flatten:
all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]
return all_doclens
|