Spaces:
Runtime error
Runtime error
crypto-code
commited on
Commit
•
ead7a82
1
Parent(s):
15a96ee
Update app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,7 @@ import torchvision.transforms as transforms
|
|
20 |
import av
|
21 |
import subprocess
|
22 |
import librosa
|
|
|
23 |
|
24 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
25 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
@@ -33,7 +34,7 @@ class dotdict(dict):
|
|
33 |
|
34 |
args = dotdict(args)
|
35 |
|
36 |
-
generated_audio_files =
|
37 |
|
38 |
llama_type = args.llama_type
|
39 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
@@ -117,7 +118,7 @@ def parse_text(text, image_path, video_path, audio_path):
|
|
117 |
return text, outputs
|
118 |
|
119 |
|
120 |
-
def save_audio_to_local(audio, sec):
|
121 |
global generated_audio_files
|
122 |
if not os.path.exists('temp'):
|
123 |
os.mkdir('temp')
|
@@ -126,11 +127,11 @@ def save_audio_to_local(audio, sec):
|
|
126 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
127 |
else:
|
128 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
129 |
-
generated_audio_files.append(filename)
|
130 |
return filename
|
131 |
|
132 |
|
133 |
-
def parse_reponse(model_outputs, audio_length_in_s):
|
134 |
response = ''
|
135 |
text_outputs = []
|
136 |
for output_i, p in enumerate(model_outputs):
|
@@ -146,7 +147,7 @@ def parse_reponse(model_outputs, audio_length_in_s):
|
|
146 |
response += '<br>'
|
147 |
_temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
|
148 |
else:
|
149 |
-
filename = save_audio_to_local(m, audio_length_in_s)
|
150 |
print(filename)
|
151 |
_temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
|
152 |
response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
|
@@ -161,15 +162,15 @@ def reset_user_input():
|
|
161 |
return gr.update(value='')
|
162 |
|
163 |
|
164 |
-
def reset_dialog():
|
165 |
global generated_audio_files
|
166 |
-
generated_audio_files = []
|
167 |
return [], []
|
168 |
|
169 |
|
170 |
-
def reset_state():
|
171 |
global generated_audio_files
|
172 |
-
generated_audio_files = []
|
173 |
return None, None, None, None, [], [], []
|
174 |
|
175 |
|
@@ -218,6 +219,7 @@ def get_audio_length(filename):
|
|
218 |
|
219 |
|
220 |
def predict(
|
|
|
221 |
prompt_input,
|
222 |
image_path,
|
223 |
audio_path,
|
@@ -247,28 +249,30 @@ def predict(
|
|
247 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
248 |
video = read_video_pyav(container=container, indices=indices)
|
249 |
|
250 |
-
if len(generated_audio_files) != 0:
|
251 |
-
audio_length_in_s = get_audio_length(generated_audio_files[-1])
|
252 |
sample_rate = 24000
|
253 |
-
waveform, sr = torchaudio.load(generated_audio_files[-1])
|
254 |
if sample_rate != sr:
|
255 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
256 |
audio = torch.mean(waveform, 0)
|
257 |
audio_length_in_s = int(len(audio)//sample_rate)
|
258 |
print(f"Audio Length: {audio_length_in_s}")
|
|
|
|
|
259 |
if video_path is not None:
|
260 |
audio_length_in_s = get_video_length(video_path)
|
261 |
print(f"Video Length: {audio_length_in_s}")
|
262 |
if audio_path is not None:
|
263 |
audio_length_in_s = get_audio_length(audio_path)
|
264 |
-
generated_audio_files.append(audio_path)
|
265 |
print(f"Audio Length: {audio_length_in_s}")
|
266 |
|
267 |
print(image, video, audio)
|
268 |
response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
|
269 |
audio_length_in_s=audio_length_in_s)
|
270 |
print(response)
|
271 |
-
response_chat, response_outputs = parse_reponse(response, audio_length_in_s)
|
272 |
print('text_outputs: ', response_outputs)
|
273 |
user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
|
274 |
chatbot.append((user_chat, response_chat))
|
@@ -319,9 +323,11 @@ with gr.Blocks() as demo:
|
|
319 |
|
320 |
history = gr.State([])
|
321 |
modality_cache = gr.State([])
|
|
|
322 |
|
323 |
submitBtn.click(
|
324 |
predict, [
|
|
|
325 |
user_input,
|
326 |
image_path,
|
327 |
audio_path,
|
@@ -343,8 +349,8 @@ with gr.Blocks() as demo:
|
|
343 |
show_progress=True
|
344 |
)
|
345 |
|
346 |
-
submitBtn.click(reset_user_input, [], [user_input])
|
347 |
-
emptyBtn.click(reset_state, outputs=[
|
348 |
image_path,
|
349 |
audio_path,
|
350 |
video_path,
|
|
|
20 |
import av
|
21 |
import subprocess
|
22 |
import librosa
|
23 |
+
import uuid
|
24 |
|
25 |
args = {"model": "./ckpts/checkpoint.pth", "llama_type": "7B", "llama_dir": "./ckpts/LLaMA-2",
|
26 |
"mert_path": "m-a-p/MERT-v1-330M", "vit_path": "google/vit-base-patch16-224", "vivit_path": "google/vivit-b-16x2-kinetics400",
|
|
|
34 |
|
35 |
args = dotdict(args)
|
36 |
|
37 |
+
generated_audio_files = {}
|
38 |
|
39 |
llama_type = args.llama_type
|
40 |
llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
|
|
|
118 |
return text, outputs
|
119 |
|
120 |
|
121 |
+
def save_audio_to_local(uid, audio, sec):
|
122 |
global generated_audio_files
|
123 |
if not os.path.exists('temp'):
|
124 |
os.mkdir('temp')
|
|
|
127 |
scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
|
128 |
else:
|
129 |
scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
|
130 |
+
generated_audio_files[uid].append(filename)
|
131 |
return filename
|
132 |
|
133 |
|
134 |
+
def parse_reponse(uid, model_outputs, audio_length_in_s):
|
135 |
response = ''
|
136 |
text_outputs = []
|
137 |
for output_i, p in enumerate(model_outputs):
|
|
|
147 |
response += '<br>'
|
148 |
_temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
|
149 |
else:
|
150 |
+
filename = save_audio_to_local(uid, m, audio_length_in_s)
|
151 |
print(filename)
|
152 |
_temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
|
153 |
response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
|
|
|
162 |
return gr.update(value='')
|
163 |
|
164 |
|
165 |
+
def reset_dialog(uid):
|
166 |
global generated_audio_files
|
167 |
+
generated_audio_files[uid] = []
|
168 |
return [], []
|
169 |
|
170 |
|
171 |
+
def reset_state(uid):
|
172 |
global generated_audio_files
|
173 |
+
generated_audio_files[uid] = []
|
174 |
return None, None, None, None, [], [], []
|
175 |
|
176 |
|
|
|
219 |
|
220 |
|
221 |
def predict(
|
222 |
+
uid,
|
223 |
prompt_input,
|
224 |
image_path,
|
225 |
audio_path,
|
|
|
249 |
indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
|
250 |
video = read_video_pyav(container=container, indices=indices)
|
251 |
|
252 |
+
if uid in generated_audio_files and len(generated_audio_files[uid]) != 0:
|
253 |
+
audio_length_in_s = get_audio_length(generated_audio_files[uid][-1])
|
254 |
sample_rate = 24000
|
255 |
+
waveform, sr = torchaudio.load(generated_audio_files[uid][-1])
|
256 |
if sample_rate != sr:
|
257 |
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
|
258 |
audio = torch.mean(waveform, 0)
|
259 |
audio_length_in_s = int(len(audio)//sample_rate)
|
260 |
print(f"Audio Length: {audio_length_in_s}")
|
261 |
+
else:
|
262 |
+
generated_audio_files[uid] = []
|
263 |
if video_path is not None:
|
264 |
audio_length_in_s = get_video_length(video_path)
|
265 |
print(f"Video Length: {audio_length_in_s}")
|
266 |
if audio_path is not None:
|
267 |
audio_length_in_s = get_audio_length(audio_path)
|
268 |
+
generated_audio_files[uid].append(audio_path)
|
269 |
print(f"Audio Length: {audio_length_in_s}")
|
270 |
|
271 |
print(image, video, audio)
|
272 |
response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
|
273 |
audio_length_in_s=audio_length_in_s)
|
274 |
print(response)
|
275 |
+
response_chat, response_outputs = parse_reponse(uid, response, audio_length_in_s)
|
276 |
print('text_outputs: ', response_outputs)
|
277 |
user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
|
278 |
chatbot.append((user_chat, response_chat))
|
|
|
323 |
|
324 |
history = gr.State([])
|
325 |
modality_cache = gr.State([])
|
326 |
+
uid = gr.State(uuid.uuid4())
|
327 |
|
328 |
submitBtn.click(
|
329 |
predict, [
|
330 |
+
uid,
|
331 |
user_input,
|
332 |
image_path,
|
333 |
audio_path,
|
|
|
349 |
show_progress=True
|
350 |
)
|
351 |
|
352 |
+
submitBtn.click(reset_user_input, [uid], [user_input])
|
353 |
+
emptyBtn.click(reset_state, [uid], outputs=[
|
354 |
image_path,
|
355 |
audio_path,
|
356 |
video_path,
|