|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
import datetime |
|
from typing import List |
|
import functools |
|
import os |
|
from PIL import Image |
|
from termcolor import colored |
|
import sys |
|
import logging |
|
from omegaconf import OmegaConf |
|
import json |
|
|
|
try: |
|
from torch.utils.tensorboard import SummaryWriter |
|
from torch import Tensor |
|
import torch |
|
except: |
|
raise ImportError("Please install torch to use this module!") |
|
|
|
""" |
|
NOTE: The `log` instance is a global variable, which should be imported by other modules as: |
|
`import optvq.utils.logger as logger` |
|
rather than |
|
`from optvq.utils.logger import log`. |
|
""" |
|
|
|
def setup_printer(file_log_dir: str, use_console: bool = True): |
|
printer = logging.getLogger("LOG") |
|
printer.setLevel(logging.DEBUG) |
|
printer.propagate = False |
|
|
|
|
|
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' |
|
color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ |
|
colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' |
|
|
|
|
|
if use_console: |
|
console_handler = logging.StreamHandler(sys.stdout) |
|
console_handler.setLevel(logging.DEBUG) |
|
console_handler.setFormatter( |
|
logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S") |
|
) |
|
printer.addHandler(console_handler) |
|
|
|
|
|
file_handler = logging.FileHandler(os.path.join(file_log_dir, "record.txt"), mode="a") |
|
file_handler.setLevel(logging.DEBUG) |
|
file_handler.setFormatter( |
|
logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S") |
|
) |
|
printer.addHandler(file_handler) |
|
|
|
return printer |
|
|
|
@functools.lru_cache() |
|
def config_loggers(log_dir: str, local_rank: int = 0, master_rank: int = 0): |
|
global log |
|
|
|
if local_rank == master_rank: |
|
log = LogManager(log_dir=log_dir, main_logger=True) |
|
else: |
|
log = LogManager(log_dir=log_dir, main_logger=False) |
|
|
|
class ProgressWithIndices: |
|
def __init__(self, total: int, sep_char: str = "| ", |
|
num_per_row: int = 4): |
|
self.total = total |
|
self.sep_char = sep_char |
|
self.num_per_row = num_per_row |
|
|
|
self.count = 0 |
|
self.start_time = time.time() |
|
self.past_time = None |
|
self.current_time = None |
|
self.eta = None |
|
self.speed = None |
|
self.used_time = 0 |
|
|
|
def update(self): |
|
self.count += 1 |
|
if self.count <= self.total: |
|
self.past_time = self.current_time |
|
self.current_time = time.time() |
|
|
|
if self.past_time is not None: |
|
self.eta = (self.total - self.count) * (self.current_time - self.past_time) |
|
self.eta = str(datetime.timedelta(seconds=int(self.eta))) |
|
self.speed = 1 / (self.current_time - self.past_time + 1e-8) |
|
|
|
self.used_time = self.current_time - self.start_time |
|
self.used_time = str(datetime.timedelta(seconds=int(self.used_time))) |
|
else: |
|
self.eta = 0 |
|
self.speed = 0 |
|
self.past_time = None |
|
self.current_time = None |
|
|
|
def print(self, prefix: str = "", content: str = "", ): |
|
global log |
|
prefix_str = f"{prefix}\t" + f"[{self.count}/{self.total} {self.used_time}/Eta:{self.eta}], Speed:{self.speed}iters/s\n" |
|
content_list = content.split(self.sep_char) |
|
content_list = [content.strip() for content in content_list] |
|
content_list = [ |
|
"\t\t" + self.sep_char.join(content_list[i:i + self.num_per_row]) |
|
for i in range(0, len(content_list), self.num_per_row) |
|
] |
|
content = prefix_str + "\n".join(content_list) |
|
log.info(content) |
|
|
|
class LogManager: |
|
""" |
|
This class encapsulates the tensorboard writer, the statistic meters, the console printer, and the progress counters. |
|
|
|
Args: |
|
log_dir (str): the parent directory to save all the logs |
|
init_meters (List[str]): the initial meters to be shown |
|
show_avg (bool): whether to show the average value of the meters |
|
""" |
|
def __init__(self, log_dir: str, init_meters: List[str] = [], |
|
show_avg: bool = True, main_logger: bool = False): |
|
|
|
|
|
self.show_avg = show_avg |
|
self.log_dir = log_dir |
|
self.main_logger = main_logger |
|
self.setup_dirs() |
|
|
|
|
|
self.meters = {meter: AverageMeter() for meter in init_meters} |
|
|
|
|
|
self.total_steps = 0 |
|
self.total_epochs = 0 |
|
|
|
if self.main_logger: |
|
|
|
self.board = SummaryWriter(log_dir=self.tb_log_dir) |
|
|
|
|
|
self.printer = setup_printer(self.file_log_dir, use_console=True) |
|
|
|
def state_dict(self): |
|
return { |
|
"total_steps": self.total_steps, |
|
"total_epochs": self.total_epochs, |
|
"meters": { |
|
meter_name: meter.state_dict() for meter_name, meter in self.meters.items() |
|
} |
|
} |
|
|
|
def load_state_dict(self, state_dict: dict): |
|
self.total_steps = state_dict["total_steps"] |
|
self.total_epochs = state_dict["total_epochs"] |
|
for meter_name, meter_state_dict in state_dict["meters"].items(): |
|
if meter_name not in self.meters: |
|
self.meters[meter_name] = AverageMeter() |
|
self.meters[meter_name].load_state_dict(meter_state_dict) |
|
|
|
|
|
def setup_dirs(self): |
|
""" |
|
The structure of the log directory: |
|
- log_dir: [tb_log, txt_log, img_log, model_log] |
|
""" |
|
self.tb_log_dir = os.path.join(self.log_dir, "tb_log") |
|
|
|
|
|
self.file_log_dir = self.log_dir |
|
self.img_log_dir = os.path.join(self.log_dir, "img_log") |
|
|
|
self.config_path = os.path.join(self.log_dir, "config.yaml") |
|
self.checkpoint_path = os.path.join(self.log_dir, "checkpoint.pth") |
|
self.backup_checkpoint_path = os.path.join(self.log_dir, "checkpoint.pth") |
|
self.save_logger_path = os.path.join(self.log_dir, "logger.json") |
|
|
|
if self.main_logger: |
|
os.makedirs(self.tb_log_dir, exist_ok=True) |
|
os.makedirs(self.file_log_dir, exist_ok=True) |
|
os.makedirs(self.img_log_dir, exist_ok=True) |
|
|
|
|
|
|
|
def info(self, msg, *args, **kwargs): |
|
if self.main_logger: |
|
self.printer.info(msg, *args, **kwargs) |
|
|
|
def show(self, include_key: str = ""): |
|
if isinstance(include_key, str): |
|
include_key = [include_key] |
|
if self.show_avg: |
|
return "| ".join([f"{meter_name}: {meter.val:.4f}/{meter.avg:.4f}" for meter_name, meter in self.meters.items() if any([k in meter_name for k in include_key])]) |
|
else: |
|
return "| ".join([f"{meter_name}: {meter.val:.4f}" for meter_name, meter in self.meters.items() if any([k in meter_name for k in include_key])]) |
|
|
|
|
|
|
|
def update_steps(self): |
|
self.total_steps += 1 |
|
return self.total_steps |
|
|
|
def update_epochs(self): |
|
self.total_epochs += 1 |
|
return self.total_epochs |
|
|
|
|
|
def add_histogram(self, tag: str, values: Tensor, global_step: int = None): |
|
if self.main_logger: |
|
global_step = self.total_steps if global_step is None else global_step |
|
self.board.add_histogram(tag, values, global_step) |
|
|
|
def add_scalar(self, tag: str, scalar_value: float, global_step: int = None): |
|
if isinstance(scalar_value, Tensor): |
|
scalar_value = scalar_value.item() |
|
if tag in self.meters: |
|
cur_step = self.meters[tag].update(scalar_value) |
|
cur_step = cur_step if global_step is None else global_step |
|
if self.main_logger: |
|
self.board.add_scalar(tag, scalar_value, cur_step) |
|
else: |
|
self.meters[tag] = AverageMeter() |
|
cur_step = self.meters[tag].update(scalar_value) |
|
cur_step = cur_step if global_step is None else global_step |
|
if self.main_logger: |
|
print(f"Create new meter: {tag}!") |
|
self.board.add_scalar(tag, scalar_value, cur_step) |
|
|
|
def add_scalar_dict(self, scalar_dict: dict, global_step: int = None): |
|
for tag, scalar_value in scalar_dict.items(): |
|
self.add_scalar(tag, scalar_value, global_step) |
|
|
|
def add_images(self, tag: str, images: Tensor, global_step: int = None): |
|
if self.main_logger: |
|
global_step = self.total_steps if global_step is None else global_step |
|
self.board.add_images(tag, images, global_step, dataformats="NCHW") |
|
|
|
|
|
def save_configs(self, config): |
|
if self.main_logger: |
|
|
|
OmegaConf.save(config, self.config_path) |
|
self.info(f"Save config to {self.config_path}.") |
|
|
|
|
|
state_dict = self.state_dict() |
|
with open(self.save_logger_path, "w") as f: |
|
json.dump(state_dict, f) |
|
|
|
def load_configs(self): |
|
|
|
assert os.path.exists(self.config_path), f"Config {self.config_path} does not exist!" |
|
config = OmegaConf.load(self.config_path) |
|
|
|
|
|
assert os.path.exists(self.save_logger_path), f"Logger {self.save_logger_path} does not exist!" |
|
state_dict = json.load(open(self.save_logger_path, "r")) |
|
self.load_state_dict(state_dict) |
|
|
|
return config |
|
|
|
def save_checkpoint(self, model, optimizers, schedulers, scalers, suffix: str = ""): |
|
""" |
|
checkpoint_dict: model, optimizer, scheduler, scalers |
|
""" |
|
if self.main_logger: |
|
|
|
|
|
checkpoint_dict = { |
|
"model": model.state_dict(), |
|
"epoch": self.total_epochs, |
|
"step": self.total_steps |
|
} |
|
checkpoint_dict.update({k: v.state_dict() for k, v in optimizers.items()}) |
|
checkpoint_dict.update({k: v.state_dict() for k, v in schedulers.items() if v is not None}) |
|
checkpoint_dict.update({k: v.state_dict() for k, v in scalers.items()}) |
|
|
|
checkpoint_path = self.checkpoint_path + suffix |
|
torch.save(checkpoint_dict, checkpoint_path) |
|
if os.path.exists(self.backup_checkpoint_path): |
|
os.remove(self.backup_checkpoint_path) |
|
self.backup_checkpoint_path = checkpoint_path + f".epoch{self.total_epochs}" |
|
torch.save(checkpoint_dict, self.backup_checkpoint_path) |
|
|
|
self.info(f"### Epoch: {self.total_epochs}| Steps: {self.total_steps}| Save checkpoint to {checkpoint_path}.") |
|
|
|
def load_checkpoint(self, device, model, optimizers, schedulers, scalers, resume: str = None): |
|
resume_path = self.checkpoint_path if resume is None else resume |
|
assert os.path.exists(resume_path), f"Resume {resume_path} does not exist!" |
|
|
|
|
|
checkpoint_dict = torch.load(resume_path, map_location=device) |
|
model.load_state_dict(checkpoint_dict["model"]) |
|
self.total_epochs = checkpoint_dict["epoch"] |
|
self.total_steps = checkpoint_dict["step"] |
|
for k, v in optimizers.items(): |
|
v.load_state_dict(checkpoint_dict[k]) |
|
for k, v in schedulers.items(): |
|
v.load_state_dict(checkpoint_dict[k]) |
|
for k, v in scalers.items(): |
|
v.load_state_dict(checkpoint_dict[k]) |
|
|
|
self.info(f"### Epoch: {self.total_epochs}| Steps: {self.total_steps}| Resume checkpoint from {resume_path}.") |
|
|
|
return self.total_epochs |
|
|
|
class EmptyManager: |
|
def __init__(self): |
|
for func_name in LogManager.__dict__.keys(): |
|
if not func_name.startswith("_"): |
|
setattr(self, func_name, lambda *args, **kwargs: print(f"Empty Manager! {func_name} is not available!")) |
|
|
|
class AverageMeter: |
|
def __init__(self): |
|
self.reset() |
|
|
|
def state_dict(self): |
|
return { |
|
"val": self.val, |
|
"avg": self.avg, |
|
"sum": self.sum, |
|
"count": self.count, |
|
} |
|
|
|
def load_state_dict(self, state_dict: dict): |
|
self.val = state_dict["val"] |
|
self.avg = state_dict["avg"] |
|
self.sum = state_dict["sum"] |
|
self.count = state_dict["count"] |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
return 0 |
|
|
|
def update(self, val: float, n: int = 1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
return self.count |
|
|
|
def __str__(self): |
|
return f"{self.avg:.4f}" |
|
|
|
def save_image(x: Tensor, save_path: str, scale_to_256: bool = True): |
|
""" |
|
Args: |
|
x (tensor): default data range is [0, 1] |
|
""" |
|
if scale_to_256: |
|
x = x.mul(255).clamp(0, 255) |
|
x = x.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") |
|
img = Image.fromarray(x) |
|
img.save(save_path) |
|
|
|
def save_images(images_list, ids_list, meta_path): |
|
for i, (image, id) in enumerate(zip(images_list, ids_list)): |
|
save_path = os.path.join(meta_path, f"{id}.png") |
|
save_image(image, save_path) |
|
|
|
def save_images_multithread(images_list, ids_list, meta_path): |
|
n_workers = 32 |
|
from concurrent.futures import ThreadPoolExecutor |
|
with ThreadPoolExecutor(max_workers=n_workers) as executor: |
|
for i in range(0, len(images_list), n_workers): |
|
cur_images = images_list[i:(i + n_workers)] |
|
cur_ids = ids_list[i:(i + n_workers)] |
|
executor.submit(save_images, cur_images, cur_ids, meta_path) |
|
|
|
def add_prefix(log_dict: dict, prefix: str): |
|
return { |
|
f"{prefix}/{key}": val for key, val in log_dict.items() |
|
} |
|
|
|
|
|
log = EmptyManager() |
|
GET_STATS: bool = (os.environ.get("ENABLE_STATS", "1") == "1") |
|
|