|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for nlp.nhnet.models.""" |
|
|
|
import os |
|
|
|
from absl import logging |
|
from absl.testing import parameterized |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
from tensorflow.python.distribute import combinations |
|
from tensorflow.python.distribute import strategy_combinations |
|
|
|
from official.nlp.nhnet import configs |
|
from official.nlp.nhnet import models |
|
from official.nlp.nhnet import utils |
|
|
|
|
|
def all_strategy_combinations(): |
|
return combinations.combine( |
|
distribution=[ |
|
strategy_combinations.default_strategy, |
|
strategy_combinations.tpu_strategy, |
|
strategy_combinations.one_device_strategy_gpu, |
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu, |
|
strategy_combinations.mirrored_strategy_with_two_gpus, |
|
], |
|
mode="eager", |
|
) |
|
|
|
|
|
def distribution_forward_path(strategy, |
|
model, |
|
inputs, |
|
batch_size, |
|
mode="train"): |
|
dataset = tf.data.Dataset.from_tensor_slices((inputs)) |
|
dataset = dataset.batch(batch_size) |
|
dataset = strategy.experimental_distribute_dataset(dataset) |
|
|
|
@tf.function |
|
def test_step(inputs): |
|
"""Calculates evaluation metrics on distributed devices.""" |
|
|
|
def _test_step_fn(inputs): |
|
"""Replicated accuracy calculation.""" |
|
return model(inputs, mode=mode, training=False) |
|
|
|
outputs = strategy.run(_test_step_fn, args=(inputs,)) |
|
return tf.nest.map_structure(strategy.experimental_local_results, outputs) |
|
|
|
return [test_step(inputs) for inputs in dataset] |
|
|
|
|
|
def process_decoded_ids(predictions, end_token_id): |
|
"""Transforms decoded tensors to lists ending with END_TOKEN_ID.""" |
|
if isinstance(predictions, tf.Tensor): |
|
predictions = predictions.numpy() |
|
flatten_ids = predictions.reshape((-1, predictions.shape[-1])) |
|
results = [] |
|
for ids in flatten_ids: |
|
ids = list(ids) |
|
if end_token_id in ids: |
|
ids = ids[:ids.index(end_token_id)] |
|
results.append(ids) |
|
return results |
|
|
|
|
|
class Bert2BertTest(tf.test.TestCase, parameterized.TestCase): |
|
|
|
def setUp(self): |
|
super(Bert2BertTest, self).setUp() |
|
self._config = utils.get_test_params() |
|
|
|
def test_model_creation(self): |
|
model = models.create_bert2bert_model(params=self._config) |
|
fake_ids = np.zeros((2, 10), dtype=np.int32) |
|
fake_inputs = { |
|
"input_ids": fake_ids, |
|
"input_mask": fake_ids, |
|
"segment_ids": fake_ids, |
|
"target_ids": fake_ids, |
|
} |
|
model(fake_inputs) |
|
|
|
@combinations.generate(all_strategy_combinations()) |
|
def test_bert2bert_train_forward(self, distribution): |
|
seq_length = 10 |
|
|
|
with distribution.scope(): |
|
|
|
batch_size = 2 |
|
batches = 4 |
|
fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32) |
|
fake_inputs = { |
|
"input_ids": fake_ids, |
|
"input_mask": fake_ids, |
|
"segment_ids": fake_ids, |
|
"target_ids": fake_ids, |
|
} |
|
model = models.create_bert2bert_model(params=self._config) |
|
results = distribution_forward_path(distribution, model, fake_inputs, |
|
batch_size) |
|
logging.info("Forward path results: %s", str(results)) |
|
self.assertLen(results, batches) |
|
|
|
def test_bert2bert_decoding(self): |
|
seq_length = 10 |
|
self._config.override( |
|
{ |
|
"beam_size": 3, |
|
"len_title": seq_length, |
|
"alpha": 0.6, |
|
}, |
|
is_strict=False) |
|
|
|
batch_size = 2 |
|
fake_ids = np.zeros((batch_size, seq_length), dtype=np.int32) |
|
fake_inputs = { |
|
"input_ids": fake_ids, |
|
"input_mask": fake_ids, |
|
"segment_ids": fake_ids, |
|
} |
|
self._config.override({ |
|
"padded_decode": False, |
|
"use_cache": False, |
|
}, |
|
is_strict=False) |
|
model = models.create_bert2bert_model(params=self._config) |
|
ckpt = tf.train.Checkpoint(model=model) |
|
|
|
|
|
init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt")) |
|
ckpt.restore(init_checkpoint).assert_existing_objects_matched() |
|
top_ids, scores = model(fake_inputs, mode="predict") |
|
|
|
self._config.override({ |
|
"padded_decode": False, |
|
"use_cache": True, |
|
}, |
|
is_strict=False) |
|
model = models.create_bert2bert_model(params=self._config) |
|
ckpt = tf.train.Checkpoint(model=model) |
|
ckpt.restore(init_checkpoint).assert_existing_objects_matched() |
|
cached_top_ids, cached_scores = model(fake_inputs, mode="predict") |
|
self.assertEqual( |
|
process_decoded_ids(top_ids, self._config.end_token_id), |
|
process_decoded_ids(cached_top_ids, self._config.end_token_id)) |
|
self.assertAllClose(scores, cached_scores) |
|
|
|
self._config.override({ |
|
"padded_decode": True, |
|
"use_cache": True, |
|
}, |
|
is_strict=False) |
|
model = models.create_bert2bert_model(params=self._config) |
|
ckpt = tf.train.Checkpoint(model=model) |
|
ckpt.restore(init_checkpoint).assert_existing_objects_matched() |
|
padded_top_ids, padded_scores = model(fake_inputs, mode="predict") |
|
self.assertEqual( |
|
process_decoded_ids(top_ids, self._config.end_token_id), |
|
process_decoded_ids(padded_top_ids, self._config.end_token_id)) |
|
self.assertAllClose(scores, padded_scores) |
|
|
|
@combinations.generate(all_strategy_combinations()) |
|
def test_bert2bert_eval(self, distribution): |
|
seq_length = 10 |
|
padded_decode = isinstance(distribution, |
|
tf.distribute.experimental.TPUStrategy) |
|
self._config.override( |
|
{ |
|
"beam_size": 3, |
|
"len_title": seq_length, |
|
"alpha": 0.6, |
|
"padded_decode": padded_decode, |
|
}, |
|
is_strict=False) |
|
|
|
with distribution.scope(): |
|
|
|
batch_size = 2 |
|
batches = 4 |
|
fake_ids = np.zeros((batch_size * batches, seq_length), dtype=np.int32) |
|
fake_inputs = { |
|
"input_ids": fake_ids, |
|
"input_mask": fake_ids, |
|
"segment_ids": fake_ids, |
|
} |
|
model = models.create_bert2bert_model(params=self._config) |
|
results = distribution_forward_path( |
|
distribution, model, fake_inputs, batch_size, mode="predict") |
|
self.assertLen(results, batches) |
|
results = distribution_forward_path( |
|
distribution, model, fake_inputs, batch_size, mode="eval") |
|
self.assertLen(results, batches) |
|
|
|
|
|
class NHNetTest(tf.test.TestCase, parameterized.TestCase): |
|
|
|
def setUp(self): |
|
super(NHNetTest, self).setUp() |
|
self._nhnet_config = configs.NHNetConfig() |
|
self._nhnet_config.override(utils.get_test_params().as_dict()) |
|
self._bert2bert_config = configs.BERT2BERTConfig() |
|
self._bert2bert_config.override(utils.get_test_params().as_dict()) |
|
|
|
def _count_params(self, layer, trainable_only=True): |
|
"""Returns the count of all model parameters, or just trainable ones.""" |
|
if not trainable_only: |
|
return layer.count_params() |
|
else: |
|
return int( |
|
np.sum([ |
|
tf.keras.backend.count_params(p) for p in layer.trainable_weights |
|
])) |
|
|
|
def test_create_nhnet_layers(self): |
|
single_doc_bert, single_doc_decoder = models.get_bert2bert_layers( |
|
self._bert2bert_config) |
|
multi_doc_bert, multi_doc_decoder = models.get_nhnet_layers( |
|
self._nhnet_config) |
|
|
|
|
|
self.assertEqual( |
|
self._count_params(multi_doc_bert), self._count_params(single_doc_bert)) |
|
self.assertEqual( |
|
self._count_params(multi_doc_decoder), |
|
self._count_params(single_doc_decoder)) |
|
|
|
def test_checkpoint_restore(self): |
|
bert2bert_model = models.create_bert2bert_model(self._bert2bert_config) |
|
ckpt = tf.train.Checkpoint(model=bert2bert_model) |
|
init_checkpoint = ckpt.save(os.path.join(self.get_temp_dir(), "ckpt")) |
|
nhnet_model = models.create_nhnet_model( |
|
params=self._nhnet_config, init_checkpoint=init_checkpoint) |
|
source_weights = ( |
|
bert2bert_model.bert_layer.trainable_weights + |
|
bert2bert_model.decoder_layer.trainable_weights) |
|
dest_weights = ( |
|
nhnet_model.bert_layer.trainable_weights + |
|
nhnet_model.decoder_layer.trainable_weights) |
|
for source_weight, dest_weight in zip(source_weights, dest_weights): |
|
self.assertAllClose(source_weight.numpy(), dest_weight.numpy()) |
|
|
|
@combinations.generate(all_strategy_combinations()) |
|
def test_nhnet_train_forward(self, distribution): |
|
seq_length = 10 |
|
|
|
with distribution.scope(): |
|
|
|
batch_size = 2 |
|
num_docs = 2 |
|
batches = 4 |
|
fake_ids = np.zeros((batch_size * batches, num_docs, seq_length), |
|
dtype=np.int32) |
|
fake_inputs = { |
|
"input_ids": |
|
fake_ids, |
|
"input_mask": |
|
fake_ids, |
|
"segment_ids": |
|
fake_ids, |
|
"target_ids": |
|
np.zeros((batch_size * batches, seq_length * 2), dtype=np.int32), |
|
} |
|
model = models.create_nhnet_model(params=self._nhnet_config) |
|
results = distribution_forward_path(distribution, model, fake_inputs, |
|
batch_size) |
|
logging.info("Forward path results: %s", str(results)) |
|
self.assertLen(results, batches) |
|
|
|
@combinations.generate(all_strategy_combinations()) |
|
def test_nhnet_eval(self, distribution): |
|
seq_length = 10 |
|
padded_decode = isinstance(distribution, |
|
tf.distribute.experimental.TPUStrategy) |
|
self._nhnet_config.override( |
|
{ |
|
"beam_size": 4, |
|
"len_title": seq_length, |
|
"alpha": 0.6, |
|
"multi_channel_cross_attention": True, |
|
"padded_decode": padded_decode, |
|
}, |
|
is_strict=False) |
|
|
|
with distribution.scope(): |
|
|
|
batch_size = 2 |
|
num_docs = 2 |
|
batches = 4 |
|
fake_ids = np.zeros((batch_size * batches, num_docs, seq_length), |
|
dtype=np.int32) |
|
fake_inputs = { |
|
"input_ids": fake_ids, |
|
"input_mask": fake_ids, |
|
"segment_ids": fake_ids, |
|
"target_ids": np.zeros((batch_size * batches, 5), dtype=np.int32), |
|
} |
|
model = models.create_nhnet_model(params=self._nhnet_config) |
|
results = distribution_forward_path( |
|
distribution, model, fake_inputs, batch_size, mode="predict") |
|
self.assertLen(results, batches) |
|
results = distribution_forward_path( |
|
distribution, model, fake_inputs, batch_size, mode="eval") |
|
self.assertLen(results, batches) |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|