Spaces:
Building
Building
from batch_processing import Batch | |
import argparse | |
from pathlib import Path | |
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience | |
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder | |
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT | |
from gyraudio.audio_separation.properties import ( | |
SHORT_NAME, CLEAN, NOISY, MIXED, PREDICTED, ANNOTATIONS, PATHS, BUFFERS, SAMPLING_RATE, NAME | |
) | |
import torch | |
from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint | |
from gyraudio.audio_separation.visualization.pre_load_audio import ( | |
parse_command_line_audio_load, load_buffers, audio_loading_batch) | |
from gyraudio.audio_separation.visualization.pre_load_custom_audio import ( | |
parse_command_line_generic_audio_load, generic_audio_loading_batch, | |
load_buffers_custom | |
) | |
from torchaudio.functional import resample | |
from typing import List | |
import numpy as np | |
import logging | |
from interactive_pipe.data_objects.curves import Curve, SingleCurve | |
from interactive_pipe import interactive, KeyboardControl, Control | |
from interactive_pipe import interactive_pipeline | |
from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player | |
default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
LEARNT_SAMPLING_RATE = 8000 | |
def remix(signals, snr=0., global_params={}): | |
signal = signals[BUFFERS][CLEAN] | |
noisy = signals[BUFFERS][NOISY] | |
alpha = 10 ** (-snr / 20) * torch.norm(signal) / torch.norm(noisy) | |
mixed_signal = signal + alpha * noisy | |
global_params["snr"] = snr | |
return mixed_signal | |
def augment(signals, mixed, std_dev=0., amplify=1.): | |
signals[BUFFERS][MIXED] *= amplify | |
signals[BUFFERS][NOISY] *= amplify | |
signals[BUFFERS][CLEAN] *= amplify | |
mixed = mixed*amplify+torch.randn_like(mixed)*std_dev | |
return signals, mixed | |
# @interactive( | |
# device=("cuda", ["cpu", "cuda"] | |
# ) if default_device == "cuda" else ("cpu", ["cpu"]) | |
# ) | |
def select_device(device=default_device, global_params={}): | |
global_params["device"] = device | |
# @interactive( | |
# model=KeyboardControl(value_default=0, value_range=[ | |
# 0, 99], keyup="pagedown", keydown="pageup") | |
# ) | |
ALL_MODELS = ["Tiny UNET", "Large UNET", "Large UNET (Bias Free)"] | |
def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}): | |
if isinstance(model, str): | |
model = ALL_MODELS.index(model) | |
assert isinstance(model, int) | |
selected_model = models[model % len(models)] | |
config = configs[model % len(models)] | |
short_name = config.get(SHORT_NAME, "") | |
annotations = config.get(ANNOTATIONS, "") | |
global_params[SHORT_NAME] = short_name | |
global_params[ANNOTATIONS] = annotations | |
device = global_params.get("device", "cpu") | |
with torch.no_grad(): | |
selected_model.eval() | |
selected_model.to(device) | |
predicted_signal, predicted_noise = selected_model( | |
mixed.to(device).unsqueeze(0)) | |
predicted_signal = predicted_signal.squeeze(0) | |
pred_curve = predicted_signal.detach().cpu().numpy() | |
return predicted_signal, pred_curve | |
def compute_metrics(pred, sig, global_params={}): | |
METRICS = "metrics" | |
target = sig[BUFFERS][CLEAN] | |
global_params[METRICS] = {} | |
global_params[METRICS]["MSE"] = torch.mean((target-pred.cpu())**2) | |
global_params[METRICS]["SNR"] = 10. * \ | |
torch.log10(torch.sum(target**2)/torch.sum((target-pred.cpu())**2)) | |
def get_trim(sig, zoom, center, num_samples=300): | |
N = len(sig) | |
native_ds = N/num_samples | |
center_idx = int(center*N) | |
window = int(num_samples/zoom*native_ds) | |
start_idx = max(0, center_idx - window//2) | |
end_idx = min(N, center_idx + window//2) | |
skip_factor = max(1, int(native_ds/zoom)) | |
return start_idx, end_idx, skip_factor | |
def zin(sig, zoom, center, num_samples=300): | |
start_idx, end_idx, skip_factor = get_trim( | |
sig, zoom, center, num_samples=num_samples) | |
out = np.zeros(num_samples) | |
trimmed = sig[start_idx:end_idx:skip_factor] | |
out[:len(trimmed)] = trimmed[:num_samples] | |
return out | |
def visualize_audio(signal: dict, mixed_signal, predicted_signal, zoom=1, zoomy=0., center=0.5, global_params={}): | |
"""Create curves | |
""" | |
selected = global_params.get("selected_audio", MIXED) | |
short_name = global_params.get(SHORT_NAME, "") | |
annotations = global_params.get(ANNOTATIONS, "") | |
zval = 1.5**zoom | |
start_idx, end_idx, _skip_factor = get_trim( | |
signal[BUFFERS][CLEAN][0, :], zval, center) | |
global_params["trim"] = dict(start=start_idx, end=end_idx) | |
selected = global_params.get("selected_audio", MIXED) | |
pred = SingleCurve(y=zin(predicted_signal[0, :], zval, center), | |
style="g-", label=("*" if selected == PREDICTED else " ")+f"predicted_{short_name} {annotations}") | |
clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center), | |
alpha=1., | |
style="k-", | |
linewidth=0.9, | |
label=("*" if selected == CLEAN else " ")+"clean") | |
noisy = SingleCurve(y=zin(signal[BUFFERS][NOISY][0, :], zval, center), | |
alpha=0.3, | |
style="y--", | |
linewidth=1, | |
label=("*" if selected == NOISY else " ") + "noisy" | |
) | |
mixed = SingleCurve(y=zin(mixed_signal[0, :], zval, center), style="r-", | |
alpha=0.1, | |
linewidth=2, | |
label=("*" if selected == MIXED else " ") + "mixed") | |
# true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center), | |
# alpha=0.3, style="b-", linewidth=1, label="true mixed") | |
curves = [noisy, mixed, pred, clean] | |
title = f"SNR in {global_params.get('snr', np.nan):.1f} dB" | |
if "selected_info" in global_params: | |
title += f" | {global_params['selected_info']}" | |
title += "\n" | |
for metric_name, metric_value in global_params.get("metrics", {}).items(): | |
title += f" | {metric_name} " | |
title += f"{metric_value:.2e}" if (abs(metric_value) < 1e-2 or abs(metric_value) | |
> 1000) else f"{metric_value:.2f}" | |
# if global_params.get("premixed_snr", None) is not None: | |
# title += f"| Premixed SNR : {global_params['premixed_snr']:.1f} dB" | |
return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title) | |
def signal_selector(signals, idx="Voice 1", idn=0, global_params={}): | |
idx = int(idx.split("Voice ")[-1]) | |
if isinstance(signals, dict): | |
clean_sigs = signals[CLEAN] | |
clean = clean_sigs[idx % len(clean_sigs)] | |
if BUFFERS not in clean: | |
load_buffers_custom(clean) | |
noise_sigs = signals[NOISY] | |
noise = noise_sigs[idn % len(noise_sigs)] | |
if BUFFERS not in noise: | |
load_buffers_custom(noise) | |
cbuf, nbuf = clean[BUFFERS], noise[BUFFERS] | |
if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE: | |
cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE) | |
clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE | |
if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE: | |
nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE) | |
noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE | |
min_length = min(cbuf.shape[-1], nbuf.shape[-1]) | |
min_length = min_length - min_length % 1024 | |
signal = { | |
PATHS: { | |
CLEAN: clean[PATHS], | |
NOISY: noise[PATHS] | |
}, | |
BUFFERS: { | |
CLEAN: cbuf[..., :1, :min_length], | |
NOISY: nbuf[..., :1, :min_length], | |
}, | |
NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}", | |
SAMPLING_RATE: LEARNT_SAMPLING_RATE | |
} | |
else: | |
# signals are loaded in CPU | |
signal = signals[idx % len(signals)] | |
if BUFFERS not in signal: | |
load_buffers(signal) | |
global_params["premixed_snr"] = signal.get("premixed_snr", None) | |
signal[NAME] = f"File={signal[NAME]}" | |
global_params["selected_info"] = signal[NAME] | |
global_params[SAMPLING_RATE] = signal[SAMPLING_RATE] | |
return signal | |
def interactive_audio_separation_processing(signals, model_list, config_list): | |
sig = signal_selector(signals) | |
mixed = remix(sig) | |
# sig, mixed = augment(sig, mixed) | |
select_device() | |
pred, pred_curve = audio_sep_inference(mixed, model_list, config_list) | |
compute_metrics(pred, sig) | |
sound = audio_selector(sig, mixed, pred) | |
curve = visualize_audio(sig, mixed, pred_curve) | |
trimmed_sound = audio_trim(sound) | |
audio_player(trimmed_sound) | |
return curve | |
def interactive_audio_separation_visualization( | |
all_signals: List[dict], | |
model_list: List[torch.nn.Module], | |
config_list: List[dict], | |
gui="gradio" | |
): | |
interactive_pipeline(gui=gui, cache=True, audio=True)( | |
interactive_audio_separation_processing | |
)( | |
all_signals, model_list, config_list | |
) | |
def visualization( | |
all_signals: List[dict], | |
model_list: List[torch.nn.Module], | |
config_list: List[dict], | |
device="cuda" | |
): | |
for signal in all_signals: | |
if BUFFERS not in signal: | |
load_buffers(signal, device="cpu") | |
clean = SingleCurve(y=signal[BUFFERS][CLEAN][0, :], label="clean") | |
noisy = SingleCurve(y=signal[BUFFERS][NOISY] | |
[0, :], label="noise", alpha=0.3) | |
curves = [clean, noisy] | |
for config, model in zip(config_list, model_list): | |
short_name = config.get(SHORT_NAME, "unknown") | |
predicted_signal, predicted_noise = model( | |
signal[BUFFERS][MIXED].to(device).unsqueeze(0)) | |
predicted = SingleCurve(y=predicted_signal.squeeze(0)[0, :].detach().cpu().numpy(), | |
label=f"predicted_{short_name}") | |
curves.append(predicted) | |
Curve(curves).show() | |
def parse_command_line(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: | |
if gradio_demo: | |
parser = parse_command_line_gradio(parser) | |
else: | |
parser = parse_command_line_generic(parser) | |
return parser | |
def parse_command_line_gradio(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: | |
if parser is None: | |
parser = parse_command_line_audio_load() | |
default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
iparse = parser.add_argument_group("Audio separation visualization") | |
iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=[4, 1004, 3001,], | |
help="Experiment ids to be inferred sequentially") | |
iparse.add_argument("-p", "--interactive", default=True, | |
action="store_true", help="Play = Interactive mode") | |
iparse.add_argument("-m", "--model-root", type=str, | |
default=EXPERIMENT_STORAGE_ROOT) | |
iparse.add_argument("-d", "--device", type=str, default=default_device, | |
choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"]) | |
iparse.add_argument("-gui", "--gui", type=str, | |
default="gradio", choices=["qt", "mpl", "gradio"]) | |
return parser | |
def parse_command_line_generic(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: | |
if parser is None: | |
parser = parse_command_line_audio_load() | |
default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
iparse = parser.add_argument_group("Audio separation visualization") | |
iparse.add_argument("-e", "--experiments", type=int, nargs="+", required=True, | |
help="Experiment ids to be inferred sequentially") | |
iparse.add_argument("-p", "--interactive", | |
action="store_true", help="Play = Interactive mode") | |
iparse.add_argument("-m", "--model-root", type=str, | |
default=EXPERIMENT_STORAGE_ROOT) | |
iparse.add_argument("-d", "--device", type=str, default=default_device, | |
choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"]) | |
iparse.add_argument("-gui", "--gui", type=str, | |
default="qt", choices=["qt", "mpl", "gradio"]) | |
return parser | |
def main(argv: List[str]): | |
"""Paired signals and noise in folders""" | |
batch = Batch(argv) | |
batch.set_io_description( | |
input_help='input audio files', | |
output_help=argparse.SUPPRESS | |
) | |
batch.set_multiprocessing_enabled(False) | |
parser = parse_command_line() | |
args = batch.parse_args(parser) | |
exp = args.experiments[0] | |
device = args.device | |
models_list = [] | |
config_list = [] | |
logging.info(f"Loading experiments models {args.experiments}") | |
for exp in args.experiments: | |
model_dir = Path(args.model_root) | |
short_name, model, config, _dl = get_experience(exp) | |
_, exp_dir = get_output_folder( | |
config, root_dir=model_dir, override=False) | |
assert exp_dir.exists( | |
), f"Experiment {short_name} does not exist in {model_dir}" | |
model.eval() | |
model.to(device) | |
model, __optimizer, epoch, config = load_checkpoint( | |
model, exp_dir, epoch=None, device=args.device) | |
config[SHORT_NAME] = short_name | |
models_list.append(model) | |
config_list.append(config) | |
logging.info("Load audio buffers:") | |
all_signals = batch.run(audio_loading_batch) | |
if not args.interactive: | |
visualization(all_signals, models_list, config_list, device=device) | |
else: | |
interactive_audio_separation_visualization( | |
all_signals, models_list, config_list, gui=args.gui) | |
def main_custom(argv: List[str]): | |
"""Handle custom noise and custom signals | |
""" | |
parser = parse_command_line() | |
parser.add_argument("-s", "--signal", type=str, required=True, | |
nargs="+", help="Signal to be preloaded") | |
parser.add_argument("-n", "--noise", type=str, required=True, | |
nargs="+", help="Noise to be preloaded") | |
args = parser.parse_args(argv) | |
exp = args.experiments[0] | |
device = args.device | |
models_list = [] | |
config_list = [] | |
logging.info(f"Loading experiments models {args.experiments}") | |
for exp in args.experiments: | |
model_dir = Path(args.model_root) | |
short_name, model, config, _dl = get_experience(exp) | |
_, exp_dir = get_output_folder( | |
config, root_dir=model_dir, override=False) | |
assert exp_dir.exists( | |
), f"Experiment {short_name} does not exist in {model_dir}" | |
model.eval() | |
model.to(device) | |
model, __optimizer, epoch, config = load_checkpoint( | |
model, exp_dir, epoch=None, device=args.device) | |
config[SHORT_NAME] = short_name | |
models_list.append(model) | |
config_list.append(config) | |
all_signals = {} | |
for args_paths, key in zip([args.signal, args.noise], [CLEAN, NOISY]): | |
new_argv = ["-i"] + args_paths | |
if args.preload: | |
new_argv += ["--preload"] | |
batch = Batch(new_argv) | |
new_parser = parse_command_line_generic_audio_load() | |
batch.set_io_description( | |
input_help=argparse.SUPPRESS, # 'input audio files', | |
output_help=argparse.SUPPRESS | |
) | |
batch.set_multiprocessing_enabled(False) | |
_ = batch.parse_args(new_parser) | |
all_signals[key] = batch.run(generic_audio_loading_batch) | |
interactive_audio_separation_visualization( | |
all_signals, models_list, config_list, gui=args.gui) | |