PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
29c9ba5 verified
raw
history blame
1.41 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import pickle
import numpy as np
class ShardedTensor(object):
def __init__(self, data, starts):
self.data = data
self.starts = starts
assert self.starts[0] == 0
assert self.starts[-1] == len(self.data)
assert (self.starts[1:] >= self.starts[:-1]).all()
assert (self.starts > -1).all()
@staticmethod
def from_list(xs):
starts = np.full((len(xs) + 1,), -1, dtype=np.long)
data = np.concatenate(xs, axis=0)
starts[0] = 0
for i, x in enumerate(xs):
starts[i + 1] = starts[i] + x.shape[0]
assert (starts > -1).all()
return ShardedTensor(data, starts)
def __getitem__(self, i):
return self.data[self.starts[i] : self.starts[i + 1]]
def __len__(self):
return len(self.starts) - 1
def lengths(self):
return self.starts[1:] - self.starts[:-1]
def save(self, path):
np.save(path + "_starts", self.starts)
np.save(path + "_data", self.data)
@staticmethod
def load(path, mmap_mode=None):
starts = np.load(path + "_starts.npy", mmap_mode)
data = np.load(path + "_data.npy", mmap_mode)
return ShardedTensor(data, starts)