Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import argparse | |
import os | |
import time | |
import soundfile as sf | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers import DDIMScheduler | |
from models.conditioners import MaskDiT | |
from modules.autoencoder_wrapper import Autoencoder | |
from transformers import T5Tokenizer, T5EncoderModel | |
from inference import inference | |
from utils import scale_shift, get_lr_scheduler, compute_snr, load_yaml_with_includes | |
parser = argparse.ArgumentParser() | |
# config settings | |
parser.add_argument('--config-name', type=str, default='configs/udit_ada.yml') | |
parser.add_argument('--ckpt-path', type=str, default='../ckpts/') | |
parser.add_argument('--ckpt-id', type=str, default='120') | |
parser.add_argument('--save_path', type=str, default='../output/') | |
parser.add_argument('--test-df', type=str, default='audiocaps_test.csv') | |
# parser.add_argument('--test-split', type=str, default='test') | |
parser.add_argument('--device', type=str, default='cuda') | |
parser.add_argument('--guidance-scale', type=float, default=3) | |
parser.add_argument('--guidance-rescale', type=float, default=0) | |
parser.add_argument('--ddim-steps', type=int, default=50) | |
parser.add_argument('--eta', type=float, default=1) | |
parser.add_argument('--random-seed', type=int, default=None) | |
args = parser.parse_args() | |
params = load_yaml_with_includes(args.config_name) | |
# args.ckpt_path = f"{args.ckpt_path}/{params['model_name']}/{args.ckpt_id}.pt" | |
args.save_path = f"{args.save_path}/{params['model_name']}/{args.ckpt_id}_{args.ddim_steps}_{args.guidance_scale}_{args.guidance_rescale}/" | |
args.ckpt_path = f"{args.ckpt_path}/{args.ckpt_id}.pt" | |
if __name__ == '__main__': | |
# Codec Model | |
autoencoder = Autoencoder(ckpt_path=params['autoencoder']['path'], | |
model_type=params['autoencoder']['name'], | |
quantization_first=params['autoencoder']['q_first']) | |
autoencoder.to(args.device) | |
autoencoder.eval() | |
# text encoder | |
tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model']) | |
text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model'], | |
device_map='cpu').to(args.device) | |
text_encoder.eval() | |
# main U-Net | |
unet = MaskDiT(**params['model']).to(args.device) | |
unet.eval() | |
unet.load_state_dict(torch.load(args.ckpt_path)['model']) | |
total_params = sum([param.nelement() for param in unet.parameters()]) | |
print("Number of parameter: %.2fM" % (total_params / 1e6)) | |
noise_scheduler = DDIMScheduler(**params['diff']) | |
# these steps reset dtype of noise_scheduler params | |
latents = torch.randn((1, 128, 128), device=args.device) | |
noise = torch.randn_like(latents) | |
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=args.device) | |
_ = noise_scheduler.add_noise(latents, noise, timesteps) | |
df = pd.read_csv(args.test_df) | |
# Wdf = df[df['split'] == args.test_split] | |
df = df[df['audio_length'] != 0] | |
# df = df.sample(10) | |
os.makedirs(args.save_path, exist_ok=True) | |
audio_frames = params['data']['train_frames'] | |
for i in tqdm(range(len(df))): | |
row = df.iloc[i] | |
text = row['caption'] | |
audio_id = row['audiocap_id'] | |
pred = inference(autoencoder, unet, None, None, | |
tokenizer, text_encoder, | |
params, noise_scheduler, | |
text, None, | |
audio_frames, | |
args.guidance_scale, args.guidance_rescale, | |
args.ddim_steps, args.eta, args.random_seed, | |
args.device) | |
pred = pred.cpu().numpy().squeeze(0).squeeze(0) | |
sf.write(f"{args.save_path}/{audio_id}.wav", | |
pred, samplerate=params['data']['sr']) |