Spaces:
Runtime error
Runtime error
File size: 1,995 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import re
import os
import ujson
def get_parts(directory):
extension = '.pt'
parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
if filename.endswith(extension)])
assert list(range(len(parts))) == parts, parts
# Integer-sortedness matters.
parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]
return parts, parts_paths, samples_paths
def load_doclens(directory, flatten=True):
doclens_filenames = {}
for filename in os.listdir(directory):
match = re.match("doclens.(\d+).json", filename)
if match is not None:
doclens_filenames[int(match.group(1))] = filename
doclens_filenames = [os.path.join(directory, doclens_filenames[i]) for i in sorted(doclens_filenames.keys())]
all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]
if flatten:
all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]
if len(all_doclens) == 0:
raise ValueError("Could not load doclens")
return all_doclens
def get_deltas(directory):
extension = '.residuals.pt'
parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
if filename.endswith(extension)])
assert list(range(len(parts))) == parts, parts
# Integer-sortedness matters.
parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
return parts, parts_paths
# def load_compression_data(level, path):
# with open(path, "r") as f:
# for line in f:
# line = line.split(',')
# bits = int(line[0])
# if bits == level:
# return [float(v) for v in line[1:]]
# raise ValueError(f"No data found for {level}-bit compression")
|