|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for official.nlp.tasks.masked_lm.""" |
|
|
|
import tensorflow as tf |
|
|
|
from official.nlp.configs import bert |
|
from official.nlp.configs import encoders |
|
from official.nlp.tasks import masked_lm |
|
|
|
|
|
class MLMTaskTest(tf.test.TestCase): |
|
|
|
def test_task(self): |
|
config = masked_lm.MaskedLMConfig( |
|
network=bert.BertPretrainerConfig( |
|
encoders.TransformerEncoderConfig(vocab_size=30522, num_layers=1), |
|
num_masked_tokens=20, |
|
cls_heads=[ |
|
bert.ClsHeadConfig( |
|
inner_dim=10, num_classes=2, name="next_sentence") |
|
]), |
|
train_data=bert.BertPretrainDataConfig( |
|
input_path="dummy", |
|
max_predictions_per_seq=20, |
|
seq_length=128, |
|
global_batch_size=1)) |
|
task = masked_lm.MaskedLMTask(config) |
|
model = task.build_model() |
|
metrics = task.build_metrics() |
|
dataset = task.build_inputs(config.train_data) |
|
|
|
iterator = iter(dataset) |
|
optimizer = tf.keras.optimizers.SGD(lr=0.1) |
|
task.train_step(next(iterator), model, optimizer, metrics=metrics) |
|
task.validation_step(next(iterator), model, metrics=metrics) |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|