|
import os
|
|
import sys
|
|
import ujson
|
|
import mlflow
|
|
import traceback
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from colbert.utils.utils import print_message, create_directory
|
|
|
|
|
|
class Logger():
|
|
def __init__(self, rank, run):
|
|
self.rank = rank
|
|
self.is_main = self.rank in [-1, 0]
|
|
self.run = run
|
|
self.logs_path = os.path.join(self.run.path, "logs/")
|
|
|
|
if self.is_main:
|
|
self._init_mlflow()
|
|
self.initialized_tensorboard = False
|
|
create_directory(self.logs_path)
|
|
|
|
def _init_mlflow(self):
|
|
mlflow.set_tracking_uri('file://' + os.path.join(self.run.experiments_root, "logs/mlruns/"))
|
|
mlflow.set_experiment('/'.join([self.run.experiment, self.run.script]))
|
|
|
|
mlflow.set_tag('experiment', self.run.experiment)
|
|
mlflow.set_tag('name', self.run.name)
|
|
mlflow.set_tag('path', self.run.path)
|
|
|
|
def _init_tensorboard(self):
|
|
root = os.path.join(self.run.experiments_root, "logs/tensorboard/")
|
|
logdir = '__'.join([self.run.experiment, self.run.script, self.run.name])
|
|
logdir = os.path.join(root, logdir)
|
|
|
|
self.writer = SummaryWriter(log_dir=logdir)
|
|
self.initialized_tensorboard = True
|
|
|
|
def _log_exception(self, etype, value, tb):
|
|
if not self.is_main:
|
|
return
|
|
|
|
output_path = os.path.join(self.logs_path, 'exception.txt')
|
|
trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n'
|
|
print_message(trace, '\n\n')
|
|
|
|
self.log_new_artifact(output_path, trace)
|
|
|
|
def _log_all_artifacts(self):
|
|
if not self.is_main:
|
|
return
|
|
|
|
mlflow.log_artifacts(self.logs_path)
|
|
|
|
def _log_args(self, args):
|
|
if not self.is_main:
|
|
return
|
|
|
|
for key in vars(args):
|
|
value = getattr(args, key)
|
|
if type(value) in [int, float, str, bool]:
|
|
mlflow.log_param(key, value)
|
|
|
|
with open(os.path.join(self.logs_path, 'args.json'), 'w') as output_metadata:
|
|
ujson.dump(args.input_arguments.__dict__, output_metadata, indent=4)
|
|
output_metadata.write('\n')
|
|
|
|
with open(os.path.join(self.logs_path, 'args.txt'), 'w') as output_metadata:
|
|
output_metadata.write(' '.join(sys.argv) + '\n')
|
|
|
|
def log_metric(self, name, value, step, log_to_mlflow=True):
|
|
if not self.is_main:
|
|
return
|
|
|
|
if not self.initialized_tensorboard:
|
|
self._init_tensorboard()
|
|
|
|
if log_to_mlflow:
|
|
mlflow.log_metric(name, value, step=step)
|
|
self.writer.add_scalar(name, value, step)
|
|
|
|
def log_new_artifact(self, path, content):
|
|
with open(path, 'w') as f:
|
|
f.write(content)
|
|
|
|
mlflow.log_artifact(path)
|
|
|
|
def warn(self, *args):
|
|
msg = print_message('[WARNING]', '\t', *args)
|
|
|
|
with open(os.path.join(self.logs_path, 'warnings.txt'), 'a') as output_metadata:
|
|
output_metadata.write(msg + '\n\n\n')
|
|
|
|
def info_all(self, *args):
|
|
print_message('[' + str(self.rank) + ']', '\t', *args)
|
|
|
|
def info(self, *args):
|
|
if self.is_main:
|
|
print_message(*args)
|
|
|