|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for BERT configurations and models instantiation.""" |
|
|
|
import tensorflow as tf |
|
|
|
from official.nlp.configs import bert |
|
from official.nlp.configs import encoders |
|
|
|
|
|
class BertModelsTest(tf.test.TestCase): |
|
|
|
def test_network_invocation(self): |
|
config = bert.BertPretrainerConfig( |
|
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1)) |
|
_ = bert.instantiate_bertpretrainer_from_cfg(config) |
|
|
|
|
|
config = bert.BertPretrainerConfig( |
|
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1), |
|
cls_heads=[ |
|
bert.ClsHeadConfig( |
|
inner_dim=10, num_classes=2, name="next_sentence") |
|
]) |
|
_ = bert.instantiate_bertpretrainer_from_cfg(config) |
|
|
|
with self.assertRaises(ValueError): |
|
config = bert.BertPretrainerConfig( |
|
encoder=encoders.TransformerEncoderConfig( |
|
vocab_size=10, num_layers=1), |
|
cls_heads=[ |
|
bert.ClsHeadConfig( |
|
inner_dim=10, num_classes=2, name="next_sentence"), |
|
bert.ClsHeadConfig( |
|
inner_dim=10, num_classes=2, name="next_sentence") |
|
]) |
|
_ = bert.instantiate_bertpretrainer_from_cfg(config) |
|
|
|
def test_checkpoint_items(self): |
|
config = bert.BertPretrainerConfig( |
|
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1), |
|
cls_heads=[ |
|
bert.ClsHeadConfig( |
|
inner_dim=10, num_classes=2, name="next_sentence") |
|
]) |
|
encoder = bert.instantiate_bertpretrainer_from_cfg(config) |
|
self.assertSameElements(encoder.checkpoint_items.keys(), |
|
["encoder", "next_sentence.pooler_dense"]) |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|