Spaces:
Runtime error
Runtime error
from audiotools import AudioSignal | |
import torch | |
from pathlib import Path | |
import argbind | |
from tqdm import tqdm | |
import random | |
from collections import defaultdict | |
def coarse2fine_infer( | |
signal, | |
model, | |
vqvae, | |
device, | |
signal_window=3, | |
signal_hop=1.5, | |
max_excerpts=25, | |
): | |
output = defaultdict(list) | |
# split into 3 seconds | |
windows = [s for s in signal.clone().windows(signal_window, signal_hop)] | |
random.shuffle(windows) | |
for w in windows[1:max_excerpts]: # skip the first window since it's mostly zero padded? | |
# batch the signal into chunks of 3 | |
with torch.no_grad(): | |
# get codes | |
w = w.to(device) | |
z = vqvae.encode(w.audio_data, w.sample_rate)["codes"] | |
model.to(device) | |
output["reconstructed"] = model.to_signal(z, vqvae).cpu() | |
# make a full mask | |
mask = torch.ones_like(z) | |
mask[:, :model.n_conditioning_codebooks, :] = 0 | |
output["sampled"].append(model.sample( | |
codec=vqvae, | |
time_steps=z.shape[-1], | |
sampling_steps=12, | |
start_tokens=z, | |
mask=mask, | |
temperature=0.85, | |
top_k=None, | |
sample="gumbel", | |
typical_filtering=True, | |
return_signal=True | |
).cpu()) | |
output["argmax"].append(model.sample( | |
codec=vqvae, | |
time_steps=z.shape[-1], | |
sampling_steps=1, | |
start_tokens=z, | |
mask=mask, | |
temperature=1.0, | |
top_k=None, | |
sample="argmax", | |
typical_filtering=True, | |
return_signal=True | |
).cpu()) | |
return output | |
def main( | |
sources=[ | |
"/home/hugo/data/spotdl/audio/val", "/home/hugo/data/spotdl/audio/test" | |
], | |
audio_ext="mp3", | |
exp_name="noise_mode", | |
model_paths=[ | |
"ckpt/mask/best/vampnet/weights.pth", | |
"ckpt/random/best/vampnet/weights.pth", | |
], | |
model_keys=[ | |
"noise_mode=mask", | |
"noise_mode=random", | |
], | |
vqvae_path="ckpt/wav2wav.pth", | |
device="cuda", | |
): | |
from vampnet.modules.transformer import VampNet | |
from lac.model.lac import LAC | |
from audiotools.post import audio_zip | |
models = { | |
k: VampNet.load(p) for k, p in zip(model_keys, model_paths) | |
} | |
for model in models.values(): | |
model.eval() | |
print(f"Loaded {len(models)} models.") | |
vqvae = LAC.load(vqvae_path) | |
vqvae.to(device) | |
vqvae.eval() | |
print("Loaded VQVAE.") | |
audio_dict = defaultdict(list) | |
for source in sources: | |
print(f"Processing {source}...") | |
for path in tqdm(list(Path(source).glob(f"**/*.{audio_ext}"))): | |
sig = AudioSignal(path) | |
sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0) | |
for model_key, model in models.items(): | |
out = coarse2fine_infer(sig, model, vqvae, device) | |
for k in out: | |
audio_dict[f"{model_key}-{k}"].extend(out[k]) | |
audio_zip(audio_dict, f"{exp_name}-results.zip") | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
main() | |