NCTCMumbai's picture
Upload 2583 files
18ddfe2 verified
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
"""Model using memory component.
The model embeds images using a standard CNN architecture.
These embeddings are used as keys to the memory component,
which returns nearest neighbors.
"""
import tensorflow as tf
import memory
FLAGS = tf.flags.FLAGS
class BasicClassifier(object):
def __init__(self, output_dim):
self.output_dim = output_dim
def core_builder(self, memory_val, x, y):
del x, y
y_pred = memory_val
loss = 0.0
return loss, y_pred
class LeNet(object):
"""Standard CNN architecture."""
def __init__(self, image_size, num_channels, hidden_dim):
self.image_size = image_size
self.num_channels = num_channels
self.hidden_dim = hidden_dim
self.matrix_init = tf.truncated_normal_initializer(stddev=0.1)
self.vector_init = tf.constant_initializer(0.0)
def core_builder(self, x):
"""Embeds x using standard CNN architecture.
Args:
x: Batch of images as a 2-d Tensor [batch_size, -1].
Returns:
A 2-d Tensor [batch_size, hidden_dim] of embedded images.
"""
ch1 = 32 * 2 # number of channels in 1st layer
ch2 = 64 * 2 # number of channels in 2nd layer
conv1_weights = tf.get_variable('conv1_w',
[3, 3, self.num_channels, ch1],
initializer=self.matrix_init)
conv1_biases = tf.get_variable('conv1_b', [ch1],
initializer=self.vector_init)
conv1a_weights = tf.get_variable('conv1a_w',
[3, 3, ch1, ch1],
initializer=self.matrix_init)
conv1a_biases = tf.get_variable('conv1a_b', [ch1],
initializer=self.vector_init)
conv2_weights = tf.get_variable('conv2_w', [3, 3, ch1, ch2],
initializer=self.matrix_init)
conv2_biases = tf.get_variable('conv2_b', [ch2],
initializer=self.vector_init)
conv2a_weights = tf.get_variable('conv2a_w', [3, 3, ch2, ch2],
initializer=self.matrix_init)
conv2a_biases = tf.get_variable('conv2a_b', [ch2],
initializer=self.vector_init)
# fully connected
fc1_weights = tf.get_variable(
'fc1_w', [self.image_size // 4 * self.image_size // 4 * ch2,
self.hidden_dim], initializer=self.matrix_init)
fc1_biases = tf.get_variable('fc1_b', [self.hidden_dim],
initializer=self.vector_init)
# define model
x = tf.reshape(x,
[-1, self.image_size, self.image_size, self.num_channels])
batch_size = tf.shape(x)[0]
conv1 = tf.nn.conv2d(x, conv1_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1_biases))
conv1 = tf.nn.conv2d(relu1, conv1a_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu1 = tf.nn.relu(tf.nn.bias_add(conv1, conv1a_biases))
pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
conv2 = tf.nn.conv2d(pool1, conv2_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2_biases))
conv2 = tf.nn.conv2d(relu2, conv2a_weights,
strides=[1, 1, 1, 1], padding='SAME')
relu2 = tf.nn.relu(tf.nn.bias_add(conv2, conv2a_biases))
pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
reshape = tf.reshape(pool2, [batch_size, -1])
hidden = tf.matmul(reshape, fc1_weights) + fc1_biases
return hidden
class Model(object):
"""Model for coordinating between CNN embedder and Memory module."""
def __init__(self, input_dim, output_dim, rep_dim, memory_size, vocab_size,
learning_rate=0.0001, use_lsh=False):
self.input_dim = input_dim
self.output_dim = output_dim
self.rep_dim = rep_dim
self.memory_size = memory_size
self.vocab_size = vocab_size
self.learning_rate = learning_rate
self.use_lsh = use_lsh
self.embedder = self.get_embedder()
self.memory = self.get_memory()
self.classifier = self.get_classifier()
self.global_step = tf.train.get_or_create_global_step()
def get_embedder(self):
return LeNet(int(self.input_dim ** 0.5), 1, self.rep_dim)
def get_memory(self):
cls = memory.LSHMemory if self.use_lsh else memory.Memory
return cls(self.rep_dim, self.memory_size, self.vocab_size)
def get_classifier(self):
return BasicClassifier(self.output_dim)
def core_builder(self, x, y, keep_prob, use_recent_idx=True):
embeddings = self.embedder.core_builder(x)
if keep_prob < 1.0:
embeddings = tf.nn.dropout(embeddings, keep_prob)
memory_val, _, teacher_loss = self.memory.query(
embeddings, y, use_recent_idx=use_recent_idx)
loss, y_pred = self.classifier.core_builder(memory_val, x, y)
return loss + teacher_loss, y_pred
def train(self, x, y):
loss, _ = self.core_builder(x, y, keep_prob=0.3)
gradient_ops = self.training_ops(loss)
return loss, gradient_ops
def eval(self, x, y):
_, y_preds = self.core_builder(x, y, keep_prob=1.0,
use_recent_idx=False)
return y_preds
def get_xy_placeholders(self):
return (tf.placeholder(tf.float32, [None, self.input_dim]),
tf.placeholder(tf.int32, [None]))
def setup(self):
"""Sets up all components of the computation graph."""
self.x, self.y = self.get_xy_placeholders()
# This context creates variables
with tf.variable_scope('core', reuse=None):
self.loss, self.gradient_ops = self.train(self.x, self.y)
# And this one re-uses them (thus the `reuse=True`)
with tf.variable_scope('core', reuse=True):
self.y_preds = self.eval(self.x, self.y)
def training_ops(self, loss):
opt = self.get_optimizer()
params = tf.trainable_variables()
gradients = tf.gradients(loss, params)
clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
return opt.apply_gradients(zip(clipped_gradients, params),
global_step=self.global_step)
def get_optimizer(self):
return tf.train.AdamOptimizer(learning_rate=self.learning_rate,
epsilon=1e-4)
def one_step(self, sess, x, y):
outputs = [self.loss, self.gradient_ops]
return sess.run(outputs, feed_dict={self.x: x, self.y: y})
def episode_step(self, sess, x, y, clear_memory=False):
"""Performs training steps on episodic input.
Args:
sess: A Tensorflow Session.
x: A list of batches of images defining the episode.
y: A list of batches of labels corresponding to x.
clear_memory: Whether to clear the memory before the episode.
Returns:
List of losses the same length as the episode.
"""
outputs = [self.loss, self.gradient_ops]
if clear_memory:
self.clear_memory(sess)
losses = []
for xx, yy in zip(x, y):
out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
loss = out[0]
losses.append(loss)
return losses
def predict(self, sess, x, y=None):
"""Predict the labels on a single batch of examples.
Args:
sess: A Tensorflow Session.
x: A batch of images.
y: The labels for the images in x.
This allows for updating the memory.
Returns:
Predicted y.
"""
# Storing current memory state to restore it after prediction
mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
outputs = [self.y_preds]
if y is None:
ret = sess.run(outputs, feed_dict={self.x: x})
else:
ret = sess.run(outputs, feed_dict={self.x: x, self.y: y})
# Restoring memory state
self.memory.set(*cur_memory)
return ret
def episode_predict(self, sess, x, y, clear_memory=False):
"""Predict the labels on an episode of examples.
Args:
sess: A Tensorflow Session.
x: A list of batches of images.
y: A list of labels for the images in x.
This allows for updating the memory.
clear_memory: Whether to clear the memory before the episode.
Returns:
List of predicted y.
"""
# Storing current memory state to restore it after prediction
mem_keys, mem_vals, mem_age, _ = self.memory.get()
cur_memory = (
tf.identity(mem_keys),
tf.identity(mem_vals),
tf.identity(mem_age),
None,
)
if clear_memory:
self.clear_memory(sess)
outputs = [self.y_preds]
y_preds = []
for xx, yy in zip(x, y):
out = sess.run(outputs, feed_dict={self.x: xx, self.y: yy})
y_pred = out[0]
y_preds.append(y_pred)
# Restoring memory state
self.memory.set(*cur_memory)
return y_preds
def clear_memory(self, sess):
sess.run([self.memory.clear()])