Spaces:
Runtime error
Runtime error
from audiotools import AudioSignal | |
import torch | |
from pathlib import Path | |
import argbind | |
from tqdm import tqdm | |
import random | |
from typing import List | |
from collections import defaultdict | |
def coarse2fine_infer( | |
signal, | |
model, | |
vqvae, | |
device, | |
signal_window=3, | |
signal_hop=1.5, | |
max_excerpts=20, | |
): | |
output = defaultdict(list) | |
# split into 3 seconds | |
windows = [s for s in signal.clone().windows(signal_window, signal_hop)] | |
windows = windows[1:] # skip first window since it's half zero padded | |
random.shuffle(windows) | |
for w in windows[:max_excerpts]: | |
# batch the signal into chunks of 3 | |
with torch.no_grad(): | |
# get codes | |
w = w.to(device) | |
z = vqvae.encode(w.audio_data, w.sample_rate)["codes"] | |
model.to(device) | |
output["reconstructed"] = model.to_signal(z, vqvae).cpu() | |
# make a full mask | |
mask = torch.ones_like(z) | |
mask[:, :model.n_conditioning_codebooks, :] = 0 | |
output["sampled"].append(model.sample( | |
codec=vqvae, | |
time_steps=z.shape[-1], | |
sampling_steps=12, | |
start_tokens=z, | |
mask=mask, | |
temperature=0.85, | |
top_k=None, | |
sample="gumbel", | |
typical_filtering=True, | |
return_signal=True | |
).cpu()) | |
output["argmax"].append(model.sample( | |
codec=vqvae, | |
time_steps=z.shape[-1], | |
sampling_steps=1, | |
start_tokens=z, | |
mask=mask, | |
temperature=1.0, | |
top_k=None, | |
sample="argmax", | |
typical_filtering=True, | |
return_signal=True | |
).cpu()) | |
return output | |
def main( | |
sources=[ | |
"/data/spotdl/audio/val", "/data/spotdl/audio/test" | |
], | |
audio_ext="mp3", | |
exp_name="noise_mode", | |
model_paths=[ | |
"runs/c2f-exp-03.22.23/ckpt/mask/best/vampnet/weights.pth", | |
"runs/c2f-exp-03.22.23/ckpt/random/best/vampnet/weights.pth", | |
], | |
model_keys=[ | |
"mask", | |
"random", | |
], | |
vqvae_path: str = "runs/codec-ckpt/codec.pth", | |
device: str = "cuda", | |
output_dir: str = ".", | |
): | |
from vampnet.modules.transformer import VampNet | |
from lac.model.lac import LAC | |
from audiotools.post import audio_zip | |
models = { | |
k: VampNet.load(p) for k, p in zip(model_keys, model_paths) | |
} | |
for model in models.values(): | |
model.eval() | |
print(f"Loaded {len(models)} models.") | |
vqvae = LAC.load(vqvae_path) | |
vqvae.to(device) | |
vqvae.eval() | |
print("Loaded VQVAE.") | |
output_dir = Path(output_dir) / f"{exp_name}-samples" | |
for source in sources: | |
print(f"Processing {source}...") | |
source_files = list(Path(source).glob(f"**/*.{audio_ext}")) | |
random.shuffle(source_files) | |
for path in tqdm(source_files): | |
sig = AudioSignal(path) | |
sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0) | |
out_dir = output_dir / path.stem | |
out_dir.mkdir(parents=True, exist_ok=True) | |
if out_dir.exists(): | |
print(f"Skipping {path.stem} since {out_dir} already exists.") | |
continue | |
for model_key, model in models.items(): | |
out = coarse2fine_infer(sig, model, vqvae, device) | |
for k, sig_list in out.items(): | |
for i, s in enumerate(sig_list): | |
s.write(out_dir / f"{model_key}-{k}-{i}.wav") | |
if __name__ == "__main__": | |
args = argbind.parse_args() | |
with argbind.scope(args): | |
main() | |