fffiloni commited on
Commit
d7bf027
·
verified ·
1 Parent(s): 7170008

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -18
app.py CHANGED
@@ -7,6 +7,31 @@ import json
7
  import tempfile
8
  from huggingface_hub import snapshot_download
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  num_gpus = torch.cuda.device_count()
11
  print(f"GPU AVAILABLE: {num_gpus}")
12
 
@@ -114,6 +139,11 @@ def create_temp_input_json(prompt: str, cond_image_path: str, cond_audio_path: s
114
 
115
 
116
  def infer(prompt, cond_image_path, cond_audio_path):
 
 
 
 
 
117
  # Prepare input JSON
118
  input_json_path = create_temp_input_json(prompt, cond_image_path, cond_audio_path)
119
 
@@ -140,24 +170,29 @@ def infer(prompt, cond_image_path, cond_audio_path):
140
  else:
141
  cmd = ["python3", "generate_multitalk.py"] + common_args
142
 
143
- # Log to file and stream
144
- with open("inference.log", "w") as log_file:
145
- process = subprocess.Popen(
146
- cmd,
147
- stdout=subprocess.PIPE,
148
- stderr=subprocess.STDOUT,
149
- text=True,
150
- bufsize=1
151
- )
152
- for line in process.stdout:
153
- print(line, end="")
154
- log_file.write(line)
155
- process.wait()
156
-
157
- if process.returncode != 0:
158
- raise RuntimeError("Inference failed. Check inference.log for details.")
159
-
160
- return "multi_long_multigpu_exp.mp4"
 
 
 
 
 
161
 
162
 
163
  with gr.Blocks(title="MultiTalk Inference") as demo:
 
7
  import tempfile
8
  from huggingface_hub import snapshot_download
9
 
10
+ import soundfile as sf
11
+ import tempfile
12
+ from datetime import datetime
13
+
14
+ is_shared_ui = True if "fffiloni/Meigen-MultiTalk" in os.environ['SPACE_ID'] else False
15
+
16
+ def trim_audio_to_5s_temp(audio_path, sample_rate=16000):
17
+ max_duration_sec = 5
18
+ audio, sr = sf.read(audio_path)
19
+
20
+ if sr != sample_rate:
21
+ raise ValueError(f"Expected sample rate {sample_rate}, but got {sr}")
22
+
23
+ max_samples = max_duration_sec * sample_rate
24
+ if len(audio) > max_samples:
25
+ audio = audio[:max_samples]
26
+
27
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S%f")
28
+ base_name = os.path.splitext(os.path.basename(audio_path))[0]
29
+ temp_filename = f"{base_name}_trimmed_{timestamp}.wav"
30
+ temp_path = os.path.join(tempfile.gettempdir(), temp_filename)
31
+
32
+ sf.write(temp_path, audio, samplerate=sample_rate)
33
+ return temp_path
34
+
35
  num_gpus = torch.cuda.device_count()
36
  print(f"GPU AVAILABLE: {num_gpus}")
37
 
 
139
 
140
 
141
  def infer(prompt, cond_image_path, cond_audio_path):
142
+
143
+ if is_shared_ui:
144
+ trimmed_audio_path = trim_audio_to_5s_temp(cond_audio_path)
145
+ cond_audio_path = trimmed_audio_path
146
+
147
  # Prepare input JSON
148
  input_json_path = create_temp_input_json(prompt, cond_image_path, cond_audio_path)
149
 
 
170
  else:
171
  cmd = ["python3", "generate_multitalk.py"] + common_args
172
 
173
+ try:
174
+ # Log to file and stream
175
+ with open("inference.log", "w") as log_file:
176
+ process = subprocess.Popen(
177
+ cmd,
178
+ stdout=subprocess.PIPE,
179
+ stderr=subprocess.STDOUT,
180
+ text=True,
181
+ bufsize=1
182
+ )
183
+ for line in process.stdout:
184
+ print(line, end="")
185
+ log_file.write(line)
186
+ process.wait()
187
+
188
+ if process.returncode != 0:
189
+ raise RuntimeError("Inference failed. Check inference.log for details.")
190
+
191
+ return "multi_long_multigpu_exp.mp4"
192
+
193
+ finally:
194
+ if os.path.exists(trimmed_audio_path):
195
+ os.remove(trimmed_audio_path)
196
 
197
 
198
  with gr.Blocks(title="MultiTalk Inference") as demo: