vampnet / scripts /utils /process_folder-c2f.py
Hugo Flores
interface improvements
a63cce0
raw
history blame
3.47 kB
from audiotools import AudioSignal
import torch
from pathlib import Path
import argbind
from tqdm import tqdm
import random
from collections import defaultdict
def coarse2fine_infer(
signal,
model,
vqvae,
device,
signal_window=3,
signal_hop=1.5,
max_excerpts=25,
):
output = defaultdict(list)
# split into 3 seconds
windows = [s for s in signal.clone().windows(signal_window, signal_hop)]
random.shuffle(windows)
for w in windows[1:max_excerpts]: # skip the first window since it's mostly zero padded?
# batch the signal into chunks of 3
with torch.no_grad():
# get codes
w = w.to(device)
z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
model.to(device)
output["reconstructed"] = model.to_signal(z, vqvae).cpu()
# make a full mask
mask = torch.ones_like(z)
mask[:, :model.n_conditioning_codebooks, :] = 0
output["sampled"].append(model.sample(
codec=vqvae,
time_steps=z.shape[-1],
sampling_steps=12,
start_tokens=z,
mask=mask,
temperature=0.85,
top_k=None,
sample="gumbel",
typical_filtering=True,
return_signal=True
).cpu())
output["argmax"].append(model.sample(
codec=vqvae,
time_steps=z.shape[-1],
sampling_steps=1,
start_tokens=z,
mask=mask,
temperature=1.0,
top_k=None,
sample="argmax",
typical_filtering=True,
return_signal=True
).cpu())
return output
@argbind.bind(without_prefix=True)
def main(
sources=[
"/home/hugo/data/spotdl/audio/val", "/home/hugo/data/spotdl/audio/test"
],
audio_ext="mp3",
exp_name="noise_mode",
model_paths=[
"ckpt/mask/best/vampnet/weights.pth",
"ckpt/random/best/vampnet/weights.pth",
],
model_keys=[
"noise_mode=mask",
"noise_mode=random",
],
vqvae_path="ckpt/wav2wav.pth",
device="cuda",
):
from vampnet.modules.transformer import VampNet
from lac.model.lac import LAC
from audiotools.post import audio_zip
models = {
k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
}
for model in models.values():
model.eval()
print(f"Loaded {len(models)} models.")
vqvae = LAC.load(vqvae_path)
vqvae.to(device)
vqvae.eval()
print("Loaded VQVAE.")
audio_dict = defaultdict(list)
for source in sources:
print(f"Processing {source}...")
for path in tqdm(list(Path(source).glob(f"**/*.{audio_ext}"))):
sig = AudioSignal(path)
sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
for model_key, model in models.items():
out = coarse2fine_infer(sig, model, vqvae, device)
for k in out:
audio_dict[f"{model_key}-{k}"].extend(out[k])
audio_zip(audio_dict, f"{exp_name}-results.zip")
if __name__ == "__main__":
args = argbind.parse_args()
with argbind.scope(args):
main()