Spaces:
Runtime error
Runtime error
import os | |
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/compressor_full.pt") | |
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/reverb_full.pt") | |
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/amp_full.pt") | |
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt") | |
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt") | |
import sys | |
import math | |
import torch | |
import librosa.display | |
import auraloss | |
import torchaudio | |
import numpy as np | |
import scipy.signal | |
from tqdm.notebook import tqdm | |
from time import sleep | |
import pyloudnorm as pyln | |
def measure_rt60(h, fs=1, decay_db=30, rt60_tgt=None): | |
""" | |
Analyze the RT60 of an impulse response. | |
Args: | |
h (ndarray): The discrete time impulse response as 1d array. | |
fs (float, optional): Sample rate of the impulse response. (Default: 48000) | |
decay_db (float, optional): The decay in decibels for which we actually estimate the time. (Default: 60) | |
rt60_tgt (float, optional): This parameter can be used to indicate a target RT60. (Default: None) | |
Returns: | |
est_rt60 (float): Estimated RT60. | |
""" | |
h = np.array(h) | |
fs = float(fs) | |
# The power of the impulse response in dB | |
power = h ** 2 | |
energy = np.cumsum(power[::-1])[::-1] # Integration according to Schroeder | |
try: | |
# remove the possibly all zero tail | |
i_nz = np.max(np.where(energy > 0)[0]) | |
energy = energy[:i_nz] | |
energy_db = 10 * np.log10(energy) | |
energy_db -= energy_db[0] | |
# -5 dB headroom | |
i_5db = np.min(np.where(-5 - energy_db > 0)[0]) | |
e_5db = energy_db[i_5db] | |
t_5db = i_5db / fs | |
# after decay | |
i_decay = np.min(np.where(-5 - decay_db - energy_db > 0)[0]) | |
t_decay = i_decay / fs | |
# compute the decay time | |
decay_time = t_decay - t_5db | |
est_rt60 = (60 / decay_db) * decay_time | |
except: | |
est_rt60 = np.array(0.0) | |
return est_rt60 | |
def causal_crop(x, length: int): | |
if x.shape[-1] != length: | |
stop = x.shape[-1] - 1 | |
start = stop - length | |
x = x[..., start:stop] | |
return x | |
class FiLM(torch.nn.Module): | |
def __init__( | |
self, | |
cond_dim, # dim of conditioning input | |
num_features, # dim of the conv channel | |
batch_norm=True, | |
): | |
super().__init__() | |
self.num_features = num_features | |
self.batch_norm = batch_norm | |
if batch_norm: | |
self.bn = torch.nn.BatchNorm1d(num_features, affine=False) | |
self.adaptor = torch.nn.Linear(cond_dim, num_features * 2) | |
def forward(self, x, cond): | |
cond = self.adaptor(cond) | |
g, b = torch.chunk(cond, 2, dim=-1) | |
g = g.permute(0, 2, 1) | |
b = b.permute(0, 2, 1) | |
if self.batch_norm: | |
x = self.bn(x) # apply BatchNorm without affine | |
x = (x * g) + b # then apply conditional affine | |
return x | |
class TCNBlock(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=0, activation=True): | |
super().__init__() | |
self.conv = torch.nn.Conv1d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
dilation=dilation, | |
padding=0, #((kernel_size-1)//2)*dilation, | |
bias=True) | |
if cond_dim > 0: | |
self.film = FiLM(cond_dim, out_channels, batch_norm=False) | |
if activation: | |
#self.act = torch.nn.Tanh() | |
self.act = torch.nn.PReLU() | |
self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False) | |
def forward(self, x, c=None): | |
x_in = x | |
x = self.conv(x) | |
if hasattr(self, "film"): | |
x = self.film(x, c) | |
if hasattr(self, "act"): | |
x = self.act(x) | |
x_res = causal_crop(self.res(x_in), x.shape[-1]) | |
x = x + x_res | |
return x | |
class TCN(torch.nn.Module): | |
def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.n_channels = n_channels | |
self.dilation_growth = dilation_growth | |
self.n_blocks = n_blocks | |
self.stack_size = n_blocks | |
self.blocks = torch.nn.ModuleList() | |
for n in range(n_blocks): | |
if n == 0: | |
in_ch = n_inputs | |
out_ch = n_channels | |
act = True | |
elif (n+1) == n_blocks: | |
in_ch = n_channels | |
out_ch = n_outputs | |
act = True | |
else: | |
in_ch = n_channels | |
out_ch = n_channels | |
act = True | |
dilation = dilation_growth ** n | |
self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act)) | |
def forward(self, x, c=None): | |
for block in self.blocks: | |
x = block(x, c) | |
return x | |
def compute_receptive_field(self): | |
"""Compute the receptive field in samples.""" | |
rf = self.kernel_size | |
for n in range(1, self.n_blocks): | |
dilation = self.dilation_growth ** (n % self.stack_size) | |
rf = rf + ((self.kernel_size - 1) * dilation) | |
return rf | |
# setup the pre-trained models | |
model_comp = torch.load("compressor_full.pt", map_location="cpu").eval() | |
model_verb = torch.load("reverb_full.pt", map_location="cpu").eval() | |
model_amp = torch.load("amp_full.pt", map_location="cpu").eval() | |
model_delay = torch.load("delay_full.pt", map_location="cpu").eval() | |
model_synth = torch.load("synth2synth_full.pt", map_location="cpu").eval() | |
def inference(aud, effect_type): | |
x_p, sample_rate = torchaudio.load(aud.file) | |
effect_type = effect_type #@param ["Compressor", "Reverb", "Amp", "Analog Delay", "Synth2Synth"] | |
gain_dB = -24 #@param {type:"slider", min:-24, max:24, step:0.1} | |
c0 = -1.4 #@param {type:"slider", min:-10, max:10, step:0.1} | |
c1 = 3 #@param {type:"slider", min:-10, max:10, step:0.1} | |
mix = 70 #@param {type:"slider", min:0, max:100, step:1} | |
width = 50 #@param {type:"slider", min:0, max:100, step:1} | |
max_length = 30 #@param {type:"slider", min:5, max:120, step:1} | |
stereo = True #@param {type:"boolean"} | |
tail = True #@param {type:"boolean"} | |
# select model type | |
if effect_type == "Compressor": | |
pt_model = model_comp | |
elif effect_type == "Reverb": | |
pt_model = model_verb | |
elif effect_type == "Amp": | |
pt_model = model_amp | |
elif effect_type == "Analog Delay": | |
pt_model = model_delay | |
elif effect_type == "Synth2Synth": | |
pt_model = model_synth | |
# measure the receptive field | |
pt_model_rf = pt_model.compute_receptive_field() | |
# crop input signal if needed | |
max_samples = int(sample_rate * max_length) | |
x_p_crop = x_p[:,:max_samples] | |
chs = x_p_crop.shape[0] | |
# if mono and stereo requested | |
if chs == 1 and stereo: | |
x_p_crop = x_p_crop.repeat(2,1) | |
chs = 2 | |
# pad the input signal | |
front_pad = pt_model_rf-1 | |
back_pad = 0 if not tail else front_pad | |
x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad)) | |
# design highpass filter | |
sos = scipy.signal.butter( | |
8, | |
20.0, | |
fs=sample_rate, | |
output="sos", | |
btype="highpass" | |
) | |
# compute linear gain | |
gain_ln = 10 ** (gain_dB / 20.0) | |
# process audio with pre-trained model | |
with torch.no_grad(): | |
y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad) | |
for n in range(chs): | |
if n == 0: | |
factor = (width*5e-3) | |
elif n == 1: | |
factor = -(width*5e-3) | |
c = torch.tensor([float(c0+factor), float(c1+factor)]).view(1,1,-1) | |
y_hat_ch = pt_model(gain_ln * x_p_pad[n,:].view(1,1,-1), c) | |
y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy()) | |
y_hat_ch = torch.tensor(y_hat_ch) | |
y_hat[n,:] = y_hat_ch | |
# pad the dry signal | |
x_dry = torch.nn.functional.pad(x_p_crop, (0,back_pad)) | |
# normalize each first | |
y_hat /= y_hat.abs().max() | |
x_dry /= x_dry.abs().max() | |
# mix | |
mix = mix/100.0 | |
y_hat = (mix * y_hat) + ((1-mix) * x_dry) | |
# remove transient | |
y_hat = y_hat[...,8192:] | |
y_hat /= y_hat.abs().max() | |
torchaudio.save("output.mp3", y_hat.view(chs,-1), sample_rate, compression=320.0) | |
return "output.mp3" | |