cartoonize / utils /common.py
YANGYYYY's picture
Upload 8 files
922e494 verified
raw
history blame
4.87 kB
import torch
import gc
import os
import torch.nn as nn
import urllib.request
import cv2
from tqdm import tqdm
HTTP_PREFIXES = [
'http',
'data:image/jpeg',
]
RELEASED_WEIGHTS = {
"hayao:v2": (
# Dataset trained on Google Landmark micro as training real photo
"v2",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.1/GeneratorV2_gldv2_Hayao.pt"
),
"hayao:v1": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
),
"hayao": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_hayao.pth"
),
"shinkai:v1": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
),
"shinkai": (
"v1",
"https://github.com/ptran1203/pytorch-animeGAN/releases/download/v1.0/generator_shinkai.pth"
),
}
def is_image_file(path):
_, ext = os.path.splitext(path)
return ext.lower() in (".png", ".jpg", ".jpeg")
def read_image(path):
"""
Read image from given path
"""
if any(path.startswith(p) for p in HTTP_PREFIXES):
urllib.request.urlretrieve(path, "temp.jpg")
path = "temp.jpg"
return cv2.imread(path)[: ,: ,::-1]
def save_checkpoint(model, path, optimizer=None, epoch=None):
checkpoint = {
'model_state_dict': model.state_dict(),
'epoch': epoch,
}
if optimizer is not None:
checkpoint['optimizer_state_dict'] = optimizer.state_dict()
torch.save(checkpoint, path)
def maybe_remove_module(state_dict):
# Remove added module ins state_dict in ddp training
# https://discuss.pytorch.org/t/why-are-state-dict-keys-getting-prepended-with-the-string-module/104627/3
new_state_dict = {}
module_str = 'module.'
for k, v in state_dict.items():
if k.startswith(module_str):
k = k[len(module_str):]
new_state_dict[k] = v
return new_state_dict
def load_checkpoint(model, path, optimizer=None, strip_optimizer=False, map_location=None) -> int:
state_dict = load_state_dict(path, map_location)
model_state_dict = maybe_remove_module(state_dict['model_state_dict'])
model.load_state_dict(
model_state_dict,
strict=True
)
if 'optimizer_state_dict' in state_dict:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer_state_dict'])
if strip_optimizer:
del state_dict["optimizer_state_dict"]
torch.save(state_dict, path)
print(f"Optimizer stripped and saved to {path}")
epoch = state_dict.get('epoch', 0)
return epoch
def load_state_dict(weight, map_location) -> dict:
if weight.lower() in RELEASED_WEIGHTS:
weight = _download_weight(weight.lower())
if map_location is None:
# auto select
map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
state_dict = torch.load(weight, map_location=map_location)
return state_dict
def initialize_weights(net):
for m in net.modules():
try:
if isinstance(m, nn.Conv2d):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
# m.weight.data.normal_(0, 0.02)
torch.nn.init.xavier_uniform_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
except Exception as e:
# print(f'SKip layer {m}, {e}')
pass
def set_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class DownloadProgressBar(tqdm):
'''
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads
'''
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def _download_weight(weight):
'''
Download weight and save to local file
'''
os.makedirs('.cache', exist_ok=True)
url = RELEASED_WEIGHTS[weight][1]
filename = os.path.basename(url)
save_path = f'.cache/{filename}'
if os.path.isfile(save_path):
return save_path
desc = f'Downloading {url} to {save_path}'
with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=desc) as t:
urllib.request.urlretrieve(url, save_path, reporthook=t.update_to)
return save_path