|
import gradio as gr |
|
from gradio_client import Client |
|
import cv2 |
|
from moviepy.editor import * |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
import torch |
|
from transformers import pipeline |
|
|
|
zephyr_model = "HuggingFaceH4/zephyr-7b-beta" |
|
pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto") |
|
|
|
standard_sys = f""" |
|
You will be provided a list of visual events, and an audio description. All these informations come from a single video. |
|
List of visual events are actually images extracted from this video every 12 frames. |
|
Audio events are actually the description from the audio of the video. |
|
Your job is to use these information to provide a short resume about what is happening in the video. |
|
""" |
|
|
|
def extract_frames(video_in, interval=24, output_format='.jpg'): |
|
"""Extract frames from a video at a specified interval and store them in a list. |
|
|
|
Args: |
|
- video_in: string or path-like object pointing to the video file |
|
- interval: integer specifying how many frames apart to extract images (default: 5) |
|
- output_format: string indicating desired format for saved images (default: '.jpg') |
|
|
|
Returns: |
|
A list of strings containing paths to saved images. |
|
""" |
|
|
|
|
|
vidcap = cv2.VideoCapture(video_in) |
|
frames = [] |
|
count = 0 |
|
|
|
|
|
while True: |
|
success, image = vidcap.read() |
|
|
|
|
|
if success: |
|
print('Read a new frame:', success) |
|
|
|
|
|
if count % interval == 0: |
|
filename = f'frame_{count // interval}{output_format}' |
|
frames.append(filename) |
|
cv2.imwrite(filename, image) |
|
print(f'Saved {filename}') |
|
|
|
|
|
count += 1 |
|
|
|
|
|
else: |
|
break |
|
|
|
|
|
vidcap.release() |
|
print('Done extracting frames!') |
|
|
|
return frames |
|
|
|
def process_image(image_in): |
|
client = Client("https://vikhyatk-moondream1.hf.space/") |
|
result = client.predict( |
|
image_in, |
|
"Describe precisely the image in one sentence.", |
|
api_name="/answer_question" |
|
|
|
) |
|
print(result) |
|
return result |
|
|
|
def extract_audio(video_path): |
|
video_clip = VideoFileClip(video_path) |
|
audio_clip = video_clip.audio |
|
audio_clip.write_audiofile("output_audio.mp3") |
|
return "output_audio.mp3" |
|
|
|
def get_salmonn(audio_in): |
|
salmonn_prompt = "Please describe the audio" |
|
client = Client("fffiloni/SALMONN-7B-gradio") |
|
result = client.predict( |
|
audio_in, |
|
salmonn_prompt, |
|
4, |
|
1, |
|
0.9, |
|
api_name="/gradio_answer" |
|
) |
|
print(result) |
|
return result |
|
|
|
def llm_process(user_prompt): |
|
agent_maker_sys = standard_sys |
|
|
|
instruction = f""" |
|
<|system|> |
|
{agent_maker_sys}</s> |
|
<|user|> |
|
""" |
|
|
|
prompt = f"{instruction.strip()}\n{user_prompt}</s>" |
|
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) |
|
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>' |
|
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL) |
|
|
|
print(f"SUGGESTED video description: {cleaned_text}") |
|
return cleaned_text.lstrip("\n") |
|
|
|
def infer(video_in): |
|
|
|
frame_files = extract_frames(video_in) |
|
|
|
|
|
processed_texts = [] |
|
for frame_file in frame_files: |
|
text = process_image(frame_file) |
|
processed_texts.append(text) |
|
print(processed_texts) |
|
|
|
|
|
string_list = '\n'.join(processed_texts) |
|
|
|
|
|
extracted_audio = extract_audio(video_in) |
|
print(extracted_audio) |
|
|
|
|
|
audio_content_described = get_salmonn(extracted_audio) |
|
|
|
|
|
formatted_captions = f""" |
|
### Visual events:\n{string_list}\n ### Audio events:\n{audio_content_described} |
|
""" |
|
print(formatted_captions) |
|
|
|
|
|
video_description_from_llm = llm_process(formatted_captions) |
|
|
|
return video_description_from_llm |
|
|
|
with gr.Blocks() as demo : |
|
with gr.Column(elem_id="col-container"): |
|
gr.HTML(""" |
|
<h2 style="text-align: center;">Video description</h2> |
|
""") |
|
video_in = gr.Video(label="Video input") |
|
submit_btn = gr.Button("Submit") |
|
video_description = gr.Textbox(label="Video description") |
|
submit_btn.click( |
|
fn = infer, |
|
inputs = [video_in], |
|
outputs = [video_description] |
|
) |
|
demo.queue().launch() |