Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
End-to-End Referring Video Object Segmentation with Multimodal Transformers | |
This notebook provides a (limited) hands-on demonstration of MTTR. | |
Given a text query and a short clip based on a YouTube video, we demonstrate how MTTR can be used to segment the referred object instance throughout the video. | |
### Disclaimer | |
This is a **limited** demonstration of MTTR's performance. The model used here was trained **exclusively** on Refer-YouTube-VOS with window size `w=12` (as described in our paper). No additional training data was used whatsoever. | |
Hence, the model's performance may be limited, especially on instances from unseen categories. | |
Additionally, slow processing times may be encountered, depending on the input clip length and/or resolution, and due to HuggingFace's limited computational resources (no GPU acceleration unfortunately). | |
Finally, we emphasize that this demonstration is intended to be used for academic purposes only. We do not take any responsibility for how the created content is used or distributed. | |
""" | |
import gradio as gr | |
import torch | |
import torchvision | |
import torchvision.transforms.functional as F | |
from einops import rearrange | |
import numpy as np | |
from PIL import Image, ImageDraw, ImageOps, ImageFont | |
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip | |
from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip | |
from tqdm import trange, tqdm | |
class NestedTensor(object): | |
def __init__(self, tensors, mask): | |
self.tensors = tensors | |
self.mask = mask | |
def nested_tensor_from_videos_list(videos_list): | |
def _max_by_axis(the_list): | |
maxes = the_list[0] | |
for sublist in the_list[1:]: | |
for index, item in enumerate(sublist): | |
maxes[index] = max(maxes[index], item) | |
return maxes | |
max_size = _max_by_axis([list(img.shape) for img in videos_list]) | |
padded_batch_shape = [len(videos_list)] + max_size | |
b, t, c, h, w = padded_batch_shape | |
dtype = videos_list[0].dtype | |
device = videos_list[0].device | |
padded_videos = torch.zeros(padded_batch_shape, dtype=dtype, device=device) | |
videos_pad_masks = torch.ones((b, t, h, w), dtype=torch.bool, device=device) | |
for vid_frames, pad_vid_frames, vid_pad_m in zip(videos_list, padded_videos, videos_pad_masks): | |
pad_vid_frames[:vid_frames.shape[0], :, :vid_frames.shape[2], :vid_frames.shape[3]].copy_(vid_frames) | |
vid_pad_m[:vid_frames.shape[0], :vid_frames.shape[2], :vid_frames.shape[3]] = False | |
return NestedTensor(padded_videos.transpose(0, 1), videos_pad_masks.transpose(0, 1)) | |
def apply_mask(image, mask, color, transparency=0.7): | |
mask = mask[..., np.newaxis].repeat(repeats=3, axis=2) | |
mask = mask * transparency | |
color_matrix = np.ones(image.shape, dtype=np.float) * color | |
out_image = color_matrix * mask + image * (1.0 - mask) | |
return out_image | |
def process(text_query, full_video_path): | |
start_pt, end_pt = 0, 10 | |
input_clip_path = '/tmp/input.mp4' | |
# extract the relevant subclip: | |
with VideoFileClip(full_video_path) as video: | |
subclip = video.subclip(start_pt, end_pt) | |
subclip.write_videofile(input_clip_path) | |
checkpoint_path ='./refer-youtube-vos_window-12.pth.tar' | |
model, postprocessor = torch.hub.load('Randl/MTTR:main','mttr_refer_youtube_vos', get_weights=False) | |
model_state_dict = torch.load(checkpoint_path, map_location='cpu') | |
if 'model_state_dict' in model_state_dict.keys(): | |
model_state_dict = model_state_dict['model_state_dict'] | |
model.load_state_dict(model_state_dict, strict=True) | |
text_queries= [text_query] | |
window_length = 24 # length of window during inference | |
window_overlap = 6 # overlap (in frames) between consecutive windows | |
with torch.inference_mode(): | |
# read and preprocess the video clip: | |
video, audio, meta = torchvision.io.read_video(filename=input_clip_path) | |
video = rearrange(video, 't h w c -> t c h w') | |
input_video = F.resize(video, size=360, max_size=640) | |
input_video = input_video.to(torch.float).div_(255) | |
input_video = F.normalize(input_video, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
video_metadata = {'resized_frame_size': input_video.shape[-2:], 'original_frame_size': video.shape[-2:]} | |
# partition the clip into overlapping windows of frames: | |
windows = [input_video[i:i+window_length] for i in range(0, len(input_video), window_length - window_overlap)] | |
# clean up the text queries: | |
text_queries = [" ".join(q.lower().split()) for q in text_queries] | |
pred_masks_per_query = [] | |
t, _, h, w = video.shape | |
for text_query in tqdm(text_queries, desc='text queries'): | |
pred_masks = torch.zeros(size=(t, 1, h, w)) | |
for i, window in enumerate(tqdm(windows, desc='windows')): | |
window = nested_tensor_from_videos_list([window]) | |
valid_indices = torch.arange(len(window.tensors)) | |
outputs = model(window, valid_indices, [text_query]) | |
window_masks = postprocessor(outputs, [video_metadata], window.tensors.shape[-2:])[0]['pred_masks'] | |
win_start_idx = i*(window_length-window_overlap) | |
pred_masks[win_start_idx:win_start_idx + window_length] = window_masks | |
pred_masks_per_query.append(pred_masks) | |
"""Finally, we apply the generated instance masks and their corresponding text queries on the input clip for visualization:""" | |
# RGB colors for instance masks: | |
light_blue = (41, 171, 226) | |
purple = (237, 30, 121) | |
dark_green = (35, 161, 90) | |
orange = (255, 148, 59) | |
colors = np.array([light_blue, purple, dark_green, orange]) | |
# width (in pixels) of the black strip above the video on which the text queries will be displayed: | |
text_border_height_per_query = 40 | |
video_np = rearrange(video, 't c h w -> t h w c').numpy() / 255.0 | |
# del video | |
pred_masks_per_frame = rearrange(torch.stack(pred_masks_per_query), 'q t 1 h w -> t q h w').numpy() | |
masked_video = [] | |
for vid_frame, frame_masks in tqdm(zip(video_np, pred_masks_per_frame), total=len(video_np), desc='applying masks...'): | |
# apply the masks: | |
for inst_mask, color in zip(frame_masks, colors): | |
vid_frame = apply_mask(vid_frame, inst_mask, color / 255.0) | |
vid_frame = Image.fromarray((vid_frame * 255).astype(np.uint8)) | |
# visualize the text queries: | |
vid_frame = ImageOps.expand(vid_frame, border=(0, len(text_queries)*text_border_height_per_query, 0, 0)) | |
W, H = vid_frame.size | |
draw = ImageDraw.Draw(vid_frame) | |
font = ImageFont.truetype(font='LiberationSans-Regular.ttf', size=30) | |
for i, (text_query, color) in enumerate(zip(text_queries, colors), start=1): | |
w, h = draw.textsize(text_query, font=font) | |
draw.text(((W - w) / 2, (text_border_height_per_query * i) - h - 8), | |
text_query, fill=tuple(color) + (255,), font=font) | |
masked_video.append(np.array(vid_frame)) | |
# generate and save the output clip: | |
output_clip_path = '/tmp/output_clip.mp4' | |
clip = ImageSequenceClip(sequence=masked_video, fps=meta['video_fps']) | |
clip = clip.set_audio(AudioFileClip(input_clip_path)) | |
clip.write_videofile(output_clip_path, fps=meta['video_fps'], audio=True) | |
del masked_video | |
return output_clip_path | |
title = "End-to-End Referring Video Object Segmentation with Multimodal Transformers - Interactive Demo" | |
description = "This notebook provides a (limited) hands-on demonstration of MTTR. Given a text query and a short clip based on a YouTube video, we demonstrate how MTTR can be used to segment the referred object instance throughout the video. To use it, upload an .mp4 video file and enter a text query which describes one of the object instances in that video." | |
article = "Check out [MTTR's GitHub page](https://github.com/mttr2021/MTTR) for more info about this project. <br> Also, check out our [Colab notebool](https://gradio.app/docs/) for much faster processing (GPU accelerated) and more options! <br> **Disclaimer:** <br> This is a **limited** demonstration of MTTR's performance. The model used here was trained **exclusively** on Refer-YouTube-VOS with window size `w=12` (as described in our paper). No additional training data was used whatsoever. Hence, the model's performance may be limited, especially on instances from unseen categories. <br> Additionally, slow processing times may be encountered, depending on the input clip length and/or resolution, and due to HuggingFace's limited computational resources (no GPU acceleration unfortunately). <br> Finally, we emphasize that this demonstration is intended to be used for academic purposes only. We do not take any responsibility for how the created content is used or distributed. <br> <p style='text-align: center'><a href='https://github.com/mttr2021/MTTR'>Github Repo</a></p>" | |
examples = [['guy in white shirt performing tricks on a bike', 'bike_tricks_2.mp4'], | |
['a man riding a surfboard', 'surfing.mp4'], | |
['a guy performing tricks on a skateboard', 'skateboarding.mp4'], | |
['man in red shirt playing tennis', 'tennis.mp4'], | |
['brown and black dog playing', 'dogs_playing_1.mp4'], | |
['a dog to the left playing with a toy', 'dogs_playing_2.mp4'], | |
['person in blue riding a bike', 'blue_biker_riding.mp4'], | |
['a dog to the right', 'dog_and_cat.mp4'], | |
['a person hugging a dog', 'girl_hugging_dog.mp4'], | |
['a black bike used to perform tricks', 'bike_tricks_1.mp4']] | |
iface = gr.Interface(fn=process, | |
inputs=[gr.inputs.Textbox(label="text query"), gr.inputs.Video(label="input video - first 10 seconds are used")], | |
outputs='video', | |
title=title, | |
description=description, | |
enable_queue=True, | |
examples=examples, | |
article=article) | |
iface.launch(debug=True) | |