Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
import contextlib | |
import os | |
import random | |
import tempfile | |
import unittest | |
import torch | |
import torchvision.io as io | |
from densepose.data.transform import ImageResizeTransform | |
from densepose.data.video import RandomKFramesSelector, VideoKeyframeDataset | |
try: | |
import av | |
except ImportError: | |
av = None | |
# copied from torchvision test/test_io.py | |
def _create_video_frames(num_frames, height, width): | |
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) | |
data = [] | |
for i in range(num_frames): | |
xc = float(i) / num_frames | |
yc = 1 - float(i) / (2 * num_frames) | |
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 | |
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) | |
return torch.stack(data, 0) | |
# adapted from torchvision test/test_io.py | |
def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, options=None): | |
if lossless: | |
if video_codec is not None: | |
raise ValueError("video_codec can't be specified together with lossless") | |
if options is not None: | |
raise ValueError("options can't be specified together with lossless") | |
video_codec = "libx264rgb" | |
options = {"crf": "0"} | |
if video_codec is None: | |
video_codec = "libx264" | |
if options is None: | |
options = {} | |
data = _create_video_frames(num_frames, height, width) | |
with tempfile.NamedTemporaryFile(suffix=".mp4") as f: | |
f.close() | |
io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) | |
yield f.name, data | |
os.unlink(f.name) | |
class TestVideoKeyframeDataset(unittest.TestCase): | |
def test_read_keyframes_all(self): | |
with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): | |
video_list = [fname] | |
category_list = [None] | |
dataset = VideoKeyframeDataset(video_list, category_list) | |
self.assertEqual(len(dataset), 1) | |
data1, categories1 = dataset[0]["images"], dataset[0]["categories"] | |
self.assertEqual(data1.shape, torch.Size((5, 3, 300, 300))) | |
self.assertEqual(data1.dtype, torch.float32) | |
self.assertIsNone(categories1[0]) | |
return | |
self.assertTrue(False) | |
def test_read_keyframes_with_selector(self): | |
with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): | |
video_list = [fname] | |
category_list = [None] | |
random.seed(0) | |
frame_selector = RandomKFramesSelector(3) | |
dataset = VideoKeyframeDataset(video_list, category_list, frame_selector) | |
self.assertEqual(len(dataset), 1) | |
data1, categories1 = dataset[0]["images"], dataset[0]["categories"] | |
self.assertEqual(data1.shape, torch.Size((3, 3, 300, 300))) | |
self.assertEqual(data1.dtype, torch.float32) | |
self.assertIsNone(categories1[0]) | |
return | |
self.assertTrue(False) | |
def test_read_keyframes_with_selector_with_transform(self): | |
with temp_video(60, 300, 300, 5, video_codec="mpeg4") as (fname, data): | |
video_list = [fname] | |
category_list = [None] | |
random.seed(0) | |
frame_selector = RandomKFramesSelector(1) | |
transform = ImageResizeTransform() | |
dataset = VideoKeyframeDataset(video_list, category_list, frame_selector, transform) | |
data1, categories1 = dataset[0]["images"], dataset[0]["categories"] | |
self.assertEqual(len(dataset), 1) | |
self.assertEqual(data1.shape, torch.Size((1, 3, 800, 800))) | |
self.assertEqual(data1.dtype, torch.float32) | |
self.assertIsNone(categories1[0]) | |
return | |
self.assertTrue(False) | |