#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://arxiv.org/abs/2206.07293 https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/frcrn.py https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/frcrn_se/frcrn.py """ import os from typing import Optional, Union import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig from toolbox.torchaudio.models.frcrn.conv_stft import ConviSTFT, ConvSTFT from toolbox.torchaudio.models.frcrn.unet import UNet class FRCRN(nn.Module): """ Frequency Recurrent CRN """ def __init__(self, use_complex_networks: bool = True, model_complexity: int = 45, model_depth: int = 14, padding_mode: str = "zeros", nfft: int = 640, win_size: int = 640, hop_size: int = 320, win_type: str = "hann", ): """ :param use_complex_networks: bool, Whether to use complex networks. :param model_complexity: int, define the model complexity with the number of layers :param model_depth: int, Only two options are available : 10, 20 :param padding_mode: str, Encoder's convolution filter. 'zeros', 'reflect' :param nfft: int, number of Short Time Fourier Transform (STFT) points :param win_size: int, length of window used for defining one frame of sample points :param hop_size: int, length of window shifting (equivalent to hop_size) :param win_type: str, windowing type used in STFT, eg. 'hanning', 'hamming' """ super().__init__() self.freq_bins = nfft // 2 + 1 self.nfft = nfft self.win_size = win_size self.hop_size = hop_size self.win_type = win_type self.eps = 1e-8 self.stft = ConvSTFT( nfft=self.nfft, win_size=self.win_size, hop_size=self.hop_size, win_type=self.win_type, feature_type="complex", requires_grad=False ) self.istft = ConviSTFT( nfft=self.nfft, win_size=self.win_size, hop_size=self.hop_size, win_type=self.win_type, feature_type="complex", requires_grad=False ) self.unet = UNet( in_channels=1, use_complex_networks=use_complex_networks, model_complexity=model_complexity, model_depth=model_depth, padding_mode=padding_mode ) self.unet2 = UNet( in_channels=1, use_complex_networks=use_complex_networks, model_complexity=model_complexity, model_depth=model_depth, padding_mode=padding_mode ) def forward(self, noisy: torch.Tensor): """ :param noisy: torch.Tensor, shape: [b, n_samples] or [b, c, n_samples] :return: """ if noisy.dim() == 2: noisy = torch.unsqueeze(noisy, dim=1) _, _, n_samples = noisy.shape remainder = (n_samples - self.win_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0) # [batch_size, freq_bins * 2, time_steps] cmp_spec = self.stft.forward(noisy) # [batch_size, 1, freq_bins * 2, time_steps] cmp_spec = torch.unsqueeze(cmp_spec, 1) # [batch_size, 2, freq_bins, time_steps] cmp_spec = torch.cat([ cmp_spec[:, :, :self.freq_bins, :], cmp_spec[:, :, self.freq_bins:, :], ], dim=1) # [batch_size, 2, freq_bins, time_steps, 1] cmp_spec = torch.unsqueeze(cmp_spec, dim=4) cmp_spec = torch.transpose(cmp_spec, 1, 4) # [batch_size, 1, freq_bins, time_steps, 2] unet1_out = self.unet.forward(cmp_spec) cmp_mask1 = torch.tanh(unet1_out) unet2_out = self.unet2.forward(unet1_out) cmp_mask2 = torch.tanh(unet2_out) # est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1) cmp_mask2 = cmp_mask2 + cmp_mask1 est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2) # est_wav shape: [b, n_samples] est_wav = est_wav[:, :n_samples] return est_spec, est_wav, est_mask def apply_mask(self, cmp_spec: torch.Tensor, cmp_mask: torch.Tensor, ): """ :param cmp_spec: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2] :param cmp_mask: torch.Tensor, shape: [batch_size, 1, freq_bins, time_steps, 2] :return: """ est_spec = torch.cat( tensors=[ cmp_spec[..., 0] * cmp_mask[..., 0] - cmp_spec[..., 1] * cmp_mask[..., 1], cmp_spec[..., 0] * cmp_mask[..., 1] + cmp_spec[..., 1] * cmp_mask[..., 0] ], dim=1 ) # est_spec shape: [b, 2, n//2+1, t] est_spec = torch.cat(tensors=[est_spec[:, 0, :, :], est_spec[:, 1, :, :]], dim=1) # est_spec shape: [b, n+2, t] # cmp_mask shape: [b, 1, n//2+1, t, 2] cmp_mask = torch.squeeze(cmp_mask, dim=1) # cmp_mask shape: [b, n//2+1, t, 2] cmp_mask = torch.cat(tensors=[cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], dim=1) # cmp_mask shape: [b, n+2, t] # est_spec shape: [b, n+2, t] est_wav = self.istft(est_spec) # est_wav shape: [b, 1, n_samples] est_wav = torch.squeeze(est_wav, 1) # est_wav shape: [b, n_samples] return est_spec, est_wav, cmp_mask def get_params(self, weight_decay=0.0): """ 为可训练参数配置 weight_decay (权重衰减) 的作用是实现 L2 正则化。 1. 防止过拟合: 通过向损失函数添加参数的 L2 范数 (平方和) 作为惩罚项, weight_decay 会限制模型权重的大小. 这使得模型倾向于学习更小的权重值, 降低对训练数据的过度敏感, 从而提高泛化能力. 2. 控制模型复杂度: 权重衰减直接作用于优化过程, 在梯度更新时对权重进行衰减, 公式: weight = weight - lr * (gradient + weight_decay * weight). 这相当于在梯度下降中额外引入了一个与当前权重值成正比的衰减力, 抑制权重快速增长. 3. 与优化器的具体实现相关 在 SGD 等传统优化器中, weight_decay 直接等价于 L2 正则化. 在 Adam 优化器中, 权重衰减的实现与参数更新耦合, 可能因学习率调整而效果减弱. 在 AdamW 优化器改进了这一点, 将权重衰减与学习率解耦, 使其更符合 L2 正则化的理论效果. 注意: 值过大会导致欠拟合, 过小则正则化效果弱, 常用范围是 1e-4到 1e-2. 某些场景 (如 BatchNorm 层) 可能需要通过参数分组对不同层设置不同的 weight_decay. :param weight_decay: :return: """ weights, biases = [], [] for name, param in self.named_parameters(): if "bias" in name: biases += [param] else: weights += [param] params = [{ 'params': weights, 'weight_decay': weight_decay, }, { 'params': biases, 'weight_decay': 0.0, }] return params def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): """ :param est_mask: torch.Tensor, shape: [b, n+2, t] :param clean: :param noisy: :return: """ clean_stft = self.stft(clean) clean_re = clean_stft[:, :self.freq_bins, :] clean_im = clean_stft[:, self.freq_bins:, :] noisy_stft = self.stft(noisy) noisy_re = noisy_stft[:, :self.freq_bins, :] noisy_im = noisy_stft[:, self.freq_bins:, :] noisy_power = noisy_re ** 2 + noisy_im ** 2 sr = clean_re yr = noisy_re si = clean_im yi = noisy_im y_pow = noisy_power # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8) gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps) # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8) gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps) gth_mask_re[gth_mask_re > 2] = 1 gth_mask_re[gth_mask_re < -2] = -1 gth_mask_im[gth_mask_im > 2] = 1 gth_mask_im[gth_mask_im < -2] = -1 mask_re = est_mask[:, :self.freq_bins, :] mask_im = est_mask[:, self.freq_bins:, :] loss_re = F.mse_loss(gth_mask_re, mask_re) loss_im = F.mse_loss(gth_mask_im, mask_im) loss = loss_re + loss_im return loss MODEL_FILE = "model.pt" class FRCRNPretrainedModel(FRCRN): def __init__(self, config: FRCRNConfig, ): super(FRCRNPretrainedModel, self).__init__( use_complex_networks=config.use_complex_networks, model_complexity=config.model_complexity, model_depth=config.model_depth, nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = FRCRNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): # model = FRCRN( # use_complex_networks=True, # model_complexity=45, # model_depth=14, # padding_mode="zeros", # nfft=512, # win_size=400, # hop_size=200, # win_type="hann", # ) model = FRCRN( use_complex_networks=True, model_complexity=45, model_depth=14, padding_mode="zeros", nfft=640, win_size=640, hop_size=320, win_type="hann", ) mixture = torch.rand(size=(1, 8000), dtype=torch.float32) est_spec, est_wav, est_mask = model.forward(mixture) print(est_spec.shape) print(est_wav.shape) print(est_mask.shape) return if __name__ == "__main__": main()