HoneyTian's picture
update
e86d760
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://arxiv.org/abs/2006.12847
https://github.com/facebookresearch/denoiser
"""
import math
import os
from typing import List, Optional, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.demucs.configuration_demucs import DemucsConfig
from toolbox.torchaudio.models.demucs.resample import upsample2, downsample2
activation_layer_dict = {
"glu": nn.GLU,
"relu": nn.ReLU,
"identity": nn.Identity,
"sigmoid": nn.Sigmoid,
}
class BLSTM(nn.Module):
def __init__(self,
hidden_size: int,
num_layers: int = 2,
bidirectional: bool = True,
):
super().__init__()
self.lstm = nn.LSTM(bidirectional=bidirectional,
num_layers=num_layers,
hidden_size=hidden_size,
input_size=hidden_size
)
self.linear = None
if bidirectional:
self.linear = nn.Linear(2 * hidden_size, hidden_size)
def forward(self,
x: torch.Tensor,
hx: torch.Tensor = None
):
x, hx = self.lstm.forward(x, hx)
if self.linear:
x = self.linear(x)
return x, hx
def rescale_conv(conv, reference):
std = conv.weight.std().detach()
scale = (std / reference)**0.5
conv.weight.data /= scale
if conv.bias is not None:
conv.bias.data /= scale
def rescale_module(module, reference):
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
rescale_conv(sub, reference)
class DemucsModel(nn.Module):
def __init__(self,
in_channels: int = 1,
out_channels: int = 1,
hidden_channels: int = 48,
depth: int = 5,
kernel_size: int = 8,
stride: int = 4,
causal: bool = True,
resample: int = 4,
growth: int = 2,
max_hidden: int = 10_000,
do_normalize: bool = True,
rescale: float = 0.1,
floor: float = 1e-3,
):
super(DemucsModel, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.depth = depth
self.kernel_size = kernel_size
self.stride = stride
self.causal = causal
self.resample = resample
self.growth = growth
self.max_hidden = max_hidden
self.do_normalize = do_normalize
self.rescale = rescale
self.floor = floor
if resample not in [1, 2, 4]:
raise ValueError("Resample should be 1, 2 or 4.")
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()
for index in range(depth):
encode = []
encode += [
nn.Conv1d(in_channels, hidden_channels, kernel_size, stride),
nn.ReLU(),
nn.Conv1d(hidden_channels, hidden_channels * 2, 1),
nn.GLU(1),
]
self.encoder.append(nn.Sequential(*encode))
decode = []
decode += [
nn.Conv1d(hidden_channels, 2 * hidden_channels, 1),
nn.GLU(1),
nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride),
]
if index > 0:
decode.append(nn.ReLU())
self.decoder.insert(0, nn.Sequential(*decode))
out_channels = hidden_channels
in_channels = hidden_channels
hidden_channels = min(int(growth * hidden_channels), max_hidden)
self.lstm = BLSTM(in_channels, bidirectional=not causal)
if rescale:
rescale_module(self, reference=rescale)
@staticmethod
def valid_length(length: int, depth: int, kernel_size: int, stride: int, resample: int):
"""
Return the nearest valid length to use with the model so that
there is no time steps left over in a convolutions, e.g. for all
layers, size of the input - kernel_size % stride = 0.
If the mixture has a valid length, the estimated sources
will have exactly the same length.
"""
length = math.ceil(length * resample)
for idx in range(depth):
length = math.ceil((length - kernel_size) / stride) + 1
length = max(length, 1)
for idx in range(depth):
length = (length - 1) * stride + kernel_size
length = int(math.ceil(length / resample))
return int(length)
def forward(self, noisy: torch.Tensor):
"""
:param noisy: Tensor, shape: [batch_size, num_samples] or [batch_size, channels, num_samples]
:return:
"""
if noisy.dim() == 2:
noisy = noisy.unsqueeze(1)
# noisy shape: [batch_size, channels, num_samples]
if self.do_normalize:
mono = noisy.mean(dim=1, keepdim=True)
std = mono.std(dim=-1, keepdim=True)
noisy = noisy / (self.floor + std)
else:
std = 1
_, _, length = noisy.shape
x = noisy
length_ = self.valid_length(length, self.depth, self.kernel_size, self.stride, self.resample)
x = F.pad(x, (0, length_ - length))
if self.resample == 2:
x = upsample2(x)
elif self.resample == 4:
x = upsample2(x)
x = upsample2(x)
skips = []
for encode in self.encoder:
x = encode(x)
skips.append(x)
x = x.permute(2, 0, 1)
x, _ = self.lstm(x)
x = x.permute(1, 2, 0)
for decode in self.decoder:
skip = skips.pop(-1)
x = x + skip[..., :x.shape[-1]]
x = decode(x)
if self.resample == 2:
x = downsample2(x)
elif self.resample == 4:
x = downsample2(x)
x = downsample2(x)
x = x[..., :length]
return std * x
MODEL_FILE = "model.pt"
class DemucsPretrainedModel(DemucsModel):
def __init__(self,
config: DemucsConfig,
):
super(DemucsPretrainedModel, self).__init__(
# sample_rate=config.sample_rate,
in_channels=config.in_channels,
out_channels=config.out_channels,
hidden_channels=config.hidden_channels,
depth=config.depth,
kernel_size=config.kernel_size,
stride=config.stride,
causal=config.causal,
resample=config.resample,
growth=config.growth,
max_hidden=config.max_hidden,
do_normalize=config.do_normalize,
rescale=config.rescale,
floor=config.floor,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = DemucsConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
config = DemucsConfig()
model = DemucsModel(
in_channels=config.in_channels,
out_channels=config.out_channels,
hidden_channels=config.hidden_channels,
depth=config.depth,
kernel_size=config.kernel_size,
stride=config.stride,
causal=config.causal,
resample=config.resample,
growth=config.growth,
max_hidden=config.max_hidden,
do_normalize=config.do_normalize,
rescale=config.rescale,
floor=config.floor,
)
print(model)
noisy = torch.rand(size=(1, 8000*4), dtype=torch.float32)
denoise = model.forward(noisy)
print(denoise.shape)
return
if __name__ == "__main__":
main()