Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
from gradio_client import Client
|
|
|
|
|
3 |
|
4 |
# 1. extract and store 1 image every 5 images from video input
|
5 |
# 2. extract audio
|
@@ -7,17 +9,150 @@ from gradio_client import Client
|
|
7 |
# 4. for audio, ask audio questioning model to describe sound/scene
|
8 |
# 5. give all to LLM, and ask it to resume, according to image caption list combined to audio caption
|
9 |
|
10 |
-
|
|
|
|
|
11 |
|
12 |
-
|
|
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def infer(video_in):
|
|
|
|
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
with gr.Blocks() as demo :
|
23 |
with gr.Column(elem_id="col-container"):
|
@@ -25,7 +160,7 @@ with gr.Blocks() as demo :
|
|
25 |
<h2 style="text-align: center;">Video description</h2>
|
26 |
""")
|
27 |
video_in = gr.Video(label="Video input")
|
28 |
-
submit_btn = gr.Button("
|
29 |
video_description = gr.Textbox(label="Video description")
|
30 |
submit_btn.click(
|
31 |
fn = infer,
|
|
|
1 |
import gradio as gr
|
2 |
from gradio_client import Client
|
3 |
+
import cv2
|
4 |
+
from moviepy.editor import *
|
5 |
|
6 |
# 1. extract and store 1 image every 5 images from video input
|
7 |
# 2. extract audio
|
|
|
9 |
# 4. for audio, ask audio questioning model to describe sound/scene
|
10 |
# 5. give all to LLM, and ask it to resume, according to image caption list combined to audio caption
|
11 |
|
12 |
+
import re
|
13 |
+
import torch
|
14 |
+
from transformers import pipeline
|
15 |
|
16 |
+
zephyr_model = "HuggingFaceH4/zephyr-7b-beta"
|
17 |
+
pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto")
|
18 |
|
19 |
+
standard_sys = f"""
|
20 |
|
21 |
+
"""
|
22 |
+
|
23 |
+
def extract_frames(video_in, interval=12, output_format='.jpg'):
|
24 |
+
"""Extract frames from a video at a specified interval and store them in a list.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
- video_in: string or path-like object pointing to the video file
|
28 |
+
- interval: integer specifying how many frames apart to extract images (default: 5)
|
29 |
+
- output_format: string indicating desired format for saved images (default: '.jpg')
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
A list of strings containing paths to saved images.
|
33 |
+
"""
|
34 |
+
|
35 |
+
# Initialize variables
|
36 |
+
vidcap = cv2.VideoCapture(video_in)
|
37 |
+
frames = []
|
38 |
+
count = 0
|
39 |
+
|
40 |
+
# Loop through frames until there are no more
|
41 |
+
while True:
|
42 |
+
success, image = vidcap.read()
|
43 |
+
|
44 |
+
# Check if successful read and not past end of video
|
45 |
+
if success:
|
46 |
+
print('Read a new frame:', success)
|
47 |
+
|
48 |
+
# Save current frame if it meets criteria
|
49 |
+
if count % interval == 0:
|
50 |
+
filename = f'frame_{count // interval}{output_format}'
|
51 |
+
frames.append(filename)
|
52 |
+
cv2.imwrite(filename, image)
|
53 |
+
print(f'Saved {filename}')
|
54 |
+
|
55 |
+
# Increment counter
|
56 |
+
count += 1
|
57 |
+
|
58 |
+
# Break out of loop when done reading frames
|
59 |
+
else:
|
60 |
+
break
|
61 |
+
|
62 |
+
# Close video capture
|
63 |
+
vidcap.release()
|
64 |
+
print('Done extracting frames!')
|
65 |
+
|
66 |
+
return frames
|
67 |
+
|
68 |
+
def process_image(image_in):
|
69 |
+
client = Client("vikhyatk/moondream2")
|
70 |
+
result = client.predict(
|
71 |
+
image_in, # filepath in 'image' Image component
|
72 |
+
"Describe precisely the image in one sentence.", # str in 'Question' Textbox component
|
73 |
+
api_name="/answer_question"
|
74 |
+
#api_name="/predict"
|
75 |
+
)
|
76 |
+
print(result)
|
77 |
+
return result
|
78 |
+
|
79 |
+
def extract_audio(video_path):
|
80 |
+
# Open the video clip and extract the audio stream
|
81 |
+
audio = VideoFileClip(video_path).audio
|
82 |
+
|
83 |
+
# Set the output file path and format
|
84 |
+
output_path = 'output_audio.wav'
|
85 |
+
|
86 |
+
# Write the audio stream to disk using the AAC codec
|
87 |
+
audio.write_audiofile(output_path, codec='aac')
|
88 |
+
|
89 |
+
# Confirm that the audio file was written successfully
|
90 |
+
if os.path.exists(output_path):
|
91 |
+
print(f'Successfully wrote audio to {output_path}.')
|
92 |
+
return output_path
|
93 |
+
else:
|
94 |
+
raise FileNotFoundError(f'Failed to write audio to {output_path}.')
|
95 |
+
|
96 |
+
def get_salmonn(audio_in):
|
97 |
+
salmonn_prompt = "Please list each event in the audio in order."
|
98 |
+
client = Client("fffiloni/SALMONN-7B-gradio")
|
99 |
+
result = client.predict(
|
100 |
+
audio_in, # filepath in 'Audio' Audio component
|
101 |
+
salmonn_prompt, # str in 'User question' Textbox component
|
102 |
+
4, # float (numeric value between 1 and 10) in 'beam search numbers' Slider component
|
103 |
+
1, # float (numeric value between 0.8 and 2.0) in 'temperature' Slider component
|
104 |
+
0.9, # float (numeric value between 0.1 and 1.0) in 'top p' Slider component
|
105 |
+
api_name="/gradio_answer"
|
106 |
+
)
|
107 |
+
print(result)
|
108 |
+
return result
|
109 |
+
|
110 |
+
def llm_process(user_prompt)
|
111 |
+
agent_maker_sys = standard_sys
|
112 |
+
|
113 |
+
instruction = f"""
|
114 |
+
<|system|>
|
115 |
+
{agent_maker_sys}</s>
|
116 |
+
<|user|>
|
117 |
+
"""
|
118 |
+
|
119 |
+
prompt = f"{instruction.strip()}\n{user_prompt}</s>"
|
120 |
+
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
|
121 |
+
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
|
122 |
+
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
|
123 |
+
|
124 |
+
print(f"SUGGESTED video description: {cleaned_text}")
|
125 |
+
return cleaned_text.lstrip("\n")
|
126 |
|
127 |
def infer(video_in):
|
128 |
+
# Extract frames from a video
|
129 |
+
frame_files = extract_frames(video_in)
|
130 |
|
131 |
+
# Process each extracted frame and collect results in a list
|
132 |
+
processed_texts = []
|
133 |
+
for frame_file in frame_files:
|
134 |
+
text = process_image(frame_file)
|
135 |
+
processed_texts.append(text)
|
136 |
+
print(processed_texts)
|
137 |
+
|
138 |
+
# Convert processed_texts list to a string list with line breaks
|
139 |
+
string_list = '\n'.join(processed_texts)
|
140 |
+
|
141 |
+
# Extract audio from video
|
142 |
+
extracted_audio = extract_audio(video_in)
|
143 |
+
|
144 |
+
# Get description of audio content
|
145 |
+
audio_content_described = get_salmonn(extracted_audio)
|
146 |
+
|
147 |
+
# Assemble captions
|
148 |
+
formatted_captions = f"""
|
149 |
+
### Visual events:\n{string_list}\n ### Audio events:\n{audio_content_described}
|
150 |
+
"""
|
151 |
+
print(formatted_captions)
|
152 |
+
|
153 |
+
# Send formatted captions to LLM
|
154 |
+
#video_description_from_llm = llm_process(formatted_captions)
|
155 |
+
return formatted_captions
|
156 |
|
157 |
with gr.Blocks() as demo :
|
158 |
with gr.Column(elem_id="col-container"):
|
|
|
160 |
<h2 style="text-align: center;">Video description</h2>
|
161 |
""")
|
162 |
video_in = gr.Video(label="Video input")
|
163 |
+
submit_btn = gr.Button("Submit")
|
164 |
video_description = gr.Textbox(label="Video description")
|
165 |
submit_btn.click(
|
166 |
fn = infer,
|