Spaces:
Running
Running
import torch | |
import cv2 | |
import os | |
import numpy as np | |
from tqdm import tqdm | |
def gram(input): | |
""" | |
Calculate Gram Matrix | |
https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#style-loss | |
""" | |
b, c, w, h = input.size() | |
x = input.contiguous().view(b * c, w * h) | |
# x = x / 2 | |
# Work around, torch.mm would generate some inf values. | |
# https://discuss.pytorch.org/t/gram-matrix-in-mixed-precision/166800/2 | |
# x = torch.clamp(x, max=1.0e2, min=-1.0e2) | |
# x[x > 1.0e2] = 1.0e2 | |
# x[x < -1.0e2] = -1.0e2 | |
G = torch.mm(x, x.T) | |
G = torch.clamp(G, -64990.0, 64990.0) | |
# normalize by total elements | |
result = G.div(b * c * w * h) | |
return result | |
def divisible(dim): | |
''' | |
Make width and height divisible by 32 | |
''' | |
width, height = dim | |
return width - (width % 32), height - (height % 32) | |
def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA): | |
dim = None | |
h, w = image.shape[:2] | |
if width and height: | |
return cv2.resize(image, divisible((width, height)), interpolation=inter) | |
if width is None and height is None: | |
return cv2.resize(image, divisible((w, h)), interpolation=inter) | |
if width is None: | |
r = height / float(h) | |
dim = (int(w * r), height) | |
else: | |
r = width / float(w) | |
dim = (width, int(h * r)) | |
return cv2.resize(image, divisible(dim), interpolation=inter) | |
def normalize_input(images): | |
''' | |
[0, 255] -> [-1, 1] | |
''' | |
return images / 127.5 - 1.0 | |
def denormalize_input(images, dtype=None): | |
''' | |
[-1, 1] -> [0, 255] | |
''' | |
images = images * 127.5 + 127.5 | |
if dtype is not None: | |
if isinstance(images, torch.Tensor): | |
images = images.type(dtype) | |
else: | |
# numpy.ndarray | |
images = images.astype(dtype) | |
return images | |
def preprocess_images(images): | |
''' | |
Preprocess image for inference | |
@Arguments: | |
- images: np.ndarray | |
@Returns | |
- images: torch.tensor | |
''' | |
images = images.astype(np.float32) | |
# Normalize to [-1, 1] | |
images = normalize_input(images) | |
images = torch.from_numpy(images) | |
# Add batch dim | |
if len(images.shape) == 3: | |
images = images.unsqueeze(0) | |
# channel first | |
images = images.permute(0, 3, 1, 2) | |
return images | |
def compute_data_mean(data_folder): | |
if not os.path.exists(data_folder): | |
raise FileNotFoundError(f'Folder {data_folder} does not exits') | |
image_files = os.listdir(data_folder) | |
total = np.zeros(3) | |
print(f"Compute mean (R, G, B) from {len(image_files)} images") | |
for img_file in tqdm(image_files): | |
path = os.path.join(data_folder, img_file) | |
image = cv2.imread(path) | |
total += image.mean(axis=(0, 1)) | |
channel_mean = total / len(image_files) | |
mean = np.mean(channel_mean) | |
return mean - channel_mean[...,::-1] # Convert to BGR for training | |
if __name__ == '__main__': | |
t = torch.rand(2, 14, 32, 32) | |
with torch.autocast("cpu"): | |
print(gram(t)) | |