File size: 5,323 Bytes
6e5cc8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import pathlib
import pickle
import re
import numpy as np
import tensorflow as tf
from tensorflow.keras import mixed_precision as prec
try:
from tensorflow.python.distribute import values
except Exception:
from google3.third_party.tensorflow.python.distribute import values
tf.tensor = tf.convert_to_tensor
for base in (tf.Tensor, tf.Variable, values.PerReplica):
base.mean = tf.math.reduce_mean
base.std = tf.math.reduce_std
base.var = tf.math.reduce_variance
base.sum = tf.math.reduce_sum
base.any = tf.math.reduce_any
base.all = tf.math.reduce_all
base.min = tf.math.reduce_min
base.max = tf.math.reduce_max
base.abs = tf.math.abs
base.logsumexp = tf.math.reduce_logsumexp
base.transpose = tf.transpose
base.reshape = tf.reshape
base.astype = tf.cast
# values.PerReplica.dtype = property(lambda self: self.values[0].dtype)
# tf.TensorHandle.__repr__ = lambda x: '<tensor>'
# tf.TensorHandle.__str__ = lambda x: '<tensor>'
# np.set_printoptions(threshold=5, edgeitems=0)
class Module(tf.Module):
def save(self, filename):
values = tf.nest.map_structure(lambda x: x.numpy(), self.variables)
amount = len(tf.nest.flatten(values))
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
print(f'Save checkpoint with {amount} tensors and {count} parameters.')
with pathlib.Path(filename).open('wb') as f:
pickle.dump(values, f)
def load(self, filename):
with pathlib.Path(filename).open('rb') as f:
values = pickle.load(f)
amount = len(tf.nest.flatten(values))
count = int(sum(np.prod(x.shape) for x in tf.nest.flatten(values)))
print(f'Load checkpoint with {amount} tensors and {count} parameters.')
tf.nest.map_structure(lambda x, y: x.assign(y), self.variables, values)
def get(self, name, ctor, *args, **kwargs):
# Create or get layer by name to avoid mentioning it in the constructor.
if not hasattr(self, '_modules'):
self._modules = {}
if name not in self._modules:
self._modules[name] = ctor(*args, **kwargs)
return self._modules[name]
class Optimizer(tf.Module):
def __init__(
self, name, lr, eps=1e-4, clip=None, wd=None,
opt='adam', wd_pattern=r'.*'):
assert 0 <= wd < 1
assert not clip or 1 <= clip
self._name = name
self._clip = clip
self._wd = wd
self._wd_pattern = wd_pattern
self._opt = {
'adam': lambda: tf.optimizers.Adam(lr, epsilon=eps),
'nadam': lambda: tf.optimizers.Nadam(lr, epsilon=eps),
'adamax': lambda: tf.optimizers.Adamax(lr, epsilon=eps),
'sgd': lambda: tf.optimizers.SGD(lr),
'momentum': lambda: tf.optimizers.SGD(lr, 0.9),
}[opt]()
self._mixed = (prec.global_policy().compute_dtype == tf.float16)
if self._mixed:
self._opt = prec.LossScaleOptimizer(self._opt, dynamic=True)
self._once = True
@property
def variables(self):
return self._opt.variables()
def __call__(self, tape, loss, modules):
assert loss.dtype is tf.float32, (self._name, loss.dtype)
assert len(loss.shape) == 0, (self._name, loss.shape)
metrics = {}
# Find variables.
modules = modules if hasattr(modules, '__len__') else (modules,)
varibs = tf.nest.flatten([module.variables for module in modules])
count = sum(np.prod(x.shape) for x in varibs)
if self._once:
print(f'Found {count} {self._name} parameters.')
self._once = False
# Check loss.
tf.debugging.check_numerics(loss, self._name + '_loss')
metrics[f'{self._name}_loss'] = loss
# Compute scaled gradient.
if self._mixed:
with tape:
loss = self._opt.get_scaled_loss(loss)
grads = tape.gradient(loss, varibs)
if self._mixed:
grads = self._opt.get_unscaled_gradients(grads)
if self._mixed:
metrics[f'{self._name}_loss_scale'] = self._opt.loss_scale
# Distributed sync.
context = tf.distribute.get_replica_context()
if context:
grads = context.all_reduce('mean', grads)
# Gradient clipping.
norm = tf.linalg.global_norm(grads)
if not self._mixed:
tf.debugging.check_numerics(norm, self._name + '_norm')
if self._clip:
grads, _ = tf.clip_by_global_norm(grads, self._clip, norm)
metrics[f'{self._name}_grad_norm'] = norm
# Weight decay.
if self._wd:
self._apply_weight_decay(varibs)
# Apply gradients.
self._opt.apply_gradients(
zip(grads, varibs),
experimental_aggregate_gradients=False)
return metrics
def _apply_weight_decay(self, varibs):
nontrivial = (self._wd_pattern != r'.*')
if nontrivial:
print('Applied weight decay to variables:')
for var in varibs:
if re.search(self._wd_pattern, self._name + '/' + var.name):
if nontrivial:
print('- ' + self._name + '/' + var.name)
var.assign((1 - self._wd) * var)
|