import torch import torchvision.transforms as T import PIL.Image from typing import List size = (224, 224) class ResizeWithPadding: def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None: self.target_size = target_size self.fill = fill self.mode = mode def __call__(self, image: PIL.Image) -> PIL.Image: original_width, original_height = image.size aspect_ratio = original_width / original_height if aspect_ratio > 1: new_width = self.target_size new_height = int(self.target_size / aspect_ratio) else: new_height = self.target_size new_width = int(self.target_size * aspect_ratio) resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST) delta_w = self.target_size - new_width delta_h = self.target_size - new_height padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2) padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill) padded_image.paste(resized_image, (padding[0], padding[1])) return padded_image def get_transform(mean: List[float], std: List[float]) -> T.Compose: return T.Compose([ ResizeWithPadding(), T.ToTensor(), T.Normalize(mean=mean, std=std), ]) mask_transform = T.Compose([ ResizeWithPadding(mode="L"), T.ToTensor(), T.Lambda(lambda x: (x * 255).long()), ]) class EMA: def __init__(self, alpha: float = 0.9) -> None: self.value = None self.alpha = alpha def __call__(self, value: float) -> float: if self.value is None: self.value = value else: self.value = self.alpha * self.value + (1 - self.alpha) * value return self.value