Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import spaces | |
from unittest.mock import patch | |
import gradio as gr | |
import numpy as np | |
import supervision as sv | |
import torch | |
from tqdm import tqdm | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
from utils.imports import fixed_get_imports | |
from utils.models import ( | |
run_captioning, | |
CAPTIONING_TASK, | |
run_caption_to_phrase_grounding | |
) | |
from utils.video import ( | |
create_directory, | |
remove_files_older_than, | |
generate_file_name, | |
calculate_end_frame_index | |
) | |
MARKDOWN = """ | |
# Florence-2 for Videos 🎬 | |
<div> | |
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-florence-2-on-detection-dataset.ipynb"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;"> | |
</a> | |
<a href="https://blog.roboflow.com/florence-2/"> | |
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;"> | |
</a> | |
<a href="https://arxiv.org/abs/2311.06242"> | |
<img src="https://img.shields.io/badge/arXiv-2311.06242-b31b1b.svg" alt="arXiv" style="display:inline-block;"> | |
</a> | |
</div> | |
""" | |
RESULTS = "results" | |
CHECKPOINT = "microsoft/Florence-2-base-ft" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): | |
MODEL = AutoModelForCausalLM.from_pretrained( | |
CHECKPOINT, trust_remote_code=True).to(DEVICE) | |
PROCESSOR = AutoProcessor.from_pretrained( | |
CHECKPOINT, trust_remote_code=True) | |
BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.TRACK) | |
LABEL_ANNOTATOR = sv.LabelAnnotator(color_lookup=sv.ColorLookup.TRACK) | |
TRACKER = sv.ByteTrack() | |
# creating video results directory | |
create_directory(directory_path=RESULTS) | |
def annotate_image( | |
input_image: np.ndarray, | |
detections: sv.Detections | |
) -> np.ndarray: | |
output_image = input_image.copy() | |
output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections) | |
output_image = LABEL_ANNOTATOR.annotate(output_image, detections) | |
return output_image | |
def process_video( | |
input_video: str, | |
progress=gr.Progress(track_tqdm=True) | |
) -> str: | |
# cleanup of old video files | |
remove_files_older_than(RESULTS, 30) | |
OUTPUT_LENGTH = 4 | |
video_info = sv.VideoInfo.from_video_path(input_video) | |
video_info.fps = video_info.fps // OUTPUT_LENGTH | |
total = calculate_end_frame_index(input_video, OUTPUT_LENGTH) | |
frame_generator = sv.get_video_frames_generator( | |
source_path=input_video, | |
end=total, | |
stride=OUTPUT_LENGTH | |
) | |
result_file_name = generate_file_name(extension="mp4") | |
result_file_path = os.path.join(RESULTS, result_file_name) | |
TRACKER.reset() | |
caption = None | |
with sv.VideoSink(result_file_path, video_info=video_info) as sink: | |
for _ in tqdm(range(total // OUTPUT_LENGTH), desc="Processing video..."): | |
frame = next(frame_generator) | |
if caption is None: | |
caption = run_captioning( | |
model=MODEL, | |
processor=PROCESSOR, | |
image=frame, | |
device=DEVICE | |
)[CAPTIONING_TASK] | |
detections = run_caption_to_phrase_grounding( | |
model=MODEL, | |
processor=PROCESSOR, | |
caption=caption, | |
image=frame, | |
device=DEVICE | |
) | |
detections.confidence = np.ones(len(detections)) | |
detections.class_id = np.zeros(len(detections)) | |
detections = TRACKER.update_with_detections(detections) | |
frame = annotate_image( | |
input_image=frame, | |
detections=detections | |
) | |
sink.write_frame(frame) | |
return result_file_path | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
input_video_component = gr.Video( | |
label='Input Video' | |
) | |
output_video_component = gr.Video( | |
label='Output Video' | |
) | |
with gr.Row(): | |
submit_button_component = gr.Button( | |
value='Submit', | |
scale=1, | |
variant='primary' | |
) | |
submit_button_component.click( | |
fn=process_video, | |
inputs=[ | |
input_video_component, | |
], | |
outputs=output_video_component | |
) | |
demo.launch(debug=False, show_error=True, max_threads=1) | |