|
from collections import OrderedDict |
|
import copy |
|
import io |
|
import logging |
|
import sys |
|
|
|
import yaml |
|
|
|
from espnet.utils.dynamic_import import dynamic_import |
|
|
|
|
|
PY2 = sys.version_info[0] == 2 |
|
|
|
if PY2: |
|
from collections import Sequence |
|
from funcsigs import signature |
|
else: |
|
|
|
from collections.abc import Sequence |
|
from inspect import signature |
|
|
|
|
|
|
|
|
|
import_alias = dict( |
|
identity="espnet.transform.transform_interface:Identity", |
|
time_warp="espnet.transform.spec_augment:TimeWarp", |
|
time_mask="espnet.transform.spec_augment:TimeMask", |
|
freq_mask="espnet.transform.spec_augment:FreqMask", |
|
spec_augment="espnet.transform.spec_augment:SpecAugment", |
|
speed_perturbation="espnet.transform.perturb:SpeedPerturbation", |
|
volume_perturbation="espnet.transform.perturb:VolumePerturbation", |
|
noise_injection="espnet.transform.perturb:NoiseInjection", |
|
bandpass_perturbation="espnet.transform.perturb:BandpassPerturbation", |
|
rir_convolve="espnet.transform.perturb:RIRConvolve", |
|
delta="espnet.transform.add_deltas:AddDeltas", |
|
cmvn="espnet.transform.cmvn:CMVN", |
|
utterance_cmvn="espnet.transform.cmvn:UtteranceCMVN", |
|
fbank="espnet.transform.spectrogram:LogMelSpectrogram", |
|
spectrogram="espnet.transform.spectrogram:Spectrogram", |
|
stft="espnet.transform.spectrogram:Stft", |
|
istft="espnet.transform.spectrogram:IStft", |
|
stft2fbank="espnet.transform.spectrogram:Stft2LogMelSpectrogram", |
|
wpe="espnet.transform.wpe:WPE", |
|
channel_selector="espnet.transform.channel_selector:ChannelSelector", |
|
) |
|
|
|
|
|
class Transformation(object): |
|
"""Apply some functions to the mini-batch |
|
|
|
Examples: |
|
>>> kwargs = {"process": [{"type": "fbank", |
|
... "n_mels": 80, |
|
... "fs": 16000}, |
|
... {"type": "cmvn", |
|
... "stats": "data/train/cmvn.ark", |
|
... "norm_vars": True}, |
|
... {"type": "delta", "window": 2, "order": 2}]} |
|
>>> transform = Transformation(kwargs) |
|
>>> bs = 10 |
|
>>> xs = [np.random.randn(100, 80).astype(np.float32) |
|
... for _ in range(bs)] |
|
>>> xs = transform(xs) |
|
""" |
|
|
|
def __init__(self, conffile=None): |
|
if conffile is not None: |
|
if isinstance(conffile, dict): |
|
self.conf = copy.deepcopy(conffile) |
|
else: |
|
with io.open(conffile, encoding="utf-8") as f: |
|
self.conf = yaml.safe_load(f) |
|
assert isinstance(self.conf, dict), type(self.conf) |
|
else: |
|
self.conf = {"mode": "sequential", "process": []} |
|
|
|
self.functions = OrderedDict() |
|
if self.conf.get("mode", "sequential") == "sequential": |
|
for idx, process in enumerate(self.conf["process"]): |
|
assert isinstance(process, dict), type(process) |
|
opts = dict(process) |
|
process_type = opts.pop("type") |
|
class_obj = dynamic_import(process_type, import_alias) |
|
|
|
try: |
|
self.functions[idx] = class_obj(**opts) |
|
except TypeError: |
|
try: |
|
signa = signature(class_obj) |
|
except ValueError: |
|
|
|
pass |
|
else: |
|
logging.error( |
|
"Expected signature: {}({})".format( |
|
class_obj.__name__, signa |
|
) |
|
) |
|
raise |
|
else: |
|
raise NotImplementedError( |
|
"Not supporting mode={}".format(self.conf["mode"]) |
|
) |
|
|
|
def __repr__(self): |
|
rep = "\n" + "\n".join( |
|
" {}: {}".format(k, v) for k, v in self.functions.items() |
|
) |
|
return "{}({})".format(self.__class__.__name__, rep) |
|
|
|
def __call__(self, xs, uttid_list=None, **kwargs): |
|
"""Return new mini-batch |
|
|
|
:param Union[Sequence[np.ndarray], np.ndarray] xs: |
|
:param Union[Sequence[str], str] uttid_list: |
|
:return: batch: |
|
:rtype: List[np.ndarray] |
|
""" |
|
if not isinstance(xs, Sequence): |
|
is_batch = False |
|
xs = [xs] |
|
else: |
|
is_batch = True |
|
|
|
if isinstance(uttid_list, str): |
|
uttid_list = [uttid_list for _ in range(len(xs))] |
|
|
|
if self.conf.get("mode", "sequential") == "sequential": |
|
for idx in range(len(self.conf["process"])): |
|
func = self.functions[idx] |
|
|
|
|
|
try: |
|
param = signature(func).parameters |
|
except ValueError: |
|
|
|
param = {} |
|
_kwargs = {k: v for k, v in kwargs.items() if k in param} |
|
try: |
|
if uttid_list is not None and "uttid" in param: |
|
xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)] |
|
else: |
|
xs = [func(x, **_kwargs) for x in xs] |
|
except Exception: |
|
logging.fatal( |
|
"Catch a exception from {}th func: {}".format(idx, func) |
|
) |
|
raise |
|
else: |
|
raise NotImplementedError( |
|
"Not supporting mode={}".format(self.conf["mode"]) |
|
) |
|
|
|
if is_batch: |
|
return xs |
|
else: |
|
return xs[0] |
|
|