NCTCMumbai's picture
Upload 2583 files
18ddfe2 verified
raw
history blame
9.37 kB
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
r"""Script for training model.
Simple command to get up and running:
python train.py --memory_size=8192 \
--batch_size=16 --validation_length=50 \
--episode_width=5 --episode_length=30
"""
import logging
import os
import random
import numpy as np
from six.moves import xrange
import tensorflow as tf
import data_utils
import model
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer('rep_dim', 128,
'dimension of keys to use in memory')
tf.flags.DEFINE_integer('episode_length', 100, 'length of episode')
tf.flags.DEFINE_integer('episode_width', 5,
'number of distinct labels in a single episode')
tf.flags.DEFINE_integer('memory_size', None, 'number of slots in memory. '
'Leave as None to default to episode length')
tf.flags.DEFINE_integer('batch_size', 16, 'batch size')
tf.flags.DEFINE_integer('num_episodes', 100000, 'number of training episodes')
tf.flags.DEFINE_integer('validation_frequency', 20,
'every so many training episodes, '
'assess validation accuracy')
tf.flags.DEFINE_integer('validation_length', 10,
'number of episodes to use to compute '
'validation accuracy')
tf.flags.DEFINE_integer('seed', 888, 'random seed for training sampling')
tf.flags.DEFINE_string('save_dir', '', 'directory to save model to')
tf.flags.DEFINE_bool('use_lsh', False,
'use locality-sensitive hashing '
'(NOTE: not fully tested)')
class Trainer(object):
"""Class that takes care of training, validating, and checkpointing model."""
def __init__(self, train_data, valid_data, input_dim, output_dim=None):
self.train_data = train_data
self.valid_data = valid_data
self.input_dim = input_dim
self.rep_dim = FLAGS.rep_dim
self.episode_length = FLAGS.episode_length
self.episode_width = FLAGS.episode_width
self.batch_size = FLAGS.batch_size
self.memory_size = (self.episode_length * self.batch_size
if FLAGS.memory_size is None else FLAGS.memory_size)
self.use_lsh = FLAGS.use_lsh
self.output_dim = (output_dim if output_dim is not None
else self.episode_width)
def get_model(self):
# vocab size is the number of distinct values that
# could go into the memory key-value storage
vocab_size = self.episode_width * self.batch_size
return model.Model(
self.input_dim, self.output_dim, self.rep_dim, self.memory_size,
vocab_size, use_lsh=self.use_lsh)
def sample_episode_batch(self, data,
episode_length, episode_width, batch_size):
"""Generates a random batch for training or validation.
Structures each element of the batch as an 'episode'.
Each episode contains episode_length examples and
episode_width distinct labels.
Args:
data: A dictionary mapping label to list of examples.
episode_length: Number of examples in each episode.
episode_width: Distinct number of labels in each episode.
batch_size: Batch size (number of episodes).
Returns:
A tuple (x, y) where x is a list of batches of examples
with size episode_length and y is a list of batches of labels.
"""
episodes_x = [[] for _ in xrange(episode_length)]
episodes_y = [[] for _ in xrange(episode_length)]
assert len(data) >= episode_width
keys = data.keys()
for b in xrange(batch_size):
episode_labels = random.sample(keys, episode_width)
remainder = episode_length % episode_width
remainders = [0] * (episode_width - remainder) + [1] * remainder
episode_x = [
random.sample(data[lab],
r + (episode_length - remainder) // episode_width)
for lab, r in zip(episode_labels, remainders)]
episode = sum([[(x, i, ii) for ii, x in enumerate(xx)]
for i, xx in enumerate(episode_x)], [])
random.shuffle(episode)
# Arrange episode so that each distinct label is seen before moving to
# 2nd showing
episode.sort(key=lambda elem: elem[2])
assert len(episode) == episode_length
for i in xrange(episode_length):
episodes_x[i].append(episode[i][0])
episodes_y[i].append(episode[i][1] + b * episode_width)
return ([np.array(xx).astype('float32') for xx in episodes_x],
[np.array(yy).astype('int32') for yy in episodes_y])
def compute_correct(self, ys, y_preds):
return np.mean(np.equal(y_preds, np.array(ys)))
def individual_compute_correct(self, y, y_pred):
return y_pred == y
def run(self):
"""Performs training.
Trains a model using episodic training.
Every so often, runs some evaluations on validation data.
"""
train_data, valid_data = self.train_data, self.valid_data
input_dim, output_dim = self.input_dim, self.output_dim
rep_dim, episode_length = self.rep_dim, self.episode_length
episode_width, memory_size = self.episode_width, self.memory_size
batch_size = self.batch_size
train_size = len(train_data)
valid_size = len(valid_data)
logging.info('train_size (number of labels) %d', train_size)
logging.info('valid_size (number of labels) %d', valid_size)
logging.info('input_dim %d', input_dim)
logging.info('output_dim %d', output_dim)
logging.info('rep_dim %d', rep_dim)
logging.info('episode_length %d', episode_length)
logging.info('episode_width %d', episode_width)
logging.info('memory_size %d', memory_size)
logging.info('batch_size %d', batch_size)
assert all(len(v) >= float(episode_length) / episode_width
for v in train_data.values())
assert all(len(v) >= float(episode_length) / episode_width
for v in valid_data.values())
output_dim = episode_width
self.model = self.get_model()
self.model.setup()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=10)
ckpt = None
if FLAGS.save_dir:
ckpt = tf.train.get_checkpoint_state(FLAGS.save_dir)
if ckpt and ckpt.model_checkpoint_path:
logging.info('restoring from %s', ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
logging.info('starting now')
losses = []
random.seed(FLAGS.seed)
np.random.seed(FLAGS.seed)
for i in xrange(FLAGS.num_episodes):
x, y = self.sample_episode_batch(
train_data, episode_length, episode_width, batch_size)
outputs = self.model.episode_step(sess, x, y, clear_memory=True)
loss = outputs
losses.append(loss)
if i % FLAGS.validation_frequency == 0:
logging.info('episode batch %d, avg train loss %f',
i, np.mean(losses))
losses = []
# validation
correct = []
num_shots = episode_length // episode_width
correct_by_shot = dict((k, []) for k in xrange(num_shots))
for _ in xrange(FLAGS.validation_length):
x, y = self.sample_episode_batch(
valid_data, episode_length, episode_width, 1)
outputs = self.model.episode_predict(
sess, x, y, clear_memory=True)
y_preds = outputs
correct.append(self.compute_correct(np.array(y), y_preds))
# compute per-shot accuracies
seen_counts = [0] * episode_width
# loop over episode steps
for yy, yy_preds in zip(y, y_preds):
# loop over batch examples
yyy, yyy_preds = int(yy[0]), int(yy_preds[0])
count = seen_counts[yyy % episode_width]
if count in correct_by_shot:
correct_by_shot[count].append(
self.individual_compute_correct(yyy, yyy_preds))
seen_counts[yyy % episode_width] = count + 1
logging.info('validation overall accuracy %f', np.mean(correct))
logging.info('%d-shot: %.3f, ' * num_shots,
*sum([[k, np.mean(correct_by_shot[k])]
for k in xrange(num_shots)], []))
if saver and FLAGS.save_dir:
saved_file = saver.save(sess,
os.path.join(FLAGS.save_dir, 'model.ckpt'),
global_step=self.model.global_step)
logging.info('saved model to %s', saved_file)
def main(unused_argv):
train_data, valid_data = data_utils.get_data()
trainer = Trainer(train_data, valid_data, data_utils.IMAGE_NEW_SIZE ** 2)
trainer.run()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
tf.app.run()