|
""" |
|
"Music Mixing Style Transfer: A Contrastive Learning Approach to Disentangle Audio Effects" |
|
|
|
Implementation of neural networks used in the task 'Music Mixing Style Transfer' |
|
- 'FXencoder' |
|
- TCN based 'MixFXcloner' |
|
|
|
We modify the TCN implementation from: https://github.com/csteinmetz1/micro-tcn |
|
which was introduced in the work "Efficient neural networks for real-time modeling of analog dynamic range compression" by Christian J. Steinmetz, and Joshua D. Reiss |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.nn.init as init |
|
|
|
import os |
|
import sys |
|
currentdir = os.path.dirname(os.path.realpath(__file__)) |
|
sys.path.append(os.path.dirname(currentdir)) |
|
|
|
from networks.network_utils import * |
|
|
|
|
|
|
|
|
|
class FXencoder(nn.Module): |
|
def __init__(self, config): |
|
super(FXencoder, self).__init__() |
|
|
|
config["channels"].insert(0, 2) |
|
|
|
|
|
encoder = [] |
|
for i in range(len(config["kernels"])): |
|
if config["conv_block"]=='res': |
|
encoder.append(Res_ConvBlock(dimension=1, \ |
|
in_channels=config["channels"][i], \ |
|
out_channels=config["channels"][i+1], \ |
|
kernel_size=config["kernels"][i], \ |
|
stride=config["strides"][i], \ |
|
padding="SAME", \ |
|
dilation=config["dilation"][i], \ |
|
norm=config["norm"], \ |
|
activation=config["activation"], \ |
|
last_activation=config["activation"])) |
|
elif config["conv_block"]=='conv': |
|
encoder.append(ConvBlock(dimension=1, \ |
|
layer_num=1, \ |
|
in_channels=config["channels"][i], \ |
|
out_channels=config["channels"][i+1], \ |
|
kernel_size=config["kernels"][i], \ |
|
stride=config["strides"][i], \ |
|
padding="VALID", \ |
|
dilation=config["dilation"][i], \ |
|
norm=config["norm"], \ |
|
activation=config["activation"], \ |
|
last_activation=config["activation"], \ |
|
mode='conv')) |
|
self.encoder = nn.Sequential(*encoder) |
|
|
|
|
|
self.glob_pool = nn.AdaptiveAvgPool1d(1) |
|
|
|
|
|
def forward(self, input): |
|
enc_output = self.encoder(input) |
|
glob_pooled = self.glob_pool(enc_output).squeeze(-1) |
|
|
|
|
|
return glob_pooled |
|
|
|
|
|
|
|
|
|
import pytorch_lightning as pl |
|
class TCNModel(pl.LightningModule): |
|
""" Temporal convolutional network with conditioning module. |
|
Args: |
|
nparams (int): Number of conditioning parameters. |
|
ninputs (int): Number of input channels (mono = 1, stereo 2). Default: 1 |
|
noutputs (int): Number of output channels (mono = 1, stereo 2). Default: 1 |
|
nblocks (int): Number of total TCN blocks. Default: 10 |
|
kernel_size (int): Width of the convolutional kernels. Default: 3 |
|
dialation_growth (int): Compute the dilation factor at each block as dilation_growth ** (n % stack_size). Default: 1 |
|
channel_growth (int): Compute the output channels at each black as in_ch * channel_growth. Default: 2 |
|
channel_width (int): When channel_growth = 1 all blocks use convolutions with this many channels. Default: 64 |
|
stack_size (int): Number of blocks that constitute a single stack of blocks. Default: 10 |
|
grouped (bool): Use grouped convolutions to reduce the total number of parameters. Default: False |
|
causal (bool): Causal TCN configuration does not consider future input values. Default: False |
|
skip_connections (bool): Skip connections from each block to the output. Default: False |
|
num_examples (int): Number of evaluation audio examples to log after each epochs. Default: 4 |
|
""" |
|
def __init__(self, |
|
nparams, |
|
ninputs=1, |
|
noutputs=1, |
|
nblocks=10, |
|
kernel_size=3, |
|
dilation_growth=1, |
|
channel_growth=1, |
|
channel_width=32, |
|
stack_size=10, |
|
cond_dim=2048, |
|
grouped=False, |
|
causal=False, |
|
skip_connections=False, |
|
num_examples=4, |
|
save_dir=None, |
|
**kwargs): |
|
super(TCNModel, self).__init__() |
|
self.save_hyperparameters() |
|
|
|
self.blocks = torch.nn.ModuleList() |
|
for n in range(nblocks): |
|
in_ch = out_ch if n > 0 else ninputs |
|
|
|
if self.hparams.channel_growth > 1: |
|
out_ch = in_ch * self.hparams.channel_growth |
|
else: |
|
out_ch = self.hparams.channel_width |
|
|
|
dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size) |
|
self.blocks.append(TCNBlock(in_ch, |
|
out_ch, |
|
kernel_size=self.hparams.kernel_size, |
|
dilation=dilation, |
|
padding="same" if self.hparams.causal else "valid", |
|
causal=self.hparams.causal, |
|
cond_dim=cond_dim, |
|
grouped=self.hparams.grouped, |
|
conditional=True if self.hparams.nparams > 0 else False)) |
|
|
|
self.output = torch.nn.Conv1d(out_ch, noutputs, kernel_size=1) |
|
|
|
def forward(self, x, cond): |
|
|
|
for idx, block in enumerate(self.blocks): |
|
|
|
if isinstance(cond, list): |
|
x = block(x, cond[idx]) |
|
else: |
|
x = block(x, cond) |
|
skips = 0 |
|
|
|
out = torch.clamp(self.output(x + skips), min=-1, max=1) |
|
|
|
return out |
|
|
|
def compute_receptive_field(self): |
|
""" Compute the receptive field in samples.""" |
|
rf = self.hparams.kernel_size |
|
for n in range(1,self.hparams.nblocks): |
|
dilation = self.hparams.dilation_growth ** (n % self.hparams.stack_size) |
|
rf = rf + ((self.hparams.kernel_size-1) * dilation) |
|
return rf |
|
|
|
|
|
@staticmethod |
|
def add_model_specific_args(parent_parser): |
|
parser = ArgumentParser(parents=[parent_parser], add_help=False) |
|
|
|
parser.add_argument('--ninputs', type=int, default=1) |
|
parser.add_argument('--noutputs', type=int, default=1) |
|
parser.add_argument('--nblocks', type=int, default=4) |
|
parser.add_argument('--kernel_size', type=int, default=5) |
|
parser.add_argument('--dilation_growth', type=int, default=10) |
|
parser.add_argument('--channel_growth', type=int, default=1) |
|
parser.add_argument('--channel_width', type=int, default=32) |
|
parser.add_argument('--stack_size', type=int, default=10) |
|
parser.add_argument('--grouped', default=False, action='store_true') |
|
parser.add_argument('--causal', default=False, action="store_true") |
|
parser.add_argument('--skip_connections', default=False, action="store_true") |
|
|
|
return parser |
|
|
|
|
|
class TCNBlock(torch.nn.Module): |
|
def __init__(self, |
|
in_ch, |
|
out_ch, |
|
kernel_size=3, |
|
dilation=1, |
|
cond_dim=2048, |
|
grouped=False, |
|
causal=False, |
|
conditional=False, |
|
**kwargs): |
|
super(TCNBlock, self).__init__() |
|
|
|
self.in_ch = in_ch |
|
self.out_ch = out_ch |
|
self.kernel_size = kernel_size |
|
self.dilation = dilation |
|
self.grouped = grouped |
|
self.causal = causal |
|
self.conditional = conditional |
|
|
|
groups = out_ch if grouped and (in_ch % out_ch == 0) else 1 |
|
|
|
self.pad_length = ((kernel_size-1)*dilation) if self.causal else ((kernel_size-1)*dilation)//2 |
|
self.conv1 = torch.nn.Conv1d(in_ch, |
|
out_ch, |
|
kernel_size=kernel_size, |
|
padding=self.pad_length, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=False) |
|
if grouped: |
|
self.conv1b = torch.nn.Conv1d(out_ch, out_ch, kernel_size=1) |
|
|
|
if conditional: |
|
self.film = FiLM(cond_dim, out_ch) |
|
self.bn = torch.nn.BatchNorm1d(out_ch) |
|
|
|
self.relu = torch.nn.LeakyReLU() |
|
self.res = torch.nn.Conv1d(in_ch, |
|
out_ch, |
|
kernel_size=1, |
|
groups=in_ch, |
|
bias=False) |
|
|
|
def forward(self, x, p): |
|
x_in = x |
|
|
|
x = self.relu(self.bn(self.conv1(x))) |
|
x = self.film(x, p) |
|
|
|
x_res = self.res(x_in) |
|
|
|
if self.causal: |
|
x = x[..., :-self.pad_length] |
|
x += x_res |
|
|
|
return x |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
''' check model I/O shape ''' |
|
import yaml |
|
with open('networks/configs.yaml', 'r') as f: |
|
configs = yaml.full_load(f) |
|
|
|
batch_size = 32 |
|
sr = 44100 |
|
input_length = sr*5 |
|
|
|
input = torch.rand(batch_size, 2, input_length) |
|
print(f"Input Shape : {input.shape}\n") |
|
|
|
|
|
print('\n========== Audio Effects Encoder (FXencoder) ==========') |
|
model_arc = "FXencoder" |
|
model_options = "default" |
|
|
|
config = configs[model_arc][model_options] |
|
print(f"configuration: {config}") |
|
|
|
network = FXencoder(config) |
|
pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad) |
|
print(f"Number of trainable parameters : {pytorch_total_params}") |
|
|
|
|
|
output_c = network(input) |
|
print(f"Output Shape : {output_c.shape}") |
|
|
|
|
|
print('\n========== TCN based MixFXcloner ==========') |
|
model_arc = "TCN" |
|
model_options = "default" |
|
|
|
config = configs[model_arc][model_options] |
|
print(f"configuration: {config}") |
|
|
|
network = TCNModel(nparams=config["condition_dimension"], ninputs=2, noutputs=2, \ |
|
nblocks=config["nblocks"], \ |
|
dilation_growth=config["dilation_growth"], \ |
|
kernel_size=config["kernel_size"], \ |
|
channel_width=config["channel_width"], \ |
|
stack_size=config["stack_size"], \ |
|
cond_dim=config["condition_dimension"], \ |
|
causal=config["causal"]) |
|
pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad) |
|
print(f"Number of trainable parameters : {pytorch_total_params}\tReceptive field duration : {network.compute_receptive_field() / sr:.3f}") |
|
|
|
ref_embedding = output_c |
|
|
|
output = network(input, output_c) |
|
print(f"Output Shape : {output.shape}") |
|
|
|
|