fffiloni's picture
Update app.py
720f703 verified
raw
history blame
5.58 kB
import gradio as gr
from gradio_client import Client
import cv2
from moviepy.editor import *
# 1. extract and store 1 image every 5 images from video input
# 2. extract audio
# 3. for each image from extracted_images, get caption from caption model and concatenate into list
# 4. for audio, ask audio questioning model to describe sound/scene
# 5. give all to LLM, and ask it to resume, according to image caption list combined to audio caption
import re
import torch
from transformers import pipeline
zephyr_model = "HuggingFaceH4/zephyr-7b-beta"
pipe = pipeline("text-generation", model=zephyr_model, torch_dtype=torch.bfloat16, device_map="auto")
standard_sys = f"""
"""
def extract_frames(video_in, interval=24, output_format='.jpg'):
"""Extract frames from a video at a specified interval and store them in a list.
Args:
- video_in: string or path-like object pointing to the video file
- interval: integer specifying how many frames apart to extract images (default: 5)
- output_format: string indicating desired format for saved images (default: '.jpg')
Returns:
A list of strings containing paths to saved images.
"""
# Initialize variables
vidcap = cv2.VideoCapture(video_in)
frames = []
count = 0
# Loop through frames until there are no more
while True:
success, image = vidcap.read()
# Check if successful read and not past end of video
if success:
print('Read a new frame:', success)
# Save current frame if it meets criteria
if count % interval == 0:
filename = f'frame_{count // interval}{output_format}'
frames.append(filename)
cv2.imwrite(filename, image)
print(f'Saved {filename}')
# Increment counter
count += 1
# Break out of loop when done reading frames
else:
break
# Close video capture
vidcap.release()
print('Done extracting frames!')
return frames
def process_image(image_in):
client = Client("vikhyatk/moondream2")
result = client.predict(
image_in, # filepath in 'image' Image component
"Describe precisely the image in one sentence.", # str in 'Question' Textbox component
api_name="/answer_question"
#api_name="/predict"
)
print(result)
return result
def extract_audio(video_path):
# Open the video clip and extract the audio stream
audio = VideoFileClip(video_path).audio
# Set the output file path and format
output_path = 'output_audio.wav'
# Write the audio stream to disk using the AAC codec
audio.write_audiofile(output_path, codec='aac')
# Confirm that the audio file was written successfully
if os.path.exists(output_path):
print(f'Successfully wrote audio to {output_path}.')
return output_path
else:
raise FileNotFoundError(f'Failed to write audio to {output_path}.')
def get_salmonn(audio_in):
salmonn_prompt = "Please list each event in the audio in order."
client = Client("fffiloni/SALMONN-7B-gradio")
result = client.predict(
audio_in, # filepath in 'Audio' Audio component
salmonn_prompt, # str in 'User question' Textbox component
4, # float (numeric value between 1 and 10) in 'beam search numbers' Slider component
1, # float (numeric value between 0.8 and 2.0) in 'temperature' Slider component
0.9, # float (numeric value between 0.1 and 1.0) in 'top p' Slider component
api_name="/gradio_answer"
)
print(result)
return result
def llm_process(user_prompt):
agent_maker_sys = standard_sys
instruction = f"""
<|system|>
{agent_maker_sys}</s>
<|user|>
"""
prompt = f"{instruction.strip()}\n{user_prompt}</s>"
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
pattern = r'\<\|system\|\>(.*?)\<\|assistant\|\>'
cleaned_text = re.sub(pattern, '', outputs[0]["generated_text"], flags=re.DOTALL)
print(f"SUGGESTED video description: {cleaned_text}")
return cleaned_text.lstrip("\n")
def infer(video_in):
# Extract frames from a video
frame_files = extract_frames(video_in)
# Process each extracted frame and collect results in a list
processed_texts = []
for frame_file in frame_files:
text = process_image(frame_file)
processed_texts.append(text)
print(processed_texts)
# Convert processed_texts list to a string list with line breaks
string_list = '\n'.join(processed_texts)
# Extract audio from video
extracted_audio = extract_audio(video_in)
print(extracted_audio)
# Get description of audio content
audio_content_described = get_salmonn(extracted_audio)
# Assemble captions
formatted_captions = f"""
### Visual events:\n{string_list}\n ### Audio events:\n{audio_content_described}
"""
print(formatted_captions)
# Send formatted captions to LLM
#video_description_from_llm = llm_process(formatted_captions)
return formatted_captions
with gr.Blocks() as demo :
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h2 style="text-align: center;">Video description</h2>
""")
video_in = gr.Video(label="Video input")
submit_btn = gr.Button("Submit")
video_description = gr.Textbox(label="Video description")
submit_btn.click(
fn = infer,
inputs = [video_in],
outputs = [video_description]
)
demo.queue().launch()