|
import torch |
|
import random |
|
import numbers |
|
from torchvision.transforms import RandomCrop, RandomResizedCrop |
|
from PIL import Image |
|
from torchvision.utils import _log_api_usage_once |
|
|
|
def _is_tensor_video_clip(clip): |
|
if not torch.is_tensor(clip): |
|
raise TypeError("clip should be Tensor. Got %s" % type(clip)) |
|
|
|
if not clip.ndimension() == 4: |
|
raise ValueError("clip should be 4D. Got %dD" % clip.dim()) |
|
|
|
return True |
|
|
|
|
|
def center_crop_arr(pil_image, image_size): |
|
""" |
|
Center cropping implementation from ADM. |
|
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
|
""" |
|
while min(*pil_image.size) >= 2 * image_size: |
|
pil_image = pil_image.resize( |
|
tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
|
) |
|
|
|
scale = image_size / min(*pil_image.size) |
|
pil_image = pil_image.resize( |
|
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
|
) |
|
|
|
arr = np.array(pil_image) |
|
crop_y = (arr.shape[0] - image_size) // 2 |
|
crop_x = (arr.shape[1] - image_size) // 2 |
|
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
|
|
|
|
|
def crop(clip, i, j, h, w): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
""" |
|
if len(clip.size()) != 4: |
|
raise ValueError("clip should be a 4D tensor") |
|
return clip[..., i : i + h, j : j + w] |
|
|
|
|
|
def resize(clip, target_size, interpolation_mode): |
|
if len(target_size) != 2: |
|
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") |
|
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) |
|
|
|
def resize_scale(clip, target_size, interpolation_mode): |
|
if len(target_size) != 2: |
|
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") |
|
H, W = clip.size(-2), clip.size(-1) |
|
scale_ = target_size[0] / min(H, W) |
|
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) |
|
|
|
def resize_with_scale_factor(clip, scale_factor, interpolation_mode): |
|
return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False) |
|
|
|
def resize_scale_with_height(clip, target_size, interpolation_mode): |
|
H, W = clip.size(-2), clip.size(-1) |
|
scale_ = target_size / H |
|
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) |
|
|
|
def resize_scale_with_weight(clip, target_size, interpolation_mode): |
|
H, W = clip.size(-2), clip.size(-1) |
|
scale_ = target_size / W |
|
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) |
|
|
|
|
|
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): |
|
""" |
|
Do spatial cropping and resizing to the video clip |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
i (int): i in (i,j) i.e coordinates of the upper left corner. |
|
j (int): j in (i,j) i.e coordinates of the upper left corner. |
|
h (int): Height of the cropped region. |
|
w (int): Width of the cropped region. |
|
size (tuple(int, int)): height and width of resized clip |
|
Returns: |
|
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) |
|
""" |
|
if not _is_tensor_video_clip(clip): |
|
raise ValueError("clip should be a 4D torch.tensor") |
|
clip = crop(clip, i, j, h, w) |
|
clip = resize(clip, size, interpolation_mode) |
|
return clip |
|
|
|
|
|
def center_crop(clip, crop_size): |
|
if not _is_tensor_video_clip(clip): |
|
raise ValueError("clip should be a 4D torch.tensor") |
|
h, w = clip.size(-2), clip.size(-1) |
|
|
|
th, tw = crop_size |
|
if h < th or w < tw: |
|
|
|
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) |
|
|
|
i = int(round((h - th) / 2.0)) |
|
j = int(round((w - tw) / 2.0)) |
|
return crop(clip, i, j, th, tw), i, j |
|
|
|
|
|
def center_crop_using_short_edge(clip): |
|
if not _is_tensor_video_clip(clip): |
|
raise ValueError("clip should be a 4D torch.tensor") |
|
h, w = clip.size(-2), clip.size(-1) |
|
if h < w: |
|
th, tw = h, h |
|
i = 0 |
|
j = int(round((w - tw) / 2.0)) |
|
else: |
|
th, tw = w, w |
|
i = int(round((h - th) / 2.0)) |
|
j = 0 |
|
return crop(clip, i, j, th, tw) |
|
|
|
|
|
def random_shift_crop(clip): |
|
''' |
|
Slide along the long edge, with the short edge as crop size |
|
''' |
|
if not _is_tensor_video_clip(clip): |
|
raise ValueError("clip should be a 4D torch.tensor") |
|
h, w = clip.size(-2), clip.size(-1) |
|
|
|
if h <= w: |
|
long_edge = w |
|
short_edge = h |
|
else: |
|
long_edge = h |
|
short_edge =w |
|
|
|
th, tw = short_edge, short_edge |
|
|
|
i = torch.randint(0, h - th + 1, size=(1,)).item() |
|
j = torch.randint(0, w - tw + 1, size=(1,)).item() |
|
return crop(clip, i, j, th, tw), i, j |
|
|
|
def random_crop(clip, crop_size): |
|
if not _is_tensor_video_clip(clip): |
|
raise ValueError("clip should be a 4D torch.tensor") |
|
h, w = clip.size(-2), clip.size(-1) |
|
th, tw = crop_size[-2], crop_size[-1] |
|
|
|
if h < th or w < tw: |
|
raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) |
|
|
|
i = torch.randint(0, h - th + 1, size=(1,)).item() |
|
j = torch.randint(0, w - tw + 1, size=(1,)).item() |
|
clip_crop = crop(clip, i, j, th, tw) |
|
return clip_crop, i, j |
|
|
|
|
|
def to_tensor(clip): |
|
""" |
|
Convert tensor data type from uint8 to float, divide value by 255.0 and |
|
permute the dimensions of clip tensor |
|
Args: |
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) |
|
Return: |
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) |
|
""" |
|
_is_tensor_video_clip(clip) |
|
if not clip.dtype == torch.uint8: |
|
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) |
|
|
|
return clip.float() / 255.0 |
|
|
|
|
|
def normalize(clip, mean, std, inplace=False): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) |
|
mean (tuple): pixel RGB mean. Size is (3) |
|
std (tuple): pixel standard deviation. Size is (3) |
|
Returns: |
|
normalized clip (torch.tensor): Size is (T, C, H, W) |
|
""" |
|
if not _is_tensor_video_clip(clip): |
|
raise ValueError("clip should be a 4D torch.tensor") |
|
if not inplace: |
|
clip = clip.clone() |
|
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) |
|
|
|
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) |
|
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) |
|
return clip |
|
|
|
|
|
def hflip(clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) |
|
Returns: |
|
flipped clip (torch.tensor): Size is (T, C, H, W) |
|
""" |
|
if not _is_tensor_video_clip(clip): |
|
raise ValueError("clip should be a 4D torch.tensor") |
|
return clip.flip(-1) |
|
|
|
|
|
class RandomCropVideo: |
|
def __init__(self, size): |
|
if isinstance(size, numbers.Number): |
|
self.size = (int(size), int(size)) |
|
else: |
|
self.size = size |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: randomly cropped video clip. |
|
size is (T, C, OH, OW) |
|
""" |
|
i, j, h, w = self.get_params(clip) |
|
return crop(clip, i, j, h, w) |
|
|
|
def get_params(self, clip): |
|
h, w = clip.shape[-2:] |
|
th, tw = self.size |
|
|
|
if h < th or w < tw: |
|
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") |
|
|
|
if w == tw and h == th: |
|
return 0, 0, h, w |
|
|
|
i = torch.randint(0, h - th + 1, size=(1,)).item() |
|
j = torch.randint(0, w - tw + 1, size=(1,)).item() |
|
|
|
return i, j, th, tw |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size})" |
|
|
|
class CenterCropResizeVideo: |
|
''' |
|
First use the short side for cropping length, |
|
center crop video, then resize to the specified size |
|
''' |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: scale resized / center cropped video clip. |
|
size is (T, C, crop_size, crop_size) |
|
""" |
|
|
|
clip_center_crop = center_crop_using_short_edge(clip) |
|
|
|
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) |
|
return clip_center_crop_resize |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
|
class SDXL: |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: scale resized / center cropped video clip. |
|
size is (T, C, crop_size, crop_size) |
|
""" |
|
|
|
ori_h, ori_w = clip.size(-2), clip.size(-1) |
|
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ori_h >= tar_h and ori_w >= tar_w: |
|
tar_h_div_ori_h = tar_h / ori_h |
|
tar_w_div_ori_w = tar_w / ori_w |
|
if tar_h_div_ori_h > tar_w_div_ori_w: |
|
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode) |
|
else: |
|
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode) |
|
ori_h, ori_w = clip.size(-2), clip.size(-1) |
|
clip_tar_crop, i, j = random_crop(clip, self.size) |
|
else: |
|
tar_h_div_ori_h = tar_h / ori_h |
|
tar_w_div_ori_w = tar_w / ori_w |
|
if tar_h_div_ori_h > tar_w_div_ori_w: |
|
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode) |
|
else: |
|
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode) |
|
clip_tar_crop, i, j = random_crop(clip, self.size) |
|
return clip_tar_crop, ori_h, ori_w, i, j |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
|
class SDXLCenterCrop: |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: scale resized / center cropped video clip. |
|
size is (T, C, crop_size, crop_size) |
|
""" |
|
|
|
ori_h, ori_w = clip.size(-2), clip.size(-1) |
|
tar_h, tar_w = self.size[0] + 1, self.size[1] + 1 |
|
tar_h_div_ori_h = tar_h / ori_h |
|
tar_w_div_ori_w = tar_w / ori_w |
|
|
|
if tar_h_div_ori_h > tar_w_div_ori_w: |
|
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode) |
|
|
|
else: |
|
clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode) |
|
|
|
|
|
|
|
clip_tar_crop, i, j = center_crop(clip, self.size) |
|
|
|
|
|
return clip_tar_crop, ori_h, ori_w, i, j |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
|
class InternVideo320512: |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: scale resized / center cropped video clip. |
|
size is (T, C, crop_size, crop_size) |
|
""" |
|
|
|
h, w = clip.size(-2), clip.size(-1) |
|
|
|
if h < 320: |
|
clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode) |
|
|
|
if w < 512: |
|
clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode) |
|
|
|
|
|
clip_center_crop = center_crop(clip, self.size) |
|
clip_center_crop_no_subtitles = center_crop(clip, (220, 352)) |
|
clip_center_resize = resize(clip_center_crop_no_subtitles, target_size=self.size, interpolation_mode=self.interpolation_mode) |
|
|
|
return clip_center_resize |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
class CenterCropVideo: |
|
''' |
|
First scale to the specified size in equal proportion to the short edge, |
|
then center cropping |
|
''' |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: scale resized / center cropped video clip. |
|
size is (T, C, crop_size, crop_size) |
|
""" |
|
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) |
|
clip_center_crop = center_crop(clip_resize, self.size) |
|
return clip_center_crop |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
class KineticsRandomCropResizeVideo: |
|
''' |
|
Slide along the long edge, with the short edge as crop size. And resie to the desired size. |
|
''' |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
def __call__(self, clip): |
|
clip_random_crop = random_shift_crop(clip) |
|
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) |
|
return clip_resize |
|
|
|
class ResizeVideo(): |
|
''' |
|
First use the short side for cropping length, |
|
center crop video, then resize to the specified size |
|
''' |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: scale resized / center cropped video clip. |
|
size is (T, C, crop_size, crop_size) |
|
""" |
|
clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) |
|
return clip_resize |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
class CenterCropVideo: |
|
def __init__( |
|
self, |
|
size, |
|
interpolation_mode="bilinear", |
|
): |
|
if isinstance(size, tuple): |
|
if len(size) != 2: |
|
raise ValueError(f"size should be tuple (height, width), instead got {size}") |
|
self.size = size |
|
else: |
|
self.size = (size, size) |
|
|
|
self.interpolation_mode = interpolation_mode |
|
|
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
|
Returns: |
|
torch.tensor: center cropped video clip. |
|
size is (T, C, crop_size, crop_size) |
|
""" |
|
clip_center_crop = center_crop(clip, self.size) |
|
return clip_center_crop |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
|
class NormalizeVideo: |
|
""" |
|
Normalize the video clip by mean subtraction and division by standard deviation |
|
Args: |
|
mean (3-tuple): pixel RGB mean |
|
std (3-tuple): pixel RGB standard deviation |
|
inplace (boolean): whether do in-place normalization |
|
""" |
|
|
|
def __init__(self, mean, std, inplace=False): |
|
self.mean = mean |
|
self.std = std |
|
self.inplace = inplace |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) |
|
""" |
|
return normalize(clip, self.mean, self.std, self.inplace) |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" |
|
|
|
|
|
class ToTensorVideo: |
|
""" |
|
Convert tensor data type from uint8 to float, divide value by 255.0 and |
|
permute the dimensions of clip tensor |
|
""" |
|
|
|
def __init__(self): |
|
pass |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) |
|
Return: |
|
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) |
|
""" |
|
return to_tensor(clip) |
|
|
|
def __repr__(self) -> str: |
|
return self.__class__.__name__ |
|
|
|
|
|
class RandomHorizontalFlipVideo: |
|
""" |
|
Flip the video clip along the horizontal direction with a given probability |
|
Args: |
|
p (float): probability of the clip being flipped. Default value is 0.5 |
|
""" |
|
|
|
def __init__(self, p=0.5): |
|
self.p = p |
|
|
|
def __call__(self, clip): |
|
""" |
|
Args: |
|
clip (torch.tensor): Size is (T, C, H, W) |
|
Return: |
|
clip (torch.tensor): Size is (T, C, H, W) |
|
""" |
|
if random.random() < self.p: |
|
clip = hflip(clip) |
|
return clip |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}(p={self.p})" |
|
|
|
class Compose: |
|
"""Composes several transforms together. This transform does not support torchscript. |
|
Please, see the note below. |
|
|
|
Args: |
|
transforms (list of ``Transform`` objects): list of transforms to compose. |
|
|
|
Example: |
|
>>> transforms.Compose([ |
|
>>> transforms.CenterCrop(10), |
|
>>> transforms.PILToTensor(), |
|
>>> transforms.ConvertImageDtype(torch.float), |
|
>>> ]) |
|
|
|
.. note:: |
|
In order to script the transformations, please use ``torch.nn.Sequential`` as below. |
|
|
|
>>> transforms = torch.nn.Sequential( |
|
>>> transforms.CenterCrop(10), |
|
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
>>> ) |
|
>>> scripted_transforms = torch.jit.script(transforms) |
|
|
|
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require |
|
`lambda` functions or ``PIL.Image``. |
|
|
|
""" |
|
|
|
def __init__(self, transforms): |
|
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
|
_log_api_usage_once(self) |
|
self.transforms = transforms |
|
|
|
def __call__(self, img): |
|
for t in self.transforms: |
|
if isinstance(t, SDXLCenterCrop) or isinstance(t, SDXL): |
|
img, ori_h, ori_w, crops_coords_top, crops_coords_left = t(img) |
|
else: |
|
img = t(img) |
|
return img, ori_h, ori_w, crops_coords_top, crops_coords_left |
|
|
|
def __repr__(self) -> str: |
|
format_string = self.__class__.__name__ + "(" |
|
for t in self.transforms: |
|
format_string += "\n" |
|
format_string += f" {t}" |
|
format_string += "\n)" |
|
return format_string |
|
|
|
|
|
|
|
|
|
class TemporalRandomCrop(object): |
|
"""Temporally crop the given frame indices at a random location. |
|
|
|
Args: |
|
size (int): Desired length of frames will be seen in the model. |
|
""" |
|
|
|
def __init__(self, size): |
|
self.size = size |
|
|
|
def __call__(self, total_frames): |
|
rand_end = max(0, total_frames - self.size - 1) |
|
begin_index = random.randint(0, rand_end) |
|
end_index = min(begin_index + self.size, total_frames) |
|
return begin_index, end_index |
|
|
|
|
|
if __name__ == '__main__': |
|
from torchvision import transforms |
|
import torchvision.io as io |
|
import numpy as np |
|
from torchvision.utils import save_image |
|
import os |
|
|
|
vframes, aframes, info = io.read_video( |
|
filename='./v_Archery_g01_c03.avi', |
|
pts_unit='sec', |
|
output_format='TCHW' |
|
) |
|
|
|
trans = transforms.Compose([ |
|
ToTensorVideo(), |
|
RandomHorizontalFlipVideo(), |
|
UCFCenterCropVideo(512), |
|
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
|
]) |
|
|
|
target_video_len = 32 |
|
frame_interval = 1 |
|
total_frames = len(vframes) |
|
print(total_frames) |
|
|
|
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) |
|
|
|
|
|
|
|
start_frame_ind, end_frame_ind = temporal_sample(total_frames) |
|
|
|
|
|
assert end_frame_ind - start_frame_ind >= target_video_len |
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) |
|
print(frame_indice) |
|
|
|
select_vframes = vframes[frame_indice] |
|
print(select_vframes.shape) |
|
print(select_vframes.dtype) |
|
|
|
select_vframes_trans = trans(select_vframes) |
|
print(select_vframes_trans.shape) |
|
print(select_vframes_trans.dtype) |
|
|
|
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) |
|
print(select_vframes_trans_int.dtype) |
|
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) |
|
|
|
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) |
|
|
|
for i in range(target_video_len): |
|
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1)) |