Spaces:
Running
Running
File size: 4,240 Bytes
afe1a07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import os
import torch
import argparse
import math
from einops import rearrange, repeat
from PIL import Image
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
from utils import load_t5, load_clap, load_ae
from train import RF
from constants import build_model
def prepare(t5, clip, img, prompt):
bs, c, h, w = img.shape
if bs == 1 and not isinstance(prompt, str):
bs = len(prompt)
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
if isinstance(prompt, str):
prompt = [prompt]
txt = t5(prompt)
if txt.shape[0] == 1 and bs > 1:
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
txt_ids = torch.zeros(bs, txt.shape[1], 3)
vec = clip(prompt)
if vec.shape[0] == 1 and bs > 1:
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
print(img_ids.size(), txt.size(), vec.size())
return img, {
"img_ids": img_ids.to(img.device),
"txt": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
"y": vec.to(img.device),
}
def main(args):
print('generate with MusicFlux')
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
latent_size = (256, 16)
model = build_model(args.version).to(device)
local_path = '/maindata/data/shared/multimodal/zhengcong.fei/code/music-flow/results/base/checkpoints/0050000.pt'
state_dict = torch.load(local_path, map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict['ema'])
model.eval() # important!
diffusion = RF()
model_path = '/maindata/data/shared/multimodal/public/ckpts/FLUX.1-dev'
# Setup VAE
t5 = load_t5(device, max_length=256)
clap = load_clap(device, max_length=256)
model_path = '/maindata/data/shared/multimodal/public/dataset_music/audioldm2'
vae = AutoencoderKL.from_pretrained(os.path.join(model_path, 'vae')).to(device)
vocoder = SpeechT5HifiGan.from_pretrained(os.path.join(model_path, 'vocoder')).to(device)
with open(args.prompt_file, 'r') as f:
conds_txt = f.readlines()
L = len(conds_txt)
unconds_txt = ["low quality, gentle"] * L
print(L, conds_txt, unconds_txt)
init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).cuda()
STEPSIZE = 50
img, conds = prepare(t5, clap, init_noise, conds_txt)
_, unconds = prepare(t5, clap, init_noise, unconds_txt)
with torch.autocast(device_type='cuda'):
images = diffusion.sample_with_xps(model, img, conds=conds, null_cond=unconds, sample_steps = STEPSIZE, cfg = 7.0)
print(images[-1].size(), )
images = rearrange(
images[-1],
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=128,
w=8,
ph=2,
pw=2,)
# print(images.size())
latents = 1 / vae.config.scaling_factor * images
mel_spectrogram = vae.decode(latents).sample
print(mel_spectrogram.size())
for i in range(L):
x_i = mel_spectrogram[i]
if x_i.dim() == 4:
x_i = x_i.squeeze(1)
waveform = vocoder(x_i)
waveform = waveform[0].cpu().float().detach().numpy()
print(waveform.shape)
# import soundfile as sf
# sf.write('reconstruct.wav', waveform, samplerate=16000)
from scipy.io import wavfile
wavfile.write('wav/sample_' + str(i) + '.wav', 16000, waveform)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--version", type=str, default="base")
parser.add_argument("--prompt_file", type=str, default='config/example.txt')
parser.add_argument("--seed", type=int, default=2024)
args = parser.parse_args()
main(args)
|