|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Script for training model. |
|
|
|
Simple command to get up and running: |
|
python train.py --memory_size=8192 \ |
|
--batch_size=16 --validation_length=50 \ |
|
--episode_width=5 --episode_length=30 |
|
""" |
|
|
|
import logging |
|
import os |
|
import random |
|
|
|
import numpy as np |
|
from six.moves import xrange |
|
import tensorflow as tf |
|
|
|
import data_utils |
|
import model |
|
|
|
FLAGS = tf.flags.FLAGS |
|
|
|
tf.flags.DEFINE_integer('rep_dim', 128, |
|
'dimension of keys to use in memory') |
|
tf.flags.DEFINE_integer('episode_length', 100, 'length of episode') |
|
tf.flags.DEFINE_integer('episode_width', 5, |
|
'number of distinct labels in a single episode') |
|
tf.flags.DEFINE_integer('memory_size', None, 'number of slots in memory. ' |
|
'Leave as None to default to episode length') |
|
tf.flags.DEFINE_integer('batch_size', 16, 'batch size') |
|
tf.flags.DEFINE_integer('num_episodes', 100000, 'number of training episodes') |
|
tf.flags.DEFINE_integer('validation_frequency', 20, |
|
'every so many training episodes, ' |
|
'assess validation accuracy') |
|
tf.flags.DEFINE_integer('validation_length', 10, |
|
'number of episodes to use to compute ' |
|
'validation accuracy') |
|
tf.flags.DEFINE_integer('seed', 888, 'random seed for training sampling') |
|
tf.flags.DEFINE_string('save_dir', '', 'directory to save model to') |
|
tf.flags.DEFINE_bool('use_lsh', False, |
|
'use locality-sensitive hashing ' |
|
'(NOTE: not fully tested)') |
|
|
|
|
|
class Trainer(object): |
|
"""Class that takes care of training, validating, and checkpointing model.""" |
|
|
|
def __init__(self, train_data, valid_data, input_dim, output_dim=None): |
|
self.train_data = train_data |
|
self.valid_data = valid_data |
|
self.input_dim = input_dim |
|
|
|
self.rep_dim = FLAGS.rep_dim |
|
self.episode_length = FLAGS.episode_length |
|
self.episode_width = FLAGS.episode_width |
|
self.batch_size = FLAGS.batch_size |
|
self.memory_size = (self.episode_length * self.batch_size |
|
if FLAGS.memory_size is None else FLAGS.memory_size) |
|
self.use_lsh = FLAGS.use_lsh |
|
|
|
self.output_dim = (output_dim if output_dim is not None |
|
else self.episode_width) |
|
|
|
def get_model(self): |
|
|
|
|
|
vocab_size = self.episode_width * self.batch_size |
|
return model.Model( |
|
self.input_dim, self.output_dim, self.rep_dim, self.memory_size, |
|
vocab_size, use_lsh=self.use_lsh) |
|
|
|
def sample_episode_batch(self, data, |
|
episode_length, episode_width, batch_size): |
|
"""Generates a random batch for training or validation. |
|
|
|
Structures each element of the batch as an 'episode'. |
|
Each episode contains episode_length examples and |
|
episode_width distinct labels. |
|
|
|
Args: |
|
data: A dictionary mapping label to list of examples. |
|
episode_length: Number of examples in each episode. |
|
episode_width: Distinct number of labels in each episode. |
|
batch_size: Batch size (number of episodes). |
|
|
|
Returns: |
|
A tuple (x, y) where x is a list of batches of examples |
|
with size episode_length and y is a list of batches of labels. |
|
""" |
|
|
|
episodes_x = [[] for _ in xrange(episode_length)] |
|
episodes_y = [[] for _ in xrange(episode_length)] |
|
assert len(data) >= episode_width |
|
keys = data.keys() |
|
for b in xrange(batch_size): |
|
episode_labels = random.sample(keys, episode_width) |
|
remainder = episode_length % episode_width |
|
remainders = [0] * (episode_width - remainder) + [1] * remainder |
|
episode_x = [ |
|
random.sample(data[lab], |
|
r + (episode_length - remainder) // episode_width) |
|
for lab, r in zip(episode_labels, remainders)] |
|
episode = sum([[(x, i, ii) for ii, x in enumerate(xx)] |
|
for i, xx in enumerate(episode_x)], []) |
|
random.shuffle(episode) |
|
|
|
|
|
episode.sort(key=lambda elem: elem[2]) |
|
assert len(episode) == episode_length |
|
for i in xrange(episode_length): |
|
episodes_x[i].append(episode[i][0]) |
|
episodes_y[i].append(episode[i][1] + b * episode_width) |
|
|
|
return ([np.array(xx).astype('float32') for xx in episodes_x], |
|
[np.array(yy).astype('int32') for yy in episodes_y]) |
|
|
|
def compute_correct(self, ys, y_preds): |
|
return np.mean(np.equal(y_preds, np.array(ys))) |
|
|
|
def individual_compute_correct(self, y, y_pred): |
|
return y_pred == y |
|
|
|
def run(self): |
|
"""Performs training. |
|
|
|
Trains a model using episodic training. |
|
Every so often, runs some evaluations on validation data. |
|
""" |
|
|
|
train_data, valid_data = self.train_data, self.valid_data |
|
input_dim, output_dim = self.input_dim, self.output_dim |
|
rep_dim, episode_length = self.rep_dim, self.episode_length |
|
episode_width, memory_size = self.episode_width, self.memory_size |
|
batch_size = self.batch_size |
|
|
|
train_size = len(train_data) |
|
valid_size = len(valid_data) |
|
logging.info('train_size (number of labels) %d', train_size) |
|
logging.info('valid_size (number of labels) %d', valid_size) |
|
logging.info('input_dim %d', input_dim) |
|
logging.info('output_dim %d', output_dim) |
|
logging.info('rep_dim %d', rep_dim) |
|
logging.info('episode_length %d', episode_length) |
|
logging.info('episode_width %d', episode_width) |
|
logging.info('memory_size %d', memory_size) |
|
logging.info('batch_size %d', batch_size) |
|
|
|
assert all(len(v) >= float(episode_length) / episode_width |
|
for v in train_data.values()) |
|
assert all(len(v) >= float(episode_length) / episode_width |
|
for v in valid_data.values()) |
|
|
|
output_dim = episode_width |
|
self.model = self.get_model() |
|
self.model.setup() |
|
|
|
sess = tf.Session() |
|
sess.run(tf.global_variables_initializer()) |
|
|
|
saver = tf.train.Saver(max_to_keep=10) |
|
ckpt = None |
|
if FLAGS.save_dir: |
|
ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir) |
|
if ckpt and ckpt.model_checkpoint_path: |
|
logging.info('restoring from %s', ckpt.model_checkpoint_path) |
|
saver.restore(sess, ckpt.model_checkpoint_path) |
|
|
|
logging.info('starting now') |
|
losses = [] |
|
random.seed(FLAGS.seed) |
|
np.random.seed(FLAGS.seed) |
|
for i in xrange(FLAGS.num_episodes): |
|
x, y = self.sample_episode_batch( |
|
train_data, episode_length, episode_width, batch_size) |
|
outputs = self.model.episode_step(sess, x, y, clear_memory=True) |
|
loss = outputs |
|
losses.append(loss) |
|
|
|
if i % FLAGS.validation_frequency == 0: |
|
logging.info('episode batch %d, avg train loss %f', |
|
i, np.mean(losses)) |
|
losses = [] |
|
|
|
|
|
correct = [] |
|
num_shots = episode_length // episode_width |
|
correct_by_shot = dict((k, []) for k in xrange(num_shots)) |
|
for _ in xrange(FLAGS.validation_length): |
|
x, y = self.sample_episode_batch( |
|
valid_data, episode_length, episode_width, 1) |
|
outputs = self.model.episode_predict( |
|
sess, x, y, clear_memory=True) |
|
y_preds = outputs |
|
correct.append(self.compute_correct(np.array(y), y_preds)) |
|
|
|
|
|
seen_counts = [0] * episode_width |
|
|
|
for yy, yy_preds in zip(y, y_preds): |
|
|
|
yyy, yyy_preds = int(yy[0]), int(yy_preds[0]) |
|
count = seen_counts[yyy % episode_width] |
|
if count in correct_by_shot: |
|
correct_by_shot[count].append( |
|
self.individual_compute_correct(yyy, yyy_preds)) |
|
seen_counts[yyy % episode_width] = count + 1 |
|
|
|
logging.info('validation overall accuracy %f', np.mean(correct)) |
|
logging.info('%d-shot: %.3f, ' * num_shots, |
|
*sum([[k, np.mean(correct_by_shot[k])] |
|
for k in xrange(num_shots)], [])) |
|
|
|
if saver and FLAGS.save_dir: |
|
saved_file = saver.save(sess, |
|
os.path.join(FLAGS.save_dir, 'model.ckpt'), |
|
global_step=self.model.global_step) |
|
logging.info('saved model to %s', saved_file) |
|
|
|
|
|
def main(unused_argv): |
|
train_data, valid_data = data_utils.get_data() |
|
trainer = Trainer(train_data, valid_data, data_utils.IMAGE_NEW_SIZE ** 2) |
|
trainer.run() |
|
|
|
|
|
if __name__ == '__main__': |
|
logging.basicConfig(level=logging.INFO) |
|
tf.app.run() |
|
|