Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,999 Bytes
b9d6819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import random
import argparse
import os
import time
import soundfile as sf
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDIMScheduler
from models.conditioners import MaskDiT
from modules.autoencoder_wrapper import Autoencoder
from transformers import T5Tokenizer, T5EncoderModel
from inference import inference
from utils import scale_shift, get_lr_scheduler, compute_snr, load_yaml_with_includes
parser = argparse.ArgumentParser()
# config settings
parser.add_argument('--config-name', type=str, default='configs/udit_ada.yml')
parser.add_argument('--ckpt-path', type=str, default='../ckpts/')
parser.add_argument('--ckpt-id', type=str, default='120')
parser.add_argument('--save_path', type=str, default='../output/')
parser.add_argument('--test-df', type=str, default='audiocaps_test.csv')
# parser.add_argument('--test-split', type=str, default='test')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--guidance-scale', type=float, default=3)
parser.add_argument('--guidance-rescale', type=float, default=0)
parser.add_argument('--ddim-steps', type=int, default=50)
parser.add_argument('--eta', type=float, default=1)
parser.add_argument('--random-seed', type=int, default=None)
args = parser.parse_args()
params = load_yaml_with_includes(args.config_name)
# args.ckpt_path = f"{args.ckpt_path}/{params['model_name']}/{args.ckpt_id}.pt"
args.save_path = f"{args.save_path}/{params['model_name']}/{args.ckpt_id}_{args.ddim_steps}_{args.guidance_scale}_{args.guidance_rescale}/"
args.ckpt_path = f"{args.ckpt_path}/{args.ckpt_id}.pt"
if __name__ == '__main__':
# Codec Model
autoencoder = Autoencoder(ckpt_path=params['autoencoder']['path'],
model_type=params['autoencoder']['name'],
quantization_first=params['autoencoder']['q_first'])
autoencoder.to(args.device)
autoencoder.eval()
# text encoder
tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model'],
device_map='cpu').to(args.device)
text_encoder.eval()
# main U-Net
unet = MaskDiT(**params['model']).to(args.device)
unet.eval()
unet.load_state_dict(torch.load(args.ckpt_path)['model'])
total_params = sum([param.nelement() for param in unet.parameters()])
print("Number of parameter: %.2fM" % (total_params / 1e6))
noise_scheduler = DDIMScheduler(**params['diff'])
# these steps reset dtype of noise_scheduler params
latents = torch.randn((1, 128, 128), device=args.device)
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=args.device)
_ = noise_scheduler.add_noise(latents, noise, timesteps)
df = pd.read_csv(args.test_df)
# Wdf = df[df['split'] == args.test_split]
df = df[df['audio_length'] != 0]
# df = df.sample(10)
os.makedirs(args.save_path, exist_ok=True)
audio_frames = params['data']['train_frames']
for i in tqdm(range(len(df))):
row = df.iloc[i]
text = row['caption']
audio_id = row['audiocap_id']
pred = inference(autoencoder, unet, None, None,
tokenizer, text_encoder,
params, noise_scheduler,
text, None,
audio_frames,
args.guidance_scale, args.guidance_rescale,
args.ddim_steps, args.eta, args.random_seed,
args.device)
pred = pred.cpu().numpy().squeeze(0).squeeze(0)
sf.write(f"{args.save_path}/{audio_id}.wav",
pred, samplerate=params['data']['sr']) |