|
import importlib |
|
import os |
|
from abc import ABC, abstractmethod |
|
from typing import Dict, Optional |
|
|
|
|
|
class AudioFeatureTransform(ABC): |
|
@classmethod |
|
@abstractmethod |
|
def from_config_dict(cls, config: Optional[Dict] = None): |
|
pass |
|
|
|
|
|
AUDIO_FEATURE_TRANSFORM_REGISTRY = {} |
|
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set() |
|
|
|
|
|
def register_audio_feature_transform(name): |
|
def register_audio_feature_transform_cls(cls): |
|
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY: |
|
raise ValueError(f"Cannot register duplicate transform ({name})") |
|
if not issubclass(cls, AudioFeatureTransform): |
|
raise ValueError( |
|
f"Transform ({name}: {cls.__name__}) must extend " |
|
"AudioFeatureTransform" |
|
) |
|
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES: |
|
raise ValueError( |
|
f"Cannot register audio feature transform with duplicate " |
|
f"class name ({cls.__name__})" |
|
) |
|
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls |
|
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__) |
|
return cls |
|
|
|
return register_audio_feature_transform_cls |
|
|
|
|
|
def get_audio_feature_transform(name): |
|
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name] |
|
|
|
|
|
transforms_dir = os.path.dirname(__file__) |
|
for file in os.listdir(transforms_dir): |
|
path = os.path.join(transforms_dir, file) |
|
if ( |
|
not file.startswith("_") |
|
and not file.startswith(".") |
|
and (file.endswith(".py") or os.path.isdir(path)) |
|
): |
|
name = file[: file.find(".py")] if file.endswith(".py") else file |
|
importlib.import_module("fairseq.data.audio.feature_transforms." + name) |
|
|
|
|
|
class CompositeAudioFeatureTransform(AudioFeatureTransform): |
|
@classmethod |
|
def from_config_dict(cls, config=None): |
|
_config = {} if config is None else config |
|
_transforms = _config.get("transforms") |
|
if _transforms is None: |
|
return None |
|
transforms = [ |
|
get_audio_feature_transform(_t).from_config_dict(_config.get(_t)) |
|
for _t in _transforms |
|
] |
|
return CompositeAudioFeatureTransform(transforms) |
|
|
|
def __init__(self, transforms): |
|
self.transforms = [t for t in transforms if t is not None] |
|
|
|
def __call__(self, x): |
|
for t in self.transforms: |
|
x = t(x) |
|
return x |
|
|
|
def __repr__(self): |
|
format_string = ( |
|
[self.__class__.__name__ + "("] |
|
+ [f" {t.__repr__()}" for t in self.transforms] |
|
+ [")"] |
|
) |
|
return "\n".join(format_string) |
|
|