Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
import yaml | |
import json | |
import pyloudnorm as pyln | |
from hydra.utils import instantiate | |
from soxr import resample | |
from functools import partial | |
from torchcomp import coef2ms, ms2coef | |
from copy import deepcopy | |
from modules.utils import chain_functions, vec2statedict, get_chunks | |
from modules.fx import clip_delay_eq_Q | |
from plot_utils import get_log_mags_from_eq | |
title_md = "# Vocal Effects Generator" | |
description_md = """ | |
This is a demo of the paper [DiffVox: A Differentiable Model for Capturing and Analysing Professional Effects Distributions](https://arxiv.org/abs/2504.14735), accepted at DAFx 2025. | |
In this demo, you can upload a raw vocal audio file (in mono) and apply random effects to make it sound better! | |
The effects consist of series of EQ, compressor, delay, and reverb. | |
The generator is a PCA model derived from 365 vocal effects presets fitted with the same effects chain. | |
This interface allows you to control the principal components (PCs) of the generator, randomise them, and render the audio. | |
To give you some idea, we emperically found that the first PC controls the amount of reverb and the second PC controls the amount of brightness. | |
Note that adding these PCs together does not necessarily mean that their effects are additive in the final audio. | |
We found sometimes the effects of least important PCs are more perceptible. | |
Try to play around with the sliders and buttons and see what you can come up with! | |
Currently only PCs are tweakable, but in the future we will add more controls and visualisation tools. | |
For example: | |
- Directly controlling the parameters of the effects | |
- Visualising the PCA space | |
- Visualising the frequency responses/dynamic curves of the effects | |
""" | |
SLIDER_MAX = 3 | |
SLIDER_MIN = -3 | |
NUMBER_OF_PCS = 4 | |
TEMPERATURE = 0.7 | |
CONFIG_PATH = "presets/rt_config.yaml" | |
PCA_PARAM_FILE = "presets/internal/gaussian.npz" | |
INFO_PATH = "presets/internal/info.json" | |
MASK_PATH = "presets/internal/feature_mask.npy" | |
with open(CONFIG_PATH) as fp: | |
fx_config = yaml.safe_load(fp)["model"] | |
# Global effect | |
global_fx = instantiate(fx_config) | |
global_fx.eval() | |
pca_params = np.load(PCA_PARAM_FILE) | |
mean = pca_params["mean"] | |
cov = pca_params["cov"] | |
eigvals, eigvecs = np.linalg.eigh(cov) | |
eigvals = np.flip(eigvals, axis=0) | |
eigvecs = np.flip(eigvecs, axis=1) | |
U = eigvecs * np.sqrt(eigvals) | |
U = torch.from_numpy(U).float() | |
mean = torch.from_numpy(mean).float() | |
feature_mask = torch.from_numpy(np.load(MASK_PATH)) | |
# Global latent variable | |
# z = torch.zeros_like(mean) | |
with open(INFO_PATH) as f: | |
info = json.load(f) | |
param_keys = info["params_keys"] | |
original_shapes = list( | |
map(lambda lst: lst if len(lst) else [1], info["params_original_shapes"]) | |
) | |
*vec2dict_args, _ = get_chunks(param_keys, original_shapes) | |
vec2dict_args = [param_keys, original_shapes] + vec2dict_args | |
vec2dict = partial( | |
vec2statedict, | |
**dict( | |
zip( | |
[ | |
"keys", | |
"original_shapes", | |
"selected_chunks", | |
"position", | |
"U_matrix_shape", | |
], | |
vec2dict_args, | |
) | |
), | |
) | |
global_fx.load_state_dict(vec2dict(mean), strict=False) | |
meter = pyln.Meter(44100) | |
def z2x(z): | |
# close all figures to avoid too many open figures | |
plt.close("all") | |
x = U @ z + mean | |
# # print(z) | |
# fx.load_state_dict(vec2dict(x), strict=False) | |
# fx.apply(partial(clip_delay_eq_Q, Q=0.707)) | |
return x | |
def fx2x(fx): | |
plt.close("all") | |
state_dict = fx.state_dict() | |
flattened = torch.cat([state_dict[k].flatten() for k in param_keys]) | |
x = flattened[feature_mask] | |
return x | |
def x2z(x): | |
z = U.T @ (x - mean) | |
return z | |
def inference(audio, fx): | |
sr, y = audio | |
if sr != 44100: | |
y = resample(y, sr, 44100) | |
if y.dtype.kind != "f": | |
y = y / 32768.0 | |
if y.ndim == 1: | |
y = y[:, None] | |
loudness = meter.integrated_loudness(y) | |
y = pyln.normalize.loudness(y, loudness, -18.0) | |
y = torch.from_numpy(y).float().T.unsqueeze(0) | |
if y.shape[1] != 1: | |
y = y.mean(dim=1, keepdim=True) | |
direct, wet = fx(y) | |
direct = direct.squeeze(0).T.numpy() | |
wet = wet.squeeze(0).T.numpy() | |
rendered = direct + wet | |
# rendered = fx(y).squeeze(0).T.numpy() | |
if np.max(np.abs(rendered)) > 1: | |
scaler = np.max(np.abs(rendered)) | |
rendered = rendered / scaler | |
direct = direct / scaler | |
wet = wet / scaler | |
return ( | |
(44100, (rendered * 32768).astype(np.int16)), | |
(44100, (direct * 32768).astype(np.int16)), | |
( | |
44100, | |
(wet * 32768).astype(np.int16), | |
), | |
) | |
def get_important_pcs(n=10, **kwargs): | |
sliders = [ | |
gr.Slider(minimum=SLIDER_MIN, maximum=SLIDER_MAX, label=f"PC {i}", **kwargs) | |
for i in range(1, n + 1) | |
] | |
return sliders | |
def model2json(fx): | |
fx_names = ["PK1", "PK2", "LS", "HS", "LP", "HP", "DRC"] | |
results = {k: v.toJSON() for k, v in zip(fx_names, fx)} | { | |
"Panner": fx[7].pan.toJSON() | |
} | |
spatial_fx = { | |
"DLY": fx[7].effects[0].toJSON() | {"LP": fx[7].effects[0].eq.toJSON()}, | |
"FDN": fx[7].effects[1].toJSON() | |
| { | |
"Tone correction PEQ": { | |
k: v.toJSON() for k, v in zip(fx_names[:4], fx[7].effects[1].eq) | |
} | |
}, | |
"Cross Send (dB)": fx[7].params.sends_0.log10().mul(20).item(), | |
} | |
replace_neg_inf = lambda d: ( | |
{k: (replace_neg_inf(v) if v != -np.inf else -1e500) for k, v in d.items()} | |
if isinstance(d, dict) | |
else d | |
) | |
return { | |
"Direct": results, | |
"Sends": spatial_fx, | |
} | |
def plot_eq(fx): | |
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
w, eq_log_mags = get_log_mags_from_eq(fx[:6]) | |
ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") | |
for i, eq_log_mag in enumerate(eq_log_mags): | |
ax.plot(w, eq_log_mag, "k-", alpha=0.3) | |
ax.fill_between(w, eq_log_mag, 0, facecolor="gray", edgecolor="none", alpha=0.1) | |
ax.set_xlabel("Frequency (Hz)") | |
ax.set_ylabel("Magnitude (dB)") | |
ax.set_xlim(20, 20000) | |
ax.set_ylim(-40, 20) | |
ax.set_xscale("log") | |
ax.grid() | |
return fig | |
def plot_comp(fx): | |
fig, ax = plt.subplots(figsize=(6, 5), constrained_layout=True) | |
comp = fx[6] | |
cmp_th = fx[6].params.cmp_th.item() | |
exp_th = fx[6].params.exp_th.item() | |
cmp_ratio = fx[6].params.cmp_ratio.item() | |
exp_ratio = fx[6].params.exp_ratio.item() | |
make_up = fx[6].params.make_up.item() | |
# print(cmp_ratio, cmp_th, exp_ratio, exp_th, make_up) | |
comp_in = np.linspace(-80, 0, 100) | |
comp_curve = np.where( | |
comp_in > cmp_th, | |
comp_in - (comp_in - cmp_th) * (cmp_ratio - 1) / cmp_ratio, | |
comp_in, | |
) | |
comp_out = ( | |
np.where( | |
comp_curve < exp_th, | |
comp_curve - (exp_th - comp_curve) / exp_ratio, | |
comp_curve, | |
) | |
+ make_up | |
) | |
ax.plot(comp_in, comp_out, c="black", linestyle="-") | |
ax.plot(comp_in, comp_in, c="r", alpha=0.5) | |
ax.set_xlabel("Input Level (dB)") | |
ax.set_ylabel("Output Level (dB)") | |
ax.set_xlim(-80, 0) | |
ax.set_ylim(-80, 0) | |
ax.grid() | |
return fig | |
def plot_delay(fx): | |
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
delay = fx[7].effects[0] | |
w, eq_log_mags = get_log_mags_from_eq([fx[7].effects[0].eq]) | |
log_gain = fx[7].effects[0].params.gain.log10().item() * 20 | |
d = fx[7].effects[0].params.delay.item() / 1000 | |
log_mag = sum(eq_log_mags) | |
ax.plot(w, log_mag + log_gain, color="black", linestyle="-") | |
log_feedback = fx[7].effects[0].params.feedback.log10().item() * 20 | |
for i in range(1, 10): | |
feedback_log_mag = log_mag * (i + 1) + log_feedback * i + log_gain | |
ax.plot( | |
w, | |
feedback_log_mag, | |
c="black", | |
alpha=max(0, (10 - i * d * 4) / 10), | |
linestyle="-", | |
) | |
ax.set_xscale("log") | |
ax.set_xlim(20, 20000) | |
ax.set_ylim(-80, 0) | |
ax.set_xlabel("Frequency (Hz)") | |
ax.set_ylabel("Magnitude (dB)") | |
ax.grid() | |
return fig | |
def plot_reverb(fx): | |
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
fdn = fx[7].effects[1] | |
w, eq_log_mags = get_log_mags_from_eq(fdn.eq) | |
bc = fdn.params.c.norm() * fdn.params.b.norm() | |
log_bc = torch.log10(bc).item() * 20 | |
eq_log_mags = [x + log_bc / len(eq_log_mags) for x in eq_log_mags] | |
ax.plot(w, sum(eq_log_mags), color="black", linestyle="-") | |
ax.set_xlabel("Frequency (Hz)") | |
ax.set_ylabel("Magnitude (dB)") | |
ax.set_xlim(20, 20000) | |
ax.set_ylim(-40, 6) | |
ax.set_xscale("log") | |
ax.grid() | |
return fig | |
def plot_t60(fx): | |
fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) | |
fdn = fx[7].effects[1] | |
gamma = fdn.params.gamma.squeeze().numpy() | |
delays = fdn.delays.numpy() | |
w = np.linspace(0, 22050, gamma.size) | |
t60 = -60 / (20 * np.log10(gamma + 1e-10) / np.min(delays)) / 44100 | |
ax.plot(w, t60, color="black", linestyle="-") | |
ax.set_xlabel("Frequency (Hz)") | |
ax.set_ylabel("T60 (s)") | |
ax.set_xlim(20, 20000) | |
ax.set_ylim(0, 9) | |
ax.set_xscale("log") | |
ax.grid() | |
return fig | |
def update_param(m, attr_name, value): | |
match type(getattr(m, attr_name)): | |
case torch.nn.Parameter: | |
getattr(m, attr_name).data.copy_(value) | |
case _: | |
setattr(m, attr_name, torch.tensor(value)) | |
def update_atrt(comp, attr_name, value): | |
setattr(comp, attr_name, ms2coef(torch.tensor(value), 44100)) | |
def vec2fx(x): | |
fx = deepcopy(global_fx) | |
fx.load_state_dict(vec2dict(x), strict=False) | |
fx.apply(partial(clip_delay_eq_Q, Q=0.707)) | |
return fx | |
get_last_attribute = lambda m, attr_name: ( | |
(m, attr_name) | |
if "." not in attr_name | |
else (lambda x, *remain: get_last_attribute(getattr(m, x), ".".join(remain)))( | |
*attr_name.split(".") | |
) | |
) | |
with gr.Blocks() as demo: | |
z = gr.State(torch.zeros_like(mean)) | |
fx_params = gr.State(mean) | |
fx = vec2fx(fx_params.value) | |
gr.Markdown( | |
title_md, | |
elem_id="title", | |
) | |
with gr.Row(): | |
gr.Markdown( | |
description_md, | |
elem_id="description", | |
) | |
gr.Image("diffvox_diagram.png", elem_id="diagram") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
type="numpy", sources="upload", label="Input Audio", loop=True | |
) | |
with gr.Row(): | |
random_button = gr.Button( | |
f"Randomise PCs", | |
elem_id="randomise-button", | |
) | |
reset_button = gr.Button( | |
"Reset", | |
elem_id="reset-button", | |
) | |
render_button = gr.Button( | |
"Run", elem_id="render-button", variant="primary" | |
) | |
with gr.Row(): | |
s1 = gr.Slider( | |
minimum=SLIDER_MIN, | |
maximum=SLIDER_MAX, | |
label="PC 1", | |
value=0, | |
interactive=True, | |
) | |
s2 = gr.Slider( | |
minimum=SLIDER_MIN, | |
maximum=SLIDER_MAX, | |
label="PC 2", | |
value=0, | |
interactive=True, | |
) | |
with gr.Row(): | |
s3 = gr.Slider( | |
minimum=SLIDER_MIN, | |
maximum=SLIDER_MAX, | |
label="PC 3", | |
value=0, | |
interactive=True, | |
) | |
s4 = gr.Slider( | |
minimum=SLIDER_MIN, | |
maximum=SLIDER_MAX, | |
label="PC 4", | |
value=0, | |
interactive=True, | |
) | |
sliders = [s1, s2, s3, s4] | |
extra_pc_dropdown = gr.Dropdown( | |
list(range(NUMBER_OF_PCS + 1, mean.numel() + 1)), | |
label=f"PC > {NUMBER_OF_PCS}", | |
info="Select which extra PC to adjust", | |
interactive=True, | |
) | |
extra_slider = gr.Slider( | |
minimum=SLIDER_MIN, | |
maximum=SLIDER_MAX, | |
label="Extra PC", | |
value=0, | |
) | |
with gr.Column(): | |
audio_output = gr.Audio( | |
type="numpy", label="Output Audio", interactive=False, loop=True | |
) | |
direct_output = gr.Audio( | |
type="numpy", label="Direct Audio", interactive=False, loop=True | |
) | |
wet_output = gr.Audio( | |
type="numpy", label="Wet Audio", interactive=False, loop=True | |
) | |
_ = gr.Markdown("## Parametric EQ") | |
peq_plot = gr.Plot(plot_eq(fx), label="PEQ Frequency Response", elem_id="peq-plot") | |
with gr.Row(): | |
with gr.Column(min_width=160): | |
_ = gr.Markdown("High Pass") | |
hp = fx[5] | |
hp_freq = gr.Slider( | |
minimum=16, | |
maximum=5300, | |
value=fx[5].params.freq.item(), | |
interactive=True, | |
label="Frequency (Hz)", | |
) | |
hp_q = gr.Slider( | |
minimum=0.5, | |
maximum=10, | |
value=fx[5].params.Q.item(), | |
interactive=True, | |
label="Q", | |
) | |
with gr.Column(min_width=160): | |
_ = gr.Markdown("Low Shelf") | |
ls = fx[2] | |
ls_freq = gr.Slider( | |
minimum=30, | |
maximum=200, | |
value=fx[2].params.freq.item(), | |
interactive=True, | |
label="Frequency (Hz)", | |
) | |
ls_gain = gr.Slider( | |
minimum=-12, | |
maximum=12, | |
value=fx[2].params.gain.item(), | |
interactive=True, | |
label="Gain (dB)", | |
) | |
with gr.Column(min_width=160): | |
_ = gr.Markdown("Peak filter 1") | |
pk1 = fx[0] | |
pk1_freq = gr.Slider( | |
minimum=33, | |
maximum=5400, | |
value=fx[0].params.freq.item(), | |
interactive=True, | |
label="Frequency (Hz)", | |
) | |
pk1_gain = gr.Slider( | |
minimum=-12, | |
maximum=12, | |
value=fx[0].params.gain.item(), | |
interactive=True, | |
label="Gain (dB)", | |
) | |
pk1_q = gr.Slider( | |
minimum=0.2, | |
maximum=20, | |
value=fx[0].params.Q.item(), | |
interactive=True, | |
label="Q", | |
) | |
with gr.Column(min_width=160): | |
_ = gr.Markdown("Peak filter 2") | |
pk2 = fx[1] | |
pk2_freq = gr.Slider( | |
minimum=200, | |
maximum=17500, | |
value=fx[1].params.freq.item(), | |
interactive=True, | |
label="Frequency (Hz)", | |
) | |
pk2_gain = gr.Slider( | |
minimum=-12, | |
maximum=12, | |
value=fx[1].params.gain.item(), | |
interactive=True, | |
label="Gain (dB)", | |
) | |
pk2_q = gr.Slider( | |
minimum=0.2, | |
maximum=20, | |
value=fx[1].params.Q.item(), | |
interactive=True, | |
label="Q", | |
) | |
with gr.Column(min_width=160): | |
_ = gr.Markdown("High Shelf") | |
hs = fx[3] | |
hs_freq = gr.Slider( | |
minimum=750, | |
maximum=8300, | |
value=fx[3].params.freq.item(), | |
interactive=True, | |
label="Frequency (Hz)", | |
) | |
hs_gain = gr.Slider( | |
minimum=-12, | |
maximum=12, | |
value=fx[3].params.gain.item(), | |
interactive=True, | |
label="Gain (dB)", | |
) | |
with gr.Column(min_width=160): | |
_ = gr.Markdown("Low Pass") | |
lp = fx[4] | |
lp_freq = gr.Slider( | |
minimum=200, | |
maximum=18000, | |
value=fx[4].params.freq.item(), | |
interactive=True, | |
label="Frequency (Hz)", | |
) | |
lp_q = gr.Slider( | |
minimum=0.5, | |
maximum=10, | |
value=fx[4].params.Q.item(), | |
interactive=True, | |
label="Q", | |
) | |
_ = gr.Markdown("## Compressor and Expander") | |
with gr.Row(): | |
with gr.Column(): | |
comp = fx[6] | |
cmp_th = gr.Slider( | |
minimum=-60, | |
maximum=0, | |
value=fx[6].params.cmp_th.item(), | |
interactive=True, | |
label="fx[6]. Threshold (dB)", | |
) | |
cmp_ratio = gr.Slider( | |
minimum=1, | |
maximum=20, | |
value=fx[6].params.cmp_ratio.item(), | |
interactive=True, | |
label="fx[6]. Ratio", | |
) | |
make_up = gr.Slider( | |
minimum=-12, | |
maximum=12, | |
value=fx[6].params.make_up.item(), | |
interactive=True, | |
label="Make Up (dB)", | |
) | |
attack_time = gr.Slider( | |
minimum=0.1, | |
maximum=100, | |
value=coef2ms(fx[6].params.at, 44100).item(), | |
interactive=True, | |
label="Attack Time (ms)", | |
) | |
release_time = gr.Slider( | |
minimum=50, | |
maximum=1000, | |
value=coef2ms(fx[6].params.rt, 44100).item(), | |
interactive=True, | |
label="Release Time (ms)", | |
) | |
exp_ratio = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=fx[6].params.exp_ratio.item(), | |
interactive=True, | |
label="Exp. Ratio", | |
) | |
exp_th = gr.Slider( | |
minimum=-80, | |
maximum=0, | |
value=fx[6].params.exp_th.item(), | |
interactive=True, | |
label="Exp. Threshold (dB)", | |
) | |
with gr.Column(): | |
comp_plot = gr.Plot( | |
plot_comp(fx), label="Compressor Curve", elem_id="comp-plot" | |
) | |
_ = gr.Markdown("## Ping-Pong Delay") | |
with gr.Row(): | |
with gr.Column(): | |
delay = fx[7].effects[0] | |
delay_time = gr.Slider( | |
minimum=100, | |
maximum=1000, | |
value=fx[7].effects[0].params.delay.item(), | |
interactive=True, | |
label="Delay Time (ms)", | |
) | |
feedback = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=fx[7].effects[0].params.feedback.item(), | |
interactive=True, | |
label="Feedback", | |
) | |
delay_gain = gr.Slider( | |
minimum=-80, | |
maximum=0, | |
value=fx[7].effects[0].params.gain.log10().item() * 20, | |
interactive=True, | |
label="Gain (dB)", | |
) | |
odd_pan = gr.Slider( | |
minimum=-100, | |
maximum=100, | |
value=fx[7].effects[0].odd_pan.params.pan.item() * 200 - 100, | |
interactive=True, | |
label="Odd Delay Pan", | |
) | |
even_pan = gr.Slider( | |
minimum=-100, | |
maximum=100, | |
value=fx[7].effects[0].even_pan.params.pan.item() * 200 - 100, | |
interactive=True, | |
label="Even Delay Pan", | |
) | |
delay_lp_freq = gr.Slider( | |
minimum=200, | |
maximum=16000, | |
value=fx[7].effects[0].eq.params.freq.item(), | |
interactive=True, | |
label="Low Pass Frequency (Hz)", | |
) | |
with gr.Column(): | |
delay_plot = gr.Plot( | |
plot_delay(fx), label="Delay Frequency Response", elem_id="delay-plot" | |
) | |
with gr.Row(): | |
reverb_plot = gr.Plot( | |
plot_reverb(fx), | |
label="Reverb Tone Correction PEQ", | |
elem_id="reverb-plot", | |
min_width=160, | |
) | |
t60_plot = gr.Plot( | |
plot_t60(fx), label="Reverb T60", elem_id="t60-plot", min_width=160 | |
) | |
with gr.Row(): | |
json_output = gr.JSON( | |
model2json(fx), label="Effect Settings", max_height=800, open=True | |
) | |
update_pc = lambda z, i: z[:NUMBER_OF_PCS].tolist() + [z[i - 1].item()] | |
update_pc_outputs = sliders + [extra_slider] | |
peq_sliders = [ | |
pk1_freq, | |
pk1_gain, | |
pk1_q, | |
pk2_freq, | |
pk2_gain, | |
pk2_q, | |
ls_freq, | |
ls_gain, | |
hs_freq, | |
hs_gain, | |
lp_freq, | |
lp_q, | |
hp_freq, | |
hp_q, | |
] | |
peq_attr_names = ( | |
["freq", "gain", "Q"] * 2 + ["freq", "gain"] * 2 + ["freq", "Q"] * 2 | |
) | |
peq_indices = [0] * 3 + [1] * 3 + [2] * 2 + [3] * 2 + [4] * 2 + [5] * 2 | |
cmp_sliders = [ | |
cmp_th, | |
cmp_ratio, | |
make_up, | |
exp_ratio, | |
exp_th, | |
attack_time, | |
release_time, | |
] | |
cmp_update_funcs = [update_param] * 5 + [update_atrt] * 2 | |
cmp_attr_names = [ | |
"cmp_th", | |
"cmp_ratio", | |
"make_up", | |
"exp_ratio", | |
"exp_th", | |
"at", | |
"rt", | |
] | |
delay_sliders = [delay_time, feedback, delay_lp_freq, delay_gain, odd_pan, even_pan] | |
delay_update_funcs = [update_param] * 3 + [ | |
lambda m, a, v: update_param(m, a, 10 ** (v / 20)), | |
lambda m, a, v: update_param(m, a, (v + 100) / 200), | |
lambda m, a, v: update_param(m, a, (v + 100) / 200), | |
] | |
delay_attr_names = [ | |
"params.delay", | |
"params.feedback", | |
"eq.params.freq", | |
"params.gain", | |
"odd_pan.params.pan", | |
"even_pan.params.pan", | |
] | |
delay_update_plot_flag = [True] * 4 + [False] * 2 | |
all_effect_sliders = peq_sliders + cmp_sliders + delay_sliders | |
split_sizes = [len(peq_sliders), len(cmp_sliders), len(delay_sliders)] | |
def assign_fx_params(fx, *args): | |
peq_sliders, cmp_sliders, delay_sliders = ( | |
args[: split_sizes[0]], | |
args[split_sizes[0] : sum(split_sizes[:2])], | |
args[sum(split_sizes[:2]) :], | |
) | |
for idx, s, attr_name in zip(peq_indices, peq_sliders, peq_attr_names): | |
update_param(fx[idx].params, attr_name, s) | |
for f, s, attr_name in zip(cmp_update_funcs, cmp_sliders, cmp_attr_names): | |
f(fx[6].params, attr_name, s) | |
for f, s, attr_name in zip(delay_update_funcs, delay_sliders, delay_attr_names): | |
m, name = get_last_attribute(fx[7].effects[0], attr_name) | |
f(m, name, s) | |
return fx | |
for idx, s, attr_name in zip(peq_indices, peq_sliders, peq_attr_names): | |
s.input( | |
lambda *args, idx=idx, attr_name=attr_name: chain_functions( # chain_functions( | |
lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]), | |
lambda args: ( | |
update_param(args[0][idx].params, attr_name, args[1]), | |
args[0], | |
args[2], | |
), | |
lambda args: (fx2x(args[1]), *args[1:]), | |
lambda args: [x2z(args[0]), *args], | |
lambda args: args[:2] | |
+ [model2json(args[2]), plot_eq(args[2])] | |
+ update_pc(args[0], args[3]), | |
)( | |
args | |
), | |
inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders, | |
outputs=[z, fx_params, json_output, peq_plot] + update_pc_outputs, | |
) | |
for f, s, attr_name in zip(cmp_update_funcs, cmp_sliders, cmp_attr_names): | |
s.input( | |
lambda *args, attr_name=attr_name, f=f: chain_functions( | |
lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]), | |
lambda args: ( | |
f(args[0][6].params, attr_name, args[1]), | |
args[0], | |
args[2], | |
), | |
lambda args: (fx2x(args[1]), *args[1:]), | |
lambda args: [x2z(args[0]), *args], | |
lambda args: args[:2] | |
+ [model2json(args[2]), plot_comp(args[2])] | |
+ update_pc(args[0], args[3]), | |
)(args), | |
inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders, | |
outputs=[z, fx_params, json_output, comp_plot] + update_pc_outputs, | |
) | |
for f, s, attr_name, update_plot in zip( | |
delay_update_funcs, delay_sliders, delay_attr_names, delay_update_plot_flag | |
): | |
s.input( | |
lambda *args, f=f, attr_name=attr_name, update_plot=update_plot: chain_functions( | |
lambda args: (assign_fx_params(vec2fx(args[0]), *args[3:]), *args[1:3]), | |
lambda args: ( | |
# f(args[0][7].effects[0], attr_name, args[1]), | |
f(*get_last_attribute(args[0][7].effects[0], attr_name), args[1]), | |
args[0], | |
args[2], | |
), | |
lambda args: (fx2x(args[1]), *args[1:]), | |
lambda args: [x2z(args[0]), *args], | |
lambda args: ( | |
args[:2] | |
+ [model2json(args[2])] | |
+ ([plot_delay(args[2])] if update_plot else []) | |
+ update_pc(args[0], args[3]) | |
), | |
)( | |
args | |
), | |
inputs=[fx_params, s, extra_pc_dropdown] + all_effect_sliders, | |
outputs=[z, fx_params] | |
+ [json_output] | |
+ ([delay_plot] if update_plot else []) | |
+ update_pc_outputs, | |
) | |
render_button.click( | |
# lambda *args: ( | |
# lambda x: ( | |
# x, | |
# model2json(), | |
# ) | |
# )(inference(*args)), | |
# inference, | |
# lambda audio, x: inference(audio, vec2fx(x)), | |
lambda audio, *args: chain_functions( | |
lambda args: assign_fx_params(vec2fx(args[0]), *args[1:]), | |
partial(inference, audio), | |
)(args), | |
inputs=[ | |
audio_input, | |
fx_params, | |
] | |
+ all_effect_sliders, | |
outputs=[ | |
audio_output, | |
direct_output, | |
wet_output, | |
], | |
) | |
update_fx = lambda fx: [ | |
fx[0].params.freq.item(), | |
fx[0].params.gain.item(), | |
fx[0].params.Q.item(), | |
fx[1].params.freq.item(), | |
fx[1].params.gain.item(), | |
fx[1].params.Q.item(), | |
fx[2].params.freq.item(), | |
fx[2].params.gain.item(), | |
fx[3].params.freq.item(), | |
fx[3].params.gain.item(), | |
fx[4].params.freq.item(), | |
fx[4].params.Q.item(), | |
fx[5].params.freq.item(), | |
fx[5].params.Q.item(), | |
fx[6].params.cmp_th.item(), | |
fx[6].params.cmp_ratio.item(), | |
fx[6].params.make_up.item(), | |
fx[6].params.exp_th.item(), | |
fx[6].params.exp_ratio.item(), | |
coef2ms(fx[6].params.at, 44100).item(), | |
coef2ms(fx[6].params.rt, 44100).item(), | |
fx[7].effects[0].params.delay.item(), | |
fx[7].effects[0].params.feedback.item(), | |
fx[7].effects[0].params.gain.log10().item() * 20, | |
fx[7].effects[0].eq.params.freq.item(), | |
fx[7].effects[0].odd_pan.params.pan.item() * 200 - 100, | |
fx[7].effects[0].even_pan.params.pan.item() * 200 - 100, | |
] | |
update_fx_outputs = [ | |
pk1_freq, | |
pk1_gain, | |
pk1_q, | |
pk2_freq, | |
pk2_gain, | |
pk2_q, | |
ls_freq, | |
ls_gain, | |
hs_freq, | |
hs_gain, | |
lp_freq, | |
lp_q, | |
hp_freq, | |
hp_q, | |
cmp_th, | |
cmp_ratio, | |
make_up, | |
exp_th, | |
exp_ratio, | |
attack_time, | |
release_time, | |
delay_time, | |
feedback, | |
delay_gain, | |
delay_lp_freq, | |
odd_pan, | |
even_pan, | |
] | |
update_plots = lambda fx: [ | |
plot_eq(fx), | |
plot_comp(fx), | |
plot_delay(fx), | |
plot_reverb(fx), | |
plot_t60(fx), | |
] | |
update_plots_outputs = [ | |
peq_plot, | |
comp_plot, | |
delay_plot, | |
reverb_plot, | |
t60_plot, | |
] | |
update_all = lambda z, fx, i: update_pc(z, i) + update_fx(fx) + update_plots(fx) | |
update_all_outputs = update_pc_outputs + update_fx_outputs + update_plots_outputs | |
random_button.click( | |
chain_functions( | |
lambda i: (torch.randn_like(mean).clip(SLIDER_MIN, SLIDER_MAX), i), | |
lambda args: (args[0], vec2fx(z2x(args[0])), args[1]), | |
lambda args: update_all(*args) + [args[0]], | |
), | |
inputs=extra_pc_dropdown, | |
outputs=update_all_outputs + [z], | |
) | |
reset_button.click( | |
# lambda: (lambda _: [0 for _ in range(NUMBER_OF_PCS + 1)])(z.zero_()), | |
lambda: chain_functions( | |
lambda _: torch.zeros_like(mean), | |
lambda z: (z, vec2fx(z2x(z))), | |
lambda args: update_all(args[0], args[1], NUMBER_OF_PCS) + [args[0]], | |
)(None), | |
outputs=update_all_outputs + [z], | |
) | |
def update_z(z, s, i): | |
z[i] = s | |
return z | |
for i, slider in enumerate(sliders): | |
slider.input( | |
lambda *args, i=i: chain_functions( | |
lambda args: update_z(args[0], args[1], i), | |
lambda z: (z, vec2fx(z2x(z))), | |
lambda args: [args[0]] | |
+ update_fx(args[1]) | |
+ update_plots(args[1]) | |
+ [model2json(args[1])], | |
)(args), | |
inputs=[z, slider], | |
outputs=[z] + update_fx_outputs + update_plots_outputs + [json_output], | |
) | |
extra_slider.input( | |
lambda *xs: chain_functions( | |
lambda args: update_z(args[0], args[1], args[2]), | |
lambda z: (z, vec2fx(z2x(z))), | |
lambda args: [args[0]] | |
+ update_fx(args[1]) | |
+ update_plots(args[1]) | |
+ [model2json(args[1])], | |
)(xs), | |
inputs=[z, extra_slider, extra_pc_dropdown], | |
outputs=[z] + update_fx_outputs + update_plots_outputs + [json_output], | |
) | |
extra_pc_dropdown.input( | |
lambda z, i: z[i - 1].item(), | |
inputs=[z, extra_pc_dropdown], | |
outputs=extra_slider, | |
) | |
demo.launch() | |