File size: 3,999 Bytes
b9d6819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import random
import argparse
import os
import time
import soundfile as sf
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusers import DDIMScheduler
from models.conditioners import MaskDiT
from modules.autoencoder_wrapper import Autoencoder
from transformers import T5Tokenizer, T5EncoderModel
from inference import inference
from utils import scale_shift, get_lr_scheduler, compute_snr, load_yaml_with_includes


parser = argparse.ArgumentParser()
# config settings
parser.add_argument('--config-name', type=str, default='configs/udit_ada.yml')
parser.add_argument('--ckpt-path', type=str, default='../ckpts/')
parser.add_argument('--ckpt-id', type=str, default='120')
parser.add_argument('--save_path', type=str, default='../output/')
parser.add_argument('--test-df', type=str, default='audiocaps_test.csv')
# parser.add_argument('--test-split', type=str, default='test')

parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--guidance-scale', type=float, default=3)
parser.add_argument('--guidance-rescale', type=float, default=0)
parser.add_argument('--ddim-steps', type=int, default=50)
parser.add_argument('--eta', type=float, default=1)
parser.add_argument('--random-seed', type=int, default=None)

args = parser.parse_args()
params = load_yaml_with_includes(args.config_name)

# args.ckpt_path = f"{args.ckpt_path}/{params['model_name']}/{args.ckpt_id}.pt"
args.save_path = f"{args.save_path}/{params['model_name']}/{args.ckpt_id}_{args.ddim_steps}_{args.guidance_scale}_{args.guidance_rescale}/"
args.ckpt_path = f"{args.ckpt_path}/{args.ckpt_id}.pt"

if __name__ == '__main__':
    # Codec Model
    autoencoder = Autoencoder(ckpt_path=params['autoencoder']['path'],
                              model_type=params['autoencoder']['name'],
                              quantization_first=params['autoencoder']['q_first'])
    autoencoder.to(args.device)
    autoencoder.eval()

    # text encoder
    tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
    text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model'],
                                                  device_map='cpu').to(args.device)
    text_encoder.eval()

    # main U-Net
    unet = MaskDiT(**params['model']).to(args.device)
    unet.eval()
    unet.load_state_dict(torch.load(args.ckpt_path)['model'])

    total_params = sum([param.nelement() for param in unet.parameters()])
    print("Number of parameter: %.2fM" % (total_params / 1e6))

    noise_scheduler = DDIMScheduler(**params['diff'])
    # these steps reset dtype of noise_scheduler params
    latents = torch.randn((1, 128, 128), device=args.device)
    noise = torch.randn_like(latents)
    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=args.device)
    _ = noise_scheduler.add_noise(latents, noise, timesteps)

    df = pd.read_csv(args.test_df)
    # Wdf = df[df['split'] == args.test_split]
    df = df[df['audio_length'] != 0]
    # df = df.sample(10)
    os.makedirs(args.save_path, exist_ok=True)
    audio_frames = params['data']['train_frames']

    for i in tqdm(range(len(df))):
        row = df.iloc[i]
        text = row['caption']
        audio_id = row['audiocap_id']

        pred = inference(autoencoder, unet, None, None,
                         tokenizer, text_encoder, 
                         params, noise_scheduler,
                         text, None,
                         audio_frames,
                         args.guidance_scale, args.guidance_rescale,
                         args.ddim_steps, args.eta, args.random_seed,
                         args.device)
        pred = pred.cpu().numpy().squeeze(0).squeeze(0)

        sf.write(f"{args.save_path}/{audio_id}.wav",
                 pred, samplerate=params['data']['sr'])