Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://arxiv.org/abs/2006.12847 | |
https://github.com/facebookresearch/denoiser | |
""" | |
import math | |
import os | |
from typing import List, Optional, Union | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
from toolbox.torchaudio.models.demucs.configuration_demucs import DemucsConfig | |
from toolbox.torchaudio.models.demucs.resample import upsample2, downsample2 | |
activation_layer_dict = { | |
"glu": nn.GLU, | |
"relu": nn.ReLU, | |
"identity": nn.Identity, | |
"sigmoid": nn.Sigmoid, | |
} | |
class BLSTM(nn.Module): | |
def __init__(self, | |
hidden_size: int, | |
num_layers: int = 2, | |
bidirectional: bool = True, | |
): | |
super().__init__() | |
self.lstm = nn.LSTM(bidirectional=bidirectional, | |
num_layers=num_layers, | |
hidden_size=hidden_size, | |
input_size=hidden_size | |
) | |
self.linear = None | |
if bidirectional: | |
self.linear = nn.Linear(2 * hidden_size, hidden_size) | |
def forward(self, | |
x: torch.Tensor, | |
hx: torch.Tensor = None | |
): | |
x, hx = self.lstm.forward(x, hx) | |
if self.linear: | |
x = self.linear(x) | |
return x, hx | |
def rescale_conv(conv, reference): | |
std = conv.weight.std().detach() | |
scale = (std / reference)**0.5 | |
conv.weight.data /= scale | |
if conv.bias is not None: | |
conv.bias.data /= scale | |
def rescale_module(module, reference): | |
for sub in module.modules(): | |
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): | |
rescale_conv(sub, reference) | |
class DemucsModel(nn.Module): | |
def __init__(self, | |
in_channels: int = 1, | |
out_channels: int = 1, | |
hidden_channels: int = 48, | |
depth: int = 5, | |
kernel_size: int = 8, | |
stride: int = 4, | |
causal: bool = True, | |
resample: int = 4, | |
growth: int = 2, | |
max_hidden: int = 10_000, | |
do_normalize: bool = True, | |
rescale: float = 0.1, | |
floor: float = 1e-3, | |
): | |
super(DemucsModel, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.hidden_channels = hidden_channels | |
self.depth = depth | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.causal = causal | |
self.resample = resample | |
self.growth = growth | |
self.max_hidden = max_hidden | |
self.do_normalize = do_normalize | |
self.rescale = rescale | |
self.floor = floor | |
if resample not in [1, 2, 4]: | |
raise ValueError("Resample should be 1, 2 or 4.") | |
self.encoder = nn.ModuleList() | |
self.decoder = nn.ModuleList() | |
for index in range(depth): | |
encode = [] | |
encode += [ | |
nn.Conv1d(in_channels, hidden_channels, kernel_size, stride), | |
nn.ReLU(), | |
nn.Conv1d(hidden_channels, hidden_channels * 2, 1), | |
nn.GLU(1), | |
] | |
self.encoder.append(nn.Sequential(*encode)) | |
decode = [] | |
decode += [ | |
nn.Conv1d(hidden_channels, 2 * hidden_channels, 1), | |
nn.GLU(1), | |
nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride), | |
] | |
if index > 0: | |
decode.append(nn.ReLU()) | |
self.decoder.insert(0, nn.Sequential(*decode)) | |
out_channels = hidden_channels | |
in_channels = hidden_channels | |
hidden_channels = min(int(growth * hidden_channels), max_hidden) | |
self.lstm = BLSTM(in_channels, bidirectional=not causal) | |
if rescale: | |
rescale_module(self, reference=rescale) | |
def valid_length(length: int, depth: int, kernel_size: int, stride: int, resample: int): | |
""" | |
Return the nearest valid length to use with the model so that | |
there is no time steps left over in a convolutions, e.g. for all | |
layers, size of the input - kernel_size % stride = 0. | |
If the mixture has a valid length, the estimated sources | |
will have exactly the same length. | |
""" | |
length = math.ceil(length * resample) | |
for idx in range(depth): | |
length = math.ceil((length - kernel_size) / stride) + 1 | |
length = max(length, 1) | |
for idx in range(depth): | |
length = (length - 1) * stride + kernel_size | |
length = int(math.ceil(length / resample)) | |
return int(length) | |
def forward(self, noisy: torch.Tensor): | |
""" | |
:param noisy: Tensor, shape: [batch_size, num_samples] or [batch_size, channels, num_samples] | |
:return: | |
""" | |
if noisy.dim() == 2: | |
noisy = noisy.unsqueeze(1) | |
# noisy shape: [batch_size, channels, num_samples] | |
if self.do_normalize: | |
mono = noisy.mean(dim=1, keepdim=True) | |
std = mono.std(dim=-1, keepdim=True) | |
noisy = noisy / (self.floor + std) | |
else: | |
std = 1 | |
_, _, length = noisy.shape | |
x = noisy | |
length_ = self.valid_length(length, self.depth, self.kernel_size, self.stride, self.resample) | |
x = F.pad(x, (0, length_ - length)) | |
if self.resample == 2: | |
x = upsample2(x) | |
elif self.resample == 4: | |
x = upsample2(x) | |
x = upsample2(x) | |
skips = [] | |
for encode in self.encoder: | |
x = encode(x) | |
skips.append(x) | |
x = x.permute(2, 0, 1) | |
x, _ = self.lstm(x) | |
x = x.permute(1, 2, 0) | |
for decode in self.decoder: | |
skip = skips.pop(-1) | |
x = x + skip[..., :x.shape[-1]] | |
x = decode(x) | |
if self.resample == 2: | |
x = downsample2(x) | |
elif self.resample == 4: | |
x = downsample2(x) | |
x = downsample2(x) | |
x = x[..., :length] | |
return std * x | |
MODEL_FILE = "model.pt" | |
class DemucsPretrainedModel(DemucsModel): | |
def __init__(self, | |
config: DemucsConfig, | |
): | |
super(DemucsPretrainedModel, self).__init__( | |
# sample_rate=config.sample_rate, | |
in_channels=config.in_channels, | |
out_channels=config.out_channels, | |
hidden_channels=config.hidden_channels, | |
depth=config.depth, | |
kernel_size=config.kernel_size, | |
stride=config.stride, | |
causal=config.causal, | |
resample=config.resample, | |
growth=config.growth, | |
max_hidden=config.max_hidden, | |
do_normalize=config.do_normalize, | |
rescale=config.rescale, | |
floor=config.floor, | |
) | |
self.config = config | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
config = DemucsConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
model = cls(config) | |
if os.path.isdir(pretrained_model_name_or_path): | |
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) | |
else: | |
ckpt_file = pretrained_model_name_or_path | |
with open(ckpt_file, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def save_pretrained(self, | |
save_directory: Union[str, os.PathLike], | |
state_dict: Optional[dict] = None, | |
): | |
model = self | |
if state_dict is None: | |
state_dict = model.state_dict() | |
os.makedirs(save_directory, exist_ok=True) | |
# save state dict | |
model_file = os.path.join(save_directory, MODEL_FILE) | |
torch.save(state_dict, model_file) | |
# save config | |
config_file = os.path.join(save_directory, CONFIG_FILE) | |
self.config.to_yaml_file(config_file) | |
return save_directory | |
def main(): | |
config = DemucsConfig() | |
model = DemucsModel( | |
in_channels=config.in_channels, | |
out_channels=config.out_channels, | |
hidden_channels=config.hidden_channels, | |
depth=config.depth, | |
kernel_size=config.kernel_size, | |
stride=config.stride, | |
causal=config.causal, | |
resample=config.resample, | |
growth=config.growth, | |
max_hidden=config.max_hidden, | |
do_normalize=config.do_normalize, | |
rescale=config.rescale, | |
floor=config.floor, | |
) | |
print(model) | |
noisy = torch.rand(size=(1, 8000*4), dtype=torch.float32) | |
denoise = model.forward(noisy) | |
print(denoise.shape) | |
return | |
if __name__ == "__main__": | |
main() | |