Spaces:
Running
Running
File size: 4,868 Bytes
922e494 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import torch
import gc
import os
import torch.nn as nn
import urllib.request
import cv2
from tqdm import tqdm
HTTP_PREFIXES = [
'http',
'data:image/jpeg',
]
RELEASED_WEIGHTS = {
"hayao:v2": (
# Dataset trained on Google Landmark micro as training real photo
"v2",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt"
),
"hayao:v1": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
),
"hayao": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
),
"shinkai:v1": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
),
"shinkai": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
),
}
def is_image_file(path):
_, ext = os.path.splitext(path)
return ext.lower() in (".png", ".jpg", ".jpeg")
def read_image(path):
"""
Read image from given path
"""
if any(path.startswith(p) for p in HTTP_PREFIXES):
urllib.request.urlretrieve(path, "temp.jpg")
path = "temp.jpg"
return cv2.imread(path)[: ,: ,::-1]
def save_checkpoint(model, path, optimizer=None, epoch=None):
checkpoint = {
'model_state_dict': model.state_dict(),
'epoch': epoch,
}
if optimizer is not None:
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
torch.save(checkpoint, path)
def maybe_remove_module(state_dict):
# Remove added module ins state_dict in ddp training
# https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
new_state_dict = {}
module_str = 'module.'
for k, v in state_dict.items():
if k.startswith(module_str):
k = k[len(module_str):]
new_state_dict[k] = v
return new_state_dict
def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
state_dict = load_state_dict(path, map_location)
model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
model.load_state_dict(
model_state_dict,
strict=True
)
if 'optimizer_state_dict' in state_dict:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer_state_dict'])
if strip_optimizer:
del state_dict["optimizer_state_dict"]
torch.save(state_dict, path)
print(f"Optimizer stripped and saved to {path}")
epoch = state_dict.get('epoch', 0)
return epoch
def load_state_dict(weight, map_location) -> dict:
if weight.lower() in RELEASED_WEIGHTS:
weight = _download_weight(weight.lower())
if map_location is None:
# auto select
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dict = torch.load(weight, map_location=map_location)
return state_dict
def initialize_weights(net):
for m in net.modules():
try:
if isinstance(m, nn.Conv2d):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
except Exception as e:
# print(f'SKip layer {m}, {e}')
pass
def set_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class DownloadProgressBar(tqdm):
'''
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
'''
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def _download_weight(weight):
'''
Download weight and save to local file
'''
os.makedirs('.cache', exist_ok=True)
url = RELEASED_WEIGHTS[weight][1]
filename = os.path.basename(url)
save_path = f'.cache/{filename}'
if os.path.isfile(save_path):
return save_path
desc = f'Downloading {url} to {save_path}'
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)
return save_path
|