|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Runs training for CVT text models.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import bisect |
|
import time |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
from base import utils |
|
from model import multitask_model |
|
from task_specific import task_definitions |
|
|
|
|
|
class Trainer(object): |
|
def __init__(self, config): |
|
self._config = config |
|
self.tasks = [task_definitions.get_task(self._config, task_name) |
|
for task_name in self._config.task_names] |
|
|
|
utils.log('Loading Pretrained Embeddings') |
|
pretrained_embeddings = utils.load_cpickle(self._config.word_embeddings) |
|
|
|
utils.log('Building Model') |
|
self._model = multitask_model.Model( |
|
self._config, pretrained_embeddings, self.tasks) |
|
utils.log() |
|
|
|
def train(self, sess, progress, summary_writer): |
|
heading = lambda s: utils.heading(s, '(' + self._config.model_name + ')') |
|
trained_on_sentences = 0 |
|
start_time = time.time() |
|
unsupervised_loss_total, unsupervised_loss_count = 0, 0 |
|
supervised_loss_total, supervised_loss_count = 0, 0 |
|
for mb in self._get_training_mbs(progress.unlabeled_data_reader): |
|
if mb.task_name != 'unlabeled': |
|
loss = self._model.train_labeled(sess, mb) |
|
supervised_loss_total += loss |
|
supervised_loss_count += 1 |
|
|
|
if mb.task_name == 'unlabeled': |
|
self._model.run_teacher(sess, mb) |
|
loss = self._model.train_unlabeled(sess, mb) |
|
unsupervised_loss_total += loss |
|
unsupervised_loss_count += 1 |
|
mb.teacher_predictions.clear() |
|
|
|
trained_on_sentences += mb.size |
|
global_step = self._model.get_global_step(sess) |
|
|
|
if global_step % self._config.print_every == 0: |
|
utils.log('step {:} - ' |
|
'supervised loss: {:.2f} - ' |
|
'unsupervised loss: {:.2f} - ' |
|
'{:.1f} sentences per second'.format( |
|
global_step, |
|
supervised_loss_total / max(1, supervised_loss_count), |
|
unsupervised_loss_total / max(1, unsupervised_loss_count), |
|
trained_on_sentences / (time.time() - start_time))) |
|
unsupervised_loss_total, unsupervised_loss_count = 0, 0 |
|
supervised_loss_total, supervised_loss_count = 0, 0 |
|
|
|
if global_step % self._config.eval_dev_every == 0: |
|
heading('EVAL ON DEV') |
|
self.evaluate_all_tasks(sess, summary_writer, progress.history) |
|
progress.save_if_best_dev_model(sess, global_step) |
|
utils.log() |
|
|
|
if global_step % self._config.eval_train_every == 0: |
|
heading('EVAL ON TRAIN') |
|
self.evaluate_all_tasks(sess, summary_writer, progress.history, True) |
|
utils.log() |
|
|
|
if global_step % self._config.save_model_every == 0: |
|
heading('CHECKPOINTING MODEL') |
|
progress.write(sess, global_step) |
|
utils.log() |
|
|
|
def evaluate_all_tasks(self, sess, summary_writer, history, train_set=False): |
|
for task in self.tasks: |
|
results = self._evaluate_task(sess, task, summary_writer, train_set) |
|
if history is not None: |
|
results.append(('step', self._model.get_global_step(sess))) |
|
history.append(results) |
|
if history is not None: |
|
utils.write_cpickle(history, self._config.history_file) |
|
|
|
def _evaluate_task(self, sess, task, summary_writer, train_set): |
|
scorer = task.get_scorer() |
|
data = task.train_set if train_set else task.val_set |
|
for i, mb in enumerate(data.get_minibatches(self._config.test_batch_size)): |
|
loss, batch_preds = self._model.test(sess, mb) |
|
scorer.update(mb.examples, batch_preds, loss) |
|
|
|
results = scorer.get_results(task.name + |
|
('_train_' if train_set else '_dev_')) |
|
utils.log(task.name.upper() + ': ' + scorer.results_str()) |
|
write_summary(summary_writer, results, |
|
global_step=self._model.get_global_step(sess)) |
|
return results |
|
|
|
def _get_training_mbs(self, unlabeled_data_reader): |
|
datasets = [task.train_set for task in self.tasks] |
|
weights = [np.sqrt(dataset.size) for dataset in datasets] |
|
thresholds = np.cumsum([w / np.sum(weights) for w in weights]) |
|
|
|
labeled_mbs = [dataset.endless_minibatches(self._config.train_batch_size) |
|
for dataset in datasets] |
|
unlabeled_mbs = unlabeled_data_reader.endless_minibatches() |
|
while True: |
|
dataset_ind = bisect.bisect(thresholds, np.random.random()) |
|
yield next(labeled_mbs[dataset_ind]) |
|
if self._config.is_semisup: |
|
yield next(unlabeled_mbs) |
|
|
|
|
|
def write_summary(writer, results, global_step): |
|
for k, v in results: |
|
if 'f1' in k or 'acc' in k or 'loss' in k: |
|
writer.add_summary(tf.Summary( |
|
value=[tf.Summary.Value(tag=k, simple_value=v)]), global_step) |
|
writer.flush() |
|
|