File size: 4,392 Bytes
d9a7330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import random
import spaces
import numpy as np
import gradio as gr
import soundfile as sf
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import DDIMScheduler
from src.models.conditioners import MaskDiT
from src.modules.autoencoder_wrapper import Autoencoder
from src.inference import inference
from src.utils import load_yaml_with_includes


# Load model and configs
def load_models(config_name, ckpt_path, vae_path, device):
    params = load_yaml_with_includes(config_name)

    # Load codec model
    autoencoder = Autoencoder(ckpt_path=vae_path,
                              model_type=params['autoencoder']['name'],
                              quantization_first=params['autoencoder']['q_first']).to(device)
    autoencoder.eval()

    # Load text encoder
    tokenizer = T5Tokenizer.from_pretrained(params['text_encoder']['model'])
    text_encoder = T5EncoderModel.from_pretrained(params['text_encoder']['model']).to(device)
    text_encoder.eval()

    # Load main U-Net model
    unet = MaskDiT(**params['model']).to(device)
    unet.load_state_dict(torch.load(ckpt_path)['model'])
    unet.eval()

    # Load noise scheduler
    noise_scheduler = DDIMScheduler(**params['diff'])
    
    latents = torch.randn((1, 128, 128), device=device)
    noise = torch.randn_like(latents)
    timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device=device)
    _ = noise_scheduler.add_noise(latents, noise, timesteps)

    return autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params

MAX_SEED = np.iinfo(np.int32).max

# Model and config paths
config_name = 'ckpts/ezaudio-xl.yml'
ckpt_path = 'ckpts/s3/ezaudio_s3_xl.pt'
vae_path = 'ckpts/vae/1m.pt'
save_path = 'output/'
os.makedirs(save_path, exist_ok=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

autoencoder, unet, tokenizer, text_encoder, noise_scheduler, params = load_models(config_name, ckpt_path, vae_path,
                                                                                  device)

@spaces.GPU
def generate_audio(text, length,
                   guidance_scale, guidance_rescale, ddim_steps, eta,
                   random_seed, randomize_seed):
    neg_text = None
    length = length * params['autoencoder']['latent_sr']

    if randomize_seed:
        random_seed = random.randint(0, MAX_SEED)

    pred = inference(autoencoder, unet, None, None,
                     tokenizer, text_encoder,
                     params, noise_scheduler,
                     text, neg_text,
                     length,
                     guidance_scale, guidance_rescale,
                     ddim_steps, eta, random_seed,
                     device)

    pred = pred.cpu().numpy().squeeze(0).squeeze(0)
    # output_file = f"{save_path}/{text}.wav"
    # sf.write(output_file, pred, samplerate=params['autoencoder']['sr'])

    return params['autoencoder']['sr'], pred


# Gradio Interface
def gradio_interface():
    # Input components
    text_input = gr.Textbox(label="Text Prompt", value="the sound of dog barking")
    length_input = gr.Slider(minimum=1, maximum=10, step=1, value=10, label="Audio Length (in seconds)")

    # Advanced settings
    guidance_scale_input = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=5, label="Guidance Scale")
    guidance_rescale_input = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0.75, label="Guidance Rescale")
    ddim_steps_input = gr.Slider(minimum=25, maximum=200, step=5, value=100, label="DDIM Steps")
    eta_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1, label="Eta")
    random_seed_input = gr.Slider(minimum=0, maximum=MAX_SEED, step=1, value=0,)

    randomize_seed = gr.Checkbox(label="Randomize seed", value=False)

    # Output component
    output_audio = gr.Audio(label="Converted Audio", type="numpy")

    # Interface
    gr.Interface(
        fn=generate_audio,
        inputs=[text_input, length_input, guidance_scale_input, guidance_rescale_input, ddim_steps_input, eta_input,
                random_seed_input, randomize_seed],
        outputs=output_audio,
        title="EzAudio Text-to-Audio Generator",
        description="Generate audio from text using a diffusion model. Adjust advanced settings for more control.",
        allow_flagging="never"
    ).launch()


if __name__ == "__main__":
    gradio_interface()