|
import glob |
|
import json |
|
import os |
|
import pathlib |
|
import random |
|
import re |
|
import sys |
|
import time |
|
|
|
import matplotlib.pylab as plt |
|
import numpy as np |
|
import torch |
|
import yaml |
|
from torch import distributed as dist |
|
from torch.nn.utils import weight_norm |
|
|
|
|
|
def seed_everything(seed, cudnn_deterministic=False): |
|
""" |
|
Function that sets seed for pseudo-random number generators in: |
|
pytorch, numpy, python.random |
|
|
|
Args: |
|
seed: the integer value seed for global random state |
|
""" |
|
if seed is not None: |
|
|
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_primary(): |
|
return get_rank() == 0 |
|
|
|
|
|
def get_rank(): |
|
if not dist.is_available(): |
|
return 0 |
|
if not dist.is_initialized(): |
|
return 0 |
|
|
|
return dist.get_rank() |
|
|
|
|
|
def load_yaml_config(path): |
|
with open(path) as f: |
|
config = yaml.full_load(f) |
|
return config |
|
|
|
|
|
def save_config_to_yaml(config, path): |
|
assert path.endswith('.yaml') |
|
with open(path, 'w') as f: |
|
f.write(yaml.dump(config)) |
|
f.close() |
|
|
|
|
|
def save_dict_to_json(d, path, indent=None): |
|
json.dump(d, open(path, 'w'), indent=indent) |
|
|
|
|
|
def load_dict_from_json(path): |
|
return json.load(open(path, 'r')) |
|
|
|
|
|
def write_args(args, path): |
|
args_dict = dict((name, getattr(args, name)) for name in dir(args) |
|
if not name.startswith('_')) |
|
with open(path, 'a') as args_file: |
|
args_file.write('==> torch version: {}\n'.format(torch.__version__)) |
|
args_file.write( |
|
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) |
|
args_file.write('==> Cmd:\n') |
|
args_file.write(str(sys.argv)) |
|
args_file.write('\n==> args:\n') |
|
for k, v in sorted(args_dict.items()): |
|
args_file.write(' %s: %s\n' % (str(k), str(v))) |
|
args_file.close() |
|
|
|
|
|
class Logger(object): |
|
def __init__(self, args): |
|
self.args = args |
|
self.save_dir = args.save_dir |
|
self.is_primary = is_primary() |
|
|
|
if self.is_primary: |
|
os.makedirs(self.save_dir, exist_ok=True) |
|
|
|
|
|
self.config_dir = os.path.join(self.save_dir, 'configs') |
|
os.makedirs(self.config_dir, exist_ok=True) |
|
file_name = os.path.join(self.config_dir, 'args.txt') |
|
write_args(args, file_name) |
|
|
|
log_dir = os.path.join(self.save_dir, 'logs') |
|
if not os.path.exists(log_dir): |
|
os.makedirs(log_dir, exist_ok=True) |
|
self.text_writer = open(os.path.join(log_dir, 'log.txt'), |
|
'a') |
|
if args.tensorboard: |
|
self.log_info('using tensorboard') |
|
self.tb_writer = torch.utils.tensorboard.SummaryWriter( |
|
log_dir=log_dir |
|
) |
|
else: |
|
self.tb_writer = None |
|
|
|
def save_config(self, config): |
|
if self.is_primary: |
|
save_config_to_yaml(config, |
|
os.path.join(self.config_dir, 'config.yaml')) |
|
|
|
def log_info(self, info, check_primary=True): |
|
if self.is_primary or (not check_primary): |
|
print(info) |
|
if self.is_primary: |
|
info = str(info) |
|
time_str = time.strftime('%Y-%m-%d-%H-%M') |
|
info = '{}: {}'.format(time_str, info) |
|
if not info.endswith('\n'): |
|
info += '\n' |
|
self.text_writer.write(info) |
|
self.text_writer.flush() |
|
|
|
def add_scalar(self, **kargs): |
|
"""Log a scalar variable.""" |
|
if self.is_primary: |
|
if self.tb_writer is not None: |
|
self.tb_writer.add_scalar(**kargs) |
|
|
|
def add_scalars(self, **kargs): |
|
"""Log a scalar variable.""" |
|
if self.is_primary: |
|
if self.tb_writer is not None: |
|
self.tb_writer.add_scalars(**kargs) |
|
|
|
def add_image(self, **kargs): |
|
"""Log a scalar variable.""" |
|
if self.is_primary: |
|
if self.tb_writer is not None: |
|
self.tb_writer.add_image(**kargs) |
|
|
|
def add_images(self, **kargs): |
|
"""Log a scalar variable.""" |
|
if self.is_primary: |
|
if self.tb_writer is not None: |
|
self.tb_writer.add_images(**kargs) |
|
|
|
def close(self): |
|
if self.is_primary: |
|
self.text_writer.close() |
|
self.tb_writer.close() |
|
|
|
|
|
def plot_spectrogram(spectrogram): |
|
fig, ax = plt.subplots(figsize=(10, 2)) |
|
im = ax.imshow( |
|
spectrogram, aspect="auto", origin="lower", interpolation='none') |
|
plt.colorbar(im, ax=ax) |
|
|
|
fig.canvas.draw() |
|
plt.close() |
|
|
|
return fig |
|
|
|
|
|
def init_weights(m, mean=0.0, std=0.01): |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
m.weight.data.normal_(mean, std) |
|
|
|
|
|
def apply_weight_norm(m): |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
weight_norm(m) |
|
|
|
|
|
def get_padding(kernel_size, dilation=1): |
|
return int((kernel_size * dilation - dilation) / 2) |
|
|
|
|
|
def load_checkpoint(filepath, device): |
|
assert os.path.isfile(filepath) |
|
print("Loading '{}'".format(filepath)) |
|
checkpoint_dict = torch.load(filepath, map_location=device) |
|
print("Complete.") |
|
return checkpoint_dict |
|
|
|
|
|
def save_checkpoint(filepath, obj, num_ckpt_keep=5): |
|
name = re.match(r'(do|g)_\d+', pathlib.Path(filepath).name).group(1) |
|
ckpts = sorted(pathlib.Path(filepath).parent.glob(f'{name}_*')) |
|
if len(ckpts) > num_ckpt_keep: |
|
[os.remove(c) for c in ckpts[:-num_ckpt_keep]] |
|
print("Saving checkpoint to {}".format(filepath)) |
|
torch.save(obj, filepath) |
|
print("Complete.") |
|
|
|
|
|
def scan_checkpoint(cp_dir, prefix): |
|
pattern = os.path.join(cp_dir, prefix + '????????') |
|
cp_list = glob.glob(pattern) |
|
if len(cp_list) == 0: |
|
return None |
|
return sorted(cp_list)[-1] |
|
|