Spaces:
Runtime error
Runtime error
# -*- coding:utf-8 -*- | |
import os | |
import sys | |
import shutil | |
import logging | |
import colorlog | |
from tqdm import tqdm | |
import time | |
import yaml | |
import random | |
import importlib | |
from PIL import Image | |
from warnings import simplefilter | |
import imageio | |
import math | |
import collections | |
import json | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.optim import Adam | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
from torch.utils.data import DataLoader, Dataset | |
from einops import rearrange, repeat | |
import torch.distributed as dist | |
from torchvision import datasets, transforms, utils | |
logging.getLogger().setLevel(logging.WARNING) | |
simplefilter(action='ignore', category=FutureWarning) | |
def get_logger(filename=None): | |
""" | |
examples: | |
logger = get_logger('try_logging.txt') | |
logger.debug("Do something.") | |
logger.info("Start print log.") | |
logger.warning("Something maybe fail.") | |
try: | |
raise ValueError() | |
except ValueError: | |
logger.error("Error", exc_info=True) | |
tips: | |
DO NOT logger.inf(some big tensors since color may not helpful.) | |
""" | |
logger = logging.getLogger('utils') | |
level = logging.DEBUG | |
logger.setLevel(level=level) | |
# Use propagate to avoid multiple loggings. | |
logger.propagate = False | |
# Remove %(levelname)s since we have colorlog to represent levelname. | |
format_str = '[%(asctime)s <%(filename)s:%(lineno)d> %(funcName)s] %(message)s' | |
streamHandler = logging.StreamHandler() | |
streamHandler.setLevel(level) | |
coloredFormatter = colorlog.ColoredFormatter( | |
'%(log_color)s' + format_str, | |
datefmt='%Y-%m-%d %H:%M:%S', | |
reset=True, | |
log_colors={ | |
'DEBUG': 'cyan', | |
# 'INFO': 'white', | |
'WARNING': 'yellow', | |
'ERROR': 'red', | |
'CRITICAL': 'reg,bg_white', | |
} | |
) | |
streamHandler.setFormatter(coloredFormatter) | |
logger.addHandler(streamHandler) | |
if filename: | |
fileHandler = logging.FileHandler(filename) | |
fileHandler.setLevel(level) | |
formatter = logging.Formatter(format_str) | |
fileHandler.setFormatter(formatter) | |
logger.addHandler(fileHandler) | |
# Fix multiple logging for torch.distributed | |
try: | |
class UniqueLogger: | |
def __init__(self, logger): | |
self.logger = logger | |
self.local_rank = torch.distributed.get_rank() | |
def info(self, msg, *args, **kwargs): | |
if self.local_rank == 0: | |
return self.logger.info(msg, *args, **kwargs) | |
def warning(self, msg, *args, **kwargs): | |
if self.local_rank == 0: | |
return self.logger.warning(msg, *args, **kwargs) | |
logger = UniqueLogger(logger) | |
# AssertionError for gpu with no distributed | |
# AttributeError for no gpu. | |
except Exception: | |
pass | |
return logger | |
logger = get_logger() | |
def split_filename(filename): | |
absname = os.path.abspath(filename) | |
dirname, basename = os.path.split(absname) | |
split_tmp = basename.rsplit('.', maxsplit=1) | |
if len(split_tmp) == 2: | |
rootname, extname = split_tmp | |
elif len(split_tmp) == 1: | |
rootname = split_tmp[0] | |
extname = None | |
else: | |
raise ValueError("programming error!") | |
return dirname, rootname, extname | |
def data2file(data, filename, type=None, override=False, printable=False, **kwargs): | |
dirname, rootname, extname = split_filename(filename) | |
print_did_not_save_flag = True | |
if type: | |
extname = type | |
if not os.path.exists(dirname): | |
os.makedirs(dirname, exist_ok=True) | |
if not os.path.exists(filename) or override: | |
if extname in ['jpg', 'png', 'jpeg']: | |
utils.save_image(data, filename, **kwargs) | |
elif extname == 'gif': | |
imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0) | |
elif extname == 'txt': | |
if kwargs is None: | |
kwargs = {} | |
max_step = kwargs.get('max_step') | |
if max_step is None: | |
max_step = np.Infinity | |
with open(filename, 'w', encoding='utf-8') as f: | |
for i, e in enumerate(data): | |
if i < max_step: | |
f.write(str(e) + '\n') | |
else: | |
break | |
else: | |
raise ValueError('Do not support this type') | |
if printable: logger.info('Saved data to %s' % os.path.abspath(filename)) | |
else: | |
if print_did_not_save_flag: logger.info( | |
'Did not save data to %s because file exists and override is False' % os.path.abspath( | |
filename)) | |
def file2data(filename, type=None, printable=True, **kwargs): | |
dirname, rootname, extname = split_filename(filename) | |
print_load_flag = True | |
if type: | |
extname = type | |
if extname in ['pth', 'ckpt']: | |
data = torch.load(filename, map_location=kwargs.get('map_location')) | |
elif extname == 'txt': | |
top = kwargs.get('top', None) | |
with open(filename, encoding='utf-8') as f: | |
if top: | |
data = [f.readline() for _ in range(top)] | |
else: | |
data = [e for e in f.read().split('\n') if e] | |
elif extname == 'yaml': | |
with open(filename, 'r') as f: | |
data = yaml.load(f) | |
else: | |
raise ValueError('type can only support h5, npy, json, txt') | |
if printable: | |
if print_load_flag: | |
logger.info('Loaded data from %s' % os.path.abspath(filename)) | |
return data | |
def ensure_dirname(dirname, override=False): | |
if os.path.exists(dirname) and override: | |
logger.info('Removing dirname: %s' % os.path.abspath(dirname)) | |
try: | |
shutil.rmtree(dirname) | |
except OSError as e: | |
raise ValueError('Failed to delete %s because %s' % (dirname, e)) | |
if not os.path.exists(dirname): | |
logger.info('Making dirname: %s' % os.path.abspath(dirname)) | |
os.makedirs(dirname, exist_ok=True) | |
def import_filename(filename): | |
spec = importlib.util.spec_from_file_location("mymodule", filename) | |
module = importlib.util.module_from_spec(spec) | |
sys.modules[spec.name] = module | |
spec.loader.exec_module(module) | |
return module | |
def adaptively_load_state_dict(target, state_dict): | |
target_dict = target.state_dict() | |
try: | |
common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()} | |
except Exception as e: | |
logger.warning('load error %s', e) | |
common_dict = {k: v for k, v in state_dict.items() if k in target_dict} | |
if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \ | |
target.state_dict()['param_groups'][0]['params']: | |
logger.warning('Detected mismatch params, auto adapte state_dict to current') | |
common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params'] | |
target_dict.update(common_dict) | |
target.load_state_dict(target_dict) | |
missing_keys = [k for k in target_dict.keys() if k not in common_dict] | |
unexpected_keys = [k for k in state_dict.keys() if k not in common_dict] | |
if len(unexpected_keys) != 0: | |
logger.warning( | |
f"Some weights of state_dict were not used in target: {unexpected_keys}" | |
) | |
if len(missing_keys) != 0: | |
logger.warning( | |
f"Some weights of state_dict are missing used in target {missing_keys}" | |
) | |
if len(unexpected_keys) == 0 and len(missing_keys) == 0: | |
logger.warning("Strictly Loaded state_dict.") | |
def set_seed(seed=42): | |
random.seed(seed) | |
os.environ['PYHTONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
def image2pil(filename): | |
return Image.open(filename) | |
def image2arr(filename): | |
pil = image2pil(filename) | |
return pil2arr(pil) | |
# 格式转换 | |
def pil2arr(pil): | |
if isinstance(pil, list): | |
arr = np.array( | |
[np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil]) | |
else: | |
arr = np.array(pil) | |
return arr | |
def arr2pil(arr): | |
if arr.ndim == 3: | |
return Image.fromarray(arr.astype('uint8'), 'RGB') | |
elif arr.ndim == 4: | |
return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)] | |
else: | |
raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim) | |
def notebook_show(*images): | |
from IPython.display import Image | |
from IPython.display import display | |
display(*[Image(e) for e in images]) |