Spaces:
Running
Running
File size: 3,083 Bytes
922e494 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import cv2
import os
import numpy as np
from tqdm import tqdm
def gram(input):
"""
Calculate Gram Matrix
https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#style-loss
"""
b, c, w, h = input.size()
x = input.contiguous().view(b * c, w * h)
# x = x / 2
# Work around, torch.mm would generate some inf values.
# https://discuss.pytorch.org/t/gram-matrix-in-mixed-precision/166800/2
# x = torch.clamp(x, max=1.0e2, min=-1.0e2)
# x[x > 1.0e2] = 1.0e2
# x[x < -1.0e2] = -1.0e2
G = torch.mm(x, x.T)
G = torch.clamp(G, -64990.0, 64990.0)
# normalize by total elements
result = G.div(b * c * w * h)
return result
def divisible(dim):
'''
Make width and height divisible by 32
'''
width, height = dim
return width - (width % 32), height - (height % 32)
def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA):
dim = None
h, w = image.shape[:2]
if width and height:
return cv2.resize(image, divisible((width, height)), interpolation=inter)
if width is None and height is None:
return cv2.resize(image, divisible((w, h)), interpolation=inter)
if width is None:
r = height / float(h)
dim = (int(w * r), height)
else:
r = width / float(w)
dim = (width, int(h * r))
return cv2.resize(image, divisible(dim), interpolation=inter)
def normalize_input(images):
'''
[0, 255] -> [-1, 1]
'''
return images / 127.5 - 1.0
def denormalize_input(images, dtype=None):
'''
[-1, 1] -> [0, 255]
'''
images = images * 127.5 + 127.5
if dtype is not None:
if isinstance(images, torch.Tensor):
images = images.type(dtype)
else:
# numpy.ndarray
images = images.astype(dtype)
return images
def preprocess_images(images):
'''
Preprocess image for inference
@Arguments:
- images: np.ndarray
@Returns
- images: torch.tensor
'''
images = images.astype(np.float32)
# Normalize to [-1, 1]
images = normalize_input(images)
images = torch.from_numpy(images)
# Add batch dim
if len(images.shape) == 3:
images = images.unsqueeze(0)
# channel first
images = images.permute(0, 3, 1, 2)
return images
def compute_data_mean(data_folder):
if not os.path.exists(data_folder):
raise FileNotFoundError(f'Folder {data_folder} does not exits')
image_files = os.listdir(data_folder)
total = np.zeros(3)
print(f"Compute mean (R, G, B) from {len(image_files)} images")
for img_file in tqdm(image_files):
path = os.path.join(data_folder, img_file)
image = cv2.imread(path)
total += image.mean(axis=(0, 1))
channel_mean = total / len(image_files)
mean = np.mean(channel_mean)
return mean - channel_mean[...,::-1] # Convert to BGR for training
if __name__ == '__main__':
t = torch.rand(2, 14, 32, 32)
with torch.autocast("cpu"):
print(gram(t))
|