|
import glob |
|
import io |
|
import os |
|
import re |
|
import zipfile |
|
from abc import ABC, abstractmethod |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from typing import Dict, Iterator, List, Optional, Sequence, Tuple |
|
|
|
import numpy as np |
|
|
|
|
|
@dataclass |
|
class NumpyArrayInfo: |
|
""" |
|
Information about an array in an npz file. |
|
""" |
|
|
|
name: str |
|
dtype: np.dtype |
|
shape: Tuple[int] |
|
|
|
@classmethod |
|
def infos_from_first_file(cls, glob_path: str) -> Dict[str, "NumpyArrayInfo"]: |
|
paths, _ = _npz_paths_and_length(glob_path) |
|
return cls.infos_from_file(paths[0]) |
|
|
|
@classmethod |
|
def infos_from_file(cls, npz_path: str) -> Dict[str, "NumpyArrayInfo"]: |
|
""" |
|
Extract the info of every array in an npz file. |
|
""" |
|
if not os.path.exists(npz_path): |
|
raise FileNotFoundError(f"batch of samples was not found: {npz_path}") |
|
results = {} |
|
with open(npz_path, "rb") as f: |
|
with zipfile.ZipFile(f, "r") as zip_f: |
|
for name in zip_f.namelist(): |
|
if not name.endswith(".npy"): |
|
continue |
|
key_name = name[: -len(".npy")] |
|
with zip_f.open(name, "r") as arr_f: |
|
version = np.lib.format.read_magic(arr_f) |
|
if version == (1, 0): |
|
header = np.lib.format.read_array_header_1_0(arr_f) |
|
elif version == (2, 0): |
|
header = np.lib.format.read_array_header_2_0(arr_f) |
|
else: |
|
raise ValueError(f"unknown numpy array version: {version}") |
|
shape, _, dtype = header |
|
results[key_name] = cls(name=key_name, dtype=dtype, shape=shape) |
|
return results |
|
|
|
@property |
|
def elem_shape(self) -> Tuple[int]: |
|
return self.shape[1:] |
|
|
|
def validate(self): |
|
if self.name in {"R", "G", "B"}: |
|
if len(self.shape) != 2: |
|
raise ValueError( |
|
f"expecting exactly 2-D shape for '{self.name}' but got: {self.shape}" |
|
) |
|
elif self.name == "arr_0": |
|
if len(self.shape) < 2: |
|
raise ValueError(f"expecting at least 2-D shape but got: {self.shape}") |
|
elif len(self.shape) == 3: |
|
|
|
if not np.issubdtype(self.dtype, np.floating): |
|
raise ValueError( |
|
f"invalid dtype for audio batch: {self.dtype} (expected float)" |
|
) |
|
elif self.dtype != np.uint8: |
|
raise ValueError(f"invalid dtype for image batch: {self.dtype} (expected uint8)") |
|
|
|
|
|
class NpzStreamer: |
|
def __init__(self, glob_path: str): |
|
self.paths, self.trunc_length = _npz_paths_and_length(glob_path) |
|
self.infos = NumpyArrayInfo.infos_from_file(self.paths[0]) |
|
|
|
def keys(self) -> List[str]: |
|
return list(self.infos.keys()) |
|
|
|
def stream(self, batch_size: int, keys: Sequence[str]) -> Iterator[Dict[str, np.ndarray]]: |
|
cur_batch = None |
|
num_remaining = self.trunc_length |
|
for path in self.paths: |
|
if num_remaining is not None and num_remaining <= 0: |
|
break |
|
with open_npz_arrays(path, keys) as readers: |
|
combined_reader = CombinedReader(keys, readers) |
|
while num_remaining is None or num_remaining > 0: |
|
read_bs = batch_size |
|
if cur_batch is not None: |
|
read_bs -= _dict_batch_size(cur_batch) |
|
if num_remaining is not None: |
|
read_bs = min(read_bs, num_remaining) |
|
|
|
batch = combined_reader.read_batch(read_bs) |
|
if batch is None: |
|
break |
|
if num_remaining is not None: |
|
num_remaining -= _dict_batch_size(batch) |
|
if cur_batch is None: |
|
cur_batch = batch |
|
else: |
|
cur_batch = { |
|
|
|
k: np.concatenate([cur_batch[k], v], axis=0) |
|
for k, v in batch.items() |
|
} |
|
if _dict_batch_size(cur_batch) == batch_size: |
|
yield cur_batch |
|
cur_batch = None |
|
if cur_batch is not None: |
|
yield cur_batch |
|
|
|
|
|
def _npz_paths_and_length(glob_path: str) -> Tuple[List[str], Optional[int]]: |
|
|
|
count_match = re.match("^(.*)\\[:([0-9]*)\\]$", glob_path) |
|
if count_match: |
|
raw_path = count_match[1] |
|
max_count = int(count_match[2]) |
|
else: |
|
raw_path = glob_path |
|
max_count = None |
|
paths = sorted(glob.glob(raw_path)) |
|
if not len(paths): |
|
raise ValueError(f"no paths found matching: {glob_path}") |
|
return paths, max_count |
|
|
|
|
|
class NpzArrayReader(ABC): |
|
@abstractmethod |
|
def read_batch(self, batch_size: int) -> Optional[np.ndarray]: |
|
pass |
|
|
|
|
|
class StreamingNpzArrayReader(NpzArrayReader): |
|
def __init__(self, arr_f, shape, dtype): |
|
self.arr_f = arr_f |
|
self.shape = shape |
|
self.dtype = dtype |
|
self.idx = 0 |
|
|
|
def read_batch(self, batch_size: int) -> Optional[np.ndarray]: |
|
if self.idx >= self.shape[0]: |
|
return None |
|
|
|
bs = min(batch_size, self.shape[0] - self.idx) |
|
self.idx += bs |
|
|
|
if self.dtype.itemsize == 0: |
|
return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) |
|
|
|
read_count = bs * np.prod(self.shape[1:]) |
|
read_size = int(read_count * self.dtype.itemsize) |
|
data = _read_bytes(self.arr_f, read_size, "array data") |
|
return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) |
|
|
|
|
|
class MemoryNpzArrayReader(NpzArrayReader): |
|
def __init__(self, arr): |
|
self.arr = arr |
|
self.idx = 0 |
|
|
|
@classmethod |
|
def load(cls, path: str, arr_name: str): |
|
with open(path, "rb") as f: |
|
arr = np.load(f)[arr_name] |
|
return cls(arr) |
|
|
|
def read_batch(self, batch_size: int) -> Optional[np.ndarray]: |
|
if self.idx >= self.arr.shape[0]: |
|
return None |
|
|
|
res = self.arr[self.idx : self.idx + batch_size] |
|
self.idx += batch_size |
|
return res |
|
|
|
|
|
@contextmanager |
|
def open_npz_arrays(path: str, arr_names: Sequence[str]) -> List[NpzArrayReader]: |
|
if not len(arr_names): |
|
yield [] |
|
return |
|
arr_name = arr_names[0] |
|
with open_array(path, arr_name) as arr_f: |
|
version = np.lib.format.read_magic(arr_f) |
|
header = None |
|
if version == (1, 0): |
|
header = np.lib.format.read_array_header_1_0(arr_f) |
|
elif version == (2, 0): |
|
header = np.lib.format.read_array_header_2_0(arr_f) |
|
|
|
if header is None: |
|
reader = MemoryNpzArrayReader.load(path, arr_name) |
|
else: |
|
shape, fortran, dtype = header |
|
if fortran or dtype.hasobject: |
|
reader = MemoryNpzArrayReader.load(path, arr_name) |
|
else: |
|
reader = StreamingNpzArrayReader(arr_f, shape, dtype) |
|
|
|
with open_npz_arrays(path, arr_names[1:]) as next_readers: |
|
yield [reader] + next_readers |
|
|
|
|
|
class CombinedReader: |
|
def __init__(self, keys: List[str], readers: List[NpzArrayReader]): |
|
self.keys = keys |
|
self.readers = readers |
|
|
|
def read_batch(self, batch_size: int) -> Optional[Dict[str, np.ndarray]]: |
|
batches = [r.read_batch(batch_size) for r in self.readers] |
|
any_none = any(x is None for x in batches) |
|
all_none = all(x is None for x in batches) |
|
if any_none != all_none: |
|
raise RuntimeError("different keys had different numbers of elements") |
|
if any_none: |
|
return None |
|
if any(len(x) != len(batches[0]) for x in batches): |
|
raise RuntimeError("different keys had different numbers of elements") |
|
return dict(zip(self.keys, batches)) |
|
|
|
|
|
def _read_bytes(fp, size, error_template="ran out of data"): |
|
""" |
|
Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 |
|
|
|
Read from file-like object until size bytes are read. |
|
Raises ValueError if not EOF is encountered before size bytes are read. |
|
Non-blocking objects only supported if they derive from io objects. |
|
Required as e.g. ZipExtFile in python 2.6 can return less data than |
|
requested. |
|
""" |
|
data = bytes() |
|
while True: |
|
|
|
|
|
|
|
try: |
|
r = fp.read(size - len(data)) |
|
data += r |
|
if len(r) == 0 or len(data) == size: |
|
break |
|
except io.BlockingIOError: |
|
pass |
|
if len(data) != size: |
|
msg = "EOF: reading %s, expected %d bytes got %d" |
|
raise ValueError(msg % (error_template, size, len(data))) |
|
else: |
|
return data |
|
|
|
|
|
@contextmanager |
|
def open_array(path: str, arr_name: str): |
|
with open(path, "rb") as f: |
|
with zipfile.ZipFile(f, "r") as zip_f: |
|
if f"{arr_name}.npy" not in zip_f.namelist(): |
|
raise ValueError(f"missing {arr_name} in npz file") |
|
with zip_f.open(f"{arr_name}.npy", "r") as arr_f: |
|
yield arr_f |
|
|
|
|
|
def _dict_batch_size(objs: Dict[str, np.ndarray]) -> int: |
|
return len(next(iter(objs.values()))) |
|
|