Bils commited on
Commit
e18ae6e
·
verified ·
1 Parent(s): 4a36f0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -88
app.py CHANGED
@@ -10,169 +10,213 @@ from transformers import pipeline
10
  from pydub import AudioSegment
11
  import numpy as np
12
 
 
13
  load_dotenv()
14
  hf_token = os.getenv("HF_TKN")
15
 
16
- device_id = 0 if torch.cuda.is_available() else -1
 
 
17
 
18
- # Initialize models
19
- captioning_pipeline = pipeline(
20
- "image-to-text",
21
- model="nlpconnect/vit-gpt2-image-captioning",
22
- device=device_id
23
- )
 
 
 
 
 
 
 
 
24
 
25
- pipe = DiffusionPipeline.from_pretrained(
26
- "cvssp/audioldm2",
27
- use_auth_token=hf_token
28
- )
29
 
30
- @spaces.GPU(duration=120)
31
  def analyze_image(image_file):
 
32
  try:
33
  results = captioning_pipeline(image_file)
34
- if not results or not isinstance(results, list):
35
- return "Error: Could not generate caption.", True
36
-
37
- caption = results[0].get("generated_text", "").strip()
38
- return caption if caption else "No caption generated.", not bool(caption)
39
  except Exception as e:
40
- return f"Error analyzing image: {e}", True
41
 
42
  @spaces.GPU(duration=120)
43
  def generate_audio(prompt):
 
44
  try:
45
- pipe.to("cuda")
46
- audio_output = pipe(
47
  prompt=prompt,
48
  num_inference_steps=50,
49
  guidance_scale=7.5
50
- )
51
- pipe.to("cpu")
52
- return audio_output.audios[0]
53
  except Exception as e:
54
- print(f"Error generating audio: {e}")
55
  return None
56
 
57
  def blend_audios(audio_list):
 
58
  try:
59
- # Find the longest audio duration
60
- max_length = max([arr.shape[0] for arr in audio_list])
61
-
62
- # Mix all audios
 
63
  mixed = np.zeros(max_length)
64
- for arr in audio_list:
 
65
  if arr.shape[0] < max_length:
66
  padded = np.pad(arr, (0, max_length - arr.shape[0]))
67
  else:
68
  padded = arr[:max_length]
69
  mixed += padded
70
 
71
- # Normalize the audio
72
  mixed = mixed / np.max(np.abs(mixed))
73
-
74
- # Save to temporary file
75
  _, tmp_path = tempfile.mkstemp(suffix=".wav")
76
  write(tmp_path, 16000, mixed)
77
  return tmp_path
78
  except Exception as e:
79
- print(f"Error blending audio: {e}")
80
  return None
81
 
82
  css = """
83
  #col-container { max-width: 800px; margin: 0 auto; }
84
  .toggle-row { margin: 1rem 0; }
85
  .prompt-box { margin-bottom: 0.5rem; }
 
86
  """
87
 
88
  with gr.Blocks(css=css) as demo:
89
  with gr.Column(elem_id="col-container"):
 
90
  gr.HTML("""
91
- <h1 style="text-align: center;">🎶 Advanced Sound Generator</h1>
92
- <p style="text-align: center;">⚡ Powered by Bilsimaging</p>
 
 
93
  """)
94
-
95
- # Input mode toggle
96
  input_mode = gr.Radio(
97
- choices=["Image Input", "Text Prompts"],
98
  value="Image Input",
99
  label="Select Input Mode",
100
  elem_classes="toggle-row"
101
  )
102
-
103
- # Image input section
104
  with gr.Column(visible=True) as image_col:
105
  image_upload = gr.Image(type="filepath", label="Upload Image")
106
- generate_desc_btn = gr.Button("Generate Description from Image")
107
  caption_display = gr.Textbox(label="Generated Description", interactive=False)
108
-
109
- # Text input section
110
  with gr.Column(visible=False) as text_col:
111
  with gr.Row():
112
- prompt1 = gr.Textbox(label="Sound Prompt 1", lines=2)
113
- prompt2 = gr.Textbox(label="Sound Prompt 2", lines=2)
114
  additional_prompts = gr.Column()
115
  add_prompt_btn = gr.Button("➕ Add Another Prompt", variant="secondary")
116
- generate_sound_btn = gr.Button("Generate Blended Sound", variant="primary")
117
-
118
- # Audio output
119
- audio_output = gr.Audio(label="Final Sound Composition", interactive=False)
120
-
121
- # Documentation section
 
122
  gr.Markdown("""
123
- ## 🎚️ How to Use
124
- 1. **Choose Input Mode** above
125
- 2. For images: Upload + Generate Description → Generate Sound
126
- 3. For text: Enter multiple sound prompts → Generate Blended Sound
127
- [Support on Ko-fi](https://ko-fi.com/bilsimaging)
128
  """)
129
 
130
- # Visitor badge
131
  gr.HTML("""
132
- <div style="text-align: center; margin-top: 2rem;">
133
- <a href="https://visitorbadge.io/status?path=YOUR_SPACE_URL">
134
- <img src="https://api.visitorbadge.io/api/visitors?path=YOUR_SPACE_URL&countColor=%23263759"/>
135
  </a>
136
  </div>
137
  """)
138
 
139
- # Toggle visibility based on input mode
140
- def toggle_input(mode):
141
- if mode == "Image Input":
142
- return [gr.update(visible=True), gr.update(visible=False)]
143
- return [gr.update(visible=False), gr.update(visible=True)]
144
-
145
  input_mode.change(
146
- fn=toggle_input,
147
  inputs=input_mode,
148
- outputs=[image_col, text_col]
 
149
  )
150
 
151
- # Image processing chain
152
  generate_desc_btn.click(
153
- fn=analyze_image,
154
  inputs=image_upload,
155
- outputs=caption_display
156
- ).then(
157
- fn=lambda: gr.update(interactive=True),
158
- outputs=generate_sound_btn
159
  )
160
 
161
- # Text processing chain
162
- generate_sound_btn.click(
163
- fn=lambda *prompts: [p for p in prompts if p.strip()],
164
- inputs=[prompt1, prompt2],
165
- outputs=[]
166
- ).then(
167
- fn=lambda prompts: [generate_audio(p) for p in prompts],
168
- outputs=[]
169
- ).then(
170
- fn=blend_audios,
171
- outputs=audio_output
 
 
 
 
 
 
 
 
172
  )
173
 
174
- # Queue management
175
- demo.queue(concurrency_count=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  if __name__ == "__main__":
178
- demo.launch()
 
10
  from pydub import AudioSegment
11
  import numpy as np
12
 
13
+ # Load environment variables
14
  load_dotenv()
15
  hf_token = os.getenv("HF_TKN")
16
 
17
+ # Device configuration
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ torch_dtype = torch.float16 if device == "cuda" else torch.float32
20
 
21
+ # Initialize models with automatic device detection
22
+ @spaces.GPU(duration=120)
23
+ def load_models():
24
+ global captioning_pipeline, pipe
25
+ captioning_pipeline = pipeline(
26
+ "image-to-text",
27
+ model="nlpconnect/vit-gpt2-image-captioning",
28
+ device=0 if torch.cuda.is_available() else -1
29
+ )
30
+ pipe = DiffusionPipeline.from_pretrained(
31
+ "cvssp/audioldm2",
32
+ use_auth_token=hf_token,
33
+ torch_dtype=torch_dtype
34
+ ).to(device)
35
 
36
+ load_models()
 
 
 
37
 
38
+ @spaces.GPU(duration=60)
39
  def analyze_image(image_file):
40
+ """Generate caption from image with error handling"""
41
  try:
42
  results = captioning_pipeline(image_file)
43
+ if results and isinstance(results, list):
44
+ return results[0].get("generated_text", "").strip()
45
+ return "Could not generate caption"
 
 
46
  except Exception as e:
47
+ return f"Error: {str(e)}"
48
 
49
  @spaces.GPU(duration=120)
50
  def generate_audio(prompt):
51
+ """Generate audio from text prompt"""
52
  try:
53
+ return pipe(
 
54
  prompt=prompt,
55
  num_inference_steps=50,
56
  guidance_scale=7.5
57
+ ).audios[0]
 
 
58
  except Exception as e:
59
+ print(f"Audio generation error: {str(e)}")
60
  return None
61
 
62
  def blend_audios(audio_list):
63
+ """Mix multiple audio arrays into one"""
64
  try:
65
+ valid_audios = [arr for arr in audio_list if arr is not None]
66
+ if not valid_audios:
67
+ return None
68
+
69
+ max_length = max(arr.shape[0] for arr in valid_audios)
70
  mixed = np.zeros(max_length)
71
+
72
+ for arr in valid_audios:
73
  if arr.shape[0] < max_length:
74
  padded = np.pad(arr, (0, max_length - arr.shape[0]))
75
  else:
76
  padded = arr[:max_length]
77
  mixed += padded
78
 
 
79
  mixed = mixed / np.max(np.abs(mixed))
 
 
80
  _, tmp_path = tempfile.mkstemp(suffix=".wav")
81
  write(tmp_path, 16000, mixed)
82
  return tmp_path
83
  except Exception as e:
84
+ print(f"Blending error: {str(e)}")
85
  return None
86
 
87
  css = """
88
  #col-container { max-width: 800px; margin: 0 auto; }
89
  .toggle-row { margin: 1rem 0; }
90
  .prompt-box { margin-bottom: 0.5rem; }
91
+ .danger { color: #ff4444; font-weight: bold; }
92
  """
93
 
94
  with gr.Blocks(css=css) as demo:
95
  with gr.Column(elem_id="col-container"):
96
+ # Header Section
97
  gr.HTML("""
98
+ <h1 style="text-align: center;">🎶 Generate Sound Effects from Image or Text</h1>
99
+ <p style="text-align: center;">
100
+ ⚡ Powered by <a href="https://bilsimaging.com" target="_blank">Bilsimaging</a>
101
+ </p>
102
  """)
103
+
104
+ # Input Mode Toggle
105
  input_mode = gr.Radio(
106
+ choices=["Image Input", "Text Input"],
107
  value="Image Input",
108
  label="Select Input Mode",
109
  elem_classes="toggle-row"
110
  )
111
+
112
+ # Image Input Section
113
  with gr.Column(visible=True) as image_col:
114
  image_upload = gr.Image(type="filepath", label="Upload Image")
115
+ generate_desc_btn = gr.Button("Generate Description from Image", variant="primary")
116
  caption_display = gr.Textbox(label="Generated Description", interactive=False)
117
+
118
+ # Text Input Section
119
  with gr.Column(visible=False) as text_col:
120
  with gr.Row():
121
+ prompt1 = gr.Textbox(label="Sound Prompt 1", lines=2, placeholder="Enter sound description...")
122
+ prompt2 = gr.Textbox(label="Sound Prompt 2", lines=2, placeholder="Enter sound description...")
123
  additional_prompts = gr.Column()
124
  add_prompt_btn = gr.Button("➕ Add Another Prompt", variant="secondary")
125
+ gr.Markdown("<div class='danger'>Max 5 prompts for stability</div>")
126
+
127
+ # Generation Controls
128
+ generate_sound_btn = gr.Button("Generate Sound Effect", variant="primary")
129
+ audio_output = gr.Audio(label="Generated Sound Effect", interactive=False)
130
+
131
+ # Documentation Section
132
  gr.Markdown("""
133
+ ## 👥 How You Can Contribute
134
+ We welcome contributions! Contact us at [[email protected]](mailto:[email protected]).
135
+ Support us on [Ko-fi](https://ko-fi.com/bilsimaging) - Bilel Aroua
 
 
136
  """)
137
 
138
+ # Visitor Badge
139
  gr.HTML("""
140
+ <div style="text-align: center;">
141
+ <a href="https://visitorbadge.io/status?path=https://huggingface.co/spaces/Bils/Generate-Sound-Effects-from-Image">
142
+ <img src="https://api.visitorbadge.io/api/visitors?path=https://huggingface.co/spaces/Bils/Generate-Sound-Effects-from-Image&countColor=%23263759"/>
143
  </a>
144
  </div>
145
  """)
146
 
147
+ # Input Mode Toggle Handler
 
 
 
 
 
148
  input_mode.change(
149
+ lambda mode: (gr.update(visible=mode == "Image Input"), gr.update(visible=mode == "Text Input")),
150
  inputs=input_mode,
151
+ outputs=[image_col, text_col],
152
+ concurrency_limit=1
153
  )
154
 
155
+ # Image Description Generation
156
  generate_desc_btn.click(
157
+ analyze_image,
158
  inputs=image_upload,
159
+ outputs=caption_display,
160
+ concurrency_limit=2
 
 
161
  )
162
 
163
+ # Dynamic Prompt Addition
164
+ def add_prompt(current_count):
165
+ if current_count >= 5:
166
+ return current_count, gr.update()
167
+ new_count = current_count + 1
168
+ new_prompt = gr.Textbox(
169
+ label=f"Sound Prompt {new_count}",
170
+ lines=2,
171
+ visible=True,
172
+ placeholder="Enter sound description..."
173
+ )
174
+ return new_count, new_prompt
175
+
176
+ prompt_count = gr.State(2)
177
+ add_prompt_btn.click(
178
+ add_prompt,
179
+ inputs=prompt_count,
180
+ outputs=[prompt_count, additional_prompts],
181
+ concurrency_limit=1
182
  )
183
 
184
+ # Sound Generation Handler
185
+ def process_inputs(mode, image_file, caption, *prompts):
186
+ try:
187
+ if mode == "Image Input":
188
+ if not image_file:
189
+ raise gr.Error("Please upload an image")
190
+ caption = analyze_image(image_file)
191
+ prompts = [caption]
192
+ else:
193
+ prompts = [p.strip() for p in prompts if p.strip()]
194
+ if not prompts:
195
+ raise gr.Error("Please enter at least one valid prompt")
196
+
197
+ # Generate individual audio tracks
198
+ audio_tracks = []
199
+ for prompt in prompts:
200
+ if not prompt:
201
+ continue
202
+ audio = generate_audio(prompt)
203
+ if audio is not None:
204
+ audio_tracks.append(audio)
205
+
206
+ # Blend audio tracks
207
+ if not audio_tracks:
208
+ return None
209
+ return blend_audios(audio_tracks)
210
+
211
+ except Exception as e:
212
+ raise gr.Error(f"Processing error: {str(e)}")
213
+
214
+ generate_sound_btn.click(
215
+ process_inputs,
216
+ inputs=[input_mode, image_upload, caption_display, prompt1, prompt2],
217
+ outputs=audio_output,
218
+ concurrency_limit=2
219
+ )
220
 
221
  if __name__ == "__main__":
222
+ demo.launch(max_threads=4)