|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Functions to build the Attention OCR model. |
|
|
|
Usage example: |
|
ocr_model = model.Model(num_char_classes, seq_length, num_of_views) |
|
|
|
data = ... # create namedtuple InputEndpoints |
|
endpoints = model.create_base(data.images, data.labels_one_hot) |
|
# endpoints.predicted_chars is a tensor with predicted character codes. |
|
total_loss = model.create_loss(data, endpoints) |
|
""" |
|
import sys |
|
import collections |
|
import logging |
|
import tensorflow as tf |
|
from tensorflow.contrib import slim |
|
from tensorflow.contrib.slim.nets import inception |
|
|
|
import metrics |
|
import sequence_layers |
|
import utils |
|
|
|
OutputEndpoints = collections.namedtuple('OutputEndpoints', [ |
|
'chars_logit', 'chars_log_prob', 'predicted_chars', 'predicted_scores', |
|
'predicted_text' |
|
]) |
|
|
|
|
|
ModelParams = collections.namedtuple('ModelParams', [ |
|
'num_char_classes', 'seq_length', 'num_views', 'null_code' |
|
]) |
|
|
|
ConvTowerParams = collections.namedtuple('ConvTowerParams', ['final_endpoint']) |
|
|
|
SequenceLogitsParams = collections.namedtuple('SequenceLogitsParams', [ |
|
'use_attention', 'use_autoregression', 'num_lstm_units', 'weight_decay', |
|
'lstm_state_clip_value' |
|
]) |
|
|
|
SequenceLossParams = collections.namedtuple('SequenceLossParams', [ |
|
'label_smoothing', 'ignore_nulls', 'average_across_timesteps' |
|
]) |
|
|
|
EncodeCoordinatesParams = collections.namedtuple('EncodeCoordinatesParams', [ |
|
'enabled' |
|
]) |
|
|
|
|
|
def _dict_to_array(id_to_char, default_character): |
|
num_char_classes = max(id_to_char.keys()) + 1 |
|
array = [default_character] * num_char_classes |
|
for k, v in id_to_char.items(): |
|
array[k] = v |
|
return array |
|
|
|
|
|
class CharsetMapper(object): |
|
"""A simple class to map tensor ids into strings. |
|
|
|
It works only when the character set is 1:1 mapping between individual |
|
characters and individual ids. |
|
|
|
Make sure you call tf.tables_initializer().run() as part of the init op. |
|
""" |
|
|
|
def __init__(self, charset, default_character='?'): |
|
"""Creates a lookup table. |
|
|
|
Args: |
|
charset: a dictionary with id-to-character mapping. |
|
""" |
|
mapping_strings = tf.constant(_dict_to_array(charset, default_character)) |
|
self.table = tf.contrib.lookup.index_to_string_table_from_tensor( |
|
mapping=mapping_strings, default_value=default_character) |
|
|
|
def get_text(self, ids): |
|
"""Returns a string corresponding to a sequence of character ids. |
|
|
|
Args: |
|
ids: a tensor with shape [batch_size, max_sequence_length] |
|
""" |
|
return tf.reduce_join( |
|
self.table.lookup(tf.to_int64(ids)), reduction_indices=1) |
|
|
|
|
|
def get_softmax_loss_fn(label_smoothing): |
|
"""Returns sparse or dense loss function depending on the label_smoothing. |
|
|
|
Args: |
|
label_smoothing: weight for label smoothing |
|
|
|
Returns: |
|
a function which takes labels and predictions as arguments and returns |
|
a softmax loss for the selected type of labels (sparse or dense). |
|
""" |
|
if label_smoothing > 0: |
|
|
|
def loss_fn(labels, logits): |
|
return (tf.nn.softmax_cross_entropy_with_logits( |
|
logits=logits, labels=labels)) |
|
else: |
|
|
|
def loss_fn(labels, logits): |
|
return tf.nn.sparse_softmax_cross_entropy_with_logits( |
|
logits=logits, labels=labels) |
|
|
|
return loss_fn |
|
|
|
|
|
class Model(object): |
|
"""Class to create the Attention OCR Model.""" |
|
|
|
def __init__(self, |
|
num_char_classes, |
|
seq_length, |
|
num_views, |
|
null_code, |
|
mparams=None, |
|
charset=None): |
|
"""Initialized model parameters. |
|
|
|
Args: |
|
num_char_classes: size of character set. |
|
seq_length: number of characters in a sequence. |
|
num_views: Number of views (conv towers) to use. |
|
null_code: A character code corresponding to a character which |
|
indicates end of a sequence. |
|
mparams: a dictionary with hyper parameters for methods, keys - |
|
function names, values - corresponding namedtuples. |
|
charset: an optional dictionary with a mapping between character ids and |
|
utf8 strings. If specified the OutputEndpoints.predicted_text will |
|
utf8 encoded strings corresponding to the character ids returned by |
|
OutputEndpoints.predicted_chars (by default the predicted_text contains |
|
an empty vector). |
|
NOTE: Make sure you call tf.tables_initializer().run() if the charset |
|
specified. |
|
""" |
|
super(Model, self).__init__() |
|
self._params = ModelParams( |
|
num_char_classes=num_char_classes, |
|
seq_length=seq_length, |
|
num_views=num_views, |
|
null_code=null_code) |
|
self._mparams = self.default_mparams() |
|
if mparams: |
|
self._mparams.update(mparams) |
|
self._charset = charset |
|
|
|
def default_mparams(self): |
|
return { |
|
'conv_tower_fn': |
|
ConvTowerParams(final_endpoint='Mixed_5d'), |
|
'sequence_logit_fn': |
|
SequenceLogitsParams( |
|
use_attention=True, |
|
use_autoregression=True, |
|
num_lstm_units=256, |
|
weight_decay=0.00004, |
|
lstm_state_clip_value=10.0), |
|
'sequence_loss_fn': |
|
SequenceLossParams( |
|
label_smoothing=0.1, |
|
ignore_nulls=True, |
|
average_across_timesteps=False), |
|
'encode_coordinates_fn': EncodeCoordinatesParams(enabled=False) |
|
} |
|
|
|
def set_mparam(self, function, **kwargs): |
|
self._mparams[function] = self._mparams[function]._replace(**kwargs) |
|
|
|
def conv_tower_fn(self, images, is_training=True, reuse=None): |
|
"""Computes convolutional features using the InceptionV3 model. |
|
|
|
Args: |
|
images: A tensor of shape [batch_size, height, width, channels]. |
|
is_training: whether is training or not. |
|
reuse: whether or not the network and its variables should be reused. To |
|
be able to reuse 'scope' must be given. |
|
|
|
Returns: |
|
A tensor of shape [batch_size, OH, OW, N], where OWxOH is resolution of |
|
output feature map and N is number of output features (depends on the |
|
network architecture). |
|
""" |
|
mparams = self._mparams['conv_tower_fn'] |
|
logging.debug('Using final_endpoint=%s', mparams.final_endpoint) |
|
with tf.variable_scope('conv_tower_fn/INCE'): |
|
if reuse: |
|
tf.get_variable_scope().reuse_variables() |
|
with slim.arg_scope(inception.inception_v3_arg_scope()): |
|
with slim.arg_scope([slim.batch_norm, slim.dropout], |
|
is_training=is_training): |
|
net, _ = inception.inception_v3_base( |
|
images, final_endpoint=mparams.final_endpoint) |
|
return net |
|
|
|
def _create_lstm_inputs(self, net): |
|
"""Splits an input tensor into a list of tensors (features). |
|
|
|
Args: |
|
net: A feature map of shape [batch_size, num_features, feature_size]. |
|
|
|
Raises: |
|
AssertionError: if num_features is less than seq_length. |
|
|
|
Returns: |
|
A list with seq_length tensors of shape [batch_size, feature_size] |
|
""" |
|
num_features = net.get_shape().dims[1].value |
|
if num_features < self._params.seq_length: |
|
raise AssertionError('Incorrect dimension #1 of input tensor' |
|
' %d should be bigger than %d (shape=%s)' % |
|
(num_features, self._params.seq_length, |
|
net.get_shape())) |
|
elif num_features > self._params.seq_length: |
|
logging.warning('Ignoring some features: use %d of %d (shape=%s)', |
|
self._params.seq_length, num_features, net.get_shape()) |
|
net = tf.slice(net, [0, 0, 0], [-1, self._params.seq_length, -1]) |
|
|
|
return tf.unstack(net, axis=1) |
|
|
|
def sequence_logit_fn(self, net, labels_one_hot): |
|
mparams = self._mparams['sequence_logit_fn'] |
|
|
|
with tf.variable_scope('sequence_logit_fn/SQLR'): |
|
layer_class = sequence_layers.get_layer_class(mparams.use_attention, |
|
mparams.use_autoregression) |
|
layer = layer_class(net, labels_one_hot, self._params, mparams) |
|
return layer.create_logits() |
|
|
|
def max_pool_views(self, nets_list): |
|
"""Max pool across all nets in spatial dimensions. |
|
|
|
Args: |
|
nets_list: A list of 4D tensors with identical size. |
|
|
|
Returns: |
|
A tensor with the same size as any input tensors. |
|
""" |
|
batch_size, height, width, num_features = [ |
|
d.value for d in nets_list[0].get_shape().dims |
|
] |
|
xy_flat_shape = (batch_size, 1, height * width, num_features) |
|
nets_for_merge = [] |
|
with tf.variable_scope('max_pool_views', values=nets_list): |
|
for net in nets_list: |
|
nets_for_merge.append(tf.reshape(net, xy_flat_shape)) |
|
merged_net = tf.concat(nets_for_merge, 1) |
|
net = slim.max_pool2d( |
|
merged_net, kernel_size=[len(nets_list), 1], stride=1) |
|
net = tf.reshape(net, (batch_size, height, width, num_features)) |
|
return net |
|
|
|
def pool_views_fn(self, nets): |
|
"""Combines output of multiple convolutional towers into a single tensor. |
|
|
|
It stacks towers one on top another (in height dim) in a 4x1 grid. |
|
The order is arbitrary design choice and shouldn't matter much. |
|
|
|
Args: |
|
nets: list of tensors of shape=[batch_size, height, width, num_features]. |
|
|
|
Returns: |
|
A tensor of shape [batch_size, seq_length, features_size]. |
|
""" |
|
with tf.variable_scope('pool_views_fn/STCK'): |
|
net = tf.concat(nets, 1) |
|
batch_size = net.get_shape().dims[0].value |
|
feature_size = net.get_shape().dims[3].value |
|
return tf.reshape(net, [batch_size, -1, feature_size]) |
|
|
|
def char_predictions(self, chars_logit): |
|
"""Returns confidence scores (softmax values) for predicted characters. |
|
|
|
Args: |
|
chars_logit: chars logits, a tensor with shape |
|
[batch_size x seq_length x num_char_classes] |
|
|
|
Returns: |
|
A tuple (ids, log_prob, scores), where: |
|
ids - predicted characters, a int32 tensor with shape |
|
[batch_size x seq_length]; |
|
log_prob - a log probability of all characters, a float tensor with |
|
shape [batch_size, seq_length, num_char_classes]; |
|
scores - corresponding confidence scores for characters, a float |
|
tensor |
|
with shape [batch_size x seq_length]. |
|
""" |
|
log_prob = utils.logits_to_log_prob(chars_logit) |
|
ids = tf.to_int32(tf.argmax(log_prob, axis=2), name='predicted_chars') |
|
mask = tf.cast( |
|
slim.one_hot_encoding(ids, self._params.num_char_classes), tf.bool) |
|
all_scores = tf.nn.softmax(chars_logit) |
|
selected_scores = tf.boolean_mask(all_scores, mask, name='char_scores') |
|
scores = tf.reshape(selected_scores, shape=(-1, self._params.seq_length)) |
|
return ids, log_prob, scores |
|
|
|
def encode_coordinates_fn(self, net): |
|
"""Adds one-hot encoding of coordinates to different views in the networks. |
|
|
|
For each "pixel" of a feature map it adds a onehot encoded x and y |
|
coordinates. |
|
|
|
Args: |
|
net: a tensor of shape=[batch_size, height, width, num_features] |
|
|
|
Returns: |
|
a tensor with the same height and width, but altered feature_size. |
|
""" |
|
mparams = self._mparams['encode_coordinates_fn'] |
|
if mparams.enabled: |
|
batch_size, h, w, _ = net.shape.as_list() |
|
x, y = tf.meshgrid(tf.range(w), tf.range(h)) |
|
w_loc = slim.one_hot_encoding(x, num_classes=w) |
|
h_loc = slim.one_hot_encoding(y, num_classes=h) |
|
loc = tf.concat([h_loc, w_loc], 2) |
|
loc = tf.tile(tf.expand_dims(loc, 0), [batch_size, 1, 1, 1]) |
|
return tf.concat([net, loc], 3) |
|
else: |
|
return net |
|
|
|
def create_base(self, |
|
images, |
|
labels_one_hot, |
|
scope='AttentionOcr_v1', |
|
reuse=None): |
|
"""Creates a base part of the Model (no gradients, losses or summaries). |
|
|
|
Args: |
|
images: A tensor of shape [batch_size, height, width, channels]. |
|
labels_one_hot: Optional (can be None) one-hot encoding for ground truth |
|
labels. If provided the function will create a model for training. |
|
scope: Optional variable_scope. |
|
reuse: whether or not the network and its variables should be reused. To |
|
be able to reuse 'scope' must be given. |
|
|
|
Returns: |
|
A named tuple OutputEndpoints. |
|
""" |
|
logging.debug('images: %s', images) |
|
is_training = labels_one_hot is not None |
|
with tf.variable_scope(scope, reuse=reuse): |
|
views = tf.split( |
|
value=images, num_or_size_splits=self._params.num_views, axis=2) |
|
logging.debug('Views=%d single view: %s', len(views), views[0]) |
|
|
|
nets = [ |
|
self.conv_tower_fn(v, is_training, reuse=(i != 0)) |
|
for i, v in enumerate(views) |
|
] |
|
logging.debug('Conv tower: %s', nets[0]) |
|
|
|
nets = [self.encode_coordinates_fn(net) for net in nets] |
|
logging.debug('Conv tower w/ encoded coordinates: %s', nets[0]) |
|
|
|
net = self.pool_views_fn(nets) |
|
logging.debug('Pooled views: %s', net) |
|
|
|
chars_logit = self.sequence_logit_fn(net, labels_one_hot) |
|
logging.debug('chars_logit: %s', chars_logit) |
|
|
|
predicted_chars, chars_log_prob, predicted_scores = ( |
|
self.char_predictions(chars_logit)) |
|
if self._charset: |
|
character_mapper = CharsetMapper(self._charset) |
|
predicted_text = character_mapper.get_text(predicted_chars) |
|
else: |
|
predicted_text = tf.constant([]) |
|
return OutputEndpoints( |
|
chars_logit=chars_logit, |
|
chars_log_prob=chars_log_prob, |
|
predicted_chars=predicted_chars, |
|
predicted_scores=predicted_scores, |
|
predicted_text=predicted_text) |
|
|
|
def create_loss(self, data, endpoints): |
|
"""Creates all losses required to train the model. |
|
|
|
Args: |
|
data: InputEndpoints namedtuple. |
|
endpoints: Model namedtuple. |
|
|
|
Returns: |
|
Total loss. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
self.sequence_loss_fn(endpoints.chars_logit, data.labels) |
|
total_loss = slim.losses.get_total_loss() |
|
tf.summary.scalar('TotalLoss', total_loss) |
|
return total_loss |
|
|
|
def label_smoothing_regularization(self, chars_labels, weight=0.1): |
|
"""Applies a label smoothing regularization. |
|
|
|
Uses the same method as in https://arxiv.org/abs/1512.00567. |
|
|
|
Args: |
|
chars_labels: ground truth ids of charactes, |
|
shape=[batch_size, seq_length]; |
|
weight: label-smoothing regularization weight. |
|
|
|
Returns: |
|
A sensor with the same shape as the input. |
|
""" |
|
one_hot_labels = tf.one_hot( |
|
chars_labels, depth=self._params.num_char_classes, axis=-1) |
|
pos_weight = 1.0 - weight |
|
neg_weight = weight / self._params.num_char_classes |
|
return one_hot_labels * pos_weight + neg_weight |
|
|
|
def sequence_loss_fn(self, chars_logits, chars_labels): |
|
"""Loss function for char sequence. |
|
|
|
Depending on values of hyper parameters it applies label smoothing and can |
|
also ignore all null chars after the first one. |
|
|
|
Args: |
|
chars_logits: logits for predicted characters, |
|
shape=[batch_size, seq_length, num_char_classes]; |
|
chars_labels: ground truth ids of characters, |
|
shape=[batch_size, seq_length]; |
|
mparams: method hyper parameters. |
|
|
|
Returns: |
|
A Tensor with shape [batch_size] - the log-perplexity for each sequence. |
|
""" |
|
mparams = self._mparams['sequence_loss_fn'] |
|
with tf.variable_scope('sequence_loss_fn/SLF'): |
|
if mparams.label_smoothing > 0: |
|
smoothed_one_hot_labels = self.label_smoothing_regularization( |
|
chars_labels, mparams.label_smoothing) |
|
labels_list = tf.unstack(smoothed_one_hot_labels, axis=1) |
|
else: |
|
|
|
|
|
labels_list = tf.unstack(chars_labels, axis=1) |
|
|
|
batch_size, seq_length, _ = chars_logits.shape.as_list() |
|
if mparams.ignore_nulls: |
|
weights = tf.ones((batch_size, seq_length), dtype=tf.float32) |
|
else: |
|
|
|
reject_char = tf.constant( |
|
self._params.num_char_classes - 1, |
|
shape=(batch_size, seq_length), |
|
dtype=tf.int64) |
|
known_char = tf.not_equal(chars_labels, reject_char) |
|
weights = tf.to_float(known_char) |
|
|
|
logits_list = tf.unstack(chars_logits, axis=1) |
|
weights_list = tf.unstack(weights, axis=1) |
|
loss = tf.contrib.legacy_seq2seq.sequence_loss( |
|
logits_list, |
|
labels_list, |
|
weights_list, |
|
softmax_loss_function=get_softmax_loss_fn(mparams.label_smoothing), |
|
average_across_timesteps=mparams.average_across_timesteps) |
|
tf.losses.add_loss(loss) |
|
return loss |
|
|
|
def create_summaries(self, data, endpoints, charset, is_training): |
|
"""Creates all summaries for the model. |
|
|
|
Args: |
|
data: InputEndpoints namedtuple. |
|
endpoints: OutputEndpoints namedtuple. |
|
charset: A dictionary with mapping between character codes and |
|
unicode characters. Use the one provided by a dataset.charset. |
|
is_training: If True will create summary prefixes for training job, |
|
otherwise - for evaluation. |
|
|
|
Returns: |
|
A list of evaluation ops |
|
""" |
|
|
|
def sname(label): |
|
prefix = 'train' if is_training else 'eval' |
|
return '%s/%s' % (prefix, label) |
|
|
|
max_outputs = 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tf.summary.image(sname('image'), data.images, max_outputs=max_outputs) |
|
|
|
if is_training: |
|
tf.summary.image( |
|
sname('image/orig'), data.images_orig, max_outputs=max_outputs) |
|
for var in tf.trainable_variables(): |
|
tf.summary.histogram(var.op.name, var) |
|
return None |
|
|
|
else: |
|
names_to_values = {} |
|
names_to_updates = {} |
|
|
|
def use_metric(name, value_update_tuple): |
|
names_to_values[name] = value_update_tuple[0] |
|
names_to_updates[name] = value_update_tuple[1] |
|
|
|
use_metric('CharacterAccuracy', |
|
metrics.char_accuracy( |
|
endpoints.predicted_chars, |
|
data.labels, |
|
streaming=True, |
|
rej_char=self._params.null_code)) |
|
|
|
use_metric('SequenceAccuracy', |
|
metrics.sequence_accuracy( |
|
endpoints.predicted_chars, |
|
data.labels, |
|
streaming=True, |
|
rej_char=self._params.null_code)) |
|
|
|
for name, value in names_to_values.items(): |
|
summary_name = 'eval/' + name |
|
tf.summary.scalar(summary_name, tf.Print(value, [value], summary_name)) |
|
return list(names_to_updates.values()) |
|
|
|
def create_init_fn_to_restore(self, master_checkpoint, |
|
inception_checkpoint=None): |
|
"""Creates an init operations to restore weights from various checkpoints. |
|
|
|
Args: |
|
master_checkpoint: path to a checkpoint which contains all weights for |
|
the whole model. |
|
inception_checkpoint: path to a checkpoint which contains weights for the |
|
inception part only. |
|
|
|
Returns: |
|
a function to run initialization ops. |
|
""" |
|
all_assign_ops = [] |
|
all_feed_dict = {} |
|
|
|
def assign_from_checkpoint(variables, checkpoint): |
|
logging.info('Request to re-store %d weights from %s', |
|
len(variables), checkpoint) |
|
if not variables: |
|
logging.error('Can\'t find any variables to restore.') |
|
sys.exit(1) |
|
assign_op, feed_dict = slim.assign_from_checkpoint(checkpoint, variables) |
|
all_assign_ops.append(assign_op) |
|
all_feed_dict.update(feed_dict) |
|
|
|
logging.info('variables_to_restore:\n%s' % utils.variables_to_restore().keys()) |
|
logging.info('moving_average_variables:\n%s' % [v.op.name for v in tf.moving_average_variables()]) |
|
logging.info('trainable_variables:\n%s' % [v.op.name for v in tf.trainable_variables()]) |
|
if master_checkpoint: |
|
assign_from_checkpoint(utils.variables_to_restore(), master_checkpoint) |
|
|
|
if inception_checkpoint: |
|
variables = utils.variables_to_restore( |
|
'AttentionOcr_v1/conv_tower_fn/INCE', strip_scope=True) |
|
assign_from_checkpoint(variables, inception_checkpoint) |
|
|
|
def init_assign_fn(sess): |
|
logging.info('Restoring checkpoint(s)') |
|
sess.run(all_assign_ops, all_feed_dict) |
|
|
|
return init_assign_fn |
|
|