Spaces:
Runtime error
Runtime error
File size: 3,473 Bytes
a63cce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from audiotools import AudioSignal
import torch
from pathlib import Path
import argbind
from tqdm import tqdm
import random
from collections import defaultdict
def coarse2fine_infer(
signal,
model,
vqvae,
device,
signal_window=3,
signal_hop=1.5,
max_excerpts=25,
):
output = defaultdict(list)
# split into 3 seconds
windows = [s for s in signal.clone().windows(signal_window, signal_hop)]
random.shuffle(windows)
for w in windows[1:max_excerpts]: # skip the first window since it's mostly zero padded?
# batch the signal into chunks of 3
with torch.no_grad():
# get codes
w = w.to(device)
z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
model.to(device)
output["reconstructed"] = model.to_signal(z, vqvae).cpu()
# make a full mask
mask = torch.ones_like(z)
mask[:, :model.n_conditioning_codebooks, :] = 0
output["sampled"].append(model.sample(
codec=vqvae,
time_steps=z.shape[-1],
sampling_steps=12,
start_tokens=z,
mask=mask,
temperature=0.85,
top_k=None,
sample="gumbel",
typical_filtering=True,
return_signal=True
).cpu())
output["argmax"].append(model.sample(
codec=vqvae,
time_steps=z.shape[-1],
sampling_steps=1,
start_tokens=z,
mask=mask,
temperature=1.0,
top_k=None,
sample="argmax",
typical_filtering=True,
return_signal=True
).cpu())
return output
@argbind.bind(without_prefix=True)
def main(
sources=[
"/home/hugo/data/spotdl/audio/val", "/home/hugo/data/spotdl/audio/test"
],
audio_ext="mp3",
exp_name="noise_mode",
model_paths=[
"ckpt/mask/best/vampnet/weights.pth",
"ckpt/random/best/vampnet/weights.pth",
],
model_keys=[
"noise_mode=mask",
"noise_mode=random",
],
vqvae_path="ckpt/wav2wav.pth",
device="cuda",
):
from vampnet.modules.transformer import VampNet
from lac.model.lac import LAC
from audiotools.post import audio_zip
models = {
k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
}
for model in models.values():
model.eval()
print(f"Loaded {len(models)} models.")
vqvae = LAC.load(vqvae_path)
vqvae.to(device)
vqvae.eval()
print("Loaded VQVAE.")
audio_dict = defaultdict(list)
for source in sources:
print(f"Processing {source}...")
for path in tqdm(list(Path(source).glob(f"**/*.{audio_ext}"))):
sig = AudioSignal(path)
sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
for model_key, model in models.items():
out = coarse2fine_infer(sig, model, vqvae, device)
for k in out:
audio_dict[f"{model_key}-{k}"].extend(out[k])
audio_zip(audio_dict, f"{exp_name}-results.zip")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
main()
|