# Will be fixed soon, but meanwhile: import os if os.getenv('SPACES_ZERO_GPU') == "true": os.environ['SPACES_ZERO_GPU'] = "1" import gradio as gr import random import torch from torch import inference_mode # from tempfile import NamedTemporaryFile from typing import Optional import numpy as np from models import load_model import utils import spaces from inversion_utils import inversion_forward_process, inversion_reverse_process # current_loaded_model = "cvssp/audioldm2-music" # # current_loaded_model = "cvssp/audioldm2-music" # ldm_stable = load_model(current_loaded_model, device, 200) # deafult model LDM2 = "cvssp/audioldm2" MUSIC = "cvssp/audioldm2-music" LDM2_LARGE = "cvssp/audioldm2-large" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ldm2 = load_model(model_id=LDM2, device=device) ldm2_large = load_model(model_id=LDM2_LARGE, device=device) ldm2_music = load_model(model_id=MUSIC, device=device) def randomize_seed_fn(seed, randomize_seed): if randomize_seed: seed = random.randint(0, np.iinfo(np.int32).max) torch.manual_seed(seed) return seed def invert(ldm_stable, x0, prompt_src, num_diffusion_steps, cfg_scale_src): # , ldm_stable): # ldm_stable.model.scheduler.set_timesteps(num_diffusion_steps, device=device) with inference_mode(): w0 = ldm_stable.vae_encode(x0) # find Zs and wts - forward process _, zs, wts = inversion_forward_process(ldm_stable, w0, etas=1, prompts=[prompt_src], cfg_scales=[cfg_scale_src], prog_bar=True, num_inference_steps=num_diffusion_steps, numerical_fix=True) return zs, wts def sample(ldm_stable, zs, wts, steps, prompt_tar, tstart, cfg_scale_tar): # , ldm_stable): # reverse process (via Zs and wT) tstart = torch.tensor(tstart, dtype=torch.int) skip = steps - tstart w0, _ = inversion_reverse_process(ldm_stable, xT=wts, skips=steps - skip, etas=1., prompts=[prompt_tar], neg_prompts=[""], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[:int(steps - skip)]) # vae decode image with inference_mode(): x0_dec = ldm_stable.vae_decode(w0) if x0_dec.dim() < 4: x0_dec = x0_dec[None, :, :, :] with torch.no_grad(): audio = ldm_stable.decode_to_mel(x0_dec) return (16000, audio.squeeze().cpu().numpy()) def get_duration(input_audio, model_id: str, do_inversion: bool, wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], saved_inv_model: str, source_prompt="", target_prompt="", steps=200, cfg_scale_src=3.5, cfg_scale_tar=12, t_start=45, randomize_seed=True): if model_id == LDM2: factor = 0.8 elif model_id == LDM2_LARGE: factor = 1.5 else: # MUSIC factor = 1 mult = 0 if do_inversion or randomize_seed: mult = steps if input_audio is None: raise gr.Error('Input audio missing!') duration = min(utils.get_duration(input_audio), 30) time_per_iter_of_full = factor * ((t_start /100 * steps)*2 + mult) * 0.2 print('expected time:', time_per_iter_of_full / 30 * duration) return time_per_iter_of_full / 30 * duration @spaces.GPU(duration=get_duration) def edit( # cache_dir, input_audio, model_id: str, do_inversion: bool, # wtszs_file: str, wts: Optional[torch.Tensor], zs: Optional[torch.Tensor], saved_inv_model: str, source_prompt="", target_prompt="", steps=200, cfg_scale_src=3.5, cfg_scale_tar=12, t_start=45, randomize_seed=True): print(model_id) if model_id == LDM2: ldm_stable = ldm2 elif model_id == LDM2_LARGE: ldm_stable = ldm2_large else: # MUSIC ldm_stable = ldm2_music ldm_stable.model.scheduler.set_timesteps(steps, device=device) # If the inversion was done for a different model, we need to re-run the inversion if not do_inversion and (saved_inv_model is None or saved_inv_model != model_id): do_inversion = True if input_audio is None: raise gr.Error('Input audio missing!') x0 = utils.load_audio(input_audio, ldm_stable.get_fn_STFT(), device=device) # if not (do_inversion or randomize_seed): # if not os.path.exists(wtszs_file): # do_inversion = True # Too much time has passed if wts is None or zs is None: do_inversion = True if do_inversion or randomize_seed: # always re-run inversion zs_tensor, wts_tensor = invert(ldm_stable=ldm_stable, x0=x0, prompt_src=source_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src) # f = NamedTemporaryFile("wb", dir=cache_dir, suffix=".pth", delete=False) # torch.save({'wts': wts_tensor, 'zs': zs_tensor}, f.name) # wtszs_file = f.name # wtszs_file = gr.State(value=f.name) # wts = gr.State(value=wts_tensor) wts = wts_tensor zs = zs_tensor # zs = gr.State(value=zs_tensor) # demo.move_resource_to_block_cache(f.name) saved_inv_model = model_id do_inversion = False else: # wtszs = torch.load(wtszs_file, map_location=device) # # wtszs = torch.load(wtszs_file.f, map_location=device) # wts_tensor = wtszs['wts'] # zs_tensor = wtszs['zs'] wts_tensor = wts.to(device) zs_tensor = zs.to(device) # make sure t_start is in the right limit # t_start = change_tstart_range(t_start, steps) output = sample(ldm_stable, zs_tensor, wts_tensor, steps, prompt_tar=target_prompt, tstart=int(t_start / 100 * steps), cfg_scale_tar=cfg_scale_tar) return output, wts.cpu(), zs.cpu(), saved_inv_model, do_inversion # return output, wtszs_file, saved_inv_model, do_inversion def get_example(): case = [ ['Examples/Beethoven.wav', '', 'A recording of an arcade game soundtrack.', 45, 'cvssp/audioldm2-music', '27s', 'Examples/Beethoven_arcade.wav', ], ['Examples/Beethoven.wav', 'A high quality recording of wind instruments and strings playing.', 'A high quality recording of a piano playing.', 45, 'cvssp/audioldm2-music', '27s', 'Examples/Beethoven_piano.wav', ], ['Examples/ModalJazz.wav', 'Trumpets playing alongside a piano, bass and drums in an upbeat old-timey cool jazz song.', 'A banjo playing alongside a piano, bass and drums in an upbeat old-timey cool country song.', 45, 'cvssp/audioldm2-music', '106s', 'Examples/ModalJazz_banjo.wav',], ['Examples/Cat.wav', '', 'A dog barking.', 75, 'cvssp/audioldm2-large', '10s', 'Examples/Cat_dog.wav',] ] return case intro = """
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
""" help = """load_audio
function in the utils.py
file,
change duration = min(audioldm.utils.get_duration(audio_path), 30)
to
duration = audioldm.utils.get_duration(audio_path)
.