|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Adversarial losses for text models.""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
|
|
|
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
flags = tf.app.flags |
|
FLAGS = flags.FLAGS |
|
|
|
|
|
flags.DEFINE_float('perturb_norm_length', 5.0, |
|
'Norm length of adversarial perturbation to be ' |
|
'optimized with validation. ' |
|
'5.0 is optimal on IMDB with virtual adversarial training. ') |
|
|
|
|
|
flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration') |
|
flags.DEFINE_float('small_constant_for_finite_diff', 1e-1, |
|
'Small constant for finite difference method') |
|
|
|
|
|
flags.DEFINE_string('adv_training_method', None, |
|
'The flag which specifies training method. ' |
|
'"" : non-adversarial training (e.g. for running the ' |
|
' semi-supervised sequence learning model) ' |
|
'"rp" : random perturbation training ' |
|
'"at" : adversarial training ' |
|
'"vat" : virtual adversarial training ' |
|
'"atvat" : at + vat ') |
|
flags.DEFINE_float('adv_reg_coeff', 1.0, |
|
'Regularization coefficient of adversarial loss.') |
|
|
|
|
|
def random_perturbation_loss(embedded, length, loss_fn): |
|
"""Adds noise to embeddings and recomputes classification loss.""" |
|
noise = tf.random_normal(shape=tf.shape(embedded)) |
|
perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length) |
|
return loss_fn(embedded + perturb) |
|
|
|
|
|
def adversarial_loss(embedded, loss, loss_fn): |
|
"""Adds gradient to embedding and recomputes classification loss.""" |
|
grad, = tf.gradients( |
|
loss, |
|
embedded, |
|
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) |
|
grad = tf.stop_gradient(grad) |
|
perturb = _scale_l2(grad, FLAGS.perturb_norm_length) |
|
return loss_fn(embedded + perturb) |
|
|
|
|
|
def virtual_adversarial_loss(logits, embedded, inputs, |
|
logits_from_embedding_fn): |
|
"""Virtual adversarial loss. |
|
|
|
Computes virtual adversarial perturbation by finite difference method and |
|
power iteration, adds it to the embedding, and computes the KL divergence |
|
between the new logits and the original logits. |
|
|
|
Args: |
|
logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if |
|
num_classes=2, otherwise m=num_classes. |
|
embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim]. |
|
inputs: VatxtInput. |
|
logits_from_embedding_fn: callable that takes embeddings and returns |
|
classifier logits. |
|
|
|
Returns: |
|
kl: float scalar. |
|
""" |
|
|
|
logits = tf.stop_gradient(logits) |
|
|
|
|
|
weights = inputs.eos_weights |
|
assert weights is not None |
|
if FLAGS.single_label: |
|
indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) |
|
weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1) |
|
|
|
|
|
|
|
d = tf.random_normal(shape=tf.shape(embedded)) |
|
|
|
|
|
|
|
|
|
|
|
for _ in xrange(FLAGS.num_power_iteration): |
|
d = _scale_l2( |
|
_mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff) |
|
|
|
d_logits = logits_from_embedding_fn(embedded + d) |
|
kl = _kl_divergence_with_logits(logits, d_logits, weights) |
|
d, = tf.gradients( |
|
kl, |
|
d, |
|
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) |
|
d = tf.stop_gradient(d) |
|
|
|
perturb = _scale_l2(d, FLAGS.perturb_norm_length) |
|
vadv_logits = logits_from_embedding_fn(embedded + perturb) |
|
return _kl_divergence_with_logits(logits, vadv_logits, weights) |
|
|
|
|
|
def random_perturbation_loss_bidir(embedded, length, loss_fn): |
|
"""Adds noise to embeddings and recomputes classification loss.""" |
|
noise = [tf.random_normal(shape=tf.shape(emb)) for emb in embedded] |
|
masked = [_mask_by_length(n, length) for n in noise] |
|
scaled = [_scale_l2(m, FLAGS.perturb_norm_length) for m in masked] |
|
return loss_fn([e + s for (e, s) in zip(embedded, scaled)]) |
|
|
|
|
|
def adversarial_loss_bidir(embedded, loss, loss_fn): |
|
"""Adds gradient to embeddings and recomputes classification loss.""" |
|
grads = tf.gradients( |
|
loss, |
|
embedded, |
|
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) |
|
adv_exs = [ |
|
emb + _scale_l2(tf.stop_gradient(g), FLAGS.perturb_norm_length) |
|
for emb, g in zip(embedded, grads) |
|
] |
|
return loss_fn(adv_exs) |
|
|
|
|
|
def virtual_adversarial_loss_bidir(logits, embedded, inputs, |
|
logits_from_embedding_fn): |
|
"""Virtual adversarial loss for bidirectional models.""" |
|
logits = tf.stop_gradient(logits) |
|
f_inputs, _ = inputs |
|
weights = f_inputs.eos_weights |
|
if FLAGS.single_label: |
|
indices = tf.stack([tf.range(FLAGS.batch_size), f_inputs.length - 1], 1) |
|
weights = tf.expand_dims(tf.gather_nd(f_inputs.eos_weights, indices), 1) |
|
assert weights is not None |
|
|
|
perturbs = [ |
|
_mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length) |
|
for emb in embedded |
|
] |
|
for _ in xrange(FLAGS.num_power_iteration): |
|
perturbs = [ |
|
_scale_l2(d, FLAGS.small_constant_for_finite_diff) for d in perturbs |
|
] |
|
d_logits = logits_from_embedding_fn( |
|
[emb + d for (emb, d) in zip(embedded, perturbs)]) |
|
kl = _kl_divergence_with_logits(logits, d_logits, weights) |
|
perturbs = tf.gradients( |
|
kl, |
|
perturbs, |
|
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) |
|
perturbs = [tf.stop_gradient(d) for d in perturbs] |
|
|
|
perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs] |
|
vadv_logits = logits_from_embedding_fn( |
|
[emb + d for (emb, d) in zip(embedded, perturbs)]) |
|
return _kl_divergence_with_logits(logits, vadv_logits, weights) |
|
|
|
|
|
def _mask_by_length(t, length): |
|
"""Mask t, 3-D [batch, time, dim], by length, 1-D [batch,].""" |
|
maxlen = t.get_shape().as_list()[1] |
|
|
|
|
|
mask = tf.sequence_mask(length - 1, maxlen=maxlen) |
|
mask = tf.expand_dims(tf.cast(mask, tf.float32), -1) |
|
|
|
return t * mask |
|
|
|
|
|
def _scale_l2(x, norm_length): |
|
|
|
|
|
|
|
|
|
alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12 |
|
l2_norm = alpha * tf.sqrt( |
|
tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6) |
|
x_unit = x / l2_norm |
|
return norm_length * x_unit |
|
|
|
|
|
def _kl_divergence_with_logits(q_logits, p_logits, weights): |
|
"""Returns weighted KL divergence between distributions q and p. |
|
|
|
Args: |
|
q_logits: logits for 1st argument of KL divergence shape |
|
[batch_size, num_timesteps, num_classes] if num_classes > 2, and |
|
[batch_size, num_timesteps] if num_classes == 2. |
|
p_logits: logits for 2nd argument of KL divergence with same shape q_logits. |
|
weights: 1-D float tensor with shape [batch_size, num_timesteps]. |
|
Elements should be 1.0 only on end of sequences |
|
|
|
Returns: |
|
KL: float scalar. |
|
""" |
|
|
|
if FLAGS.num_classes == 2: |
|
q = tf.nn.sigmoid(q_logits) |
|
kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) + |
|
tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q)) |
|
kl = tf.squeeze(kl, 2) |
|
|
|
|
|
else: |
|
q = tf.nn.softmax(q_logits) |
|
kl = tf.reduce_sum( |
|
q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), -1) |
|
|
|
num_labels = tf.reduce_sum(weights) |
|
num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels) |
|
|
|
kl.get_shape().assert_has_rank(2) |
|
weights.get_shape().assert_has_rank(2) |
|
|
|
loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl') |
|
return loss |
|
|