Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
0fbd9ed
1
Parent(s):
c43590a
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/compressor_full.pt")
|
4 |
+
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/reverb_full.pt")
|
5 |
+
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/amp_full.pt")
|
6 |
+
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt")
|
7 |
+
os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt")
|
8 |
+
|
9 |
+
import sys
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import librosa.display
|
13 |
+
import IPython
|
14 |
+
import auraloss
|
15 |
+
import torchaudio
|
16 |
+
import numpy as np
|
17 |
+
import scipy.signal
|
18 |
+
from google.colab import files
|
19 |
+
from tqdm.notebook import tqdm
|
20 |
+
from time import sleep
|
21 |
+
import matplotlib
|
22 |
+
import pyloudnorm as pyln
|
23 |
+
import matplotlib.pyplot as plt
|
24 |
+
from IPython.display import Image
|
25 |
+
|
26 |
+
def measure_rt60(h, fs=1, decay_db=30, rt60_tgt=None):
|
27 |
+
"""
|
28 |
+
Analyze the RT60 of an impulse response.
|
29 |
+
Args:
|
30 |
+
h (ndarray): The discrete time impulse response as 1d array.
|
31 |
+
fs (float, optional): Sample rate of the impulse response. (Default: 48000)
|
32 |
+
decay_db (float, optional): The decay in decibels for which we actually estimate the time. (Default: 60)
|
33 |
+
rt60_tgt (float, optional): This parameter can be used to indicate a target RT60. (Default: None)
|
34 |
+
Returns:
|
35 |
+
est_rt60 (float): Estimated RT60.
|
36 |
+
"""
|
37 |
+
|
38 |
+
h = np.array(h)
|
39 |
+
fs = float(fs)
|
40 |
+
|
41 |
+
# The power of the impulse response in dB
|
42 |
+
power = h ** 2
|
43 |
+
energy = np.cumsum(power[::-1])[::-1] # Integration according to Schroeder
|
44 |
+
|
45 |
+
try:
|
46 |
+
# remove the possibly all zero tail
|
47 |
+
i_nz = np.max(np.where(energy > 0)[0])
|
48 |
+
energy = energy[:i_nz]
|
49 |
+
energy_db = 10 * np.log10(energy)
|
50 |
+
energy_db -= energy_db[0]
|
51 |
+
|
52 |
+
# -5 dB headroom
|
53 |
+
i_5db = np.min(np.where(-5 - energy_db > 0)[0])
|
54 |
+
e_5db = energy_db[i_5db]
|
55 |
+
t_5db = i_5db / fs
|
56 |
+
|
57 |
+
# after decay
|
58 |
+
i_decay = np.min(np.where(-5 - decay_db - energy_db > 0)[0])
|
59 |
+
t_decay = i_decay / fs
|
60 |
+
|
61 |
+
# compute the decay time
|
62 |
+
decay_time = t_decay - t_5db
|
63 |
+
est_rt60 = (60 / decay_db) * decay_time
|
64 |
+
except:
|
65 |
+
est_rt60 = np.array(0.0)
|
66 |
+
|
67 |
+
return est_rt60
|
68 |
+
|
69 |
+
def causal_crop(x, length: int):
|
70 |
+
if x.shape[-1] != length:
|
71 |
+
stop = x.shape[-1] - 1
|
72 |
+
start = stop - length
|
73 |
+
x = x[..., start:stop]
|
74 |
+
return x
|
75 |
+
|
76 |
+
class FiLM(torch.nn.Module):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
cond_dim, # dim of conditioning input
|
80 |
+
num_features, # dim of the conv channel
|
81 |
+
batch_norm=True,
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
self.num_features = num_features
|
85 |
+
self.batch_norm = batch_norm
|
86 |
+
if batch_norm:
|
87 |
+
self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
|
88 |
+
self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)
|
89 |
+
|
90 |
+
def forward(self, x, cond):
|
91 |
+
|
92 |
+
cond = self.adaptor(cond)
|
93 |
+
g, b = torch.chunk(cond, 2, dim=-1)
|
94 |
+
g = g.permute(0, 2, 1)
|
95 |
+
b = b.permute(0, 2, 1)
|
96 |
+
|
97 |
+
if self.batch_norm:
|
98 |
+
x = self.bn(x) # apply BatchNorm without affine
|
99 |
+
x = (x * g) + b # then apply conditional affine
|
100 |
+
|
101 |
+
return x
|
102 |
+
|
103 |
+
class TCNBlock(torch.nn.Module):
|
104 |
+
def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=0, activation=True):
|
105 |
+
super().__init__()
|
106 |
+
self.conv = torch.nn.Conv1d(
|
107 |
+
in_channels,
|
108 |
+
out_channels,
|
109 |
+
kernel_size,
|
110 |
+
dilation=dilation,
|
111 |
+
padding=0, #((kernel_size-1)//2)*dilation,
|
112 |
+
bias=True)
|
113 |
+
if cond_dim > 0:
|
114 |
+
self.film = FiLM(cond_dim, out_channels, batch_norm=False)
|
115 |
+
if activation:
|
116 |
+
#self.act = torch.nn.Tanh()
|
117 |
+
self.act = torch.nn.PReLU()
|
118 |
+
self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False)
|
119 |
+
|
120 |
+
def forward(self, x, c=None):
|
121 |
+
x_in = x
|
122 |
+
x = self.conv(x)
|
123 |
+
if hasattr(self, "film"):
|
124 |
+
x = self.film(x, c)
|
125 |
+
if hasattr(self, "act"):
|
126 |
+
x = self.act(x)
|
127 |
+
x_res = causal_crop(self.res(x_in), x.shape[-1])
|
128 |
+
x = x + x_res
|
129 |
+
|
130 |
+
return x
|
131 |
+
|
132 |
+
class TCN(torch.nn.Module):
|
133 |
+
def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0):
|
134 |
+
super().__init__()
|
135 |
+
self.kernel_size = kernel_size
|
136 |
+
self.n_channels = n_channels
|
137 |
+
self.dilation_growth = dilation_growth
|
138 |
+
self.n_blocks = n_blocks
|
139 |
+
self.stack_size = n_blocks
|
140 |
+
|
141 |
+
self.blocks = torch.nn.ModuleList()
|
142 |
+
for n in range(n_blocks):
|
143 |
+
if n == 0:
|
144 |
+
in_ch = n_inputs
|
145 |
+
out_ch = n_channels
|
146 |
+
act = True
|
147 |
+
elif (n+1) == n_blocks:
|
148 |
+
in_ch = n_channels
|
149 |
+
out_ch = n_outputs
|
150 |
+
act = True
|
151 |
+
else:
|
152 |
+
in_ch = n_channels
|
153 |
+
out_ch = n_channels
|
154 |
+
act = True
|
155 |
+
|
156 |
+
dilation = dilation_growth ** n
|
157 |
+
self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act))
|
158 |
+
|
159 |
+
def forward(self, x, c=None):
|
160 |
+
for block in self.blocks:
|
161 |
+
x = block(x, c)
|
162 |
+
|
163 |
+
return x
|
164 |
+
|
165 |
+
def compute_receptive_field(self):
|
166 |
+
"""Compute the receptive field in samples."""
|
167 |
+
rf = self.kernel_size
|
168 |
+
for n in range(1, self.n_blocks):
|
169 |
+
dilation = self.dilation_growth ** (n % self.stack_size)
|
170 |
+
rf = rf + ((self.kernel_size - 1) * dilation)
|
171 |
+
return rf
|
172 |
+
|
173 |
+
# setup the pre-trained models
|
174 |
+
model_comp = torch.load("compressor_full.pt", map_location="cpu").eval()
|
175 |
+
model_verb = torch.load("reverb_full.pt", map_location="cpu").eval()
|
176 |
+
model_amp = torch.load("amp_full.pt", map_location="cpu").eval()
|
177 |
+
model_delay = torch.load("delay_full.pt", map_location="cpu").eval()
|
178 |
+
model_synth = torch.load("synth2synth_full.pt", map_location="cpu").eval()
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
def inference(aud, effect_type):
|
183 |
+
x_p, sample_rate = torchaudio.load(aud.file)
|
184 |
+
|
185 |
+
effect_type = effect_type #@param ["Compressor", "Reverb", "Amp", "Analog Delay", "Synth2Synth"]
|
186 |
+
gain_dB = -24 #@param {type:"slider", min:-24, max:24, step:0.1}
|
187 |
+
c0 = -1.4 #@param {type:"slider", min:-10, max:10, step:0.1}
|
188 |
+
c1 = 3 #@param {type:"slider", min:-10, max:10, step:0.1}
|
189 |
+
mix = 70 #@param {type:"slider", min:0, max:100, step:1}
|
190 |
+
width = 50 #@param {type:"slider", min:0, max:100, step:1}
|
191 |
+
max_length = 30 #@param {type:"slider", min:5, max:120, step:1}
|
192 |
+
stereo = True #@param {type:"boolean"}
|
193 |
+
tail = True #@param {type:"boolean"}
|
194 |
+
|
195 |
+
# select model type
|
196 |
+
if effect_type == "Compressor":
|
197 |
+
pt_model = model_comp
|
198 |
+
elif effect_type == "Reverb":
|
199 |
+
pt_model = model_verb
|
200 |
+
elif effect_type == "Amp":
|
201 |
+
pt_model = model_amp
|
202 |
+
elif effect_type == "Analog Delay":
|
203 |
+
pt_model = model_delay
|
204 |
+
elif effect_type == "Synth2Synth":
|
205 |
+
pt_model = model_synth
|
206 |
+
|
207 |
+
# measure the receptive field
|
208 |
+
pt_model_rf = pt_model.compute_receptive_field()
|
209 |
+
|
210 |
+
# crop input signal if needed
|
211 |
+
max_samples = int(sample_rate * max_length)
|
212 |
+
x_p_crop = x_p[:,:max_samples]
|
213 |
+
chs = x_p_crop.shape[0]
|
214 |
+
|
215 |
+
# if mono and stereo requested
|
216 |
+
if chs == 1 and stereo:
|
217 |
+
x_p_crop = x_p_crop.repeat(2,1)
|
218 |
+
chs = 2
|
219 |
+
|
220 |
+
# pad the input signal
|
221 |
+
front_pad = pt_model_rf-1
|
222 |
+
back_pad = 0 if not tail else front_pad
|
223 |
+
x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad))
|
224 |
+
|
225 |
+
# design highpass filter
|
226 |
+
sos = scipy.signal.butter(
|
227 |
+
8,
|
228 |
+
20.0,
|
229 |
+
fs=sample_rate,
|
230 |
+
output="sos",
|
231 |
+
btype="highpass"
|
232 |
+
)
|
233 |
+
|
234 |
+
# compute linear gain
|
235 |
+
gain_ln = 10 ** (gain_dB / 20.0)
|
236 |
+
|
237 |
+
# process audio with pre-trained model
|
238 |
+
with torch.no_grad():
|
239 |
+
y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad)
|
240 |
+
for n in range(chs):
|
241 |
+
if n == 0:
|
242 |
+
factor = (width*5e-3)
|
243 |
+
elif n == 1:
|
244 |
+
factor = -(width*5e-3)
|
245 |
+
c = torch.tensor([float(c0+factor), float(c1+factor)]).view(1,1,-1)
|
246 |
+
y_hat_ch = pt_model(gain_ln * x_p_pad[n,:].view(1,1,-1), c)
|
247 |
+
y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy())
|
248 |
+
y_hat_ch = torch.tensor(y_hat_ch)
|
249 |
+
y_hat[n,:] = y_hat_ch
|
250 |
+
|
251 |
+
# pad the dry signal
|
252 |
+
x_dry = torch.nn.functional.pad(x_p_crop, (0,back_pad))
|
253 |
+
|
254 |
+
# normalize each first
|
255 |
+
y_hat /= y_hat.abs().max()
|
256 |
+
x_dry /= x_dry.abs().max()
|
257 |
+
|
258 |
+
# mix
|
259 |
+
mix = mix/100.0
|
260 |
+
y_hat = (mix * y_hat) + ((1-mix) * x_dry)
|
261 |
+
|
262 |
+
# remove transient
|
263 |
+
y_hat = y_hat[...,8192:]
|
264 |
+
y_hat /= y_hat.abs().max()
|
265 |
+
|
266 |
+
torchaudio.save("output.mp3", y_hat.view(chs,-1), sample_rate, compression=320.0)
|
267 |
+
return "output.mp3"
|
268 |
+
|
269 |
+
|