HoneyTian's picture
add dfnet2
ed91efa
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/haoxiangsnr/IRM-based-Speech-Enhancement-using-LSTM/blob/master/model/lstm_model.py
"""
import os
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio
from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
MODEL_FILE = "model.pt"
class Transpose(nn.Module):
def __init__(self, dim0: int, dim1: int):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, inputs: torch.Tensor):
inputs = torch.transpose(inputs, dim0=self.dim0, dim1=self.dim1)
return inputs
class LstmModel(nn.Module):
def __init__(self,
nfft: int = 512,
win_size: int = 512,
hop_size: int = 256,
win_type: str = "hann",
hidden_size=1024,
num_layers: int = 2,
batch_first: bool = True,
dropout: float = 0.2,
):
super(LstmModel, self).__init__()
self.nfft = nfft
self.win_size = win_size
self.hop_size = hop_size
self.win_type = win_type
self.spec_bins = nfft // 2 + 1
self.hidden_size = hidden_size
self.eps = 1e-8
self.stft = ConvSTFT(
nfft=self.nfft,
win_size=self.win_size,
hop_size=self.hop_size,
win_type=self.win_type,
power=None,
requires_grad=False
)
self.istft = ConviSTFT(
nfft=self.nfft,
win_size=self.win_size,
hop_size=self.hop_size,
win_type=self.win_type,
requires_grad=False
)
self.lstm = nn.LSTM(input_size=self.spec_bins,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
dropout=dropout,
)
self.linear = nn.Linear(in_features=hidden_size, out_features=self.spec_bins)
self.activation = nn.Sigmoid()
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
if signal.dim() == 2:
signal = torch.unsqueeze(signal, dim=1)
_, _, n_samples = signal.shape
remainder = (n_samples - self.win_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
return signal
def forward(self,
noisy: torch.Tensor,
h_state: Tuple[torch.Tensor, torch.Tensor] = None,
):
num_samples = noisy.shape[-1]
noisy = self.signal_prepare(noisy)
batch_size, _, num_samples_pad = noisy.shape
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
mag_noisy, pha_noisy = self.mag_pha_stft(noisy)
# shape: (b, f, t)
# t = (num_samples - win_size) / hop_size + 1
mask, h_state = self.forward_chunk(mag_noisy, h_state)
# mask shape: (b, f, t)
stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
denoise = self.istft.forward(stft_denoise)
# denoise shape: [b, 1, num_samples_pad]
denoise = denoise[:, :, :num_samples]
# denoise shape: [b, 1, num_samples]
return denoise, mask, h_state
def mag_pha_stft(self, noisy: torch.Tensor):
# noisy shape: [b, num_samples]
stft_noisy = self.stft.forward(noisy)
# stft_noisy shape: [b, f, t], torch.complex64
real = torch.real(stft_noisy)
imag = torch.imag(stft_noisy)
mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
pha_noisy = torch.atan2(imag, real)
# shape: (b, f, t)
return mag_noisy, pha_noisy
def forward_chunk(self,
mag_noisy: torch.Tensor,
h_state: Tuple[torch.Tensor, torch.Tensor] = None,
):
# mag_noisy shape: (b, f, t)
x = torch.transpose(mag_noisy, dim0=2, dim1=1)
# x shape: (b, t, f)
x, h_state = self.lstm.forward(x, hx=h_state)
x = self.linear.forward(x)
mask = self.activation(x)
# mask shape: (b, t, f)
mask = torch.transpose(mask, dim0=2, dim1=1)
# mask shape: (b, f, t)
return mask, h_state
def do_mask(self,
mag_noisy: torch.Tensor,
pha_noisy: torch.Tensor,
mask: torch.Tensor,
):
# (b, f, t)
mag_denoise = mag_noisy * mask
stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
return stft_denoise
class LstmPretrainedModel(LstmModel):
def __init__(self,
config: LstmConfig,
):
super(LstmPretrainedModel, self).__init__(
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
dropout=config.dropout,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = LstmConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
config = LstmConfig()
model = LstmPretrainedModel(config)
model.eval()
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
noisy = model.signal_prepare(noisy)
b, _, num_samples = noisy.shape
t = (num_samples - config.win_size) / config.hop_size + 1
waveform, mask, h_state = model.forward(noisy)
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
# noisy_pad shape: [b, 1, num_samples_pad]
h_state = None
sub_spec_list = list()
for i in range(int(t)):
begin = i * config.hop_size
end = begin + config.win_size
sub_noisy = noisy[:, :, begin:end]
mag_noisy, pha_noisy = model.mag_pha_stft(sub_noisy)
mask, h_state = model.forward_chunk(mag_noisy, h_state)
sub_spec = model.do_mask(mag_noisy, pha_noisy, mask)
sub_spec_list.append(sub_spec)
spec = torch.concat(sub_spec_list, dim=2)
# 1
waveform = model.istft.forward(spec)
waveform = waveform[:, :, :num_samples]
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
# 2
cache_dict = None
waveform = torch.zeros(size=(b, 1, num_samples), dtype=torch.float32)
for i in range(int(t)):
sub_spec = spec[:, :, i:i+1]
begin = i * config.hop_size
end = begin + config.win_size - config.hop_size
sub_waveform, cache_dict = model.istft.forward_chunk(sub_spec, cache_dict=cache_dict)
# end = begin + config.win_size
# sub_waveform = model.istft.forward(sub_spec)
# (b, 1, win_size)
waveform[:, :, begin:end] = sub_waveform
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
return
if __name__ == "__main__":
main()