File size: 4,240 Bytes
afe1a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os 
import torch 
import argparse
import math 
from einops import rearrange, repeat
from PIL import Image
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan

from utils import load_t5, load_clap, load_ae
from train import RF 
from constants import build_model


def prepare(t5, clip, img, prompt):
    bs, c, h, w = img.shape
    if bs == 1 and not isinstance(prompt, str):
        bs = len(prompt)

    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
    if img.shape[0] == 1 and bs > 1:
        img = repeat(img, "1 ... -> bs ...", bs=bs)

    img_ids = torch.zeros(h // 2, w // 2, 3)
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

    if isinstance(prompt, str):
        prompt = [prompt]
    txt = t5(prompt)
    if txt.shape[0] == 1 and bs > 1:
        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
    txt_ids = torch.zeros(bs, txt.shape[1], 3)

    vec = clip(prompt)
    if vec.shape[0] == 1 and bs > 1:
        vec = repeat(vec, "1 ... -> bs ...", bs=bs)

    print(img_ids.size(), txt.size(), vec.size())
    return img, {
        "img_ids": img_ids.to(img.device),
        "txt": txt.to(img.device),
        "txt_ids": txt_ids.to(img.device),
        "y": vec.to(img.device),
    }

def main(args):
    print('generate with MusicFlux')
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    latent_size = (256, 16) 

    model = build_model(args.version).to(device) 
    local_path = '/maindata/data/shared/multimodal/zhengcong.fei/code/music-flow/results/base/checkpoints/0050000.pt'
    state_dict = torch.load(local_path, map_location=lambda storage, loc: storage)
    model.load_state_dict(state_dict['ema'])
    model.eval()  # important! 
    diffusion = RF()

    model_path = '/maindata/data/shared/multimodal/public/ckpts/FLUX.1-dev'  

    # Setup VAE
    t5 = load_t5(device, max_length=256)
    clap = load_clap(device, max_length=256)

    model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2' 
    vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')).to(device)
    vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')).to(device)

    with open(args.prompt_file, 'r') as f: 
        conds_txt = f.readlines()
    L = len(conds_txt) 
    unconds_txt = ["low quality, gentle"] * L 
    print(L, conds_txt, unconds_txt) 

    init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda() 

    STEPSIZE = 50
    img, conds = prepare(t5, clap, init_noise, conds_txt)
    _, unconds = prepare(t5, clap, init_noise, unconds_txt) 
    with torch.autocast(device_type='cuda'): 
        images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0)
    
    print(images[-1].size(), )
    
    images = rearrange(
        images[-1], 
        "b (h w) (c ph pw) -> b c (h ph) (w pw)",
        h=128,
        w=8,
        ph=2,
        pw=2,)
    # print(images.size())
    latents = 1 / vae.config.scaling_factor * images
    mel_spectrogram = vae.decode(latents).sample 
    print(mel_spectrogram.size()) 
    
    for i in range(L): 
        x_i = mel_spectrogram[i]
        if x_i.dim() == 4:
            x_i = x_i.squeeze(1)
        waveform = vocoder(x_i)
        waveform = waveform[0].cpu().float().detach().numpy()
        print(waveform.shape)
        # import soundfile as sf
        # sf.write('reconstruct.wav', waveform, samplerate=16000) 
        from  scipy.io import wavfile 
        wavfile.write('wav/sample_' + str(i) + '.wav', 16000, waveform) 
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--version", type=str, default="base")
    parser.add_argument("--prompt_file", type=str, default='config/example.txt')
    parser.add_argument("--seed", type=int, default=2024)
    args = parser.parse_args()
    main(args)