Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import torchaudio | |
import io | |
import base64 | |
import uuid | |
import os | |
import time | |
import re | |
import threading | |
import gc | |
import random | |
import numpy as np | |
from einops import rearrange | |
from huggingface_hub import login | |
from stable_audio_tools import get_pretrained_model | |
from stable_audio_tools.inference.generation import generate_diffusion_cond | |
from gradio_client import Client | |
from contextlib import contextmanager | |
# Global model storage | |
model_cache = {} | |
model_lock = threading.Lock() | |
def resource_cleanup(): | |
"""Context manager to ensure proper cleanup of GPU resources.""" | |
try: | |
yield | |
finally: | |
if torch.cuda.is_available(): | |
torch.cuda.synchronize() | |
torch.cuda.empty_cache() | |
gc.collect() | |
def load_stable_audio_model(): | |
"""Load stable-audio-open-small model if not already loaded.""" | |
with model_lock: | |
if 'stable_audio_model' not in model_cache: | |
print("π Loading stable-audio-open-small model...") | |
# Authenticate with HF | |
hf_token = os.getenv('HF_TOKEN') | |
if hf_token: | |
login(token=hf_token) | |
print(f"β HF authenticated") | |
# Load model | |
model, config = get_pretrained_model("stabilityai/stable-audio-open-small") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
if device == "cuda": | |
model = model.half() | |
model_cache['stable_audio_model'] = model | |
model_cache['stable_audio_config'] = config | |
model_cache['stable_audio_device'] = device | |
print(f"β Stable Audio model loaded on {device}") | |
return (model_cache['stable_audio_model'], | |
model_cache['stable_audio_config'], | |
model_cache['stable_audio_device']) | |
def generate_stable_audio_loop(prompt, loop_type, bpm, bars, seed=-1): | |
"""Generate a BPM-aware loop using stable-audio-open-small""" | |
try: | |
model, config, device = load_stable_audio_model() | |
# Calculate loop duration based on BPM and bars | |
seconds_per_beat = 60.0 / bpm | |
seconds_per_bar = seconds_per_beat * 4 # 4/4 time | |
target_loop_duration = seconds_per_bar * bars | |
# Enhance prompt based on loop type and BPM | |
if loop_type == "drums": | |
enhanced_prompt = f"{prompt} drum loop {bpm}bpm" | |
negative_prompt = "melody, harmony, pitched instruments, vocals, singing" | |
else: # instruments | |
enhanced_prompt = f"{prompt} instrumental loop {bpm}bpm" | |
negative_prompt = "drums, percussion, kick, snare, hi-hat" | |
# Set seed | |
if seed == -1: | |
seed = random.randint(0, 2**32 - 1) | |
torch.manual_seed(seed) | |
if device == "cuda": | |
torch.cuda.manual_seed(seed) | |
print(f"π΅ Generating {loop_type} loop:") | |
print(f" Enhanced prompt: {enhanced_prompt}") | |
print(f" Target duration: {target_loop_duration:.2f}s ({bars} bars at {bpm}bpm)") | |
print(f" Seed: {seed}") | |
# Prepare conditioning | |
conditioning = [{ | |
"prompt": enhanced_prompt, | |
"seconds_total": 12 # Model generates 12s max | |
}] | |
negative_conditioning = [{ | |
"prompt": negative_prompt, | |
"seconds_total": 12 | |
}] | |
start_time = time.time() | |
with resource_cleanup(): | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
with torch.cuda.amp.autocast(enabled=(device == "cuda")): | |
output = generate_diffusion_cond( | |
model, | |
steps=8, # Fast generation | |
cfg_scale=1.0, # Good balance for loops | |
conditioning=conditioning, | |
negative_conditioning=negative_conditioning, | |
sample_size=config["sample_size"], | |
sampler_type="pingpong", | |
device=device, | |
seed=seed | |
) | |
generation_time = time.time() - start_time | |
# Post-process audio | |
output = rearrange(output, "b d n -> d (b n)") # (2, N) stereo | |
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1) | |
# Extract the loop portion | |
sample_rate = config["sample_rate"] | |
loop_samples = int(target_loop_duration * sample_rate) | |
available_samples = output.shape[1] | |
if loop_samples > available_samples: | |
loop_samples = available_samples | |
actual_duration = available_samples / sample_rate | |
print(f"β οΈ Requested {target_loop_duration:.2f}s, got {actual_duration:.2f}s") | |
# Extract loop from beginning (cleanest beat alignment) | |
loop_output = output[:, :loop_samples] | |
loop_output_int16 = loop_output.mul(32767).to(torch.int16).cpu() | |
# Save to temporary file | |
loop_filename = f"loop_{loop_type}_{bpm}bpm_{bars}bars_{seed}.wav" | |
torchaudio.save(loop_filename, loop_output_int16, sample_rate) | |
actual_duration = loop_samples / sample_rate | |
print(f"β {loop_type.title()} loop generated: {actual_duration:.2f}s in {generation_time:.2f}s") | |
return loop_filename, f"Generated {actual_duration:.2f}s {loop_type} loop at {bpm}bpm ({bars} bars)" | |
except Exception as e: | |
print(f"β Generation error: {str(e)}") | |
return None, f"Error: {str(e)}" | |
def combine_loops(drums_audio, instruments_audio, bpm, bars, num_repeats): | |
"""Combine drum and instrument loops with specified repetitions""" | |
try: | |
if not drums_audio and not instruments_audio: | |
return None, "No audio files to combine" | |
# Calculate timing | |
seconds_per_beat = 60.0 / bpm | |
seconds_per_bar = seconds_per_beat * 4 | |
loop_duration = seconds_per_bar * bars | |
total_duration = loop_duration * num_repeats | |
print(f"ποΈ Combining loops:") | |
print(f" Loop duration: {loop_duration:.2f}s ({bars} bars)") | |
print(f" Repeats: {num_repeats}") | |
print(f" Total duration: {total_duration:.2f}s") | |
combined_audio = None | |
sample_rate = None | |
# Process each audio file | |
for audio_path, audio_type in [(drums_audio, "drums"), (instruments_audio, "instruments")]: | |
if audio_path: | |
# Load audio | |
waveform, sr = torchaudio.load(audio_path) | |
if sample_rate is None: | |
sample_rate = sr | |
# Ensure we have the exact loop duration | |
target_samples = int(loop_duration * sr) | |
if waveform.shape[1] > target_samples: | |
waveform = waveform[:, :target_samples] | |
elif waveform.shape[1] < target_samples: | |
# Pad if necessary | |
padding = target_samples - waveform.shape[1] | |
waveform = torch.cat([waveform, torch.zeros(waveform.shape[0], padding)], dim=1) | |
# Repeat the loop | |
repeated_waveform = waveform.repeat(1, num_repeats) | |
print(f" {audio_type}: {waveform.shape[1]/sr:.2f}s repeated {num_repeats}x = {repeated_waveform.shape[1]/sr:.2f}s") | |
# Add to combined audio | |
if combined_audio is None: | |
combined_audio = repeated_waveform | |
else: | |
combined_audio = combined_audio + repeated_waveform | |
if combined_audio is None: | |
return None, "No valid audio to combine" | |
# Normalize to prevent clipping | |
combined_audio = combined_audio / torch.max(torch.abs(combined_audio)) | |
combined_audio = combined_audio.clamp(-1, 1) | |
# Convert to int16 and save | |
combined_audio_int16 = combined_audio.mul(32767).to(torch.int16) | |
combined_filename = f"combined_{bpm}bpm_{bars}bars_{num_repeats}loops_{random.randint(1000, 9999)}.wav" | |
torchaudio.save(combined_filename, combined_audio_int16, sample_rate) | |
actual_duration = combined_audio.shape[1] / sample_rate | |
status = f"Combined into {actual_duration:.2f}s audio ({num_repeats} Γ {bars} bars at {bpm}bpm)" | |
print(f"β {status}") | |
return combined_filename, status | |
except Exception as e: | |
print(f"β Combine error: {str(e)}") | |
return None, f"Combine error: {str(e)}" | |
def transform_with_melodyflow_api(audio_path, prompt, solver="euler", flowstep=0.12): | |
"""Transform audio using Facebook/MelodyFlow space API""" | |
if audio_path is None: | |
return None, "β No audio file provided" | |
try: | |
# Initialize client for Facebook MelodyFlow space | |
client = Client("facebook/MelodyFlow") | |
# Set steps based on solver | |
if solver == "midpoint": | |
base_steps = 128 | |
effective_steps = base_steps // 2 # 64 effective steps | |
else: # euler | |
base_steps = 125 | |
effective_steps = base_steps // 5 # 25 effective steps | |
print(f"ποΈ MelodyFlow transformation:") | |
print(f" Prompt: {prompt}") | |
print(f" Solver: {solver} ({effective_steps} effective steps)") | |
print(f" Flowstep: {flowstep}") | |
# Call the MelodyFlow API - pass file path directly | |
result = client.predict( | |
model="facebook/melodyflow-t24-30secs", | |
text=prompt, | |
solver=solver, | |
steps=base_steps, | |
target_flowstep=flowstep, | |
regularize=solver == "euler", | |
regularization_strength=0.2, | |
duration=30, | |
melody=audio_path, # Pass file path directly instead of handle_file(audio_path) | |
api_name="/predict" | |
) | |
if result and len(result) > 0 and result[0]: | |
# Save the result locally | |
output_filename = f"melodyflow_transformed_{random.randint(1000, 9999)}.wav" | |
import shutil | |
shutil.copy2(result[0], output_filename) | |
status_msg = f"β Transformed with prompt: '{prompt}' (flowstep: {flowstep}, {effective_steps} steps)" | |
return output_filename, status_msg | |
else: | |
return None, "β MelodyFlow API returned no results" | |
except Exception as e: | |
return None, f"β MelodyFlow API error: {str(e)}" | |
def calculate_optimal_bars(bpm): | |
"""Calculate optimal bar count for given BPM to fit in ~10s""" | |
seconds_per_beat = 60.0 / bpm | |
seconds_per_bar = seconds_per_beat * 4 | |
max_duration = 10.0 | |
for bars in [8, 4, 2, 1]: | |
if seconds_per_bar * bars <= max_duration: | |
return bars | |
return 1 | |
# ========== GRADIO INTERFACE ========== | |
with gr.Blocks(title="π΅ Stable Audio Loop Generator") as iface: | |
gr.Markdown("# π΅ Stable Audio Loop Generator") | |
gr.Markdown("**Generate synchronized drum and instrument loops with stable-audio-open-small, then transform with MelodyFlow!**") | |
with gr.Accordion("How This Works", open=False): | |
gr.Markdown(""" | |
**Workflow:** | |
1. **Set global BPM and bars** - affects both drum and instrument generation | |
2. **Generate drum loop** - creates BPM-aware percussion | |
3. **Generate instrument loop** - creates melodic/harmonic content | |
4. **Combine loops** - layer them together with repetitions (up to 30s) | |
5. **Transform** - use MelodyFlow to stylistically transform the combined result | |
**Features:** | |
- BPM-aware generation ensures perfect sync between loops | |
- Negative prompting separates drums from instruments cleanly | |
- Smart bar calculation optimizes loop length for the BPM | |
- MelodyFlow integration for advanced style transfer | |
""") | |
# ========== GLOBAL CONTROLS ========== | |
gr.Markdown("## ποΈ Global Settings") | |
with gr.Row(): | |
global_bpm = gr.Dropdown( | |
label="Global BPM", | |
choices=[90, 100, 110, 120, 130, 140, 150], | |
value=120, | |
info="BPM applied to both drum and instrument generation" | |
) | |
global_bars = gr.Dropdown( | |
label="Loop Length (Bars)", | |
choices=[1, 2, 4, 8], | |
value=4, | |
info="Number of bars for each loop" | |
) | |
base_prompt = gr.Textbox( | |
label="Base Prompt", | |
value="techno", | |
placeholder="e.g., 'techno', 'jazz', 'ambient', 'hip-hop'", | |
info="Style applied to both loops" | |
) | |
# Auto-suggest optimal bars based on BPM | |
def update_suggested_bars(bpm): | |
optimal = calculate_optimal_bars(bpm) | |
return gr.update(info=f"Suggested: {optimal} bars for {bpm}bpm (β€10s)") | |
global_bpm.change(update_suggested_bars, inputs=[global_bpm], outputs=[global_bars]) | |
# ========== LOOP GENERATION ========== | |
gr.Markdown("## π₯ Step 1: Generate Individual Loops") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### π₯ Drum Loop") | |
generate_drums_btn = gr.Button("Generate Drums", variant="primary", size="lg") | |
drums_audio = gr.Audio(label="Drum Loop", type="filepath") | |
drums_status = gr.Textbox(label="Drums Status", value="Ready to generate") | |
with gr.Column(): | |
gr.Markdown("### πΉ Instrument Loop") | |
generate_instruments_btn = gr.Button("Generate Instruments", variant="secondary", size="lg") | |
instruments_audio = gr.Audio(label="Instrument Loop", type="filepath") | |
instruments_status = gr.Textbox(label="Instruments Status", value="Ready to generate") | |
# Seed controls | |
with gr.Row(): | |
drums_seed = gr.Number(label="Drums Seed", value=-1, info="-1 for random") | |
instruments_seed = gr.Number(label="Instruments Seed", value=-1, info="-1 for random") | |
# ========== COMBINATION ========== | |
gr.Markdown("## ποΈ Step 2: Combine Loops") | |
with gr.Row(): | |
num_repeats = gr.Slider( | |
label="Number of Repetitions", | |
minimum=1, | |
maximum=5, | |
step=1, | |
value=2, | |
info="How many times to repeat each loop (creates longer audio)" | |
) | |
combine_btn = gr.Button("ποΈ Combine Loops", variant="primary", size="lg") | |
combined_audio = gr.Audio(label="Combined Loops", type="filepath") | |
combine_status = gr.Textbox(label="Combine Status", value="Generate loops first") | |
# ========== MELODYFLOW TRANSFORMATION ========== | |
gr.Markdown("## π¨ Step 3: Transform with MelodyFlow") | |
with gr.Row(): | |
with gr.Column(): | |
transform_prompt = gr.Textbox( | |
label="Transformation Prompt", | |
value="aggressive industrial techno with distorted sounds", | |
placeholder="Describe the style transformation", | |
lines=2 | |
) | |
with gr.Column(): | |
transform_solver = gr.Dropdown( | |
label="Solver", | |
choices=["euler", "midpoint"], | |
value="euler", | |
info="EULER: faster (25 steps), MIDPOINT: slower (64 steps)" | |
) | |
transform_flowstep = gr.Slider( | |
label="Transform Intensity", | |
minimum=0.0, | |
maximum=0.15, | |
step=0.01, | |
value=0.12, | |
info="Lower = more dramatic transformation" | |
) | |
transform_btn = gr.Button("π¨ Transform Audio", variant="secondary", size="lg") | |
transformed_audio = gr.Audio(label="Transformed Audio", type="filepath") | |
transform_status = gr.Textbox(label="Transform Status", value="Combine audio first") | |
# ========== EVENT HANDLERS ========== | |
# Generate drums | |
generate_drums_btn.click( | |
generate_stable_audio_loop, | |
inputs=[base_prompt, gr.State("drums"), global_bpm, global_bars, drums_seed], | |
outputs=[drums_audio, drums_status] | |
) | |
# Generate instruments | |
generate_instruments_btn.click( | |
generate_stable_audio_loop, | |
inputs=[base_prompt, gr.State("instruments"), global_bpm, global_bars, instruments_seed], | |
outputs=[instruments_audio, instruments_status] | |
) | |
# Combine loops | |
combine_btn.click( | |
combine_loops, | |
inputs=[drums_audio, instruments_audio, global_bpm, global_bars, num_repeats], | |
outputs=[combined_audio, combine_status] | |
) | |
# Transform with MelodyFlow | |
transform_btn.click( | |
transform_with_melodyflow_api, | |
inputs=[combined_audio, transform_prompt, transform_solver, transform_flowstep], | |
outputs=[transformed_audio, transform_status] | |
) | |
# ========== EXAMPLES ========== | |
gr.Markdown("## π― Example Workflows") | |
examples = gr.Examples( | |
examples=[ | |
["techno", 128, 4, "aggressive industrial techno"], | |
["jazz", 110, 2, "smooth lo-fi jazz with vinyl crackle"], | |
["ambient", 90, 8, "ethereal ambient soundscape"], | |
["hip-hop", 100, 4, "classic boom bap hip-hop"], | |
["drum and bass", 140, 4, "liquid drum and bass"], | |
], | |
inputs=[base_prompt, global_bpm, global_bars, transform_prompt], | |
) | |
if __name__ == "__main__": | |
iface.launch() |