import random import argparse import os import time import soundfile as sf import numpy as np import pandas as pd from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from diffusers import DDIMScheduler from models.conditioners import MaskDiT from modules.autoencoder_wrapper import Autoencoder from transformers import T5Tokenizer, T5EncoderModel from inference import inference from utils import scale_shift, get_lr_scheduler, compute_snr, load_yaml_with_includes parser = argparse.ArgumentParser() # config settings parser.add_argument('--config-name', type=str, default='configs/udit_ada.yml') parser.add_argument('--ckpt-path', type=str, default='../ckpts/') parser.add_argument('--ckpt-id', type=str, default='120') parser.add_argument('--save_path', type=str, default='../output/') parser.add_argument('--test-df', type=str, default='audiocaps_test.csv') # parser.add_argument('--test-split', type=str, default='test') parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--guidance-scale', type=float, default=3) parser.add_argument('--guidance-rescale', type=float, default=0) parser.add_argument('--ddim-steps', type=int, default=50) parser.add_argument('--eta', type=float, default=1) parser.add_argument('--random-seed', type=int, default=None) args = parser.parse_args() params = load_yaml_with_includes(args.config_name) # args.ckpt_path = f"{args.ckpt_path}/{params['model_name']}/{args.ckpt_id}.pt" args.save_path = f"{args.save_path}/{params['model_name']}/{args.ckpt_id}_{args.ddim_steps}_{args.guidance_scale}_{args.guidance_rescale}/" args.ckpt_path = f"{args.ckpt_path}/{args.ckpt_id}.pt" if __name__ == '__main__': # Codec Model autoencoder = Autoencoder(ckpt_path=params['autoencoder']['path'], model_type=params['autoencoder']['name'], quantization_first=params['autoencoder']['q_first']) autoencoder.to(args.device) autoencoder.eval() # text encoder tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model']) text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model'], device_map='cpu').to(args.device) text_encoder.eval() # main U-Net unet = MaskDiT(**params['model']).to(args.device) unet.eval() unet.load_state_dict(torch.load(args.ckpt_path)['model']) total_params = sum([param.nelement() for param in unet.parameters()]) print("Number of parameter: %.2fM" % (total_params / 1e6)) noise_scheduler = DDIMScheduler(**params['diff']) # these steps reset dtype of noise_scheduler params latents = torch.randn((1, 128, 128), device=args.device) noise = torch.randn_like(latents) timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=args.device) _ = noise_scheduler.add_noise(latents, noise, timesteps) df = pd.read_csv(args.test_df) # Wdf = df[df['split'] == args.test_split] df = df[df['audio_length'] != 0] # df = df.sample(10) os.makedirs(args.save_path, exist_ok=True) audio_frames = params['data']['train_frames'] for i in tqdm(range(len(df))): row = df.iloc[i] text = row['caption'] audio_id = row['audiocap_id'] pred = inference(autoencoder, unet, None, None, tokenizer, text_encoder, params, noise_scheduler, text, None, audio_frames, args.guidance_scale, args.guidance_rescale, args.ddim_steps, args.eta, args.random_seed, args.device) pred = pred.cpu().numpy().squeeze(0).squeeze(0) sf.write(f"{args.save_path}/{audio_id}.wav", pred, samplerate=params['data']['sr'])