|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import tensorflow as tf |
|
|
|
from official.nlp.bert import bert_models |
|
from official.nlp.bert import configs as bert_configs |
|
from official.nlp.modeling import networks |
|
|
|
|
|
class BertModelsTest(tf.test.TestCase): |
|
|
|
def setUp(self): |
|
super(BertModelsTest, self).setUp() |
|
self._bert_test_config = bert_configs.BertConfig( |
|
attention_probs_dropout_prob=0.0, |
|
hidden_act='gelu', |
|
hidden_dropout_prob=0.0, |
|
hidden_size=16, |
|
initializer_range=0.02, |
|
intermediate_size=32, |
|
max_position_embeddings=128, |
|
num_attention_heads=2, |
|
num_hidden_layers=2, |
|
type_vocab_size=2, |
|
vocab_size=30522) |
|
|
|
def test_pretrain_model(self): |
|
model, encoder = bert_models.pretrain_model( |
|
self._bert_test_config, |
|
seq_length=5, |
|
max_predictions_per_seq=2, |
|
initializer=None, |
|
use_next_sentence_label=True) |
|
self.assertIsInstance(model, tf.keras.Model) |
|
self.assertIsInstance(encoder, networks.TransformerEncoder) |
|
|
|
|
|
self.assertEqual(model.output.shape.as_list(), [None,]) |
|
|
|
|
|
self.assertIsInstance(encoder.output, list) |
|
self.assertLen(encoder.output, 2) |
|
|
|
self.assertEqual(encoder.output[0].shape.as_list(), [None, 5, 16]) |
|
|
|
self.assertEqual(encoder.output[1].shape.as_list(), [None, 16]) |
|
|
|
def test_squad_model(self): |
|
model, core_model = bert_models.squad_model( |
|
self._bert_test_config, |
|
max_seq_length=5, |
|
initializer=None, |
|
hub_module_url=None, |
|
hub_module_trainable=None) |
|
self.assertIsInstance(model, tf.keras.Model) |
|
self.assertIsInstance(core_model, tf.keras.Model) |
|
|
|
|
|
self.assertIsInstance(model.output, list) |
|
self.assertLen(model.output, 2) |
|
|
|
self.assertEqual(model.output[0].shape.as_list(), [None, 5]) |
|
|
|
self.assertEqual(model.output[1].shape.as_list(), [None, 5]) |
|
|
|
|
|
self.assertIsInstance(core_model.output, list) |
|
self.assertLen(core_model.output, 2) |
|
|
|
self.assertEqual(core_model.output[0].shape.as_list(), [None, 5, 16]) |
|
|
|
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16]) |
|
|
|
def test_classifier_model(self): |
|
model, core_model = bert_models.classifier_model( |
|
self._bert_test_config, |
|
num_labels=3, |
|
max_seq_length=5, |
|
final_layer_initializer=None, |
|
hub_module_url=None, |
|
hub_module_trainable=None) |
|
self.assertIsInstance(model, tf.keras.Model) |
|
self.assertIsInstance(core_model, tf.keras.Model) |
|
|
|
|
|
self.assertEqual(model.output.shape.as_list(), [None, 3]) |
|
|
|
|
|
self.assertIsInstance(core_model.output, list) |
|
self.assertLen(core_model.output, 2) |
|
|
|
self.assertEqual(core_model.output[0].shape.as_list(), [None, 1, 16]) |
|
|
|
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16]) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|