|
import os |
|
import re |
|
import sys |
|
import time |
|
import json |
|
import torch |
|
import pickle |
|
import random |
|
import getpass |
|
import logging |
|
import argparse |
|
import subprocess |
|
import numpy as np |
|
from datetime import timedelta, date |
|
from .utils import get_code_version |
|
|
|
|
|
class LogFormatter(): |
|
|
|
def __init__(self): |
|
self.start_time = time.time() |
|
|
|
def format(self, record): |
|
elapsed_seconds = round(record.created - self.start_time) |
|
|
|
prefix = "%s - %s - %s" % ( |
|
record.levelname, |
|
time.strftime('%x %X'), |
|
timedelta(seconds=elapsed_seconds) |
|
) |
|
message = record.getMessage() |
|
message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) |
|
return "%s - %s" % (prefix, message) if message else '' |
|
|
|
|
|
def create_logger(filepath, rank): |
|
""" |
|
Create a logger. |
|
Use a different log file for each process. |
|
""" |
|
|
|
log_formatter = LogFormatter() |
|
|
|
|
|
if filepath is not None: |
|
if rank > 0: |
|
filepath = '%s-%i' % (filepath, rank) |
|
file_handler = logging.FileHandler(filepath, "a", encoding='utf-8') |
|
file_handler.setLevel(logging.DEBUG) |
|
file_handler.setFormatter(log_formatter) |
|
|
|
|
|
console_handler = logging.StreamHandler() |
|
console_handler.setLevel(logging.INFO) |
|
console_handler.setFormatter(log_formatter) |
|
|
|
|
|
logger = logging.getLogger() |
|
logger.handlers = [] |
|
logger.setLevel(logging.DEBUG) |
|
logger.propagate = False |
|
if filepath is not None: |
|
logger.addHandler(file_handler) |
|
logger.addHandler(console_handler) |
|
|
|
|
|
def reset_time(): |
|
log_formatter.start_time = time.time() |
|
logger.reset_time = reset_time |
|
|
|
return logger |
|
|
|
|
|
def initialize_exp(params): |
|
""" |
|
Initialize the experiment: |
|
- dump parameters |
|
- create a logger |
|
""" |
|
|
|
exp_folder = get_dump_path(params) |
|
json.dump(vars(params), open(os.path.join(exp_folder, 'params.pkl'), 'w'), indent=4) |
|
|
|
|
|
command = ["python", sys.argv[0]] |
|
for x in sys.argv[1:]: |
|
if x.startswith('--'): |
|
assert '"' not in x and "'" not in x |
|
command.append(x) |
|
else: |
|
assert "'" not in x |
|
if re.match('^[a-zA-Z0-9_]+$', x): |
|
command.append("%s" % x) |
|
else: |
|
command.append("'%s'" % x) |
|
command = ' '.join(command) |
|
params.command = command + ' --exp_id "%s"' % params.exp_id |
|
|
|
|
|
assert len(params.exp_name.strip()) > 0 |
|
|
|
|
|
logger = create_logger(os.path.join(exp_folder, 'train.log'), rank=getattr(params, 'global_rank', 0)) |
|
logger.info("============ Initialized logger ============") |
|
|
|
|
|
|
|
|
|
logger.info("The experiment will be stored in %s\n" % exp_folder) |
|
logger.info("Running command: %s" % command) |
|
logger.info("") |
|
return logger |
|
|
|
|
|
def get_dump_path(params): |
|
""" |
|
Create a directory to store the experiment. |
|
""" |
|
assert len(params.exp_name) > 0 |
|
assert not params.dump_path in ('', None), \ |
|
'Please choose your favorite destination for dump.' |
|
dump_path = params.dump_path |
|
|
|
|
|
when = date.today().strftime('%m%d-') |
|
sweep_path = os.path.join(dump_path, when + params.exp_name) |
|
if not os.path.exists(sweep_path): |
|
subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait() |
|
|
|
|
|
if params.exp_id == '': |
|
chars = 'abcdefghijklmnopqrstuvwxyz0123456789' |
|
while True: |
|
exp_id = ''.join(random.choice(chars) for _ in range(10)) |
|
if not os.path.isdir(os.path.join(sweep_path, exp_id)): |
|
break |
|
params.exp_id = exp_id |
|
|
|
|
|
exp_folder = os.path.join(sweep_path, params.exp_id) |
|
if not os.path.isdir(exp_folder): |
|
subprocess.Popen("mkdir -p %s" % exp_folder, shell=True).wait() |
|
return exp_folder |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |
|
|