Spaces:
Running
Running
Upload 8 files
Browse files- utils/__init__.py +21 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/common.cpython-39.pyc +0 -0
- utils/__pycache__/image_processing.cpython-39.pyc +0 -0
- utils/common.py +165 -0
- utils/fast_numpyio.py +43 -0
- utils/image_processing.py +135 -0
- utils/logger.py +24 -0
utils/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common import *
|
2 |
+
from .image_processing import *
|
3 |
+
|
4 |
+
class DefaultArgs:
|
5 |
+
dataset ='Hayao'
|
6 |
+
data_dir ='/content'
|
7 |
+
epochs = 10
|
8 |
+
batch_size = 1
|
9 |
+
checkpoint_dir ='/content/checkpoints'
|
10 |
+
save_image_dir ='/content/images'
|
11 |
+
display_image =True
|
12 |
+
save_interval =2
|
13 |
+
debug_samples =0
|
14 |
+
lr_g = 0.001
|
15 |
+
lr_d = 0.002
|
16 |
+
wadvg = 300.0
|
17 |
+
wadvd = 300.0
|
18 |
+
wcon = 1.5
|
19 |
+
wgra = 3
|
20 |
+
wcol = 10
|
21 |
+
use_sn = False
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (739 Bytes). View file
|
|
utils/__pycache__/common.cpython-39.pyc
ADDED
Binary file (4.29 kB). View file
|
|
utils/__pycache__/image_processing.cpython-39.pyc
ADDED
Binary file (2.83 kB). View file
|
|
utils/common.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gc
|
3 |
+
import os
|
4 |
+
import torch.nn as nn
|
5 |
+
import urllib.request
|
6 |
+
import cv2
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
HTTP_PREFIXES = [
|
10 |
+
'http',
|
11 |
+
'data:image/jpeg',
|
12 |
+
]
|
13 |
+
|
14 |
+
|
15 |
+
RELEASED_WEIGHTS = {
|
16 |
+
"hayao:v2": (
|
17 |
+
# Dataset trained on Google Landmark micro as training real photo
|
18 |
+
"v2",
|
19 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt"
|
20 |
+
),
|
21 |
+
"hayao:v1": (
|
22 |
+
"v1",
|
23 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
|
24 |
+
),
|
25 |
+
"hayao": (
|
26 |
+
"v1",
|
27 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
|
28 |
+
),
|
29 |
+
"shinkai:v1": (
|
30 |
+
"v1",
|
31 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
|
32 |
+
),
|
33 |
+
"shinkai": (
|
34 |
+
"v1",
|
35 |
+
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
|
36 |
+
),
|
37 |
+
}
|
38 |
+
|
39 |
+
def is_image_file(path):
|
40 |
+
_, ext = os.path.splitext(path)
|
41 |
+
return ext.lower() in (".png", ".jpg", ".jpeg")
|
42 |
+
|
43 |
+
|
44 |
+
def read_image(path):
|
45 |
+
"""
|
46 |
+
Read image from given path
|
47 |
+
"""
|
48 |
+
|
49 |
+
if any(path.startswith(p) for p in HTTP_PREFIXES):
|
50 |
+
urllib.request.urlretrieve(path, "temp.jpg")
|
51 |
+
path = "temp.jpg"
|
52 |
+
|
53 |
+
return cv2.imread(path)[: ,: ,::-1]
|
54 |
+
|
55 |
+
|
56 |
+
def save_checkpoint(model, path, optimizer=None, epoch=None):
|
57 |
+
checkpoint = {
|
58 |
+
'model_state_dict': model.state_dict(),
|
59 |
+
'epoch': epoch,
|
60 |
+
}
|
61 |
+
if optimizer is not None:
|
62 |
+
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
|
63 |
+
|
64 |
+
torch.save(checkpoint, path)
|
65 |
+
|
66 |
+
def maybe_remove_module(state_dict):
|
67 |
+
# Remove added module ins state_dict in ddp training
|
68 |
+
# https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
|
69 |
+
new_state_dict = {}
|
70 |
+
module_str = 'module.'
|
71 |
+
for k, v in state_dict.items():
|
72 |
+
|
73 |
+
if k.startswith(module_str):
|
74 |
+
k = k[len(module_str):]
|
75 |
+
new_state_dict[k] = v
|
76 |
+
return new_state_dict
|
77 |
+
|
78 |
+
|
79 |
+
def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
|
80 |
+
state_dict = load_state_dict(path, map_location)
|
81 |
+
model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
|
82 |
+
model.load_state_dict(
|
83 |
+
model_state_dict,
|
84 |
+
strict=True
|
85 |
+
)
|
86 |
+
if 'optimizer_state_dict' in state_dict:
|
87 |
+
if optimizer is not None:
|
88 |
+
optimizer.load_state_dict(state_dict['optimizer_state_dict'])
|
89 |
+
if strip_optimizer:
|
90 |
+
del state_dict["optimizer_state_dict"]
|
91 |
+
torch.save(state_dict, path)
|
92 |
+
print(f"Optimizer stripped and saved to {path}")
|
93 |
+
|
94 |
+
epoch = state_dict.get('epoch', 0)
|
95 |
+
return epoch
|
96 |
+
|
97 |
+
|
98 |
+
def load_state_dict(weight, map_location) -> dict:
|
99 |
+
if weight.lower() in RELEASED_WEIGHTS:
|
100 |
+
weight = _download_weight(weight.lower())
|
101 |
+
|
102 |
+
if map_location is None:
|
103 |
+
# auto select
|
104 |
+
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
|
105 |
+
state_dict = torch.load(weight, map_location=map_location)
|
106 |
+
|
107 |
+
return state_dict
|
108 |
+
|
109 |
+
|
110 |
+
def initialize_weights(net):
|
111 |
+
for m in net.modules():
|
112 |
+
try:
|
113 |
+
if isinstance(m, nn.Conv2d):
|
114 |
+
# m.weight.data.normal_(0, 0.02)
|
115 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
116 |
+
m.bias.data.zero_()
|
117 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
118 |
+
# m.weight.data.normal_(0, 0.02)
|
119 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
120 |
+
m.bias.data.zero_()
|
121 |
+
elif isinstance(m, nn.Linear):
|
122 |
+
# m.weight.data.normal_(0, 0.02)
|
123 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
124 |
+
m.bias.data.zero_()
|
125 |
+
elif isinstance(m, nn.BatchNorm2d):
|
126 |
+
m.weight.data.fill_(1)
|
127 |
+
m.bias.data.zero_()
|
128 |
+
except Exception as e:
|
129 |
+
# print(f'SKip layer {m}, {e}')
|
130 |
+
pass
|
131 |
+
|
132 |
+
|
133 |
+
def set_lr(optimizer, lr):
|
134 |
+
for param_group in optimizer.param_groups:
|
135 |
+
param_group['lr'] = lr
|
136 |
+
|
137 |
+
|
138 |
+
class DownloadProgressBar(tqdm):
|
139 |
+
'''
|
140 |
+
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
|
141 |
+
'''
|
142 |
+
def update_to(self, b=1, bsize=1, tsize=None):
|
143 |
+
if tsize is not None:
|
144 |
+
self.total = tsize
|
145 |
+
self.update(b * bsize - self.n)
|
146 |
+
|
147 |
+
|
148 |
+
def _download_weight(weight):
|
149 |
+
'''
|
150 |
+
Download weight and save to local file
|
151 |
+
'''
|
152 |
+
os.makedirs('.cache', exist_ok=True)
|
153 |
+
url = RELEASED_WEIGHTS[weight][1]
|
154 |
+
filename = os.path.basename(url)
|
155 |
+
save_path = f'.cache/{filename}'
|
156 |
+
|
157 |
+
if os.path.isfile(save_path):
|
158 |
+
return save_path
|
159 |
+
|
160 |
+
desc = f'Downloading {url} to {save_path}'
|
161 |
+
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
|
162 |
+
urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)
|
163 |
+
|
164 |
+
return save_path
|
165 |
+
|
utils/fast_numpyio.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from https://github.com/divideconcept/fastnumpyio/blob/main/fastnumpyio.py
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import numpy as np
|
5 |
+
import numpy.lib.format
|
6 |
+
import struct
|
7 |
+
|
8 |
+
def save(file, array):
|
9 |
+
magic_string=b"\x93NUMPY\x01\x00v\x00"
|
10 |
+
header=bytes(("{'descr': '"+array.dtype.descr[0][1]+"', 'fortran_order': False, 'shape': "+str(array.shape)+", }").ljust(127-len(magic_string))+"\n",'utf-8')
|
11 |
+
if type(file) == str:
|
12 |
+
file=open(file,"wb")
|
13 |
+
file.write(magic_string)
|
14 |
+
file.write(header)
|
15 |
+
file.write(array.data)
|
16 |
+
|
17 |
+
def pack(array):
|
18 |
+
size=len(array.shape)
|
19 |
+
return bytes(array.dtype.byteorder.replace('=','<' if sys.byteorder == 'little' else '>')+array.dtype.kind,'utf-8')+array.dtype.itemsize.to_bytes(1,byteorder='little')+struct.pack(f'<B{size}I',size,*array.shape)+array.data
|
20 |
+
|
21 |
+
def load(file):
|
22 |
+
if type(file) == str:
|
23 |
+
file=open(file,"rb")
|
24 |
+
header = file.read(128)
|
25 |
+
if not header:
|
26 |
+
return None
|
27 |
+
descr = str(header[19:25], 'utf-8').replace("'","").replace(" ","")
|
28 |
+
shape = tuple(int(num) for num in str(header[60:120], 'utf-8').replace(', }', '').replace('(', '').replace(')', '').split(','))
|
29 |
+
datasize = numpy.lib.format.descr_to_dtype(descr).itemsize
|
30 |
+
for dimension in shape:
|
31 |
+
datasize *= dimension
|
32 |
+
return np.ndarray(shape, dtype=descr, buffer=file.read(datasize))
|
33 |
+
|
34 |
+
def unpack(data):
|
35 |
+
dtype = str(data[:2],'utf-8')
|
36 |
+
dtype += str(data[2])
|
37 |
+
size = data[3]
|
38 |
+
shape = struct.unpack_from(f'<{size}I', data, 4)
|
39 |
+
datasize=data[2]
|
40 |
+
for dimension in shape:
|
41 |
+
datasize *= dimension
|
42 |
+
return np.ndarray(shape, dtype=dtype, buffer=data[4+size*4:4+size*4+datasize])
|
43 |
+
|
utils/image_processing.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def gram(input):
|
9 |
+
"""
|
10 |
+
Calculate Gram Matrix
|
11 |
+
|
12 |
+
https://pytorch.org/tutorials/advanced/neural_style_tutorial.html#style-loss
|
13 |
+
"""
|
14 |
+
b, c, w, h = input.size()
|
15 |
+
|
16 |
+
x = input.contiguous().view(b * c, w * h)
|
17 |
+
|
18 |
+
# x = x / 2
|
19 |
+
|
20 |
+
# Work around, torch.mm would generate some inf values.
|
21 |
+
# https://discuss.pytorch.org/t/gram-matrix-in-mixed-precision/166800/2
|
22 |
+
# x = torch.clamp(x, max=1.0e2, min=-1.0e2)
|
23 |
+
# x[x > 1.0e2] = 1.0e2
|
24 |
+
# x[x < -1.0e2] = -1.0e2
|
25 |
+
|
26 |
+
G = torch.mm(x, x.T)
|
27 |
+
G = torch.clamp(G, -64990.0, 64990.0)
|
28 |
+
# normalize by total elements
|
29 |
+
result = G.div(b * c * w * h)
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def divisible(dim):
|
35 |
+
'''
|
36 |
+
Make width and height divisible by 32
|
37 |
+
'''
|
38 |
+
width, height = dim
|
39 |
+
return width - (width % 32), height - (height % 32)
|
40 |
+
|
41 |
+
|
42 |
+
def resize_image(image, width=None, height=None, inter=cv2.INTER_AREA):
|
43 |
+
dim = None
|
44 |
+
h, w = image.shape[:2]
|
45 |
+
|
46 |
+
if width and height:
|
47 |
+
return cv2.resize(image, divisible((width, height)), interpolation=inter)
|
48 |
+
|
49 |
+
if width is None and height is None:
|
50 |
+
return cv2.resize(image, divisible((w, h)), interpolation=inter)
|
51 |
+
|
52 |
+
if width is None:
|
53 |
+
r = height / float(h)
|
54 |
+
dim = (int(w * r), height)
|
55 |
+
|
56 |
+
else:
|
57 |
+
r = width / float(w)
|
58 |
+
dim = (width, int(h * r))
|
59 |
+
|
60 |
+
return cv2.resize(image, divisible(dim), interpolation=inter)
|
61 |
+
|
62 |
+
|
63 |
+
def normalize_input(images):
|
64 |
+
'''
|
65 |
+
[0, 255] -> [-1, 1]
|
66 |
+
'''
|
67 |
+
return images / 127.5 - 1.0
|
68 |
+
|
69 |
+
|
70 |
+
def denormalize_input(images, dtype=None):
|
71 |
+
'''
|
72 |
+
[-1, 1] -> [0, 255]
|
73 |
+
'''
|
74 |
+
images = images * 127.5 + 127.5
|
75 |
+
|
76 |
+
if dtype is not None:
|
77 |
+
if isinstance(images, torch.Tensor):
|
78 |
+
images = images.type(dtype)
|
79 |
+
else:
|
80 |
+
# numpy.ndarray
|
81 |
+
images = images.astype(dtype)
|
82 |
+
|
83 |
+
return images
|
84 |
+
|
85 |
+
|
86 |
+
def preprocess_images(images):
|
87 |
+
'''
|
88 |
+
Preprocess image for inference
|
89 |
+
|
90 |
+
@Arguments:
|
91 |
+
- images: np.ndarray
|
92 |
+
|
93 |
+
@Returns
|
94 |
+
- images: torch.tensor
|
95 |
+
'''
|
96 |
+
images = images.astype(np.float32)
|
97 |
+
|
98 |
+
# Normalize to [-1, 1]
|
99 |
+
images = normalize_input(images)
|
100 |
+
images = torch.from_numpy(images)
|
101 |
+
|
102 |
+
# Add batch dim
|
103 |
+
if len(images.shape) == 3:
|
104 |
+
images = images.unsqueeze(0)
|
105 |
+
|
106 |
+
# channel first
|
107 |
+
images = images.permute(0, 3, 1, 2)
|
108 |
+
|
109 |
+
return images
|
110 |
+
|
111 |
+
def compute_data_mean(data_folder):
|
112 |
+
if not os.path.exists(data_folder):
|
113 |
+
raise FileNotFoundError(f'Folder {data_folder} does not exits')
|
114 |
+
|
115 |
+
image_files = os.listdir(data_folder)
|
116 |
+
total = np.zeros(3)
|
117 |
+
|
118 |
+
print(f"Compute mean (R, G, B) from {len(image_files)} images")
|
119 |
+
|
120 |
+
for img_file in tqdm(image_files):
|
121 |
+
path = os.path.join(data_folder, img_file)
|
122 |
+
image = cv2.imread(path)
|
123 |
+
total += image.mean(axis=(0, 1))
|
124 |
+
|
125 |
+
channel_mean = total / len(image_files)
|
126 |
+
mean = np.mean(channel_mean)
|
127 |
+
|
128 |
+
return mean - channel_mean[...,::-1] # Convert to BGR for training
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == '__main__':
|
132 |
+
t = torch.rand(2, 14, 32, 32)
|
133 |
+
|
134 |
+
with torch.autocast("cpu"):
|
135 |
+
print(gram(t))
|
utils/logger.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
def get_logger(path, *args, **kwargs):
|
5 |
+
# logger = logging.getLogger('train')
|
6 |
+
# logger.setLevel(logging.NOTSET)
|
7 |
+
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
8 |
+
# # add filehandler
|
9 |
+
# fh = logging.FileHandler(path)
|
10 |
+
# fh.setLevel(logging.NOTSET)
|
11 |
+
# fh.setFormatter(formatter)
|
12 |
+
# ch = logging.StreamHandler()
|
13 |
+
# ch.setLevel(logging.ERROR)
|
14 |
+
# logger.addHandler(fh)
|
15 |
+
# logger.addHandler(ch)
|
16 |
+
# return logger
|
17 |
+
logging.basicConfig(format = '%(asctime)s %(message)s',
|
18 |
+
datefmt = '%m/%d/%Y %I:%M:%S %p',
|
19 |
+
handlers=[
|
20 |
+
logging.FileHandler(path),
|
21 |
+
logging.StreamHandler()
|
22 |
+
],
|
23 |
+
level=logging.DEBUG)
|
24 |
+
return logging
|