|
|
|
import argparse |
|
import logging |
|
from pathlib import Path |
|
import sys |
|
from typing import List |
|
from typing import Optional |
|
from typing import Sequence |
|
from typing import Tuple |
|
from typing import Union |
|
|
|
import humanfriendly |
|
import numpy as np |
|
import torch |
|
from tqdm import trange |
|
from typeguard import check_argument_types |
|
|
|
from espnet.utils.cli_utils import get_commandline_args |
|
from espnet2.fileio.sound_scp import SoundScpWriter |
|
from espnet2.tasks.enh import EnhancementTask |
|
from espnet2.torch_utils.device_funcs import to_device |
|
from espnet2.torch_utils.set_all_random_seed import set_all_random_seed |
|
from espnet2.utils import config_argparse |
|
from espnet2.utils.types import str2bool |
|
from espnet2.utils.types import str2triple_str |
|
from espnet2.utils.types import str_or_none |
|
|
|
|
|
EPS = torch.finfo(torch.get_default_dtype()).eps |
|
|
|
|
|
class SeparateSpeech: |
|
"""SeparateSpeech class |
|
|
|
Examples: |
|
>>> import soundfile |
|
>>> separate_speech = SeparateSpeech("enh_config.yml", "enh.pth") |
|
>>> audio, rate = soundfile.read("speech.wav") |
|
>>> separate_speech(audio) |
|
[separated_audio1, separated_audio2, ...] |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
enh_train_config: Union[Path, str], |
|
enh_model_file: Union[Path, str] = None, |
|
segment_size: Optional[float] = None, |
|
hop_size: Optional[float] = None, |
|
normalize_segment_scale: bool = False, |
|
show_progressbar: bool = False, |
|
ref_channel: Optional[int] = None, |
|
normalize_output_wav: bool = False, |
|
device: str = "cpu", |
|
dtype: str = "float32", |
|
): |
|
assert check_argument_types() |
|
|
|
|
|
enh_model, enh_train_args = EnhancementTask.build_model_from_file( |
|
enh_train_config, enh_model_file, device |
|
) |
|
enh_model.to(dtype=getattr(torch, dtype)).eval() |
|
|
|
self.device = device |
|
self.dtype = dtype |
|
self.enh_train_args = enh_train_args |
|
self.enh_model = enh_model |
|
|
|
|
|
|
|
self.segment_size = segment_size |
|
self.hop_size = hop_size |
|
self.normalize_segment_scale = normalize_segment_scale |
|
self.normalize_output_wav = normalize_output_wav |
|
self.show_progressbar = show_progressbar |
|
|
|
self.num_spk = enh_model.num_spk |
|
task = "enhancement" if self.num_spk == 1 else "separation" |
|
|
|
|
|
if ref_channel is not None: |
|
logging.info( |
|
"Overwrite enh_model.separator.ref_channel with {}".format(ref_channel) |
|
) |
|
enh_model.separator.ref_channel = ref_channel |
|
self.ref_channel = ref_channel |
|
else: |
|
self.ref_channel = enh_model.ref_channel |
|
|
|
self.segmenting = segment_size is not None and hop_size is not None |
|
if self.segmenting: |
|
logging.info("Perform segment-wise speech %s" % task) |
|
logging.info( |
|
"Segment length = {} sec, hop length = {} sec".format( |
|
segment_size, hop_size |
|
) |
|
) |
|
else: |
|
logging.info("Perform direct speech %s on the input" % task) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, speech_mix: Union[torch.Tensor, np.ndarray], fs: int = 8000 |
|
) -> List[torch.Tensor]: |
|
"""Inference |
|
|
|
Args: |
|
speech_mix: Input speech data (Batch, Nsamples [, Channels]) |
|
fs: sample rate |
|
Returns: |
|
[separated_audio1, separated_audio2, ...] |
|
|
|
""" |
|
assert check_argument_types() |
|
|
|
|
|
if isinstance(speech_mix, np.ndarray): |
|
speech_mix = torch.as_tensor(speech_mix) |
|
|
|
assert speech_mix.dim() > 1, speech_mix.size() |
|
batch_size = speech_mix.size(0) |
|
speech_mix = speech_mix.to(getattr(torch, self.dtype)) |
|
|
|
lengths = speech_mix.new_full( |
|
[batch_size], dtype=torch.long, fill_value=speech_mix.size(1) |
|
) |
|
|
|
|
|
speech_mix = to_device(speech_mix, device=self.device) |
|
lengths = to_device(lengths, device=self.device) |
|
|
|
if self.segmenting and lengths[0] > self.segment_size * fs: |
|
|
|
overlap_length = int(np.round(fs * (self.segment_size - self.hop_size))) |
|
num_segments = int( |
|
np.ceil((speech_mix.size(1) - overlap_length) / (self.hop_size * fs)) |
|
) |
|
t = T = int(self.segment_size * fs) |
|
pad_shape = speech_mix[:, :T].shape |
|
enh_waves = [] |
|
range_ = trange if self.show_progressbar else range |
|
for i in range_(num_segments): |
|
st = int(i * self.hop_size * fs) |
|
en = st + T |
|
if en >= lengths[0]: |
|
|
|
en = lengths[0] |
|
speech_seg = speech_mix.new_zeros(pad_shape) |
|
t = en - st |
|
speech_seg[:, :t] = speech_mix[:, st:en] |
|
else: |
|
t = T |
|
speech_seg = speech_mix[:, st:en] |
|
|
|
lengths_seg = speech_mix.new_full( |
|
[batch_size], dtype=torch.long, fill_value=T |
|
) |
|
|
|
feats, f_lens = self.enh_model.encoder(speech_seg, lengths_seg) |
|
feats, _, _ = self.enh_model.separator(feats, f_lens) |
|
processed_wav = [ |
|
self.enh_model.decoder(f, lengths_seg)[0] for f in feats |
|
] |
|
if speech_seg.dim() > 2: |
|
|
|
speech_seg_ = speech_seg[:, self.ref_channel] |
|
else: |
|
speech_seg_ = speech_seg |
|
|
|
if self.normalize_segment_scale: |
|
|
|
|
|
processed_wav = [ |
|
self.normalize_scale(w, speech_seg_) for w in processed_wav |
|
] |
|
|
|
enh_waves.append(torch.stack(processed_wav, dim=0)) |
|
|
|
|
|
waves = enh_waves[0] |
|
for i in range(1, num_segments): |
|
|
|
perm = self.cal_permumation( |
|
waves[:, :, -overlap_length:], |
|
enh_waves[i][:, :, :overlap_length], |
|
criterion="si_snr", |
|
) |
|
|
|
for batch in range(batch_size): |
|
enh_waves[i][:, batch] = enh_waves[i][perm[batch], batch] |
|
|
|
if i == num_segments - 1: |
|
enh_waves[i][:, :, t:] = 0 |
|
enh_waves_res_i = enh_waves[i][:, :, overlap_length:t] |
|
else: |
|
enh_waves_res_i = enh_waves[i][:, :, overlap_length:] |
|
|
|
|
|
waves[:, :, -overlap_length:] = ( |
|
waves[:, :, -overlap_length:] + enh_waves[i][:, :, :overlap_length] |
|
) / 2 |
|
|
|
waves = torch.cat([waves, enh_waves_res_i], dim=2) |
|
|
|
assert waves.size(2) == speech_mix.size(1), (waves.shape, speech_mix.shape) |
|
waves = torch.unbind(waves, dim=0) |
|
else: |
|
|
|
feats, f_lens = self.enh_model.encoder(speech_mix, lengths) |
|
feats, _, _ = self.enh_model.separator(feats, f_lens) |
|
waves = [self.enh_model.decoder(f, lengths)[0] for f in feats] |
|
|
|
assert len(waves) == self.num_spk, len(waves) == self.num_spk |
|
assert len(waves[0]) == batch_size, (len(waves[0]), batch_size) |
|
if self.normalize_output_wav: |
|
waves = [ |
|
(w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).cpu().numpy() |
|
for w in waves |
|
] |
|
else: |
|
waves = [w.cpu().numpy() for w in waves] |
|
|
|
return waves |
|
|
|
@staticmethod |
|
@torch.no_grad() |
|
def normalize_scale(enh_wav, ref_ch_wav): |
|
"""Normalize the energy of enh_wav to match that of ref_ch_wav. |
|
|
|
Args: |
|
enh_wav (torch.Tensor): (B, Nsamples) |
|
ref_ch_wav (torch.Tensor): (B, Nsamples) |
|
Returns: |
|
enh_wav (torch.Tensor): (B, Nsamples) |
|
""" |
|
ref_energy = torch.sqrt(torch.mean(ref_ch_wav.pow(2), dim=1)) |
|
enh_energy = torch.sqrt(torch.mean(enh_wav.pow(2), dim=1)) |
|
return enh_wav * (ref_energy / enh_energy)[:, None] |
|
|
|
@torch.no_grad() |
|
def cal_permumation(self, ref_wavs, enh_wavs, criterion="si_snr"): |
|
"""Calculate the permutation between seaprated streams in two adjacent segments. |
|
|
|
Args: |
|
ref_wavs (List[torch.Tensor]): [(Batch, Nsamples)] |
|
enh_wavs (List[torch.Tensor]): [(Batch, Nsamples)] |
|
criterion (str): one of ("si_snr", "mse", "corr) |
|
Returns: |
|
perm (torch.Tensor): permutation for enh_wavs (Batch, num_spk) |
|
""" |
|
loss_func = { |
|
"si_snr": self.enh_model.si_snr_loss, |
|
"mse": lambda enh, ref: torch.mean((enh - ref).pow(2), dim=1), |
|
"corr": lambda enh, ref: ( |
|
(enh * ref).sum(dim=1) |
|
/ (enh.pow(2).sum(dim=1) * ref.pow(2).sum(dim=1) + EPS) |
|
).clamp(min=EPS, max=1 - EPS), |
|
}[criterion] |
|
|
|
_, perm = self.enh_model._permutation_loss(ref_wavs, enh_wavs, loss_func) |
|
return perm |
|
|
|
|
|
def humanfriendly_or_none(value: str): |
|
if value in ("none", "None", "NONE"): |
|
return None |
|
return humanfriendly.parse_size(value) |
|
|
|
|
|
def inference( |
|
output_dir: str, |
|
batch_size: int, |
|
dtype: str, |
|
fs: int, |
|
ngpu: int, |
|
seed: int, |
|
num_workers: int, |
|
log_level: Union[int, str], |
|
data_path_and_name_and_type: Sequence[Tuple[str, str, str]], |
|
key_file: Optional[str], |
|
enh_train_config: str, |
|
enh_model_file: str, |
|
allow_variable_data_keys: bool, |
|
segment_size: Optional[float], |
|
hop_size: Optional[float], |
|
normalize_segment_scale: bool, |
|
show_progressbar: bool, |
|
ref_channel: Optional[int], |
|
normalize_output_wav: bool, |
|
): |
|
assert check_argument_types() |
|
if batch_size > 1: |
|
raise NotImplementedError("batch decoding is not implemented") |
|
if ngpu > 1: |
|
raise NotImplementedError("only single GPU decoding is supported") |
|
|
|
logging.basicConfig( |
|
level=log_level, |
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
|
) |
|
|
|
if ngpu >= 1: |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
|
|
|
|
set_all_random_seed(seed) |
|
|
|
|
|
separate_speech = SeparateSpeech( |
|
enh_train_config=enh_train_config, |
|
enh_model_file=enh_model_file, |
|
segment_size=segment_size, |
|
hop_size=hop_size, |
|
normalize_segment_scale=normalize_segment_scale, |
|
show_progressbar=show_progressbar, |
|
ref_channel=ref_channel, |
|
normalize_output_wav=normalize_output_wav, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
|
|
loader = EnhancementTask.build_streaming_iterator( |
|
data_path_and_name_and_type, |
|
dtype=dtype, |
|
batch_size=batch_size, |
|
key_file=key_file, |
|
num_workers=num_workers, |
|
preprocess_fn=EnhancementTask.build_preprocess_fn( |
|
separate_speech.enh_train_args, False |
|
), |
|
collate_fn=EnhancementTask.build_collate_fn( |
|
separate_speech.enh_train_args, False |
|
), |
|
allow_variable_data_keys=allow_variable_data_keys, |
|
inference=True, |
|
) |
|
|
|
|
|
writers = [] |
|
for i in range(separate_speech.num_spk): |
|
writers.append( |
|
SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") |
|
) |
|
|
|
for keys, batch in loader: |
|
assert isinstance(batch, dict), type(batch) |
|
assert all(isinstance(s, str) for s in keys), keys |
|
_bs = len(next(iter(batch.values()))) |
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
|
batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} |
|
|
|
waves = separate_speech(**batch) |
|
for (spk, w) in enumerate(waves): |
|
for b in range(batch_size): |
|
writers[spk][keys[b]] = fs, w[b] |
|
|
|
for writer in writers: |
|
writer.close() |
|
|
|
|
|
def get_parser(): |
|
parser = config_argparse.ArgumentParser( |
|
description="Frontend inference", |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
) |
|
|
|
|
|
|
|
parser.add_argument( |
|
"--log_level", |
|
type=lambda x: x.upper(), |
|
default="INFO", |
|
choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
|
help="The verbose level of logging", |
|
) |
|
|
|
parser.add_argument("--output_dir", type=str, required=True) |
|
parser.add_argument( |
|
"--ngpu", |
|
type=int, |
|
default=0, |
|
help="The number of gpus. 0 indicates CPU mode", |
|
) |
|
parser.add_argument("--seed", type=int, default=0, help="Random seed") |
|
parser.add_argument( |
|
"--dtype", |
|
default="float32", |
|
choices=["float16", "float32", "float64"], |
|
help="Data type", |
|
) |
|
parser.add_argument( |
|
"--fs", type=humanfriendly_or_none, default=8000, help="Sampling rate" |
|
) |
|
parser.add_argument( |
|
"--num_workers", |
|
type=int, |
|
default=1, |
|
help="The number of workers used for DataLoader", |
|
) |
|
|
|
group = parser.add_argument_group("Input data related") |
|
group.add_argument( |
|
"--data_path_and_name_and_type", |
|
type=str2triple_str, |
|
required=True, |
|
action="append", |
|
) |
|
group.add_argument("--key_file", type=str_or_none) |
|
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
|
|
|
group = parser.add_argument_group("Output data related") |
|
group.add_argument( |
|
"--normalize_output_wav", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to normalize the predicted wav to [-1~1]", |
|
) |
|
|
|
group = parser.add_argument_group("The model configuration related") |
|
group.add_argument("--enh_train_config", type=str, required=True) |
|
group.add_argument("--enh_model_file", type=str, required=True) |
|
|
|
group = parser.add_argument_group("Data loading related") |
|
group.add_argument( |
|
"--batch_size", |
|
type=int, |
|
default=1, |
|
help="The batch size for inference", |
|
) |
|
group = parser.add_argument_group("SeparateSpeech related") |
|
group.add_argument( |
|
"--segment_size", |
|
type=float, |
|
default=None, |
|
help="Segment length in seconds for segment-wise speech enhancement/separation", |
|
) |
|
group.add_argument( |
|
"--hop_size", |
|
type=float, |
|
default=None, |
|
help="Hop length in seconds for segment-wise speech enhancement/separation", |
|
) |
|
group.add_argument( |
|
"--normalize_segment_scale", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to normalize the energy of the separated streams in each segment", |
|
) |
|
group.add_argument( |
|
"--show_progressbar", |
|
type=str2bool, |
|
default=False, |
|
help="Whether to show a progress bar when performing segment-wise speech " |
|
"enhancement/separation", |
|
) |
|
group.add_argument( |
|
"--ref_channel", |
|
type=int, |
|
default=None, |
|
help="If not None, this will overwrite the ref_channel defined in the " |
|
"separator module (for multi-channel speech processing)", |
|
) |
|
|
|
return parser |
|
|
|
|
|
def main(cmd=None): |
|
print(get_commandline_args(), file=sys.stderr) |
|
parser = get_parser() |
|
args = parser.parse_args(cmd) |
|
kwargs = vars(args) |
|
kwargs.pop("config", None) |
|
inference(**kwargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|