Bils's picture
Update app.py
ccdc62f verified
raw
history blame
5.22 kB
import io
from pathlib import Path
from typing import Tuple, Optional
import gradio as gr
import numpy as np
import torch
from PIL import Image
from dotenv import load_dotenv
from diffusers import DiffusionPipeline
from transformers import pipeline
from huggingface_hub import login
# Load environment variables
load_dotenv()
hf_token = os.getenv("HF_TKN")
if hf_token:
login(token=hf_token)
# Device configuration
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load models
@spaces.GPU
def load_models():
"""Load both models with proper device placement"""
caption_pipe = pipeline(
"image-to-text",
model="nlpconnect/vit-gpt2-image-captioning",
device=device
)
audio_pipe = DiffusionPipeline.from_pretrained(
"cvssp/audioldm2",
token=hf_token,
torch_dtype=torch_dtype
)
return caption_pipe, audio_pipe
caption_pipe, audio_pipe = load_models()
def analyze_image(image_bytes: bytes) -> Tuple[str, bool]:
"""Generate caption from image bytes with enhanced error handling"""
try:
image = Image.open(io.BytesIO(image_bytes))
if image.mode != "RGB":
image = image.convert("RGB")
results = caption_pipe(image)
if not results or not isinstance(results, list):
return "Error: Invalid response from caption model", True
caption = results[0].get("generated_text", "").strip()
return caption or "No caption generated", not bool(caption)
except Exception as e:
return f"Image processing error: {str(e)}", True
@spaces.GPU(duration=120)
def generate_audio(caption: str) -> Optional[Tuple[int, np.ndarray]]:
"""Generate audio from caption with resource management"""
try:
# Device management with context
original_device = next(audio_pipe.parameters()).device
audio_pipe.to(device)
# Generation with progress awareness
audio = audio_pipe(
prompt=caption,
num_inference_steps=50,
guidance_scale=7.5,
audio_length_in_s=5.0 # Keep audio generation short
).audios[0]
# Post-processing
audio = audio.squeeze() # Handle mono channel
audio = np.clip(audio, -1, 1) # Ensure valid range
return (16000, audio)
except Exception as e:
print(f"Audio generation error: {str(e)}")
return None
finally:
audio_pipe.to(original_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
# UI Components
css = """
#col-container {
max-width: 800px;
margin: 0 auto;
}
.disclaimer {
font-size: 0.9em;
color: #666;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h1 style="text-align: center;">🎶 Image to Sound Effect Generator</h1>
<p style="text-align: center;">
⚡ Powered by <a href="https://bilsimaging.com" target="_blank">Bilsimaging</a>
</p>
""")
with gr.Row():
image_input = gr.Image(type="filepath", label="Upload Image")
caption_output = gr.Textbox(label="Generated Description", interactive=False)
with gr.Row():
generate_btn = gr.Button("Generate Description", variant="primary")
audio_output = gr.Audio(label="Generated Sound", interactive=False)
sound_btn = gr.Button("Generate Sound", variant="secondary")
gr.Examples(
examples=[str(Path(__file__).parent / "examples" / f) for f in ["storm.jpg", "city.jpg"]],
inputs=image_input,
outputs=[caption_output, audio_output],
fn=lambda x: (analyze_image(Path(x).read_bytes())[0], None),
cache_examples=True
)
gr.Markdown("### 🛠️ Usage Tips")
gr.Markdown("""
- Use clear, high-contrast images for best results
- Complex scenes may require multiple generations
- Keep sound generation under 10 seconds for quick results
""")
gr.Markdown("### ⚠️ Disclaimer", elem_classes="disclaimer")
gr.Markdown("""
Generated content may not always be accurate. Use at your own discretion.
[Privacy Policy](https://bilsimaging.com/privacy) |
[Terms of Service](https://bilsimaging.com/terms)
""")
# Event handling
generate_btn.click(
fn=lambda x: analyze_image(Path(x).read_bytes())[0],
inputs=image_input,
outputs=caption_output,
api_name="describe"
)
sound_btn.click(
fn=generate_audio,
inputs=caption_output,
outputs=audio_output,
api_name="generate_sound"
)
# Input validation
image_input.change(
fn=lambda: [gr.update(value=""), gr.update(value=None)],
outputs=[caption_output, audio_output]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0" if os.getenv("SPACE_ID") else "127.0.0.1")