|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""CNN-BiLSTM sentence encoder.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import tensorflow as tf |
|
from base import embeddings |
|
from model import model_helpers |
|
|
|
|
|
class Encoder(object): |
|
def __init__(self, config, inputs, pretrained_embeddings): |
|
self._config = config |
|
self._inputs = inputs |
|
|
|
self.word_reprs = self._get_word_reprs(pretrained_embeddings) |
|
self.uni_fw, self.uni_bw = self._get_unidirectional_reprs(self.word_reprs) |
|
self.uni_reprs = tf.concat([self.uni_fw, self.uni_bw], axis=-1) |
|
self.bi_fw, self.bi_bw, self.bi_reprs = self._get_bidirectional_reprs( |
|
self.uni_reprs) |
|
|
|
def _get_word_reprs(self, pretrained_embeddings): |
|
with tf.variable_scope('word_embeddings'): |
|
word_embedding_matrix = tf.get_variable( |
|
'word_embedding_matrix', initializer=pretrained_embeddings) |
|
word_embeddings = tf.nn.embedding_lookup( |
|
word_embedding_matrix, self._inputs.words) |
|
word_embeddings = tf.nn.dropout(word_embeddings, self._inputs.keep_prob) |
|
word_embeddings *= tf.get_variable('emb_scale', initializer=1.0) |
|
|
|
if not self._config.use_chars: |
|
return word_embeddings |
|
|
|
with tf.variable_scope('char_embeddings'): |
|
char_embedding_matrix = tf.get_variable( |
|
'char_embeddings', |
|
shape=[embeddings.NUM_CHARS, self._config.char_embedding_size]) |
|
char_embeddings = tf.nn.embedding_lookup(char_embedding_matrix, |
|
self._inputs.chars) |
|
shape = tf.shape(char_embeddings) |
|
char_embeddings = tf.reshape( |
|
char_embeddings, |
|
shape=[-1, shape[-2], self._config.char_embedding_size]) |
|
char_reprs = [] |
|
for filter_width in self._config.char_cnn_filter_widths: |
|
conv = tf.layers.conv1d( |
|
char_embeddings, self._config.char_cnn_n_filters, filter_width) |
|
conv = tf.nn.relu(conv) |
|
conv = tf.nn.dropout(tf.reduce_max(conv, axis=1), |
|
self._inputs.keep_prob) |
|
conv = tf.reshape(conv, shape=[-1, shape[1], |
|
self._config.char_cnn_n_filters]) |
|
char_reprs.append(conv) |
|
return tf.concat([word_embeddings] + char_reprs, axis=-1) |
|
|
|
def _get_unidirectional_reprs(self, word_reprs): |
|
with tf.variable_scope('unidirectional_reprs'): |
|
word_lstm_input_size = ( |
|
self._config.word_embedding_size if not self._config.use_chars else |
|
(self._config.word_embedding_size + |
|
len(self._config.char_cnn_filter_widths) |
|
* self._config.char_cnn_n_filters)) |
|
word_reprs.set_shape([None, None, word_lstm_input_size]) |
|
(outputs_fw, outputs_bw), _ = tf.nn.bidirectional_dynamic_rnn( |
|
model_helpers.multi_lstm_cell(self._config.unidirectional_sizes, |
|
self._inputs.keep_prob, |
|
self._config.projection_size), |
|
model_helpers.multi_lstm_cell(self._config.unidirectional_sizes, |
|
self._inputs.keep_prob, |
|
self._config.projection_size), |
|
word_reprs, |
|
dtype=tf.float32, |
|
sequence_length=self._inputs.lengths, |
|
scope='unilstm' |
|
) |
|
return outputs_fw, outputs_bw |
|
|
|
def _get_bidirectional_reprs(self, uni_reprs): |
|
with tf.variable_scope('bidirectional_reprs'): |
|
current_outputs = uni_reprs |
|
outputs_fw, outputs_bw = None, None |
|
for size in self._config.bidirectional_sizes: |
|
(outputs_fw, outputs_bw), _ = tf.nn.bidirectional_dynamic_rnn( |
|
model_helpers.lstm_cell(size, self._inputs.keep_prob, |
|
self._config.projection_size), |
|
model_helpers.lstm_cell(size, self._inputs.keep_prob, |
|
self._config.projection_size), |
|
current_outputs, |
|
dtype=tf.float32, |
|
sequence_length=self._inputs.lengths, |
|
scope='bilstm' |
|
) |
|
current_outputs = tf.concat([outputs_fw, outputs_bw], axis=-1) |
|
return outputs_fw, outputs_bw, current_outputs |
|
|