|
|
|
import argparse |
|
import os |
|
import random |
|
import time |
|
import logging |
|
import numpy as np |
|
from base import config |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser(description=' ') |
|
parser.add_argument('--config', type=str, default='**.yaml', help='config file') |
|
parser.add_argument('opts', help=' ', default=None, |
|
nargs=argparse.REMAINDER) |
|
args = parser.parse_args() |
|
assert args.config is not None |
|
cfg = config.load_cfg_from_cfg_file(args.config) |
|
if args.opts is not None: |
|
cfg = config.merge_cfg_from_list(cfg, args.opts) |
|
return cfg |
|
|
|
|
|
def get_logger(): |
|
logger_name = "main-logger" |
|
logger = logging.getLogger(logger_name) |
|
logger.setLevel(logging.INFO) |
|
handler = logging.StreamHandler() |
|
fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d]=>%(message)s" |
|
handler.setFormatter(logging.Formatter(fmt)) |
|
logger.addHandler(handler) |
|
return logger |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
def check_mkdir(dir_name): |
|
if not os.path.exists(dir_name): |
|
os.mkdir(dir_name) |
|
|
|
|
|
def check_makedirs(dir_name): |
|
if not os.path.exists(dir_name): |
|
os.makedirs(dir_name) |
|
|
|
|
|
def main_process(args): |
|
return not args.multiprocessing_distributed or ( |
|
args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0) |
|
|