import gradio as gr import spaces import yaml import torch import librosa from diffusers import DDIMScheduler from transformers import AutoProcessor, ClapModel from model.udit import UDiT from vae_modules.autoencoder_wrapper import Autoencoder import numpy as np diffusion_config = './config/SoloAudio.yaml' diffusion_ckpt = './pretrained_models/soloaudio_v2.pt' autoencoder_path = './pretrained_models/audio-vae.pt' uncond_path = './pretrained_models/uncond.npz' sample_rate = 24000 device = 'cuda' if torch.cuda.is_available() else 'cpu' with open(diffusion_config, 'r') as fp: diff_config = yaml.safe_load(fp) v_prediction = diff_config["ddim"]["v_prediction"] clapmodel = ClapModel.from_pretrained("laion/larger_clap_general").to(device) processor = AutoProcessor.from_pretrained('laion/larger_clap_general') autoencoder = Autoencoder(autoencoder_path, 'stable_vae', quantization_first=True) autoencoder.eval() autoencoder.to(device) unet = UDiT( **diff_config['diffwrap']['UDiT'] ).to(device) unet.load_state_dict(torch.load(diffusion_ckpt)['model']) unet.eval() if v_prediction: print('v prediction') scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers']) else: print('noise prediction') scheduler = DDIMScheduler(**diff_config["ddim"]['diffusers']) # these steps reset dtype of noise_scheduler params latents = torch.randn((1, 128, 128), device=device) noise = torch.randn(latents.shape).to(latents.device) timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (noise.shape[0],), device=latents.device).long() _ = scheduler.add_noise(latents, noise, timesteps) def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg @spaces.GPU def sample_diffusion(mixture, timbre, ddim_steps=50, eta=0, seed=2023, guidance_scale=False, guidance_rescale=0.0,): with torch.no_grad(): scheduler.set_timesteps(ddim_steps) generator = torch.Generator(device=device).manual_seed(seed) # init noise noise = torch.randn(mixture.shape, generator=generator, device=device) pred = noise for t in scheduler.timesteps: pred = scheduler.scale_model_input(pred, t) if guidance_scale: uncond = torch.tensor(np.load(uncond_path)['arr_0']).unsqueeze(0).to(device) pred_combined = torch.cat([pred, pred], dim=0) mixture_combined = torch.cat([mixture, mixture], dim=0) timbre_combined = torch.cat([timbre, uncond], dim=0) output_combined = unet(x=pred_combined, timesteps=t, mixture=mixture_combined, timbre=timbre_combined) output_pos, output_neg = torch.chunk(output_combined, 2, dim=0) model_output = output_neg + guidance_scale * (output_pos - output_neg) if guidance_rescale > 0.0: # avoid overexposed model_output = rescale_noise_cfg(model_output, output_pos, guidance_rescale=guidance_rescale) else: model_output = unet(x=pred, timesteps=t, mixture=mixture, timbre=timbre) pred = scheduler.step(model_output=model_output, timestep=t, sample=pred, eta=eta, generator=generator).prev_sample pred = autoencoder(embedding=pred).squeeze(1) return pred @spaces.GPU def tse(gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale): with torch.no_grad(): mixture, _ = librosa.load(gt_file_input, sr=sample_rate) # Check the length of the audio in samples current_length = len(mixture) target_length = sample_rate * 10 # Cut or pad the audio to match the target length if current_length > target_length: # Trim the audio if it's longer than the target length mixture = mixture[:target_length] elif current_length < target_length: # Pad the audio with zeros if it's shorter than the target length padding = target_length - current_length mixture = np.pad(mixture, (0, padding), mode='constant') mixture = torch.tensor(mixture).unsqueeze(0).to(device) mixture = autoencoder(audio=mixture.unsqueeze(1)) text_inputs = processor( text=[text_input], max_length=10, # Fixed length for text padding='max_length', # Pad text to max length truncation=True, # Truncate text if it's longer than max length return_tensors="pt" ) inputs = { "input_ids": text_inputs["input_ids"][0].unsqueeze(0), # Text input IDs "attention_mask": text_inputs["attention_mask"][0].unsqueeze(0), # Attention mask for text } inputs = {key: value.to(device) for key, value in inputs.items()} timbre = clapmodel.get_text_features(**inputs) pred = sample_diffusion(mixture, timbre, num_infer_steps, eta, seed, guidance_scale, guidance_rescale) return sample_rate, pred.squeeze().cpu().numpy() # CSS styling (optional) css = """ #col-container { margin: 0 auto; max-width: 1280px; } """ # Gradio Blocks layout with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # SoloAudio: Target Sound Extraction with Language-oriented Audio Diffusion Transformer. Adjust advanced settings for more control. This space only supports a 10-second audio input now. Learn more about 🟣**SoloAudio** on the [SoloAudio Homepage](https://wanghelin1997.github.io/SoloAudio-Demo/). """) with gr.Tab("Target Sound Extraction"): # Basic Input: Text prompt with gr.Row(): gt_file_input = gr.Audio(label="Upload Audio to Extract", type="filepath", value="demo/0_mix.wav") text_input = gr.Textbox( label="Text Prompt", show_label=True, max_lines=2, placeholder="Enter your prompt", container=True, value="The sound of gunshot", scale=4 ) # Run button run_button = gr.Button("Extract", scale=1) # Output Component result = gr.Audio(label="Extracted Audio", type="numpy") # Advanced settings in an Accordion with gr.Accordion("Advanced Settings", open=False): # Audio Length guidance_scale = gr.Slider(minimum=1.0, maximum=10, step=0.1, value=3.0, label="Guidance Scale") guidance_rescale = gr.Slider(minimum=0.0, maximum=1, step=0.05, value=0., label="Guidance Rescale") num_infer_steps = gr.Slider(minimum=25, maximum=200, step=5, value=50, label="DDIM Steps") eta = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Eta") seed = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Seed") # Define the trigger and input-output linking for generation run_button.click( fn=tse, inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale], outputs=[result] ) text_input.submit(fn=tse, inputs=[gt_file_input, text_input, num_infer_steps, eta, seed, guidance_scale, guidance_rescale], outputs=[result] ) # Launch the Gradio demo demo.launch()