|
import torch
|
|
import torchvision.transforms as T
|
|
import PIL.Image
|
|
from typing import List
|
|
|
|
|
|
size = (224, 224)
|
|
|
|
|
|
class ResizeWithPadding:
|
|
def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None:
|
|
self.target_size = target_size
|
|
self.fill = fill
|
|
self.mode = mode
|
|
|
|
def __call__(self, image: PIL.Image) -> PIL.Image:
|
|
original_width, original_height = image.size
|
|
aspect_ratio = original_width / original_height
|
|
if aspect_ratio > 1:
|
|
new_width = self.target_size
|
|
new_height = int(self.target_size / aspect_ratio)
|
|
else:
|
|
new_height = self.target_size
|
|
new_width = int(self.target_size * aspect_ratio)
|
|
resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST)
|
|
delta_w = self.target_size - new_width
|
|
delta_h = self.target_size - new_height
|
|
padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
|
|
padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill)
|
|
padded_image.paste(resized_image, (padding[0], padding[1]))
|
|
return padded_image
|
|
|
|
|
|
def get_transform(mean: List[float], std: List[float]) -> T.Compose:
|
|
return T.Compose([
|
|
ResizeWithPadding(),
|
|
T.ToTensor(),
|
|
T.Normalize(mean=mean, std=std),
|
|
])
|
|
|
|
mask_transform = T.Compose([
|
|
ResizeWithPadding(mode="L"),
|
|
T.ToTensor(),
|
|
T.Lambda(lambda x: (x * 255).long()),
|
|
])
|
|
|
|
|
|
class EMA:
|
|
def __init__(self, alpha: float = 0.9) -> None:
|
|
self.value = None
|
|
self.alpha = alpha
|
|
|
|
def __call__(self, value: float) -> float:
|
|
if self.value is None:
|
|
self.value = value
|
|
else:
|
|
self.value = self.alpha * self.value + (1 - self.alpha) * value
|
|
return self.value |