Spaces:
Running
on
Zero
Running
on
Zero
import io | |
from pathlib import Path | |
from typing import Tuple, Optional | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from dotenv import load_dotenv | |
from diffusers import DiffusionPipeline | |
from transformers import pipeline | |
from huggingface_hub import login | |
# Load environment variables | |
load_dotenv() | |
hf_token = os.getenv("HF_TKN") | |
if hf_token: | |
login(token=hf_token) | |
# Device configuration | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Load models | |
def load_models(): | |
"""Load both models with proper device placement""" | |
caption_pipe = pipeline( | |
"image-to-text", | |
model="nlpconnect/vit-gpt2-image-captioning", | |
device=device | |
) | |
audio_pipe = DiffusionPipeline.from_pretrained( | |
"cvssp/audioldm2", | |
token=hf_token, | |
torch_dtype=torch_dtype | |
) | |
return caption_pipe, audio_pipe | |
caption_pipe, audio_pipe = load_models() | |
def analyze_image(image_bytes: bytes) -> Tuple[str, bool]: | |
"""Generate caption from image bytes with enhanced error handling""" | |
try: | |
image = Image.open(io.BytesIO(image_bytes)) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
results = caption_pipe(image) | |
if not results or not isinstance(results, list): | |
return "Error: Invalid response from caption model", True | |
caption = results[0].get("generated_text", "").strip() | |
return caption or "No caption generated", not bool(caption) | |
except Exception as e: | |
return f"Image processing error: {str(e)}", True | |
def generate_audio(caption: str) -> Optional[Tuple[int, np.ndarray]]: | |
"""Generate audio from caption with resource management""" | |
try: | |
# Device management with context | |
original_device = next(audio_pipe.parameters()).device | |
audio_pipe.to(device) | |
# Generation with progress awareness | |
audio = audio_pipe( | |
prompt=caption, | |
num_inference_steps=50, | |
guidance_scale=7.5, | |
audio_length_in_s=5.0 # Keep audio generation short | |
).audios[0] | |
# Post-processing | |
audio = audio.squeeze() # Handle mono channel | |
audio = np.clip(audio, -1, 1) # Ensure valid range | |
return (16000, audio) | |
except Exception as e: | |
print(f"Audio generation error: {str(e)}") | |
return None | |
finally: | |
audio_pipe.to(original_device) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# UI Components | |
css = """ | |
#col-container { | |
max-width: 800px; | |
margin: 0 auto; | |
} | |
.disclaimer { | |
font-size: 0.9em; | |
color: #666; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(""" | |
<h1 style="text-align: center;">🎶 Image to Sound Effect Generator</h1> | |
<p style="text-align: center;"> | |
⚡ Powered by <a href="https://bilsimaging.com" target="_blank">Bilsimaging</a> | |
</p> | |
""") | |
with gr.Row(): | |
image_input = gr.Image(type="filepath", label="Upload Image") | |
caption_output = gr.Textbox(label="Generated Description", interactive=False) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate Description", variant="primary") | |
audio_output = gr.Audio(label="Generated Sound", interactive=False) | |
sound_btn = gr.Button("Generate Sound", variant="secondary") | |
gr.Examples( | |
examples=[str(Path(__file__).parent / "examples" / f) for f in ["storm.jpg", "city.jpg"]], | |
inputs=image_input, | |
outputs=[caption_output, audio_output], | |
fn=lambda x: (analyze_image(Path(x).read_bytes())[0], None), | |
cache_examples=True | |
) | |
gr.Markdown("### 🛠️ Usage Tips") | |
gr.Markdown(""" | |
- Use clear, high-contrast images for best results | |
- Complex scenes may require multiple generations | |
- Keep sound generation under 10 seconds for quick results | |
""") | |
gr.Markdown("### ⚠️ Disclaimer", elem_classes="disclaimer") | |
gr.Markdown(""" | |
Generated content may not always be accurate. Use at your own discretion. | |
[Privacy Policy](https://bilsimaging.com/privacy) | | |
[Terms of Service](https://bilsimaging.com/terms) | |
""") | |
# Event handling | |
generate_btn.click( | |
fn=lambda x: analyze_image(Path(x).read_bytes())[0], | |
inputs=image_input, | |
outputs=caption_output, | |
api_name="describe" | |
) | |
sound_btn.click( | |
fn=generate_audio, | |
inputs=caption_output, | |
outputs=audio_output, | |
api_name="generate_sound" | |
) | |
# Input validation | |
image_input.change( | |
fn=lambda: [gr.update(value=""), gr.update(value=None)], | |
outputs=[caption_output, audio_output] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1") |