|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
import logging |
|
import os |
|
from typing import List |
|
|
|
import numpy as np |
|
from torch.utils.data import Dataset |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
level=os.environ.get("LOGLEVEL", "INFO").upper(), |
|
) |
|
logger = logging.getLogger("dataset") |
|
|
|
|
|
class ReprDataset(Dataset): |
|
def __init__( |
|
self, |
|
data_dir: str, |
|
batch_len: int, |
|
): |
|
self.batch_len = batch_len |
|
|
|
self.blocks = self._load_blocks(data_dir) |
|
self.offsets = self._load_offsets(data_dir) |
|
assert len(self.blocks) == len(self.offsets) |
|
|
|
for i in range(len(self.blocks)): |
|
assert self.blocks[i].shape[0] == self.offsets[i][-1] |
|
|
|
self.n_examples = np.cumsum([0] + [offset.shape[0] - 1 for offset in self.offsets]) |
|
|
|
def __len__(self): |
|
return self.n_examples[-1] |
|
|
|
def __getitem__(self, idx): |
|
|
|
block_id = -1 |
|
for n in range(len(self.n_examples) - 1): |
|
if self.n_examples[n] <= idx < self.n_examples[n + 1]: |
|
block_id = n |
|
break |
|
assert 0 <= block_id < len(self.blocks), f"Failed to find {idx}" |
|
block_offset = idx - self.n_examples[block_id] |
|
start = self.offsets[block_id][block_offset] |
|
end = self.offsets[block_id][block_offset + 1] |
|
|
|
|
|
if end - start < self.batch_len: |
|
return None |
|
elif end - start == self.batch_len: |
|
return self.blocks[block_id][start:end] |
|
else: |
|
start_offset = np.random.randint(low=start, high=end - self.batch_len) |
|
return self.blocks[block_id][start_offset:start_offset + self.batch_len] |
|
|
|
@staticmethod |
|
def _load_blocks(feat_dir: str) -> List[np.ndarray]: |
|
|
|
file_names = glob.glob(os.path.join(feat_dir, "*.npy"), recursive=False) |
|
|
|
file_names = sorted(file_names, key=lambda x: int(os.path.basename(x).split("_")[0])) |
|
logger.info(f"Found following blocks: {file_names}") |
|
blocks = [np.load(name, mmap_mode="r") for name in file_names] |
|
return blocks |
|
|
|
@staticmethod |
|
def _load_offsets(feat_dir: str): |
|
def load_lens(file_name: str): |
|
with open(file_name, mode="r") as fp: |
|
res = fp.read().strip().split("\n") |
|
|
|
res = [0] + [int(r) for r in res] |
|
return np.cumsum(res, dtype=int) |
|
|
|
|
|
file_names = glob.glob(os.path.join(feat_dir, "*.len"), recursive=False) |
|
file_names = sorted(file_names, key=lambda x: int(os.path.basename(x).split("_")[0])) |
|
file_lens = [] |
|
for name in file_names: |
|
file_lens.append(load_lens(name)) |
|
return file_lens |
|
|