HoneyTian's picture
add frcrn model
1d4c9c3
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://arxiv.org/abs/2206.07293
https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/frcrn.py
https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/frcrn_se/frcrn.py
"""
import os
from typing import Optional, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
from toolbox.torchaudio.models.frcrn.conv_stft import ConviSTFT, ConvSTFT
from toolbox.torchaudio.models.frcrn.unet import UNet
class FRCRN(nn.Module):
""" Frequency Recurrent CRN """
def __init__(self,
use_complex_networks: bool = True,
model_complexity: int = 45,
model_depth: int = 14,
padding_mode: str = "zeros",
nfft: int = 640,
win_size: int = 640,
hop_size: int = 320,
win_type: str = "hann",
):
"""
:param use_complex_networks: bool, Whether to use complex networks.
:param model_complexity: int, define the model complexity with the number of layers
:param model_depth: int, Only two options are available : 10, 20
:param padding_mode: str, Encoder's convolution filter. 'zeros', 'reflect'
:param nfft: int, number of Short Time Fourier Transform (STFT) points
:param win_size: int, length of window used for defining one frame of sample points
:param hop_size: int, length of window shifting (equivalent to hop_size)
:param win_type: str, windowing type used in STFT, eg. 'hanning', 'hamming'
"""
super().__init__()
self.freq_bins = nfft // 2 + 1
self.nfft = nfft
self.win_size = win_size
self.hop_size = hop_size
self.win_type = win_type
self.eps = 1e-8
self.stft = ConvSTFT(
nfft=self.nfft,
win_size=self.win_size,
hop_size=self.hop_size,
win_type=self.win_type,
feature_type="complex",
requires_grad=False
)
self.istft = ConviSTFT(
nfft=self.nfft,
win_size=self.win_size,
hop_size=self.hop_size,
win_type=self.win_type,
feature_type="complex",
requires_grad=False
)
self.unet = UNet(
in_channels=1,
use_complex_networks=use_complex_networks,
model_complexity=model_complexity,
model_depth=model_depth,
padding_mode=padding_mode
)
self.unet2 = UNet(
in_channels=1,
use_complex_networks=use_complex_networks,
model_complexity=model_complexity,
model_depth=model_depth,
padding_mode=padding_mode
)
def forward(self, noisy: torch.Tensor):
"""
:param noisy: torch.Tensor, shape: [b, n_samples] or [b, c, n_samples]
:return:
"""
if noisy.dim() == 2:
noisy = torch.unsqueeze(noisy, dim=1)
_, _, n_samples = noisy.shape
remainder = (n_samples - self.win_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
# [batch_size, freq_bins * 2, time_steps]
cmp_spec = self.stft.forward(noisy)
# [batch_size, 1, freq_bins * 2, time_steps]
cmp_spec = torch.unsqueeze(cmp_spec, 1)
# [batch_size, 2, freq_bins, time_steps]
cmp_spec = torch.cat([
cmp_spec[:, :, :self.freq_bins, :],
cmp_spec[:, :, self.freq_bins:, :],
], dim=1)
# [batch_size, 2, freq_bins, time_steps, 1]
cmp_spec = torch.unsqueeze(cmp_spec, dim=4)
cmp_spec = torch.transpose(cmp_spec, 1, 4)
# [batch_size, 1, freq_bins, time_steps, 2]
unet1_out = self.unet.forward(cmp_spec)
cmp_mask1 = torch.tanh(unet1_out)
unet2_out = self.unet2.forward(unet1_out)
cmp_mask2 = torch.tanh(unet2_out)
# est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1)
cmp_mask2 = cmp_mask2 + cmp_mask1
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
# est_wav shape: [b, n_samples]
est_wav = est_wav[:, :n_samples]
return est_spec, est_wav, est_mask
def apply_mask(self,
cmp_spec: torch.Tensor,
cmp_mask: torch.Tensor,
):
"""
:param cmp_spec: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2]
:param cmp_mask: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2]
:return:
"""
est_spec = torch.cat(
tensors=[
cmp_spec[..., 0] * cmp_mask[..., 0] - cmp_spec[..., 1] * cmp_mask[..., 1],
cmp_spec[..., 0] * cmp_mask[..., 1] + cmp_spec[..., 1] * cmp_mask[..., 0]
], dim=1
)
# est_spec shape: [b, 2, n//2+1, t]
est_spec = torch.cat(tensors=[est_spec[:, 0, :, :], est_spec[:, 1, :, :]], dim=1)
# est_spec shape: [b, n+2, t]
# cmp_mask shape: [b, 1, n//2+1, t, 2]
cmp_mask = torch.squeeze(cmp_mask, dim=1)
# cmp_mask shape: [b, n//2+1, t, 2]
cmp_mask = torch.cat(tensors=[cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], dim=1)
# cmp_mask shape: [b, n+2, t]
# est_spec shape: [b, n+2, t]
est_wav = self.istft(est_spec)
# est_wav shape: [b, 1, n_samples]
est_wav = torch.squeeze(est_wav, 1)
# est_wav shape: [b, n_samples]
return est_spec, est_wav, cmp_mask
def get_params(self, weight_decay=0.0):
"""
为可训练参数配置 weight_decay (权重衰减) 的作用是实现 L2 正则化。
1. 防止过拟合: 通过向损失函数添加参数的 L2 范数 (平方和) 作为惩罚项, weight_decay 会限制模型权重的大小.
这使得模型倾向于学习更小的权重值, 降低对训练数据的过度敏感, 从而提高泛化能力.
2. 控制模型复杂度: 权重衰减直接作用于优化过程, 在梯度更新时对权重进行衰减,
公式: weight = weight - lr * (gradient + weight_decay * weight).
这相当于在梯度下降中额外引入了一个与当前权重值成正比的衰减力, 抑制权重快速增长.
3. 与优化器的具体实现相关
在 SGD 等传统优化器中, weight_decay 直接等价于 L2 正则化.
在 Adam 优化器中, 权重衰减的实现与参数更新耦合, 可能因学习率调整而效果减弱.
在 AdamW 优化器改进了这一点, 将权重衰减与学习率解耦, 使其更符合 L2 正则化的理论效果.
注意:
值过大会导致欠拟合, 过小则正则化效果弱, 常用范围是 1e-4到 1e-2.
某些场景 (如 BatchNorm 层) 可能需要通过参数分组对不同层设置不同的 weight_decay.
:param weight_decay:
:return:
"""
weights, biases = [], []
for name, param in self.named_parameters():
if "bias" in name:
biases += [param]
else:
weights += [param]
params = [{
'params': weights,
'weight_decay': weight_decay,
}, {
'params': biases,
'weight_decay': 0.0,
}]
return params
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
"""
:param est_mask: torch.Tensor, shape: [b, n+2, t]
:param clean:
:param noisy:
:return:
"""
clean_stft = self.stft(clean)
clean_re = clean_stft[:, :self.freq_bins, :]
clean_im = clean_stft[:, self.freq_bins:, :]
noisy_stft = self.stft(noisy)
noisy_re = noisy_stft[:, :self.freq_bins, :]
noisy_im = noisy_stft[:, self.freq_bins:, :]
noisy_power = noisy_re ** 2 + noisy_im ** 2
sr = clean_re
yr = noisy_re
si = clean_im
yi = noisy_im
y_pow = noisy_power
# (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
# (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)
gth_mask_re[gth_mask_re > 2] = 1
gth_mask_re[gth_mask_re < -2] = -1
gth_mask_im[gth_mask_im > 2] = 1
gth_mask_im[gth_mask_im < -2] = -1
mask_re = est_mask[:, :self.freq_bins, :]
mask_im = est_mask[:, self.freq_bins:, :]
loss_re = F.mse_loss(gth_mask_re, mask_re)
loss_im = F.mse_loss(gth_mask_im, mask_im)
loss = loss_re + loss_im
return loss
MODEL_FILE = "model.pt"
class FRCRNPretrainedModel(FRCRN):
def __init__(self,
config: FRCRNConfig,
):
super(FRCRNPretrainedModel, self).__init__(
use_complex_networks=config.use_complex_networks,
model_complexity=config.model_complexity,
model_depth=config.model_depth,
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = FRCRNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
# model = FRCRN(
# use_complex_networks=True,
# model_complexity=45,
# model_depth=14,
# padding_mode="zeros",
# nfft=512,
# win_size=400,
# hop_size=200,
# win_type="hann",
# )
model = FRCRN(
use_complex_networks=True,
model_complexity=45,
model_depth=14,
padding_mode="zeros",
nfft=640,
win_size=640,
hop_size=320,
win_type="hann",
)
mixture = torch.rand(size=(1, 8000), dtype=torch.float32)
est_spec, est_wav, est_mask = model.forward(mixture)
print(est_spec.shape)
print(est_wav.shape)
print(est_mask.shape)
return
if __name__ == "__main__":
main()