eksemyashkina's picture
Upload 9 files
af720c2 verified
import torch
import torchvision.transforms as T
import PIL.Image
from typing import List
size = (224, 224)
class ResizeWithPadding:
def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None:
self.target_size = target_size
self.fill = fill
self.mode = mode
def __call__(self, image: PIL.Image) -> PIL.Image:
original_width, original_height = image.size
aspect_ratio = original_width / original_height
if aspect_ratio > 1:
new_width = self.target_size
new_height = int(self.target_size / aspect_ratio)
else:
new_height = self.target_size
new_width = int(self.target_size * aspect_ratio)
resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST)
delta_w = self.target_size - new_width
delta_h = self.target_size - new_height
padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill)
padded_image.paste(resized_image, (padding[0], padding[1]))
return padded_image
def get_transform(mean: List[float], std: List[float]) -> T.Compose:
return T.Compose([
ResizeWithPadding(),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
mask_transform = T.Compose([
ResizeWithPadding(mode="L"),
T.ToTensor(),
T.Lambda(lambda x: (x * 255).long()),
])
class EMA:
def __init__(self, alpha: float = 0.9) -> None:
self.value = None
self.alpha = alpha
def __call__(self, value: float) -> float:
if self.value is None:
self.value = value
else:
self.value = self.alpha * self.value + (1 - self.alpha) * value
return self.value