Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
from typing import Literal | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import rearrange | |
def init_layer(layer: nn.Module) -> None: | |
"""Initialize a Linear or Convolutional layer.""" | |
nn.init.xavier_uniform_(layer.weight) | |
if hasattr(layer, "bias") and layer.bias is not None: | |
layer.bias.data.zero_() | |
def init_bn(bn: nn.Module) -> None: | |
"""Initialize a Batchnorm layer.""" | |
bn.bias.data.zero_() | |
bn.weight.data.fill_(1.0) | |
bn.running_mean.data.zero_() | |
bn.running_var.data.fill_(1.0) | |
def act(x: torch.Tensor, activation: str) -> torch.Tensor: | |
"""Activation function.""" | |
funcs = {"relu": F.relu_, "leaky_relu": lambda x: F.leaky_relu_(x, 0.01), "swish": lambda x: x * torch.sigmoid(x)} | |
return funcs.get(activation, lambda x: Exception("Incorrect activation!"))(x) | |
class Res2DAVPBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, avp_kernel_size, activation): | |
"""Convolutional residual block modified fromr bytedance/music_source_separation.""" | |
super().__init__() | |
padding = kernel_size[0] // 2, kernel_size[1] // 2 | |
self.activation = activation | |
self.bn1, self.bn2 = nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels) | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding, bias=False) | |
self.is_shortcut = in_channels != out_channels | |
if self.is_shortcut: | |
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) | |
self.avp = nn.AvgPool2d(avp_kernel_size) | |
self.init_weights() | |
def init_weights(self): | |
for m in [self.conv1, self.conv2] + ([self.shortcut] if self.is_shortcut else []): | |
init_layer(m) | |
for m in [self.bn1, self.bn2]: | |
init_bn(m) | |
def forward(self, x): | |
origin = x | |
x = act(self.bn1(self.conv1(x)), self.activation) | |
x = self.bn2(self.conv2(x)) | |
x += self.shortcut(origin) if self.is_shortcut else origin | |
x = act(x, self.activation) | |
return self.avp(x) | |
class PreEncoderBlockRes3B(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=(3, 3), avp_kernerl_size=(1, 2), activation='relu'): | |
"""Pre-Encoder with 3 Res2DAVPBlocks.""" | |
super().__init__() | |
self.blocks = nn.ModuleList([ | |
Res2DAVPBlock(in_channels if i == 0 else out_channels, out_channels, kernel_size, avp_kernerl_size, | |
activation) for i in range(3) | |
]) | |
def forward(self, x): # (B, T, F) | |
x = rearrange(x, 'b t f -> b 1 t f') | |
for block in self.blocks: | |
x = block(x) | |
return rearrange(x, 'b c t f -> b t f c') | |
def test_res3b(): | |
# mel-spec input | |
x = torch.randn(2, 256, 512) # (B, T, F) | |
pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128) | |
x = pre(x) # (2, 256, 64, 128): B T,F,C | |
x = torch.randn(2, 110, 1024) # (B, T, F) | |
pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128) | |
x = pre(x) # (2, 110, 128, 128): B,T,F,C | |
# ==================================================================================================================== | |
# PreEncoderBlockHFTT: hFT-Transformer-like Pre-encoder | |
# ==================================================================================================================== | |
class PreEncoderBlockHFTT(nn.Module): | |
def __init__(self, margin_pre=15, margin_post=16) -> None: | |
"""Pre-Encoder with hFT-Transformer-like convolutions.""" | |
super().__init__() | |
self.margin_pre, self.margin_post = margin_pre, margin_post | |
self.conv = nn.Conv2d(1, 4, kernel_size=(1, 5), padding='same', padding_mode='zeros') | |
self.emb_freq = nn.Linear(128, 128) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: (B, T, F) | |
x = rearrange(x, 'b t f -> b 1 f t') # (B, 1, F, T) or (2, 1, 128, 110) | |
x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (B, 1, F, T+margin) or (2,1,128,141) | |
x = self.conv(x) # (B, C, F, T+margin) or (2, 4, 128, 141) | |
x = x.unfold(dimension=3, size=32, step=1) # (B, c1, T, F, c2) or (2, 4, 128, 110, 32) | |
x = rearrange(x, 'b c1 f t c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128) | |
return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128) | |
def test_hftt(): | |
# from model.spectrogram import get_spectrogram_layer_from_audio_cfg | |
# from config.config import audio_cfg as default_audio_cfg | |
# audio_cfg = default_audio_cfg | |
# audio_cfg['codec'] = 'melspec' | |
# audio_cfg['hop_length'] = 300 | |
# audio_cfg['n_mels'] = 128 | |
# x = torch.randn(2, 1, 32767) | |
# mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg) | |
# x = mspec(x) | |
x = torch.randn(2, 110, 128) # (B, T, F) | |
pre_enc_hftt = PreEncoderBlockHFTT() | |
y = pre_enc_hftt(x) # (2, 110, 128, 128): B, T, F, C | |
# ==================================================================================================================== | |
# PreEncoderBlockRes3BHFTT: hFT-Transformer-like Pre-encoder with Res2DAVPBlock and spec input | |
# ==================================================================================================================== | |
class PreEncoderBlockRes3BHFTT(nn.Module): | |
def __init__(self, margin_pre: int = 15, margin_post: int = 16) -> None: | |
"""Pre-Encoder with hFT-Transformer-like convolutions. | |
Args: | |
margin_pre (int): padding before the input | |
margin_post (int): padding after the input | |
stack_dim (Literal['c', 'f']): stack dimension. channel or frequency | |
""" | |
super().__init__() | |
self.margin_pre, self.margin_post = margin_pre, margin_post | |
self.res3b = PreEncoderBlockRes3B(in_channels=1, out_channels=4) | |
self.emb_freq = nn.Linear(128, 128) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: (B, T, F) or (2, 110, 1024), input spectrogram | |
x = rearrange(x, 'b t f -> b f t') # (2, 1024, 110): B,F,T | |
x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (2, 1024, 141): B,F,T+margin | |
x = rearrange(x, 'b f t -> b t f') # (2, 141, 1024): B,T+margin,F | |
x = self.res3b(x) # (2, 141, 128, 4): B,T+margin,F,C | |
x = x.unfold(dimension=1, size=32, step=1) # (B, T, F, C1, C2) or (2, 110, 128, 4, 32) | |
x = rearrange(x, 'b t f c1 c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128) | |
return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128) | |
def test_res3b_hftt(): | |
# from model.spectrogram import get_spectrogram_layer_from_audio_cfg | |
# from config.config import audio_cfg as default_audio_cfg | |
# audio_cfg = default_audio_cfg | |
# audio_cfg['codec'] = 'spec' | |
# audio_cfg['hop_length'] = 300 | |
# x = torch.randn(2, 1, 32767) | |
# spec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg) | |
# x = spec(x) # (2, 110, 1024): B,T,F | |
x = torch.randn(2, 110, 1024) # (B, T, F) | |
pre_enc_res3b_hftt = PreEncoderBlockRes3BHFTT() | |
y = pre_enc_res3b_hftt(x) # (2, 110, 128, 128): B, T, F, C | |
# # ==================================================================================================================== | |
# # PreEncoderBlockConv1D: Pre-encoder without activation, with Melspec input | |
# # ==================================================================================================================== | |
# class PreEncoderBlockConv1D(nn.Module): | |
# def __init__(self, | |
# in_channels, | |
# out_channels, | |
# kernel_size=3) -> None: | |
# """Pre-Encoder with 1D convolution.""" | |
# super().__init__() | |
# self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1) | |
# self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1) | |
# def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# # x: (B, T, F) or (2, 128, 256), input melspec | |
# x = rearrange(x, 'b t f -> b f t') # (2, 256, 128): B,F,T | |
# x = self.conv1(x) # (2, 128, 128): B,F,T | |
# return rearrange(x, 'b f t -> b t f') # (2, 110, 128): B,T,F | |
# def test_conv1d(): | |
# # from model.spectrogram import get_spectrogram_layer_from_audio_cfg | |
# # from config.config import audio_cfg as default_audio_cfg | |
# # audio_cfg = default_audio_cfg | |
# # audio_cfg['codec'] = 'melspec' | |
# # audio_cfg['hop_length'] = 256 | |
# # audio_cfg['n_mels'] = 512 | |
# # x = torch.randn(2, 1, 32767) | |
# # mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg) | |
# # x = mspec(x) | |
# x = torch.randn(2, 128, 128) # (B, T, F) | |
# pre_enc_conv1d = PreEncoderBlockConv1D(in_channels=1, out_channels=128) | |
# y = pre_enc_conv1d(x) # (2, 110, 128, 128): B, T, F, C | |