# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A multi-task and semi-supervised NLP model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from model import encoder from model import shared_inputs class Inference(object): def __init__(self, config, inputs, pretrained_embeddings, tasks): with tf.variable_scope('encoder'): self.encoder = encoder.Encoder(config, inputs, pretrained_embeddings) self.modules = {} for task in tasks: with tf.variable_scope(task.name): self.modules[task.name] = task.get_module(inputs, self.encoder) class Model(object): def __init__(self, config, pretrained_embeddings, tasks): self._config = config self._tasks = tasks self._global_step, self._optimizer = self._get_optimizer() self._inputs = shared_inputs.Inputs(config) with tf.variable_scope('model', reuse=tf.AUTO_REUSE) as scope: inference = Inference(config, self._inputs, pretrained_embeddings, tasks) self._trainer = inference self._tester = inference self._teacher = inference if config.ema_test or config.ema_teacher: ema = tf.train.ExponentialMovingAverage(config.ema_decay) model_vars = tf.get_collection("trainable_variables", "model") ema_op = ema.apply(model_vars) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op) def ema_getter(getter, name, *args, **kwargs): var = getter(name, *args, **kwargs) return ema.average(var) scope.set_custom_getter(ema_getter) inference_ema = Inference( config, self._inputs, pretrained_embeddings, tasks) if config.ema_teacher: self._teacher = inference_ema if config.ema_test: self._tester = inference_ema self._unlabeled_loss = self._get_consistency_loss(tasks) self._unlabeled_train_op = self._get_train_op(self._unlabeled_loss) self._labeled_train_ops = {} for task in self._tasks: task_loss = self._trainer.modules[task.name].supervised_loss self._labeled_train_ops[task.name] = self._get_train_op(task_loss) def _get_consistency_loss(self, tasks): return sum([self._trainer.modules[task.name].unsupervised_loss for task in tasks]) def _get_optimizer(self): global_step = tf.get_variable('global_step', initializer=0, trainable=False) warm_up_multiplier = (tf.minimum(tf.to_float(global_step), self._config.warm_up_steps) / self._config.warm_up_steps) decay_multiplier = 1.0 / (1 + self._config.lr_decay * tf.sqrt(tf.to_float(global_step))) lr = self._config.lr * warm_up_multiplier * decay_multiplier optimizer = tf.train.MomentumOptimizer(lr, self._config.momentum) return global_step, optimizer def _get_train_op(self, loss): grads, vs = zip(*self._optimizer.compute_gradients(loss)) grads, _ = tf.clip_by_global_norm(grads, self._config.grad_clip) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): return self._optimizer.apply_gradients( zip(grads, vs), global_step=self._global_step) def _create_feed_dict(self, mb, model, is_training=True): feed = self._inputs.create_feed_dict(mb, is_training) if mb.task_name in model.modules: model.modules[mb.task_name].update_feed_dict(feed, mb) else: for module in model.modules.values(): module.update_feed_dict(feed, mb) return feed def train_unlabeled(self, sess, mb): return sess.run([self._unlabeled_train_op, self._unlabeled_loss], feed_dict=self._create_feed_dict(mb, self._trainer))[1] def train_labeled(self, sess, mb): return sess.run([self._labeled_train_ops[mb.task_name], self._trainer.modules[mb.task_name].supervised_loss,], feed_dict=self._create_feed_dict(mb, self._trainer))[1] def run_teacher(self, sess, mb): result = sess.run({task.name: self._teacher.modules[task.name].probs for task in self._tasks}, feed_dict=self._create_feed_dict(mb, self._teacher, False)) for task_name, probs in result.iteritems(): mb.teacher_predictions[task_name] = probs.astype('float16') def test(self, sess, mb): return sess.run( [self._tester.modules[mb.task_name].supervised_loss, self._tester.modules[mb.task_name].preds], feed_dict=self._create_feed_dict(mb, self._tester, False)) def get_global_step(self, sess): return sess.run(self._global_step)