Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from einops import rearrange, repeat | |
from diffusers import AutoencoderKL | |
from transformers import SpeechT5HifiGan | |
from scipy.io import wavfile | |
import glob | |
import random | |
import numpy as np | |
import re | |
import requests | |
import time | |
import gc | |
# Import necessary functions and classes | |
from utils import load_t5, load_clap | |
from train import RF | |
from constants import build_model | |
# Global variables to store loaded models and resources | |
global_model = None | |
global_t5 = None | |
global_clap = None | |
global_vae = None | |
global_vocoder = None | |
global_diffusion = None | |
current_model_name = None | |
# Set the models directory | |
MODELS_DIR = os.path.join(os.path.dirname(__file__), "models") | |
GENERATIONS_DIR = os.path.join(os.path.dirname(__file__), "generations") | |
def prepare(t5, clip, img, prompt): | |
bs, c, h, w = img.shape | |
if bs == 1 and not isinstance(prompt, str): | |
bs = len(prompt) | |
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) | |
if img.shape[0] == 1 and bs > 1: | |
img = repeat(img, "1 ... -> bs ...", bs=bs) | |
img_ids = torch.zeros(h // 2, w // 2, 3) | |
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] | |
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] | |
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) | |
if isinstance(prompt, str): | |
prompt = [prompt] | |
# Generate text embeddings | |
txt = t5(prompt) | |
if txt.shape[0] == 1 and bs > 1: | |
txt = repeat(txt, "1 ... -> bs ...", bs=bs) | |
txt_ids = torch.zeros(bs, txt.shape[1], 3) | |
vec = clip(prompt) | |
if vec.shape[0] == 1 and bs > 1: | |
vec = repeat(vec, "1 ... -> bs ...", bs=bs) | |
return img, { | |
"img_ids": img_ids.to(img.device), | |
"txt": txt.to(img.device), | |
"txt_ids": txt_ids.to(img.device), | |
"y": vec.to(img.device), | |
} | |
def unload_current_model(): | |
global global_model, current_model_name | |
if global_model is not None: | |
del global_model | |
global_model = None | |
current_model_name = None | |
torch.cuda.empty_cache() | |
gc.collect() | |
def load_model(model_name, device, model_url=None): | |
global global_model, current_model_name | |
unload_current_model() | |
if model_url: | |
print(f"Downloading model from URL: {model_url}") | |
response = requests.get(model_url) | |
if response.status_code == 200: | |
model_path = os.path.join(MODELS_DIR, "downloaded_model.pt") | |
with open(model_path, 'wb') as f: | |
f.write(response.content) | |
model_name = "downloaded_model.pt" | |
else: | |
return f"Failed to download model from URL: {model_url}" | |
else: | |
model_path = os.path.join(MODELS_DIR, model_name) | |
if not os.path.exists(model_path): | |
return f"Model file not found: {model_path}" | |
# Determine model size from filename | |
if 'musicflow_b' in model_name: | |
model_size = "base" | |
elif 'musicflow_g' in model_name: | |
model_size = "giant" | |
elif 'musicflow_l' in model_name: | |
model_size = "large" | |
elif 'musicflow_s' in model_name: | |
model_size = "small" | |
else: | |
model_size = "base" # Default to base if unrecognized | |
print(f"Loading {model_size} model: {model_name}") | |
try: | |
start_time = time.time() | |
global_model = build_model(model_size).to(device) | |
state_dict = torch.load(model_path, map_location=device, weights_only=True) | |
global_model.load_state_dict(state_dict['ema'], strict=False) | |
global_model.eval() | |
global_model.model_path = model_path | |
current_model_name = model_name | |
end_time = time.time() | |
load_time = end_time - start_time | |
return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds" | |
except Exception as e: | |
unload_current_model() | |
print(f"Error loading model {model_name}: {str(e)}") | |
return f"Failed to load model: {model_name}. Error: {str(e)}" | |
def load_resources(device): | |
global global_t5, global_clap, global_vae, global_vocoder, global_diffusion | |
try: | |
start_time = time.time() | |
print("Loading T5 and CLAP models...") | |
global_t5 = load_t5(device, max_length=256) | |
global_clap = load_clap(device, max_length=256) | |
print("Loading VAE and vocoder...") | |
global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device) | |
global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device) | |
print("Initializing diffusion...") | |
global_diffusion = RF() | |
end_time = time.time() | |
load_time = end_time - start_time | |
print(f"Base resources loaded successfully in {load_time:.2f} seconds!") | |
return f"Resources loaded successfully in {load_time:.2f} seconds!" | |
except Exception as e: | |
print(f"Error loading resources: {str(e)}") | |
return f"Failed to load resources. Error: {str(e)}" | |
def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()): | |
global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion | |
if global_model is None: | |
return "Please select and load a model first.", None | |
if global_t5 is None or global_clap is None or global_vae is None or global_vocoder is None or global_diffusion is None: | |
return "Resources not properly loaded. Please reload the page and try again.", None | |
if seed == 0: | |
seed = random.randint(1, 1000000) | |
print(f"Using seed: {seed}") | |
torch.manual_seed(seed) | |
torch.set_grad_enabled(False) | |
# Ensure we're using CPU if CUDA is not available | |
if device == "cuda" and not torch.cuda.is_available(): | |
print("CUDA is not available. Falling back to CPU.") | |
device = "cpu" | |
# Calculate the number of segments needed for the desired duration | |
segment_duration = 10 # Each segment is 10 seconds | |
num_segments = int(np.ceil(duration / segment_duration)) | |
all_waveforms = [] | |
for i in range(num_segments): | |
progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}") | |
# Use the same seed for all segments | |
torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency | |
latent_size = (256, 16) | |
conds_txt = [prompt] | |
unconds_txt = ["low quality, gentle"] | |
L = len(conds_txt) | |
init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device) | |
img, conds = prepare(global_t5, global_clap, init_noise, conds_txt) | |
_, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt) | |
# Implement batching for inference | |
images = [] | |
for batch_start in range(0, img.shape[0], batch_size): | |
batch_end = min(batch_start + batch_size, img.shape[0]) | |
batch_img = img[batch_start:batch_end] | |
batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()} | |
batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()} | |
with torch.no_grad(): | |
batch_images = global_diffusion.sample_with_xps( | |
global_model, batch_img, conds=batch_conds, null_cond=batch_unconds, | |
sample_steps=steps, cfg=cfg_scale | |
) | |
images.append(batch_images[-1]) | |
images = torch.cat(images, dim=0) | |
images = rearrange( | |
images, | |
"b (h w) (c ph pw) -> b c (h ph) (w pw)", | |
h=128, | |
w=8, | |
ph=2, | |
pw=2,) | |
latents = 1 / global_vae.config.scaling_factor * images | |
mel_spectrogram = global_vae.decode(latents).sample | |
x_i = mel_spectrogram[0] | |
if x_i.dim() == 4: | |
x_i = x_i.squeeze(1) | |
waveform = global_vocoder(x_i) | |
waveform = waveform[0].cpu().float().detach().numpy() | |
all_waveforms.append(waveform) | |
# Clear some memory after each segment | |
del images, latents, mel_spectrogram, x_i | |
torch.cuda.empty_cache() | |
gc.collect() | |
# Concatenate all waveforms | |
final_waveform = np.concatenate(all_waveforms) | |
# Trim to exact duration | |
sample_rate = 16000 | |
final_waveform = final_waveform[:int(duration * sample_rate)] | |
progress(0.9, desc="Saving audio file") | |
# Create 'generations' folder | |
os.makedirs(GENERATIONS_DIR, exist_ok=True) | |
# Generate filename | |
prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_') | |
model_name = os.path.splitext(os.path.basename(global_model.model_path))[0] | |
model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}' | |
base_filename = f"{prompt_part}_{seed}{model_suffix}" | |
output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav") | |
# Check if file exists and add numerical suffix if needed | |
counter = 1 | |
while os.path.exists(output_path): | |
output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav") | |
counter += 1 | |
wavfile.write(output_path, sample_rate, final_waveform) | |
progress(1.0, desc="Audio generation complete") | |
return f"Generated with seed: {seed}", output_path | |
# Get list of .pt files in the models directory | |
model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt")) | |
model_choices = [os.path.basename(f) for f in model_files] | |
# Ensure we have at least one model | |
if not model_choices: | |
print(f"No models found in the models directory: {MODELS_DIR}") | |
print("Available files in the directory:") | |
print(os.listdir(MODELS_DIR)) | |
model_choices = ["No models available"] | |
# Set default model | |
default_model = 'musicflow_b.pt' if 'musicflow_b.pt' in model_choices else model_choices[0] | |
# Set up dark grey theme | |
theme = gr.themes.Monochrome( | |
primary_hue="gray", | |
secondary_hue="gray", | |
neutral_hue="gray", | |
radius_size=gr.themes.sizes.radius_sm, | |
) | |
# Gradio Interface | |
with gr.Blocks(theme=theme) as iface: | |
gr.Markdown( | |
""" | |
<div style="text-align: center;"> | |
<h1>FluxMusic Generator</h1> | |
<p>Generate music based on text prompts using FluxMusic model.</p> | |
<p>Feel free to clone this space and run on GPU locally or on Hugging Face.</p> | |
</div> | |
""") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model) | |
model_url = gr.Textbox(label="Or enter model URL") | |
device_choice = gr.Radio(["cpu", "cuda"], label="Device", value="cpu") | |
load_model_button = gr.Button("Load Model") | |
model_status = gr.Textbox(label="Model Status", value="No model loaded") | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt") | |
seed = gr.Number(label="Seed", value=0) | |
with gr.Row(): | |
cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20) | |
steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100) | |
duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1) | |
generate_button = gr.Button("Generate Music") | |
output_status = gr.Textbox(label="Generation Status") | |
output_audio = gr.Audio(type="filepath") | |
def on_load_model_click(model_name, device, url): | |
# Ensure we're using CPU if CUDA is not available | |
if device == "cuda" and not torch.cuda.is_available(): | |
print("CUDA is not available. Falling back to CPU.") | |
device = "cpu" | |
resource_status = load_resources(device) | |
if "Failed" in resource_status: | |
return resource_status | |
if url: | |
result = load_model(None, device, model_url=url) | |
else: | |
result = load_model(model_name, device) | |
return result | |
load_model_button.click(on_load_model_click, inputs=[model_dropdown, device_choice, model_url], outputs=[model_status]) | |
generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration, device_choice], outputs=[output_status, output_audio]) | |
# Load default model and resources on startup | |
iface.load(lambda: on_load_model_click(default_model, "cpu", None), inputs=None, outputs=None) | |
# Launch the interface | |
iface.launch() |