#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://arxiv.org/abs/2006.12847 https://github.com/facebookresearch/denoiser """ import math import os from typing import List, Optional, Union import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.demucs.configuration_demucs import DemucsConfig from toolbox.torchaudio.models.demucs.resample import upsample2, downsample2 activation_layer_dict = { "glu": nn.GLU, "relu": nn.ReLU, "identity": nn.Identity, "sigmoid": nn.Sigmoid, } class BLSTM(nn.Module): def __init__(self, hidden_size: int, num_layers: int = 2, bidirectional: bool = True, ): super().__init__() self.lstm = nn.LSTM(bidirectional=bidirectional, num_layers=num_layers, hidden_size=hidden_size, input_size=hidden_size ) self.linear = None if bidirectional: self.linear = nn.Linear(2 * hidden_size, hidden_size) def forward(self, x: torch.Tensor, hx: torch.Tensor = None ): x, hx = self.lstm.forward(x, hx) if self.linear: x = self.linear(x) return x, hx def rescale_conv(conv, reference): std = conv.weight.std().detach() scale = (std / reference)**0.5 conv.weight.data /= scale if conv.bias is not None: conv.bias.data /= scale def rescale_module(module, reference): for sub in module.modules(): if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): rescale_conv(sub, reference) class DemucsModel(nn.Module): def __init__(self, in_channels: int = 1, out_channels: int = 1, hidden_channels: int = 48, depth: int = 5, kernel_size: int = 8, stride: int = 4, causal: bool = True, resample: int = 4, growth: int = 2, max_hidden: int = 10_000, do_normalize: bool = True, rescale: float = 0.1, floor: float = 1e-3, ): super(DemucsModel, self).__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels self.depth = depth self.kernel_size = kernel_size self.stride = stride self.causal = causal self.resample = resample self.growth = growth self.max_hidden = max_hidden self.do_normalize = do_normalize self.rescale = rescale self.floor = floor if resample not in [1, 2, 4]: raise ValueError("Resample should be 1, 2 or 4.") self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() for index in range(depth): encode = [] encode += [ nn.Conv1d(in_channels, hidden_channels, kernel_size, stride), nn.ReLU(), nn.Conv1d(hidden_channels, hidden_channels * 2, 1), nn.GLU(1), ] self.encoder.append(nn.Sequential(*encode)) decode = [] decode += [ nn.Conv1d(hidden_channels, 2 * hidden_channels, 1), nn.GLU(1), nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride), ] if index > 0: decode.append(nn.ReLU()) self.decoder.insert(0, nn.Sequential(*decode)) out_channels = hidden_channels in_channels = hidden_channels hidden_channels = min(int(growth * hidden_channels), max_hidden) self.lstm = BLSTM(in_channels, bidirectional=not causal) if rescale: rescale_module(self, reference=rescale) @staticmethod def valid_length(length: int, depth: int, kernel_size: int, stride: int, resample: int): """ Return the nearest valid length to use with the model so that there is no time steps left over in a convolutions, e.g. for all layers, size of the input - kernel_size % stride = 0. If the mixture has a valid length, the estimated sources will have exactly the same length. """ length = math.ceil(length * resample) for idx in range(depth): length = math.ceil((length - kernel_size) / stride) + 1 length = max(length, 1) for idx in range(depth): length = (length - 1) * stride + kernel_size length = int(math.ceil(length / resample)) return int(length) def forward(self, noisy: torch.Tensor): """ :param noisy: Tensor, shape: [batch_size, num_samples] or [batch_size, channels, num_samples] :return: """ if noisy.dim() == 2: noisy = noisy.unsqueeze(1) # noisy shape: [batch_size, channels, num_samples] if self.do_normalize: mono = noisy.mean(dim=1, keepdim=True) std = mono.std(dim=-1, keepdim=True) noisy = noisy / (self.floor + std) else: std = 1 _, _, length = noisy.shape x = noisy length_ = self.valid_length(length, self.depth, self.kernel_size, self.stride, self.resample) x = F.pad(x, (0, length_ - length)) if self.resample == 2: x = upsample2(x) elif self.resample == 4: x = upsample2(x) x = upsample2(x) skips = [] for encode in self.encoder: x = encode(x) skips.append(x) x = x.permute(2, 0, 1) x, _ = self.lstm(x) x = x.permute(1, 2, 0) for decode in self.decoder: skip = skips.pop(-1) x = x + skip[..., :x.shape[-1]] x = decode(x) if self.resample == 2: x = downsample2(x) elif self.resample == 4: x = downsample2(x) x = downsample2(x) x = x[..., :length] return std * x MODEL_FILE = "model.pt" class DemucsPretrainedModel(DemucsModel): def __init__(self, config: DemucsConfig, ): super(DemucsPretrainedModel, self).__init__( # sample_rate=config.sample_rate, in_channels=config.in_channels, out_channels=config.out_channels, hidden_channels=config.hidden_channels, depth=config.depth, kernel_size=config.kernel_size, stride=config.stride, causal=config.causal, resample=config.resample, growth=config.growth, max_hidden=config.max_hidden, do_normalize=config.do_normalize, rescale=config.rescale, floor=config.floor, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = DemucsConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): config = DemucsConfig() model = DemucsModel( in_channels=config.in_channels, out_channels=config.out_channels, hidden_channels=config.hidden_channels, depth=config.depth, kernel_size=config.kernel_size, stride=config.stride, causal=config.causal, resample=config.resample, growth=config.growth, max_hidden=config.max_hidden, do_normalize=config.do_normalize, rescale=config.rescale, floor=config.floor, ) print(model) noisy = torch.rand(size=(1, 8000*4), dtype=torch.float32) denoise = model.forward(noisy) print(denoise.shape) return if __name__ == "__main__": main()