File size: 2,999 Bytes
ad5354d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

import copy

import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F

from src.efficientvit.models.utils import torch_random_choices

__all__ = [
    "RRSController",
    "get_interpolate",
    "MyRandomResizedCrop",
]


class RRSController:
    ACTIVE_SIZE = (224, 224)
    IMAGE_SIZE_LIST = [(224, 224)]

    CHOICE_LIST = None

    @staticmethod
    def get_candidates() -> list[tuple[int, int]]:
        return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)

    @staticmethod
    def sample_resolution(batch_id: int) -> None:
        RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]

    @staticmethod
    def set_epoch(epoch: int, batch_per_epoch: int) -> None:
        g = torch.Generator()
        g.manual_seed(epoch)
        RRSController.CHOICE_LIST = torch_random_choices(
            RRSController.get_candidates(),
            g,
            batch_per_epoch,
        )


def get_interpolate(name: str) -> F.InterpolationMode:
    mapping = {
        "nearest": F.InterpolationMode.NEAREST,
        "bilinear": F.InterpolationMode.BILINEAR,
        "bicubic": F.InterpolationMode.BICUBIC,
        "box": F.InterpolationMode.BOX,
        "hamming": F.InterpolationMode.HAMMING,
        "lanczos": F.InterpolationMode.LANCZOS,
    }
    if name in mapping:
        return mapping[name]
    elif name == "random":
        return torch_random_choices(
            [
                F.InterpolationMode.NEAREST,
                F.InterpolationMode.BILINEAR,
                F.InterpolationMode.BICUBIC,
                F.InterpolationMode.BOX,
                F.InterpolationMode.HAMMING,
                F.InterpolationMode.LANCZOS,
            ],
        )
    else:
        raise NotImplementedError


class MyRandomResizedCrop(transforms.RandomResizedCrop):
    def __init__(
        self,
        scale=(0.08, 1.0),
        ratio=(3.0 / 4.0, 4.0 / 3.0),
        interpolation: str = "random",
    ):
        super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
        self.interpolation = interpolation

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
        target_size = RRSController.ACTIVE_SIZE
        return F.resized_crop(
            img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
        )

    def __repr__(self) -> str:
        format_string = self.__class__.__name__
        format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
        format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
        format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
        format_string += f"\tinterpolation={self.interpolation})"
        return format_string