Spaces:
Running
Running
import logging | |
from pprint import pformat | |
import numpy as np | |
import tensorflow as tf | |
class BaseModel(object): | |
def __init__(self, hps): | |
super(BaseModel, self).__init__() | |
self.hps = hps | |
g = tf.Graph() | |
with g.as_default(): | |
# open a session | |
config = tf.compat.v1.ConfigProto() | |
config.log_device_placement = True | |
config.allow_soft_placement = True | |
config.gpu_options.allow_growth = True | |
self.sess = tf.compat.v1.Session(config=config, graph=g) | |
# build model | |
self.build_net() | |
self.build_ops() | |
# initialize | |
self.sess.run(tf.compat.v1.global_variables_initializer()) | |
self.saver = tf.compat.v1.train.Saver() | |
self.writer = tf.compat.v1.summary.FileWriter(self.hps.exp_dir + '/summary') | |
# logging | |
total_params = 0 | |
trainable_variables = tf.compat.v1.trainable_variables() | |
logging.info('=' * 20) | |
logging.info("Variables:") | |
logging.info(pformat(trainable_variables)) | |
for v in trainable_variables: | |
num_params = np.prod(v.get_shape().as_list()) | |
total_params += num_params | |
logging.info("TOTAL TENSORS: %d TOTAL PARAMS: %f[M]" % ( | |
len(trainable_variables), total_params / 1e6)) | |
logging.info('=' * 20) | |
def save(self, filename='params'): | |
fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt' | |
self.saver.save(self.sess, fname) | |
def load(self, filename='params'): | |
fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt' | |
self.saver.restore(self.sess, fname) | |
def build_net(self): | |
raise NotImplementedError() | |
def build_ops(self): | |
# optimizer | |
self.global_step = tf.compat.v1.train.get_or_create_global_step() | |
learning_rate = tf.compat.v1.train.inverse_time_decay( | |
self.hps.lr, self.global_step, | |
self.hps.decay_steps, self.hps.decay_rate, | |
staircase=True) | |
warmup_lr = tf.compat.v1.train.inverse_time_decay( | |
0.001 * self.hps.lr, self.global_step, | |
self.hps.decay_steps, self.hps.decay_rate, | |
staircase=True) | |
learning_rate = tf.cond(pred=tf.less(self.global_step, 1000), true_fn=lambda: warmup_lr, false_fn=lambda: learning_rate) | |
tf.compat.v1.summary.scalar('lr', learning_rate) | |
if self.hps.optimizer == 'adam': | |
optimizer = tf.compat.v1.train.AdamOptimizer( | |
learning_rate=learning_rate, | |
beta1=0.9, beta2=0.999, epsilon=1e-08, | |
use_locking=False, name="Adam") | |
elif self.hps.optimizer == 'rmsprop': | |
optimizer = tf.compat.v1.train.RMSPropOptimizer( | |
learning_rate=learning_rate) | |
elif self.hps.optimizer == 'mom': | |
optimizer = tf.compat.v1.train.MomentumOptimizer( | |
learning_rate=learning_rate, | |
momentum=0.9) | |
else: | |
optimizer = tf.compat.v1.train.GradientDescentOptimizer( | |
learning_rate=learning_rate) | |
# regularization | |
l2_reg = sum( | |
[tf.reduce_sum(input_tensor=tf.square(v)) for v in tf.compat.v1.trainable_variables() | |
if ("magnitude" in v.name) or ("rescaling_scale" in v.name)]) | |
reg_loss = 0.00005 * l2_reg | |
# train | |
grads_and_vars = optimizer.compute_gradients( | |
self.loss+reg_loss, tf.compat.v1.trainable_variables()) | |
grads, vars_ = zip(*grads_and_vars) | |
if self.hps.clip_gradient > 0: | |
grads, gradient_norm = tf.clip_by_global_norm( | |
grads, clip_norm=self.hps.clip_gradient) | |
gradient_norm = tf.debugging.check_numerics( | |
gradient_norm, "Gradient norm is NaN or Inf.") | |
tf.compat.v1.summary.scalar('gradient_norm', gradient_norm) | |
capped_grads_and_vars = zip(grads, vars_) | |
self.train_op = optimizer.apply_gradients( | |
capped_grads_and_vars, global_step=self.global_step) | |
# summary | |
self.summ_op = tf.compat.v1.summary.merge_all() | |
def execute(self, cmd, batch): | |
return self.sess.run(cmd, {self.x:batch['x'], self.b:batch['b'], self.m:batch['m']}) | |