import spaces import os import tempfile import gradio as gr from dotenv import load_dotenv import torch from scipy.io.wavfile import write from diffusers import DiffusionPipeline from transformers import pipeline from pydub import AudioSegment import numpy as np load_dotenv() hf_token = os.getenv("HF_TKN") device_id = 0 if torch.cuda.is_available() else -1 # Initialize models captioning_pipeline = pipeline( "image-to-text", model="nlpconnect/vit-gpt2-image-captioning", device=device_id ) pipe = DiffusionPipeline.from_pretrained( "cvssp/audioldm2", use_auth_token=hf_token ) @spaces.GPU(duration=120) def analyze_image(image_file): try: results = captioning_pipeline(image_file) if not results or not isinstance(results, list): return "Error: Could not generate caption.", True caption = results[0].get("generated_text", "").strip() return caption if caption else "No caption generated.", not bool(caption) except Exception as e: return f"Error analyzing image: {e}", True @spaces.GPU(duration=120) def generate_audio(prompt): try: pipe.to("cuda") audio_output = pipe( prompt=prompt, num_inference_steps=50, guidance_scale=7.5 ) pipe.to("cpu") return audio_output.audios[0] except Exception as e: print(f"Error generating audio: {e}") return None def blend_audios(audio_list): try: # Find the longest audio duration max_length = max([arr.shape[0] for arr in audio_list]) # Mix all audios mixed = np.zeros(max_length) for arr in audio_list: if arr.shape[0] < max_length: padded = np.pad(arr, (0, max_length - arr.shape[0])) else: padded = arr[:max_length] mixed += padded # Normalize the audio mixed = mixed / np.max(np.abs(mixed)) # Save to temporary file _, tmp_path = tempfile.mkstemp(suffix=".wav") write(tmp_path, 16000, mixed) return tmp_path except Exception as e: print(f"Error blending audio: {e}") return None css = """ #col-container { max-width: 800px; margin: 0 auto; } .toggle-row { margin: 1rem 0; } .prompt-box { margin-bottom: 0.5rem; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.HTML("""
⚡ Powered by Bilsimaging
""") # Input mode toggle input_mode = gr.Radio( choices=["Image Input", "Text Prompts"], value="Image Input", label="Select Input Mode", elem_classes="toggle-row" ) # Image input section with gr.Column(visible=True) as image_col: image_upload = gr.Image(type="filepath", label="Upload Image") generate_desc_btn = gr.Button("Generate Description from Image") caption_display = gr.Textbox(label="Generated Description", interactive=False) # Text input section with gr.Column(visible=False) as text_col: with gr.Row(): prompt1 = gr.Textbox(label="Sound Prompt 1", lines=2) prompt2 = gr.Textbox(label="Sound Prompt 2", lines=2) additional_prompts = gr.Column() add_prompt_btn = gr.Button("➕ Add Another Prompt", variant="secondary") generate_sound_btn = gr.Button("Generate Blended Sound", variant="primary") # Audio output audio_output = gr.Audio(label="Final Sound Composition", interactive=False) # Documentation section gr.Markdown(""" ## 🎚️ How to Use 1. **Choose Input Mode** above 2. For images: Upload + Generate Description → Generate Sound 3. For text: Enter multiple sound prompts → Generate Blended Sound [Support on Ko-fi](https://ko-fi.com/bilsimaging) """) # Visitor badge gr.HTML(""" """) # Toggle visibility based on input mode def toggle_input(mode): if mode == "Image Input": return [gr.update(visible=True), gr.update(visible=False)] return [gr.update(visible=False), gr.update(visible=True)] input_mode.change( fn=toggle_input, inputs=input_mode, outputs=[image_col, text_col] ) # Image processing chain generate_desc_btn.click( fn=analyze_image, inputs=image_upload, outputs=caption_display ).then( fn=lambda: gr.update(interactive=True), outputs=generate_sound_btn ) # Text processing chain generate_sound_btn.click( fn=lambda *prompts: [p for p in prompts if p.strip()], inputs=[prompt1, prompt2], outputs=[] ).then( fn=lambda prompts: [generate_audio(p) for p in prompts], outputs=[] ).then( fn=blend_audios, outputs=audio_output ) # Queue management demo.queue(concurrency_count=2) if __name__ == "__main__": demo.launch()