Last commit not found
import cv2 | |
import torch as th | |
import os | |
import numpy as np | |
from decord import VideoReader, cpu | |
import ffmpeg | |
class Normalize(object): | |
def __init__(self, mean, std): | |
self.mean = th.FloatTensor(mean).view(1, 3, 1, 1) | |
self.std = th.FloatTensor(std).view(1, 3, 1, 1) | |
def __call__(self, tensor): | |
tensor = (tensor - self.mean) / (self.std + 1e-8) | |
return tensor | |
class Preprocessing(object): | |
def __init__(self): | |
self.norm = Normalize( | |
mean=[0.48145466, 0.4578275, 0.40821073], | |
std=[0.26862954, 0.26130258, 0.27577711], | |
) | |
def __call__(self, tensor): | |
tensor = tensor / 255.0 | |
tensor = self.norm(tensor) | |
return tensor | |
class VideoLoader: | |
"""Pytorch video loader.""" | |
def __init__( | |
self, | |
framerate=1, | |
size=224, | |
centercrop=True, | |
): | |
self.centercrop = centercrop | |
self.size = size | |
self.framerate = framerate | |
self.preprocess = Preprocessing() | |
self.max_feats = 10 | |
self.features_dim = 768 | |
# def _get_video_dim(self, video_path): | |
# vr = VideoReader(video_path, ctx=cpu(0)) | |
# height, width, _ = vr[0].shape | |
# frame_rate = vr.get_avg_fps() | |
# return height, width, frame_rate | |
def _get_output_dim(self, h, w): | |
if isinstance(self.size, tuple) and len(self.size) == 2: | |
return self.size | |
elif h >= w: | |
return int(h * self.size / w), self.size | |
else: | |
return self.size, int(w * self.size / h) | |
def _get_video_dim(self, video_path): | |
probe = ffmpeg.probe(video_path) | |
video_stream = next( | |
(stream for stream in probe["streams"] if stream["codec_type"] == "video"), | |
None, | |
) | |
width = int(video_stream["width"]) | |
height = int(video_stream["height"]) | |
num, denum = video_stream["avg_frame_rate"].split("/") | |
frame_rate = int(num) / int(denum) | |
return height, width, frame_rate | |
def _getvideo(self, video_path): | |
if os.path.isfile(video_path): | |
print("Decoding video: {}".format(video_path)) | |
try: | |
h, w, fr = self._get_video_dim(video_path) | |
except: | |
print("ffprobe failed at: {}".format(video_path)) | |
return { | |
"video": th.zeros(1), | |
"input": video_path | |
} | |
if fr < 1: | |
print("Corrupted Frame Rate: {}".format(video_path)) | |
return { | |
"video": th.zeros(1), | |
"input": video_path | |
} | |
height, width = self._get_output_dim(h, w) | |
# resize ## | |
try: | |
cmd = ( | |
ffmpeg.input(video_path) | |
.filter("fps", fps=self.framerate) | |
.filter("scale", width, height) | |
) | |
if self.centercrop: | |
x = int((width - self.size) / 2.0) | |
y = int((height - self.size) / 2.0) | |
cmd = cmd.crop(x, y, self.size, self.size) | |
out, _ = cmd.output("pipe:", format="rawvideo", pix_fmt="rgb24").run( | |
capture_stdout=True, quiet=True | |
) | |
# try: | |
# vr = VideoReader(video_path, ctx=cpu(0)) | |
# video = vr.get_batch(range(0, len(vr), int(fr))).asnumpy() | |
# video = np.array([cv2.resize(frame, (width, height)) for frame in video]) | |
# if self.centercrop: | |
# x = int((width - self.size) / 2.0) | |
# y = int((height - self.size) / 2.0) | |
# video = video[:, y:y+self.size, x:x+self.size, :] | |
except: | |
print("ffmpeg error at: {}".format(video_path)) | |
return { | |
"video": th.zeros(1), | |
"input": video_path, | |
} | |
if self.centercrop and isinstance(self.size, int): | |
height, width = self.size, self.size | |
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3]) | |
video = th.from_numpy(video.astype("float32")) | |
video = video.permute(0, 3, 1, 2) # t,c,h,w | |
else: | |
video = th.zeros(1) | |
return {"video": video, "input": video_path} | |
def __call__(self, video_path): | |
video = self._getvideo(video_path)['video'] | |
if len(video) > self.max_feats: | |
sampled = [] | |
for j in range(self.max_feats): | |
sampled.append(video[(j * len(video)) // self.max_feats]) | |
video = th.stack(sampled) | |
video_len = self.max_feats | |
elif len(video) < self.max_feats: | |
video_len = len(video) | |
video = th.cat( | |
[video, th.zeros(self.max_feats - video_len, self.features_dim)], 0 | |
) | |
video = self.preprocess(video) | |
return video, video_len | |