Spaces:
Runtime error
Runtime error
File size: 3,304 Bytes
a63cce0 e4e3c4e a63cce0 6f6fd13 a63cce0 e4e3c4e a63cce0 6f6fd13 a63cce0 e4e3c4e a63cce0 e4e3c4e 6f6fd13 a63cce0 e4e3c4e 6f6fd13 e4e3c4e 6f6fd13 a63cce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
from audiotools import AudioSignal
import torch
from pathlib import Path
import argbind
from tqdm import tqdm
import random
from typing import List
from collections import defaultdict
def coarse2fine_infer(
signal,
model,
vqvae,
device,
):
output = {}
w = signal
w = w.to(device)
z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
model.to(device)
output["reconstructed"] = model.to_signal(z, vqvae).cpu()
# make a full mask
mask = torch.ones_like(z)
mask[:, :model.n_conditioning_codebooks, :] = 0
output["sampled"] = model.sample(
codec=vqvae,
time_steps=z.shape[-1],
sampling_steps=12,
start_tokens=z,
mask=mask,
temperature=0.85,
top_k=None,
sample="gumbel",
typical_filtering=True,
return_signal=True
).cpu()
output["argmax"] = model.sample(
codec=vqvae,
time_steps=z.shape[-1],
sampling_steps=1,
start_tokens=z,
mask=mask,
temperature=1.0,
top_k=None,
sample="argmax",
typical_filtering=True,
return_signal=True
).cpu()
return output
@argbind.bind(without_prefix=True)
def main(
sources=[
"/data/spotdl/audio/val", "/data/spotdl/audio/test"
],
exp_name="noise_mode",
model_paths=[
"runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth",
"runs/c2f-exp-03.22.23/ckpt/random/epoch=400/vampnet/weights.pth",
],
model_keys=[
"mask",
"random",
],
vqvae_path: str = "runs/codec-ckpt/codec.pth",
device: str = "cuda",
output_dir: str = ".",
max_excerpts: int = 5000,
duration: float = 3.0,
):
from vampnet.modules.transformer import VampNet
from lac.model.lac import LAC
models = {
k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
}
for model in models.values():
model.eval()
print(f"Loaded {len(models)} models.")
vqvae = LAC.load(vqvae_path)
vqvae.to(device)
vqvae.eval()
print("Loaded VQVAE.")
output_dir = Path(output_dir) / f"{exp_name}-samples"
from audiotools.data.datasets import AudioLoader, AudioDataset
loader = AudioLoader(sources=sources)
dataset = AudioDataset(loader,
sample_rate=vqvae.sample_rate,
duration=duration,
n_examples=max_excerpts,
without_replacement=True,
)
for i in tqdm(range(max_excerpts)):
sig = dataset[i]["signal"]
sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
for model_key, model in models.items():
out = coarse2fine_infer(sig, model, vqvae, device)
out_dir = output_dir / model_key / Path(sig.path_to_file).stem
out_dir.mkdir(parents=True, exist_ok=True)
for k, s in out.items():
s.write(out_dir / f"{k}.wav")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
main()
|