torchnet / BaseTrainer.py
milselarch's picture
push to main
df07554
import os
import numpy as np
import shutil
import resource
import options as opt
from helpers import *
from datetime import datetime as Datetime
from tensorboardX import SummaryWriter
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(
resource.RLIMIT_NOFILE, (65536, rlimit[1])
)
class BaseTrainer(object):
def __init__(self, name='M', base_dir=''):
self.name = name
self.base_dir = base_dir
self.date_stamp = self.make_date_stamp()
self.save_name = f'{self.name}-{self.date_stamp}'
self.weights_dir = None
self.log_dir = None
self.writer = None
@staticmethod
def get_dataset_kwargs(
shared_dict=None, base_dir='',
char_map=opt.char_map, **kwargs
):
return kwargify(
video_path=opt.video_path,
shared_dict=shared_dict,
alignments_dir=opt.alignments_dir,
vid_pad=opt.vid_padding,
image_dir=opt.images_dir,
txt_pad=opt.txt_padding,
phonemes_dir=opt.phonemes_dir,
frame_doubling=opt.frame_doubling,
char_map=char_map,
base_dir=base_dir,
**kwargs
)
def init_tensorboard(self):
self.log_dir = f'runs/{self.save_name}'
self.weights_dir = f'weights/{self.save_name}'
if not os.path.exists(self.log_dir):
os.mkdir(self.log_dir)
if not os.path.exists(self.weights_dir):
os.mkdir(self.weights_dir)
self.writer = SummaryWriter(self.log_dir)
# save current state of options file
shutil.copyfile(
'options.py', os.path.join(self.log_dir, 'options.py')
)
@staticmethod
def make_date_stamp():
return Datetime.now().strftime("%y%m%d-%H%M")
def log_scalar(self, name, value, iterations, label):
self.writer.add_scalars(name, {label: value}, iterations)