File size: 3,734 Bytes
80ebcb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import List, Literal, Tuple

import torch
import torch.nn.functional as F


def center_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
    num_frames, num_channels, height, width = video.shape
    crop_h, crop_w = size
    top = (height - crop_h) // 2
    left = (width - crop_w) // 2
    return video[:, :, top : top + crop_h, left : left + crop_w]


def resize_crop_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
    num_frames, num_channels, height, width = video.shape
    target_h, target_w = size
    scale = max(target_h / height, target_w / width)
    new_h, new_w = int(height * scale), int(width * scale)
    video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False)
    return center_crop_video(video, size)


def bicubic_resize_video(video: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
    num_frames, num_channels, height, width = video.shape
    video = F.interpolate(video, size=size, mode="bicubic", align_corners=False)
    return video


def find_nearest_video_resolution(
    video: torch.Tensor, resolution_buckets: List[Tuple[int, int, int]]
) -> Tuple[int, int, int]:
    num_frames, num_channels, height, width = video.shape
    aspect_ratio = width / height
    possible_buckets = [b for b in resolution_buckets if b[0] <= num_frames]

    if not possible_buckets:
        best_frame_match = min(resolution_buckets, key=lambda b: abs(b[0] - num_frames))
    else:
        best_frame_match = max(possible_buckets, key=lambda b: b[0])

    frame_filtered_buckets = [b for b in resolution_buckets if b[0] == best_frame_match[0]]

    def aspect_ratio_diff(bucket):
        return abs((bucket[2] / bucket[1]) - aspect_ratio)

    return min(frame_filtered_buckets, key=aspect_ratio_diff)


def resize_to_nearest_bucket_video(
    video: torch.Tensor,
    resolution_buckets: List[Tuple[int, int, int]],
    resize_mode: Literal["center_crop", "resize_crop", "bicubic"] = "bicubic",
) -> torch.Tensor:
    """
    Resizes a video tensor to the nearest resolution bucket using the specified mode.
    - It first finds a frame match with <= T frames.
    - Then, it selects the closest height/width bucket.

    Args:
        video (`torch.Tensor`):
            Input video tensor of shape `(B, T, C, H, W)`.
        resolution_buckets (`List[Tuple[int, int, int]]`):
            Available (num_frames, height, width) resolution buckets.
        resize_mode (`str`):
            One of ["center_crop", "resize_crop", "bicubic"].

    Returns:
        `torch.Tensor`:
            Resized video tensor of the nearest bucket resolution.
    """
    target_frames, target_h, target_w = find_nearest_video_resolution(video, resolution_buckets)

    # Adjust frame count: only interpolate frames if no lesser/equal frame count exists
    num_frames, num_channels, height, width = video.shape
    _first_frame_only = False
    if num_frames > target_frames:
        # Downsample: Select frames evenly
        indices = torch.linspace(0, num_frames - 1, target_frames).long()
        video = video[indices, :, :, :]
    elif num_frames < target_frames:
        _first_frame_only = False

    # Resize spatial resolution
    if resize_mode == "center_crop":
        return center_crop_video(video, (target_h, target_w)), _first_frame_only
    elif resize_mode == "resize_crop":
        return resize_crop_video(video, (target_h, target_w)), _first_frame_only
    elif resize_mode == "bicubic":
        return bicubic_resize_video(video, (target_h, target_w)), _first_frame_only
    else:
        raise ValueError(
            f"Invalid resize_mode: {resize_mode}. Choose from 'center_crop', 'resize_crop', or 'bicubic'."
        )