#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py """ import os from typing import Optional, Union, Tuple import torch import torch.nn as nn from torch.nn import functional as F import torchaudio from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT MODEL_FILE = "model.pt" class Transpose(nn.Module): def __init__(self, dim0: int, dim1: int): super(Transpose, self).__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, inputs: torch.Tensor): inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1) return inputs class LstmModel(nn.Module): def __init__(self, nfft: int = 512, win_size: int = 512, hop_size: int = 256, win_type: str = "hann", hidden_size=1024, num_layers: int = 2, batch_first: bool = True, dropout: float = 0.2, ): super(LstmModel, self).__init__() self.nfft = nfft self.win_size = win_size self.hop_size = hop_size self.win_type = win_type self.spec_bins = nfft // 2 + 1 self.hidden_size = hidden_size self.eps = 1e-8 self.stft = ConvSTFT( nfft=self.nfft, win_size=self.win_size, hop_size=self.hop_size, win_type=self.win_type, power=None, requires_grad=False ) self.istft = ConviSTFT( nfft=self.nfft, win_size=self.win_size, hop_size=self.hop_size, win_type=self.win_type, requires_grad=False ) self.lstm = nn.LSTM(input_size=self.spec_bins, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first, dropout=dropout, ) self.linear = nn.Linear(in_features=hidden_size, out_features=self.spec_bins) self.activation = nn.Sigmoid() def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: if signal.dim() == 2: signal = torch.unsqueeze(signal, dim=1) _, _, n_samples = signal.shape remainder = (n_samples - self.win_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) return signal def forward(self, noisy: torch.Tensor, h_state: Tuple[torch.Tensor, torch.Tensor] = None, ): num_samples = noisy.shape[-1] noisy = self.signal_prepare(noisy) batch_size, _, num_samples_pad = noisy.shape # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") mag_noisy, pha_noisy = self.mag_pha_stft(noisy) # shape: (b, f, t) # t = (num_samples - win_size) / hop_size + 1 mask, h_state = self.forward_chunk(mag_noisy, h_state) # mask shape: (b, f, t) stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) denoise = self.istft.forward(stft_denoise) # denoise shape: [b, 1, num_samples_pad] denoise = denoise[:, :, :num_samples] # denoise shape: [b, 1, num_samples] return denoise, mask, h_state def mag_pha_stft(self, noisy: torch.Tensor): # noisy shape: [b, num_samples] stft_noisy = self.stft.forward(noisy) # stft_noisy shape: [b, f, t], torch.complex64 real = torch.real(stft_noisy) imag = torch.imag(stft_noisy) mag_noisy = torch.sqrt(real ** 2 + imag ** 2) pha_noisy = torch.atan2(imag, real) # shape: (b, f, t) return mag_noisy, pha_noisy def forward_chunk(self, mag_noisy: torch.Tensor, h_state: Tuple[torch.Tensor, torch.Tensor] = None, ): # mag_noisy shape: (b, f, t) x = torch.transpose(mag_noisy, dim0=2, dim1=1) # x shape: (b, t, f) x, h_state = self.lstm.forward(x, hx=h_state) x = self.linear.forward(x) mask = self.activation(x) # mask shape: (b, t, f) mask = torch.transpose(mask, dim0=2, dim1=1) # mask shape: (b, f, t) return mask, h_state def do_mask(self, mag_noisy: torch.Tensor, pha_noisy: torch.Tensor, mask: torch.Tensor, ): # (b, f, t) mag_denoise = mag_noisy * mask stft_denoise = mag_denoise * torch.exp((1j * pha_noisy)) return stft_denoise class LstmPretrainedModel(LstmModel): def __init__(self, config: LstmConfig, ): super(LstmPretrainedModel, self).__init__( nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, hidden_size=config.hidden_size, num_layers=config.num_layers, dropout=config.dropout, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = LstmConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): config = LstmConfig() model = LstmPretrainedModel(config) model.eval() noisy = torch.randn(size=(1, 16000), dtype=torch.float32) noisy = model.signal_prepare(noisy) b, _, num_samples = noisy.shape t = (num_samples - config.win_size) / config.hop_size + 1 waveform, mask, h_state = model.forward(noisy) print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) # noisy_pad shape: [b, 1, num_samples_pad] h_state = None sub_spec_list = list() for i in range(int(t)): begin = i * config.hop_size end = begin + config.win_size sub_noisy = noisy[:, :, begin:end] mag_noisy, pha_noisy = model.mag_pha_stft(sub_noisy) mask, h_state = model.forward_chunk(mag_noisy, h_state) sub_spec = model.do_mask(mag_noisy, pha_noisy, mask) sub_spec_list.append(sub_spec) spec = torch.concat(sub_spec_list, dim=2) # 1 waveform = model.istft.forward(spec) waveform = waveform[:, :, :num_samples] print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) # 2 cache_dict = None waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32) for i in range(int(t)): sub_spec = spec[:, :, i:i+1] begin = i * config.hop_size end = begin + config.win_size - config.hop_size sub_waveform, cache_dict = model.istft.forward_chunk(sub_spec, cache_dict=cache_dict) # end = begin + config.win_size # sub_waveform = model.istft.forward(sub_spec) # (b, 1, win_size) waveform[:, :, begin:end] = sub_waveform print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) return if __name__ == "__main__": main()