|
|
|
|
|
|
|
|
|
|
|
"""Utility functions.""" |
|
|
|
import fnmatch |
|
import logging |
|
import os |
|
import sys |
|
|
|
import h5py |
|
import numpy as np |
|
|
|
|
|
def find_files(root_dir, query="*.wav", include_root_dir=True): |
|
"""Find files recursively. |
|
|
|
Args: |
|
root_dir (str): Root root_dir to find. |
|
query (str): Query to find. |
|
include_root_dir (bool): If False, root_dir name is not included. |
|
|
|
Returns: |
|
list: List of found filenames. |
|
|
|
""" |
|
files = [] |
|
for root, dirnames, filenames in os.walk(root_dir, followlinks=True): |
|
for filename in fnmatch.filter(filenames, query): |
|
files.append(os.path.join(root, filename)) |
|
if not include_root_dir: |
|
files = [file_.replace(root_dir + "/", "") for file_ in files] |
|
|
|
return files |
|
|
|
|
|
def read_hdf5(hdf5_name, hdf5_path): |
|
"""Read hdf5 dataset. |
|
|
|
Args: |
|
hdf5_name (str): Filename of hdf5 file. |
|
hdf5_path (str): Dataset name in hdf5 file. |
|
|
|
Return: |
|
any: Dataset values. |
|
|
|
""" |
|
if not os.path.exists(hdf5_name): |
|
logging.error(f"There is no such a hdf5 file ({hdf5_name}).") |
|
sys.exit(1) |
|
|
|
hdf5_file = h5py.File(hdf5_name, "r") |
|
|
|
if hdf5_path not in hdf5_file: |
|
logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") |
|
sys.exit(1) |
|
|
|
hdf5_data = hdf5_file[hdf5_path][()] |
|
hdf5_file.close() |
|
|
|
return hdf5_data |
|
|
|
|
|
def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): |
|
"""Write dataset to hdf5. |
|
|
|
Args: |
|
hdf5_name (str): Hdf5 dataset filename. |
|
hdf5_path (str): Dataset path in hdf5. |
|
write_data (ndarray): Data to write. |
|
is_overwrite (bool): Whether to overwrite dataset. |
|
|
|
""" |
|
|
|
write_data = np.array(write_data) |
|
|
|
|
|
folder_name, _ = os.path.split(hdf5_name) |
|
if not os.path.exists(folder_name) and len(folder_name) != 0: |
|
os.makedirs(folder_name) |
|
|
|
|
|
if os.path.exists(hdf5_name): |
|
|
|
hdf5_file = h5py.File(hdf5_name, "r+") |
|
|
|
if hdf5_path in hdf5_file: |
|
if is_overwrite: |
|
logging.warning("Dataset in hdf5 file already exists. " |
|
"recreate dataset in hdf5.") |
|
hdf5_file.__delitem__(hdf5_path) |
|
else: |
|
logging.error("Dataset in hdf5 file already exists. " |
|
"if you want to overwrite, please set is_overwrite = True.") |
|
hdf5_file.close() |
|
sys.exit(1) |
|
else: |
|
|
|
hdf5_file = h5py.File(hdf5_name, "w") |
|
|
|
|
|
hdf5_file.create_dataset(hdf5_path, data=write_data) |
|
hdf5_file.flush() |
|
hdf5_file.close() |
|
|
|
|
|
class HDF5ScpLoader(object): |
|
"""Loader class for a fests.scp file of hdf5 file. |
|
|
|
Examples: |
|
key1 /some/path/a.h5:feats |
|
key2 /some/path/b.h5:feats |
|
key3 /some/path/c.h5:feats |
|
key4 /some/path/d.h5:feats |
|
... |
|
>>> loader = HDF5ScpLoader("hdf5.scp") |
|
>>> array = loader["key1"] |
|
|
|
key1 /some/path/a.h5 |
|
key2 /some/path/b.h5 |
|
key3 /some/path/c.h5 |
|
key4 /some/path/d.h5 |
|
... |
|
>>> loader = HDF5ScpLoader("hdf5.scp", "feats") |
|
>>> array = loader["key1"] |
|
|
|
""" |
|
|
|
def __init__(self, feats_scp, default_hdf5_path="feats"): |
|
"""Initialize HDF5 scp loader. |
|
|
|
Args: |
|
feats_scp (str): Kaldi-style feats.scp file with hdf5 format. |
|
default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used. |
|
|
|
""" |
|
self.default_hdf5_path = default_hdf5_path |
|
with open(feats_scp, encoding='utf-8') as f: |
|
lines = [line.replace("\n", "") for line in f.readlines()] |
|
self.data = {} |
|
for line in lines: |
|
key, value = line.split() |
|
self.data[key] = value |
|
|
|
def get_path(self, key): |
|
"""Get hdf5 file path for a given key.""" |
|
return self.data[key] |
|
|
|
def __getitem__(self, key): |
|
"""Get ndarray for a given key.""" |
|
p = self.data[key] |
|
if ":" in p: |
|
return read_hdf5(*p.split(":")) |
|
else: |
|
return read_hdf5(p, self.default_hdf5_path) |
|
|
|
def __len__(self): |
|
"""Return the length of the scp file.""" |
|
return len(self.data) |
|
|
|
def __iter__(self): |
|
"""Return the iterator of the scp file.""" |
|
return iter(self.data) |
|
|
|
def keys(self): |
|
"""Return the keys of the scp file.""" |
|
return self.data.keys() |
|
|