|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""tf.keras Models for NHNet.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
|
|
from __future__ import print_function |
|
|
|
from absl import logging |
|
import gin |
|
import tensorflow as tf |
|
from typing import Optional, Text |
|
|
|
from official.modeling import tf_utils |
|
from official.modeling.hyperparams import params_dict |
|
from official.nlp.modeling import networks |
|
from official.nlp.modeling.layers import multi_channel_attention |
|
from official.nlp.nhnet import configs |
|
from official.nlp.nhnet import decoder |
|
from official.nlp.nhnet import utils |
|
from official.nlp.transformer import beam_search |
|
|
|
|
|
def embedding_linear(embedding_matrix, x): |
|
"""Uses embeddings as linear transformation weights.""" |
|
with tf.name_scope("presoftmax_linear"): |
|
batch_size = tf.shape(x)[0] |
|
length = tf.shape(x)[1] |
|
hidden_size = tf.shape(x)[2] |
|
vocab_size = tf.shape(embedding_matrix)[0] |
|
|
|
x = tf.reshape(x, [-1, hidden_size]) |
|
logits = tf.matmul(x, embedding_matrix, transpose_b=True) |
|
|
|
return tf.reshape(logits, [batch_size, length, vocab_size]) |
|
|
|
|
|
def _add_sos_to_seq(seq, start_token_id): |
|
"""Add a start sequence token while keeping seq length.""" |
|
batch_size = tf.shape(seq)[0] |
|
seq_len = tf.shape(seq)[1] |
|
sos_ids = tf.ones([batch_size], tf.int32) * start_token_id |
|
targets = tf.concat([tf.expand_dims(sos_ids, axis=1), seq], axis=1) |
|
targets = targets[:, :-1] |
|
tf.assert_equal(tf.shape(targets), (batch_size, seq_len)) |
|
return targets |
|
|
|
|
|
def remove_sos_from_seq(seq, pad_token_id): |
|
"""Remove the start sequence token while keeping seq length.""" |
|
batch_size, seq_len = tf_utils.get_shape_list(seq, expected_rank=2) |
|
|
|
targets = seq[:, 1:] |
|
|
|
pad_ids = tf.ones([batch_size], tf.int32) * pad_token_id |
|
targets = tf.concat([targets, tf.expand_dims(pad_ids, axis=1)], axis=1) |
|
tf.assert_equal(tf.shape(targets), (batch_size, seq_len)) |
|
return targets |
|
|
|
|
|
class Bert2Bert(tf.keras.Model): |
|
"""Bert2Bert encoder decoder model for training.""" |
|
|
|
def __init__(self, params, bert_layer, decoder_layer, name=None): |
|
super(Bert2Bert, self).__init__(name=name) |
|
self.params = params |
|
if not bert_layer.built: |
|
raise ValueError("bert_layer should be built.") |
|
if not decoder_layer.built: |
|
raise ValueError("decoder_layer should be built.") |
|
self.bert_layer = bert_layer |
|
self.decoder_layer = decoder_layer |
|
|
|
def get_config(self): |
|
return {"params": self.params.as_dict()} |
|
|
|
def get_decode_logits(self, |
|
decoder_inputs, |
|
ids, |
|
decoder_self_attention_bias, |
|
step, |
|
cache=None): |
|
if cache: |
|
if self.params.get("padded_decode", False): |
|
bias_shape = decoder_self_attention_bias.shape.as_list() |
|
self_attention_bias = tf.slice( |
|
decoder_self_attention_bias, [0, 0, step, 0], |
|
[bias_shape[0], bias_shape[1], 1, bias_shape[3]]) |
|
else: |
|
self_attention_bias = decoder_self_attention_bias[:, :, step:step + |
|
1, :step + 1] |
|
|
|
decoder_input = ids[:, -1:] |
|
else: |
|
self_attention_bias = decoder_self_attention_bias[:, :, :step + 1, :step + |
|
1] |
|
decoder_input = ids |
|
decoder_inputs["target_ids"] = decoder_input |
|
decoder_inputs["self_attention_bias"] = self_attention_bias |
|
if cache: |
|
decoder_outputs = self.decoder_layer( |
|
decoder_inputs, |
|
cache, |
|
decode_loop_step=step, |
|
padded_decode=self.params.get("padded_decode", False)) |
|
else: |
|
decoder_outputs = self.decoder_layer(decoder_inputs) |
|
logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, |
|
decoder_outputs[:, -1:, :]) |
|
logits = tf.squeeze(logits, axis=[1]) |
|
return logits |
|
|
|
def _get_symbols_to_logits_fn(self, max_decode_length): |
|
"""Returns a decoding function that calculates logits of the next tokens.""" |
|
|
|
|
|
decoder_self_attention_bias = decoder.get_attention_bias( |
|
input_tensor=None, |
|
bias_type="decoder_self", |
|
max_length=max_decode_length) |
|
|
|
def _symbols_to_logits_fn(ids, i, cache): |
|
"""Generate logits for next candidate IDs. |
|
|
|
Args: |
|
ids: Current decoded sequences. int tensor with shape [batch_size * |
|
beam_size, i + 1] |
|
i: Loop index |
|
cache: dictionary of values storing the encoder output, encoder-decoder |
|
attention bias, and previous decoder attention values. |
|
|
|
Returns: |
|
Tuple of |
|
(logits with shape [batch_size * beam_size, vocab_size], |
|
updated cache values) |
|
""" |
|
decoder_inputs = dict( |
|
all_encoder_outputs=cache["all_encoder_outputs"], |
|
attention_bias=cache["attention_bias"]) |
|
logits = self.get_decode_logits( |
|
decoder_inputs, |
|
ids, |
|
decoder_self_attention_bias, |
|
step=i, |
|
cache=cache if self.params.use_cache else None) |
|
return logits, cache |
|
|
|
return _symbols_to_logits_fn |
|
|
|
def train_decode(self, decode_outputs): |
|
logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, |
|
decode_outputs) |
|
decode_output_ids = tf.cast(tf.argmax(logits, axis=-1), tf.int32) |
|
output_log_probs = tf.nn.log_softmax(logits, axis=-1) |
|
return logits, decode_output_ids, output_log_probs |
|
|
|
def predict_decode(self, start_token_ids, cache): |
|
symbols_to_logits_fn = self._get_symbols_to_logits_fn(self.params.len_title) |
|
|
|
decoded_ids, scores = beam_search.sequence_beam_search( |
|
symbols_to_logits_fn=symbols_to_logits_fn, |
|
initial_ids=start_token_ids, |
|
initial_cache=cache, |
|
vocab_size=self.params.vocab_size, |
|
beam_size=self.params.beam_size, |
|
alpha=self.params.alpha, |
|
max_decode_length=self.params.len_title, |
|
padded_decode=self.params.get("padded_decode", False), |
|
eos_id=self.params.end_token_id) |
|
return decoded_ids, scores |
|
|
|
def _get_logits_for_decode_ids(self, decoder_inputs, top_decoded_ids): |
|
"""Returns the log probabilities for ids.""" |
|
target_ids = _add_sos_to_seq(top_decoded_ids, self.params.start_token_id) |
|
decoder_inputs["self_attention_bias"] = decoder.get_attention_bias( |
|
target_ids, bias_type="decoder_self") |
|
decoder_inputs["target_ids"] = target_ids |
|
decoder_outputs = self.decoder_layer(decoder_inputs) |
|
logits = embedding_linear(self.decoder_layer.embedding_lookup.embeddings, |
|
decoder_outputs) |
|
return logits |
|
|
|
def _init_cache(self, batch_size): |
|
num_heads = self.params.num_decoder_attn_heads |
|
dim_per_head = self.params.hidden_size // num_heads |
|
init_decode_length = ( |
|
self.params.len_title if self.params.get("padded_decode", False) else 0) |
|
cache = {} |
|
for layer in range(self.params.num_decoder_layers): |
|
cache[str(layer)] = { |
|
"key": |
|
tf.zeros( |
|
[batch_size, init_decode_length, num_heads, dim_per_head], |
|
dtype=tf.float32), |
|
"value": |
|
tf.zeros( |
|
[batch_size, init_decode_length, num_heads, dim_per_head], |
|
dtype=tf.float32) |
|
} |
|
return cache |
|
|
|
def call(self, inputs, mode="train"): |
|
"""Implements call(). |
|
|
|
Args: |
|
inputs: a dictionary of tensors. |
|
mode: string, an enum for mode, train/eval. |
|
|
|
Returns: |
|
logits, decode_output_ids, output_log_probs for training. top_decoded_ids |
|
for eval. |
|
""" |
|
input_ids = inputs["input_ids"] |
|
input_mask = inputs["input_mask"] |
|
segment_ids = inputs["segment_ids"] |
|
all_encoder_outputs, _ = self.bert_layer( |
|
[input_ids, input_mask, segment_ids]) |
|
|
|
if mode not in ("train", "eval", "predict"): |
|
raise ValueError("Invalid call mode: %s" % mode) |
|
encoder_decoder_attention_bias = decoder.get_attention_bias( |
|
input_ids, |
|
bias_type="single_cross", |
|
padding_value=self.params.pad_token_id) |
|
if mode == "train": |
|
self_attention_bias = decoder.get_attention_bias( |
|
inputs["target_ids"], bias_type="decoder_self") |
|
decoder_inputs = dict( |
|
attention_bias=encoder_decoder_attention_bias, |
|
all_encoder_outputs=all_encoder_outputs, |
|
target_ids=inputs["target_ids"], |
|
self_attention_bias=self_attention_bias) |
|
decoder_outputs = self.decoder_layer(decoder_inputs) |
|
return self.train_decode(decoder_outputs) |
|
|
|
batch_size = tf.shape(input_ids)[0] |
|
start_token_ids = tf.ones([batch_size], |
|
tf.int32) * self.params.start_token_id |
|
|
|
if self.params.use_cache: |
|
cache = self._init_cache(batch_size) |
|
else: |
|
cache = {} |
|
cache["all_encoder_outputs"] = all_encoder_outputs |
|
cache["attention_bias"] = encoder_decoder_attention_bias |
|
decoded_ids, scores = self.predict_decode(start_token_ids, cache) |
|
if mode == "predict": |
|
return decoded_ids[:, :self.params.beam_size, |
|
1:], scores[:, :self.params.beam_size] |
|
|
|
decoder_inputs = dict( |
|
attention_bias=encoder_decoder_attention_bias, |
|
all_encoder_outputs=all_encoder_outputs) |
|
top_decoded_ids = decoded_ids[:, 0, 1:] |
|
return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids) |
|
|
|
|
|
class NHNet(Bert2Bert): |
|
"""NHNet model which performs multi-doc decoding.""" |
|
|
|
def __init__(self, params, bert_layer, decoder_layer, name=None): |
|
super(NHNet, self).__init__(params, bert_layer, decoder_layer, name=name) |
|
self.doc_attention = multi_channel_attention.VotingAttention( |
|
num_heads=params.num_decoder_attn_heads, |
|
head_size=params.hidden_size // params.num_decoder_attn_heads) |
|
|
|
def _expand_doc_attention_probs(self, doc_attention_probs, target_length): |
|
"""Expands doc attention probs to fit the decoding sequence length.""" |
|
doc_attention_probs = tf.expand_dims( |
|
doc_attention_probs, axis=[1]) |
|
doc_attention_probs = tf.expand_dims( |
|
doc_attention_probs, axis=[2]) |
|
return tf.tile(doc_attention_probs, |
|
[1, self.params.num_decoder_attn_heads, target_length, 1]) |
|
|
|
def _get_symbols_to_logits_fn(self, max_decode_length): |
|
"""Returns a decoding function that calculates logits of the next tokens.""" |
|
|
|
|
|
decoder_self_attention_bias = decoder.get_attention_bias( |
|
input_tensor=None, |
|
bias_type="decoder_self", |
|
max_length=max_decode_length) |
|
|
|
def _symbols_to_logits_fn(ids, i, cache): |
|
"""Generate logits for next candidate IDs.""" |
|
if self.params.use_cache: |
|
target_length = 1 |
|
else: |
|
target_length = i + 1 |
|
decoder_inputs = dict( |
|
doc_attention_probs=self._expand_doc_attention_probs( |
|
cache["doc_attention_probs"], target_length), |
|
all_encoder_outputs=cache["all_encoder_outputs"], |
|
attention_bias=cache["attention_bias"]) |
|
logits = self.get_decode_logits( |
|
decoder_inputs, |
|
ids, |
|
decoder_self_attention_bias, |
|
step=i, |
|
cache=cache if self.params.use_cache else None) |
|
return logits, cache |
|
|
|
return _symbols_to_logits_fn |
|
|
|
def call(self, inputs, mode="training"): |
|
input_shape = tf_utils.get_shape_list(inputs["input_ids"], expected_rank=3) |
|
batch_size, num_docs, len_passage = (input_shape[0], input_shape[1], |
|
input_shape[2]) |
|
input_ids = tf.reshape(inputs["input_ids"], [-1, len_passage]) |
|
input_mask = tf.reshape(inputs["input_mask"], [-1, len_passage]) |
|
segment_ids = tf.reshape(inputs["segment_ids"], [-1, len_passage]) |
|
all_encoder_outputs, _ = self.bert_layer( |
|
[input_ids, input_mask, segment_ids]) |
|
encoder_outputs = tf.reshape( |
|
all_encoder_outputs[-1], |
|
[batch_size, num_docs, len_passage, self.params.hidden_size]) |
|
doc_attention_mask = tf.reshape( |
|
tf.cast( |
|
tf.math.count_nonzero(input_mask, axis=1, dtype=tf.int32) > 2, |
|
tf.int32), [batch_size, num_docs]) |
|
|
|
doc_attention_probs = self.doc_attention(encoder_outputs, |
|
doc_attention_mask) |
|
encoder_decoder_attention_bias = decoder.get_attention_bias( |
|
inputs["input_ids"], |
|
bias_type="multi_cross", |
|
padding_value=self.params.pad_token_id) |
|
|
|
if mode == "train": |
|
target_length = tf_utils.get_shape_list( |
|
inputs["target_ids"], expected_rank=2)[1] |
|
doc_attention_probs = self._expand_doc_attention_probs( |
|
doc_attention_probs, target_length) |
|
self_attention_bias = decoder.get_attention_bias( |
|
inputs["target_ids"], bias_type="decoder_self") |
|
decoder_inputs = dict( |
|
attention_bias=encoder_decoder_attention_bias, |
|
self_attention_bias=self_attention_bias, |
|
target_ids=inputs["target_ids"], |
|
all_encoder_outputs=encoder_outputs, |
|
doc_attention_probs=doc_attention_probs) |
|
decoder_outputs = self.decoder_layer(decoder_inputs) |
|
return self.train_decode(decoder_outputs) |
|
|
|
|
|
if self.params.use_cache: |
|
cache = self._init_cache(batch_size) |
|
else: |
|
cache = {} |
|
cache["all_encoder_outputs"] = [encoder_outputs] |
|
cache["attention_bias"] = encoder_decoder_attention_bias |
|
cache["doc_attention_probs"] = doc_attention_probs |
|
|
|
start_token_ids = tf.ones([batch_size], |
|
tf.int32) * self.params.start_token_id |
|
decoded_ids, scores = self.predict_decode(start_token_ids, cache) |
|
if mode == "predict": |
|
return decoded_ids[:, :self.params.beam_size, |
|
1:], scores[:, :self.params.beam_size] |
|
|
|
top_decoded_ids = decoded_ids[:, 0, 1:] |
|
target_length = tf_utils.get_shape_list(top_decoded_ids)[-1] |
|
decoder_inputs = dict( |
|
attention_bias=encoder_decoder_attention_bias, |
|
all_encoder_outputs=[encoder_outputs], |
|
doc_attention_probs=self._expand_doc_attention_probs( |
|
doc_attention_probs, target_length)) |
|
return self._get_logits_for_decode_ids(decoder_inputs, top_decoded_ids) |
|
|
|
|
|
def get_bert2bert_layers(params: configs.BERT2BERTConfig): |
|
"""Creates a Bert2Bert stem model and returns Bert encoder/decoder. |
|
|
|
We use funtional-style to create stem model because we need to make all layers |
|
built to restore variables in a customized way. The layers are called with |
|
placeholder inputs to make them fully built. |
|
|
|
Args: |
|
params: ParamsDict. |
|
|
|
Returns: |
|
two keras Layers, bert_model_layer and decoder_layer |
|
""" |
|
input_ids = tf.keras.layers.Input( |
|
shape=(None,), name="input_ids", dtype=tf.int32) |
|
input_mask = tf.keras.layers.Input( |
|
shape=(None,), name="input_mask", dtype=tf.int32) |
|
segment_ids = tf.keras.layers.Input( |
|
shape=(None,), name="segment_ids", dtype=tf.int32) |
|
target_ids = tf.keras.layers.Input( |
|
shape=(None,), name="target_ids", dtype=tf.int32) |
|
bert_config = utils.get_bert_config_from_params(params) |
|
bert_model_layer = networks.TransformerEncoder( |
|
vocab_size=bert_config.vocab_size, |
|
hidden_size=bert_config.hidden_size, |
|
num_layers=bert_config.num_hidden_layers, |
|
num_attention_heads=bert_config.num_attention_heads, |
|
intermediate_size=bert_config.intermediate_size, |
|
activation=tf_utils.get_activation(bert_config.hidden_act), |
|
dropout_rate=bert_config.hidden_dropout_prob, |
|
attention_dropout_rate=bert_config.attention_probs_dropout_prob, |
|
sequence_length=None, |
|
max_sequence_length=bert_config.max_position_embeddings, |
|
type_vocab_size=bert_config.type_vocab_size, |
|
initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=bert_config.initializer_range), |
|
return_all_encoder_outputs=True, |
|
name="bert_encoder") |
|
all_encoder_outputs, _ = bert_model_layer( |
|
[input_ids, input_mask, segment_ids]) |
|
|
|
decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) |
|
|
|
cross_attention_bias = decoder.AttentionBias(bias_type="single_cross")( |
|
input_ids) |
|
self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( |
|
target_ids) |
|
decoder_inputs = dict( |
|
attention_bias=cross_attention_bias, |
|
self_attention_bias=self_attention_bias, |
|
target_ids=target_ids, |
|
all_encoder_outputs=all_encoder_outputs) |
|
_ = decoder_layer(decoder_inputs) |
|
|
|
return bert_model_layer, decoder_layer |
|
|
|
|
|
def get_nhnet_layers(params: configs.NHNetConfig): |
|
"""Creates a Mult-doc encoder/decoder. |
|
|
|
Args: |
|
params: ParamsDict. |
|
|
|
Returns: |
|
two keras Layers, bert_model_layer and decoder_layer |
|
""" |
|
input_ids = tf.keras.layers.Input( |
|
shape=(None,), name="input_ids", dtype=tf.int32) |
|
input_mask = tf.keras.layers.Input( |
|
shape=(None,), name="input_mask", dtype=tf.int32) |
|
segment_ids = tf.keras.layers.Input( |
|
shape=(None,), name="segment_ids", dtype=tf.int32) |
|
bert_config = utils.get_bert_config_from_params(params) |
|
bert_model_layer = networks.TransformerEncoder( |
|
vocab_size=bert_config.vocab_size, |
|
hidden_size=bert_config.hidden_size, |
|
num_layers=bert_config.num_hidden_layers, |
|
num_attention_heads=bert_config.num_attention_heads, |
|
intermediate_size=bert_config.intermediate_size, |
|
activation=tf_utils.get_activation(bert_config.hidden_act), |
|
dropout_rate=bert_config.hidden_dropout_prob, |
|
attention_dropout_rate=bert_config.attention_probs_dropout_prob, |
|
sequence_length=None, |
|
max_sequence_length=bert_config.max_position_embeddings, |
|
type_vocab_size=bert_config.type_vocab_size, |
|
initializer=tf.keras.initializers.TruncatedNormal( |
|
stddev=bert_config.initializer_range), |
|
return_all_encoder_outputs=True, |
|
name="bert_encoder") |
|
bert_model_layer([input_ids, input_mask, segment_ids]) |
|
|
|
input_ids = tf.keras.layers.Input( |
|
shape=(None, None), name="input_ids", dtype=tf.int32) |
|
all_encoder_outputs = tf.keras.layers.Input((None, None, params.hidden_size), |
|
dtype=tf.float32) |
|
target_ids = tf.keras.layers.Input( |
|
shape=(None,), name="target_ids", dtype=tf.int32) |
|
doc_attention_probs = tf.keras.layers.Input( |
|
(params.num_decoder_attn_heads, None, None), dtype=tf.float32) |
|
|
|
decoder_layer = decoder.Decoder(params, bert_model_layer._embedding_layer) |
|
|
|
cross_attention_bias = decoder.AttentionBias(bias_type="multi_cross")( |
|
input_ids) |
|
self_attention_bias = decoder.AttentionBias(bias_type="decoder_self")( |
|
target_ids) |
|
decoder_inputs = dict( |
|
attention_bias=cross_attention_bias, |
|
self_attention_bias=self_attention_bias, |
|
target_ids=target_ids, |
|
all_encoder_outputs=all_encoder_outputs, |
|
doc_attention_probs=doc_attention_probs) |
|
_ = decoder_layer(decoder_inputs) |
|
|
|
return bert_model_layer, decoder_layer |
|
|
|
|
|
def create_transformer_model(params, |
|
init_checkpoint: Optional[Text] = None |
|
) -> tf.keras.Model: |
|
"""A helper to create Transformer model.""" |
|
bert_layer, decoder_layer = get_bert2bert_layers(params=params) |
|
model = Bert2Bert( |
|
params=params, |
|
bert_layer=bert_layer, |
|
decoder_layer=decoder_layer, |
|
name="transformer") |
|
|
|
if init_checkpoint: |
|
logging.info( |
|
"Checkpoint file %s found and restoring from " |
|
"initial checkpoint.", init_checkpoint) |
|
ckpt = tf.train.Checkpoint(model=model) |
|
ckpt.restore(init_checkpoint).expect_partial() |
|
|
|
return model |
|
|
|
|
|
def create_bert2bert_model( |
|
params: configs.BERT2BERTConfig, |
|
cls=Bert2Bert, |
|
init_checkpoint: Optional[Text] = None) -> tf.keras.Model: |
|
"""A helper to create Bert2Bert model.""" |
|
bert_layer, decoder_layer = get_bert2bert_layers(params=params) |
|
if init_checkpoint: |
|
utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer, |
|
init_checkpoint) |
|
return cls( |
|
params=params, |
|
bert_layer=bert_layer, |
|
decoder_layer=decoder_layer, |
|
name="bert2bert") |
|
|
|
|
|
def create_nhnet_model( |
|
params: configs.NHNetConfig, |
|
cls=NHNet, |
|
init_checkpoint: Optional[Text] = None) -> tf.keras.Model: |
|
"""A helper to create NHNet model.""" |
|
bert_layer, decoder_layer = get_nhnet_layers(params=params) |
|
model = cls( |
|
params=params, |
|
bert_layer=bert_layer, |
|
decoder_layer=decoder_layer, |
|
name="nhnet") |
|
if init_checkpoint: |
|
logging.info( |
|
"Checkpoint file %s found and restoring from " |
|
"initial checkpoint.", init_checkpoint) |
|
if params.init_from_bert2bert: |
|
ckpt = tf.train.Checkpoint(model=model) |
|
ckpt.restore(init_checkpoint).assert_existing_objects_matched() |
|
else: |
|
utils.initialize_bert2bert_from_pretrained_bert(bert_layer, decoder_layer, |
|
init_checkpoint) |
|
return model |
|
|
|
|
|
@gin.configurable |
|
def get_model_params(model: Optional[Text] = "bert2bert", |
|
config_class=None) -> params_dict.ParamsDict: |
|
"""Helper function to convert config file to ParamsDict.""" |
|
if model == "bert2bert": |
|
return configs.BERT2BERTConfig() |
|
elif model == "nhnet": |
|
return configs.NHNetConfig() |
|
elif config_class: |
|
return config_class() |
|
else: |
|
raise KeyError("The model type is not defined: %s" % model) |
|
|
|
|
|
@gin.configurable |
|
def create_model(model_type: Text, |
|
params, |
|
init_checkpoint: Optional[Text] = None): |
|
"""A factory function to create different types of models.""" |
|
if model_type == "bert2bert": |
|
return create_bert2bert_model(params, init_checkpoint=init_checkpoint) |
|
elif model_type == "nhnet": |
|
return create_nhnet_model(params, init_checkpoint=init_checkpoint) |
|
elif "transformer" in model_type: |
|
return create_transformer_model( |
|
params, init_checkpoint=init_checkpoint) |
|
else: |
|
raise KeyError("The model type is not defined: %s" % model_type) |
|
|