balthou's picture
bugfix on np.nan
db3db91
from batch_processing import Batch
import argparse
from pathlib import Path
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
from gyraudio.audio_separation.properties import (
SHORT_NAME, CLEAN, NOISY, MIXED, PREDICTED, ANNOTATIONS, PATHS, BUFFERS, SAMPLING_RATE, NAME
)
import torch
from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
from gyraudio.audio_separation.visualization.pre_load_audio import (
parse_command_line_audio_load, load_buffers, audio_loading_batch)
from gyraudio.audio_separation.visualization.pre_load_custom_audio import (
parse_command_line_generic_audio_load, generic_audio_loading_batch,
load_buffers_custom
)
from torchaudio.functional import resample
from typing import List
import numpy as np
import logging
from interactive_pipe.data_objects.curves import Curve, SingleCurve
from interactive_pipe import interactive, KeyboardControl, Control
from interactive_pipe import interactive_pipeline
from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player
default_device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNT_SAMPLING_RATE = 8000
@interactive(
snr=(0., [-10., 10.], "SNR [dB]")
)
def remix(signals, snr=0., global_params={}):
signal = signals[BUFFERS][CLEAN]
noisy = signals[BUFFERS][NOISY]
alpha = 10 ** (-snr / 20) * torch.norm(signal) / torch.norm(noisy)
mixed_signal = signal + alpha * noisy
global_params["snr"] = snr
return mixed_signal
@interactive(std_dev=Control(0., value_range=[0., 0.1], name="extra noise std", step=0.0001),
amplify=(1., [0., 10.], "amplification of everything"))
def augment(signals, mixed, std_dev=0., amplify=1.):
signals[BUFFERS][MIXED] *= amplify
signals[BUFFERS][NOISY] *= amplify
signals[BUFFERS][CLEAN] *= amplify
mixed = mixed*amplify+torch.randn_like(mixed)*std_dev
return signals, mixed
# @interactive(
# device=("cuda", ["cpu", "cuda"]
# ) if default_device == "cuda" else ("cpu", ["cpu"])
# )
def select_device(device=default_device, global_params={}):
global_params["device"] = device
# @interactive(
# model=KeyboardControl(value_default=0, value_range=[
# 0, 99], keyup="pagedown", keydown="pageup")
# )
ALL_MODELS = ["Tiny UNET", "Large UNET", "Large UNET (Bias Free)"]
@interactive(
model=(ALL_MODELS[-1], ALL_MODELS, "Model selection")
)
def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}):
if isinstance(model, str):
model = ALL_MODELS.index(model)
assert isinstance(model, int)
selected_model = models[model % len(models)]
config = configs[model % len(models)]
short_name = config.get(SHORT_NAME, "")
annotations = config.get(ANNOTATIONS, "")
global_params[SHORT_NAME] = short_name
global_params[ANNOTATIONS] = annotations
device = global_params.get("device", "cpu")
with torch.no_grad():
selected_model.eval()
selected_model.to(device)
predicted_signal, predicted_noise = selected_model(
mixed.to(device).unsqueeze(0))
predicted_signal = predicted_signal.squeeze(0)
pred_curve = predicted_signal.detach().cpu().numpy()
return predicted_signal, pred_curve
def compute_metrics(pred, sig, global_params={}):
METRICS = "metrics"
target = sig[BUFFERS][CLEAN]
global_params[METRICS] = {}
global_params[METRICS]["MSE"] = torch.mean((target-pred.cpu())**2)
global_params[METRICS]["SNR"] = 10. * \
torch.log10(torch.sum(target**2)/torch.sum((target-pred.cpu())**2))
def get_trim(sig, zoom, center, num_samples=300):
N = len(sig)
native_ds = N/num_samples
center_idx = int(center*N)
window = int(num_samples/zoom*native_ds)
start_idx = max(0, center_idx - window//2)
end_idx = min(N, center_idx + window//2)
skip_factor = max(1, int(native_ds/zoom))
return start_idx, end_idx, skip_factor
def zin(sig, zoom, center, num_samples=300):
start_idx, end_idx, skip_factor = get_trim(
sig, zoom, center, num_samples=num_samples)
out = np.zeros(num_samples)
trimmed = sig[start_idx:end_idx:skip_factor]
out[:len(trimmed)] = trimmed[:num_samples]
return out
@interactive(
center=KeyboardControl(value_default=0.5, value_range=[
0., 1.], step=0.01, keyup="6", keydown="4", name="Trim (center)"),
zoom=KeyboardControl(value_default=0., value_range=[
0., 15.], step=1, keyup="+", keydown="-", name="Trim (zoom)"),
# zoomy=KeyboardControl(
# value_default=0., value_range=[-15., 15.], step=1, keyup="up", keydown="down")
)
def visualize_audio(signal: dict, mixed_signal, predicted_signal, zoom=1, zoomy=0., center=0.5, global_params={}):
"""Create curves
"""
selected = global_params.get("selected_audio", MIXED)
short_name = global_params.get(SHORT_NAME, "")
annotations = global_params.get(ANNOTATIONS, "")
zval = 1.5**zoom
start_idx, end_idx, _skip_factor = get_trim(
signal[BUFFERS][CLEAN][0, :], zval, center)
global_params["trim"] = dict(start=start_idx, end=end_idx)
selected = global_params.get("selected_audio", MIXED)
pred = SingleCurve(y=zin(predicted_signal[0, :], zval, center),
style="g-", label=("*" if selected == PREDICTED else " ")+f"predicted_{short_name} {annotations}")
clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center),
alpha=1.,
style="k-",
linewidth=0.9,
label=("*" if selected == CLEAN else " ")+"clean")
noisy = SingleCurve(y=zin(signal[BUFFERS][NOISY][0, :], zval, center),
alpha=0.3,
style="y--",
linewidth=1,
label=("*" if selected == NOISY else " ") + "noisy"
)
mixed = SingleCurve(y=zin(mixed_signal[0, :], zval, center), style="r-",
alpha=0.1,
linewidth=2,
label=("*" if selected == MIXED else " ") + "mixed")
# true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center),
# alpha=0.3, style="b-", linewidth=1, label="true mixed")
curves = [noisy, mixed, pred, clean]
title = f"SNR in {global_params.get('snr', np.nan):.1f} dB"
if "selected_info" in global_params:
title += f" | {global_params['selected_info']}"
title += "\n"
for metric_name, metric_value in global_params.get("metrics", {}).items():
title += f" | {metric_name} "
title += f"{metric_value:.2e}" if (abs(metric_value) < 1e-2 or abs(metric_value)
> 1000) else f"{metric_value:.2f}"
# if global_params.get("premixed_snr", None) is not None:
# title += f"| Premixed SNR : {global_params['premixed_snr']:.1f} dB"
return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title)
@interactive(
idx=("Voice 1", ["Voice 1", "Voice 2",
"Voice 3", "Voice 4"], "Clean signal"),
# idx=KeyboardControl(value_default=0, value_range=[
# 0, 1000], modulo=True, keyup="8", keydown="2", name="clean signal index"),
# idn=KeyboardControl(value_default=0, value_range=[
# 0, 1000], modulo=True, keyup="9", keydown="3", name="noisy signal index")
)
def signal_selector(signals, idx="Voice 1", idn=0, global_params={}):
idx = int(idx.split("Voice ")[-1])
if isinstance(signals, dict):
clean_sigs = signals[CLEAN]
clean = clean_sigs[idx % len(clean_sigs)]
if BUFFERS not in clean:
load_buffers_custom(clean)
noise_sigs = signals[NOISY]
noise = noise_sigs[idn % len(noise_sigs)]
if BUFFERS not in noise:
load_buffers_custom(noise)
cbuf, nbuf = clean[BUFFERS], noise[BUFFERS]
if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE:
nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE)
noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE
min_length = min(cbuf.shape[-1], nbuf.shape[-1])
min_length = min_length - min_length % 1024
signal = {
PATHS: {
CLEAN: clean[PATHS],
NOISY: noise[PATHS]
},
BUFFERS: {
CLEAN: cbuf[..., :1, :min_length],
NOISY: nbuf[..., :1, :min_length],
},
NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}",
SAMPLING_RATE: LEARNT_SAMPLING_RATE
}
else:
# signals are loaded in CPU
signal = signals[idx % len(signals)]
if BUFFERS not in signal:
load_buffers(signal)
global_params["premixed_snr"] = signal.get("premixed_snr", None)
signal[NAME] = f"File={signal[NAME]}"
global_params["selected_info"] = signal[NAME]
global_params[SAMPLING_RATE] = signal[SAMPLING_RATE]
return signal
def interactive_audio_separation_processing(signals, model_list, config_list):
sig = signal_selector(signals)
mixed = remix(sig)
# sig, mixed = augment(sig, mixed)
select_device()
pred, pred_curve = audio_sep_inference(mixed, model_list, config_list)
compute_metrics(pred, sig)
sound = audio_selector(sig, mixed, pred)
curve = visualize_audio(sig, mixed, pred_curve)
trimmed_sound = audio_trim(sound)
audio_player(trimmed_sound)
return curve
def interactive_audio_separation_visualization(
all_signals: List[dict],
model_list: List[torch.nn.Module],
config_list: List[dict],
gui="gradio"
):
interactive_pipeline(gui=gui, cache=True, audio=True)(
interactive_audio_separation_processing
)(
all_signals, model_list, config_list
)
def visualization(
all_signals: List[dict],
model_list: List[torch.nn.Module],
config_list: List[dict],
device="cuda"
):
for signal in all_signals:
if BUFFERS not in signal:
load_buffers(signal, device="cpu")
clean = SingleCurve(y=signal[BUFFERS][CLEAN][0, :], label="clean")
noisy = SingleCurve(y=signal[BUFFERS][NOISY]
[0, :], label="noise", alpha=0.3)
curves = [clean, noisy]
for config, model in zip(config_list, model_list):
short_name = config.get(SHORT_NAME, "unknown")
predicted_signal, predicted_noise = model(
signal[BUFFERS][MIXED].to(device).unsqueeze(0))
predicted = SingleCurve(y=predicted_signal.squeeze(0)[0, :].detach().cpu().numpy(),
label=f"predicted_{short_name}")
curves.append(predicted)
Curve(curves).show()
def parse_command_line(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
if gradio_demo:
parser = parse_command_line_gradio(parser)
else:
parser = parse_command_line_generic(parser)
return parser
def parse_command_line_gradio(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
if parser is None:
parser = parse_command_line_audio_load()
default_device = "cuda" if torch.cuda.is_available() else "cpu"
iparse = parser.add_argument_group("Audio separation visualization")
iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=[4, 1004, 3001,],
help="Experiment ids to be inferred sequentially")
iparse.add_argument("-p", "--interactive", default=True,
action="store_true", help="Play = Interactive mode")
iparse.add_argument("-m", "--model-root", type=str,
default=EXPERIMENT_STORAGE_ROOT)
iparse.add_argument("-d", "--device", type=str, default=default_device,
choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"])
iparse.add_argument("-gui", "--gui", type=str,
default="gradio", choices=["qt", "mpl", "gradio"])
return parser
def parse_command_line_generic(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser:
if parser is None:
parser = parse_command_line_audio_load()
default_device = "cuda" if torch.cuda.is_available() else "cpu"
iparse = parser.add_argument_group("Audio separation visualization")
iparse.add_argument("-e", "--experiments", type=int, nargs="+", required=True,
help="Experiment ids to be inferred sequentially")
iparse.add_argument("-p", "--interactive",
action="store_true", help="Play = Interactive mode")
iparse.add_argument("-m", "--model-root", type=str,
default=EXPERIMENT_STORAGE_ROOT)
iparse.add_argument("-d", "--device", type=str, default=default_device,
choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"])
iparse.add_argument("-gui", "--gui", type=str,
default="qt", choices=["qt", "mpl", "gradio"])
return parser
def main(argv: List[str]):
"""Paired signals and noise in folders"""
batch = Batch(argv)
batch.set_io_description(
input_help='input audio files',
output_help=argparse.SUPPRESS
)
batch.set_multiprocessing_enabled(False)
parser = parse_command_line()
args = batch.parse_args(parser)
exp = args.experiments[0]
device = args.device
models_list = []
config_list = []
logging.info(f"Loading experiments models {args.experiments}")
for exp in args.experiments:
model_dir = Path(args.model_root)
short_name, model, config, _dl = get_experience(exp)
_, exp_dir = get_output_folder(
config, root_dir=model_dir, override=False)
assert exp_dir.exists(
), f"Experiment {short_name} does not exist in {model_dir}"
model.eval()
model.to(device)
model, __optimizer, epoch, config = load_checkpoint(
model, exp_dir, epoch=None, device=args.device)
config[SHORT_NAME] = short_name
models_list.append(model)
config_list.append(config)
logging.info("Load audio buffers:")
all_signals = batch.run(audio_loading_batch)
if not args.interactive:
visualization(all_signals, models_list, config_list, device=device)
else:
interactive_audio_separation_visualization(
all_signals, models_list, config_list, gui=args.gui)
def main_custom(argv: List[str]):
"""Handle custom noise and custom signals
"""
parser = parse_command_line()
parser.add_argument("-s", "--signal", type=str, required=True,
nargs="+", help="Signal to be preloaded")
parser.add_argument("-n", "--noise", type=str, required=True,
nargs="+", help="Noise to be preloaded")
args = parser.parse_args(argv)
exp = args.experiments[0]
device = args.device
models_list = []
config_list = []
logging.info(f"Loading experiments models {args.experiments}")
for exp in args.experiments:
model_dir = Path(args.model_root)
short_name, model, config, _dl = get_experience(exp)
_, exp_dir = get_output_folder(
config, root_dir=model_dir, override=False)
assert exp_dir.exists(
), f"Experiment {short_name} does not exist in {model_dir}"
model.eval()
model.to(device)
model, __optimizer, epoch, config = load_checkpoint(
model, exp_dir, epoch=None, device=args.device)
config[SHORT_NAME] = short_name
models_list.append(model)
config_list.append(config)
all_signals = {}
for args_paths, key in zip([args.signal, args.noise], [CLEAN, NOISY]):
new_argv = ["-i"] + args_paths
if args.preload:
new_argv += ["--preload"]
batch = Batch(new_argv)
new_parser = parse_command_line_generic_audio_load()
batch.set_io_description(
input_help=argparse.SUPPRESS, # 'input audio files',
output_help=argparse.SUPPRESS
)
batch.set_multiprocessing_enabled(False)
_ = batch.parse_args(new_parser)
all_signals[key] = batch.run(generic_audio_loading_batch)
interactive_audio_separation_visualization(
all_signals, models_list, config_list, gui=args.gui)