Bils's picture
Update app.py
a4f881b verified
raw
history blame
5.53 kB
import spaces
import os
import tempfile
import gradio as gr
from dotenv import load_dotenv
import torch
from scipy.io.wavfile import write
from diffusers import DiffusionPipeline
from transformers import pipeline
from pydub import AudioSegment
import numpy as np
load_dotenv()
hf_token = os.getenv("HF_TKN")
device_id = 0 if torch.cuda.is_available() else -1
# Initialize models
captioning_pipeline = pipeline(
"image-to-text",
model="nlpconnect/vit-gpt2-image-captioning",
device=device_id
)
pipe = DiffusionPipeline.from_pretrained(
"cvssp/audioldm2",
use_auth_token=hf_token
)
@spaces.GPU(duration=120)
def analyze_image(image_file):
try:
results = captioning_pipeline(image_file)
if not results or not isinstance(results, list):
return "Error: Could not generate caption.", True
caption = results[0].get("generated_text", "").strip()
return caption if caption else "No caption generated.", not bool(caption)
except Exception as e:
return f"Error analyzing image: {e}", True
@spaces.GPU(duration=120)
def generate_audio(prompt):
try:
pipe.to("cuda")
audio_output = pipe(
prompt=prompt,
num_inference_steps=50,
guidance_scale=7.5
)
pipe.to("cpu")
return audio_output.audios[0]
except Exception as e:
print(f"Error generating audio: {e}")
return None
def blend_audios(audio_list):
try:
# Find the longest audio duration
max_length = max([arr.shape[0] for arr in audio_list])
# Mix all audios
mixed = np.zeros(max_length)
for arr in audio_list:
if arr.shape[0] < max_length:
padded = np.pad(arr, (0, max_length - arr.shape[0]))
else:
padded = arr[:max_length]
mixed += padded
# Normalize the audio
mixed = mixed / np.max(np.abs(mixed))
# Save to temporary file
_, tmp_path = tempfile.mkstemp(suffix=".wav")
write(tmp_path, 16000, mixed)
return tmp_path
except Exception as e:
print(f"Error blending audio: {e}")
return None
css = """
#col-container { max-width: 800px; margin: 0 auto; }
.toggle-row { margin: 1rem 0; }
.prompt-box { margin-bottom: 0.5rem; }
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h1 style="text-align: center;">๐ŸŽถ Advanced Sound Generator</h1>
<p style="text-align: center;">โšก Powered by Bilsimaging</p>
""")
# Input mode toggle
input_mode = gr.Radio(
choices=["Image Input", "Text Prompts"],
value="Image Input",
label="Select Input Mode",
elem_classes="toggle-row"
)
# Image input section
with gr.Column(visible=True) as image_col:
image_upload = gr.Image(type="filepath", label="Upload Image")
generate_desc_btn = gr.Button("Generate Description from Image")
caption_display = gr.Textbox(label="Generated Description", interactive=False)
# Text input section
with gr.Column(visible=False) as text_col:
with gr.Row():
prompt1 = gr.Textbox(label="Sound Prompt 1", lines=2)
prompt2 = gr.Textbox(label="Sound Prompt 2", lines=2)
additional_prompts = gr.Column()
add_prompt_btn = gr.Button("โž• Add Another Prompt", variant="secondary")
generate_sound_btn = gr.Button("Generate Blended Sound", variant="primary")
# Audio output
audio_output = gr.Audio(label="Final Sound Composition", interactive=False)
# Documentation section
gr.Markdown("""
## ๐ŸŽš๏ธ How to Use
1. **Choose Input Mode** above
2. For images: Upload + Generate Description โ†’ Generate Sound
3. For text: Enter multiple sound prompts โ†’ Generate Blended Sound
[Support on Ko-fi](https://ko-fi.com/bilsimaging)
""")
# Visitor badge
gr.HTML("""
<div style="text-align: center; margin-top: 2rem;">
<a href="https://visitorbadge.io/status?path=YOUR_SPACE_URL">
<img src="https://api.visitorbadge.io/api/visitors?path=YOUR_SPACE_URL&countColor=%23263759"/>
</a>
</div>
""")
# Toggle visibility based on input mode
def toggle_input(mode):
if mode == "Image Input":
return [gr.update(visible=True), gr.update(visible=False)]
return [gr.update(visible=False), gr.update(visible=True)]
input_mode.change(
fn=toggle_input,
inputs=input_mode,
outputs=[image_col, text_col]
)
# Image processing chain
generate_desc_btn.click(
fn=analyze_image,
inputs=image_upload,
outputs=caption_display
).then(
fn=lambda: gr.update(interactive=True),
outputs=generate_sound_btn
)
# Text processing chain
generate_sound_btn.click(
fn=lambda *prompts: [p for p in prompts if p.strip()],
inputs=[prompt1, prompt2],
outputs=[]
).then(
fn=lambda prompts: [generate_audio(p) for p in prompts],
outputs=[]
).then(
fn=blend_audios,
outputs=audio_output
)
# Queue management
demo.queue(concurrency_count=2)
if __name__ == "__main__":
demo.launch()