mama / app.py
thongnguyen5999's picture
Update app.py
47c1c5b verified
import gradio as gr
import argparse
import shutil
import os
from video_keyframe_detector.cli import keyframeDetection
import numpy as np
import cv2
from PIL import Image
from transformers import AutoTokenizer, AutoModelForCausalLM
hf_path = 'tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B'
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
config = model.config
tokenizer = AutoTokenizer.from_pretrained(hf_path, use_fast=False, model_max_length = config.tokenizer_model_max_length,padding_side = config.tokenizer_padding_side)
def extract_keyframes(video_path, num_keyframes=12):
video_id = video_path.split('/')[-1].strip().split('.')[0]
os.makedirs("temp", exist_ok=True)
keyframeDetection(video_path, "temp", 0.2)
video_frame_list = sorted(os.listdir(os.path.join("temp", "keyFrames")), key=lambda x: int(x.split('.')[0][8:]))
os.makedirs(os.path.join("video_frames", video_id), exist_ok=True)
selected_frame_idx_set = set(np.linspace(1, len(video_frame_list) - 1, num_keyframes).astype(int))
cnt = 0
for i in range(len(video_frame_list)):
if i in selected_frame_idx_set:
source_file = os.path.join("temp", "keyFrames", video_frame_list[i])
target_file = os.path.join("video_frames", video_id, f"frame_{cnt}.jpg")
shutil.copyfile(source_file, target_file)
cnt += 1
shutil.rmtree("temp", ignore_errors=True)
def concatenate_frames(video_path):
os.makedirs("concatenated_frames", exist_ok=True)
video_id = video_path.split('/')[-1].strip().split('.')[0]
image_frame_dir = os.path.join("video_frames", video_id)
image_frame_list = sorted(os.listdir(os.path.join(image_frame_dir)), key=lambda x: int(x.split('.')[0].split('_')[1]))
img_list = []
for image_frame in image_frame_list:
img_frame = cv2.imread(os.path.join(image_frame_dir, image_frame))
img_list.append(img_frame)
img_row1 = cv2.hconcat(img_list[:4])
img_row2 = cv2.hconcat(img_list[4:8])
img_row3 = cv2.hconcat(img_list[8:12])
img_v = cv2.vconcat([img_row1, img_row2, img_row3])
cv2.imwrite(os.path.join("concatenated_frames", f"{video_id}.jpg"), img_v)
def image_parser(args):
out = args.image_file.split(args.sep)
return out
def generate_video_caption(video_path):
video_id = video_path.split('/')[-1].strip().split('.')[0]
image_file = os.path.join("concatenated_frames", f"{video_id}.jpg")
prompt = "In a short sentence, describe the process in the video."
output_text, generation_time = model.chat(prompt=prompt, image=image_file, tokenizer=tokenizer)
return output_text
def clean_files_and_folders():
shutil.rmtree("concatenated_frames")
shutil.rmtree("video_frames")
def video_to_text(video_file):
video_path = video_file.name
extract_keyframes(video_path)
concatenate_frames(video_path)
video_caption = generate_video_caption(video_path)
clean_files_and_folders()
return video_caption
iface = gr.Interface(
fn=video_to_text,
inputs=gr.File(file_types=["video"]),
outputs="text",
title="MAMA Video-Text Generation Pipeline",
description="Upload a video and get the description. Due to limited budget, we can only use TinyLLaVA on CPUs. Please only try videos which are less than 1MB. Thank you so much and Welcome to MAMA!"
)
iface.launch(share=True)