diff --git a/modeling/official/README-TPU.md b/modeling/official/README-TPU.md
new file mode 100644
index 0000000000000000000000000000000000000000..a6031c44f0338e18762d8d7183299931aa3a285f
--- /dev/null
+++ b/modeling/official/README-TPU.md
@@ -0,0 +1,32 @@
+# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
+
+## Natural Language Processing
+
+* [bert](nlp/bert): A powerful pre-trained language representation model:
+ BERT, which stands for Bidirectional Encoder Representations from
+ Transformers.
+ [BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
+* [transformer](nlp/transformer): A transformer model to translate the WMT
+ English to German dataset.
+ [Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
+
+## Computer Vision
+
+* [efficientnet](vision/image_classification): A family of convolutional
+ neural networks that scale by balancing network depth, width, and
+ resolution and can be used to classify ImageNet's dataset of 1000 classes.
+ See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
+* [mnist](vision/image_classification): A basic model to classify digits
+ from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
+* [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
+* [resnet](vision/image_classification): A deep residual network that can
+ be used to classify ImageNet's dataset of 1000 classes.
+ See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
+* [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
+* [shapemask](vision/detection): An object detection and instance segmentation model using shape priors. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/ZbXgVoc6Rf6mBRlPj0JpLA).
+
+## Recommendation
+* [dlrm](recommendation/ranking): [Deep Learning Recommendation Model for
+Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
+* [dcn v2](recommendation/ranking): [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535).
+* [ncf](recommendation): Neural Collaborative Filtering. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/0k3gKjZlR1ewkVTRyLB6IQ).
diff --git a/modeling/official/README.md b/modeling/official/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0235710b2c3e7fb917d7655fe96605c1c7f62923
--- /dev/null
+++ b/modeling/official/README.md
@@ -0,0 +1,166 @@
+
+
+
+
+# TensorFlow Official Models
+
+The TensorFlow official models are a collection of models
+that use TensorFlow’s high-level APIs.
+They are intended to be well-maintained, tested, and kept up to date
+with the latest TensorFlow API.
+
+They should also be reasonably optimized for fast performance while still
+being easy to read.
+These models are used as end-to-end tests, ensuring that the models run
+with the same or improved speed and performance with each new TensorFlow build.
+
+The API documentation of the latest stable release is published to
+[tensorflow.org](https://www.tensorflow.org/api_docs/python/tfm).
+
+## More models to come!
+
+The team is actively developing new models.
+In the near future, we will add:
+
+* State-of-the-art language understanding models.
+* State-of-the-art image classification models.
+* State-of-the-art object detection and instance segmentation models.
+* State-of-the-art video classification models.
+
+## Table of Contents
+
+- [Models and Implementations](#models-and-implementations)
+ * [Computer Vision](#computer-vision)
+ + [Image Classification](#image-classification)
+ + [Object Detection and Segmentation](#object-detection-and-segmentation)
+ + [Video Classification](#video-classification)
+ * [Natural Language Processing](#natural-language-processing)
+ * [Recommendation](#recommendation)
+- [How to get started with the official models](#how-to-get-started-with-the-official-models)
+- [Contributions](#contributions)
+
+## Models and Implementations
+
+### [Computer Vision](vision/README.md)
+
+#### Image Classification
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [ResNet](vision/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
+| [ResNet-RS](vision/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
+| [EfficientNet](vision/MODEL_GARDEN.md) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
+| [Vision Transformer](vision/MODEL_GARDEN.md) | [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) |
+
+#### Object Detection and Segmentation
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [RetinaNet](vision/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
+| [Mask R-CNN](vision/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
+| [YOLO](projects/yolo/README.md) | [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696) |
+| [SpineNet](vision/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
+| [Cascade RCNN-RS and RetinaNet-RS](vision/MODEL_GARDEN.md) | [Simple Training Strategies and Model Scaling for Object Detection](https://arxiv.org/abs/2107.00057)|
+
+#### Video Classification
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [Mobile Video Networks (MoViNets)](projects/movinet) | [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) |
+
+### [Natural Language Processing](nlp/README.md)
+
+#### Pre-trained Language Model
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [ALBERT](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
+| [BERT](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
+| [ELECTRA](nlp/tasks/electra_task.py) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) |
+
+
+#### Neural Machine Translation
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [Transformer](nlp/MODEL_GARDEN.md#available-model-configs) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
+
+#### Natural Language Generation
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
+
+
+#### Knowledge Distillation
+
+| Model | Reference (Paper) |
+|-------|-------------------|
+| [MobileBERT](projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
+
+### Recommendation
+
+Model | Reference (Paper)
+-------------------------------- | -----------------
+[DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
+[DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
+[NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)
+
+## How to get started with the official models
+
+* The official models in the master branch are developed using
+[master branch of TensorFlow 2](https://github.com/tensorflow/tensorflow/tree/master).
+When you clone (the repository) or download (`pip` binary) master branch of
+official models , master branch of TensorFlow gets downloaded as a
+dependency. This is equivalent to the following.
+
+```shell
+pip3 install tf-models-nightly
+pip3 install tensorflow-text-nightly # when model uses `nlp` packages
+```
+
+* Incase of stable versions, targeting a specific release, Tensorflow-models
+repository version numbers match with the target TensorFlow release. For
+example, [TensorFlow-models v2.8.x](https://github.com/tensorflow/models/releases/tag/v2.8.0)
+is compatible with [TensorFlow v2.8.x](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0).
+This is equivalent to the following:
+
+```shell
+pip3 install tf-models-official==2.8.0
+pip3 install tensorflow-text==2.8.0 # when models in uses `nlp` packages
+```
+
+Starting from 2.9.x release, we release the modeling library as
+`tensorflow_models` package and users can `import tensorflow_models` directly to
+access to the exported symbols. If you are
+using the latest nightly version or github code directly, please follow the
+docstrings in the github.
+
+Please follow the below steps before running models in this repository.
+
+### Requirements
+
+* The latest TensorFlow Model Garden release and the latest TensorFlow 2
+ * If you are on a version of TensorFlow earlier than 2.2, please
+upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
+* Python 3.7+
+
+Our integration tests run with Python 3.7. Although Python 3.6 should work, we
+don't recommend earlier versions.
+
+### Installation
+
+Please check [here](https://github.com/tensorflow/models#Installation) for the
+instructions.
+
+Available pypi packages:
+
+* [tf-models-official](https://pypi.org/project/tf-models-official/)
+* [tf-models-nightly](https://pypi.org/project/tf-models-nightly/): nightly
+release with the latest changes.
+* [tf-models-no-deps](https://pypi.org/project/tf-models-no-deps/): without
+`tensorflow` and `tensorflow-text` in the `install_requires` list.
+
+## Contributions
+
+If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
diff --git a/modeling/official/__init__.py b/modeling/official/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/common/__init__.py b/modeling/official/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f338592c943c69c8ca66bc1f0981a619ea10e27
--- /dev/null
+++ b/modeling/official/common/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
diff --git a/modeling/official/common/dataset_fn.py b/modeling/official/common/dataset_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..420099abd37abe2c33afc137bba587d4c1a4f877
--- /dev/null
+++ b/modeling/official/common/dataset_fn.py
@@ -0,0 +1,44 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility library for picking an appropriate dataset function."""
+
+import functools
+from typing import Any, Callable, Type, Union
+
+import tensorflow as tf, tf_keras
+
+PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
+
+
+def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
+ if file_type == 'tfrecord':
+ return tf.data.TFRecordDataset
+ if file_type == 'tfrecord_compressed':
+ return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
+ raise ValueError('Unrecognized file_type: {}'.format(file_type))
diff --git a/modeling/official/common/distribute_utils.py b/modeling/official/common/distribute_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..35841ff4587e9d18fb935ae79e88b0770ee2a13f
--- /dev/null
+++ b/modeling/official/common/distribute_utils.py
@@ -0,0 +1,233 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper functions for running models in a distributed setting."""
+
+import json
+import os
+import tensorflow as tf, tf_keras
+
+
+def _collective_communication(all_reduce_alg):
+ """Return a CollectiveCommunication based on all_reduce_alg.
+
+ Args:
+ all_reduce_alg: a string specifying which collective communication to pick,
+ or None.
+
+ Returns:
+ tf.distribute.experimental.CollectiveCommunication object
+
+ Raises:
+ ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
+ """
+ collective_communication_options = {
+ None: tf.distribute.experimental.CollectiveCommunication.AUTO,
+ "ring": tf.distribute.experimental.CollectiveCommunication.RING,
+ "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
+ }
+ if all_reduce_alg not in collective_communication_options:
+ raise ValueError(
+ "When used with `multi_worker_mirrored`, valid values for "
+ "all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
+ all_reduce_alg))
+ return collective_communication_options[all_reduce_alg]
+
+
+def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
+ """Return a CrossDeviceOps based on all_reduce_alg and num_packs.
+
+ Args:
+ all_reduce_alg: a string specifying which cross device op to pick, or None.
+ num_packs: an integer specifying number of packs for the cross device op.
+
+ Returns:
+ tf.distribute.CrossDeviceOps object or None.
+
+ Raises:
+ ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
+ """
+ if all_reduce_alg is None:
+ return None
+ mirrored_all_reduce_options = {
+ "nccl": tf.distribute.NcclAllReduce,
+ "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
+ }
+ if all_reduce_alg not in mirrored_all_reduce_options:
+ raise ValueError(
+ "When used with `mirrored`, valid values for all_reduce_alg are "
+ "[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
+ all_reduce_alg))
+ cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
+ return cross_device_ops_class(num_packs=num_packs)
+
+
+def tpu_initialize(tpu_address):
+ """Initializes TPU for TF 2.x training.
+
+ Args:
+ tpu_address: string, bns address of master TPU worker.
+
+ Returns:
+ A TPUClusterResolver.
+ """
+ cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
+ tpu=tpu_address)
+ if tpu_address not in ("", "local"):
+ tf.config.experimental_connect_to_cluster(cluster_resolver)
+ tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
+ return cluster_resolver
+
+
+def get_distribution_strategy(distribution_strategy="mirrored",
+ num_gpus=0,
+ all_reduce_alg=None,
+ num_packs=1,
+ tpu_address=None,
+ **kwargs):
+ """Return a Strategy for running the model.
+
+ Args:
+ distribution_strategy: a string specifying which distribution strategy to
+ use. Accepted values are "off", "one_device", "mirrored",
+ "parameter_server", "multi_worker_mirrored", and "tpu" -- case
+ insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
+ "off" means to use the default strategy which is obtained from
+ tf.distribute.get_strategy (for details on the default strategy, see
+ https://www.tensorflow.org/guide/distributed_training#default_strategy).
+ num_gpus: Number of GPUs to run this model.
+ all_reduce_alg: Optional. Specifies which algorithm to use when performing
+ all-reduce. For `MirroredStrategy`, valid values are "nccl" and
+ "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
+ "ring" and "nccl". If None, DistributionStrategy will choose based on
+ device topology.
+ num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
+ or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
+ tpu_address: Optional. String that represents TPU to connect to. Must not be
+ None if `distribution_strategy` is set to `tpu`.
+ **kwargs: Additional kwargs for internal usages.
+
+ Returns:
+ tf.distribute.Strategy object.
+ Raises:
+ ValueError: if `distribution_strategy` is "off" or "one_device" and
+ `num_gpus` is larger than 1; or `num_gpus` is negative or if
+ `distribution_strategy` is `tpu` but `tpu_address` is not specified.
+ """
+ del kwargs
+ if num_gpus < 0:
+ raise ValueError("`num_gpus` can not be negative.")
+
+ if not isinstance(distribution_strategy, str):
+ msg = ("distribution_strategy must be a string but got: %s." %
+ (distribution_strategy,))
+ if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
+ msg += (" If you meant to pass the string 'off', make sure you add "
+ "quotes around 'off' so that yaml interprets it as a string "
+ "instead of a bool.")
+ raise ValueError(msg)
+
+ distribution_strategy = distribution_strategy.lower()
+ if distribution_strategy == "off":
+ if num_gpus > 1:
+ raise ValueError(f"When {num_gpus} GPUs are specified, "
+ "distribution_strategy flag cannot be set to `off`.")
+ # Return the default distribution strategy.
+ return tf.distribute.get_strategy()
+
+ if distribution_strategy == "tpu":
+ # When tpu_address is an empty string, we communicate with local TPUs.
+ cluster_resolver = tpu_initialize(tpu_address)
+ return tf.distribute.TPUStrategy(cluster_resolver)
+
+ if distribution_strategy == "multi_worker_mirrored":
+ return tf.distribute.experimental.MultiWorkerMirroredStrategy(
+ communication=_collective_communication(all_reduce_alg))
+
+ if distribution_strategy == "one_device":
+ if num_gpus == 0:
+ return tf.distribute.OneDeviceStrategy("device:CPU:0")
+ if num_gpus > 1:
+ raise ValueError("`OneDeviceStrategy` can not be used for more than "
+ "one device.")
+ return tf.distribute.OneDeviceStrategy("device:GPU:0")
+
+ if distribution_strategy == "mirrored":
+ if num_gpus == 0:
+ devices = ["device:CPU:0"]
+ else:
+ devices = ["device:GPU:%d" % i for i in range(num_gpus)]
+ return tf.distribute.MirroredStrategy(
+ devices=devices,
+ cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
+
+ if distribution_strategy == "parameter_server":
+ cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
+ return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
+
+ raise ValueError("Unrecognized Distribution Strategy: %r" %
+ distribution_strategy)
+
+
+def configure_cluster(worker_hosts=None, task_index=-1):
+ """Set multi-worker cluster spec in TF_CONFIG environment variable.
+
+ Args:
+ worker_hosts: comma-separated list of worker ip:port pairs.
+ task_index: index of the worker.
+
+ Returns:
+ Number of workers in the cluster.
+ """
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if tf_config:
+ num_workers = (
+ len(tf_config["cluster"].get("chief", [])) +
+ len(tf_config["cluster"].get("worker", [])))
+ elif worker_hosts:
+ workers = worker_hosts.split(",")
+ num_workers = len(workers)
+ if num_workers > 1 and task_index < 0:
+ raise ValueError("Must specify task_index when number of workers > 1")
+ task_index = 0 if num_workers == 1 else task_index
+ os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": workers
+ },
+ "task": {
+ "type": "worker",
+ "index": task_index
+ }
+ })
+ else:
+ num_workers = 1
+ return num_workers
+
+
+def get_strategy_scope(strategy):
+ if strategy:
+ strategy_scope = strategy.scope()
+ else:
+ strategy_scope = DummyContextManager()
+
+ return strategy_scope
+
+
+class DummyContextManager(object):
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ pass
diff --git a/modeling/official/common/distribute_utils_test.py b/modeling/official/common/distribute_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f510f808c2c42dc96d013d426c318a15efb1a75
--- /dev/null
+++ b/modeling/official/common/distribute_utils_test.py
@@ -0,0 +1,124 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for distribution util functions."""
+
+import sys
+import tensorflow as tf, tf_keras
+
+from official.common import distribute_utils
+
+TPU_TEST = 'test_tpu' in sys.argv[0]
+
+
+class DistributeUtilsTest(tf.test.TestCase):
+ """Tests for distribute util functions."""
+
+ def test_invalid_args(self):
+ with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
+ _ = distribute_utils.get_distribution_strategy(num_gpus=-1)
+
+ with self.assertRaisesRegex(ValueError,
+ '.*If you meant to pass the string .*'):
+ _ = distribute_utils.get_distribution_strategy(
+ distribution_strategy=False, num_gpus=0)
+ with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
+ _ = distribute_utils.get_distribution_strategy(
+ distribution_strategy='off', num_gpus=2)
+ with self.assertRaisesRegex(ValueError,
+ '`OneDeviceStrategy` can not be used.*'):
+ _ = distribute_utils.get_distribution_strategy(
+ distribution_strategy='one_device', num_gpus=2)
+
+ def test_one_device_strategy_cpu(self):
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
+ self.assertEquals(ds.num_replicas_in_sync, 1)
+ self.assertEquals(len(ds.extended.worker_devices), 1)
+ self.assertIn('CPU', ds.extended.worker_devices[0])
+
+ def test_one_device_strategy_gpu(self):
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
+ self.assertEquals(ds.num_replicas_in_sync, 1)
+ self.assertEquals(len(ds.extended.worker_devices), 1)
+ self.assertIn('GPU', ds.extended.worker_devices[0])
+
+ def test_mirrored_strategy(self):
+ # CPU only.
+ _ = distribute_utils.get_distribution_strategy(num_gpus=0)
+ # 5 GPUs.
+ ds = distribute_utils.get_distribution_strategy(num_gpus=5)
+ self.assertEquals(ds.num_replicas_in_sync, 5)
+ self.assertEquals(len(ds.extended.worker_devices), 5)
+ for device in ds.extended.worker_devices:
+ self.assertIn('GPU', device)
+
+ _ = distribute_utils.get_distribution_strategy(
+ distribution_strategy='mirrored',
+ num_gpus=2,
+ all_reduce_alg='nccl',
+ num_packs=2)
+ with self.assertRaisesRegex(
+ ValueError,
+ 'When used with `mirrored`, valid values for all_reduce_alg are.*'):
+ _ = distribute_utils.get_distribution_strategy(
+ distribution_strategy='mirrored',
+ num_gpus=2,
+ all_reduce_alg='dummy',
+ num_packs=2)
+
+ def test_mwms(self):
+ distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
+ ds = distribute_utils.get_distribution_strategy(
+ 'multi_worker_mirrored', all_reduce_alg='nccl')
+ self.assertIsInstance(
+ ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
+
+ with self.assertRaisesRegex(
+ ValueError,
+ 'When used with `multi_worker_mirrored`, valid values.*'):
+ _ = distribute_utils.get_distribution_strategy(
+ 'multi_worker_mirrored', all_reduce_alg='dummy')
+
+ def test_no_strategy(self):
+ ds = distribute_utils.get_distribution_strategy('off')
+ self.assertIs(ds, tf.distribute.get_strategy())
+
+ def test_tpu_strategy(self):
+ if not TPU_TEST:
+ self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
+ with self.assertRaises(ValueError):
+ _ = distribute_utils.get_distribution_strategy('tpu')
+
+ ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
+ self.assertIsInstance(
+ ds, tf.distribute.TPUStrategy)
+
+ def test_invalid_strategy(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'distribution_strategy must be a string but got: False. If'):
+ distribute_utils.get_distribution_strategy(False)
+ with self.assertRaisesRegexp(
+ ValueError, 'distribution_strategy must be a string but got: 1'):
+ distribute_utils.get_distribution_strategy(1)
+
+ def test_get_strategy_scope(self):
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
+ with distribute_utils.get_strategy_scope(ds):
+ self.assertIs(tf.distribute.get_strategy(), ds)
+ with distribute_utils.get_strategy_scope(None):
+ self.assertIsNot(tf.distribute.get_strategy(), ds)
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/common/flags.py b/modeling/official/common/flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5a471522f961b2b1e3fe78d4174b2f73fe9ae97
--- /dev/null
+++ b/modeling/official/common/flags.py
@@ -0,0 +1,114 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The central place to define flags."""
+
+from absl import flags
+
+
+def define_flags():
+ """Defines flags.
+
+ All flags are defined as optional, but in practice most models use some of
+ these flags and so mark_flags_as_required() should be called after calling
+ this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
+ For example:
+
+ ```
+ from absl import flags
+ from official.common import flags as tfm_flags # pylint: disable=line-too-long
+ ...
+ tfm_flags.define_flags()
+ flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
+ ```
+
+ The reason all flags are optional is because unit tests often do not set or
+ use any of the flags.
+ """
+ flags.DEFINE_string(
+ 'experiment', default=None, help=
+ 'The experiment type registered, specifying an ExperimentConfig.')
+
+ flags.DEFINE_enum(
+ 'mode',
+ default=None,
+ enum_values=[
+ 'train', 'eval', 'train_and_eval', 'continuous_eval',
+ 'continuous_train_and_eval', 'train_and_validate',
+ 'train_and_post_eval'
+ ],
+ help='Mode to run: `train`, `eval`, `train_and_eval`, '
+ '`continuous_eval`, `continuous_train_and_eval` and '
+ '`train_and_validate` (which is not implemented in '
+ 'the open source version).')
+
+ flags.DEFINE_string(
+ 'model_dir',
+ default=None,
+ help='The directory where the model and training/evaluation summaries'
+ 'are stored.')
+
+ flags.DEFINE_multi_string(
+ 'config_file',
+ default=None,
+ help='YAML/JSON files which specifies overrides. The override order '
+ 'follows the order of args. Note that each file '
+ 'can be used as an override template to override the default parameters '
+ 'specified in Python. If the same parameter is specified in both '
+ '`--config_file` and `--params_override`, `config_file` will be used '
+ 'first, followed by params_override.')
+
+ flags.DEFINE_string(
+ 'params_override',
+ default=None,
+ help='a YAML/JSON string or a YAML file which specifies additional '
+ 'overrides over the default parameters and those specified in '
+ '`--config_file`. Note that this is supposed to be used only to override '
+ 'the model parameters, but not the parameters like TPU specific flags. '
+ 'One canonical use case of `--config_file` and `--params_override` is '
+ 'users first define a template config file using `--config_file`, then '
+ 'use `--params_override` to adjust the minimal set of tuning parameters, '
+ 'for example setting up different `train_batch_size`. The final override '
+ 'order of parameters: default_model_params --> params from config_file '
+ '--> params in params_override. See also the help message of '
+ '`--config_file`.')
+
+ # The libraries rely on gin often make mistakes that include flags inside
+ # the library files which causes conflicts.
+ try:
+ flags.DEFINE_multi_string(
+ 'gin_file', default=None, help='List of paths to the config files.')
+ except flags.DuplicateFlagError:
+ pass
+
+ try:
+ flags.DEFINE_multi_string(
+ 'gin_params',
+ default=None,
+ help='Newline separated list of Gin parameter bindings.')
+ except flags.DuplicateFlagError:
+ pass
+
+ flags.DEFINE_string(
+ 'tpu',
+ default=None,
+ help='The Cloud TPU to use for training. This should be either the name '
+ 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
+ 'url.')
+
+ flags.DEFINE_string(
+ 'tf_data_service', default=None, help='The tf.data service address')
+
+ flags.DEFINE_string(
+ 'tpu_platform', default=None, help='TPU platform type.')
diff --git a/modeling/official/common/registry_imports.py b/modeling/official/common/registry_imports.py
new file mode 100644
index 0000000000000000000000000000000000000000..05ba505882b452d94a049dc12fb4e1b69058f68a
--- /dev/null
+++ b/modeling/official/common/registry_imports.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""All necessary imports for registration."""
+# pylint: disable=unused-import
+from official import vision
+from official.nlp import tasks
+from official.nlp.configs import experiment_configs
+from official.utils.testing import mock_task
diff --git a/modeling/official/common/streamz_counters.py b/modeling/official/common/streamz_counters.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6066ab54f9f7bdf26c44b6a927e40b582f0b351
--- /dev/null
+++ b/modeling/official/common/streamz_counters.py
@@ -0,0 +1,27 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Global streamz counters."""
+
+from tensorflow.python.eager import monitoring
+
+
+progressive_policy_creation_counter = monitoring.Counter(
+ "/tensorflow/training/fast_training/progressive_policy_creation",
+ "Counter for the number of ProgressivePolicy creations.")
+
+
+stack_vars_to_vars_call_counter = monitoring.Counter(
+ "/tensorflow/training/fast_training/tf_vars_to_vars",
+ "Counter for the number of low-level stacking API calls.")
diff --git a/modeling/official/core/__init__.py b/modeling/official/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b4f7623831e38c5c7cb6558cafac691ed4c232
--- /dev/null
+++ b/modeling/official/core/__init__.py
@@ -0,0 +1,31 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Core is shared by both `nlp` and `vision`."""
+
+from official.core import actions
+from official.core import base_task
+from official.core import base_trainer
+from official.core import config_definitions
+from official.core import exp_factory
+from official.core import export_base
+from official.core import file_writers
+from official.core import input_reader
+from official.core import registry
+from official.core import savedmodel_checkpoint_manager
+from official.core import task_factory
+from official.core import tf_example_builder
+from official.core import tf_example_feature_key
+from official.core import train_lib
+from official.core import train_utils
diff --git a/modeling/official/core/actions.py b/modeling/official/core/actions.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5555614f7700a7c2ff344a0c37af0ae3262959b
--- /dev/null
+++ b/modeling/official/core/actions.py
@@ -0,0 +1,236 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides TFM orbit actions and associated helper functions/classes."""
+
+import os
+from typing import List
+from absl import logging
+
+import gin
+import orbit
+import tensorflow as tf, tf_keras
+
+from official.core import base_trainer
+from official.core import config_definitions
+from official.modeling import optimization
+
+
+class PruningAction:
+ """Train action to updates pruning related information.
+
+ This action updates pruning steps at the end of trainig loop, and log
+ pruning metrics to tensorboard.
+
+ This action must be used when training a pruned model to avoid pruning error.
+ """
+
+ def __init__(
+ self,
+ export_dir: str,
+ model: tf_keras.Model,
+ optimizer: tf_keras.optimizers.Optimizer,
+ ):
+ """Initializes the instance.
+
+ Args:
+ export_dir: `str` for the export directory of the pruning summaries.
+ model: `tf_keras.Model` model instance used for training. This will be
+ used to assign a pruning step to each prunable weight.
+ optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for
+ training. This will be used to find the current training steps.
+ """
+ # TODO(b/221490190): Avoid local import when the bug is fixed.
+ import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
+ self._optimizer = optimizer
+ self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
+ self.update_pruning_step.set_model(model)
+ self.update_pruning_step.on_train_begin()
+
+ self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
+ log_dir=export_dir)
+ model.optimizer = optimizer
+ self.pruning_summaries.set_model(model)
+
+ def __call__(self, output: orbit.runner.Output):
+ """Update pruning step and log pruning summaries.
+
+ Args:
+ output: The train output.
+ """
+ self.update_pruning_step.on_epoch_end(batch=None)
+ self.pruning_summaries.on_epoch_begin(epoch=None)
+
+
+class EMACheckpointing:
+ """Eval action to save checkpoint with average weights when EMA is used.
+
+ This action swaps the weights of the model with the average weights, then it
+ saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
+ expensive for large models, so doing this action in eval is more efficient
+ than training.
+ """
+
+ def __init__(self,
+ export_dir: str,
+ optimizer: tf_keras.optimizers.Optimizer,
+ checkpoint: tf.train.Checkpoint,
+ max_to_keep: int = 1):
+ """Initializes the instance.
+
+ Args:
+ export_dir: `str` for the export directory of the EMA average weights.
+ optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for
+ training. This will be used to swap the model weights with the average
+ weigths.
+ checkpoint: `tf.train.Checkpoint` instance.
+ max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
+ """
+ if not isinstance(optimizer, optimization.ExponentialMovingAverage):
+ raise ValueError('Optimizer has to be instance of'
+ 'optimization.ExponentialMovingAverage for'
+ 'EMACheckpointing action')
+
+ export_dir = os.path.join(export_dir, 'ema_checkpoints')
+ tf.io.gfile.makedirs(os.path.dirname(export_dir))
+ self._optimizer = optimizer
+ self._checkpoint = checkpoint
+ self._checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=export_dir,
+ max_to_keep=max_to_keep,
+ checkpoint_name='average_weights')
+
+ def __call__(self, output: orbit.runner.Output):
+ """Swaps model weights, and saves the checkpoint.
+
+ Args:
+ output: The train or eval output.
+ """
+ self._optimizer.swap_weights()
+ self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
+ self._optimizer.swap_weights()
+
+
+class RecoveryAction:
+ """Train action to recover from loss blowup.
+
+ Checks the loss value by the given threshold. If applicable, recover the
+ model by reading the checkpoint on disk.
+ """
+
+ def __init__(self, checkpoint_manager: tf.train.CheckpointManager):
+ self.checkpoint_manager = checkpoint_manager
+
+ def __call__(self, _):
+ """Recovers the training by triggering checkpoint restoration."""
+ # Loads the previous good checkpoint.
+ checkpoint_path = self.checkpoint_manager.restore_or_initialize()
+ logging.warning('Recovering the model from checkpoint: %s.',
+ checkpoint_path)
+
+
+class RecoveryCondition:
+ """Recovery Condition."""
+
+ def __init__(self,
+ global_step: tf.Variable,
+ loss_upper_bound: float,
+ recovery_begin_steps: int = 0,
+ recovery_max_trials: int = 3):
+ self.recover_counter = 0
+ self.recovery_begin_steps = recovery_begin_steps
+ self.recovery_max_trials = recovery_max_trials
+ self.loss_upper_bound = loss_upper_bound
+ self.global_step = global_step
+
+ def __call__(self, outputs: orbit.runner.Output):
+ loss_value = outputs['training_loss']
+ if tf.math.is_nan(loss_value):
+ self.recover_counter += 1
+ if self.recover_counter > self.recovery_max_trials:
+ raise RuntimeError(
+ 'The loss value is NaN after training loop and it happens %d times.'
+ % self.recover_counter)
+ return True
+ if (self.global_step >= self.recovery_begin_steps and
+ loss_value > self.loss_upper_bound):
+ self.recover_counter += 1
+ if self.recover_counter > self.recovery_max_trials:
+ raise RuntimeError(
+ f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.'
+ )
+ return True
+ return False
+
+
+@gin.configurable
+def get_eval_actions(params: config_definitions.ExperimentConfig,
+ trainer: base_trainer.Trainer,
+ model_dir: str) -> List[orbit.Action]:
+ """Gets eval actions for TFM trainer."""
+ eval_actions = []
+ # Adds ema checkpointing action to save the average weights under
+ # ema_checkpoints subdir.
+ if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage):
+ eval_actions.append(
+ EMACheckpointing(
+ export_dir=model_dir,
+ optimizer=trainer.optimizer,
+ checkpoint=trainer.checkpoint,
+ max_to_keep=params.trainer.max_to_keep))
+
+ return eval_actions
+
+
+@gin.configurable
+def get_train_actions(
+ params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
+ model_dir: str,
+ checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]:
+ """Gets train actions for TFM trainer."""
+ train_actions = []
+ # Adds pruning callback actions.
+ if hasattr(params.task, 'pruning') and params.task.pruning:
+ train_actions.append(
+ PruningAction(
+ export_dir=model_dir,
+ model=trainer.model,
+ optimizer=trainer.optimizer))
+
+ if params.trainer.recovery_max_trials >= 0:
+ recovery_condition = RecoveryCondition(
+ global_step=trainer.global_step,
+ loss_upper_bound=params.trainer.loss_upper_bound,
+ recovery_begin_steps=params.trainer.recovery_begin_steps,
+ recovery_max_trials=params.trainer.recovery_max_trials,
+ )
+ recover_action = orbit.actions.ConditionalAction(
+ condition=recovery_condition,
+ action=RecoveryAction(checkpoint_manager),
+ )
+ train_actions.append(recover_action)
+
+ if (
+ params.trainer.preemption_on_demand_checkpoint
+ and trainer.strategy.cluster_resolver
+ ):
+ on_demand_checkpoint_action = orbit.actions.SaveCheckpointIfPreempted(
+ trainer.strategy.cluster_resolver,
+ checkpoint_manager,
+ trainer.global_step,
+ keep_running_after_save=True,
+ )
+ train_actions.append(on_demand_checkpoint_action)
+ return train_actions
diff --git a/modeling/official/core/actions_test.py b/modeling/official/core/actions_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d5ee54501fafe903fc74018a43d87eee78d6b63
--- /dev/null
+++ b/modeling/official/core/actions_test.py
@@ -0,0 +1,131 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for TFM actions."""
+
+import os
+
+from absl.testing import parameterized
+import numpy as np
+import orbit
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.core import actions
+from official.modeling import optimization
+
+
+class TestModel(tf_keras.Model):
+
+ def __init__(self):
+ super().__init__()
+ self.value = tf.Variable(0.0)
+ self.dense = tf_keras.layers.Dense(2)
+ _ = self.dense(tf.zeros((2, 2), tf.float32))
+
+ def call(self, x, training=None):
+ return self.value + x
+
+
+class ActionsTest(tf.test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy,
+ ],))
+ def test_ema_checkpointing(self, distribution):
+ with distribution.scope():
+ directory = self.create_tempdir()
+ model = TestModel()
+ optimizer = tf_keras.optimizers.SGD()
+ optimizer = optimization.ExponentialMovingAverage(
+ optimizer, trainable_weights_only=False)
+
+ # Creats average weights for the model variables. Average weights are
+ # initialized to zero.
+ optimizer.shadow_copy(model)
+ checkpoint = tf.train.Checkpoint(model=model)
+
+ # Changes model.value to 3, average value is still 0.
+ model.value.assign(3)
+
+ # Checks model.value is 3
+ self.assertEqual(model(0.), 3)
+ ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
+
+ ema_action({})
+ self.assertNotEmpty(
+ tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
+
+ checkpoint.read(
+ tf.train.latest_checkpoint(
+ os.path.join(directory, 'ema_checkpoints')))
+
+ # Checks model.value is 0 after swapping.
+ self.assertEqual(model(0.), 0)
+
+ # Raises an error for a normal optimizer.
+ with self.assertRaisesRegex(ValueError,
+ 'Optimizer has to be instance of.*'):
+ _ = actions.EMACheckpointing(directory, tf_keras.optimizers.SGD(),
+ checkpoint)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],))
+ def test_recovery_condition(self, distribution):
+ with distribution.scope():
+ global_step = orbit.utils.create_global_step()
+ recover_condition = actions.RecoveryCondition(
+ global_step, loss_upper_bound=0.5, recovery_max_trials=2)
+ outputs = {'training_loss': 0.6}
+ self.assertTrue(recover_condition(outputs))
+ self.assertTrue(recover_condition(outputs))
+ with self.assertRaises(RuntimeError):
+ recover_condition(outputs)
+
+ global_step = orbit.utils.create_global_step()
+ recover_condition = actions.RecoveryCondition(
+ global_step, loss_upper_bound=0.5, recovery_max_trials=2)
+ outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
+ self.assertTrue(recover_condition(outputs))
+ self.assertTrue(recover_condition(outputs))
+ with self.assertRaises(RuntimeError):
+ recover_condition(outputs)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.one_device_strategy,
+ ],))
+ def test_pruning(self, distribution):
+ with distribution.scope():
+ directory = self.get_temp_dir()
+ model = TestModel()
+ optimizer = tf_keras.optimizers.SGD()
+ pruning = actions.PruningAction(directory, model, optimizer)
+
+ pruning({})
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/base_task.py b/modeling/official/core/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb95b1f3ca047dcc6e0419d81355b384e648b1e8
--- /dev/null
+++ b/modeling/official/core/base_task.py
@@ -0,0 +1,360 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines the base task abstraction."""
+import abc
+import functools
+from typing import Optional
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.core import config_definitions
+from official.modeling import optimization
+from official.modeling import performance
+from official.modeling.privacy import configs
+from official.modeling.privacy import ops
+
+OptimizationConfig = optimization.OptimizationConfig
+RuntimeConfig = config_definitions.RuntimeConfig
+DifferentialPrivacyConfig = configs.DifferentialPrivacyConfig
+
+
+class Task(tf.Module, metaclass=abc.ABCMeta):
+ """A single-replica view of training procedure.
+
+ Tasks provide artifacts for training/validation procedures, including
+ loading/iterating over Datasets, training/validation steps, calculating the
+ loss and customized metrics with reduction.
+ """
+
+ # Special keys in train/validate step returned logs.
+ loss = "loss"
+
+ def __init__(self,
+ params,
+ logging_dir: Optional[str] = None,
+ name: Optional[str] = None):
+ """Task initialization.
+
+ Args:
+ params: the task configuration instance, which can be any of dataclass,
+ ConfigDict, namedtuple, etc.
+ logging_dir: a string pointing to where the model, summaries etc. will be
+ saved. You can also write additional stuff in this directory.
+ name: the task name.
+ """
+ super().__init__(name=name)
+ self._task_config = params
+ self._logging_dir = (
+ logging_dir or ""
+ ) # Empty directory hints current working dir.
+
+ @property
+ def task_config(self):
+ return self._task_config
+
+ @property
+ def logging_dir(self) -> str:
+ return self._logging_dir
+
+ @classmethod
+ def create_optimizer(cls, optimizer_config: OptimizationConfig,
+ runtime_config: Optional[RuntimeConfig] = None,
+ dp_config: Optional[DifferentialPrivacyConfig] = None):
+ """Creates an TF optimizer from configurations.
+
+ Args:
+ optimizer_config: the parameters of the Optimization settings.
+ runtime_config: the parameters of the runtime.
+ dp_config: the parameter of differential privacy.
+
+ Returns:
+ A tf.optimizers.Optimizer object.
+ """
+ gradient_transformers = None
+ if dp_config is not None:
+ logging.info("Adding differential privacy transform with config %s.",
+ dp_config.as_dict())
+ noise_stddev = dp_config.clipping_norm * dp_config.noise_multiplier
+ gradient_transformers = [
+ functools.partial(
+ ops.clip_l2_norm, l2_norm_clip=dp_config.clipping_norm),
+ functools.partial(
+ ops.add_noise, noise_stddev=noise_stddev)
+ ]
+
+ opt_factory = optimization.OptimizerFactory(optimizer_config)
+ optimizer = opt_factory.build_optimizer(
+ opt_factory.build_learning_rate(),
+ gradient_transformers=gradient_transformers
+ )
+ # Configuring optimizer when loss_scale is set in runtime config. This helps
+ # avoiding overflow/underflow for float16 computations.
+ if runtime_config:
+ optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=runtime_config.mixed_precision_dtype == "float16",
+ loss_scale=runtime_config.loss_scale)
+
+ return optimizer
+
+ def initialize(self, model: tf_keras.Model):
+ """[Optional] A callback function used as CheckpointManager's init_fn.
+
+ This function will be called when no checkpoint is found for the model.
+ If there is a checkpoint, the checkpoint will be loaded and this function
+ will not be called. You can use this callback function to load a pretrained
+ checkpoint, saved under a directory other than the model_dir.
+
+ Args:
+ model: The keras.Model built or used by this task.
+ """
+ ckpt_dir_or_file = self.task_config.init_checkpoint
+ logging.info("Trying to load pretrained checkpoint from %s",
+ ckpt_dir_or_file)
+ if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
+ ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
+ if not ckpt_dir_or_file:
+ logging.info("No checkpoint file found from %s. Will not load.",
+ ckpt_dir_or_file)
+ return
+
+ if hasattr(model, "checkpoint_items"):
+ checkpoint_items = model.checkpoint_items
+ else:
+ checkpoint_items = dict(model=model)
+ ckpt = tf.train.Checkpoint(**checkpoint_items)
+ status = ckpt.read(ckpt_dir_or_file)
+ status.expect_partial().assert_existing_objects_matched()
+ logging.info("Finished loading pretrained checkpoint from %s",
+ ckpt_dir_or_file)
+
+ def build_model(self) -> tf_keras.Model:
+ """[Optional] Creates model architecture.
+
+ Returns:
+ A model instance.
+ """ # pytype: disable=bad-return-type # typed-keras
+
+ @abc.abstractmethod
+ def build_inputs(self,
+ params,
+ input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a dataset or a nested structure of dataset functions.
+
+ Dataset functions define per-host datasets with the per-replica batch size.
+ With distributed training, this method runs on remote hosts.
+
+ Args:
+ params: hyperparams to create input pipelines, which can be any of
+ dataclass, ConfigDict, namedtuple, etc.
+ input_context: optional distribution input pipeline context.
+
+ Returns:
+ A nested structure of per-replica input functions.
+ """
+
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
+ """Standard interface to compute losses.
+
+ Args:
+ labels: optional label tensors.
+ model_outputs: a nested structure of output tensors.
+ aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
+
+ Returns:
+ The total loss tensor.
+ """
+ del model_outputs, labels
+
+ if aux_losses is None:
+ losses = [tf.constant(0.0, dtype=tf.float32)]
+ else:
+ losses = aux_losses
+ total_loss = tf.add_n(losses)
+ return total_loss
+
+ def build_metrics(self, training: bool = True):
+ """Gets streaming metrics for training/validation."""
+ del training
+ return []
+
+ def process_metrics(self, metrics, labels, model_outputs, **kwargs):
+ """Process and update metrics.
+
+ Called when using custom training loop API.
+
+ Args:
+ metrics: a nested structure of metrics objects. The return of function
+ self.build_metrics.
+ labels: a tensor or a nested structure of tensors.
+ model_outputs: a tensor or a nested structure of tensors. For example,
+ output of the keras model built by self.build_model.
+ **kwargs: other args.
+ """
+ for metric in metrics:
+ metric.update_state(labels, model_outputs)
+
+ def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
+ """Process and update compiled_metrics.
+
+ call when using compile/fit API.
+
+ Args:
+ compiled_metrics: the compiled metrics (model.compiled_metrics).
+ labels: a tensor or a nested structure of tensors.
+ model_outputs: a tensor or a nested structure of tensors. For example,
+ output of the keras model built by self.build_model.
+ """
+ compiled_metrics.update_state(labels, model_outputs)
+
+ def train_step(self,
+ inputs,
+ model: tf_keras.Model,
+ optimizer: tf_keras.optimizers.Optimizer,
+ metrics=None):
+ """Does forward and backward.
+
+ With distribution strategies, this method runs on devices.
+
+ Args:
+ inputs: a dictionary of input tensors.
+ model: the model, forward pass definition.
+ optimizer: the optimizer for this training step.
+ metrics: a nested structure of metrics objects.
+
+ Returns:
+ A dictionary of logs.
+ """
+ if isinstance(inputs, tuple) and len(inputs) == 2:
+ features, labels = inputs
+ else:
+ features, labels = inputs, inputs
+ with tf.GradientTape() as tape:
+ outputs = model(features, training=True)
+ # Computes per-replica loss.
+ if model.compiled_loss:
+ loss = model.compiled_loss(
+ labels, outputs, regularization_losses=model.losses)
+ loss += self.build_losses(
+ labels=labels, model_outputs=outputs, aux_losses=None)
+ else:
+ loss = self.build_losses(
+ labels=labels, model_outputs=outputs, aux_losses=model.losses)
+ # Scales loss as the default gradients allreduce performs sum inside the
+ # optimizer.
+ scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
+
+ # For mixed precision, when a LossScaleOptimizer is used, the loss is
+ # scaled to avoid numeric underflow.
+ if isinstance(optimizer,
+ tf_keras.mixed_precision.LossScaleOptimizer):
+ scaled_loss = optimizer.get_scaled_loss(scaled_loss)
+
+ tvars = model.trainable_variables
+ grads = tape.gradient(scaled_loss, tvars)
+
+ if isinstance(optimizer,
+ tf_keras.mixed_precision.LossScaleOptimizer):
+ grads = optimizer.get_unscaled_gradients(grads)
+ optimizer.apply_gradients(list(zip(grads, tvars)))
+ logs = {self.loss: loss}
+ if metrics:
+ self.process_metrics(metrics, labels, outputs)
+ if model.compiled_metrics:
+ self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
+ logs.update({m.name: m.result() for m in metrics or []})
+ logs.update({m.name: m.result() for m in model.metrics})
+ return logs
+
+ def validation_step(self, inputs, model: tf_keras.Model, metrics=None):
+ """Validation step.
+
+ With distribution strategies, this method runs on devices.
+
+ Args:
+ inputs: a dictionary of input tensors.
+ model: the keras.Model.
+ metrics: a nested structure of metrics objects.
+
+ Returns:
+ A dictionary of logs.
+ """
+ if isinstance(inputs, tuple) and len(inputs) == 2:
+ features, labels = inputs
+ else:
+ features, labels = inputs, inputs
+ outputs = self.inference_step(features, model)
+ loss = self.build_losses(
+ labels=labels, model_outputs=outputs, aux_losses=model.losses)
+ logs = {self.loss: loss}
+ if metrics:
+ self.process_metrics(metrics, labels, outputs)
+ if model.compiled_metrics:
+ self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
+ logs.update({m.name: m.result() for m in metrics or []})
+ logs.update({m.name: m.result() for m in model.metrics})
+ return logs
+
+ def inference_step(self, inputs, model: tf_keras.Model):
+ """Performs the forward step.
+
+ With distribution strategies, this method runs on devices.
+
+ Args:
+ inputs: a dictionary of input tensors.
+ model: the keras.Model.
+
+ Returns:
+ Model outputs.
+ """
+ return model(inputs, training=False)
+
+ def aggregate_logs(self, state, step_logs):
+ """Optional aggregation over logs returned from a validation step.
+
+ Given step_logs from a validation step, this function aggregates the logs
+ after each eval_step() (see eval_reduce() function in
+ official/core/base_trainer.py). It runs on CPU and can be used to aggregate
+ metrics during validation, when there are too many metrics that cannot fit
+ into TPU memory. Note that this may increase latency due to data transfer
+ between TPU and CPU. Also, the step output from a validation step may be a
+ tuple with elements from replicas, and a concatenation of the elements is
+ needed in such case.
+
+ Args:
+ state: The current state of training, for example, it can be a sequence of
+ metrics.
+ step_logs: Logs from a validation step. Can be a dictionary.
+ """
+ pass
+
+ def reduce_aggregated_logs(self,
+ aggregated_logs,
+ global_step: Optional[tf.Tensor] = None):
+ """Optional reduce of aggregated logs over validation steps.
+
+ This function reduces aggregated logs at the end of validation, and can be
+ used to compute the final metrics. It runs on CPU and in each eval_end() in
+ base trainer (see eval_end() function in official/core/base_trainer.py).
+
+ Args:
+ aggregated_logs: Aggregated logs over multiple validation steps.
+ global_step: An optional variable of global step.
+
+ Returns:
+ A dictionary of reduced results.
+ """
+ return {}
diff --git a/modeling/official/core/base_trainer.py b/modeling/official/core/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d8c7821700d26a9b7a9092c8497acc3d915eeae
--- /dev/null
+++ b/modeling/official/core/base_trainer.py
@@ -0,0 +1,498 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Standard Trainer implementation.
+
+The base trainer implements the Orbit `StandardTrainable` and
+`StandardEvaluable` interfaces. Trainers inside this project should be
+interchangable and independent on model architectures and tasks.
+"""
+import functools
+from typing import Union, Optional
+from absl import logging
+import gin
+import orbit
+import tensorflow as tf, tf_keras
+
+from official.core import base_task
+from official.core import config_definitions
+from official.modeling import optimization
+
+ExperimentConfig = config_definitions.ExperimentConfig
+TrainerConfig = config_definitions.TrainerConfig
+
+
+class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
+ """Trainer class for both sync and async Strategy."""
+
+ def init_async(self):
+ """Initializes the Async Trainer base class."""
+ assert isinstance(self._strategy, tf.distribute.Strategy)
+ self._is_async = isinstance(
+ self._strategy, tf.distribute.experimental.ParameterServerStrategy)
+ self._coordinator = None
+ if self._is_async:
+ self._coordinator = (
+ tf.distribute.experimental.coordinator.ClusterCoordinator(
+ self._strategy))
+
+ def coordinator_for_async(
+ self,
+ ) -> tf.distribute.experimental.coordinator.ClusterCoordinator:
+ if not self._coordinator:
+ raise ValueError(
+ "Coordinator uninitialized for async run. Call init_async() first."
+ )
+ return self._coordinator
+
+ def join(self):
+ """Join all async steps. Only useful in aysnc training."""
+ if getattr(self, "_is_async", False):
+ self.coordinator_for_async().join()
+
+ def create_train_loop_fn(self):
+ """Creates a eval loop from the given step function and options."""
+ train_loop_fn = super().create_train_loop_fn()
+ if getattr(self, "_is_async", False):
+
+ def _async_loop_fn(iterator, num_steps):
+ self.coordinator_for_async().schedule(
+ train_loop_fn, args=(iterator, num_steps)
+ )
+
+ return _async_loop_fn
+ else:
+ return train_loop_fn
+
+ def create_eval_loop_fn(self, has_state: bool):
+ """Creates a training loop from the given step function and options."""
+ eval_loop_fn = super().create_eval_loop_fn(has_state)
+
+ if getattr(self, "_is_async", False):
+ if has_state:
+ raise ValueError(
+ "Stateful eval loop is not supported in async training.")
+
+ def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
+ assert state is None
+ assert reduce_fn is None
+ self.coordinator_for_async().schedule(
+ eval_loop_fn, args=(iterator, num_steps)
+ )
+
+ return _async_loop_fn
+ else:
+ return eval_loop_fn
+
+ def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
+ """A utility function to help create a `tf.distribute.DistributedDataset`.
+
+ Args:
+ dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
+ returning a `tf.data.Dataset`. If it is a function, it may optionally
+ have an argument named `input_context` which will be passed a
+ `tf.distribute.InputContext` instance.
+ *args: Any positional arguments to pass through to `dataset_or_fn`.
+ **kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
+
+ Returns:
+ A distributed Dataset.
+ """
+ if getattr(self, "_is_async", False):
+ per_worker_dataset_fn = functools.partial(
+ orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn,
+ *args, **kwargs)
+ per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
+
+ return self.coordinator_for_async().create_per_worker_dataset(
+ per_worker_dataset_fn
+ )
+ else:
+ return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
+ *args, **kwargs)
+
+
+def get_runtime_options(config: ExperimentConfig):
+ """Get tf.distribute.RunOptions from config."""
+ xla_options = {}
+ if config.runtime.tpu_enable_xla_dynamic_padder is not None:
+ xla_options["enable_xla_dynamic_padder"] = (
+ config.runtime.tpu_enable_xla_dynamic_padder)
+ return tf.distribute.RunOptions(
+ experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
+
+
+@gin.configurable
+class Trainer(_AsyncTrainer):
+ """Implements the common trainer shared for TensorFlow models."""
+
+ # pylint: disable=super-init-not-called
+ def __init__(
+ self,
+ config: ExperimentConfig,
+ task: base_task.Task,
+ model: tf_keras.Model,
+ optimizer: tf.optimizers.Optimizer,
+ train: bool = True,
+ evaluate: bool = True,
+ train_dataset: Optional[Union[tf.data.Dataset,
+ tf.distribute.DistributedDataset]] = None,
+ validation_dataset: Optional[Union[
+ tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
+ checkpoint_exporter=None):
+ """Initialize common trainer for TensorFlow models.
+
+ Args:
+ config: An `ExperimentConfig` instance specifying experiment config.
+ task: A base_task.Task instance.
+ model: The model instance, e.g. a tf_keras.Model instance.
+ optimizer: tf.optimizers.Optimizer instance.
+ train: bool, whether or not this trainer will be used for training.
+ default to True.
+ evaluate: bool, whether or not this trainer will be used for evaluation.
+ default to True.
+ train_dataset: a dataset object created for training. With tf.distribute,
+ it needs to be a `DistributedDataset`.
+ validation_dataset: a dataset object created for evaluation. With
+ tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
+ create a dataset iterator for each eval round, so the dataset does not
+ need to repeat.
+ checkpoint_exporter: an object that has the `maybe_export_checkpoint`
+ interface.
+ """
+ # Gets the current distribution strategy. If not inside any strategy scope,
+ # it gets a single-replica no-op strategy.
+ self._strategy = tf.distribute.get_strategy()
+ self._validate_params(
+ config,
+ check_train_data=train_dataset is None,
+ check_validation_data=validation_dataset is None)
+ self._config = config
+ self._task = task
+ self._model = model
+ self._optimizer = optimizer
+ self._checkpoint_exporter = checkpoint_exporter
+ self._recovery = None
+ # Runtime options are only applied to train_step.
+ # We use default for eval_step.
+ self._runtime_options = get_runtime_options(config)
+
+ # Creates a shadow copy of the weights to store weights moving average.
+ if isinstance(self._optimizer, optimization.ExponentialMovingAverage
+ ) and not self._optimizer.has_shadow_copy:
+ self._optimizer.shadow_copy(self._model)
+
+ # global_step increases by 1 after each training iteration.
+ # We should have global_step.numpy() == self.optimizer.iterations.numpy()
+ # when there is only 1 optimizer.
+ self._global_step = orbit.utils.create_global_step()
+ if hasattr(self.model, "checkpoint_items"):
+ checkpoint_items = self.model.checkpoint_items
+ else:
+ checkpoint_items = {}
+ self._checkpoint = tf.train.Checkpoint(
+ global_step=self.global_step,
+ model=self.model,
+ optimizer=self.optimizer,
+ **checkpoint_items)
+
+ self._train_loss = tf_keras.metrics.Mean("training_loss", dtype=tf.float32)
+ self._validation_loss = tf_keras.metrics.Mean(
+ "validation_loss", dtype=tf.float32)
+ model_metrics = model.metrics if hasattr(model, "metrics") else []
+
+ self.init_async()
+
+ if train:
+ self._train_metrics = self.task.build_metrics(
+ training=True) + model_metrics
+ train_dataset = train_dataset or self.distribute_dataset(
+ self.task.build_inputs, self.config.task.train_data)
+ orbit.StandardTrainer.__init__(
+ self,
+ train_dataset,
+ options=orbit.StandardTrainerOptions(
+ use_tf_while_loop=config.trainer.train_tf_while_loop,
+ use_tf_function=config.trainer.train_tf_function,
+ use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
+
+ if evaluate:
+ self._validation_metrics = self.task.build_metrics(
+ training=False) + model_metrics
+ validation_dataset = validation_dataset or self.distribute_dataset(
+ self.task.build_inputs, self.config.task.validation_data)
+ orbit.StandardEvaluator.__init__(
+ self,
+ validation_dataset,
+ options=orbit.StandardEvaluatorOptions(
+ use_tf_function=config.trainer.eval_tf_function,
+ use_tf_while_loop=config.trainer.eval_tf_while_loop))
+
+ def _validate_params(self,
+ config,
+ check_train_data=True,
+ check_validation_data=True):
+ r"""Validates if the configuration object passed to the Trainer.
+
+ The experiment configuration should be structured as:
+ \trainer
+ \task
+ \train_data
+ \validation_data
+
+ Args:
+ config: a namedtuple, dataclass, ConfigDict, etc.
+ check_train_data: whether to check task.train_data field.
+ check_validation_data: whether to check task.validation_data field.
+ """
+ if not hasattr(config, "trainer"):
+ raise AttributeError("The trainer requires the configuration contains an"
+ " attribute `trainer`.")
+
+ if not hasattr(config, "task"):
+ raise AttributeError("The trainer requires the configuration contains an"
+ " attribute `task`.")
+
+ if check_train_data and not hasattr(config.task, "train_data"):
+ raise AttributeError("The trainer requires the configuration contains an"
+ " attribute `task.train_data`.")
+
+ if check_validation_data and not hasattr(config.task, "validation_data"):
+ raise AttributeError("The trainer requires the configuration contains an"
+ " attribute `task.validation_data`.")
+
+ @property
+ def strategy(self):
+ return self._strategy
+
+ @property
+ def config(self):
+ return self._config
+
+ @property
+ def task(self):
+ return self._task
+
+ @property
+ def model(self):
+ return self._model
+
+ @property
+ def optimizer(self):
+ if hasattr(self, "_optimizer"):
+ return self._optimizer
+ else:
+ return None
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ @property
+ def train_loss(self):
+ """Accesses the training loss metric object."""
+ return self._train_loss
+
+ @property
+ def validation_loss(self):
+ """Accesses the validation loss metric object."""
+ return self._validation_loss
+
+ @property
+ def train_metrics(self):
+ """Accesses all training metric objects."""
+ return self._train_metrics
+
+ @property
+ def validation_metrics(self):
+ """Accesses all validation metric metric objects."""
+ return self._validation_metrics
+
+ def initialize(self):
+ """A callback function.
+
+ This function will be called when no checkpoint found for the model.
+ If there is a checkpoint, the checkpoint will be loaded and this function
+ will not be called. Tasks may use this callback function to load a
+ pretrained checkpoint, saved under a directory other than the model_dir.
+ """
+ self.task.initialize(self.model)
+
+ @property
+ def checkpoint(self):
+ """Accesses the training checkpoint."""
+ return self._checkpoint
+
+ @property
+ def checkpoint_exporter(self):
+ """Accesses the checkpoint exporter."""
+ return self._checkpoint_exporter
+
+ def train_loop_end(self):
+ """See base class."""
+ self.join()
+ logs = {}
+ for metric in self.train_metrics + [self.train_loss]:
+ logs[metric.name] = metric.result()
+ metric.reset_states()
+ if callable(self.optimizer.learning_rate):
+ # Maybe a self-implemented optimizer does not have `optimizer.iterations`.
+ # So just to be safe here.
+ if hasattr(self.optimizer, "iterations"):
+ logs["learning_rate"] = self.optimizer.learning_rate(
+ self.optimizer.iterations)
+ else:
+ logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
+ else:
+ logs["learning_rate"] = self.optimizer.learning_rate
+ return logs
+
+ def next_train_inputs(self, iterator):
+ """Fetches the next inputs for the model during train.
+
+ This method consumes the input iterator and returns the next inputs for the
+ model.
+
+ This method provides a way to control how to fetch the next model input, and
+ what data to send to the model.
+
+ Note: This function runs on the host side when accelerators are used.
+
+ Note: Depending on the training setup this may or may not run in eager mode.
+ In most cases it will be run in graph mode.
+
+ Args:
+ iterator: Dataset iterator to generate the next inputs from.
+
+ Returns:
+ The inputs to the model.
+ """
+ return next(iterator)
+
+ def train_step(self, iterator):
+ """See base class."""
+
+ def step_fn(inputs):
+ if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
+ task_train_step = tf.function(self.task.train_step, jit_compile=True)
+ else:
+ task_train_step = self.task.train_step
+ logs = task_train_step(
+ inputs,
+ model=self.model,
+ optimizer=self.optimizer,
+ metrics=self.train_metrics)
+ self._train_loss.update_state(logs[self.task.loss])
+ self.global_step.assign_add(1)
+
+ inputs = self.next_train_inputs(iterator)
+ self.strategy.run(step_fn, args=(inputs,), options=self._runtime_options)
+
+ def eval_begin(self):
+ """Sets up metrics."""
+ for metric in self.validation_metrics + [self.validation_loss]:
+ metric.reset_states()
+ # Swaps weights to test on weights moving average.
+ if self.optimizer and isinstance(self.optimizer,
+ optimization.ExponentialMovingAverage):
+ self.optimizer.swap_weights()
+
+ def next_eval_inputs(self, iterator):
+ """Fetches the next inputs for the model during eval.
+
+ This method consumes the input iterator and returns the next inputs for the
+ model and an additional logs dict. The output dict remains in the host (not
+ sent to GPUs/TPUs) and is merged with the model outputs which will be
+ processed later in `aggregate_logs`. This is useful for sending extra logs
+ downstream that are not compatible with the accelerators.
+
+ Note: This function runs on the host side when accelerators are used.
+
+ Note: Depending on the training setup this may or may not run in eager mode.
+ In most cases it will be run in graph mode.
+
+ Args:
+ iterator: Dataset iterator to generate the next inputs from.
+
+ Returns:
+ The inputs to the model, and an additional logs dictionnary. The logs
+ are not passed to the model, instead they are merged with model output
+ logs.
+ """
+ passthrough_logs = dict()
+ return next(iterator), passthrough_logs
+
+ def eval_step(self, iterator):
+ """See base class."""
+
+ def step_fn(inputs):
+ logs = self.task.validation_step(
+ inputs, model=self.model, metrics=self.validation_metrics)
+ if self.task.loss in logs:
+ self._validation_loss.update_state(logs[self.task.loss])
+ return logs
+
+ inputs, passthrough_logs = self.next_eval_inputs(iterator)
+ distributed_outputs = self.strategy.run(step_fn, args=(inputs,))
+ logs = tf.nest.map_structure(
+ self.strategy.experimental_local_results, distributed_outputs
+ )
+
+ if set(logs.keys()) & set(passthrough_logs.keys()):
+ logging.warning(
+ (
+ "Conflict between the pasthrough log keys and the returned model"
+ " log keys. Found %r keys in the passthrough logs and %r keys in"
+ " the model logs. Model log keys takes precedence."
+ ),
+ logs.keys(),
+ passthrough_logs.keys(),
+ )
+
+ return passthrough_logs | logs
+
+ def eval_end(self, aggregated_logs=None):
+ """Processes evaluation results."""
+ self.join()
+ logs = {}
+ for metric in self.validation_metrics:
+ logs[metric.name] = metric.result()
+ if self.validation_loss.count.numpy() != 0:
+ logs[self.validation_loss.name] = self.validation_loss.result()
+ else:
+ # `self.validation_loss` metric was not updated, because the validation
+ # loss was not returned from the task's `validation_step` method.
+ logging.info("The task did not report validation loss.")
+ if aggregated_logs:
+ metrics = self.task.reduce_aggregated_logs(
+ aggregated_logs, global_step=self.global_step)
+ logs.update(metrics)
+
+ if self._checkpoint_exporter:
+ self._checkpoint_exporter.maybe_export_checkpoint(
+ self.checkpoint, logs, self.global_step.numpy())
+ metric_name = self.config.trainer.best_checkpoint_eval_metric
+ logs["best_" +
+ metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name]
+
+ # Swaps back weights after testing when EMA is used.
+ # This happens after best checkpoint export so that average weights used for
+ # eval are exported instead of regular weights.
+ if self.optimizer and isinstance(self.optimizer,
+ optimization.ExponentialMovingAverage):
+ self.optimizer.swap_weights()
+ return logs
+
+ def eval_reduce(self, state=None, step_outputs=None):
+ return self.task.aggregate_logs(state, step_outputs)
diff --git a/modeling/official/core/base_trainer_test.py b/modeling/official/core/base_trainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1290cb4880cb4fb9805d702124e1a65dc68faf32
--- /dev/null
+++ b/modeling/official/core/base_trainer_test.py
@@ -0,0 +1,363 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for tensorflow_models.core.trainers.trainer."""
+# pylint: disable=g-direct-tensorflow-import
+import gc
+import multiprocessing
+import os
+import sys
+
+from absl.testing import parameterized
+import orbit
+import portpicker
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.core import base_trainer as trainer_lib
+from official.core import config_definitions as cfg
+from official.core import train_lib
+from official.utils.testing import mock_task
+
+TPU_TEST = 'test_tpu' in sys.argv[0]
+GPU_TEST = 'test_gpu' in sys.argv[0]
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],)
+
+
+def create_in_process_cluster(num_workers, num_ps):
+ """Creates and starts local servers and returns the cluster_resolver."""
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {}
+ cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
+ if num_ps > 0:
+ cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
+
+ cluster_spec = tf.train.ClusterSpec(cluster_dict)
+
+ # Workers need some inter_ops threads to work properly.
+ worker_config = tf.compat.v1.ConfigProto()
+ if multiprocessing.cpu_count() < num_workers + 1:
+ worker_config.inter_op_parallelism_threads = num_workers + 1
+
+ for i in range(num_workers):
+ tf.distribute.Server(
+ cluster_spec,
+ job_name='worker',
+ task_index=i,
+ config=worker_config,
+ protocol='grpc')
+
+ for i in range(num_ps):
+ tf.distribute.Server(
+ cluster_spec, job_name='ps', task_index=i, protocol='grpc')
+
+ cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
+ cluster_spec, rpc_layer='grpc')
+ return cluster_resolver
+
+
+def dataset_fn(input_context=None):
+ del input_context
+
+ def dummy_data(_):
+ return tf.zeros((1, 1), dtype=tf.float32)
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+class MockAsyncTrainer(trainer_lib._AsyncTrainer):
+ """Mock AsyncTrainer to test the _AsyncTrainer class."""
+
+ def __init__(self):
+ self._strategy = tf.distribute.get_strategy()
+ self.init_async()
+
+ self.global_step = tf.Variable(
+ 0,
+ dtype=tf.int64,
+ name='global_step',
+ trainable=False,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
+ self.eval_global_step = tf.Variable(
+ 0,
+ dtype=tf.int64,
+ name='eval_global_step',
+ trainable=False,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
+
+ train_dataset = self.distribute_dataset(dataset_fn)
+ orbit.StandardTrainer.__init__(
+ self, train_dataset, options=orbit.StandardTrainerOptions())
+
+ validation_dataset = self.distribute_dataset(dataset_fn)
+ orbit.StandardEvaluator.__init__(
+ self,
+ validation_dataset,
+ options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
+
+ def train_loop_begin(self):
+ self.global_step.assign(0)
+
+ def train_step(self, iterator):
+
+ def replica_step(_):
+ self.global_step.assign_add(1)
+
+ self._strategy.run(replica_step, args=(next(iterator),))
+
+ def train_loop_end(self):
+ self.join()
+ return self.global_step.numpy()
+
+ def eval_begin(self):
+ self.eval_global_step.assign(0)
+
+ def eval_step(self, iterator):
+
+ def replica_step(_):
+ self.eval_global_step.assign_add(1)
+
+ self._strategy.run(replica_step, args=(next(iterator),))
+
+ def eval_end(self):
+ self.join()
+ return self.eval_global_step.numpy()
+
+
+class TrainerTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._config = cfg.ExperimentConfig(
+ trainer=cfg.TrainerConfig(
+ optimizer_config=cfg.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'sgd'
+ },
+ 'learning_rate': {
+ 'type': 'constant'
+ }
+ })))
+
+ def tearDown(self):
+ gc.collect()
+ # This will only contain uncollectable garbage, i.e. reference cycles
+ # involving objects with __del__ defined.
+ self.assertEmpty(gc.garbage)
+ super().tearDown()
+
+ def create_test_trainer(self, config, model_dir=None, task=None):
+ task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
+ ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
+ trainer = trainer_lib.Trainer(
+ config,
+ task,
+ model=task.build_model(),
+ optimizer=task.create_optimizer(config.trainer.optimizer_config,
+ config.runtime),
+ checkpoint_exporter=ckpt_exporter)
+ return trainer
+
+ @combinations.generate(all_strategy_combinations())
+ def test_trainer_train(self, distribution):
+ with distribution.scope():
+ trainer = self.create_test_trainer(self._config)
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('training_loss', logs)
+ self.assertIn('learning_rate', logs)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_trainer_passing_datasets(self, distribution):
+ with distribution.scope():
+ task = mock_task.MockTask(self._config)
+ train_dataset = orbit.utils.make_distributed_dataset(
+ distribution, task.build_inputs, self._config.task.train_data)
+ validation_dataset = orbit.utils.make_distributed_dataset(
+ distribution, task.build_inputs, self._config.task.validation_data)
+ self._config.task.train_data = None
+ self._config.task.validation_data = None
+ trainer = trainer_lib.Trainer(
+ self._config,
+ task,
+ model=task.build_model(),
+ optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
+ self._config.runtime),
+ train_dataset=train_dataset,
+ validation_dataset=validation_dataset)
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('training_loss', logs)
+ self.assertIn('learning_rate', logs)
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('validation_loss', logs)
+
+ def test_base_async_trainer(self):
+ if TPU_TEST or GPU_TEST:
+ self.skipTest('Aysnc training is not available on GPU/GPU.')
+ num_workers = 3
+ num_ps = 2
+ cluster_resolver = create_in_process_cluster(num_workers, num_ps)
+ distribution = tf.distribute.experimental.ParameterServerStrategy(
+ cluster_resolver)
+ with distribution.scope():
+ trainer = MockAsyncTrainer()
+ trainer.init_async()
+ self.assertIsInstance(
+ trainer._coordinator,
+ tf.distribute.experimental.coordinator.ClusterCoordinator)
+ self.assertEqual(trainer.train(tf.constant(10)), 10)
+ self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
+
+ def test_async_trainer_train(self):
+ if TPU_TEST or GPU_TEST:
+ self.skipTest('Aysnc training is not available on GPU/TPU.')
+ num_workers = 3
+ num_ps = 2
+ cluster_resolver = create_in_process_cluster(num_workers, num_ps)
+ distribution = tf.distribute.experimental.ParameterServerStrategy(
+ cluster_resolver)
+ with distribution.scope():
+ config = cfg.ExperimentConfig(**self._config.as_dict())
+ config.trainer.eval_tf_while_loop = True
+ trainer = self.create_test_trainer(config)
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('training_loss', logs)
+ self.assertIn('learning_rate', logs)
+
+ def test_async_trainer_validate(self):
+ if TPU_TEST or GPU_TEST:
+ self.skipTest('Aysnc training is not available on GPU/GPU.')
+ num_workers = 3
+ num_ps = 2
+ cluster_resolver = create_in_process_cluster(num_workers, num_ps)
+ distribution = tf.distribute.experimental.ParameterServerStrategy(
+ cluster_resolver)
+ with distribution.scope():
+ config = cfg.ExperimentConfig(**self._config.as_dict())
+ config.trainer.eval_tf_while_loop = True
+ trainer = self.create_test_trainer(config)
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('acc', logs)
+ self.assertIn('validation_loss', logs)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_trainer_validate(self, distribution):
+ with distribution.scope():
+ trainer = self.create_test_trainer(self._config)
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
+ self.assertIn('validation_loss', logs)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_trainer_validate_without_loss(self, distribution):
+
+ class MockTaskWithoutValidationLoss(mock_task.MockTask):
+
+ def validation_step(self, inputs, model, metrics=None):
+ # Disable validation loss.
+ logs = super().validation_step(inputs, model)
+ del logs[self.loss]
+ return logs
+
+ with distribution.scope():
+ task = MockTaskWithoutValidationLoss()
+ trainer = self.create_test_trainer(self._config, task=task)
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
+ self.assertNotIn('validation_loss', logs)
+
+ @combinations.generate(
+ combinations.combine(
+ mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
+ loss_scale=[None, 'dynamic', 128, 256],
+ ))
+ def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
+ config = cfg.ExperimentConfig(
+ runtime=cfg.RuntimeConfig(
+ mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
+ trainer=cfg.TrainerConfig(
+ optimizer_config=cfg.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'sgd'
+ },
+ 'learning_rate': {
+ 'type': 'constant'
+ },
+ })))
+ trainer = self.create_test_trainer(config)
+ if mixed_precision_dtype == 'float16':
+ self.assertIsInstance(trainer.optimizer,
+ tf_keras.mixed_precision.LossScaleOptimizer)
+ if loss_scale in (None, 'dynamic'):
+ self.assertTrue(trainer.optimizer.dynamic)
+ else:
+ self.assertFalse(trainer.optimizer.dynamic)
+ self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
+ else:
+ self.assertIsInstance(
+ trainer.optimizer,
+ (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
+
+ metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('training_loss', metrics)
+
+ def test_export_best_ckpt(self):
+ config = cfg.ExperimentConfig(
+ trainer=cfg.TrainerConfig(
+ best_checkpoint_export_subdir='best_ckpt',
+ best_checkpoint_eval_metric='acc',
+ optimizer_config=cfg.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'sgd'
+ },
+ 'learning_rate': {
+ 'type': 'constant'
+ }
+ })))
+ model_dir = self.get_temp_dir()
+ trainer = self.create_test_trainer(config, model_dir=model_dir)
+ trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
+ trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
+ self.assertTrue(
+ tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
+
+ def test_model_with_compiled_loss(self):
+ task = mock_task.MockTask()
+ model = task.build_model()
+ model.compile(loss=tf_keras.losses.CategoricalCrossentropy())
+ trainer = trainer_lib.Trainer(
+ self._config,
+ task,
+ model=model,
+ optimizer=task.create_optimizer(self._config.trainer.optimizer_config))
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('training_loss', logs)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/config_definitions.py b/modeling/official/core/config_definitions.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef1f8acf1fb1c67d1f39d0285e64bba79da78083
--- /dev/null
+++ b/modeling/official/core/config_definitions.py
@@ -0,0 +1,309 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common configuration settings."""
+
+import dataclasses
+from typing import Optional, Sequence, Union
+
+from official.modeling.hyperparams import base_config
+from official.modeling.optimization.configs import optimization_config
+from official.modeling.privacy import configs as dp_configs
+
+OptimizationConfig = optimization_config.OptimizationConfig
+
+
+@dataclasses.dataclass
+class DataConfig(base_config.Config):
+ """The base configuration for building datasets.
+
+ Attributes:
+ input_path: The path to the input. It can be either (1) a str indicating a
+ file path/pattern, or (2) a str indicating multiple file paths/patterns
+ separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
+ str, each of which is a file path/pattern or multiple file paths/patterns
+ separated by comma, or (4) a dictionary of the previous three approaches
+ for more advanced data mixing using named access. It should not be
+ specified when the following `tfds_name` is specified.
+ tfds_name: The name of the tensorflow dataset (TFDS). It should not be
+ specified when the above `input_path` is specified.
+ tfds_split: A str indicating which split of the data to load from TFDS. It
+ is required when above `tfds_name` is specified.
+ global_batch_size: The global batch size across all replicas.
+ is_training: Whether this data is used for training or not. This flag is
+ useful for consumers of this object to determine whether the data should
+ be repeated or shuffled.
+ drop_remainder: Whether the last batch should be dropped in the case it has
+ fewer than `global_batch_size` elements.
+ shuffle_buffer_size: The buffer size used for shuffling training data.
+ cache: Whether to cache dataset examples. If `True`, we will cache the
+ dataset after applying the decode_fn and parse_fn. It can be used to avoid
+ re-reading from disk, re-decoding and re-parsing the example on the second
+ epoch, but it requires significant memory overhead.
+ cycle_length: The number of files that will be processed concurrently when
+ interleaving files.
+ block_length: The number of consecutive elements to produce from each input
+ element before cycling to another input element when interleaving files.
+ deterministic: A boolean controlling whether determinism should be enforced.
+ sharding: Whether sharding is used in the input pipeline.
+ enable_tf_data_service: A boolean indicating whether to enable tf.data
+ service for the input pipeline.
+ tf_data_service_address: The URI of a tf.data service to offload
+ preprocessing onto during training. The URI should be in the format
+ "protocol://address", e.g. "grpc://tf-data-service:5050". It can be
+ overridden by `FLAGS.tf_data_service` flag in the binary.
+ tf_data_service_job_name: The name of the tf.data service job. This argument
+ makes it possible for multiple datasets to share the same job. The default
+ behavior is that the dataset creates anonymous, exclusively owned jobs.
+ tfds_data_dir: A str specifying the directory to read/write TFDS data.
+ tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
+ returned tf.data.Dataset will have a 2-tuple structure (input, label)
+ according to builder.info.supervised_keys; if False, the default, the
+ returned tf.data.Dataset will have a dictionary with all the features.
+ tfds_skip_decoding_feature: A str to indicate which features are skipped for
+ decoding when loading dataset from TFDS. Use comma to separate multiple
+ features. The main use case is to skip the image/video decoding for better
+ performance.
+ enable_shared_tf_data_service_between_parallel_trainers: A bool. When set to
+ true, only a single tf.data service will be started, and it will be shared
+ between all the trainer run simultaneously, e.g. using vizier to tune
+ hyperparameters. This will save CPU and RAM resources compared to running
+ separate tf.data service for each trainer. Notice that if batch size is
+ different for different trainers, the field
+ apply_tf_data_service_before_batching also needs to be true so that only a
+ single tf.data service instance will be created. In this case, tf.data
+ service will be applied before batching operation. So make sure to not
+ apply any processing steps after batching (e.g. in postprocess_fn) since
+ they wouldn't be paralleled by tf.data service and may slow down your
+ tf.data pipeline. When using shared tf.data service, the tf.data dataset
+ must be infinite, and slow trainer may skip certain training examples.
+ More details about shared tf.data service can be found at:
+ https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
+ apply_tf_data_service_before_batching: A bool. If set to True, tf.data
+ service will be applied before batching operation. This is useful to make
+ sure only a single tf.data service instance is created when
+ enable_shared_tf_data_service_between_parallel_trainers is true and batch
+ size is changing between parallel trainers.
+ trainer_id: A string. The id of the trainer if there are multiple parallel
+ trainer running at the same time, e.g. in vizier tuning case. It will be
+ automatically set if this field is needed. Users does not need to set it
+ when creating experiment configs.
+ seed: An optional seed to use for deterministic shuffling/preprocessing.
+ prefetch_buffer_size: An int specifying the buffer size of prefetch
+ datasets. If None, the buffer size is autotuned. Specifying this is useful
+ in case autotuning uses up too much memory by making the buffer size too
+ high.
+ autotune_algorithm: If specified, use this algorithm for AUTOTUNE. See:
+ https://www.tensorflow.org/api_docs/python/tf/data/experimental/AutotuneAlgorithm
+ """
+ input_path: Union[Sequence[str], str, base_config.Config] = ""
+ tfds_name: Union[str, base_config.Config] = ""
+ tfds_split: str = ""
+ global_batch_size: int = 0
+ is_training: Optional[bool] = None
+ drop_remainder: bool = True
+ shuffle_buffer_size: int = 100
+ cache: bool = False
+ cycle_length: Optional[int] = None
+ block_length: int = 1
+ deterministic: Optional[bool] = None
+ sharding: bool = True
+ enable_tf_data_service: bool = False
+ tf_data_service_address: Optional[str] = None
+ tf_data_service_job_name: Optional[str] = None
+ tfds_data_dir: str = ""
+ tfds_as_supervised: bool = False
+ tfds_skip_decoding_feature: str = ""
+ enable_shared_tf_data_service_between_parallel_trainers: bool = False
+ apply_tf_data_service_before_batching: bool = False
+ trainer_id: Optional[str] = None
+ seed: Optional[int] = None
+ prefetch_buffer_size: Optional[int] = None
+ autotune_algorithm: Optional[str] = None
+
+
+@dataclasses.dataclass
+class RuntimeConfig(base_config.Config):
+ """High-level configurations for Runtime.
+
+ These include parameters that are not directly related to the experiment,
+ e.g. directories, accelerator type, etc.
+
+ Attributes:
+ distribution_strategy: e.g. 'mirrored', 'tpu', etc.
+ enable_xla: Whether or not to enable XLA.
+ per_gpu_thread_count: thread count per GPU.
+ gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
+ dataset_num_private_threads: Number of threads for a private threadpool
+ created for all datasets computation.
+ tpu: The address of the TPU to use, if any.
+ num_gpus: The number of GPUs to use, if any.
+ worker_hosts: comma-separated list of worker ip:port pairs for running
+ multi-worker models with DistributionStrategy.
+ task_index: If multi-worker training, the task index of this worker.
+ all_reduce_alg: Defines the algorithm for performing all-reduce.
+ num_packs: Sets `num_packs` in the cross device ops used in
+ MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
+ mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
+ 'float16', or 'bfloat16'.
+ loss_scale: The type of loss scale, or 'float' value. This is used when
+ setting the mixed precision policy.
+ run_eagerly: Whether or not to run the experiment eagerly.
+ batchnorm_spatial_persistent: Whether or not to enable the spatial
+ persistent mode for CuDNN batch norm kernel for improved GPU performance.
+ """
+ distribution_strategy: str = "mirrored"
+ enable_xla: bool = False
+ gpu_thread_mode: Optional[str] = None
+ dataset_num_private_threads: Optional[int] = None
+ per_gpu_thread_count: int = 0
+ tpu: Optional[str] = None
+ num_gpus: int = 0
+ worker_hosts: Optional[str] = None
+ task_index: int = -1
+ all_reduce_alg: Optional[str] = None
+ num_packs: int = 1
+ mixed_precision_dtype: Optional[str] = None
+ loss_scale: Optional[Union[str, float]] = None
+ run_eagerly: bool = False
+ batchnorm_spatial_persistent: bool = False
+
+ # XLA runtime params.
+ # XLA params are only applied to the train_step.
+ # These augments can improve training speed. They can also improve eval, but
+ # may reduce usability and users would need to make changes to code.
+
+ # Whether to enable XLA dynamic padder
+ # infrastructure to handle dynamic shapes inputs inside XLA. True by
+ # default. Disabling this may cause correctness issues with dynamic shapes
+ # inputs, as XLA will just assume the inputs are with padded shapes. However
+ # users can optionally set it to False to improve device time if masking is
+ # already handled in the user side.
+ # If None, will respect XLA default.
+ tpu_enable_xla_dynamic_padder: Optional[bool] = None
+
+ # Global model parallelism configurations.
+ num_cores_per_replica: int = 1
+ default_shard_dim: int = -1
+ use_tpu_mp_strategy: bool = False
+
+ def model_parallelism(self):
+ return dict(
+ num_cores_per_replica=self.num_cores_per_replica,
+ default_shard_dim=self.default_shard_dim)
+
+
+@dataclasses.dataclass
+class TrainerConfig(base_config.Config):
+ """Configuration for trainer.
+
+ Attributes:
+ optimizer_config: optimizer config, it includes optimizer, learning rate,
+ and warmup schedule configs.
+ train_tf_while_loop: whether or not to use tf while loop.
+ train_tf_function: whether or not to use tf_function for training loop.
+ eval_tf_function: whether or not to use tf_function for eval.
+ eval_tf_while_loop: whether or not to use tf while loop for eval.
+ allow_tpu_summary: Whether to allow summary happen inside the XLA program
+ runs on TPU through automatic outside compilation.
+ steps_per_loop: number of steps per loop to report training metrics. This
+ can also be used to reduce host worker communication in a TPU setup.
+ summary_interval: number of steps between each summary.
+ checkpoint_interval: number of steps between checkpoints.
+ max_to_keep: max checkpoints to keep.
+ continuous_eval_timeout: maximum number of seconds to wait between
+ checkpoints, if set to None, continuous eval will wait indefinitely. This
+ is only used continuous_train_and_eval and continuous_eval modes. Default
+ value is 1 hrs.
+ train_steps: number of train steps.
+ validation_steps: number of eval steps. If -1, the entire eval dataset is
+ used.
+ validation_interval: number of training steps to run between evaluations.
+ best_checkpoint_export_subdir: if set, the trainer will keep track of the
+ best evaluation metric, and export the corresponding best checkpoint under
+ `model_dir/best_checkpoint_export_subdir`. Note that this only works if
+ mode contains eval (such as `train_and_eval`, `continuous_eval`, and
+ `continuous_train_and_eval`).
+ best_checkpoint_eval_metric: for exporting the best checkpoint, which
+ evaluation metric the trainer should monitor. This can be any evaluation
+ metric appears on tensorboard.
+ best_checkpoint_metric_comp: for exporting the best checkpoint, how the
+ trainer should compare the evaluation metrics. This can be either `higher`
+ (higher the better) or `lower` (lower the better).
+ validation_summary_subdir: A 'str', sub directory for saving eval summary.
+ preemption_on_demand_checkpoint: whether or not to save on-demand
+ checkpoints after a preemption.
+ """
+ optimizer_config: OptimizationConfig = dataclasses.field(
+ default_factory=OptimizationConfig
+ )
+ # Orbit settings.
+ train_tf_while_loop: bool = True
+ train_tf_function: bool = True
+ eval_tf_function: bool = True
+ eval_tf_while_loop: bool = False
+ allow_tpu_summary: bool = False
+ # Trainer intervals.
+ steps_per_loop: int = 1000
+ summary_interval: int = 1000
+ checkpoint_interval: int = 1000
+ # Checkpoint manager.
+ max_to_keep: int = 5
+ continuous_eval_timeout: int = 60 * 60
+ # Train/Eval routines.
+ train_steps: int = 0
+ # Sets validation steps to be -1 to evaluate the entire dataset.
+ validation_steps: int = -1
+ validation_interval: int = 1000
+ # Best checkpoint export.
+ best_checkpoint_export_subdir: str = ""
+ best_checkpoint_eval_metric: str = ""
+ best_checkpoint_metric_comp: str = "higher"
+ # Blowup recovery.
+ loss_upper_bound: float = 1e6
+ recovery_begin_steps: int = 0 # Enforcing the loss bound after these steps.
+ # When max trials < 0, no recovery module; max trials = 0, we will check
+ # the condition and fail the job if the condition happens; max trials > 0,
+ # we will retore the model states.
+ recovery_max_trials: int = 0
+ validation_summary_subdir: str = "validation"
+ # Preemption on-demand checkpoint.
+ preemption_on_demand_checkpoint: bool = True # copybara-replace
+
+
+@dataclasses.dataclass
+class TaskConfig(base_config.Config):
+ """Config passed to task."""
+ init_checkpoint: str = ""
+ model: Optional[base_config.Config] = None
+ train_data: DataConfig = dataclasses.field(default_factory=DataConfig)
+ validation_data: DataConfig = dataclasses.field(default_factory=DataConfig)
+ name: Optional[str] = None
+ # Configs for differential privacy
+ # These configs are only effective if you use create_optimizer in
+ # tensorflow_models/official/core/base_task.py
+ # DEPRECATED b/264611883
+ differential_privacy_config: Optional[
+ dp_configs.DifferentialPrivacyConfig] = None
+ # Whether to show image summary. Useful to visualize model predictions. Only
+ # work for vision tasks.
+ allow_image_summary: bool = False
+
+
+@dataclasses.dataclass
+class ExperimentConfig(base_config.Config):
+ """Top-level configuration."""
+ task: TaskConfig = dataclasses.field(default_factory=TaskConfig)
+ trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig)
+ runtime: RuntimeConfig = dataclasses.field(default_factory=RuntimeConfig)
diff --git a/modeling/official/core/exp_factory.py b/modeling/official/core/exp_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c2cc132e9480e71b3af248cabd5de9be6b647c
--- /dev/null
+++ b/modeling/official/core/exp_factory.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Experiment factory methods."""
+
+from official.core import config_definitions as cfg
+from official.core import registry
+
+
+_REGISTERED_CONFIGS = {}
+
+
+def register_config_factory(name):
+ """Register ExperimentConfig factory method."""
+ return registry.register(_REGISTERED_CONFIGS, name)
+
+
+def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
+ """Looks up the `ExperimentConfig` according to the `exp_name`."""
+ exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
+ return exp_creater()
diff --git a/modeling/official/core/export_base.py b/modeling/official/core/export_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6166878cd88ebd8944477ca8d1daab9ab36db59
--- /dev/null
+++ b/modeling/official/core/export_base.py
@@ -0,0 +1,182 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Base class for model export."""
+
+import abc
+import functools
+import time
+from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+MAX_DIRECTORY_CREATION_ATTEMPTS = 10
+
+
+class ExportModule(tf.Module, metaclass=abc.ABCMeta):
+ """Base Export Module."""
+
+ def __init__(self,
+ params,
+ model: Union[tf.Module, tf_keras.Model],
+ inference_step: Optional[Callable[..., Any]] = None,
+ *,
+ preprocessor: Optional[Callable[..., Any]] = None,
+ postprocessor: Optional[Callable[..., Any]] = None):
+ """Instantiates an ExportModel.
+
+ Examples:
+
+ `inference_step` must be a function that has `model` as an kwarg or the
+ second positional argument.
+ ```
+ def _inference_step(inputs, model=None):
+ return model(inputs, training=False)
+
+ module = ExportModule(params, model, inference_step=_inference_step)
+ ```
+
+ `preprocessor` and `postprocessor` could be either functions or `tf.Module`.
+ The usages of preprocessor and postprocessor are managed by the
+ implementation of `serve()` method.
+
+ Args:
+ params: A dataclass for parameters to the module.
+ model: A model instance which contains weights and forward computation.
+ inference_step: An optional callable to forward-pass the model. If not
+ specified, it creates a parital function with `model` as an required
+ kwarg.
+ preprocessor: An optional callable to preprocess the inputs.
+ postprocessor: An optional callable to postprocess the model outputs.
+ """
+ super().__init__(name=None)
+ self.model = model
+ self.params = params
+
+ if inference_step is not None:
+ self.inference_step = functools.partial(inference_step, model=self.model)
+ else:
+ if issubclass(type(model), tf_keras.Model):
+ # Default to self.model.call instead of self.model.__call__ to avoid
+ # keras tracing logic designed for training.
+ # Since most of Model Garden's call doesn't not have training kwargs
+ # or the default is False, we don't pass anything here.
+ # Please pass custom inference step if your model has training=True as
+ # default.
+ self.inference_step = self.model.call
+ else:
+ self.inference_step = functools.partial(
+ self.model.__call__, training=False)
+ self.preprocessor = preprocessor
+ self.postprocessor = postprocessor
+
+ @abc.abstractmethod
+ def serve(self) -> Mapping[Text, tf.Tensor]:
+ """The bare inference function which should run on all devices.
+
+ Expecting tensors are passed in through keyword arguments. Returns a
+ dictionary of tensors, when the keys will be used inside the SignatureDef.
+ """
+
+ @abc.abstractmethod
+ def get_inference_signatures(
+ self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
+ """Get defined function signatures."""
+
+
+def export(export_module: ExportModule,
+ function_keys: Union[List[Text], Dict[Text, Text]],
+ export_savedmodel_dir: Text,
+ checkpoint_path: Optional[Text] = None,
+ timestamped: bool = True,
+ save_options: Optional[tf.saved_model.SaveOptions] = None,
+ checkpoint: Optional[tf.train.Checkpoint] = None) -> Text:
+ """Exports to SavedModel format.
+
+ Args:
+ export_module: a ExportModule with the keras Model and serving tf.functions.
+ function_keys: a list of string keys to retrieve pre-defined serving
+ signatures. The signaute keys will be set with defaults. If a dictionary
+ is provided, the values will be used as signature keys.
+ export_savedmodel_dir: Output saved model directory.
+ checkpoint_path: Object-based checkpoint path or directory.
+ timestamped: Whether to export the savedmodel to a timestamped directory.
+ save_options: `SaveOptions` for `tf.saved_model.save`.
+ checkpoint: An optional tf.train.Checkpoint. If provided, the export module
+ will use it to read the weights.
+
+ Returns:
+ The savedmodel directory path.
+ """
+ ckpt_dir_or_file = checkpoint_path
+ if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
+ ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
+ if ckpt_dir_or_file:
+ if checkpoint is None:
+ checkpoint = tf.train.Checkpoint(model=export_module.model)
+ checkpoint.read(
+ ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
+ if isinstance(function_keys, list):
+ if len(function_keys) == 1:
+ function_keys = {
+ function_keys[0]: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ }
+ else:
+ raise ValueError(
+ 'If the function_keys is a list, it must contain a single element. %s'
+ % function_keys)
+
+ signatures = export_module.get_inference_signatures(function_keys)
+ if timestamped:
+ export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode(
+ 'utf-8')
+ else:
+ export_dir = export_savedmodel_dir
+ tf.saved_model.save(
+ export_module, export_dir, signatures=signatures, options=save_options)
+ return export_dir
+
+
+def get_timestamped_export_dir(export_dir_base):
+ """Builds a path to a new subdirectory within the base directory.
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported graph
+ and checkpoints.
+
+ Returns:
+ The full path of the new subdirectory (which is not actually created yet).
+
+ Raises:
+ RuntimeError: if repeated attempts fail to obtain a unique timestamped
+ directory name.
+ """
+ attempts = 0
+ while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
+ timestamp = int(time.time())
+
+ result_dir = tf.io.gfile.join(
+ tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp)))
+ if not tf.io.gfile.exists(result_dir):
+ # Collisions are still possible (though extremely unlikely): this
+ # directory is not actually created yet, but it will be almost
+ # instantly on return from this function.
+ return result_dir
+ time.sleep(1)
+ attempts += 1
+ logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
+ str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
+ raise RuntimeError('Failed to obtain a unique export directory name after '
+ f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
diff --git a/modeling/official/core/export_base_test.py b/modeling/official/core/export_base_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f968735614f74646c379ddcf4c4dc06c5d1fb077
--- /dev/null
+++ b/modeling/official/core/export_base_test.py
@@ -0,0 +1,133 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.core.export_base."""
+import os
+from typing import Any, Dict, Mapping, Text
+
+import tensorflow as tf, tf_keras
+
+from official.core import export_base
+
+
+class TestModule(export_base.ExportModule):
+
+ @tf.function
+ def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
+ x = inputs if self.preprocessor is None else self.preprocessor(
+ inputs=inputs)
+ x = self.inference_step(x)
+ x = self.postprocessor(x) if self.postprocessor else x
+ return {'outputs': x}
+
+ def get_inference_signatures(
+ self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
+ input_signature = tf.TensorSpec(shape=[None, None], dtype=tf.float32)
+ return {'foo': self.serve.get_concrete_function(input_signature)}
+
+
+class ExportBaseTest(tf.test.TestCase):
+
+ def test_export_module(self):
+ tmp_dir = self.get_temp_dir()
+ model = tf_keras.layers.Dense(2)
+ inputs = tf.ones([2, 4], tf.float32)
+ expected_output = model(inputs, training=False)
+ module = TestModule(params=None, model=model)
+ ckpt_path = tf.train.Checkpoint(model=model).save(
+ os.path.join(tmp_dir, 'ckpt'))
+ export_dir = export_base.export(
+ module, ['foo'],
+ export_savedmodel_dir=tmp_dir,
+ checkpoint_path=ckpt_path,
+ timestamped=True)
+ self.assertTrue(os.path.exists(os.path.join(export_dir, 'saved_model.pb')))
+ self.assertTrue(
+ os.path.exists(
+ os.path.join(export_dir, 'variables', 'variables.index')))
+ self.assertTrue(
+ os.path.exists(
+ os.path.join(export_dir, 'variables',
+ 'variables.data-00000-of-00001')))
+
+ imported = tf.saved_model.load(export_dir)
+ output = imported.signatures['foo'](inputs)
+ self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
+
+ def test_custom_inference_step(self):
+ tmp_dir = self.get_temp_dir()
+ model = tf_keras.layers.Dense(2)
+ inputs = tf.ones([2, 4], tf.float32)
+
+ def _inference_step(inputs, model):
+ return tf.nn.softmax(model(inputs, training=False))
+
+ module = TestModule(
+ params=None, model=model, inference_step=_inference_step)
+ expected_output = _inference_step(inputs, model)
+ ckpt_path = tf.train.Checkpoint(model=model).save(
+ os.path.join(tmp_dir, 'ckpt'))
+ export_dir = export_base.export(
+ module, ['foo'],
+ export_savedmodel_dir=tmp_dir,
+ checkpoint_path=ckpt_path,
+ timestamped=False)
+ imported = tf.saved_model.load(export_dir)
+ output = imported.signatures['foo'](inputs)
+ self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
+
+ def test_processors(self):
+ model = tf.Module()
+ inputs = tf.zeros((), tf.float32)
+
+ def _inference_step(inputs, model):
+ del model
+ return inputs + 1.0
+
+ def _preprocessor(inputs):
+ print(inputs)
+ return inputs + 0.1
+
+ module = TestModule(
+ params=None,
+ model=model,
+ inference_step=_inference_step,
+ preprocessor=_preprocessor)
+ output = module.serve(inputs)
+ self.assertAllClose(output['outputs'].numpy(), 1.1)
+
+ class _PostProcessor(tf.Module):
+
+ def __call__(self, inputs):
+ return inputs + 0.01
+
+ module = TestModule(
+ params=None,
+ model=model,
+ inference_step=_inference_step,
+ preprocessor=_preprocessor,
+ postprocessor=_PostProcessor())
+ output = module.serve(inputs)
+ self.assertAllClose(output['outputs'].numpy(), 1.11)
+
+ def test_get_timestamped_export_dir(self):
+ export_dir = self.get_temp_dir()
+ timed_dir = export_base.get_timestamped_export_dir(
+ export_dir_base=export_dir)
+ self.assertFalse(tf.io.gfile.exists(timed_dir))
+ self.assertIn(export_dir, str(timed_dir))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/file_writers.py b/modeling/official/core/file_writers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a23d5d2e8addb51e32872f36cacb7ffc8fb556dc
--- /dev/null
+++ b/modeling/official/core/file_writers.py
@@ -0,0 +1,80 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""File writer functions for dataset preparation, infra validation, and unit tests."""
+
+import io
+from typing import Optional, Sequence, Union
+
+import tensorflow as tf, tf_keras
+
+
+def write_small_dataset(examples: Sequence[Union[tf.train.Example,
+ tf.train.SequenceExample]],
+ output_path: str,
+ file_type: str = 'tfrecord') -> None:
+ """Writes `examples` to a file at `output_path` with type `file_type`.
+
+ CAVEAT: This function is not recommended for writing large datasets, since it
+ will loop through `examples` and perform write operation sequentially.
+
+ Args:
+ examples: List of tf.train.Example or tf.train.SequenceExample.
+ output_path: Output path for the dataset.
+ file_type: A string indicating the file format, could be: 'tfrecord',
+ 'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
+ string is case insensitive.
+ """
+ file_type = file_type.lower()
+
+ if file_type == 'tfrecord' or file_type == 'tfrecords':
+ _write_tfrecord(examples, output_path)
+ elif file_type == 'tfrecord_compressed' or file_type == 'tfrecords_gzip':
+ _write_tfrecord(examples, output_path,
+ tf.io.TFRecordOptions(compression_type='GZIP'))
+ elif file_type == 'riegeli':
+ _write_riegeli(examples, output_path)
+ else:
+ raise ValueError(f'Unknown file_type: {file_type}')
+
+
+def _write_tfrecord(examples: Sequence[Union[tf.train.Example,
+ tf.train.SequenceExample]],
+ output_path: str,
+ options: Optional[tf.io.TFRecordOptions] = None) -> None:
+ """Writes `examples` to a TFRecord file at `output_path`.
+
+ Args:
+ examples: A list of tf.train.Example.
+ output_path: Output path for the dataset.
+ options: Options used for manipulating TFRecord files.
+ """
+ with tf.io.TFRecordWriter(output_path, options) as writer:
+ for example in examples:
+ writer.write(example.SerializeToString())
+
+
+def _write_riegeli(examples: Sequence[Union[tf.train.Example,
+ tf.train.SequenceExample]],
+ output_path: str) -> None:
+ """Writes `examples` to a Riegeli file at `output_path`.
+
+ Args:
+ examples: A list of tf.train.Example.
+ output_path: Output path for the dataset.
+ """
+ with io.FileIO(output_path, 'wb') as fileio:
+ import riegeli # pylint: disable=g-import-not-at-top
+ with riegeli.RecordWriter(fileio) as writer:
+ writer.write_messages(examples)
diff --git a/modeling/official/core/file_writers_test.py b/modeling/official/core/file_writers_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c472bbd5cf15da7a816f16c897f2aca773138d
--- /dev/null
+++ b/modeling/official/core/file_writers_test.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for file_writers."""
+
+import os
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from official.core import file_writers
+from official.core import tf_example_builder
+
+
+class FileWritersTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example_builder.add_bytes_feature('foo', 'Hello World!')
+ self._example = example_builder.example
+
+ @parameterized.parameters('tfrecord', 'TFRecord', 'tfrecords',
+ 'tfrecord_compressed', 'TFRecord_Compressed',
+ 'tfrecords_gzip')
+ def test_write_small_dataset_success(self, file_type):
+ temp_dir = self.create_tempdir()
+ temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
+ file_writers.write_small_dataset([self._example], temp_dataset_file,
+ file_type)
+ self.assertTrue(os.path.exists(temp_dataset_file))
+
+ def test_write_small_dataset_unrecognized_format(self):
+ file_type = 'bar'
+ temp_dir = self.create_tempdir()
+ temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
+ with self.assertRaises(ValueError):
+ file_writers.write_small_dataset([self._example], temp_dataset_file,
+ file_type)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/input_reader.py b/modeling/official/core/input_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..356779e5e5afd8b52f5561ccb1faa747c71be2b5
--- /dev/null
+++ b/modeling/official/core/input_reader.py
@@ -0,0 +1,591 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A common dataset reader."""
+import dataclasses
+import random
+from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
+
+from absl import logging
+import tensorflow as tf, tf_keras
+import tensorflow_datasets as tfds
+
+from official.core import config_definitions as cfg
+
+
+def _get_random_integer():
+ return random.randint(0, (1 << 31) - 1)
+
+
+def _maybe_map_fn(dataset: tf.data.Dataset,
+ fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
+ """Calls dataset.map if a valid function is passed in."""
+ return dataset if fn is None else dataset.map(
+ fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+
+def match_files(input_path: Union[Sequence[str], str]) -> List[str]:
+ """Matches files from an input_path."""
+ matched_files = []
+ # Read dataset from files.
+ usage = ('`input_path` should be either (1) a str indicating a file '
+ 'path/pattern, or (2) a str indicating multiple file '
+ 'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
+ '"a,b,c", or (3) a list of str, each of which is a file '
+ 'path/pattern or multiple file paths/patterns separated by '
+ 'comma, but got: %s')
+ if isinstance(input_path, str):
+ input_path_list = [input_path]
+ elif isinstance(input_path, (list, tuple)):
+ if any(not isinstance(x, str) for x in input_path):
+ raise ValueError(usage % input_path)
+ input_path_list = input_path
+ else:
+ raise ValueError(usage % input_path)
+
+ for input_path in input_path_list:
+ input_patterns = input_path.strip().split(',')
+ for input_pattern in input_patterns:
+ input_pattern = input_pattern.strip()
+ if not input_pattern:
+ continue
+ if '*' in input_pattern or '?' in input_pattern:
+ tmp_matched_files = tf.io.gfile.glob(input_pattern)
+ if not tmp_matched_files:
+ raise ValueError('%s does not match any files.' % input_pattern)
+ matched_files.extend(tmp_matched_files)
+ else:
+ matched_files.append(input_pattern)
+
+ if not matched_files:
+ raise ValueError('%s does not match any files.' % input_path)
+
+ return matched_files
+
+
+def _read_files_then_shard(matched_files: List[str],
+ dataset_fn,
+ input_context: Optional[
+ tf.distribute.InputContext] = None,
+ sharding: bool = False,
+ repeat: bool = False) -> tf.data.Dataset:
+ """Sends all data files to every worker and then shard by data."""
+ dataset = dataset_fn(matched_files)
+
+ # When `input_file` is a path to a single file or the number of files is
+ # less than the number of input pipelines, disable auto sharding
+ # so that same input file is sent to all workers.
+ options = tf.data.Options()
+ options.experimental_distribute.auto_shard_policy = (
+ tf.data.experimental.AutoShardPolicy.OFF)
+ dataset = dataset.with_options(options)
+ # Do not enable sharding if tf.data service is enabled, as sharding will be
+ # handled inside tf.data service.
+ if sharding and input_context and (input_context.num_input_pipelines > 1):
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+
+ if repeat:
+ dataset = dataset.repeat()
+ return dataset
+
+
+def _shard_files_then_read(matched_files: List[str],
+ dataset_fn,
+ input_context: Optional[
+ tf.distribute.InputContext] = None,
+ seed: Optional[Union[int, tf.Tensor]] = None,
+ is_training: bool = False,
+ sharding: bool = False,
+ cache: bool = False,
+ cycle_length: Optional[int] = None,
+ block_length: Optional[int] = None,
+ deterministic: bool = False) -> tf.data.Dataset:
+ """Shards the data files and then sent a split to every worker to read."""
+ dataset = tf.data.Dataset.from_tensor_slices(matched_files)
+
+ # Shuffle and repeat at file level.
+ # If cache is enabled, `reshuffle_each_iteration` is set to False,
+ # because we will read the same cached data in every iteration anyway.
+ if is_training:
+ # We need a seed to shuffle the files so that when each TPU workers gets
+ # its own shard the files do not overlap.
+ if sharding and seed is None:
+ seed = _get_random_integer()
+ dataset = dataset.shuffle(
+ len(matched_files),
+ seed=seed,
+ reshuffle_each_iteration=True if not cache else False)
+
+ # Do not enable sharding if tf.data service is enabled, as sharding will be
+ # handled inside tf.data service.
+ if sharding and input_context and (input_context.num_input_pipelines > 1):
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+
+ # If cache is enabled, we will call `repeat()` later after `cache()`.
+ if is_training and not cache:
+ dataset = dataset.repeat()
+
+ dataset = dataset.interleave(
+ map_func=dataset_fn,
+ cycle_length=cycle_length,
+ block_length=block_length,
+ num_parallel_calls=(cycle_length
+ if cycle_length else tf.data.experimental.AUTOTUNE),
+ deterministic=deterministic)
+ return dataset
+
+
+def _read_tfds(tfds_name: Text,
+ tfds_data_dir: Text,
+ tfds_split: Text,
+ tfds_skip_decoding_feature: Text,
+ tfds_as_supervised: bool,
+ input_context: Optional[tf.distribute.InputContext] = None,
+ seed: Optional[Union[int, tf.Tensor]] = None,
+ is_training: bool = False,
+ cache: bool = False,
+ cycle_length: Optional[int] = None,
+ block_length: Optional[int] = None) -> tf.data.Dataset:
+ """Reads a dataset from tfds."""
+ repeat_filenames = is_training and not cache
+ read_config = tfds.ReadConfig(
+ interleave_cycle_length=cycle_length,
+ interleave_block_length=block_length,
+ input_context=input_context,
+ shuffle_seed=seed,
+ repeat_filenames=repeat_filenames,
+ # Only assert cardinality when we have a finite dataset.
+ assert_cardinality=not repeat_filenames,
+ skip_prefetch=True)
+
+ decoders = {}
+ if tfds_skip_decoding_feature:
+ for skip_feature in tfds_skip_decoding_feature.split(','):
+ decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
+
+ if tfds_name.startswith('mldataset.'):
+ dataset = tfds.load(name=tfds_name,
+ split=tfds_split,
+ as_supervised=tfds_as_supervised,
+ decoders=decoders if decoders else None,
+ read_config=read_config)
+ else:
+ builder = tfds.builder(tfds_name, data_dir=tfds_data_dir)
+ if builder.info.splits:
+ num_shards = len(builder.info.splits[tfds_split].file_instructions)
+ else:
+ # The tfds mock path often does not provide splits.
+ num_shards = 1
+ load_kwargs = dict(
+ name=tfds_name, download=True, split=tfds_split,
+ shuffle_files=is_training, as_supervised=tfds_as_supervised,
+ decoders=decoders if decoders else None)
+ if tfds_data_dir:
+ load_kwargs.update({'data_dir': tfds_data_dir})
+
+ if input_context and num_shards < input_context.num_input_pipelines:
+ # The number of files in the dataset split is smaller than the number of
+ # input pipelines. We read the entire dataset first and then shard in the
+ # host memory.
+ read_config = dataclasses.replace(read_config, input_context=None)
+ load_kwargs.update({'read_config': read_config})
+ dataset = tfds.load(**load_kwargs)
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+ else:
+ load_kwargs.update({'read_config': read_config})
+ dataset = tfds.load(**load_kwargs)
+ return dataset
+
+
+class InputReader:
+ """Input reader that returns a tf.data.Dataset instance."""
+
+ # A static random number which is the same across different InputReader
+ # instances.
+ static_randnum = _get_random_integer()
+
+ def __init__(
+ self,
+ params: cfg.DataConfig,
+ dataset_fn=tf.data.TFRecordDataset,
+ decoder_fn: Optional[Callable[..., Any]] = None,
+ combine_fn: Optional[Callable[..., Any]] = None,
+ sample_fn: Optional[Callable[..., Any]] = None,
+ parser_fn: Optional[Callable[..., Any]] = None,
+ filter_fn: Optional[Callable[..., tf.Tensor]] = None,
+ transform_and_batch_fn: Optional[
+ Callable[
+ [tf.data.Dataset, Optional[tf.distribute.InputContext]],
+ tf.data.Dataset,
+ ]
+ ] = None,
+ postprocess_fn: Optional[Callable[..., Any]] = None,
+ ):
+ """Initializes an InputReader instance.
+
+ Args:
+ params: A config_definitions.DataConfig object.
+ dataset_fn: A `tf.data.Dataset` that consumes the input files. For
+ example, it can be `tf.data.TFRecordDataset`.
+ decoder_fn: An optional `callable` that takes the serialized data string
+ and decodes them into the raw tensor dictionary.
+ combine_fn: An optional `callable` that takes a dictionarty of
+ `tf.data.Dataset` objects as input and outputs a combined dataset. It
+ will be executed after the decoder_fn and before the sample_fn.
+ sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
+ input and outputs the transformed dataset. It performs sampling on the
+ decoded raw tensors dict before the parser_fn.
+ parser_fn: An optional `callable` that takes the decoded raw tensors dict
+ and parse them into a dictionary of tensors that can be consumed by the
+ model. It will be executed after decoder_fn.
+ filter_fn: An optional `callable` mapping a dataset element to a boolean.
+ It will be executed after parser_fn.
+ transform_and_batch_fn: An optional `callable` that takes a
+ `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
+ input, and returns a `tf.data.Dataset` object. It will be executed after
+ `parser_fn` to transform and batch the dataset; if None, after
+ `parser_fn` is executed, the dataset will be batched into per-replica
+ batch size.
+ postprocess_fn: A optional `callable` that processes batched tensors. It
+ will be executed after batching.
+ """
+ if params.input_path and params.tfds_name:
+ raise ValueError('At most one of `input_path` and `tfds_name` can be '
+ 'specified, but got %s and %s.' %
+ (params.input_path, params.tfds_name))
+
+ if (isinstance(params.input_path, cfg.base_config.Config) or
+ isinstance(params.tfds_name, cfg.base_config.Config)
+ ) and combine_fn is None:
+ raise ValueError(
+ 'A combine_fn is required if `input_path` or `tfds_name` is a dict.')
+
+ self._tfds_name = params.tfds_name
+ self._tfds_data_dir = params.tfds_data_dir
+ self._matched_files = None
+ if not params.input_path:
+ # Read dataset from TFDS.
+ if not params.tfds_split:
+ raise ValueError(
+ '`tfds_name` is %s, but `tfds_split` is not specified.' %
+ params.tfds_name)
+ else:
+ self._matched_files = self.get_files(params.input_path)
+
+ self._global_batch_size = params.global_batch_size
+ self._is_training = params.is_training
+ self._drop_remainder = params.drop_remainder
+ self._shuffle_buffer_size = params.shuffle_buffer_size
+ self._cache = params.cache
+ self._cycle_length = params.cycle_length
+ self._block_length = params.block_length
+ self._deterministic = params.deterministic
+ self._sharding = params.sharding
+ self._tfds_split = params.tfds_split
+ self._tfds_as_supervised = params.tfds_as_supervised
+ self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
+
+ self._dataset_fn = dataset_fn
+ self._decoder_fn = decoder_fn
+ self._combine_fn = combine_fn
+ self._sample_fn = sample_fn
+ self._parser_fn = parser_fn
+ self._transform_and_batch_fn = transform_and_batch_fn
+ self._postprocess_fn = postprocess_fn
+ self._filter_fn = filter_fn
+ self._seed = params.seed
+ self._prefetch_buffer_size = (
+ params.prefetch_buffer_size or tf.data.experimental.AUTOTUNE)
+ self._autotune_algorithm = params.autotune_algorithm
+
+ # When tf.data service is enabled, each data service worker should get
+ # different random seeds. Thus, we set `seed` to None.
+ # Sharding should also be disabled because tf data service handles how
+ # each worker shard data with `processing_mode` in distribute method.
+ if params.enable_tf_data_service:
+ self._seed = None
+ self._sharding = False
+
+ self._enable_tf_data_service = (
+ params.enable_tf_data_service and params.tf_data_service_address)
+ self._tf_data_service_address = params.tf_data_service_address
+ self._enable_shared_tf_data_service_between_parallel_trainers = (
+ params.enable_shared_tf_data_service_between_parallel_trainers)
+ self._apply_tf_data_service_before_batching = (
+ params.apply_tf_data_service_before_batching)
+ self._trainer_id = params.trainer_id
+ if self._enable_tf_data_service:
+ # Add a random seed as the tf.data service job name suffix, so tf.data
+ # service doesn't reuse the previous state if TPU worker gets preempted.
+ # It's necessary to add global batch size into the tf data service job
+ # name because when tuning batch size with vizier and tf data service is
+ # also enable, the tf data servce job name should be different for
+ # different vizier trials since once batch size is changed, from the
+ # tf.data perspective, the dataset is a different instance, and a
+ # different job name should be used for tf data service. Otherwise, the
+ # model would read tensors from the incorrect tf data service job, which
+ # would causes dimension mismatch on the batch size dimension.
+ self._tf_data_service_job_name = (
+ f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
+ f'{self.static_randnum}')
+ self._enable_round_robin_tf_data_service = params.get(
+ 'enable_round_robin_tf_data_service', False)
+ if self._enable_shared_tf_data_service_between_parallel_trainers:
+ # When shared tf.data service is enabled, only a single tf.data service
+ # instance should be created and shared between parallel trainers. If
+ # the global batch size is different across trainers,
+ # params.apply_tf_data_service_before_batching should be set to true
+ # because tf.data service with different batch sizes will be considered
+ # separate tf.data service instances.
+ self._tf_data_service_job_name = (
+ f'{params.tf_data_service_job_name}_{self.static_randnum}')
+
+ def get_files(self, input_path):
+ """Gets matched files. Can be overridden by subclasses."""
+ if not input_path:
+ return None
+ # we want to combine / mix datasets
+ if isinstance(input_path, cfg.base_config.Config):
+ matched_files = {}
+ for k, v in input_path.as_dict().items():
+ matched_files[k] = match_files(v)
+ # single dataset
+ else:
+ matched_files = match_files(input_path)
+ return matched_files
+
+ def _read_data_source(
+ self,
+ matched_files: Union[Dict[str, List[str]], List[str]],
+ dataset_fn,
+ input_context: Optional[tf.distribute.InputContext] = None,
+ ):
+ """Reads the data source (files/tfds) to a dataset."""
+
+ def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
+ if len(files) > 1:
+ if input_context and (len(files) < input_context.num_input_pipelines):
+ logging.warn(
+ (
+ 'The number of files %d is less than the number of input '
+ 'pipelines %d. We will send all input files to every worker. '
+ 'Please consider sharding your data into more files.'
+ ),
+ len(files),
+ input_context.num_input_pipelines,
+ )
+ return _read_files_then_shard(
+ files,
+ dataset_fn,
+ input_context,
+ sharding=self._sharding,
+ repeat=self._is_training and not self._cache)
+ else:
+ return _shard_files_then_read(
+ files,
+ dataset_fn,
+ input_context,
+ seed=self._seed,
+ is_training=self._is_training,
+ sharding=self._sharding,
+ cache=self._cache,
+ cycle_length=self._cycle_length,
+ block_length=self._block_length,
+ deterministic=self._deterministic)
+ elif len(files) == 1:
+ return _read_files_then_shard(
+ files,
+ dataset_fn,
+ input_context,
+ sharding=self._sharding,
+ repeat=self._is_training and not self._cache)
+ else:
+ raise ValueError('It is unexpected that `tfds_builder` is None and '
+ 'there is also no `files`.')
+
+ if self._tfds_name:
+ if isinstance(self._tfds_name, cfg.base_config.Config):
+ dataset = {}
+ for k, tfds_name in self._tfds_name.as_dict().items():
+ dataset[k] = _read_tfds(
+ tfds_name=tfds_name,
+ tfds_data_dir=self._tfds_data_dir,
+ tfds_split=self._tfds_split,
+ tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
+ tfds_as_supervised=self._tfds_as_supervised,
+ input_context=input_context,
+ seed=self._seed,
+ is_training=self._is_training,
+ cache=self._cache,
+ cycle_length=self._cycle_length,
+ block_length=self._block_length)
+ else:
+ dataset = _read_tfds(
+ tfds_name=self._tfds_name,
+ tfds_data_dir=self._tfds_data_dir,
+ tfds_split=self._tfds_split,
+ tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
+ tfds_as_supervised=self._tfds_as_supervised,
+ input_context=input_context,
+ seed=self._seed,
+ is_training=self._is_training,
+ cache=self._cache,
+ cycle_length=self._cycle_length,
+ block_length=self._block_length)
+ elif isinstance(matched_files, (list, tuple)):
+ dataset = _files_to_dataset(matched_files)
+ elif isinstance(matched_files, dict):
+ dataset = {}
+ for k, fs in matched_files.items():
+ dataset[k] = _files_to_dataset(fs)
+ else:
+ raise ValueError('`matched_files` should be a list or dict.')
+
+ return dataset
+
+ def _decode_and_parse_dataset(
+ self,
+ dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]],
+ batch_size: int,
+ input_context: Optional[tf.distribute.InputContext] = None
+ ) -> tf.data.Dataset:
+ """Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
+
+ def _shuffle_and_decode(ds):
+ # If cache is enabled, we will call `shuffle()` later after `cache()`.
+ if self._is_training and not self._cache:
+ ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
+ # Decode
+ ds = _maybe_map_fn(ds, self._decoder_fn)
+ return ds
+
+ dataset = tf.nest.map_structure(_shuffle_and_decode, dataset)
+ if tf.nest.is_nested(dataset):
+ dataset = self._combine_fn(dataset)
+
+ if self._sample_fn is not None:
+ dataset = dataset.apply(self._sample_fn)
+ dataset = _maybe_map_fn(dataset, self._parser_fn)
+
+ if self._filter_fn is not None:
+ dataset = dataset.filter(self._filter_fn)
+
+ if self._cache:
+ dataset = dataset.cache()
+ if self._is_training:
+ dataset = dataset.repeat()
+ dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
+
+ # Applies tf.data service before batching operations. This is useful when
+ # tf.data service is shared between parallel trainers, and batch size is
+ # changing between parallel trainers. Then batch size is changing, tf.data
+ # services will be considered different instances if applied after batching
+ # operations, which make it difficult to share between parallel trainers.
+ # However, if there are additional expensive operations in
+ # self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data
+ # pipeline could be slowed down. In this case, try to move these dataset
+ # operations into early stages if possible.
+ if (self._enable_shared_tf_data_service_between_parallel_trainers and
+ self._apply_tf_data_service_before_batching):
+ dataset = self._maybe_apply_data_service(dataset, input_context)
+
+ if self._transform_and_batch_fn is not None:
+ dataset = self._transform_and_batch_fn(dataset, input_context)
+ else:
+ per_replica_batch_size = input_context.get_per_replica_batch_size(
+ batch_size) if input_context else batch_size
+ dataset = dataset.batch(
+ per_replica_batch_size, drop_remainder=self._drop_remainder)
+
+ return dataset
+
+ def _maybe_apply_data_service(
+ self,
+ dataset: tf.data.Dataset,
+ input_context: Optional[tf.distribute.InputContext] = None
+ ) -> tf.data.Dataset:
+ """Potentially distributes a dataset."""
+ if self._enable_tf_data_service and input_context:
+ if self._enable_round_robin_tf_data_service:
+ replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
+ input_context.num_input_pipelines)
+ base_consumer_index = input_context.input_pipeline_id * (
+ replicas_per_input_pipeline)
+ num_consumers = input_context.num_input_pipelines * (
+ replicas_per_input_pipeline)
+ range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
+ tfds_kwargs = {
+ 'processing_mode': 'parallel_epochs',
+ 'service': self._tf_data_service_address,
+ 'job_name': self._tf_data_service_job_name,
+ 'num_consumers': num_consumers
+ }
+ if self._enable_shared_tf_data_service_between_parallel_trainers:
+ raise ValueError('Shared tf.data service does not support round-robin'
+ ' tf.data service.')
+ dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
+ tf.data.experimental.service.distribute(
+ consumer_index=base_consumer_index + i, **tfds_kwargs)))
+ # Use parallel interleave to read multiple batches from a tf.data
+ # service worker in parallel.
+ dataset = dataset.interleave(
+ lambda x: x,
+ cycle_length=replicas_per_input_pipeline,
+ num_parallel_calls=replicas_per_input_pipeline,
+ deterministic=True)
+ else:
+ tfds_kwargs = {
+ 'processing_mode': 'parallel_epochs',
+ 'service': self._tf_data_service_address,
+ 'job_name': self._tf_data_service_job_name,
+ }
+ if self._enable_shared_tf_data_service_between_parallel_trainers:
+ tfds_kwargs.update({
+ 'processing_mode':
+ tf.data.experimental.service.ShardingPolicy.OFF,
+ 'cross_trainer_cache':
+ tf.data.experimental.service.CrossTrainerCache(
+ trainer_id=self._trainer_id)
+ })
+ dataset = dataset.apply(
+ tf.data.experimental.service.distribute(**tfds_kwargs))
+ return dataset
+
+ def read(self,
+ input_context: Optional[tf.distribute.InputContext] = None,
+ dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
+ """Generates a tf.data.Dataset object."""
+ if dataset is None:
+ dataset = self._read_data_source(self._matched_files, self._dataset_fn,
+ input_context)
+ dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
+ input_context)
+ dataset = _maybe_map_fn(dataset, self._postprocess_fn)
+ if not (self._enable_shared_tf_data_service_between_parallel_trainers and
+ self._apply_tf_data_service_before_batching):
+ dataset = self._maybe_apply_data_service(dataset, input_context)
+
+ if self._deterministic is not None:
+ options = tf.data.Options()
+ options.deterministic = self._deterministic
+ dataset = dataset.with_options(options)
+ if self._autotune_algorithm:
+ options = tf.data.Options()
+ options.autotune.autotune_algorithm = (
+ tf.data.experimental.AutotuneAlgorithm[self._autotune_algorithm])
+ dataset = dataset.with_options(options)
+ return dataset.prefetch(self._prefetch_buffer_size)
diff --git a/modeling/official/core/registry.py b/modeling/official/core/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..de3e0f1c814fe127d1df180de29e8641423cb3d0
--- /dev/null
+++ b/modeling/official/core/registry.py
@@ -0,0 +1,101 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Registry utility."""
+
+
+def register(registered_collection, reg_key):
+ """Register decorated function or class to collection.
+
+ Register decorated function or class into registered_collection, in a
+ hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
+ the decorated function or class is stored under
+ registered_collection["my_model"]["my_exp"]["my_config_0"].
+ This decorator is supposed to be used together with the lookup() function in
+ this file.
+
+ Args:
+ registered_collection: a dictionary. The decorated function or class will be
+ put into this collection.
+ reg_key: The key for retrieving the registered function or class. If reg_key
+ is a string, it can be hierarchical like my_model/my_exp/my_config_0
+ Returns:
+ A decorator function
+ Raises:
+ KeyError: when function or class to register already exists.
+ """
+ def decorator(fn_or_cls):
+ """Put fn_or_cls in the dictionary."""
+ if isinstance(reg_key, str):
+ hierarchy = reg_key.split("/")
+ collection = registered_collection
+ for h_idx, entry_name in enumerate(hierarchy[:-1]):
+ if entry_name not in collection:
+ collection[entry_name] = {}
+ collection = collection[entry_name]
+ if not isinstance(collection, dict):
+ raise KeyError(
+ "Collection path {} at position {} already registered as "
+ "a function or class.".format(entry_name, h_idx))
+ leaf_reg_key = hierarchy[-1]
+ else:
+ collection = registered_collection
+ leaf_reg_key = reg_key
+
+ if leaf_reg_key in collection:
+ raise KeyError("Function or class {} registered multiple times.".format(
+ leaf_reg_key))
+
+ collection[leaf_reg_key] = fn_or_cls
+ return fn_or_cls
+ return decorator
+
+
+def lookup(registered_collection, reg_key):
+ """Lookup and return decorated function or class in the collection.
+
+ Lookup decorated function or class in registered_collection, in a
+ hierarchical order. For example, when
+ reg_key="my_model/my_exp/my_config_0",
+ this function will return
+ registered_collection["my_model"]["my_exp"]["my_config_0"].
+
+ Args:
+ registered_collection: a dictionary. The decorated function or class will be
+ retrieved from this collection.
+ reg_key: The key for retrieving the registered function or class. If reg_key
+ is a string, it can be hierarchical like my_model/my_exp/my_config_0
+ Returns:
+ The registered function or class.
+ Raises:
+ LookupError: when reg_key cannot be found.
+ """
+ if isinstance(reg_key, str):
+ hierarchy = reg_key.split("/")
+ collection = registered_collection
+ for h_idx, entry_name in enumerate(hierarchy):
+ if entry_name not in collection:
+ raise LookupError(
+ f"collection path {entry_name} at position {h_idx} is never "
+ f"registered. Please make sure the {entry_name} and its library is "
+ "imported and linked to the trainer binary.")
+ collection = collection[entry_name]
+ return collection
+ else:
+ if reg_key not in registered_collection:
+ raise LookupError(
+ f"registration key {reg_key} is never "
+ f"registered. Please make sure the {reg_key} and its library is "
+ "imported and linked to the trainer binary.")
+ return registered_collection[reg_key]
diff --git a/modeling/official/core/registry_test.py b/modeling/official/core/registry_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbae9dc32325e31ef3810beb5b38d09eb7b73b26
--- /dev/null
+++ b/modeling/official/core/registry_test.py
@@ -0,0 +1,88 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for registry."""
+
+import tensorflow as tf, tf_keras
+from official.core import registry
+
+
+class RegistryTest(tf.test.TestCase):
+
+ def test_register(self):
+ collection = {}
+
+ @registry.register(collection, 'functions/func_0')
+ def func_test():
+ pass
+
+ self.assertEqual(registry.lookup(collection, 'functions/func_0'), func_test)
+
+ @registry.register(collection, 'classes/cls_0')
+ class ClassRegistryKey:
+ pass
+
+ self.assertEqual(
+ registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey)
+
+ @registry.register(collection, ClassRegistryKey)
+ class ClassRegistryValue:
+ pass
+
+ self.assertEqual(
+ registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
+
+ def test_register_hierarchy(self):
+ collection = {}
+
+ @registry.register(collection, 'functions/func_0')
+ def func_test0():
+ pass
+
+ @registry.register(collection, 'func_1')
+ def func_test1():
+ pass
+
+ @registry.register(collection, func_test1)
+ def func_test2():
+ pass
+
+ expected_collection = {
+ 'functions': {
+ 'func_0': func_test0,
+ },
+ 'func_1': func_test1,
+ func_test1: func_test2,
+ }
+ self.assertEqual(collection, expected_collection)
+
+ def test_register_error(self):
+ collection = {}
+
+ @registry.register(collection, 'functions/func_0')
+ def func_test0(): # pylint: disable=unused-variable
+ pass
+
+ with self.assertRaises(KeyError):
+
+ @registry.register(collection, 'functions/func_0/sub_func')
+ def func_test1(): # pylint: disable=unused-variable
+ pass
+
+ with self.assertRaises(LookupError):
+ registry.lookup(collection, 'non-exist')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/savedmodel_checkpoint_manager.py b/modeling/official/core/savedmodel_checkpoint_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3b958b83f984626df64cb258886234bab88d491
--- /dev/null
+++ b/modeling/official/core/savedmodel_checkpoint_manager.py
@@ -0,0 +1,258 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Custom checkpoint manager that also exports saved models."""
+
+import os
+import re
+import time
+from typing import Callable, List, Mapping, Optional, Union
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+SAVED_MODULES_PATH_SUFFIX = 'saved_modules'
+
+
+def make_saved_modules_directory_name(checkpoint_name: str) -> str:
+ return f'{checkpoint_name}_{SAVED_MODULES_PATH_SUFFIX}'
+
+
+class SavedModelCheckpointManager(tf.train.CheckpointManager):
+ """A CheckpointManager that also exports `SavedModel`s."""
+
+ def __init__(self,
+ checkpoint: tf.train.Checkpoint,
+ directory: str,
+ max_to_keep: int,
+ modules_to_export: Optional[Mapping[str, tf.Module]] = None,
+ keep_checkpoint_every_n_hours: Optional[int] = None,
+ checkpoint_name: str = 'ckpt',
+ step_counter: Optional[tf.Variable] = None,
+ checkpoint_interval: Optional[int] = None,
+ init_fn: Optional[Callable[[], None]] = None):
+ """See base class."""
+ super().__init__(
+ checkpoint=checkpoint,
+ directory=directory,
+ max_to_keep=max_to_keep,
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
+ checkpoint_name=checkpoint_name,
+ step_counter=step_counter,
+ checkpoint_interval=checkpoint_interval,
+ init_fn=init_fn)
+ self._modules_to_export = modules_to_export
+ self._savedmodels = self.get_existing_savedmodels()
+
+ def save(self,
+ checkpoint_number: Optional[int] = None,
+ check_interval: bool = True,
+ options: Optional[tf.train.CheckpointOptions] = None):
+ """See base class."""
+ checkpoint_path = super().save(
+ checkpoint_number=checkpoint_number,
+ check_interval=check_interval,
+ options=options)
+ if not checkpoint_path: # Nothing got written.
+ return
+ if not self._modules_to_export: # No modules to export.
+ logging.info('Skip saving SavedModel due to empty modules_to_export.')
+ return checkpoint_path
+
+ # Save the models for the checkpoint that just got written.
+ saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
+ # Atomic export of SavedModel. Write into a temporary direcotory and then
+ # rename as the final direcotory after finishing the writing.
+ # This can avoid trying to read an unfinished savedmodel.
+ saved_modules_directory_tmp = saved_modules_directory + '_temp'
+ for model_name, model in self._modules_to_export.items():
+ signatures = getattr(model, 'saved_model_signatures', None)
+ if signatures is not None:
+ tf.saved_model.save(
+ obj=model,
+ export_dir=os.path.join(saved_modules_directory_tmp, model_name),
+ signatures=signatures)
+ if tf.io.gfile.exists(saved_modules_directory_tmp):
+ tf.io.gfile.rename(saved_modules_directory_tmp, saved_modules_directory)
+
+ saved_modules_directories_to_keep = [
+ make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
+ ]
+ existing_saved_modules_dirs = self.get_existing_savedmodels()
+
+ self._savedmodels = []
+ # Keep savedmodels in the same order as checkpoints (from oldest to newest).
+ for saved_modules_dir_to_keep in saved_modules_directories_to_keep:
+ if saved_modules_dir_to_keep in existing_saved_modules_dirs:
+ self._savedmodels.append(saved_modules_dir_to_keep)
+
+ for existing_saved_modules_dir in existing_saved_modules_dirs:
+ if existing_saved_modules_dir not in self._savedmodels:
+ tf.io.gfile.rmtree(existing_saved_modules_dir)
+
+ return checkpoint_path
+
+ def get_existing_savedmodels(self) -> List[str]:
+ """Gets a list of all existing SavedModel paths in `directory`.
+
+ Returns:
+ A list of all existing SavedModel paths.
+ """
+ saved_modules_glob = make_saved_modules_directory_name(
+ self._checkpoint_prefix + '-*')
+ savedmodels = tf.io.gfile.glob(saved_modules_glob)
+ # Filter out temporary savedmodel.
+ savedmodels = [
+ savedmodel
+ for savedmodel in savedmodels
+ if savedmodel.endswith(SAVED_MODULES_PATH_SUFFIX)
+ ]
+ return savedmodels
+
+ @property
+ def latest_savedmodel(self) -> Union[str, None]:
+ """The path of the most recent SavedModel in `directory`.
+
+ Returns:
+ The latest SavedModel path. If there are no SavedModels, returns `None`.
+ """
+ if self._savedmodels:
+ return self._savedmodels[-1]
+ return None
+
+ @property
+ def savedmodels(self) -> List[str]:
+ """A list of managed SavedModels.
+
+ Returns:
+ A list of SavedModel paths, sorted from oldest to newest.
+ """
+ return self._savedmodels
+
+ @property
+ def modules_to_export(self) -> Union[Mapping[str, tf.Module], None]:
+ return self._modules_to_export
+
+ def get_savedmodel_number_from_path(self,
+ savedmodel_path: str) -> Union[int, None]:
+ """Gets the savedmodel_number/checkpoint_number from savedmodel filepath.
+
+ The savedmodel_number is global step when using with orbit controller.
+
+ Args:
+ savedmodel_path: savedmodel directory path.
+
+ Returns:
+ Savedmodel number or None if no matched pattern found in savedmodel path.
+ """
+ pattern = rf'\d+_{SAVED_MODULES_PATH_SUFFIX}$'
+ savedmodel_number = re.search(pattern, savedmodel_path)
+ if savedmodel_number:
+ savedmodel_number = savedmodel_number.group()
+ return int(savedmodel_number[:-len(SAVED_MODULES_PATH_SUFFIX) - 1])
+ return None
+
+ def savedmodels_iterator(self,
+ min_interval_secs: float = 0,
+ timeout: Optional[float] = None,
+ timeout_fn: Optional[Callable[[], bool]] = None):
+ """Continuously yield new SavedModel files as they appear.
+
+ The iterator only checks for new savedmodels when control flow has been
+ reverted to it. The logic is same to the `train.checkpoints_iterator`.
+
+ Args:
+ min_interval_secs: The minimum number of seconds between yielding
+ savedmodels.
+ timeout: The maximum number of seconds to wait between savedmodels. If
+ left as `None`, then the process will wait indefinitely.
+ timeout_fn: Optional function to call after a timeout. If the function
+ returns True, then it means that no new savedmodels will be generated
+ and the iterator will exit. The function is called with no arguments.
+
+ Yields:
+ String paths to latest SavedModel files as they arrive.
+ """
+ savedmodel_path = None
+ while True:
+ new_savedmodel_path = self.wait_for_new_savedmodel(
+ savedmodel_path, timeout=timeout)
+ if new_savedmodel_path is None:
+ if not timeout_fn:
+ # timed out
+ logging.info('Timed-out waiting for a savedmodel.')
+ return
+ if timeout_fn():
+ # The timeout_fn indicated that we are truly done.
+ return
+ else:
+ # The timeout_fn indicated that more savedmodels may come.
+ continue
+ start = time.time()
+ savedmodel_path = new_savedmodel_path
+ yield savedmodel_path
+ time_to_next_eval = start + min_interval_secs - time.time()
+ if time_to_next_eval > 0:
+ time.sleep(time_to_next_eval)
+
+ def wait_for_new_savedmodel(
+ self,
+ last_savedmodel: Optional[str] = None,
+ seconds_to_sleep: float = 1.0,
+ timeout: Optional[float] = None) -> Union[str, None]:
+ """Waits until a new savedmodel file is found.
+
+ Args:
+ last_savedmodel: The last savedmodel path used or `None` if we're
+ expecting a savedmodel for the first time.
+ seconds_to_sleep: The number of seconds to sleep for before looking for a
+ new savedmodel.
+ timeout: The maximum number of seconds to wait. If left as `None`, then
+ the process will wait indefinitely.
+
+ Returns:
+ A new savedmodel path, or None if the timeout was reached.
+ """
+ logging.info('Waiting for new savedmodel at %s', self._directory)
+ stop_time = time.time() + timeout if timeout is not None else None
+
+ last_savedmodel_number = -1
+ if last_savedmodel:
+ last_savedmodel_number = self.get_savedmodel_number_from_path(
+ last_savedmodel)
+
+ while True:
+ if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
+ return None
+
+ existing_savedmodels = {}
+ for savedmodel_path in self.get_existing_savedmodels():
+ savedmodel_number = self.get_savedmodel_number_from_path(
+ savedmodel_path)
+ if savedmodel_number is not None:
+ existing_savedmodels[savedmodel_number] = savedmodel_path
+
+ # Find the first savedmodel with larger step number as next savedmodel.
+ savedmodel_path = None
+ existing_savedmodels = dict(sorted(existing_savedmodels.items()))
+ for savedmodel_number in existing_savedmodels:
+ if savedmodel_number > last_savedmodel_number:
+ savedmodel_path = existing_savedmodels[savedmodel_number]
+ break
+
+ if savedmodel_path:
+ logging.info('Found new savedmodel at %s', savedmodel_path)
+ return savedmodel_path
+ else:
+ time.sleep(seconds_to_sleep)
diff --git a/modeling/official/core/savedmodel_checkpoint_manager_test.py b/modeling/official/core/savedmodel_checkpoint_manager_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..56ca19d7950f379345c3c3f02a8da3da60ea13e6
--- /dev/null
+++ b/modeling/official/core/savedmodel_checkpoint_manager_test.py
@@ -0,0 +1,125 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+from typing import Iterable
+
+import tensorflow as tf, tf_keras
+
+from official.core import savedmodel_checkpoint_manager
+
+
+def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
+ for model_name in models:
+ if not tf.io.gfile.isdir(
+ os.path.join(
+ savedmodel_checkpoint_manager.make_saved_modules_directory_name(
+ checkpoint_path), model_name)):
+ return False
+ return True
+
+
+class _ModelForTest(tf_keras.Model):
+ def __init__(self, hidden_size: int = 8):
+ super().__init__()
+ self.dense = tf_keras.layers.Dense(hidden_size)
+
+ @tf.function(input_signature=[tf.TensorSpec([None, 16])])
+ def call(self, inputs):
+ return self.dense(inputs)
+
+ @property
+ def saved_model_signatures(self):
+ # Build SavedModel signatures.
+ return dict(serving_default=self.call)
+
+
+class CheckpointManagerTest(tf.test.TestCase):
+
+ def _create_manager(self, max_to_keep: int = 1) -> tf.train.CheckpointManager:
+ """Sets up SavedModelCheckpointManager object.
+
+ Args:
+ max_to_keep: max number of savedmodels to keep.
+
+ Returns:
+ created savedmodel manager.
+ """
+ models = {
+ 'model_1': _ModelForTest(12),
+ 'model_2': _ModelForTest(14),
+ }
+ checkpoint = tf.train.Checkpoint()
+ manager = savedmodel_checkpoint_manager.SavedModelCheckpointManager(
+ checkpoint=checkpoint,
+ directory=self.get_temp_dir(),
+ max_to_keep=max_to_keep,
+ modules_to_export=models)
+ return manager
+
+ def test_max_to_keep(self):
+ manager = self._create_manager()
+ models = manager.modules_to_export
+ first_path = manager.save()
+ second_path = manager.save()
+
+ savedmodel = savedmodel_checkpoint_manager.make_saved_modules_directory_name(
+ manager.latest_checkpoint)
+ self.assertEqual(savedmodel, manager.latest_savedmodel)
+ self.assertTrue(_models_exist(second_path, models.keys()))
+ self.assertFalse(_models_exist(first_path, models.keys()))
+
+ def test_returns_none_after_timeout(self):
+ manager = self._create_manager()
+ start = time.time()
+ ret = manager.wait_for_new_savedmodel(
+ None, timeout=1.0, seconds_to_sleep=0.5)
+ end = time.time()
+ self.assertIsNone(ret)
+ # We've waited 0.5 second.
+ self.assertGreater(end, start + 0.5)
+ # The timeout kicked in.
+ self.assertLess(end, start + 0.6)
+
+ def test_saved_model_iterator(self):
+ manager = self._create_manager(max_to_keep=2)
+ self.assertIsNotNone(manager.save(checkpoint_number=1))
+ self.assertIsNotNone(manager.save(checkpoint_number=2))
+ self.assertIsNotNone(manager.save(checkpoint_number=3))
+
+ # Savedmodels are in time order.
+ expected_savedmodels = manager.savedmodels
+ # Order not guaranteed.
+ existing_savedmodels = manager.get_existing_savedmodels()
+ savedmodels = list(manager.savedmodels_iterator(timeout=3.0))
+ self.assertEqual(savedmodels, expected_savedmodels)
+ self.assertEqual(set(savedmodels), set(existing_savedmodels))
+
+ def test_saved_model_iterator_timeout_fn(self):
+ manager = self._create_manager()
+ timeout_fn_calls = [0]
+
+ def timeout_fn():
+ timeout_fn_calls[0] += 1
+ return timeout_fn_calls[0] > 3
+
+ results = list(
+ manager.savedmodels_iterator(timeout=0.1, timeout_fn=timeout_fn))
+ self.assertEqual([], results)
+ self.assertEqual(4, timeout_fn_calls[0])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/task_factory.py b/modeling/official/core/task_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8833231ce32882a7d240bf2dd6eeb180994f35a
--- /dev/null
+++ b/modeling/official/core/task_factory.py
@@ -0,0 +1,70 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A global factory to register and access all registered tasks."""
+
+from official.core import registry
+
+_REGISTERED_TASK_CLS = {}
+
+
+# TODO(b/158741360): Add type annotations once pytype checks across modules.
+def register_task_cls(task_config_cls):
+ """Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
+
+ This decorator supports registration of tasks as follows:
+
+ ```
+ @dataclasses.dataclass
+ class MyTaskConfig(TaskConfig):
+ # Add fields here.
+ pass
+
+ @register_task_cls(MyTaskConfig)
+ class MyTask(Task):
+ # Inherits def __init__(self, task_config).
+ pass
+
+ my_task_config = MyTaskConfig()
+ my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
+ ```
+
+ Besisdes a class itself, other callables that create a Task from a TaskConfig
+ can be decorated by the result of this function, as long as there is at most
+ one registration for each config class.
+
+ Args:
+ task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
+ Each task_config_cls can only be used for a single registration.
+
+ Returns:
+ A callable for use as class decorator that registers the decorated class
+ for creation from an instance of task_config_cls.
+ """
+ return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
+
+
+def get_task(task_config, **kwargs):
+ """Creates a Task (of suitable subclass type) from task_config."""
+ # TODO(hongkuny): deprecate the task factory to use config.BUILDER.
+ if task_config.BUILDER is not None:
+ return task_config.BUILDER(task_config, **kwargs)
+ return get_task_cls(task_config.__class__)(task_config, **kwargs)
+
+
+# The user-visible get_task() is defined after classes have been registered.
+# TODO(b/158741360): Add type annotations once pytype checks across modules.
+def get_task_cls(task_config_cls):
+ task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
+ return task_cls
diff --git a/modeling/official/core/test_utils.py b/modeling/official/core/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..266702e69fb12afc66cdda084e8f5bf2cdaf008b
--- /dev/null
+++ b/modeling/official/core/test_utils.py
@@ -0,0 +1,59 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils for testing."""
+
+import tensorflow as tf, tf_keras
+
+
+class FakeKerasModel(tf_keras.Model):
+ """Fake keras model for testing."""
+
+ def __init__(self):
+ super().__init__()
+ self.dense = tf_keras.layers.Dense(4, activation=tf.nn.relu)
+ self.dense2 = tf_keras.layers.Dense(4, activation=tf.nn.relu)
+
+ def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
+ return self.dense2(self.dense(inputs))
+
+
+class _Dense(tf.Module):
+ """A dense layer."""
+
+ def __init__(self, input_dim, output_size, name=None):
+ super().__init__(name=name)
+ with self.name_scope:
+ self.w = tf.Variable(
+ tf.random.normal([input_dim, output_size]), name='w')
+ self.b = tf.Variable(tf.zeros([output_size]), name='b')
+
+ @tf.Module.with_name_scope
+ def __call__(self, x):
+ y = tf.matmul(x, self.w) + self.b
+ return tf.nn.relu(y)
+
+
+class FakeModule(tf.Module):
+ """Fake model using tf.Module for testing."""
+
+ def __init__(self, input_size, name=None):
+ super().__init__(name=name)
+ with self.name_scope:
+ self.dense = _Dense(input_size, 4, name='dense')
+ self.dense2 = _Dense(4, 4, name='dense_1')
+
+ @tf.Module.with_name_scope
+ def __call__(self, x):
+ return self.dense2(self.dense(x))
diff --git a/modeling/official/core/tf_example_builder.py b/modeling/official/core/tf_example_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2be1ca2b9c30c7f88f662354052d48d12070aa30
--- /dev/null
+++ b/modeling/official/core/tf_example_builder.py
@@ -0,0 +1,144 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Builder class for preparing tf.train.Example."""
+
+# https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7
+from __future__ import annotations
+
+from typing import Mapping, Sequence, Union
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+BytesValueType = Union[bytes, Sequence[bytes], str, Sequence[str]]
+
+_to_array = lambda v: [v] if not isinstance(v, (list, np.ndarray)) else v
+_to_bytes = lambda v: v.encode() if isinstance(v, str) else v
+_to_bytes_array = lambda v: list(map(_to_bytes, _to_array(v)))
+
+
+class TfExampleBuilder(object):
+ """Builder class for preparing tf.train.Example.
+
+ Read API doc at https://www.tensorflow.org/api_docs/python/tf/train/Example.
+
+ Example usage:
+ >>> example_builder = TfExampleBuilder()
+ >>> example = (
+ example_builder.add_bytes_feature('feature_a', 'foobarbaz')
+ .add_ints_feature('feature_b', [1, 2, 3])
+ .example)
+ """
+
+ def __init__(self) -> None:
+ self._example = tf.train.Example()
+
+ @property
+ def example(self) -> tf.train.Example:
+ """Returns a copy of the generated tf.train.Example proto."""
+ return self._example
+
+ @property
+ def serialized_example(self) -> str:
+ """Returns a serialized string of the generated tf.train.Example proto."""
+ return self._example.SerializeToString()
+
+ def set(self, example: tf.train.Example) -> TfExampleBuilder:
+ """Sets the example."""
+ self._example = example
+ return self
+
+ def reset(self) -> TfExampleBuilder:
+ """Resets the example to an empty proto."""
+ self._example = tf.train.Example()
+ return self
+
+ ###### Basic APIs for primitive data types ######
+ def add_feature_dict(
+ self, feature_dict: Mapping[str, tf.train.Feature]) -> TfExampleBuilder:
+ """Adds the predefined `feature_dict` to the example.
+
+ Note: Please prefer to using feature-type-specific methods.
+
+ Args:
+ feature_dict: A dictionary from tf.Example feature key to
+ tf.train.Feature.
+
+ Returns:
+ The builder object for subsequent method calls.
+ """
+ for k, v in feature_dict.items():
+ self._example.features.feature[k].CopyFrom(v)
+ return self
+
+ def add_feature(self, key: str,
+ feature: tf.train.Feature) -> TfExampleBuilder:
+ """Adds predefined `feature` with `key` to the example.
+
+ Args:
+ key: String key of the feature.
+ feature: The feature to be added to the example.
+
+ Returns:
+ The builder object for subsequent method calls.
+ """
+ self._example.features.feature[key].CopyFrom(feature)
+ return self
+
+ def add_bytes_feature(self, key: str,
+ value: BytesValueType) -> TfExampleBuilder:
+ """Adds byte(s) or string(s) with `key` to the example.
+
+ Args:
+ key: String key of the feature.
+ value: The byte(s) or string(s) to be added to the example.
+
+ Returns:
+ The builder object for subsequent method calls.
+ """
+ return self.add_feature(
+ key,
+ tf.train.Feature(
+ bytes_list=tf.train.BytesList(value=_to_bytes_array(value))))
+
+ def add_ints_feature(self, key: str,
+ value: Union[int, Sequence[int]]) -> TfExampleBuilder:
+ """Adds integer(s) with `key` to the example.
+
+ Args:
+ key: String key of the feature.
+ value: The integer(s) to be added to the example.
+
+ Returns:
+ The builder object for subsequent method calls.
+ """
+ return self.add_feature(
+ key,
+ tf.train.Feature(int64_list=tf.train.Int64List(value=_to_array(value))))
+
+ def add_floats_feature(
+ self, key: str, value: Union[float, Sequence[float]]) -> TfExampleBuilder:
+ """Adds float(s) with `key` to the example.
+
+ Args:
+ key: String key of the feature.
+ value: The float(s) to be added to the example.
+
+ Returns:
+ The builder object for subsequent method calls.
+ """
+ return self.add_feature(
+ key,
+ tf.train.Feature(float_list=tf.train.FloatList(value=_to_array(value))))
diff --git a/modeling/official/core/tf_example_builder_test.py b/modeling/official/core/tf_example_builder_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b26ef5443734452de17b30ce50a856d80bb6c28d
--- /dev/null
+++ b/modeling/official/core/tf_example_builder_test.py
@@ -0,0 +1,165 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for tf_example_builder.
+
+See `test_add_image_matrix_feature_with_fake_image` for the typical structure of
+a unit test.
+"""
+
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+from official.core import tf_example_builder
+
+
+class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
+
+ def test_init_an_empty_example(self):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example = example_builder.example
+ self.assertProtoEquals('', example)
+
+ def test_init_an_empty_serialized_example(self):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example = example_builder.serialized_example
+ self.assertProtoEquals('', example)
+
+ def test_add_feature(self):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example_builder.add_feature(
+ 'foo',
+ tf.train.Feature(
+ bytes_list=tf.train.BytesList(value=[b'Hello World!'])))
+ example = example_builder.example
+ # Use proto text to show how the entire proto would look like.
+ self.assertProtoEquals(
+ """
+ features: {
+ feature: {
+ key: "foo"
+ value: {
+ bytes_list: {
+ value: "Hello World!"
+ }
+ }
+ }
+ }""", example)
+
+ def test_add_feature_dict(self):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example_builder.add_feature_dict({
+ 'foo':
+ tf.train.Feature(
+ bytes_list=tf.train.BytesList(value=[b'Hello World!'])),
+ 'bar':
+ tf.train.Feature(
+ int64_list=tf.train.Int64List(value=[299, 792, 458]))
+ })
+ example = example_builder.example
+ # Use proto text to show how the entire proto would look like.
+ self.assertProtoEquals(
+ """
+ features: {
+ feature: {
+ key: "foo"
+ value: {
+ bytes_list: {
+ value: "Hello World!"
+ }
+ }
+ }
+ feature: {
+ key: "bar"
+ value: {
+ int64_list: {
+ value: 299
+ value: 792
+ value: 458
+ }
+ }
+ }
+ }""", example)
+
+ @parameterized.named_parameters(
+ ('single_bytes', b'Hello World!', b'Hello World!'),
+ ('single_string', 'Hello World!', b'Hello World!'))
+ def test_add_single_byte_feature(self, value, expected_value):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example_builder.add_bytes_feature('foo', value)
+ example = example_builder.example
+ # Use constructor to easily work with test parameters.
+ self.assertProtoEquals(
+ tf.train.Example(
+ features=tf.train.Features(
+ feature={
+ 'foo':
+ tf.train.Feature(
+ bytes_list=tf.train.BytesList(
+ value=[expected_value]))
+ })), example)
+
+ @parameterized.named_parameters(
+ ('multiple_bytes', [b'Hello World!', b'Good Morning!'
+ ], [b'Hello World!', b'Good Morning!']),
+ ('multiple_sring', ['Hello World!', 'Good Morning!'
+ ], [b'Hello World!', b'Good Morning!']))
+ def test_add_multiple_bytes_feature(self, values, expected_values):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example_builder.add_bytes_feature('foo', values)
+ example = example_builder.example
+ self.assertProtoEquals(
+ tf.train.Example(
+ features=tf.train.Features(
+ feature={
+ 'foo':
+ tf.train.Feature(
+ bytes_list=tf.train.BytesList(
+ value=expected_values))
+ })), example)
+
+ @parameterized.named_parameters(
+ ('single_integer', 123, [123]),
+ ('multiple_integers', [123, 456, 789], [123, 456, 789]))
+ def test_add_ints_feature(self, value, expected_value):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example_builder.add_ints_feature('bar', value)
+ example = example_builder.example
+ self.assertProtoEquals(
+ tf.train.Example(
+ features=tf.train.Features(
+ feature={
+ 'bar':
+ tf.train.Feature(
+ int64_list=tf.train.Int64List(value=expected_value))
+ })), example)
+
+ @parameterized.named_parameters(
+ ('single_float', 3.14, [3.14]),
+ ('multiple_floats', [3.14, 1.57, 6.28], [3.14, 1.57, 6.28]))
+ def test_add_floats_feature(self, value, expected_value):
+ example_builder = tf_example_builder.TfExampleBuilder()
+ example_builder.add_floats_feature('baz', value)
+ example = example_builder.example
+ self.assertProtoEquals(
+ tf.train.Example(
+ features=tf.train.Features(
+ feature={
+ 'baz':
+ tf.train.Feature(
+ float_list=tf.train.FloatList(value=expected_value))
+ })), example)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/tf_example_feature_key.py b/modeling/official/core/tf_example_feature_key.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6cd748aedb9dc560aeb2b02ec4327bba66b0c78
--- /dev/null
+++ b/modeling/official/core/tf_example_feature_key.py
@@ -0,0 +1,62 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data classes for tf.Example proto feature keys.
+
+Feature keys are grouped by feature types. Key names follow conventions in
+go/tf-example.
+"""
+import dataclasses
+import functools
+from typing import Optional
+
+# Disable init function to use the one defined in base class.
+dataclass = functools.partial(dataclasses.dataclass(init=False))
+
+
+@dataclass
+class TfExampleFeatureKeyBase:
+ """Base dataclass for defining tf.Example proto feature keys.
+
+ This class defines the logic of adding prefix to feature keys. Subclasses
+ will define feature keys for a specific feature type in data fields.
+
+ NOTE: Please follow subclass examples in this module to define feature keys
+ for a new feature type.
+ """
+
+ def __init__(self, prefix: Optional[str] = None):
+ """Instantiates the feature key class.
+
+ Adds a string prefix to all fields of a feature key instance if `prefix` is
+ not None nor empty.
+
+ Example usage:
+
+ >>> test_key = EncodedImageFeatureKey()
+ >>> test_key.encoded
+ image/encoded
+ >>> test_key = EncodedImageFeatureKey('prefix')
+ >>> test_key.encoded
+ prefix/image/encoded
+
+ Args:
+ prefix: A prefix string that will be added before the feature key string
+ with a trailing slash '/'.
+ """
+ if prefix:
+ for field in dataclasses.fields(self): # pytype: disable=wrong-arg-types # re-none
+ key_name = field.name
+ key_value = getattr(self, key_name)
+ setattr(self, key_name, f'{prefix}/{key_value}')
diff --git a/modeling/official/core/tf_example_feature_key_test.py b/modeling/official/core/tf_example_feature_key_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e40552c08aa68860958b97bb28b9e87c824e2624
--- /dev/null
+++ b/modeling/official/core/tf_example_feature_key_test.py
@@ -0,0 +1,49 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for tf_example_feature_key."""
+import dataclasses
+import inspect
+from absl.testing import absltest
+from absl.testing import parameterized
+
+from official.core import tf_example_feature_key
+
+
+@tf_example_feature_key.dataclass
+class TestFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
+ test: str = 'foo/bar'
+
+
+class TfExampleFeatureKeyTest(parameterized.TestCase):
+
+ def test_add_prefix_success(self):
+ test_key = TestFeatureKey('prefix')
+ self.assertEqual(test_key.test, 'prefix/foo/bar')
+
+ @parameterized.parameters(None, '')
+ def test_add_prefix_skip_success(self, prefix):
+ test_key = TestFeatureKey(prefix)
+ self.assertEqual(test_key.test, 'foo/bar')
+
+ def test_all_feature_key_classes_are_valid(self):
+ for _, obj in inspect.getmembers(tf_example_feature_key):
+ if inspect.isclass(obj):
+ self.assertTrue(dataclasses.is_dataclass(obj))
+ self.assertTrue(
+ issubclass(obj, tf_example_feature_key.TfExampleFeatureKeyBase))
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/modeling/official/core/train_lib.py b/modeling/official/core/train_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..37bc051e8bfac74b114008a03f84ffbd7dd8259f
--- /dev/null
+++ b/modeling/official/core/train_lib.py
@@ -0,0 +1,372 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TFM common training driver library."""
+# pytype: disable=attribute-error
+import os
+import tempfile
+from typing import Any, List, Mapping, Optional, Tuple
+
+# Import libraries
+
+from absl import logging
+import orbit
+import tensorflow as tf, tf_keras
+
+from official.core import actions
+from official.core import base_task
+from official.core import base_trainer
+from official.core import config_definitions
+from official.core import train_utils
+
+maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
+
+
+class OrbitExperimentRunner:
+ """Runs experiment with Orbit training loop.
+
+ The default experiment runner for model garden experiments. User can
+ customize the experiment pipeline by subclassing this class and replacing
+ components or functions.
+
+ For example, an experiment runner with customized checkpoint manager:
+
+ ```python
+ class MyExpRunnerWithExporter(OrbitExperimentRunner):
+ def _maybe_build_checkpoint_manager(sefl):
+ # Replaces the default CheckpointManger with a customized one.
+ return MyCheckpointManager(*args)
+
+ # In user code, instead of the orginal
+ # `OrbitExperimentRunner(..).run(mode)`, now user can do:
+ MyExpRunnerWithExporter(**needed_kwargs).run(mode)
+ ```
+
+ Similar override can be done to other components.
+ """
+
+ def __init__(
+ self,
+ distribution_strategy: tf.distribute.Strategy,
+ task: base_task.Task,
+ mode: str,
+ params: config_definitions.ExperimentConfig,
+ model_dir: str,
+ run_post_eval: bool = False,
+ save_summary: bool = True,
+ train_actions: Optional[List[orbit.Action]] = None,
+ eval_actions: Optional[List[orbit.Action]] = None,
+ trainer: Optional[base_trainer.Trainer] = None,
+ controller_cls=orbit.Controller,
+ summary_manager: Optional[orbit.utils.SummaryManager] = None,
+ eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
+ enable_async_checkpointing: bool = False,
+ ):
+ """Constructor.
+
+ Args:
+ distribution_strategy: A distribution strategy.
+ task: A Task instance.
+ mode: A 'str', specifying the mode. Can be 'train', 'eval',
+ 'train_and_eval' or 'continuous_eval'.
+ params: ExperimentConfig instance.
+ model_dir: A 'str', a path to store model checkpoints and summaries.
+ run_post_eval: Whether to run post eval once after training, metrics logs
+ are returned.
+ save_summary: Whether to save train and validation summary.
+ train_actions: Optional list of Orbit train actions.
+ eval_actions: Optional list of Orbit eval actions.
+ trainer: the base_trainer.Trainer instance. It should be created within
+ the strategy.scope().
+ controller_cls: The controller class to manage the train and eval process.
+ Must be a orbit.Controller subclass.
+ summary_manager: Instance of the summary manager to override default
+ summary manager.
+ eval_summary_manager: Instance of the eval summary manager to override
+ default eval summary manager.
+ enable_async_checkpointing: Optional boolean indicating whether to enable
+ async checkpoint saving.
+ """
+ self.strategy = distribution_strategy or tf.distribute.get_strategy()
+ self._params = params
+ self._model_dir = model_dir
+ self._mode = mode
+ self._run_post_eval = run_post_eval
+
+ self._trainer = trainer or self._build_trainer(
+ task,
+ train='train' in mode,
+ evaluate=('eval' in mode) or run_post_eval)
+ assert self.trainer is not None
+ self._checkpoint_manager = self._maybe_build_checkpoint_manager()
+ self._summary_manager = summary_manager
+ self._eval_summary_manager = eval_summary_manager
+ self._controller = self._build_controller(
+ trainer=self.trainer if 'train' in mode else None,
+ evaluator=self.trainer,
+ save_summary=save_summary,
+ train_actions=train_actions,
+ eval_actions=eval_actions,
+ controller_cls=controller_cls,
+ enable_async_checkpointing=enable_async_checkpointing)
+
+ @property
+ def params(self) -> config_definitions.ExperimentConfig:
+ """The whole experiment parameters object."""
+ return self._params
+
+ @property
+ def model_dir(self) -> str:
+ """Path to the model folder, which stores checkpoints, params, log, etc."""
+ return self._model_dir
+
+ @property
+ def trainer(self) -> base_trainer.Trainer:
+ """The underlying Orbit Trainer object."""
+ return self._trainer
+
+ @property
+ def checkpoint_manager(self) -> Optional[tf.train.CheckpointManager]:
+ """The CheckpointManager that stores the checkpoints in a train job."""
+ return self._checkpoint_manager
+
+ @property
+ def controller(self) -> orbit.Controller:
+ """The Orbit controller object."""
+ return self._controller
+
+ def _build_trainer(self, task: base_task.Task, train: bool,
+ evaluate: bool) -> base_trainer.Trainer:
+ """Create trainer."""
+ with self.strategy.scope():
+ trainer = train_utils.create_trainer(
+ self.params,
+ task,
+ train=train,
+ evaluate=evaluate,
+ checkpoint_exporter=self._build_best_checkpoint_exporter())
+ return trainer
+
+ def _build_best_checkpoint_exporter(self):
+ return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
+
+ def _maybe_build_checkpoint_manager(
+ self) -> Optional[tf.train.CheckpointManager]:
+ """Maybe create a CheckpointManager."""
+ assert self.trainer is not None
+ if self.trainer.checkpoint:
+ if self.model_dir is None:
+ raise ValueError('model_dir must be specified, but got None')
+
+ if (not self.strategy) or self.strategy.extended.should_checkpoint:
+ ckpt_path = self.model_dir
+ max_to_keep = self.params.trainer.max_to_keep
+ else:
+ # In multi worker training we need every worker to save checkpoint,
+ # because variables can trigger synchronization on read and
+ # synchronization needs all workers to participate. To avoid workers
+ # overriding each other we save to a temporary directory on non-chief
+ # workers.
+ ckpt_path = tempfile.mkdtemp()
+ max_to_keep = 1
+
+ checkpoint_manager = tf.train.CheckpointManager(
+ self.trainer.checkpoint,
+ directory=ckpt_path,
+ max_to_keep=max_to_keep,
+ step_counter=self.trainer.global_step,
+ checkpoint_interval=self.params.trainer.checkpoint_interval,
+ init_fn=self.trainer.initialize)
+ else:
+ checkpoint_manager = None
+ return checkpoint_manager
+
+ def _build_controller(
+ self,
+ trainer,
+ evaluator,
+ save_summary: bool = True,
+ train_actions: Optional[List[orbit.Action]] = None,
+ eval_actions: Optional[List[orbit.Action]] = None,
+ controller_cls=orbit.Controller,
+ enable_async_checkpointing: bool = False,
+ ) -> orbit.Controller:
+ """Builds a Orbit controler."""
+ train_actions = [] if not train_actions else train_actions
+ if trainer:
+ checkpoint_manager = self.checkpoint_manager
+ assert checkpoint_manager, 'Checkpoint manager required but undefined.'
+ train_actions += actions.get_train_actions(
+ self.params,
+ trainer,
+ self.model_dir,
+ checkpoint_manager=checkpoint_manager,
+ )
+
+ eval_actions = [] if not eval_actions else eval_actions
+ if evaluator:
+ eval_actions += actions.get_eval_actions(self.params, evaluator,
+ self.model_dir)
+
+ if save_summary:
+ eval_summary_dir = os.path.join(
+ self.model_dir, self.params.trainer.validation_summary_subdir
+ )
+ else:
+ eval_summary_dir = None
+
+ controller = controller_cls(
+ strategy=self.strategy,
+ trainer=trainer,
+ evaluator=evaluator,
+ global_step=self.trainer.global_step,
+ steps_per_loop=self.params.trainer.steps_per_loop,
+ checkpoint_manager=self.checkpoint_manager,
+ enable_async_checkpointing=enable_async_checkpointing,
+ summary_dir=os.path.join(self.model_dir, 'train')
+ if (save_summary)
+ else None,
+ eval_summary_dir=eval_summary_dir,
+ summary_interval=self.params.trainer.summary_interval
+ if (save_summary)
+ else None,
+ train_actions=train_actions,
+ eval_actions=eval_actions,
+ summary_manager=self._summary_manager
+ if hasattr(self, '_summary_manager')
+ else None,
+ eval_summary_manager=self._eval_summary_manager
+ if hasattr(self, '_eval_summary_manager')
+ else None,
+ )
+ return controller
+
+ def run(self) -> Tuple[tf_keras.Model, Mapping[str, Any]]:
+ """Run experiments by mode.
+
+ Returns:
+ A 2-tuple of (model, eval_logs).
+ model: `tf_keras.Model` instance.
+ eval_logs: returns eval metrics logs when run_post_eval is set to True,
+ otherwise, returns {}.
+ """
+ mode = self._mode
+ params = self.params
+ logging.info('Starts to execute mode: %s', mode)
+ with self.strategy.scope():
+ if mode == 'train' or mode == 'train_and_post_eval':
+ self.controller.train(steps=params.trainer.train_steps)
+ elif mode == 'train_and_eval':
+ self.controller.train_and_evaluate(
+ train_steps=params.trainer.train_steps,
+ eval_steps=params.trainer.validation_steps,
+ eval_interval=params.trainer.validation_interval)
+ elif mode == 'eval':
+ self.controller.evaluate(steps=params.trainer.validation_steps)
+ elif mode == 'continuous_eval':
+
+ def timeout_fn():
+ if self.trainer.global_step.numpy() >= params.trainer.train_steps:
+ return True
+ return False
+
+ self.controller.evaluate_continuously(
+ steps=params.trainer.validation_steps,
+ timeout=params.trainer.continuous_eval_timeout,
+ timeout_fn=timeout_fn)
+ else:
+ raise NotImplementedError('The mode is not implemented: %s' % mode)
+
+ num_params = train_utils.try_count_params(self.trainer.model)
+ if num_params is not None:
+ logging.info('Number of trainable params in model: %f Millions.',
+ num_params / 10.**6)
+
+ flops = train_utils.try_count_flops(self.trainer.model)
+ if flops is not None:
+ logging.info('FLOPs (multi-adds) in model: %f Billions.',
+ flops / 10.**9 / 2)
+
+ if self._run_post_eval or mode == 'train_and_post_eval':
+ with self.strategy.scope():
+ return self.trainer.model, self.controller.evaluate(
+ steps=params.trainer.validation_steps)
+ else:
+ return self.trainer.model, {}
+
+
+def run_experiment(
+ distribution_strategy: tf.distribute.Strategy,
+ task: base_task.Task,
+ mode: str,
+ params: config_definitions.ExperimentConfig,
+ model_dir: str,
+ run_post_eval: bool = False,
+ save_summary: bool = True,
+ train_actions: Optional[List[orbit.Action]] = None,
+ eval_actions: Optional[List[orbit.Action]] = None,
+ trainer: Optional[base_trainer.Trainer] = None,
+ controller_cls=orbit.Controller,
+ summary_manager: Optional[orbit.utils.SummaryManager] = None,
+ eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
+ enable_async_checkpointing: bool = False,
+) -> Tuple[tf_keras.Model, Mapping[str, Any]]:
+ """Runs train/eval configured by the experiment params.
+
+ Args:
+ distribution_strategy: A distribution distribution_strategy.
+ task: A Task instance.
+ mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
+ or 'continuous_eval'.
+ params: ExperimentConfig instance.
+ model_dir: A 'str', a path to store model checkpoints and summaries.
+ run_post_eval: Whether to run post eval once after training, metrics logs
+ are returned.
+ save_summary: Whether to save train and validation summary.
+ train_actions: Optional list of Orbit train actions.
+ eval_actions: Optional list of Orbit eval actions.
+ trainer: the base_trainer.Trainer instance. It should be created within the
+ strategy.scope().
+ controller_cls: The controller class to manage the train and eval process.
+ Must be a orbit.Controller subclass.
+ summary_manager: Instance of the summary manager to override default summary
+ manager.
+ eval_summary_manager: Instance of the eval summary manager to override
+ default eval summary manager.
+ enable_async_checkpointing: Optional boolean indicating whether to enable
+ async checkpoint saving.
+
+ Returns:
+ A 2-tuple of (model, eval_logs).
+ model: `tf_keras.Model` instance.
+ eval_logs: returns eval metrics logs when run_post_eval is set to True,
+ otherwise, returns {}.
+ """
+ runner = OrbitExperimentRunner(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode=mode,
+ params=params,
+ model_dir=model_dir,
+ run_post_eval=run_post_eval,
+ save_summary=save_summary,
+ train_actions=train_actions,
+ eval_actions=eval_actions,
+ trainer=trainer,
+ controller_cls=controller_cls,
+ summary_manager=summary_manager,
+ eval_summary_manager=eval_summary_manager,
+ enable_async_checkpointing=enable_async_checkpointing,
+ )
+ return runner.run()
diff --git a/modeling/official/core/train_lib_test.py b/modeling/official/core/train_lib_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b77fde2ca6962651a2022ef54c48e99a1d426aad
--- /dev/null
+++ b/modeling/official/core/train_lib_test.py
@@ -0,0 +1,280 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for train_ctl_lib."""
+import json
+import os
+
+from absl import flags
+from absl.testing import flagsaver
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.common import flags as tfm_flags
+# pylint: disable=unused-import
+from official.common import registry_imports
+# pylint: enable=unused-import
+from official.core import task_factory
+from official.core import train_lib
+from official.core import train_utils
+from official.utils.testing import mock_task
+
+FLAGS = flags.FLAGS
+
+tfm_flags.define_flags()
+
+
+class TrainTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(TrainTest, self).setUp()
+ self._test_config = {
+ 'trainer': {
+ 'checkpoint_interval': 10,
+ 'steps_per_loop': 10,
+ 'summary_interval': 10,
+ 'train_steps': 10,
+ 'validation_steps': 5,
+ 'validation_interval': 10,
+ 'continuous_eval_timeout': 1,
+ 'validation_summary_subdir': 'validation',
+ 'optimizer_config': {
+ 'optimizer': {
+ 'type': 'sgd',
+ },
+ 'learning_rate': {
+ 'type': 'constant'
+ }
+ }
+ },
+ }
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ flag_mode=['train', 'eval', 'train_and_eval'],
+ run_post_eval=[True, False]))
+ def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
+ model_dir = self.get_temp_dir()
+ flags_dict = dict(
+ experiment='mock',
+ mode=flag_mode,
+ model_dir=model_dir,
+ params_override=json.dumps(self._test_config))
+ with flagsaver.flagsaver(**flags_dict):
+ params = train_utils.parse_configuration(flags.FLAGS)
+ train_utils.serialize_config(params, model_dir)
+ with distribution_strategy.scope():
+ task = task_factory.get_task(params.task, logging_dir=model_dir)
+
+ _, logs = train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode=flag_mode,
+ params=params,
+ model_dir=model_dir,
+ run_post_eval=run_post_eval)
+
+ if 'eval' in flag_mode:
+ self.assertTrue(
+ tf.io.gfile.exists(
+ os.path.join(model_dir,
+ params.trainer.validation_summary_subdir)))
+ if run_post_eval:
+ self.assertNotEmpty(logs)
+ else:
+ self.assertEmpty(logs)
+ self.assertNotEmpty(
+ tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
+ if flag_mode == 'eval':
+ return
+ self.assertNotEmpty(
+ tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
+ # Tests continuous evaluation.
+ _, logs = train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode='continuous_eval',
+ params=params,
+ model_dir=model_dir,
+ run_post_eval=run_post_eval)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ flag_mode=['train', 'eval', 'train_and_eval'],
+ run_post_eval=[True, False]))
+ def test_end_to_end_class(self, distribution_strategy, flag_mode,
+ run_post_eval):
+ model_dir = self.get_temp_dir()
+ flags_dict = dict(
+ experiment='mock',
+ mode=flag_mode,
+ model_dir=model_dir,
+ params_override=json.dumps(self._test_config))
+ with flagsaver.flagsaver(**flags_dict):
+ params = train_utils.parse_configuration(flags.FLAGS)
+ train_utils.serialize_config(params, model_dir)
+ with distribution_strategy.scope():
+ task = task_factory.get_task(params.task, logging_dir=model_dir)
+
+ _, logs = train_lib.OrbitExperimentRunner(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode=flag_mode,
+ params=params,
+ model_dir=model_dir,
+ run_post_eval=run_post_eval).run()
+
+ if 'eval' in flag_mode:
+ self.assertTrue(
+ tf.io.gfile.exists(
+ os.path.join(model_dir,
+ params.trainer.validation_summary_subdir)))
+ if run_post_eval:
+ self.assertNotEmpty(logs)
+ else:
+ self.assertEmpty(logs)
+ self.assertNotEmpty(
+ tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
+ if flag_mode == 'eval':
+ return
+ self.assertNotEmpty(
+ tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
+ # Tests continuous evaluation.
+ _, logs = train_lib.OrbitExperimentRunner(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode='continuous_eval',
+ params=params,
+ model_dir=model_dir,
+ run_post_eval=run_post_eval).run()
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ flag_mode=['train', 'train_and_eval'],
+ ))
+ def test_recovery_nan_error(self, distribution_strategy, flag_mode):
+ model_dir = self.get_temp_dir()
+ flags_dict = dict(
+ experiment='mock',
+ mode=flag_mode,
+ model_dir=model_dir,
+ params_override=json.dumps(self._test_config))
+ with flagsaver.flagsaver(**flags_dict):
+ params = train_utils.parse_configuration(flags.FLAGS)
+ train_utils.serialize_config(params, model_dir)
+ with distribution_strategy.scope():
+ # task = task_factory.get_task(params.task, logging_dir=model_dir)
+ task = mock_task.MockTask(params.task, logging_dir=model_dir)
+
+ # Set the loss to NaN to trigger RunTimeError.
+ def build_losses(labels, model_outputs, aux_losses=None):
+ del labels, model_outputs
+ return tf.constant([np.nan], tf.float32) + aux_losses
+
+ task.build_losses = build_losses
+
+ with self.assertRaises(RuntimeError):
+ train_lib.OrbitExperimentRunner(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode=flag_mode,
+ params=params,
+ model_dir=model_dir).run()
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ flag_mode=['train'],
+ ))
+ def test_recovery(self, distribution_strategy, flag_mode):
+ loss_threshold = 1.0
+ model_dir = self.get_temp_dir()
+ flags_dict = dict(
+ experiment='mock',
+ mode=flag_mode,
+ model_dir=model_dir,
+ params_override=json.dumps(self._test_config))
+ with flagsaver.flagsaver(**flags_dict):
+ params = train_utils.parse_configuration(flags.FLAGS)
+ params.trainer.loss_upper_bound = loss_threshold
+ params.trainer.recovery_max_trials = 1
+ train_utils.serialize_config(params, model_dir)
+ with distribution_strategy.scope():
+ task = task_factory.get_task(params.task, logging_dir=model_dir)
+
+ # Saves a checkpoint for reference.
+ model = task.build_model()
+ checkpoint = tf.train.Checkpoint(model=model)
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint, self.get_temp_dir(), max_to_keep=2)
+ checkpoint_manager.save()
+ before_weights = model.get_weights()
+
+ def build_losses(labels, model_outputs, aux_losses=None):
+ del labels, model_outputs
+ return tf.constant([loss_threshold], tf.float32) + aux_losses
+
+ task.build_losses = build_losses
+
+ model, _ = train_lib.OrbitExperimentRunner(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode=flag_mode,
+ params=params,
+ model_dir=model_dir).run()
+ after_weights = model.get_weights()
+ for left, right in zip(before_weights, after_weights):
+ self.assertAllEqual(left, right)
+
+ def test_parse_configuration(self):
+ model_dir = self.get_temp_dir()
+ flags_dict = dict(
+ experiment='mock',
+ mode='train',
+ model_dir=model_dir,
+ params_override=json.dumps(self._test_config))
+ with flagsaver.flagsaver(**flags_dict):
+ params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
+ with self.assertRaises(ValueError):
+ params.override({'task': {'init_checkpoint': 'Foo'}})
+
+ params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
+ params.override({'task': {'init_checkpoint': 'Bar'}})
+ self.assertEqual(params.task.init_checkpoint, 'Bar')
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/core/train_utils.py b/modeling/official/core/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..984ac319d504b067d2de0af138a012d8782af467
--- /dev/null
+++ b/modeling/official/core/train_utils.py
@@ -0,0 +1,610 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Training utils."""
+
+import dataclasses
+import inspect
+import json
+import os
+import pprint
+from typing import Any, Callable, Dict, List, Optional, Union
+
+from absl import logging
+import gin
+import numpy as np
+import orbit
+import tensorflow as tf, tf_keras
+
+# pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.framework import ops
+from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
+# pylint: enable=g-direct-tensorflow-import
+from official.core import base_task
+from official.core import base_trainer
+from official.core import config_definitions
+from official.core import exp_factory
+from official.modeling import hyperparams
+
+
+BEST_CHECKPOINT_NAME = 'best_ckpt'
+
+
+def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
+ """Get leaf from a dictionary with arbitrary depth with a list of keys.
+
+ Args:
+ d: The dictionary to extract value from.
+ keys: The list of keys to extract values recursively.
+
+ Returns:
+ The value of the leaf.
+
+ Raises:
+ KeyError: If the value of keys extracted is a dictionary.
+ """
+ leaf = d
+ for k in keys:
+ if not isinstance(leaf, dict) or k not in leaf:
+ raise KeyError(
+ 'Path not exist while traversing the dictionary: d with keys'
+ ': %s.' % keys)
+ leaf = leaf[k]
+
+ if isinstance(leaf, dict):
+ raise KeyError('The value extracted with keys: %s is not a leaf of the '
+ 'dictionary: %s.' % (keys, d))
+ return leaf
+
+
+def cast_leaf_nested_dict(d: Dict[str, Any],
+ cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
+ """Cast the leaves of a dictionary with arbitrary depth in place.
+
+ Args:
+ d: The dictionary to extract value from.
+ cast_fn: The casting function.
+
+ Returns:
+ A dictionray with the same structure as d.
+ """
+ for key, value in d.items():
+ if isinstance(value, dict):
+ d[key] = cast_leaf_nested_dict(value, cast_fn)
+ else:
+ d[key] = cast_fn(value)
+ return d
+
+
+def _filter_leaf_nested_dict(
+ d: Dict[str, Any], predicate: Callable[[Any], bool]
+) -> Dict[str, Any]:
+ """Filters the leaves of a dictionary with arbitrary depth in place.
+
+ Args:
+ d: The dictionary to extract value from.
+ predicate: A function that will be called on every leave item. When the
+ function returns True the leave will be kept. Otherwise the leave will be
+ dropped.
+
+ Returns:
+ A new dictionray with filtered result.
+ """
+ result = {}
+ for key, value in d.items():
+ if isinstance(value, dict):
+ result[key] = _filter_leaf_nested_dict(value, predicate)
+ elif predicate(value):
+ result[key] = value
+ return result
+
+
+def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
+ data_dir: str) -> Any:
+ """Maybe create a BestCheckpointExporter object, according to the config."""
+ export_subdir = params.trainer.best_checkpoint_export_subdir
+ metric_name = params.trainer.best_checkpoint_eval_metric
+ metric_comp = params.trainer.best_checkpoint_metric_comp
+ if data_dir and export_subdir and metric_name:
+ best_ckpt_dir = os.path.join(data_dir, export_subdir)
+ best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
+ metric_comp)
+ logging.info(
+ 'Created the best checkpoint exporter. '
+ 'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
+ export_subdir, metric_name)
+ else:
+ best_ckpt_exporter = None
+
+ return best_ckpt_exporter
+
+
+class BestCheckpointExporter:
+ """Keeps track of the best result, and saves its checkpoint.
+
+ Orbit will support an API for checkpoint exporter. This class will be used
+ together with orbit once this functionality is ready.
+ """
+
+ def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
+ """Initialization.
+
+ Args:
+ export_dir: The directory that will contain exported checkpoints.
+ metric_name: Indicates which metric to look at, when determining which
+ result is better. If eval_logs being passed to maybe_export_checkpoint
+ is a nested dictionary, use `|` as a seperator for different layers.
+ metric_comp: Indicates how to compare results. Either `lower` or `higher`.
+ """
+ self._export_dir = export_dir
+ self._metric_name = metric_name.split('|')
+ self._metric_comp = metric_comp
+ if self._metric_comp not in ('lower', 'higher'):
+ raise ValueError('best checkpoint metric comp must be one of '
+ 'higher, lower. Got: {}'.format(self._metric_comp))
+ tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
+ self._best_ckpt_logs = self._maybe_load_best_eval_metric()
+ self._checkpoint_manager = None
+
+ def _get_checkpoint_manager(self, checkpoint):
+ """Gets an existing checkpoint manager or creates a new one."""
+ if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
+ != checkpoint):
+ logging.info('Creates a new checkpoint manager.')
+ self._checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=self._export_dir,
+ max_to_keep=1,
+ checkpoint_name=BEST_CHECKPOINT_NAME)
+
+ return self._checkpoint_manager
+
+ def maybe_export_checkpoint(
+ self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
+ """Compare eval_logs with past eval_logs and export checkpoint if better."""
+ logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
+ eval_logs, global_step)
+ if self._best_ckpt_logs is None or self._new_metric_is_better(
+ self._best_ckpt_logs, eval_logs):
+ self._best_ckpt_logs = eval_logs
+ if write_logs:
+ self.export_best_eval_metric(self._best_ckpt_logs, global_step)
+ self._get_checkpoint_manager(checkpoint).save()
+ return True
+ return False
+
+ def _maybe_load_best_eval_metric(self):
+ if not tf.io.gfile.exists(self.best_ckpt_logs_path):
+ return None
+ with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
+ return json.loads(reader.read())
+
+ def _new_metric_is_better(self, old_logs, new_logs):
+ """Check if the metric in new_logs is better than the metric in old_logs."""
+ old_value = float(
+ orbit.utils.get_value(
+ get_leaf_nested_dict(old_logs, self._metric_name)))
+ new_value = float(
+ orbit.utils.get_value(
+ get_leaf_nested_dict(new_logs, self._metric_name)))
+
+ logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
+ old_value, new_value)
+ if self._metric_comp == 'higher':
+ if new_value > old_value:
+ logging.info('[BestCheckpointExporter] '
+ 'the new number is better since it is higher.')
+ return True
+ else: # self._metric_comp == 'lower':
+ if new_value < old_value:
+ logging.info('[BestCheckpointExporter] '
+ 'the new number is better since it is lower.')
+ return True
+ return False
+
+ def export_best_eval_metric(self, eval_logs, global_step):
+ """Export evaluation results of the best checkpoint into a json file."""
+ # eval_log_ext may contains non-scalar tensors, such as image data when
+ # `allow_image_summary` is True. Here we only keep scalar tensors.
+ eval_logs_ext = _filter_leaf_nested_dict(
+ eval_logs, lambda x: tf.rank(x) <= 1
+ )
+ eval_logs_ext['best_ckpt_global_step'] = global_step
+ eval_logs_ext = cast_leaf_nested_dict(
+ eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
+ # Saving json file is very fast.
+ with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
+ writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
+
+ @property
+ def best_ckpt_logs(self):
+ return self._best_ckpt_logs
+
+ @property
+ def best_ckpt_logs_path(self):
+ return os.path.join(self._export_dir, 'info.json')
+
+ @property
+ def best_ckpt_path(self):
+ """Returns the best ckpt path or None if there is no ckpt yet."""
+ return tf.train.latest_checkpoint(self._export_dir)
+
+
+def create_optimizer(task: base_task.Task,
+ params: config_definitions.ExperimentConfig
+ ) -> tf_keras.optimizers.Optimizer:
+ """A create optimizer util to be backward compatability with new args."""
+ if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
+ dp_config = None
+ if hasattr(params.task, 'differential_privacy_config'):
+ dp_config = params.task.differential_privacy_config
+ optimizer = task.create_optimizer(
+ params.trainer.optimizer_config, params.runtime,
+ dp_config=dp_config)
+ else:
+ if hasattr(params.task, 'differential_privacy_config'
+ ) and params.task.differential_privacy_config is not None:
+ raise ValueError('Differential privacy config is specified but '
+ 'task.create_optimizer api does not accept it.')
+ optimizer = task.create_optimizer(
+ params.trainer.optimizer_config,
+ params.runtime)
+ return optimizer
+
+
+@gin.configurable
+def create_trainer(params: config_definitions.ExperimentConfig,
+ task: base_task.Task,
+ train: bool,
+ evaluate: bool,
+ checkpoint_exporter: Optional[BestCheckpointExporter] = None,
+ trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
+ """Create trainer."""
+ logging.info('Running default trainer.')
+ model = task.build_model()
+ optimizer = create_optimizer(task, params)
+ return trainer_cls(
+ params,
+ task,
+ model=model,
+ optimizer=optimizer,
+ train=train,
+ evaluate=evaluate,
+ checkpoint_exporter=checkpoint_exporter)
+
+
+@dataclasses.dataclass
+class ParseConfigOptions:
+ """Use this dataclass instead of FLAGS to customize parse_configuration()."""
+ experiment: str
+ config_file: List[str]
+ tpu: str = ''
+ tf_data_service: str = ''
+ params_override: str = ''
+
+ def __contains__(self, name):
+ return name in dataclasses.asdict(self)
+
+
+class ExperimentParser:
+ """Constructs the Experiment config from Flags or equivalent object.
+
+ Most of the cases, users only need to call the `parse()` function:
+ ```
+ builder = ExperimentParser(FLAGS)
+ params = builder.parse()
+ ```
+
+ The advanced users can modify the flow by calling the parse_*() functions
+ separately.
+ """
+
+ def __init__(self, flags_obj):
+ self._flags_obj = flags_obj
+
+ def parse(self):
+ """Overrall process of constructing Experiment config."""
+ params = self.base_experiment()
+ params = self.parse_config_file(params)
+ params = self.parse_runtime(params)
+ params = self.parse_data_service(params)
+ params = self.parse_params_override(params)
+ return params
+
+ def base_experiment(self):
+ """Get the base experiment config from --experiment field."""
+ if self._flags_obj.experiment is None:
+ raise ValueError('The flag --experiment must be specified.')
+ return exp_factory.get_exp_config(self._flags_obj.experiment)
+
+ def parse_config_file(self, params):
+ """Override the configs of params from the config_file."""
+ for config_file in self._flags_obj.config_file or []:
+ params = hyperparams.override_params_dict(
+ params, config_file, is_strict=True)
+ return params
+
+ def parse_runtime(self, params):
+ """Override the runtime configs of params from flags."""
+ # Override the TPU address and tf.data service address.
+ params.override({
+ 'runtime': {
+ 'tpu': self._flags_obj.tpu,
+ },
+ })
+ return params
+
+ def parse_data_service(self, params):
+ """Override the data service configs of params from flags."""
+ if ('tf_data_service' in self._flags_obj and
+ self._flags_obj.tf_data_service and
+ isinstance(params.task, config_definitions.TaskConfig)):
+ params.override({
+ 'task': {
+ 'train_data': {
+ 'tf_data_service_address': self._flags_obj.tf_data_service,
+ },
+ 'validation_data': {
+ 'tf_data_service_address': self._flags_obj.tf_data_service,
+ }
+ }
+ })
+ return params
+
+ def parse_params_override(self, params):
+ # Get the second level of override from `--params_override`.
+ # `--params_override` is typically used as a further override over the
+ # template. For example, one may define a particular template for training
+ # ResNet50 on ImageNet in a config file and pass it via `--config_file`,
+ # then define different learning rates and pass it via `--params_override`.
+ if self._flags_obj.params_override:
+ params = hyperparams.override_params_dict(
+ params, self._flags_obj.params_override, is_strict=True)
+ return params
+
+
+def parse_configuration(flags_obj, lock_return=True, print_return=True):
+ """Parses ExperimentConfig from flags."""
+
+ params = ExperimentParser(flags_obj).parse()
+
+ params.validate()
+ if lock_return:
+ params.lock()
+
+ if print_return:
+ pp = pprint.PrettyPrinter()
+ logging.info('Final experiment parameters:\n%s',
+ pp.pformat(params.as_dict()))
+
+ return params
+
+
+def serialize_config(params: config_definitions.ExperimentConfig,
+ model_dir: str):
+ """Serializes and saves the experiment config."""
+ if model_dir is None:
+ raise ValueError('model_dir must be specified, but got None')
+ params_save_path = os.path.join(model_dir, 'params.yaml')
+ logging.info('Saving experiment configuration to %s', params_save_path)
+ tf.io.gfile.makedirs(model_dir)
+ hyperparams.save_params_dict_to_yaml(params, params_save_path)
+
+
+def save_gin_config(filename_suffix: str, model_dir: str):
+ """Serializes and saves the experiment config."""
+ gin_save_path = os.path.join(
+ model_dir, 'operative_config.{}.gin'.format(filename_suffix))
+ logging.info('Saving gin configurations to %s', gin_save_path)
+ tf.io.gfile.makedirs(model_dir)
+ with tf.io.gfile.GFile(gin_save_path, 'w') as f:
+ f.write(gin.operative_config_str())
+
+
+def read_global_step_from_checkpoint(ckpt_file_path):
+ """Read global step from checkpoint, or get global step from its filename."""
+ global_step = tf.Variable(-1, dtype=tf.int64)
+ ckpt = tf.train.Checkpoint(global_step=global_step)
+ try:
+ ckpt.restore(ckpt_file_path).expect_partial()
+ global_step_maybe_restored = global_step.numpy()
+ except tf.errors.InvalidArgumentError:
+ global_step_maybe_restored = -1
+
+ if global_step_maybe_restored == -1:
+ raise ValueError('global_step not found in checkpoint {}. '
+ 'If you want to run finetune eval jobs, you need to '
+ 'make sure that your pretrain model writes '
+ 'global_step in its checkpoints.'.format(ckpt_file_path))
+ global_step_restored = global_step.numpy()
+ logging.info('get global_step %d from checkpoint %s', global_step_restored,
+ ckpt_file_path)
+ return global_step_restored
+
+
+def write_json_summary(log_dir, global_step, eval_metrics):
+ """Dump evaluation metrics to json file."""
+ serializable_dict = {}
+ for name, value in eval_metrics.items():
+ if hasattr(value, 'numpy'):
+ serializable_dict[name] = str(value.numpy())
+ else:
+ serializable_dict[name] = str(value)
+ output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
+ logging.info('Evaluation results at pretrain step %d: %s', global_step,
+ serializable_dict)
+ with tf.io.gfile.GFile(output_json, 'w') as writer:
+ writer.write(json.dumps(serializable_dict, indent=4) + '\n')
+
+
+def write_summary(summary_writer, global_step, eval_metrics):
+ """Write evaluation metrics to TF summary."""
+ numeric_dict = {}
+ for name, value in eval_metrics.items():
+ numeric_dict[name] = float(orbit.utils.get_value(value))
+ with summary_writer.as_default():
+ for name, value in numeric_dict.items():
+ tf.summary.scalar(name, value, step=global_step)
+ summary_writer.flush()
+
+
+def remove_ckpts(model_dir):
+ """Remove model checkpoints, so we can restart."""
+ ckpts = os.path.join(model_dir, 'ckpt-*')
+ logging.info('removing checkpoint files %s', ckpts)
+ for file_to_remove in tf.io.gfile.glob(ckpts):
+ tf.io.gfile.rmtree(file_to_remove)
+
+ file_to_remove = os.path.join(model_dir, 'checkpoint')
+ if tf.io.gfile.exists(file_to_remove):
+ tf.io.gfile.remove(file_to_remove)
+
+
+def write_model_params(model: Union[tf.Module, tf_keras.Model],
+ output_path: str) -> None:
+ """Writes the model parameters and shapes to a file.
+
+ Args:
+ model: A model instance.
+ output_path: Output file path.
+ """
+ with tf.io.gfile.GFile(output_path, 'w') as f:
+ total_params = 0
+ for var in model.variables:
+ shape = tf.shape(var)
+ total_params += tf.math.reduce_prod(shape).numpy()
+ f.write(f'{var.name} {shape.numpy().tolist()}\n')
+ f.write(f'\nTotal params: {total_params}\n')
+
+
+def try_count_params(
+ model: Union[tf.Module, tf_keras.Model],
+ trainable_only: bool = False):
+ """Count the number of parameters if model is possible.
+
+ Args:
+ model: Try to count the number of params in this model.
+ trainable_only: Whether to calculate trainable params only. This flag is
+ not used when the model has `count_params` attribute.
+
+ Returns:
+ The number of parameters or None.
+ """
+ if hasattr(model, 'count_params'):
+ try:
+ return model.count_params()
+ except ValueError:
+ logging.info('Number of trainable params unknown, because the build() '
+ 'methods in keras layers were not called. This is probably '
+ 'because the model was not feed any input, e.g., the max '
+ 'train step already reached before this run.')
+ return None
+ else:
+ total_params = 0
+ variables = model.trainable_variables if trainable_only else model.variables
+ for var in variables:
+ shape = tf.shape(var)
+ total_params += tf.math.reduce_prod(shape).numpy()
+ return total_params
+
+
+def try_count_flops(model: Union[tf.Module, tf_keras.Model],
+ inputs_kwargs: Optional[Dict[str, Any]] = None,
+ output_path: Optional[str] = None):
+ """Counts and returns model FLOPs.
+
+ Args:
+ model: A model instance.
+ inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
+ shape specifications to getting corresponding concrete function.
+ output_path: A file path to write the profiling results to.
+
+ Returns:
+ The model's FLOPs.
+ """
+ if hasattr(model, 'inputs'):
+ try:
+ # Get input shape and set batch size to 1.
+ if model.inputs:
+ inputs = [
+ tf.TensorSpec([1] + input.shape[1:], input.dtype)
+ for input in model.inputs
+ ]
+ concrete_func = tf.function(model).get_concrete_function(inputs)
+ # If model.inputs is invalid, try to use the input to get concrete
+ # function for model.call (subclass model).
+ else:
+ concrete_func = tf.function(model.call).get_concrete_function(
+ **inputs_kwargs)
+ frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
+
+ # Calculate FLOPs.
+ run_meta = tf.compat.v1.RunMetadata()
+ opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
+ if output_path is not None:
+ opts['output'] = f'file:outfile={output_path}'
+ else:
+ opts['output'] = 'none'
+ flops = tf.compat.v1.profiler.profile(
+ graph=frozen_func.graph, run_meta=run_meta, options=opts)
+ return flops.total_float_ops
+ except Exception as e: # pylint: disable=broad-except
+ logging.info(
+ 'Failed to count model FLOPs with error %s, because the build() '
+ 'methods in keras layers were not called. This is probably because '
+ 'the model was not feed any input, e.g., the max train step already '
+ 'reached before this run.', e)
+ return None
+ return None
+
+
+@ops.RegisterStatistics('Einsum', 'flops')
+def _einsum_flops(graph, node):
+ """Calculates the compute resources needed for Einsum."""
+ assert len(node.input) == 2
+ x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
+ graph, node.input[0])
+ y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
+ graph, node.input[1])
+ x_shape.assert_is_fully_defined()
+ y_shape.assert_is_fully_defined()
+ x_shape = x_shape.as_list()
+ y_shape = y_shape.as_list()
+ equation = str(node.attr['equation'])
+ equation = (
+ equation.replace('s:', '')
+ .replace('"', '')
+ .replace(' ', '')
+ .replace('\n', '')
+ )
+ x_str = equation.split(',')[0]
+ y_r_str = equation.split(',')[1]
+ y_str = y_r_str.split('->')[0]
+ r_str = y_r_str.split('->')[1]
+ shape_dic = {}
+ contracted = set()
+ for indice in x_str + y_str:
+ if indice in x_str:
+ indice_dim = x_shape[x_str.find(indice)]
+ elif indice in y_str:
+ indice_dim = y_shape[y_str.find(indice)]
+ else:
+ raise ValueError('indice {} not found in inputs'.format(indice))
+ shape_dic[indice] = indice_dim
+ if indice not in r_str:
+ contracted.add(indice)
+ madds = np.prod([shape_dic[indice] for indice in r_str]) * (
+ np.prod([shape_dic[indice] for indice in contracted]))
+ flops = 2 * madds
+ return ops.OpStats('flops', flops)
diff --git a/modeling/official/core/train_utils_test.py b/modeling/official/core/train_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..807e95cbf245bfa644aad189a1d8523e51c960e0
--- /dev/null
+++ b/modeling/official/core/train_utils_test.py
@@ -0,0 +1,215 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.core.train_utils."""
+import json
+import os
+import pprint
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.core import exp_factory
+from official.core import test_utils
+from official.core import train_utils
+from official.modeling import hyperparams
+
+
+@exp_factory.register_config_factory('foo')
+def foo():
+ """Multitask experiment for test."""
+ experiment_config = hyperparams.Config(
+ default_params={
+ 'runtime': {
+ 'tpu': 'fake',
+ },
+ 'task': {
+ 'model': {
+ 'model_id': 'bar',
+ },
+ },
+ 'trainer': {
+ 'train_steps': -1,
+ 'validation_steps': -1,
+ },
+ })
+ return experiment_config
+
+
+class TrainUtilsTest(tf.test.TestCase):
+
+ def test_get_leaf_nested_dict(self):
+ d = {'a': {'i': {'x': 5}}}
+ self.assertEqual(train_utils.get_leaf_nested_dict(d, ['a', 'i', 'x']), 5)
+
+ def test_get_leaf_nested_dict_not_leaf(self):
+ with self.assertRaisesRegex(KeyError, 'The value extracted with keys.*'):
+ d = {'a': {'i': {'x': 5}}}
+ train_utils.get_leaf_nested_dict(d, ['a', 'i'])
+
+ def test_get_leaf_nested_dict_path_not_exist_missing_key(self):
+ with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
+ d = {'a': {'i': {'x': 5}}}
+ train_utils.get_leaf_nested_dict(d, ['a', 'i', 'y'])
+
+ def test_get_leaf_nested_dict_path_not_exist_out_of_range(self):
+ with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
+ d = {'a': {'i': {'x': 5}}}
+ train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
+
+ def test_get_leaf_nested_dict_path_not_exist_meets_leaf(self):
+ with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
+ d = {'a': {'i': 5}}
+ train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
+
+ def test_cast_leaf_nested_dict(self):
+ d = {'a': {'i': {'x': '123'}}, 'b': 456.5}
+ d = train_utils.cast_leaf_nested_dict(d, int)
+ self.assertEqual(d['a']['i']['x'], 123)
+ self.assertEqual(d['b'], 456)
+
+ def test_write_model_params_keras_model(self):
+ inputs = np.zeros([2, 3])
+ model = test_utils.FakeKerasModel()
+ model(inputs) # Must do forward pass to build the model.
+
+ filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
+ train_utils.write_model_params(model, filepath)
+ actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
+
+ expected = [
+ 'fake_keras_model/dense/kernel:0 [3, 4]',
+ 'fake_keras_model/dense/bias:0 [4]',
+ 'fake_keras_model/dense_1/kernel:0 [4, 4]',
+ 'fake_keras_model/dense_1/bias:0 [4]',
+ '',
+ 'Total params: 36',
+ ]
+ self.assertEqual(actual, expected)
+
+ def test_write_model_params_module(self):
+ inputs = np.zeros([2, 3], dtype=np.float32)
+ model = test_utils.FakeModule(3, name='fake_module')
+ model(inputs) # Must do forward pass to build the model.
+
+ filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
+ train_utils.write_model_params(model, filepath)
+ actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
+
+ expected = [
+ 'fake_module/dense/b:0 [4]',
+ 'fake_module/dense/w:0 [3, 4]',
+ 'fake_module/dense_1/b:0 [4]',
+ 'fake_module/dense_1/w:0 [4, 4]',
+ '',
+ 'Total params: 36',
+ ]
+ self.assertEqual(actual, expected)
+
+ def test_construct_experiment_from_flags(self):
+ options = train_utils.ParseConfigOptions(
+ experiment='foo',
+ config_file=[],
+ tpu='bar',
+ tf_data_service='',
+ params_override='task.model.model_id=new,'
+ 'trainer.train_steps=10,'
+ 'trainer.validation_steps=11')
+ builder = train_utils.ExperimentParser(options)
+ params_from_obj = builder.parse()
+ params_from_func = train_utils.parse_configuration(options)
+ pp = pprint.PrettyPrinter()
+ self.assertEqual(
+ pp.pformat(params_from_obj.as_dict()),
+ pp.pformat(params_from_func.as_dict()))
+ self.assertEqual(params_from_obj.runtime.tpu, 'bar')
+ self.assertEqual(params_from_obj.task.model.model_id, 'new')
+ self.assertEqual(params_from_obj.trainer.train_steps, 10)
+ self.assertEqual(params_from_obj.trainer.validation_steps, 11)
+
+
+class BestCheckpointExporterTest(tf.test.TestCase):
+
+ def test_maybe_export(self):
+ model_dir = self.create_tempdir().full_path
+ best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1')
+ metric_name = 'test_metric|metric_1'
+ exporter = train_utils.BestCheckpointExporter(
+ model_dir, metric_name, 'higher')
+ v = tf.Variable(1.0)
+ checkpoint = tf.train.Checkpoint(v=v)
+ ret = exporter.maybe_export_checkpoint(
+ checkpoint, {'test_metric': {'metric_1': 5.0}}, 100)
+ with self.subTest(name='Successful first save.'):
+ self.assertEqual(ret, True)
+ v_2 = tf.Variable(2.0)
+ checkpoint_2 = tf.train.Checkpoint(v=v_2)
+ checkpoint_2.restore(best_ckpt_path)
+ self.assertEqual(v_2.numpy(), 1.0)
+
+ v = tf.Variable(3.0)
+ checkpoint = tf.train.Checkpoint(v=v)
+ ret = exporter.maybe_export_checkpoint(
+ checkpoint, {'test_metric': {'metric_1': 6.0}}, 200)
+ with self.subTest(name='Successful better metic save.'):
+ self.assertEqual(ret, True)
+ v_2 = tf.Variable(2.0)
+ checkpoint_2 = tf.train.Checkpoint(v=v_2)
+ checkpoint_2.restore(best_ckpt_path)
+ self.assertEqual(v_2.numpy(), 3.0)
+
+ v = tf.Variable(5.0)
+ checkpoint = tf.train.Checkpoint(v=v)
+ ret = exporter.maybe_export_checkpoint(
+ checkpoint, {'test_metric': {'metric_1': 1.0}}, 300)
+ with self.subTest(name='Worse metic no save.'):
+ self.assertEqual(ret, False)
+ v_2 = tf.Variable(2.0)
+ checkpoint_2 = tf.train.Checkpoint(v=v_2)
+ checkpoint_2.restore(best_ckpt_path)
+ self.assertEqual(v_2.numpy(), 3.0)
+
+ def test_export_best_eval_metric(self):
+ model_dir = self.create_tempdir().full_path
+ metric_name = 'test_metric|metric_1'
+ exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
+ 'higher')
+ exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100)
+ with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
+ 'rb') as reader:
+ metric = json.loads(reader.read())
+ self.assertAllEqual(
+ metric,
+ {'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
+
+ def test_export_best_eval_metric_skips_non_scalar_values(self):
+ model_dir = self.create_tempdir().full_path
+ metric_name = 'test_metric|metric_1'
+ exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
+ 'higher')
+ image = tf.zeros(shape=[16, 8, 1])
+ eval_logs = {'test_metric': {'metric_1': 5.0, 'image': image}}
+
+ exporter.export_best_eval_metric(eval_logs, 100)
+
+ with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
+ 'rb') as reader:
+ metric = json.loads(reader.read())
+ self.assertAllEqual(
+ metric,
+ {'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/README.md b/modeling/official/legacy/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ced1fce05dfcd8308d7bec8b01186a8804bc074f
--- /dev/null
+++ b/modeling/official/legacy/README.md
@@ -0,0 +1,5 @@
+Models in this `legacy` directory are mainly are used for benchmarking the
+models.
+
+Please note that the models in this `legacy` directory are not supported like
+the models in official/nlp and official/vision.
diff --git a/modeling/official/legacy/__init__.py b/modeling/official/legacy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/albert/README.md b/modeling/official/legacy/albert/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..117a050314e586e646c60d0845aecb4eced22757
--- /dev/null
+++ b/modeling/official/legacy/albert/README.md
@@ -0,0 +1,4 @@
+# ALBERT (ALBERT: A Lite BERT for Self-supervised Learning of Language Representations)
+
+**WARNING**: This directory is deprecated.
+See `nlp/docs/MODEL_GARDEN.md` for the new ALBERT implementation.
diff --git a/modeling/official/legacy/albert/__init__.py b/modeling/official/legacy/albert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/albert/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/albert/configs.py b/modeling/official/legacy/albert/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71757fa719578e68e0cf72f0099e966534365ef
--- /dev/null
+++ b/modeling/official/legacy/albert/configs.py
@@ -0,0 +1,50 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The ALBERT configurations."""
+
+import six
+
+from official.legacy.bert import configs
+
+
+class AlbertConfig(configs.BertConfig):
+ """Configuration for `ALBERT`."""
+
+ def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
+ """Constructs AlbertConfig.
+
+ Args:
+ num_hidden_groups: Number of group for the hidden layers, parameters in
+ the same group are shared. Note that this value and also the following
+ 'inner_group_num' has to be 1 for now, because all released ALBERT
+ models set them to 1. We may support arbitary valid values in future.
+ inner_group_num: Number of inner repetition of attention and ffn.
+ **kwargs: The remaining arguments are the same as above 'BertConfig'.
+ """
+ super(AlbertConfig, self).__init__(**kwargs)
+
+ # TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
+ # in the released ALBERT. Support other values in AlbertEncoder if needed.
+ if inner_group_num != 1 or num_hidden_groups != 1:
+ raise ValueError("We only support 'inner_group_num' and "
+ "'num_hidden_groups' as 1.")
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `AlbertConfig` from a Python dictionary of parameters."""
+ config = AlbertConfig(vocab_size=None)
+ for (key, value) in six.iteritems(json_object):
+ config.__dict__[key] = value
+ return config
diff --git a/modeling/official/legacy/bert/README.md b/modeling/official/legacy/bert/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..82bd5e9d9c577258c679447fc5dcbb8efc93bca2
--- /dev/null
+++ b/modeling/official/legacy/bert/README.md
@@ -0,0 +1,395 @@
+# BERT (Bidirectional Encoder Representations from Transformers)
+
+**WARNING**: We are on the way to deprecating most of the code in this directory.
+Please see
+[this link](../g3doc/tutorials/bert_new.md)
+for the new tutorial and use the new code in `nlp/modeling`. This README is
+still correct for this legacy implementation.
+
+The academic paper which describes BERT in detail and provides full results on a
+number of tasks can be found here: https://arxiv.org/abs/1810.04805.
+
+This repository contains TensorFlow 2.x implementation for BERT.
+
+## Contents
+ * [Contents](#contents)
+ * [Pre-trained Models](#pre-trained-models)
+ * [Restoring from Checkpoints](#restoring-from-checkpoints)
+ * [Set Up](#set-up)
+ * [Process Datasets](#process-datasets)
+ * [Fine-tuning with BERT](#fine-tuning-with-bert)
+ * [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
+ * [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
+ * [SQuAD 1.1](#squad-1.1)
+
+
+## Pre-trained Models
+
+We released both checkpoints and tf.hub modules as the pretrained models for
+fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
+released in TF 1.x official BERT repository
+[google-research/bert](https://github.com/google-research/bert)
+in order to keep consistent with BERT paper.
+
+
+### Access to Pretrained Checkpoints
+
+Pretrained checkpoints can be found in the following links:
+
+**Note: We have switched BERT implementation
+to use Keras functional-style networks in [nlp/modeling](../modeling).
+The new checkpoints are:**
+
+* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
+ 12-layer, 768-hidden, 12-heads, 110M parameters
+* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
+ 12-layer, 768-hidden, 12-heads , 110M parameters
+* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Multilingual Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/multi_cased_L-12_H-768_A-12.tar.gz)**:
+ 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
+
+We recommend to host checkpoints on Google Cloud Storage buckets when you use
+Cloud GPU/TPU.
+
+### Restoring from Checkpoints
+
+`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
+weights from provided pre-trained checkpoints, you can use the following code:
+
+```python
+init_checkpoint='the pretrained model checkpoint path.'
+model=tf.keras.Model() # Bert pre-trained model as feature extractor.
+checkpoint = tf.train.Checkpoint(model=model)
+checkpoint.restore(init_checkpoint)
+```
+
+Checkpoints featuring native serialized Keras models
+(i.e. model.load()/load_weights()) will be available soon.
+
+### Access to Pretrained hub modules.
+
+Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
+following links:
+
+* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/)**:
+ 12-layer, 768-hidden, 12-heads, 110M parameters
+* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/)**:
+ 12-layer, 768-hidden, 12-heads , 110M parameters
+* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/)**:
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
+* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/)**:
+ 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
+* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/)**:
+ Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
+ 110M parameters
+
+## Set Up
+
+```shell
+export PYTHONPATH="$PYTHONPATH:/path/to/models"
+```
+
+Install `tf-nightly` to get latest updates:
+
+```shell
+pip install tf-nightly-gpu
+```
+
+With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
+TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
+
+```shell
+ctpu up -name --tf-version=”nightly”
+```
+
+Second, you need to install TF 2 `tf-nightly` on your VM:
+
+```shell
+pip install tf-nightly
+```
+
+## Process Datasets
+
+### Pre-training
+
+There is no change to generate pre-training data. Please use the script
+[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
+which is essentially branched from the [BERT research repo](https://github.com/google-research/bert)
+to get processed pre-training data and it adapts to TF2 symbols and python3
+compatibility.
+
+Running the pre-training script requires an input and output directory, as well as a vocab file. Note that max_seq_length will need to match the sequence length parameter you specify when you run pre-training.
+
+Example shell script to call create_pretraining_data.py
+```
+export WORKING_DIR='local disk or cloud location'
+export BERT_DIR='local disk or cloud location'
+python models/official/nlp/data/create_pretraining_data.py \
+ --input_file=$WORKING_DIR/input/input.txt \
+ --output_file=$WORKING_DIR/output/tf_examples.tfrecord \
+ --vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
+ --do_lower_case=True \
+ --max_seq_length=512 \
+ --max_predictions_per_seq=76 \
+ --masked_lm_prob=0.15 \
+ --random_seed=12345 \
+ --dupe_factor=5
+```
+
+### Fine-tuning
+
+To prepare the fine-tuning data for final model training, use the
+[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
+Resulting datasets in `tf_record` format and training meta data should be later
+passed to training or evaluation scripts. The task-specific arguments are
+described in the following sections:
+
+* GLUE
+
+Users can download the
+[GLUE data](https://gluebenchmark.com/tasks) by running
+[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
+and unpack it to some directory `$GLUE_DIR`.
+Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate it on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.
+
+```shell
+export GLUE_DIR=~/glue
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+
+export TASK_NAME=MNLI
+export OUTPUT_DIR=gs://some_bucket/datasets
+python ../data/create_finetuning_data.py \
+ --input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
+ --eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
+ --meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
+ --fine_tuning_task_type=classification --max_seq_length=128 \
+ --classification_task_name=${TASK_NAME}
+```
+
+* SQUAD
+
+The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
+detailed information about the SQuAD datasets and evaluation.
+
+The necessary files can be found here:
+
+* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
+* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
+* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
+* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
+* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
+* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
+
+```shell
+export SQUAD_DIR=~/squad
+export SQUAD_VERSION=v1.1
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export OUTPUT_DIR=gs://some_bucket/datasets
+
+python ../data/create_finetuning_data.py \
+ --squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --fine_tuning_task_type=squad --max_seq_length=384
+```
+
+Note: To create fine-tuning data with SQUAD 2.0, you need to add flag `--version_2_with_negative=True`.
+
+## Fine-tuning with BERT
+
+### Cloud GPUs and TPUs
+
+* Cloud Storage
+
+The unzipped pre-trained model files can also be found in the Google Cloud
+Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export MODEL_DIR=gs://some_bucket/my_output_dir
+```
+
+Currently, users are able to access to `tf-nightly` TPUs and the following TPU
+script should run with `tf-nightly`.
+
+* GPU -> TPU
+
+Just add the following flags to `run_classifier.py` or `run_squad.py`:
+
+```shell
+ --distribution_strategy=tpu
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+### Sentence and Sentence-pair Classification Tasks
+
+This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
+Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
+few minutes on most GPUs.
+
+We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
+workflow.
+For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
+(uncased_L-12_H-768_A-12).
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export GLUE_DIR=gs://some_bucket/datasets
+export TASK=MRPC
+
+python run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=4 \
+ --eval_batch_size=4 \
+ --steps_per_loop=1 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+Alternatively, instead of specifying `init_checkpoint`, you can specify
+`hub_module_url` to employ a pre-trained BERT hub module, e.g.,
+` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
+
+After training a model, to get predictions from the classifier, you can set the
+`--mode=predict` and offer the test set tfrecords to `--eval_data_path`.
+The output will be created in file called test_results.tsv in the output folder.
+Each line will contain output for each sample, columns are the class
+probabilities.
+
+```shell
+python run_classifier.py \
+ --mode='predict' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --eval_batch_size=4 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+To use TPU, you only need to switch the distribution strategy type to `tpu` with TPU
+information and use remote storage for model checkpoints.
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export TPU_IP_ADDRESS='???'
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export GLUE_DIR=gs://some_bucket/datasets
+export TASK=MRPC
+
+python run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=32 \
+ --eval_batch_size=32 \
+ --steps_per_loop=1000 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
+training steps inside a `tf.function` can significantly increase TPU utilization
+and callbacks will not be called inside the loop.
+
+### SQuAD 1.1
+
+The Stanford Question Answering Dataset (SQuAD) is a popular question answering
+benchmark dataset. See more on [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
+
+We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
+workflow.
+For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
+(uncased_L-12_H-768_A-12).
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export SQUAD_DIR=gs://some_bucket/datasets
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export SQUAD_VERSION=v1.1
+
+python run_squad.py \
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=4 \
+ --predict_batch_size=4 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=2 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=mirrored
+```
+
+Similarly, you can replace `init_checkpoint` FLAG with `hub_module_url` to
+specify a hub module path.
+
+`run_squad.py` writes the prediction for `--predict_file` by default. If you set
+the `--model=predict` and offer the SQuAD test data, the scripts will generate
+the prediction json file.
+
+To use TPU, you need to switch the distribution strategy type to `tpu` with TPU
+information.
+
+```shell
+export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+export TPU_IP_ADDRESS='???'
+export MODEL_DIR=gs://some_bucket/my_output_dir
+export SQUAD_DIR=gs://some_bucket/datasets
+export SQUAD_VERSION=v1.1
+
+python run_squad.py \
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
+ --vocab_file=${BERT_DIR}/vocab.txt \
+ --bert_config_file=${BERT_DIR}/bert_config.json \
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
+ --train_batch_size=32 \
+ --learning_rate=8e-5 \
+ --num_train_epochs=2 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
+```
+
+The dev set predictions will be saved into a file called predictions.json in the
+model_dir:
+
+```shell
+python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
+```
+
+
diff --git a/modeling/official/legacy/bert/__init__.py b/modeling/official/legacy/bert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f338592c943c69c8ca66bc1f0981a619ea10e27
--- /dev/null
+++ b/modeling/official/legacy/bert/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
diff --git a/modeling/official/legacy/bert/bert_cloud_tpu.md b/modeling/official/legacy/bert/bert_cloud_tpu.md
new file mode 100644
index 0000000000000000000000000000000000000000..baf6f9bdc0c155cb53b30cea5f404aa166c3a2c6
--- /dev/null
+++ b/modeling/official/legacy/bert/bert_cloud_tpu.md
@@ -0,0 +1,110 @@
+# BERT FineTuning with Cloud TPU: Sentence and Sentence-Pair Classification Tasks (TF 2.1)
+This tutorial shows you how to train the Bidirectional Encoder Representations from Transformers (BERT) model on Cloud TPU.
+
+
+## Set up Cloud Storage and Compute Engine VM
+1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
+2. Create a variable for the project's id:
+```
+export PROJECT_ID=your-project_id
+```
+3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
+```
+gcloud config set project ${PROJECT_ID}
+```
+4. Create a Cloud Storage bucket using the following command:
+```
+gsutil mb -p ${PROJECT_ID} -c standard -l europe-west4 -b on gs://your-bucket-name
+```
+This Cloud Storage bucket stores the data you use to train your model and the training results.
+5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
+```
+ctpu up --tpu-size=v3-8 \
+ --machine-type=n1-standard-8 \
+ --zone=europe-west4-a \
+ --tf-version=2.1 [optional flags: --project, --name]
+```
+6. The configuration you specified appears. Enter y to approve or n to cancel.
+7. When the ctpu up command has finished executing, verify that your shell prompt has changed from username@project to username@tpuname. This change shows that you are now logged into your Compute Engine VM.
+```
+gcloud compute ssh vm-name --zone=europe-west4-a
+(vm)$ export TPU_NAME=vm-name
+```
+As you continue these instructions, run each command that begins with `(vm)$` in your VM session window.
+
+## Prepare the Dataset
+1. From your Compute Engine virtual machine (VM), install requirements.txt.
+```
+(vm)$ cd /usr/share/models
+(vm)$ sudo pip3 install -r official/requirements.txt
+```
+2. Optional: download download_glue_data.py
+
+This tutorial uses the General Language Understanding Evaluation (GLUE) benchmark to evaluate and analyze the performance of the model. The GLUE data is provided for this tutorial at gs://cloud-tpu-checkpoints/bert/classification.
+
+## Define parameter values
+Next, define several parameter values that are required when you train and evaluate your model:
+
+```
+(vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
+(vm)$ export STORAGE_BUCKET=gs://your-bucket-name
+(vm)$ export BERT_BASE_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
+(vm)$ export MODEL_DIR=${STORAGE_BUCKET}/bert-output
+(vm)$ export GLUE_DIR=gs://cloud-tpu-checkpoints/bert/classification
+(vm)$ export TASK=mnli
+```
+
+## Train the model
+From your Compute Engine VM, run the following command.
+
+```
+(vm)$ python3 official/nlp/bert/run_classifier.py \
+ --mode='train_and_eval' \
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
+ --bert_config_file=$BERT_BASE_DIR/bert_config.json \
+ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
+ --train_batch_size=32 \
+ --eval_batch_size=32 \
+ --learning_rate=2e-5 \
+ --num_train_epochs=3 \
+ --model_dir=${MODEL_DIR} \
+ --distribution_strategy=tpu \
+ --tpu=${TPU_NAME}
+```
+
+## Verify your results
+The training takes approximately 1 hour on a v3-8 TPU. When script completes, you should see results similar to the following:
+```
+Training Summary:
+{'train_loss': 0.28142181038856506,
+'last_train_metrics': 0.9467429518699646,
+'eval_metrics': 0.8599063158035278,
+'total_training_steps': 36813}
+```
+
+## Clean up
+To avoid incurring charges to your GCP account for the resources used in this topic:
+1. Disconnect from the Compute Engine VM:
+```
+(vm)$ exit
+```
+2. In your Cloud Shell, run ctpu delete with the --zone flag you used when you set up the Cloud TPU to delete your Compute Engine VM and your Cloud TPU:
+```
+$ ctpu delete --zone=your-zone
+```
+3. Run ctpu status specifying your zone to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:
+```
+$ ctpu status --zone=your-zone
+```
+4. Run gsutil as shown, replacing your-bucket with the name of the Cloud Storage bucket you created for this tutorial:
+```
+$ gsutil rm -r gs://your-bucket
+```
+
+
+
+
+
+
diff --git a/modeling/official/legacy/bert/bert_models.py b/modeling/official/legacy/bert/bert_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..59560d8412ea90239d5bc3bf631ab24056118ed6
--- /dev/null
+++ b/modeling/official/legacy/bert/bert_models.py
@@ -0,0 +1,365 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT models that are compatible with TF 2.0."""
+
+import gin
+import tensorflow as tf, tf_keras
+import tensorflow_hub as hub
+from official.legacy.albert import configs as albert_configs
+from official.legacy.bert import configs
+from official.modeling import tf_utils
+from official.nlp.modeling import models
+from official.nlp.modeling import networks
+
+
+class BertPretrainLossAndMetricLayer(tf_keras.layers.Layer):
+ """Returns layer that computes custom loss and metrics for pretraining."""
+
+ def __init__(self, vocab_size, **kwargs):
+ super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
+ self._vocab_size = vocab_size
+ self.config = {
+ 'vocab_size': vocab_size,
+ }
+
+ def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
+ lm_example_loss, sentence_output, sentence_labels,
+ next_sentence_loss):
+ """Adds metrics."""
+ masked_lm_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
+ lm_labels, lm_output)
+ numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
+ denominator = tf.reduce_sum(lm_label_weights) + 1e-5
+ masked_lm_accuracy = numerator / denominator
+ self.add_metric(
+ masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
+
+ self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
+
+ if sentence_labels is not None:
+ next_sentence_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
+ sentence_labels, sentence_output)
+ self.add_metric(
+ next_sentence_accuracy,
+ name='next_sentence_accuracy',
+ aggregation='mean')
+
+ if next_sentence_loss is not None:
+ self.add_metric(
+ next_sentence_loss, name='next_sentence_loss', aggregation='mean')
+
+ def call(self,
+ lm_output_logits,
+ sentence_output_logits,
+ lm_label_ids,
+ lm_label_weights,
+ sentence_labels=None):
+ """Implements call() for the layer."""
+ lm_label_weights = tf.cast(lm_label_weights, tf.float32)
+ lm_output_logits = tf.cast(lm_output_logits, tf.float32)
+
+ lm_prediction_losses = tf_keras.losses.sparse_categorical_crossentropy(
+ lm_label_ids, lm_output_logits, from_logits=True)
+ lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
+ lm_denominator_loss = tf.reduce_sum(lm_label_weights)
+ mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
+ lm_denominator_loss)
+
+ if sentence_labels is not None:
+ sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
+ sentence_loss = tf_keras.losses.sparse_categorical_crossentropy(
+ sentence_labels, sentence_output_logits, from_logits=True)
+ sentence_loss = tf.reduce_mean(sentence_loss)
+ loss = mask_label_loss + sentence_loss
+ else:
+ sentence_loss = None
+ loss = mask_label_loss
+
+ batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
+ # TODO(hongkuny): Avoids the hack and switches add_loss.
+ final_loss = tf.fill(batch_shape, loss)
+
+ self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
+ mask_label_loss, sentence_output_logits, sentence_labels,
+ sentence_loss)
+ return final_loss
+
+
+@gin.configurable
+def get_transformer_encoder(bert_config,
+ sequence_length=None,
+ transformer_encoder_cls=None,
+ output_range=None):
+ """Gets a 'TransformerEncoder' object.
+
+ Args:
+ bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
+ sequence_length: [Deprecated].
+ transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
+ default BERT encoder implementation.
+ output_range: the sequence output range, [0, output_range). Default setting
+ is to return the entire sequence output.
+
+ Returns:
+ A encoder object.
+ """
+ del sequence_length
+ if transformer_encoder_cls is not None:
+ # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
+ embedding_cfg = dict(
+ vocab_size=bert_config.vocab_size,
+ type_vocab_size=bert_config.type_vocab_size,
+ hidden_size=bert_config.hidden_size,
+ max_seq_length=bert_config.max_position_embeddings,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ )
+ hidden_cfg = dict(
+ num_attention_heads=bert_config.num_attention_heads,
+ intermediate_size=bert_config.intermediate_size,
+ intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range),
+ )
+ kwargs = dict(
+ embedding_cfg=embedding_cfg,
+ hidden_cfg=hidden_cfg,
+ num_hidden_instances=bert_config.num_hidden_layers,
+ pooled_output_dim=bert_config.hidden_size,
+ pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range))
+
+ # Relies on gin configuration to define the Transformer encoder arguments.
+ return transformer_encoder_cls(**kwargs)
+
+ kwargs = dict(
+ vocab_size=bert_config.vocab_size,
+ hidden_size=bert_config.hidden_size,
+ num_layers=bert_config.num_hidden_layers,
+ num_attention_heads=bert_config.num_attention_heads,
+ intermediate_size=bert_config.intermediate_size,
+ activation=tf_utils.get_activation(bert_config.hidden_act),
+ dropout_rate=bert_config.hidden_dropout_prob,
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
+ max_sequence_length=bert_config.max_position_embeddings,
+ type_vocab_size=bert_config.type_vocab_size,
+ embedding_width=bert_config.embedding_size,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range))
+ if isinstance(bert_config, albert_configs.AlbertConfig):
+ return networks.AlbertEncoder(**kwargs)
+ else:
+ assert isinstance(bert_config, configs.BertConfig)
+ kwargs['output_range'] = output_range
+ return networks.BertEncoder(**kwargs)
+
+
+def pretrain_model(bert_config,
+ seq_length,
+ max_predictions_per_seq,
+ initializer=None,
+ use_next_sentence_label=True,
+ return_core_pretrainer_model=False):
+ """Returns model to be used for pre-training.
+
+ Args:
+ bert_config: Configuration that defines the core BERT model.
+ seq_length: Maximum sequence length of the training data.
+ max_predictions_per_seq: Maximum number of tokens in sequence to mask out
+ and use for pretraining.
+ initializer: Initializer for weights in BertPretrainer.
+ use_next_sentence_label: Whether to use the next sentence label.
+ return_core_pretrainer_model: Whether to also return the `BertPretrainer`
+ object.
+
+ Returns:
+ A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
+ save weights after pretraining, and (3) optional core `BertPretrainer`
+ object if argument `return_core_pretrainer_model` is True.
+ """
+ input_word_ids = tf_keras.layers.Input(
+ shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
+ input_mask = tf_keras.layers.Input(
+ shape=(seq_length,), name='input_mask', dtype=tf.int32)
+ input_type_ids = tf_keras.layers.Input(
+ shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
+ masked_lm_positions = tf_keras.layers.Input(
+ shape=(max_predictions_per_seq,),
+ name='masked_lm_positions',
+ dtype=tf.int32)
+ masked_lm_ids = tf_keras.layers.Input(
+ shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
+ masked_lm_weights = tf_keras.layers.Input(
+ shape=(max_predictions_per_seq,),
+ name='masked_lm_weights',
+ dtype=tf.int32)
+
+ if use_next_sentence_label:
+ next_sentence_labels = tf_keras.layers.Input(
+ shape=(1,), name='next_sentence_labels', dtype=tf.int32)
+ else:
+ next_sentence_labels = None
+
+ transformer_encoder = get_transformer_encoder(bert_config, seq_length)
+ if initializer is None:
+ initializer = tf_keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range)
+ pretrainer_model = models.BertPretrainer(
+ network=transformer_encoder,
+ embedding_table=transformer_encoder.get_embedding_table(),
+ num_classes=2, # The next sentence prediction label has two classes.
+ activation=tf_utils.get_activation(bert_config.hidden_act),
+ num_token_predictions=max_predictions_per_seq,
+ initializer=initializer,
+ output='logits')
+
+ outputs = pretrainer_model(
+ [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
+ lm_output = outputs['masked_lm']
+ sentence_output = outputs['classification']
+ pretrain_loss_layer = BertPretrainLossAndMetricLayer(
+ vocab_size=bert_config.vocab_size)
+ output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
+ masked_lm_weights, next_sentence_labels)
+ inputs = {
+ 'input_word_ids': input_word_ids,
+ 'input_mask': input_mask,
+ 'input_type_ids': input_type_ids,
+ 'masked_lm_positions': masked_lm_positions,
+ 'masked_lm_ids': masked_lm_ids,
+ 'masked_lm_weights': masked_lm_weights,
+ }
+ if use_next_sentence_label:
+ inputs['next_sentence_labels'] = next_sentence_labels
+
+ keras_model = tf_keras.Model(inputs=inputs, outputs=output_loss)
+ if return_core_pretrainer_model:
+ return keras_model, transformer_encoder, pretrainer_model
+ else:
+ return keras_model, transformer_encoder
+
+
+def squad_model(bert_config,
+ max_seq_length,
+ initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=True):
+ """Returns BERT Squad model along with core BERT model to import weights.
+
+ Args:
+ bert_config: BertConfig, the config defines the core Bert model.
+ max_seq_length: integer, the maximum input sequence length.
+ initializer: Initializer for the final dense layer in the span labeler.
+ Defaulted to TruncatedNormal initializer.
+ hub_module_url: TF-Hub path/url to Bert module.
+ hub_module_trainable: True to finetune layers in the hub module.
+
+ Returns:
+ A tuple of (1) keras model that outputs start logits and end logits and
+ (2) the core BERT transformer encoder.
+ """
+ if initializer is None:
+ initializer = tf_keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range)
+ if not hub_module_url:
+ bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
+ return models.BertSpanLabeler(
+ network=bert_encoder, initializer=initializer), bert_encoder
+
+ input_word_ids = tf_keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
+ input_mask = tf_keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
+ input_type_ids = tf_keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
+ core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
+ pooled_output, sequence_output = core_model(
+ [input_word_ids, input_mask, input_type_ids])
+ bert_encoder = tf_keras.Model(
+ inputs={
+ 'input_word_ids': input_word_ids,
+ 'input_mask': input_mask,
+ 'input_type_ids': input_type_ids,
+ },
+ outputs=[sequence_output, pooled_output],
+ name='core_model')
+ return models.BertSpanLabeler(
+ network=bert_encoder, initializer=initializer), bert_encoder
+
+
+def classifier_model(bert_config,
+ num_labels,
+ max_seq_length=None,
+ final_layer_initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=True):
+ """BERT classifier model in functional API style.
+
+ Construct a Keras model for predicting `num_labels` outputs from an input with
+ maximum sequence length `max_seq_length`.
+
+ Args:
+ bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
+ ALBERT model.
+ num_labels: integer, the number of classes.
+ max_seq_length: integer, the maximum input sequence length.
+ final_layer_initializer: Initializer for final dense layer. Defaulted
+ TruncatedNormal initializer.
+ hub_module_url: TF-Hub path/url to Bert module.
+ hub_module_trainable: True to finetune layers in the hub module.
+
+ Returns:
+ Combined prediction model (words, mask, type) -> (one-hot labels)
+ BERT sub-model (words, mask, type) -> (bert_outputs)
+ """
+ if final_layer_initializer is not None:
+ initializer = final_layer_initializer
+ else:
+ initializer = tf_keras.initializers.TruncatedNormal(
+ stddev=bert_config.initializer_range)
+
+ if not hub_module_url:
+ bert_encoder = get_transformer_encoder(
+ bert_config, max_seq_length, output_range=1)
+ return models.BertClassifier(
+ bert_encoder,
+ num_classes=num_labels,
+ dropout_rate=bert_config.hidden_dropout_prob,
+ initializer=initializer), bert_encoder
+
+ input_word_ids = tf_keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
+ input_mask = tf_keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
+ input_type_ids = tf_keras.layers.Input(
+ shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
+ bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
+ pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
+ output = tf_keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
+ pooled_output)
+
+ output = tf_keras.layers.Dense(
+ num_labels, kernel_initializer=initializer, name='output')(
+ output)
+ return tf_keras.Model(
+ inputs={
+ 'input_word_ids': input_word_ids,
+ 'input_mask': input_mask,
+ 'input_type_ids': input_type_ids
+ },
+ outputs=output), bert_model
diff --git a/modeling/official/legacy/bert/bert_models_test.py b/modeling/official/legacy/bert/bert_models_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..20ebd8926e22a8c18cc7e4ae4a9c16bd567c250d
--- /dev/null
+++ b/modeling/official/legacy/bert/bert_models_test.py
@@ -0,0 +1,106 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.bert import bert_models
+from official.legacy.bert import configs as bert_configs
+from official.nlp.modeling import networks
+
+
+class BertModelsTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(BertModelsTest, self).setUp()
+ self._bert_test_config = bert_configs.BertConfig(
+ attention_probs_dropout_prob=0.0,
+ hidden_act='gelu',
+ hidden_dropout_prob=0.0,
+ hidden_size=16,
+ initializer_range=0.02,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=2,
+ type_vocab_size=2,
+ vocab_size=30522)
+
+ def test_pretrain_model(self):
+ model, encoder = bert_models.pretrain_model(
+ self._bert_test_config,
+ seq_length=5,
+ max_predictions_per_seq=2,
+ initializer=None,
+ use_next_sentence_label=True)
+ self.assertIsInstance(model, tf_keras.Model)
+ self.assertIsInstance(encoder, networks.BertEncoder)
+
+ # model has one scalar output: loss value.
+ self.assertEqual(model.output.shape.as_list(), [
+ None,
+ ])
+
+ # Expect two output from encoder: sequence and classification output.
+ self.assertIsInstance(encoder.output, list)
+ self.assertLen(encoder.output, 2)
+ # shape should be [batch size, hidden_size]
+ self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
+
+ def test_squad_model(self):
+ model, core_model = bert_models.squad_model(
+ self._bert_test_config,
+ max_seq_length=5,
+ initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=None)
+ self.assertIsInstance(model, tf_keras.Model)
+ self.assertIsInstance(core_model, tf_keras.Model)
+
+ # Expect two output from model: start positions and end positions
+ self.assertIsInstance(model.output, list)
+ self.assertLen(model.output, 2)
+
+ # Expect two output from core_model: sequence and classification output.
+ self.assertIsInstance(core_model.output, list)
+ self.assertLen(core_model.output, 2)
+ # shape should be [batch size, None, hidden_size]
+ self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
+ # shape should be [batch size, hidden_size]
+ self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
+
+ def test_classifier_model(self):
+ model, core_model = bert_models.classifier_model(
+ self._bert_test_config,
+ num_labels=3,
+ max_seq_length=5,
+ final_layer_initializer=None,
+ hub_module_url=None,
+ hub_module_trainable=None)
+ self.assertIsInstance(model, tf_keras.Model)
+ self.assertIsInstance(core_model, tf_keras.Model)
+
+ # model has one classification output with num_labels=3.
+ self.assertEqual(model.output.shape.as_list(), [None, 3])
+
+ # Expect two output from core_model: sequence and classification output.
+ self.assertIsInstance(core_model.output, list)
+ self.assertLen(core_model.output, 2)
+ # shape should be [batch size, None, hidden_size]
+ self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
+ # shape should be [batch size, hidden_size]
+ self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/bert/common_flags.py b/modeling/official/legacy/bert/common_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f21a3f59e6c4a98f31c89f18d4ae9f148b13a65
--- /dev/null
+++ b/modeling/official/legacy/bert/common_flags.py
@@ -0,0 +1,125 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defining common flags used across all BERT models/applications."""
+
+from absl import flags
+import tensorflow as tf, tf_keras
+
+from official.utils import hyperparams_flags
+from official.utils.flags import core as flags_core
+
+
+def define_common_bert_flags():
+ """Define common flags for BERT tasks."""
+ flags_core.define_base(
+ data_dir=False,
+ model_dir=True,
+ clean=False,
+ train_epochs=False,
+ epochs_between_evals=False,
+ stop_threshold=False,
+ batch_size=False,
+ num_gpu=True,
+ export_dir=False,
+ distribution_strategy=True,
+ run_eagerly=True)
+ flags_core.define_distribution()
+ flags.DEFINE_string('bert_config_file', None,
+ 'Bert configuration file to define core bert layers.')
+ flags.DEFINE_string(
+ 'model_export_path', None,
+ 'Path to the directory, where trainined model will be '
+ 'exported.')
+ flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
+ flags.DEFINE_string(
+ 'init_checkpoint', None,
+ 'Initial checkpoint (usually from a pre-trained BERT model).')
+ flags.DEFINE_integer('num_train_epochs', 3,
+ 'Total number of training epochs to perform.')
+ flags.DEFINE_integer(
+ 'steps_per_loop', None,
+ 'Number of steps per graph-mode loop. Only training step '
+ 'happens inside the loop. Callbacks will not be called '
+ 'inside. If not set the value will be configured depending on the '
+ 'devices available.')
+ flags.DEFINE_float('learning_rate', 5e-5,
+ 'The initial learning rate for Adam.')
+ flags.DEFINE_float('end_lr', 0.0,
+ 'The end learning rate for learning rate decay.')
+ flags.DEFINE_string('optimizer_type', 'adamw',
+ 'The type of optimizer to use for training (adamw|lamb)')
+ flags.DEFINE_boolean(
+ 'scale_loss', False,
+ 'Whether to divide the loss by number of replica inside the per-replica '
+ 'loss function.')
+ flags.DEFINE_boolean(
+ 'use_keras_compile_fit', False,
+ 'If True, uses Keras compile/fit() API for training logic. Otherwise '
+ 'use custom training loop.')
+ flags.DEFINE_string(
+ 'hub_module_url', None, 'TF-Hub path/url to Bert module. '
+ 'If specified, init_checkpoint flag should not be used.')
+ flags.DEFINE_bool('hub_module_trainable', True,
+ 'True to make keras layers in the hub module trainable.')
+ flags.DEFINE_string(
+ 'sub_model_export_name', None,
+ 'If set, `sub_model` checkpoints are exported into '
+ 'FLAGS.model_dir/FLAGS.sub_model_export_name.')
+ flags.DEFINE_bool('explicit_allreduce', False,
+ 'True to use explicit allreduce instead of the implicit '
+ 'allreduce in optimizer.apply_gradients(). If fp16 mixed '
+ 'precision training is used, this also enables allreduce '
+ 'gradients in fp16.')
+ flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
+ 'Number of bytes of a gradient pack for allreduce. '
+ 'Should be positive integer, if set to 0, all '
+ 'gradients are in one pack. Breaking gradient into '
+ 'packs could enable overlap between allreduce and '
+ 'backprop computation. This flag only takes effect '
+ 'when explicit_allreduce is set to True.')
+
+ flags_core.define_log_steps()
+
+ # Adds flags for mixed precision and multi-worker training.
+ flags_core.define_performance(
+ num_parallel_calls=False,
+ inter_op=False,
+ intra_op=False,
+ synthetic_data=False,
+ max_train_steps=False,
+ dtype=True,
+ loss_scale=True,
+ all_reduce_alg=True,
+ num_packs=False,
+ tf_gpu_thread_mode=True,
+ datasets_num_private_threads=True,
+ enable_xla=True,
+ fp16_implementation=True,
+ )
+
+ # Adds gin configuration flags.
+ hyperparams_flags.define_gin_flags()
+
+
+def dtype():
+ return flags_core.get_tf_dtype(flags.FLAGS)
+
+
+def use_float16():
+ return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
+
+
+def get_loss_scale():
+ return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
diff --git a/modeling/official/legacy/bert/configs.py b/modeling/official/legacy/bert/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..402f7de7a22da12e730705800d86e221027bf6d0
--- /dev/null
+++ b/modeling/official/legacy/bert/configs.py
@@ -0,0 +1,104 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The main BERT model and related functions."""
+
+import copy
+import json
+
+import six
+import tensorflow as tf, tf_keras
+
+
+class BertConfig(object):
+ """Configuration for `BertModel`."""
+
+ def __init__(self,
+ vocab_size,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ initializer_range=0.02,
+ embedding_size=None,
+ backward_compatible=True):
+ """Constructs BertConfig.
+
+ Args:
+ vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
+ hidden_size: Size of the encoder layers and the pooler layer.
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
+ num_attention_heads: Number of attention heads for each attention layer in
+ the Transformer encoder.
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
+ layer in the Transformer encoder.
+ hidden_act: The non-linear activation function (function or string) in the
+ encoder and pooler.
+ hidden_dropout_prob: The dropout probability for all fully connected
+ layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob: The dropout ratio for the attention
+ probabilities.
+ max_position_embeddings: The maximum sequence length that this model might
+ ever be used with. Typically set this to something large just in case
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
+ `BertModel`.
+ initializer_range: The stdev of the truncated_normal_initializer for
+ initializing all weight matrices.
+ embedding_size: (Optional) width of the factorized word embeddings.
+ backward_compatible: Boolean, whether the variables shape are compatible
+ with checkpoints converted from TF 1.x BERT.
+ """
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.embedding_size = embedding_size
+ self.backward_compatible = backward_compatible
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
+ config = BertConfig(vocab_size=None)
+ for (key, value) in six.iteritems(json_object):
+ config.__dict__[key] = value
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `BertConfig` from a json file of parameters."""
+ with tf.io.gfile.GFile(json_file, "r") as reader:
+ text = reader.read()
+ return cls.from_dict(json.loads(text))
+
+ def to_dict(self):
+ """Serializes this instance to a Python dictionary."""
+ output = copy.deepcopy(self.__dict__)
+ return output
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
diff --git a/modeling/official/legacy/bert/export_tfhub.py b/modeling/official/legacy/bert/export_tfhub.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cee3fdfcc58411681c3cad57476cc0cb8bc336a
--- /dev/null
+++ b/modeling/official/legacy/bert/export_tfhub.py
@@ -0,0 +1,139 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A script to export BERT as a TF-Hub SavedModel.
+
+This script is **DEPRECATED** for exporting BERT encoder models;
+see the error message in by main() for details.
+"""
+
+from typing import Text
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+from official.legacy.bert import bert_models
+from official.legacy.bert import configs
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("bert_config_file", None,
+ "Bert configuration file to define core bert layers.")
+flags.DEFINE_string("model_checkpoint_path", None,
+ "File path to TF model checkpoint.")
+flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
+flags.DEFINE_string("vocab_file", None,
+ "The vocabulary file that the BERT model was trained on.")
+flags.DEFINE_bool(
+ "do_lower_case", None, "Whether to lowercase. If None, "
+ "do_lower_case will be enabled if 'uncased' appears in the "
+ "name of --vocab_file")
+flags.DEFINE_enum("model_type", "encoder", ["encoder", "squad"],
+ "What kind of BERT model to export.")
+
+
+def create_bert_model(bert_config: configs.BertConfig) -> tf_keras.Model:
+ """Creates a BERT keras core model from BERT configuration.
+
+ Args:
+ bert_config: A `BertConfig` to create the core model.
+
+ Returns:
+ A keras model.
+ """
+ # Adds input layers just as placeholders.
+ input_word_ids = tf_keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_word_ids")
+ input_mask = tf_keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_mask")
+ input_type_ids = tf_keras.layers.Input(
+ shape=(None,), dtype=tf.int32, name="input_type_ids")
+ transformer_encoder = bert_models.get_transformer_encoder(
+ bert_config, sequence_length=None)
+ sequence_output, pooled_output = transformer_encoder(
+ [input_word_ids, input_mask, input_type_ids])
+ # To keep consistent with legacy hub modules, the outputs are
+ # "pooled_output" and "sequence_output".
+ return tf_keras.Model(
+ inputs=[input_word_ids, input_mask, input_type_ids],
+ outputs=[pooled_output, sequence_output]), transformer_encoder
+
+
+def export_bert_tfhub(bert_config: configs.BertConfig,
+ model_checkpoint_path: Text,
+ hub_destination: Text,
+ vocab_file: Text,
+ do_lower_case: bool = None):
+ """Restores a tf_keras.Model and saves for TF-Hub."""
+ # If do_lower_case is not explicit, default to checking whether "uncased" is
+ # in the vocab file name
+ if do_lower_case is None:
+ do_lower_case = "uncased" in vocab_file
+ logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
+ do_lower_case, vocab_file)
+ core_model, encoder = create_bert_model(bert_config)
+ checkpoint = tf.train.Checkpoint(
+ model=encoder, # Legacy checkpoints.
+ encoder=encoder)
+ checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
+ core_model.vocab_file = tf.saved_model.Asset(vocab_file)
+ core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
+ core_model.save(hub_destination, include_optimizer=False, save_format="tf")
+
+
+def export_bert_squad_tfhub(bert_config: configs.BertConfig,
+ model_checkpoint_path: Text,
+ hub_destination: Text,
+ vocab_file: Text,
+ do_lower_case: bool = None):
+ """Restores a tf_keras.Model for BERT with SQuAD and saves for TF-Hub."""
+ # If do_lower_case is not explicit, default to checking whether "uncased" is
+ # in the vocab file name
+ if do_lower_case is None:
+ do_lower_case = "uncased" in vocab_file
+ logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
+ do_lower_case, vocab_file)
+ span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None)
+ checkpoint = tf.train.Checkpoint(model=span_labeling)
+ checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
+ span_labeling.vocab_file = tf.saved_model.Asset(vocab_file)
+ span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False)
+ span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
+
+
+def main(_):
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ if FLAGS.model_type == "encoder":
+ deprecation_note = (
+ "nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
+ "models. Please switch to nlp/tools/export_tfhub for exporting BERT "
+ "(and other) encoders with dict inputs/outputs conforming to "
+ "https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
+ )
+ logging.error(deprecation_note)
+ print("\n\nNOTICE:", deprecation_note, "\n")
+ export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
+ FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
+ elif FLAGS.model_type == "squad":
+ export_bert_squad_tfhub(bert_config, FLAGS.model_checkpoint_path,
+ FLAGS.export_path, FLAGS.vocab_file,
+ FLAGS.do_lower_case)
+ else:
+ raise ValueError("Unsupported model_type %s." % FLAGS.model_type)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/bert/export_tfhub_test.py b/modeling/official/legacy/bert/export_tfhub_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa8cf2d67355bf1c9d524d2473048a75817544e7
--- /dev/null
+++ b/modeling/official/legacy/bert/export_tfhub_test.py
@@ -0,0 +1,108 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests official.nlp.bert.export_tfhub."""
+
+import os
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf, tf_keras
+import tensorflow_hub as hub
+
+from official.legacy.bert import configs
+from official.legacy.bert import export_tfhub
+
+
+class ExportTfhubTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters("model", "encoder")
+ def test_export_tfhub(self, ckpt_key_name):
+ # Exports a savedmodel for TF-Hub
+ hidden_size = 16
+ bert_config = configs.BertConfig(
+ vocab_size=100,
+ hidden_size=hidden_size,
+ intermediate_size=32,
+ max_position_embeddings=128,
+ num_attention_heads=2,
+ num_hidden_layers=1)
+ bert_model, encoder = export_tfhub.create_bert_model(bert_config)
+ model_checkpoint_dir = os.path.join(self.get_temp_dir(), "checkpoint")
+ checkpoint = tf.train.Checkpoint(**{ckpt_key_name: encoder})
+ checkpoint.save(os.path.join(model_checkpoint_dir, "test"))
+ model_checkpoint_path = tf.train.latest_checkpoint(model_checkpoint_dir)
+
+ vocab_file = os.path.join(self.get_temp_dir(), "uncased_vocab.txt")
+ with tf.io.gfile.GFile(vocab_file, "w") as f:
+ f.write("dummy content")
+
+ hub_destination = os.path.join(self.get_temp_dir(), "hub")
+ export_tfhub.export_bert_tfhub(bert_config, model_checkpoint_path,
+ hub_destination, vocab_file)
+
+ # Restores a hub KerasLayer.
+ hub_layer = hub.KerasLayer(hub_destination, trainable=True)
+
+ if hasattr(hub_layer, "resolved_object"):
+ # Checks meta attributes.
+ self.assertTrue(hub_layer.resolved_object.do_lower_case.numpy())
+ with tf.io.gfile.GFile(
+ hub_layer.resolved_object.vocab_file.asset_path.numpy()) as f:
+ self.assertEqual("dummy content", f.read())
+ # Checks the hub KerasLayer.
+ for source_weight, hub_weight in zip(bert_model.trainable_weights,
+ hub_layer.trainable_weights):
+ self.assertAllClose(source_weight.numpy(), hub_weight.numpy())
+
+ seq_length = 10
+ dummy_ids = np.zeros((2, seq_length), dtype=np.int32)
+ hub_outputs = hub_layer([dummy_ids, dummy_ids, dummy_ids])
+ source_outputs = bert_model([dummy_ids, dummy_ids, dummy_ids])
+
+ # The outputs of hub module are "pooled_output" and "sequence_output",
+ # while the outputs of encoder is in reversed order, i.e.,
+ # "sequence_output" and "pooled_output".
+ encoder_outputs = reversed(encoder([dummy_ids, dummy_ids, dummy_ids]))
+ self.assertEqual(hub_outputs[0].shape, (2, hidden_size))
+ self.assertEqual(hub_outputs[1].shape, (2, seq_length, hidden_size))
+ for source_output, hub_output, encoder_output in zip(
+ source_outputs, hub_outputs, encoder_outputs):
+ self.assertAllClose(source_output.numpy(), hub_output.numpy())
+ self.assertAllClose(source_output.numpy(), encoder_output.numpy())
+
+ # Test that training=True makes a difference (activates dropout).
+ def _dropout_mean_stddev(training, num_runs=20):
+ input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
+ inputs = [input_ids, np.ones_like(input_ids), np.zeros_like(input_ids)]
+ outputs = np.concatenate(
+ [hub_layer(inputs, training=training)[0] for _ in range(num_runs)])
+ return np.mean(np.std(outputs, axis=0))
+
+ self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
+ self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
+
+ # Test propagation of seq_length in shape inference.
+ input_word_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
+ input_mask = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
+ input_type_ids = tf_keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
+ pooled_output, sequence_output = hub_layer(
+ [input_word_ids, input_mask, input_type_ids])
+ self.assertEqual(pooled_output.shape.as_list(), [None, hidden_size])
+ self.assertEqual(sequence_output.shape.as_list(),
+ [None, seq_length, hidden_size])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/bert/input_pipeline.py b/modeling/official/legacy/bert/input_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc6ba48ded1ab1f938d2a024a9c79b2f9dee9b4e
--- /dev/null
+++ b/modeling/official/legacy/bert/input_pipeline.py
@@ -0,0 +1,302 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT model input pipelines."""
+
+import tensorflow as tf, tf_keras
+
+
+def decode_record(record, name_to_features):
+ """Decodes a record to a TensorFlow example."""
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+
+def single_file_dataset(input_file, name_to_features, num_samples=None):
+ """Creates a single-file dataset to be passed for BERT custom training."""
+ # For training, we want a lot of parallel reading and shuffling.
+ # For eval, we want no shuffling and parallel reading doesn't matter.
+ d = tf.data.TFRecordDataset(input_file)
+ if num_samples:
+ d = d.take(num_samples)
+ d = d.map(
+ lambda record: decode_record(record, name_to_features),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ # When `input_file` is a path to a single file or a list
+ # containing a single path, disable auto sharding so that
+ # same input file is sent to all workers.
+ if isinstance(input_file, str) or len(input_file) == 1:
+ options = tf.data.Options()
+ options.experimental_distribute.auto_shard_policy = (
+ tf.data.experimental.AutoShardPolicy.OFF)
+ d = d.with_options(options)
+ return d
+
+
+def create_pretrain_dataset(input_patterns,
+ seq_length,
+ max_predictions_per_seq,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=None,
+ use_next_sentence_label=True,
+ use_position_id=False,
+ output_fake_labels=True):
+ """Creates input dataset from (tf)records files for pretraining."""
+ name_to_features = {
+ 'input_ids':
+ tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask':
+ tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids':
+ tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'masked_lm_positions':
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
+ 'masked_lm_ids':
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.int64),
+ 'masked_lm_weights':
+ tf.io.FixedLenFeature([max_predictions_per_seq], tf.float32),
+ }
+ if use_next_sentence_label:
+ name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
+ tf.int64)
+ if use_position_id:
+ name_to_features['position_ids'] = tf.io.FixedLenFeature([seq_length],
+ tf.int64)
+ for input_pattern in input_patterns:
+ if not tf.io.gfile.glob(input_pattern):
+ raise ValueError('%s does not match any files.' % input_pattern)
+
+ dataset = tf.data.Dataset.list_files(input_patterns, shuffle=is_training)
+
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+ if is_training:
+ dataset = dataset.repeat()
+
+ # We set shuffle buffer to exactly match total number of
+ # training files to ensure that training data is well shuffled.
+ input_files = []
+ for input_pattern in input_patterns:
+ input_files.extend(tf.io.gfile.glob(input_pattern))
+ dataset = dataset.shuffle(len(input_files))
+
+ # In parallel, create tf record dataset for each train files.
+ # cycle_length = 8 means that up to 8 files will be read and deserialized in
+ # parallel. You may want to increase this number if you have a large number of
+ # CPU cores.
+ dataset = dataset.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=8,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if is_training:
+ dataset = dataset.shuffle(100)
+
+ decode_fn = lambda record: decode_record(record, name_to_features)
+ dataset = dataset.map(
+ decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ def _select_data_from_record(record):
+ """Filter out features to use for pretraining."""
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids'],
+ 'masked_lm_positions': record['masked_lm_positions'],
+ 'masked_lm_ids': record['masked_lm_ids'],
+ 'masked_lm_weights': record['masked_lm_weights'],
+ }
+ if use_next_sentence_label:
+ x['next_sentence_labels'] = record['next_sentence_labels']
+ if use_position_id:
+ x['position_ids'] = record['position_ids']
+
+ # TODO(hongkuny): Remove the fake labels after migrating bert pretraining.
+ if output_fake_labels:
+ return (x, record['masked_lm_weights'])
+ else:
+ return x
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=is_training)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def create_classifier_dataset(file_path,
+ seq_length,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=None,
+ label_type=tf.int64,
+ include_sample_weights=False,
+ num_samples=None):
+ """Creates input dataset from (tf)records files for train/eval."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'label_ids': tf.io.FixedLenFeature([], label_type),
+ }
+ if include_sample_weights:
+ name_to_features['weight'] = tf.io.FixedLenFeature([], tf.float32)
+ dataset = single_file_dataset(file_path, name_to_features,
+ num_samples=num_samples)
+
+ # The dataset is always sharded by number of hosts.
+ # num_input_pipelines is the number of hosts rather than number of cores.
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+
+ def _select_data_from_record(record):
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids']
+ }
+ y = record['label_ids']
+ if include_sample_weights:
+ w = record['weight']
+ return (x, y, w)
+ return (x, y)
+
+ if is_training:
+ dataset = dataset.shuffle(100)
+ dataset = dataset.repeat()
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=is_training)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def create_squad_dataset(file_path,
+ seq_length,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=None):
+ """Creates input dataset from (tf)records files for train/eval."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ }
+ if is_training:
+ name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
+ else:
+ name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
+
+ dataset = single_file_dataset(file_path, name_to_features)
+
+ # The dataset is always sharded by number of hosts.
+ # num_input_pipelines is the number of hosts rather than number of cores.
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+
+ def _select_data_from_record(record):
+ """Dispatches record to features and labels."""
+ x, y = {}, {}
+ for name, tensor in record.items():
+ if name in ('start_positions', 'end_positions'):
+ y[name] = tensor
+ elif name == 'input_ids':
+ x['input_word_ids'] = tensor
+ elif name == 'segment_ids':
+ x['input_type_ids'] = tensor
+ else:
+ x[name] = tensor
+ return (x, y)
+
+ if is_training:
+ dataset = dataset.shuffle(100)
+ dataset = dataset.repeat()
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def create_retrieval_dataset(file_path,
+ seq_length,
+ batch_size,
+ input_pipeline_context=None):
+ """Creates input dataset from (tf)records files for scoring."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
+ 'example_id': tf.io.FixedLenFeature([1], tf.int64),
+ }
+ dataset = single_file_dataset(file_path, name_to_features)
+
+ # The dataset is always sharded by number of hosts.
+ # num_input_pipelines is the number of hosts rather than number of cores.
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+
+ def _select_data_from_record(record):
+ x = {
+ 'input_word_ids': record['input_ids'],
+ 'input_mask': record['input_mask'],
+ 'input_type_ids': record['segment_ids']
+ }
+ y = record['example_id']
+ return (x, y)
+
+ dataset = dataset.map(
+ _select_data_from_record,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=False)
+
+ def _pad_to_batch(x, y):
+ cur_size = tf.shape(y)[0]
+ pad_size = batch_size - cur_size
+
+ pad_ids = tf.zeros(shape=[pad_size, seq_length], dtype=tf.int32)
+ for key in ('input_word_ids', 'input_mask', 'input_type_ids'):
+ x[key] = tf.concat([x[key], pad_ids], axis=0)
+
+ pad_labels = -tf.ones(shape=[pad_size, 1], dtype=tf.int32)
+ y = tf.concat([y, pad_labels], axis=0)
+ return x, y
+
+ dataset = dataset.map(
+ _pad_to_batch,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
diff --git a/modeling/official/legacy/bert/model_saving_utils.py b/modeling/official/legacy/bert/model_saving_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f833e5ea8f6b40bf8fda3cba2d0f2dcce433967
--- /dev/null
+++ b/modeling/official/legacy/bert/model_saving_utils.py
@@ -0,0 +1,67 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities to save models."""
+
+import os
+import typing
+from absl import logging
+import tensorflow as tf, tf_keras
+
+
+def export_bert_model(model_export_path: typing.Text,
+ model: tf_keras.Model,
+ checkpoint_dir: typing.Optional[typing.Text] = None,
+ restore_model_using_load_weights: bool = False) -> None:
+ """Export BERT model for serving which does not include the optimizer.
+
+ Args:
+ model_export_path: Path to which exported model will be saved.
+ model: Keras model object to export.
+ checkpoint_dir: Path from which model weights will be loaded, if
+ specified.
+ restore_model_using_load_weights: Whether to use checkpoint.restore() API
+ for custom checkpoint or to use model.load_weights() API. There are 2
+ different ways to save checkpoints. One is using tf.train.Checkpoint and
+ another is using Keras model.save_weights(). Custom training loop
+ implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
+ callback internally uses model.save_weights() API. Since these two API's
+ cannot be used toghether, model loading logic must be take into account
+ how model checkpoint was saved.
+
+ Raises:
+ ValueError when either model_export_path or model is not specified.
+ """
+ if not model_export_path:
+ raise ValueError('model_export_path must be specified.')
+ if not isinstance(model, tf_keras.Model):
+ raise ValueError('model must be a tf_keras.Model object.')
+
+ if checkpoint_dir:
+ if restore_model_using_load_weights:
+ model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
+ assert tf.io.gfile.exists(model_weight_path)
+ model.load_weights(model_weight_path)
+ else:
+ checkpoint = tf.train.Checkpoint(model=model)
+
+ # Restores the model from latest checkpoint.
+ latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
+ assert latest_checkpoint_file
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(
+ latest_checkpoint_file).assert_existing_objects_matched()
+
+ model.save(model_export_path, include_optimizer=False, save_format='tf')
diff --git a/modeling/official/legacy/bert/model_training_utils.py b/modeling/official/legacy/bert/model_training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6bb589c7e462886fd559b777b19367a7bfad53
--- /dev/null
+++ b/modeling/official/legacy/bert/model_training_utils.py
@@ -0,0 +1,590 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A light weight utilities to train NLP models."""
+
+import json
+import os
+import tempfile
+
+from absl import logging
+import tensorflow as tf, tf_keras
+from tensorflow.python.util import deprecation
+from official.common import distribute_utils
+from official.modeling import grad_utils
+
+_SUMMARY_TXT = 'training_summary.txt'
+_MIN_SUMMARY_STEPS = 10
+
+
+def _should_export_checkpoint(strategy):
+ return (not strategy) or strategy.extended.should_checkpoint
+
+
+def _should_export_summary(strategy):
+ return (not strategy) or strategy.extended.should_save_summary
+
+
+def _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_prefix):
+ """Saves model to with provided checkpoint prefix."""
+
+ if _should_export_checkpoint(strategy):
+ checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
+ saved_path = checkpoint.save(checkpoint_path)
+ logging.info('Saving model as TF checkpoint: %s', saved_path)
+ else:
+ # In multi worker training we need every worker to save checkpoint, because
+ # variables can trigger synchronization on read and synchronization needs
+ # all workers to participate. To avoid workers overriding each other we save
+ # to a temporary directory on non-chief workers.
+ tmp_dir = tempfile.mkdtemp()
+ checkpoint.save(os.path.join(tmp_dir, 'ckpt'))
+ tf.io.gfile.rmtree(tmp_dir)
+ return
+
+
+def _get_input_iterator(input_fn, strategy):
+ """Returns distributed dataset iterator."""
+ # When training with TPU pods, datasets needs to be cloned across
+ # workers. Since Dataset instance cannot be cloned in eager mode, we instead
+ # pass callable that returns a dataset.
+ if not callable(input_fn):
+ raise ValueError('`input_fn` should be a closure that returns a dataset.')
+ iterator = iter(strategy.distribute_datasets_from_function(input_fn))
+ return iterator
+
+
+def _float_metric_value(metric):
+ """Gets the value of a float-value keras metric."""
+ return metric.result().numpy().astype(float)
+
+
+def clip_by_global_norm_callback(grads_and_vars):
+ """Performs gradient clipping."""
+ grads, variables = zip(*grads_and_vars)
+ (clipped_grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
+ return zip(clipped_grads, variables)
+
+
+def steps_to_run(current_step, steps_per_epoch, steps_per_loop):
+ """Calculates steps to run on device."""
+ if steps_per_loop <= 0:
+ raise ValueError('steps_per_loop should be positive integer.')
+ if steps_per_loop == 1:
+ return steps_per_loop
+ remainder_in_epoch = current_step % steps_per_epoch
+ if remainder_in_epoch != 0:
+ return min(steps_per_epoch - remainder_in_epoch, steps_per_loop)
+ else:
+ return steps_per_loop
+
+
+def write_txt_summary(training_summary, summary_dir):
+ """Writes a summary text file to record stats."""
+ if not tf.io.gfile.exists(summary_dir):
+ tf.io.gfile.mkdir(summary_dir)
+ summary_path = os.path.join(summary_dir, _SUMMARY_TXT)
+ with tf.io.gfile.GFile(summary_path, 'wb') as f:
+ logging.info('Training Summary: \n%s', str(training_summary))
+ f.write(json.dumps(training_summary, indent=4))
+
+
+@deprecation.deprecated(
+ None, 'This function is deprecated and we do not expect adding new '
+ 'functionalities. Please do not have your code depending '
+ 'on this library.')
+def run_customized_training_loop(
+ # pylint: disable=invalid-name
+ _sentinel=None,
+ # pylint: enable=invalid-name
+ strategy=None,
+ model_fn=None,
+ loss_fn=None,
+ scale_loss=True,
+ model_dir=None,
+ train_input_fn=None,
+ steps_per_epoch=None,
+ num_eval_per_epoch=1,
+ steps_per_loop=None,
+ epochs=1,
+ eval_input_fn=None,
+ eval_steps=None,
+ metric_fn=None,
+ init_checkpoint=None,
+ custom_callbacks=None,
+ run_eagerly=False,
+ sub_model_export_name=None,
+ explicit_allreduce=False,
+ pre_allreduce_callbacks=None,
+ post_allreduce_callbacks=None,
+ train_summary_interval=0,
+ allreduce_bytes_per_pack=0):
+ """Run BERT pretrain model training using low-level API.
+
+ Args:
+ _sentinel: Used to prevent positional parameters. Internal, do not use.
+ strategy: Distribution strategy on which to run low level training loop.
+ model_fn: Function that returns a tuple (model, sub_model). Caller of this
+ function should add optimizer to the `model` via calling
+ `model.compile()` API or manually setting `model.optimizer` attribute.
+ Second element of the returned tuple(sub_model) is an optional sub model
+ to be used for initial checkpoint -- if provided.
+ loss_fn: Function with signature func(labels, logits) and returns a loss
+ tensor.
+ scale_loss: Whether to divide the raw loss by number of replicas before
+ gradients calculation.
+ model_dir: Model directory used during training for restoring/saving model
+ weights.
+ train_input_fn: Function that returns a tf.data.Dataset used for training.
+ steps_per_epoch: Number of steps to run per epoch. At the end of each
+ epoch, model checkpoint will be saved and evaluation will be conducted
+ if evaluation dataset is provided.
+ num_eval_per_epoch: Number of evaluations per epoch.
+ steps_per_loop: Number of steps per graph-mode loop. In order to reduce
+ communication in eager context, training logs are printed every
+ steps_per_loop.
+ epochs: Number of epochs to train.
+ eval_input_fn: Function that returns evaluation dataset. If none,
+ evaluation is skipped.
+ eval_steps: Number of steps to run evaluation. Required if `eval_input_fn`
+ is not none.
+ metric_fn: A metrics function that returns either a Keras Metric object or
+ a list of Keras Metric objects to record evaluation result using
+ evaluation dataset or with training dataset after every epoch.
+ init_checkpoint: Optional checkpoint to load to `sub_model` returned by
+ `model_fn`.
+ custom_callbacks: A list of Keras Callbacks objects to run during
+ training. More specifically, `on_train_begin(), on_train_end(),
+ on_batch_begin()`, `on_batch_end()`, `on_epoch_begin()`,
+ `on_epoch_end()` methods are invoked during training. Note that some
+ metrics may be missing from `logs`.
+ run_eagerly: Whether to run model training in pure eager execution. This
+ should be disable for TPUStrategy.
+ sub_model_export_name: If not None, will export `sub_model` returned by
+ `model_fn` into checkpoint files. The name of intermediate checkpoint
+ file is {sub_model_export_name}_step_{step}.ckpt and the last
+ checkpint's name is {sub_model_export_name}.ckpt; if None, `sub_model`
+ will not be exported as checkpoint.
+ explicit_allreduce: Whether to explicitly perform gradient allreduce,
+ instead of relying on implicit allreduce in optimizer.apply_gradients().
+ default is False. For now, if training using FP16 mixed precision,
+ explicit allreduce will aggregate gradients in FP16 format. For TPU and
+ GPU training using FP32, explicit allreduce will aggregate gradients in
+ FP32 format.
+ pre_allreduce_callbacks: A list of callback functions that takes gradients
+ and model variables pairs as input, manipulate them, and returns a new
+ gradients and model variables paris. The callback functions will be
+ invoked in the list order and before gradients are allreduced. With
+ mixed precision training, the pre_allreduce_allbacks will be applied on
+ scaled_gradients. Default is no callbacks. Only used when
+ explicit_allreduce=True.
+ post_allreduce_callbacks: A list of callback functions that takes
+ gradients and model variables pairs as input, manipulate them, and
+ returns a new gradients and model variables paris. The callback
+ functions will be invoked in the list order and right before gradients
+ are applied to variables for updates. Default is no callbacks. Only used
+ when explicit_allreduce=True.
+ train_summary_interval: Step interval for training summaries. If the value
+ is a negative number, then training summaries are not enabled.
+ allreduce_bytes_per_pack: A non-negative integer. Breaks collective
+ operations into packs of certain size. If it's zero, all gradients are
+ in one pack. Breaking gradient into packs could enable overlap between
+ allreduce and backprop computation. This flag only takes effect when
+ explicit_allreduce is set to True.'
+
+ Returns:
+ Trained model.
+
+ Raises:
+ ValueError: (1) When model returned by `model_fn` does not have optimizer
+ attribute or when required parameters are set to none. (2) eval args are
+ not specified correctly. (3) metric_fn must be a callable if specified.
+ (4) sub_model_checkpoint_name is specified, but `sub_model` returned
+ by `model_fn` is None.
+ """
+
+ if _sentinel is not None:
+ raise ValueError('only call `run_customized_training_loop()` '
+ 'with named arguments.')
+
+ required_arguments = [
+ strategy, model_fn, loss_fn, model_dir, steps_per_epoch, train_input_fn
+ ]
+
+ steps_between_evals = int(steps_per_epoch / num_eval_per_epoch)
+ if [arg for arg in required_arguments if arg is None]:
+ raise ValueError('`strategy`, `model_fn`, `loss_fn`, `model_dir`, '
+ '`steps_per_epoch` and `train_input_fn` are required '
+ 'parameters.')
+ if not steps_per_loop:
+ if tf.config.list_logical_devices('TPU'):
+ # One can't fully utilize a TPU with steps_per_loop=1, so in this case
+ # default users to a more useful value.
+ steps_per_loop = min(1000, steps_between_evals)
+ else:
+ steps_per_loop = 1
+ logging.info('steps_per_loop not specified. Using steps_per_loop=%d',
+ steps_per_loop)
+ if steps_per_loop > steps_between_evals:
+ logging.warning(
+ 'steps_per_loop: %d is specified to be greater than '
+ ' steps_between_evals: %d, we will use steps_between_evals as'
+ ' steps_per_loop.', steps_per_loop, steps_between_evals)
+ steps_per_loop = steps_between_evals
+ assert tf.executing_eagerly()
+
+ if run_eagerly:
+ if isinstance(
+ strategy,
+ (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)):
+ raise ValueError(
+ 'TPUStrategy should not run eagerly as it heavily relies on graph'
+ ' optimization for the distributed system.')
+
+ if eval_input_fn and eval_steps is None:
+ raise ValueError(
+ '`eval_step` is required when `eval_input_fn ` is not none.')
+ if metric_fn and not callable(metric_fn):
+ raise ValueError(
+ 'if `metric_fn` is specified, metric_fn must be a callable.')
+
+ total_training_steps = steps_per_epoch * epochs
+ train_iterator = _get_input_iterator(train_input_fn, strategy)
+ eval_loss_metric = tf_keras.metrics.Mean('training_loss', dtype=tf.float32)
+
+ with distribute_utils.get_strategy_scope(strategy):
+ # To correctly place the model weights on accelerators,
+ # model and optimizer should be created in scope.
+ model, sub_model = model_fn()
+ if not hasattr(model, 'optimizer'):
+ raise ValueError('User should set optimizer attribute to model '
+ 'inside `model_fn`.')
+ if sub_model_export_name and sub_model is None:
+ raise ValueError('sub_model_export_name is specified as %s, but '
+ 'sub_model is None.' % sub_model_export_name)
+
+ callback_list = tf_keras.callbacks.CallbackList(
+ callbacks=custom_callbacks, model=model)
+
+ optimizer = model.optimizer
+
+ if init_checkpoint:
+ logging.info(
+ 'Checkpoint file %s found and restoring from '
+ 'initial checkpoint for core model.', init_checkpoint)
+ checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
+ checkpoint.read(init_checkpoint).assert_existing_objects_matched()
+ logging.info('Loading from checkpoint file completed')
+
+ train_loss_metric = tf_keras.metrics.Mean('training_loss', dtype=tf.float32)
+ eval_metrics = metric_fn() if metric_fn else []
+ if not isinstance(eval_metrics, list):
+ eval_metrics = [eval_metrics]
+ # If evaluation is required, make a copy of metric as it will be used by
+ # both train and evaluation.
+ train_metrics = [
+ metric.__class__.from_config(metric.get_config())
+ for metric in eval_metrics
+ ]
+
+ # Create summary writers
+ if _should_export_summary(strategy):
+ summary_dir = os.path.join(model_dir, 'summaries')
+ else:
+ # In multi worker training we need every worker to write summary, because
+ # variables can trigger synchronization on read and synchronization needs
+ # all workers to participate.
+ summary_dir = tempfile.mkdtemp()
+ eval_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, 'eval'))
+ last_summary_step = 0
+ if steps_per_loop >= _MIN_SUMMARY_STEPS and train_summary_interval >= 0:
+ # Only writes summary when the stats are collected sufficiently over
+ # enough steps.
+ train_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, 'train'))
+ else:
+ train_summary_writer = tf.summary.create_noop_writer()
+
+ # Collects training variables.
+ training_vars = model.trainable_variables
+
+ def _replicated_step(inputs):
+ """Replicated training step."""
+
+ inputs, labels = inputs
+ with tf.GradientTape() as tape:
+ model_outputs = model(inputs, training=True)
+ loss = loss_fn(labels, model_outputs)
+ # Raw loss is used for reporting in metrics/logs.
+ raw_loss = loss
+ if scale_loss:
+ # Scales down the loss for gradients to be invariant from replicas.
+ loss = loss / strategy.num_replicas_in_sync
+
+ if explicit_allreduce:
+ grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
+ training_vars,
+ pre_allreduce_callbacks,
+ post_allreduce_callbacks,
+ allreduce_bytes_per_pack)
+ else:
+ if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
+ with tape:
+ scaled_loss = optimizer.get_scaled_loss(loss)
+ scaled_grads = tape.gradient(scaled_loss, training_vars)
+ grads = optimizer.get_unscaled_gradients(scaled_grads)
+ else:
+ grads = tape.gradient(loss, training_vars)
+ optimizer.apply_gradients(zip(grads, training_vars))
+ # For reporting, the metric takes the mean of losses.
+ train_loss_metric.update_state(raw_loss)
+ for metric in train_metrics:
+ metric.update_state(labels, model_outputs)
+
+ @tf.function
+ def train_steps(iterator, steps):
+ """Performs distributed training steps in a loop.
+
+ Args:
+ iterator: the distributed iterator of training datasets.
+ steps: an tf.int32 integer tensor to specify number of steps to run
+ inside host training loop.
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ """
+ if not isinstance(steps, tf.Tensor):
+ raise ValueError('steps should be an Tensor. Python object may cause '
+ 'retracing.')
+
+ for _ in tf.range(steps):
+ strategy.run(_replicated_step, args=(next(iterator),))
+
+ def train_single_step(iterator):
+ """Performs a distributed training step.
+
+ Args:
+ iterator: the distributed iterator of training datasets.
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+ """
+ strategy.run(_replicated_step, args=(next(iterator),))
+
+ def test_step(iterator):
+ """Calculates evaluation metrics on distributed devices."""
+
+ def _test_step_fn(inputs):
+ """Replicated accuracy calculation."""
+
+ inputs, labels = inputs
+ model_outputs = model(inputs, training=False)
+ for metric in eval_metrics:
+ metric.update_state(labels, model_outputs)
+ return model_outputs, labels
+
+ outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
+ outputs = tf.nest.map_structure(strategy.experimental_local_results,
+ outputs)
+ labels = tf.nest.map_structure(strategy.experimental_local_results,
+ labels)
+ return outputs, labels
+
+ if not run_eagerly:
+ train_single_step = tf.function(train_single_step)
+ test_step = tf.function(test_step)
+
+ def _run_evaluation(current_training_step, test_iterator):
+ """Runs validation steps and aggregate metrics.
+
+ Args:
+ current_training_step: tf.int32 tensor containing the current step.
+ test_iterator: distributed iterator of test datasets.
+
+ Returns:
+ A dict of metic names and values.
+ """
+ # The last batch of the evaluation is often smaller than previous ones.
+ # Moreover, in some distributed pieces it might even be empty. Therefore,
+ # different from the way training_loss is calculated, it is needed to
+ # gather all the logits and labels here to calculate the evaluation loss
+ # outside.
+ loss_list, loss_weights = list(), list()
+ for _ in range(eval_steps):
+ outputs, labels = test_step(test_iterator)
+ for cur_logits, cur_labels in zip(outputs, labels):
+ # This is to handle cases when cur_labels is not a single tensor,
+ # but a dict of tensors.
+ cur_weight = tf.shape(tf.nest.flatten(cur_labels)[0])[0]
+ if cur_weight != 0:
+ loss_list.append(loss_fn(cur_labels, cur_logits).numpy())
+ loss_weights.append(cur_weight)
+ # The sample_weights are the actual number of examples in each batch,
+ # a summation of numbers of examples in each replica if using
+ # distributed training.
+ eval_loss_metric.update_state(loss_list, sample_weight=loss_weights)
+
+ logs = {}
+ with eval_summary_writer.as_default():
+ for metric in [eval_loss_metric] + eval_metrics + model.metrics:
+ metric_value = _float_metric_value(metric)
+ logs[metric.name] = metric_value
+ logging.info('Step: [%d] Validation %s = %f', current_training_step,
+ metric.name, metric_value)
+ tf.summary.scalar(
+ metric.name, metric_value, step=current_training_step)
+ eval_summary_writer.flush()
+
+ return logs
+
+ # Training loop starts here.
+ checkpoint = tf.train.Checkpoint(
+ model=model, optimizer=optimizer, global_step=optimizer.iterations)
+ sub_model_checkpoint = tf.train.Checkpoint(
+ model=sub_model,
+ global_step=optimizer.iterations) if sub_model_export_name else None
+
+ latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
+ if latest_checkpoint_file:
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(latest_checkpoint_file)
+ logging.info('Loading from checkpoint file completed')
+
+ current_step = optimizer.iterations.numpy()
+ checkpoint_name = 'ctl_step_{step}.ckpt'
+
+ logs = {}
+ callback_list.on_train_begin()
+ while current_step < total_training_steps and not model.stop_training:
+ if current_step % steps_per_epoch == 0:
+ callback_list.on_epoch_begin(int(current_step / steps_per_epoch) + 1)
+
+ # Training loss/metric are taking average over steps inside micro
+ # training loop. We reset the their values before each round.
+ train_loss_metric.reset_states()
+ for metric in train_metrics + model.metrics:
+ metric.reset_states()
+
+ callback_list.on_batch_begin(current_step)
+ # Runs several steps in the host while loop.
+ steps = steps_to_run(current_step, steps_between_evals, steps_per_loop)
+
+ if tf.config.list_physical_devices('GPU'):
+ # TODO(zongweiz): merge with train_steps once tf.while_loop
+ # GPU performance bugs are fixed.
+ for _ in range(steps):
+ train_single_step(train_iterator)
+ else:
+ # Converts steps to a Tensor to avoid tf.function retracing.
+ train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
+ train_loss = _float_metric_value(train_loss_metric)
+ current_step += steps
+
+ # Updates training logging.
+ training_status = 'Train Step: %d/%d / loss = %s' % (
+ current_step, total_training_steps, train_loss)
+
+ if current_step >= last_summary_step + train_summary_interval:
+ summary_writer = train_summary_writer
+ last_summary_step = current_step
+ else:
+ summary_writer = tf.summary.create_noop_writer()
+
+ with summary_writer.as_default():
+ if callable(optimizer.learning_rate):
+ tf.summary.scalar(
+ 'learning_rate',
+ optimizer.learning_rate(current_step),
+ step=current_step)
+ tf.summary.scalar(train_loss_metric.name, train_loss, step=current_step)
+ for metric in train_metrics + model.metrics:
+ metric_value = _float_metric_value(metric)
+ training_status += ' %s = %f' % (metric.name, metric_value)
+ tf.summary.scalar(metric.name, metric_value, step=current_step)
+ summary_writer.flush()
+ logging.info(training_status)
+
+ # If no need for evaluation, we only call on_batch_end with train_loss,
+ # this is to ensure we get granular global_step/sec on Tensorboard.
+ if current_step % steps_between_evals:
+ callback_list.on_batch_end(current_step - 1, {'loss': train_loss})
+ else:
+ # Save a submodel with the step in the file name after each epoch.
+ if sub_model_export_name:
+ _save_checkpoint(
+ strategy, sub_model_checkpoint, model_dir,
+ '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
+
+ # Save model checkpoints and run validation steps after each epoch
+ # (with the exception of the final epoch which is handled after the
+ # training loop).
+ if current_step < total_training_steps:
+ _save_checkpoint(strategy, checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ if eval_input_fn:
+ # Re-initialize evaluation metric.
+ eval_loss_metric.reset_states()
+ for metric in eval_metrics + model.metrics:
+ metric.reset_states()
+
+ logging.info('Running evaluation after step: %s.', current_step)
+ logs = _run_evaluation(current_step,
+ _get_input_iterator(eval_input_fn, strategy))
+ # We add train_loss here rather than call on_batch_end twice to make
+ # sure that no duplicated values are generated.
+ logs['loss'] = train_loss
+ callback_list.on_batch_end(current_step - 1, logs)
+
+ # Calls on_epoch_end after each real epoch ends to prevent mis-calculation
+ # of training steps.
+ if current_step % steps_per_epoch == 0:
+ callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
+
+ if sub_model_export_name:
+ _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
+ '%s.ckpt' % sub_model_export_name)
+
+ _save_checkpoint(strategy, checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ if eval_input_fn:
+ # Re-initialize evaluation metric.
+ eval_loss_metric.reset_states()
+ for metric in eval_metrics + model.metrics:
+ metric.reset_states()
+
+ logging.info('Running final evaluation after training is complete.')
+ logs = _run_evaluation(current_step,
+ _get_input_iterator(eval_input_fn, strategy))
+ callback_list.on_epoch_end(int(current_step / steps_per_epoch), logs)
+ training_summary = {
+ 'total_training_steps': total_training_steps,
+ 'train_loss': _float_metric_value(train_loss_metric),
+ }
+ for metric in model.metrics:
+ training_summary[metric.name] = _float_metric_value(metric)
+ if eval_metrics:
+ training_summary['last_train_metrics'] = _float_metric_value(
+ train_metrics[0])
+ training_summary['eval_metrics'] = _float_metric_value(eval_metrics[0])
+
+ write_txt_summary(training_summary, summary_dir)
+
+ if not _should_export_summary(strategy):
+ tf.io.gfile.rmtree(summary_dir)
+
+ callback_list.on_train_end()
+
+ return model
diff --git a/modeling/official/legacy/bert/model_training_utils_test.py b/modeling/official/legacy/bert/model_training_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d44ea785caa745cbcfe9a07d93ccec00d00c8c1
--- /dev/null
+++ b/modeling/official/legacy/bert/model_training_utils_test.py
@@ -0,0 +1,306 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.modeling.training.model_training_utils."""
+
+import os
+
+from absl import logging
+from absl.testing import flagsaver
+from absl.testing import parameterized
+from absl.testing.absltest import mock
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.legacy.bert import common_flags
+from official.legacy.bert import model_training_utils
+
+
+common_flags.define_common_bert_flags()
+
+
+def eager_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],)
+
+
+def eager_gpu_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],)
+
+
+def create_fake_data_input_fn(batch_size, features_shape, num_classes):
+ """Creates a dummy input function with the given feature and label shapes.
+
+ Args:
+ batch_size: integer.
+ features_shape: list[int]. Feature shape for an individual example.
+ num_classes: integer. Number of labels.
+
+ Returns:
+ An input function that is usable in the executor.
+ """
+
+ def _dataset_fn(input_context=None):
+ """An input function for generating fake data."""
+ local_batch_size = input_context.get_per_replica_batch_size(batch_size)
+ features = np.random.rand(64, *features_shape)
+ labels = np.random.randint(2, size=[64, num_classes])
+ # Convert the inputs to a Dataset.
+ dataset = tf.data.Dataset.from_tensor_slices((features, labels))
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+
+ def _assign_dtype(features, labels):
+ features = tf.cast(features, tf.float32)
+ labels = tf.cast(labels, tf.float32)
+ return features, labels
+
+ # Shuffle, repeat, and batch the examples.
+ dataset = dataset.map(_assign_dtype)
+ dataset = dataset.shuffle(64).repeat()
+ dataset = dataset.batch(local_batch_size, drop_remainder=True)
+ dataset = dataset.prefetch(buffer_size=64)
+ return dataset
+
+ return _dataset_fn
+
+
+def create_model_fn(input_shape, num_classes, use_float16=False):
+
+ def _model_fn():
+ """A one-layer softmax model suitable for testing."""
+ input_layer = tf_keras.layers.Input(shape=input_shape)
+ x = tf_keras.layers.Dense(num_classes, activation='relu')(input_layer)
+ output_layer = tf_keras.layers.Dense(num_classes, activation='softmax')(x)
+ sub_model = tf_keras.models.Model(input_layer, x, name='sub_model')
+ model = tf_keras.models.Model(input_layer, output_layer, name='model')
+ model.add_metric(
+ tf.reduce_mean(input_layer), name='mean_input', aggregation='mean')
+ model.optimizer = tf_keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
+ if use_float16:
+ model.optimizer = tf_keras.mixed_precision.LossScaleOptimizer(
+ model.optimizer)
+ return model, sub_model
+
+ return _model_fn
+
+
+def metric_fn():
+ """Gets a tf.keras metric object."""
+ return tf_keras.metrics.CategoricalAccuracy(name='accuracy', dtype=tf.float32)
+
+
+def summaries_with_matching_keyword(keyword, summary_dir):
+ """Yields summary protos matching given keyword from event file."""
+ event_paths = tf.io.gfile.glob(os.path.join(summary_dir, 'events*'))
+ for event in tf.compat.v1.train.summary_iterator(event_paths[-1]):
+ if event.summary is not None:
+ for value in event.summary.value:
+ if keyword in value.tag:
+ logging.error(event)
+ yield event.summary
+
+
+def check_eventfile_for_keyword(keyword, summary_dir):
+ """Checks event files for the keyword."""
+ return any(summaries_with_matching_keyword(keyword, summary_dir))
+
+
+class RecordingCallback(tf_keras.callbacks.Callback):
+
+ def __init__(self):
+ self.batch_begin = [] # (batch, logs)
+ self.batch_end = [] # (batch, logs)
+ self.epoch_begin = [] # (epoch, logs)
+ self.epoch_end = [] # (epoch, logs)
+
+ def on_batch_begin(self, batch, logs=None):
+ self.batch_begin.append((batch, logs))
+
+ def on_batch_end(self, batch, logs=None):
+ self.batch_end.append((batch, logs))
+
+ def on_epoch_begin(self, epoch, logs=None):
+ self.epoch_begin.append((epoch, logs))
+
+ def on_epoch_end(self, epoch, logs=None):
+ self.epoch_end.append((epoch, logs))
+
+
+class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(ModelTrainingUtilsTest, self).setUp()
+ self._model_fn = create_model_fn(input_shape=[128], num_classes=3)
+
+ @flagsaver.flagsaver
+ def run_training(self, strategy, model_dir, steps_per_loop, run_eagerly):
+ input_fn = create_fake_data_input_fn(
+ batch_size=8, features_shape=[128], num_classes=3)
+ model_training_utils.run_customized_training_loop(
+ strategy=strategy,
+ model_fn=self._model_fn,
+ loss_fn=tf_keras.losses.categorical_crossentropy,
+ model_dir=model_dir,
+ steps_per_epoch=20,
+ steps_per_loop=steps_per_loop,
+ epochs=2,
+ train_input_fn=input_fn,
+ eval_input_fn=input_fn,
+ eval_steps=10,
+ init_checkpoint=None,
+ sub_model_export_name='my_submodel_name',
+ metric_fn=metric_fn,
+ custom_callbacks=None,
+ run_eagerly=run_eagerly)
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_train_eager_single_step(self, distribution):
+ model_dir = self.create_tempdir().full_path
+ if isinstance(
+ distribution,
+ (tf.distribute.TPUStrategy, tf.distribute.experimental.TPUStrategy)):
+ with self.assertRaises(ValueError):
+ self.run_training(
+ distribution, model_dir, steps_per_loop=1, run_eagerly=True)
+ else:
+ self.run_training(
+ distribution, model_dir, steps_per_loop=1, run_eagerly=True)
+
+ @combinations.generate(eager_gpu_strategy_combinations())
+ def test_train_eager_mixed_precision(self, distribution):
+ model_dir = self.create_tempdir().full_path
+ tf_keras.mixed_precision.set_global_policy('mixed_float16')
+ self._model_fn = create_model_fn(
+ input_shape=[128], num_classes=3, use_float16=True)
+ self.run_training(
+ distribution, model_dir, steps_per_loop=1, run_eagerly=True)
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_train_check_artifacts(self, distribution):
+ model_dir = self.create_tempdir().full_path
+ self.run_training(
+ distribution, model_dir, steps_per_loop=10, run_eagerly=False)
+
+ # Two checkpoints should be saved after two epochs.
+ files = map(os.path.basename,
+ tf.io.gfile.glob(os.path.join(model_dir, 'ctl_step_*index')))
+ self.assertCountEqual(
+ ['ctl_step_20.ckpt-1.index', 'ctl_step_40.ckpt-2.index'], files)
+
+ # Three submodel checkpoints should be saved after two epochs (one after
+ # each epoch plus one final).
+ files = map(
+ os.path.basename,
+ tf.io.gfile.glob(os.path.join(model_dir, 'my_submodel_name*index')))
+ self.assertCountEqual([
+ 'my_submodel_name.ckpt-3.index',
+ 'my_submodel_name_step_20.ckpt-1.index',
+ 'my_submodel_name_step_40.ckpt-2.index'
+ ], files)
+
+ self.assertNotEmpty(
+ tf.io.gfile.glob(
+ os.path.join(model_dir, 'summaries/training_summary*')))
+
+ # Loss and accuracy values should be written into summaries.
+ self.assertTrue(
+ check_eventfile_for_keyword('loss',
+ os.path.join(model_dir, 'summaries/train')))
+ self.assertTrue(
+ check_eventfile_for_keyword('accuracy',
+ os.path.join(model_dir, 'summaries/train')))
+ self.assertTrue(
+ check_eventfile_for_keyword('mean_input',
+ os.path.join(model_dir, 'summaries/train')))
+ self.assertTrue(
+ check_eventfile_for_keyword('accuracy',
+ os.path.join(model_dir, 'summaries/eval')))
+ self.assertTrue(
+ check_eventfile_for_keyword('mean_input',
+ os.path.join(model_dir, 'summaries/eval')))
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_train_check_callbacks(self, distribution):
+ model_dir = self.create_tempdir().full_path
+ callback = RecordingCallback()
+ callbacks = [callback]
+ input_fn = create_fake_data_input_fn(
+ batch_size=8, features_shape=[128], num_classes=3)
+ model_training_utils.run_customized_training_loop(
+ strategy=distribution,
+ model_fn=self._model_fn,
+ loss_fn=tf_keras.losses.categorical_crossentropy,
+ model_dir=model_dir,
+ steps_per_epoch=20,
+ num_eval_per_epoch=4,
+ steps_per_loop=10,
+ epochs=2,
+ train_input_fn=input_fn,
+ eval_input_fn=input_fn,
+ eval_steps=10,
+ init_checkpoint=None,
+ metric_fn=metric_fn,
+ custom_callbacks=callbacks,
+ run_eagerly=False)
+ self.assertEqual(callback.epoch_begin, [(1, {}), (2, {})])
+ epoch_ends, epoch_end_infos = zip(*callback.epoch_end)
+ self.assertEqual(list(epoch_ends), [1, 2, 2])
+ for info in epoch_end_infos:
+ self.assertIn('accuracy', info)
+
+ self.assertEqual(callback.batch_begin, [(0, {}), (5, {}), (10, {}),
+ (15, {}), (20, {}), (25, {}),
+ (30, {}), (35, {})])
+ batch_ends, batch_end_infos = zip(*callback.batch_end)
+ self.assertEqual(list(batch_ends), [4, 9, 14, 19, 24, 29, 34, 39])
+ for info in batch_end_infos:
+ self.assertIn('loss', info)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.one_device_strategy_gpu,
+ ],))
+ def test_train_check_artifacts_non_chief(self, distribution):
+ # We shouldn't export artifacts on non-chief workers. Since there's no easy
+ # way to test with real MultiWorkerMirroredStrategy, we patch the strategy
+ # to make it as if it's MultiWorkerMirroredStrategy on non-chief workers.
+ extended = distribution.extended
+ with mock.patch.object(extended.__class__, 'should_checkpoint',
+ new_callable=mock.PropertyMock, return_value=False), \
+ mock.patch.object(extended.__class__, 'should_save_summary',
+ new_callable=mock.PropertyMock, return_value=False):
+ model_dir = self.create_tempdir().full_path
+ self.run_training(
+ distribution, model_dir, steps_per_loop=10, run_eagerly=False)
+ self.assertEmpty(tf.io.gfile.listdir(model_dir))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/bert/run_classifier.py b/modeling/official/legacy/bert/run_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..125675b27465b197aac8f9034d5e175091962b60
--- /dev/null
+++ b/modeling/official/legacy/bert/run_classifier.py
@@ -0,0 +1,515 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT classification or regression finetuning runner in TF 2.x."""
+
+import functools
+import json
+import math
+import os
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import gin
+import tensorflow as tf, tf_keras
+from official.common import distribute_utils
+from official.legacy.bert import bert_models
+from official.legacy.bert import common_flags
+from official.legacy.bert import configs as bert_configs
+from official.legacy.bert import input_pipeline
+from official.legacy.bert import model_saving_utils
+from official.modeling import performance
+from official.nlp import optimization
+from official.utils.misc import keras_utils
+
+flags.DEFINE_enum(
+ 'mode', 'train_and_eval', ['train_and_eval', 'export_only', 'predict'],
+ 'One of {"train_and_eval", "export_only", "predict"}. `train_and_eval`: '
+ 'trains the model and evaluates in the meantime. '
+ '`export_only`: will take the latest checkpoint inside '
+ 'model_dir and export a `SavedModel`. `predict`: takes a checkpoint and '
+ 'restores the model to output predictions on the test set.')
+flags.DEFINE_string('train_data_path', None,
+ 'Path to training data for BERT classifier.')
+flags.DEFINE_string('eval_data_path', None,
+ 'Path to evaluation data for BERT classifier.')
+flags.DEFINE_string(
+ 'input_meta_data_path', None,
+ 'Path to file that contains meta data about input '
+ 'to be used for training and evaluation.')
+flags.DEFINE_integer('train_data_size', None, 'Number of training samples '
+ 'to use. If None, uses the full train data. '
+ '(default: None).')
+flags.DEFINE_string('predict_checkpoint_path', None,
+ 'Path to the checkpoint for predictions.')
+flags.DEFINE_integer(
+ 'num_eval_per_epoch', 1,
+ 'Number of evaluations per epoch. The purpose of this flag is to provide '
+ 'more granular evaluation scores and checkpoints. For example, if original '
+ 'data has N samples and num_eval_per_epoch is n, then each epoch will be '
+ 'evaluated every N/n samples.')
+flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
+flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
+
+common_flags.define_common_bert_flags()
+
+FLAGS = flags.FLAGS
+
+LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
+
+
+def get_loss_fn(num_classes):
+ """Gets the classification loss function."""
+
+ def classification_loss_fn(labels, logits):
+ """Classification loss."""
+ labels = tf.reshape(labels, [-1])
+ log_probs = tf.nn.log_softmax(logits, axis=-1)
+ one_hot_labels = tf.one_hot(
+ tf.cast(labels, dtype=tf.int32), depth=num_classes, dtype=tf.float32)
+ per_example_loss = -tf.reduce_sum(
+ tf.cast(one_hot_labels, dtype=tf.float32) * log_probs, axis=-1)
+ return tf.reduce_mean(per_example_loss)
+
+ return classification_loss_fn
+
+
+def get_dataset_fn(input_file_pattern,
+ max_seq_length,
+ global_batch_size,
+ is_training,
+ label_type=tf.int64,
+ include_sample_weights=False,
+ num_samples=None):
+ """Gets a closure to create a dataset."""
+
+ def _dataset_fn(ctx=None):
+ """Returns tf.data.Dataset for distributed BERT pretraining."""
+ batch_size = ctx.get_per_replica_batch_size(
+ global_batch_size) if ctx else global_batch_size
+ dataset = input_pipeline.create_classifier_dataset(
+ tf.io.gfile.glob(input_file_pattern),
+ max_seq_length,
+ batch_size,
+ is_training=is_training,
+ input_pipeline_context=ctx,
+ label_type=label_type,
+ include_sample_weights=include_sample_weights,
+ num_samples=num_samples)
+ return dataset
+
+ return _dataset_fn
+
+
+def run_bert_classifier(strategy,
+ bert_config,
+ input_meta_data,
+ model_dir,
+ epochs,
+ steps_per_epoch,
+ steps_per_loop,
+ eval_steps,
+ warmup_steps,
+ initial_lr,
+ init_checkpoint,
+ train_input_fn,
+ eval_input_fn,
+ training_callbacks=True,
+ custom_callbacks=None,
+ custom_metrics=None):
+ """Run BERT classifier training using low-level API."""
+ max_seq_length = input_meta_data['max_seq_length']
+ num_classes = input_meta_data.get('num_labels', 1)
+ is_regression = num_classes == 1
+
+ def _get_classifier_model():
+ """Gets a classifier model."""
+ classifier_model, core_model = (
+ bert_models.classifier_model(
+ bert_config,
+ num_classes,
+ max_seq_length,
+ hub_module_url=FLAGS.hub_module_url,
+ hub_module_trainable=FLAGS.hub_module_trainable))
+ optimizer = optimization.create_optimizer(initial_lr,
+ steps_per_epoch * epochs,
+ warmup_steps, FLAGS.end_lr,
+ FLAGS.optimizer_type)
+ classifier_model.optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=common_flags.use_float16())
+ return classifier_model, core_model
+
+ # tf_keras.losses objects accept optional sample_weight arguments (eg. coming
+ # from the dataset) to compute weighted loss, as used for the regression
+ # tasks. The classification tasks, using the custom get_loss_fn don't accept
+ # sample weights though.
+ loss_fn = (tf_keras.losses.MeanSquaredError() if is_regression
+ else get_loss_fn(num_classes))
+
+ # Defines evaluation metrics function, which will create metrics in the
+ # correct device and strategy scope.
+ if custom_metrics:
+ metric_fn = custom_metrics
+ elif is_regression:
+ metric_fn = functools.partial(
+ tf_keras.metrics.MeanSquaredError,
+ 'mean_squared_error',
+ dtype=tf.float32)
+ else:
+ metric_fn = functools.partial(
+ tf_keras.metrics.SparseCategoricalAccuracy,
+ 'accuracy',
+ dtype=tf.float32)
+
+ # Start training using Keras compile/fit API.
+ logging.info('Training using TF 2.x Keras compile/fit API with '
+ 'distribution strategy.')
+ return run_keras_compile_fit(
+ model_dir,
+ strategy,
+ _get_classifier_model,
+ train_input_fn,
+ eval_input_fn,
+ loss_fn,
+ metric_fn,
+ init_checkpoint,
+ epochs,
+ steps_per_epoch,
+ steps_per_loop,
+ eval_steps,
+ training_callbacks=training_callbacks,
+ custom_callbacks=custom_callbacks)
+
+
+def run_keras_compile_fit(model_dir,
+ strategy,
+ model_fn,
+ train_input_fn,
+ eval_input_fn,
+ loss_fn,
+ metric_fn,
+ init_checkpoint,
+ epochs,
+ steps_per_epoch,
+ steps_per_loop,
+ eval_steps,
+ training_callbacks=True,
+ custom_callbacks=None):
+ """Runs BERT classifier model using Keras compile/fit API."""
+
+ with strategy.scope():
+ training_dataset = train_input_fn()
+ evaluation_dataset = eval_input_fn() if eval_input_fn else None
+ bert_model, sub_model = model_fn()
+ optimizer = bert_model.optimizer
+
+ if init_checkpoint:
+ checkpoint = tf.train.Checkpoint(model=sub_model, encoder=sub_model)
+ checkpoint.read(init_checkpoint).assert_existing_objects_matched()
+
+ if not isinstance(metric_fn, (list, tuple)):
+ metric_fn = [metric_fn]
+ bert_model.compile(
+ optimizer=optimizer,
+ loss=loss_fn,
+ metrics=[fn() for fn in metric_fn],
+ steps_per_execution=steps_per_loop)
+
+ summary_dir = os.path.join(model_dir, 'summaries')
+ summary_callback = tf_keras.callbacks.TensorBoard(summary_dir)
+ checkpoint = tf.train.Checkpoint(model=bert_model, optimizer=optimizer)
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=model_dir,
+ max_to_keep=None,
+ step_counter=optimizer.iterations,
+ checkpoint_interval=0)
+ checkpoint_callback = keras_utils.SimpleCheckpoint(checkpoint_manager)
+
+ if training_callbacks:
+ if custom_callbacks is not None:
+ custom_callbacks += [summary_callback, checkpoint_callback]
+ else:
+ custom_callbacks = [summary_callback, checkpoint_callback]
+
+ history = bert_model.fit(
+ x=training_dataset,
+ validation_data=evaluation_dataset,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ validation_steps=eval_steps,
+ callbacks=custom_callbacks)
+ stats = {'total_training_steps': steps_per_epoch * epochs}
+ if 'loss' in history.history:
+ stats['train_loss'] = history.history['loss'][-1]
+ if 'val_accuracy' in history.history:
+ stats['eval_metrics'] = history.history['val_accuracy'][-1]
+ return bert_model, stats
+
+
+def get_predictions_and_labels(strategy,
+ trained_model,
+ eval_input_fn,
+ is_regression=False,
+ return_probs=False):
+ """Obtains predictions of trained model on evaluation data.
+
+ Note that list of labels is returned along with the predictions because the
+ order changes on distributing dataset over TPU pods.
+
+ Args:
+ strategy: Distribution strategy.
+ trained_model: Trained model with preloaded weights.
+ eval_input_fn: Input function for evaluation data.
+ is_regression: Whether it is a regression task.
+ return_probs: Whether to return probabilities of classes.
+
+ Returns:
+ predictions: List of predictions.
+ labels: List of gold labels corresponding to predictions.
+ """
+
+ @tf.function
+ def test_step(iterator):
+ """Computes predictions on distributed devices."""
+
+ def _test_step_fn(inputs):
+ """Replicated predictions."""
+ inputs, labels = inputs
+ logits = trained_model(inputs, training=False)
+ if not is_regression:
+ probabilities = tf.nn.softmax(logits)
+ return probabilities, labels
+ else:
+ return logits, labels
+
+ outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
+ # outputs: current batch logits as a tuple of shard logits
+ outputs = tf.nest.map_structure(strategy.experimental_local_results,
+ outputs)
+ labels = tf.nest.map_structure(strategy.experimental_local_results, labels)
+ return outputs, labels
+
+ def _run_evaluation(test_iterator):
+ """Runs evaluation steps."""
+ preds, golds = list(), list()
+ try:
+ with tf.experimental.async_scope():
+ while True:
+ probabilities, labels = test_step(test_iterator)
+ for cur_probs, cur_labels in zip(probabilities, labels):
+ if return_probs:
+ preds.extend(cur_probs.numpy().tolist())
+ else:
+ preds.extend(tf.math.argmax(cur_probs, axis=1).numpy())
+ golds.extend(cur_labels.numpy().tolist())
+ except (StopIteration, tf.errors.OutOfRangeError):
+ tf.experimental.async_clear_error()
+ return preds, golds
+
+ test_iter = iter(strategy.distribute_datasets_from_function(eval_input_fn))
+ predictions, labels = _run_evaluation(test_iter)
+
+ return predictions, labels
+
+
+def export_classifier(model_export_path, input_meta_data, bert_config,
+ model_dir):
+ """Exports a trained model as a `SavedModel` for inference.
+
+ Args:
+ model_export_path: a string specifying the path to the SavedModel directory.
+ input_meta_data: dictionary containing meta data about input and model.
+ bert_config: Bert configuration file to define core bert layers.
+ model_dir: The directory where the model weights and training/evaluation
+ summaries are stored.
+
+ Raises:
+ Export path is not specified, got an empty string or None.
+ """
+ if not model_export_path:
+ raise ValueError('Export path is not specified: %s' % model_export_path)
+ if not model_dir:
+ raise ValueError('Export path is not specified: %s' % model_dir)
+
+ # Export uses float32 for now, even if training uses mixed precision.
+ tf_keras.mixed_precision.set_global_policy('float32')
+ classifier_model = bert_models.classifier_model(
+ bert_config,
+ input_meta_data.get('num_labels', 1),
+ hub_module_url=FLAGS.hub_module_url,
+ hub_module_trainable=False)[0]
+
+ model_saving_utils.export_bert_model(
+ model_export_path, model=classifier_model, checkpoint_dir=model_dir)
+
+
+def run_bert(strategy,
+ input_meta_data,
+ model_config,
+ train_input_fn=None,
+ eval_input_fn=None,
+ init_checkpoint=None,
+ custom_callbacks=None,
+ custom_metrics=None):
+ """Run BERT training."""
+ # Enables XLA in Session Config. Should not be set for TPU.
+ keras_utils.set_session_config(FLAGS.enable_xla)
+ performance.set_mixed_precision_policy(common_flags.dtype())
+
+ epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
+ train_data_size = (
+ input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
+ if FLAGS.train_data_size:
+ train_data_size = min(train_data_size, FLAGS.train_data_size)
+ logging.info('Updated train_data_size: %s', train_data_size)
+ steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
+ warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
+ eval_steps = int(
+ math.ceil(input_meta_data['eval_data_size'] / FLAGS.eval_batch_size))
+
+ if not strategy:
+ raise ValueError('Distribution strategy has not been specified.')
+
+ if not custom_callbacks:
+ custom_callbacks = []
+
+ if FLAGS.log_steps:
+ custom_callbacks.append(
+ keras_utils.TimeHistory(
+ batch_size=FLAGS.train_batch_size,
+ log_steps=FLAGS.log_steps,
+ logdir=FLAGS.model_dir))
+
+ trained_model, _ = run_bert_classifier(
+ strategy,
+ model_config,
+ input_meta_data,
+ FLAGS.model_dir,
+ epochs,
+ steps_per_epoch,
+ FLAGS.steps_per_loop,
+ eval_steps,
+ warmup_steps,
+ FLAGS.learning_rate,
+ init_checkpoint or FLAGS.init_checkpoint,
+ train_input_fn,
+ eval_input_fn,
+ custom_callbacks=custom_callbacks,
+ custom_metrics=custom_metrics)
+
+ if FLAGS.model_export_path:
+ model_saving_utils.export_bert_model(
+ FLAGS.model_export_path, model=trained_model)
+ return trained_model
+
+
+def custom_main(custom_callbacks=None, custom_metrics=None):
+ """Run classification or regression.
+
+ Args:
+ custom_callbacks: list of tf_keras.Callbacks passed to training loop.
+ custom_metrics: list of metrics passed to the training loop.
+ """
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
+
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
+ label_type = LABEL_TYPES_MAP[input_meta_data.get('label_type', 'int')]
+ include_sample_weights = input_meta_data.get('has_sample_weights', False)
+
+ if not FLAGS.model_dir:
+ FLAGS.model_dir = '/tmp/bert20/'
+
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+
+ if FLAGS.mode == 'export_only':
+ export_classifier(FLAGS.model_export_path, input_meta_data, bert_config,
+ FLAGS.model_dir)
+ return
+
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ tpu_address=FLAGS.tpu)
+ eval_input_fn = get_dataset_fn(
+ FLAGS.eval_data_path,
+ input_meta_data['max_seq_length'],
+ FLAGS.eval_batch_size,
+ is_training=False,
+ label_type=label_type,
+ include_sample_weights=include_sample_weights)
+
+ if FLAGS.mode == 'predict':
+ num_labels = input_meta_data.get('num_labels', 1)
+ with strategy.scope():
+ classifier_model = bert_models.classifier_model(
+ bert_config, num_labels)[0]
+ checkpoint = tf.train.Checkpoint(model=classifier_model)
+ latest_checkpoint_file = (
+ FLAGS.predict_checkpoint_path or
+ tf.train.latest_checkpoint(FLAGS.model_dir))
+ assert latest_checkpoint_file
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(
+ latest_checkpoint_file).assert_existing_objects_matched()
+ preds, _ = get_predictions_and_labels(
+ strategy,
+ classifier_model,
+ eval_input_fn,
+ is_regression=(num_labels == 1),
+ return_probs=True)
+ output_predict_file = os.path.join(FLAGS.model_dir, 'test_results.tsv')
+ with tf.io.gfile.GFile(output_predict_file, 'w') as writer:
+ logging.info('***** Predict results *****')
+ for probabilities in preds:
+ output_line = '\t'.join(
+ str(class_probability)
+ for class_probability in probabilities) + '\n'
+ writer.write(output_line)
+ return
+
+ if FLAGS.mode != 'train_and_eval':
+ raise ValueError('Unsupported mode is specified: %s' % FLAGS.mode)
+ train_input_fn = get_dataset_fn(
+ FLAGS.train_data_path,
+ input_meta_data['max_seq_length'],
+ FLAGS.train_batch_size,
+ is_training=True,
+ label_type=label_type,
+ include_sample_weights=include_sample_weights,
+ num_samples=FLAGS.train_data_size)
+ run_bert(
+ strategy,
+ input_meta_data,
+ bert_config,
+ train_input_fn,
+ eval_input_fn,
+ custom_callbacks=custom_callbacks,
+ custom_metrics=custom_metrics)
+
+
+def main(_):
+ custom_main(custom_callbacks=None, custom_metrics=None)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('bert_config_file')
+ flags.mark_flag_as_required('input_meta_data_path')
+ flags.mark_flag_as_required('model_dir')
+ app.run(main)
diff --git a/modeling/official/legacy/bert/run_pretraining.py b/modeling/official/legacy/bert/run_pretraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..d385a0ed76e8dec50942df23d015d0739cae6231
--- /dev/null
+++ b/modeling/official/legacy/bert/run_pretraining.py
@@ -0,0 +1,217 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Run masked LM/next sentence pre-training for BERT in TF 2.x."""
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import gin
+import tensorflow as tf, tf_keras
+from official.common import distribute_utils
+from official.legacy.bert import bert_models
+from official.legacy.bert import common_flags
+from official.legacy.bert import configs
+from official.legacy.bert import input_pipeline
+from official.legacy.bert import model_training_utils
+from official.modeling import performance
+from official.nlp import optimization
+
+
+flags.DEFINE_string('input_files', None,
+ 'File path to retrieve training data for pre-training.')
+# Model training specific flags.
+flags.DEFINE_integer(
+ 'max_seq_length', 128,
+ 'The maximum total input sequence length after WordPiece tokenization. '
+ 'Sequences longer than this will be truncated, and sequences shorter '
+ 'than this will be padded.')
+flags.DEFINE_integer('max_predictions_per_seq', 20,
+ 'Maximum predictions per sequence_output.')
+flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
+flags.DEFINE_integer('num_steps_per_epoch', 1000,
+ 'Total number of training steps to run per epoch.')
+flags.DEFINE_float('warmup_steps', 10000,
+ 'Warmup steps for Adam weight decay optimizer.')
+flags.DEFINE_bool('use_next_sentence_label', True,
+ 'Whether to use next sentence label to compute final loss.')
+flags.DEFINE_bool('train_summary_interval', 0, 'Step interval for training '
+ 'summaries. If the value is a negative number, '
+ 'then training summaries are not enabled.')
+
+common_flags.define_common_bert_flags()
+
+FLAGS = flags.FLAGS
+
+
+def get_pretrain_dataset_fn(input_file_pattern, seq_length,
+ max_predictions_per_seq, global_batch_size,
+ use_next_sentence_label=True):
+ """Returns input dataset from input file string."""
+ def _dataset_fn(ctx=None):
+ """Returns tf.data.Dataset for distributed BERT pretraining."""
+ input_patterns = input_file_pattern.split(',')
+ batch_size = ctx.get_per_replica_batch_size(global_batch_size)
+ train_dataset = input_pipeline.create_pretrain_dataset(
+ input_patterns,
+ seq_length,
+ max_predictions_per_seq,
+ batch_size,
+ is_training=True,
+ input_pipeline_context=ctx,
+ use_next_sentence_label=use_next_sentence_label)
+ return train_dataset
+
+ return _dataset_fn
+
+
+def get_loss_fn():
+ """Returns loss function for BERT pretraining."""
+
+ def _bert_pretrain_loss_fn(unused_labels, losses, **unused_args):
+ return tf.reduce_mean(losses)
+
+ return _bert_pretrain_loss_fn
+
+
+def run_customized_training(strategy,
+ bert_config,
+ init_checkpoint,
+ max_seq_length,
+ max_predictions_per_seq,
+ model_dir,
+ steps_per_epoch,
+ steps_per_loop,
+ epochs,
+ initial_lr,
+ warmup_steps,
+ end_lr,
+ optimizer_type,
+ input_files,
+ train_batch_size,
+ use_next_sentence_label=True,
+ train_summary_interval=0,
+ custom_callbacks=None,
+ explicit_allreduce=False,
+ pre_allreduce_callbacks=None,
+ post_allreduce_callbacks=None,
+ allreduce_bytes_per_pack=0):
+ """Run BERT pretrain model training using low-level API."""
+
+ train_input_fn = get_pretrain_dataset_fn(input_files, max_seq_length,
+ max_predictions_per_seq,
+ train_batch_size,
+ use_next_sentence_label)
+
+ def _get_pretrain_model():
+ """Gets a pretraining model."""
+ pretrain_model, core_model = bert_models.pretrain_model(
+ bert_config, max_seq_length, max_predictions_per_seq,
+ use_next_sentence_label=use_next_sentence_label)
+ optimizer = optimization.create_optimizer(
+ initial_lr, steps_per_epoch * epochs, warmup_steps,
+ end_lr, optimizer_type)
+ pretrain_model.optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=common_flags.use_float16())
+ return pretrain_model, core_model
+
+ trained_model = model_training_utils.run_customized_training_loop(
+ strategy=strategy,
+ model_fn=_get_pretrain_model,
+ loss_fn=get_loss_fn(),
+ scale_loss=FLAGS.scale_loss,
+ model_dir=model_dir,
+ init_checkpoint=init_checkpoint,
+ train_input_fn=train_input_fn,
+ steps_per_epoch=steps_per_epoch,
+ steps_per_loop=steps_per_loop,
+ epochs=epochs,
+ sub_model_export_name='pretrained/bert_model',
+ explicit_allreduce=explicit_allreduce,
+ pre_allreduce_callbacks=pre_allreduce_callbacks,
+ post_allreduce_callbacks=post_allreduce_callbacks,
+ allreduce_bytes_per_pack=allreduce_bytes_per_pack,
+ train_summary_interval=train_summary_interval,
+ custom_callbacks=custom_callbacks)
+
+ return trained_model
+
+
+def run_bert_pretrain(strategy, custom_callbacks=None):
+ """Runs BERT pre-training."""
+
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ if not strategy:
+ raise ValueError('Distribution strategy is not specified.')
+
+ # Runs customized training loop.
+ logging.info('Training using customized training loop TF 2.0 with distributed'
+ 'strategy.')
+
+ performance.set_mixed_precision_policy(common_flags.dtype())
+
+ # Only when explicit_allreduce = True, post_allreduce_callbacks and
+ # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
+ # longer implicitly allreduce gradients, users manually allreduce gradient and
+ # pass the allreduced grads_and_vars to apply_gradients().
+ # With explicit_allreduce = True, clip_by_global_norm is moved to after
+ # allreduce.
+ return run_customized_training(
+ strategy,
+ bert_config,
+ FLAGS.init_checkpoint, # Used to initialize only the BERT submodel.
+ FLAGS.max_seq_length,
+ FLAGS.max_predictions_per_seq,
+ FLAGS.model_dir,
+ FLAGS.num_steps_per_epoch,
+ FLAGS.steps_per_loop,
+ FLAGS.num_train_epochs,
+ FLAGS.learning_rate,
+ FLAGS.warmup_steps,
+ FLAGS.end_lr,
+ FLAGS.optimizer_type,
+ FLAGS.input_files,
+ FLAGS.train_batch_size,
+ FLAGS.use_next_sentence_label,
+ FLAGS.train_summary_interval,
+ custom_callbacks=custom_callbacks,
+ explicit_allreduce=FLAGS.explicit_allreduce,
+ pre_allreduce_callbacks=[
+ model_training_utils.clip_by_global_norm_callback
+ ],
+ allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
+
+
+def main(_):
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
+ if not FLAGS.model_dir:
+ FLAGS.model_dir = '/tmp/bert20/'
+ # Configures cluster spec for multi-worker distribution strategy.
+ if FLAGS.num_gpus > 0:
+ _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ all_reduce_alg=FLAGS.all_reduce_alg,
+ tpu_address=FLAGS.tpu)
+ if strategy:
+ print('***** Number of cores used : ', strategy.num_replicas_in_sync)
+
+ run_bert_pretrain(strategy)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/modeling/official/legacy/bert/run_squad.py b/modeling/official/legacy/bert/run_squad.py
new file mode 100644
index 0000000000000000000000000000000000000000..3620b42831c4eb9addc2c8d2b202de450d527e26
--- /dev/null
+++ b/modeling/official/legacy/bert/run_squad.py
@@ -0,0 +1,148 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Run BERT on SQuAD 1.1 and SQuAD 2.0 in TF 2.x."""
+
+import json
+import os
+import time
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import gin
+import tensorflow as tf, tf_keras
+from official.common import distribute_utils
+from official.legacy.bert import configs as bert_configs
+from official.legacy.bert import run_squad_helper
+from official.nlp.data import squad_lib as squad_lib_wp
+from official.nlp.tools import tokenization
+from official.utils.misc import keras_utils
+
+
+flags.DEFINE_string('vocab_file', None,
+ 'The vocabulary file that the BERT model was trained on.')
+
+# More flags can be found in run_squad_helper.
+run_squad_helper.define_common_squad_flags()
+
+FLAGS = flags.FLAGS
+
+
+def train_squad(strategy,
+ input_meta_data,
+ custom_callbacks=None,
+ run_eagerly=False,
+ init_checkpoint=None,
+ sub_model_export_name=None):
+ """Run bert squad training."""
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ init_checkpoint = init_checkpoint or FLAGS.init_checkpoint
+ run_squad_helper.train_squad(strategy, input_meta_data, bert_config,
+ custom_callbacks, run_eagerly, init_checkpoint,
+ sub_model_export_name=sub_model_export_name)
+
+
+def predict_squad(strategy, input_meta_data):
+ """Makes predictions for the squad dataset."""
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ run_squad_helper.predict_squad(
+ strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
+
+
+def eval_squad(strategy, input_meta_data):
+ """Evaluate on the squad dataset."""
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ eval_metrics = run_squad_helper.eval_squad(
+ strategy, input_meta_data, tokenizer, bert_config, squad_lib_wp)
+ return eval_metrics
+
+
+def export_squad(model_export_path, input_meta_data):
+ """Exports a trained model as a `SavedModel` for inference.
+
+ Args:
+ model_export_path: a string specifying the path to the SavedModel directory.
+ input_meta_data: dictionary containing meta data about input and model.
+
+ Raises:
+ Export path is not specified, got an empty string or None.
+ """
+ bert_config = bert_configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ run_squad_helper.export_squad(model_export_path, input_meta_data, bert_config)
+
+
+def main(_):
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
+
+ with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
+ input_meta_data = json.loads(reader.read().decode('utf-8'))
+
+ if FLAGS.mode == 'export_only':
+ export_squad(FLAGS.model_export_path, input_meta_data)
+ return
+
+ # Configures cluster spec for multi-worker distribution strategy.
+ if FLAGS.num_gpus > 0:
+ _ = distribute_utils.configure_cluster(FLAGS.worker_hosts, FLAGS.task_index)
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.distribution_strategy,
+ num_gpus=FLAGS.num_gpus,
+ all_reduce_alg=FLAGS.all_reduce_alg,
+ tpu_address=FLAGS.tpu)
+
+ if 'train' in FLAGS.mode:
+ if FLAGS.log_steps:
+ custom_callbacks = [keras_utils.TimeHistory(
+ batch_size=FLAGS.train_batch_size,
+ log_steps=FLAGS.log_steps,
+ logdir=FLAGS.model_dir,
+ )]
+ else:
+ custom_callbacks = None
+
+ train_squad(
+ strategy,
+ input_meta_data,
+ custom_callbacks=custom_callbacks,
+ run_eagerly=FLAGS.run_eagerly,
+ sub_model_export_name=FLAGS.sub_model_export_name,
+ )
+ if 'predict' in FLAGS.mode:
+ predict_squad(strategy, input_meta_data)
+ if 'eval' in FLAGS.mode:
+ eval_metrics = eval_squad(strategy, input_meta_data)
+ f1_score = eval_metrics['final_f1']
+ logging.info('SQuAD eval F1-score: %f', f1_score)
+ summary_dir = os.path.join(FLAGS.model_dir, 'summaries', 'eval')
+ summary_writer = tf.summary.create_file_writer(summary_dir)
+ with summary_writer.as_default():
+ # TODO(lehou): write to the correct step number.
+ tf.summary.scalar('F1-score', f1_score, step=0)
+ summary_writer.flush()
+ # Also write eval_metrics to json file.
+ squad_lib_wp.write_to_json_files(
+ eval_metrics, os.path.join(summary_dir, 'eval_metrics.json'))
+ time.sleep(60)
+
+
+if __name__ == '__main__':
+ flags.mark_flag_as_required('bert_config_file')
+ flags.mark_flag_as_required('model_dir')
+ app.run(main)
diff --git a/modeling/official/legacy/bert/run_squad_helper.py b/modeling/official/legacy/bert/run_squad_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba56596c95753e44a949b948d9910a357dd1698f
--- /dev/null
+++ b/modeling/official/legacy/bert/run_squad_helper.py
@@ -0,0 +1,471 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Library for running BERT family models on SQuAD 1.1/2.0 in TF 2.x."""
+
+import collections
+import json
+import os
+
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+from official.legacy.bert import bert_models
+from official.legacy.bert import common_flags
+from official.legacy.bert import input_pipeline
+from official.legacy.bert import model_saving_utils
+from official.legacy.bert import model_training_utils
+from official.modeling import performance
+from official.nlp import optimization
+from official.nlp.data import squad_lib_sp
+from official.nlp.tools import squad_evaluate_v1_1
+from official.nlp.tools import squad_evaluate_v2_0
+from official.utils.misc import keras_utils
+
+
+def define_common_squad_flags():
+ """Defines common flags used by SQuAD tasks."""
+ flags.DEFINE_enum(
+ 'mode', 'train_and_eval', [
+ 'train_and_eval', 'train_and_predict', 'train', 'eval', 'predict',
+ 'export_only'
+ ], 'One of {"train_and_eval", "train_and_predict", '
+ '"train", "eval", "predict", "export_only"}. '
+ '`train_and_eval`: train & predict to json files & compute eval metrics. '
+ '`train_and_predict`: train & predict to json files. '
+ '`train`: only trains the model. '
+ '`eval`: predict answers from squad json file & compute eval metrics. '
+ '`predict`: predict answers from the squad json file. '
+ '`export_only`: will take the latest checkpoint inside '
+ 'model_dir and export a `SavedModel`.')
+ flags.DEFINE_string('train_data_path', '',
+ 'Training data path with train tfrecords.')
+ flags.DEFINE_string(
+ 'input_meta_data_path', None,
+ 'Path to file that contains meta data about input '
+ 'to be used for training and evaluation.')
+ # Model training specific flags.
+ flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
+ # Predict processing related.
+ flags.DEFINE_string(
+ 'predict_file', None, 'SQuAD prediction json file path. '
+ '`predict` mode supports multiple files: one can use '
+ 'wildcard to specify multiple files and it can also be '
+ 'multiple file patterns separated by comma. Note that '
+ '`eval` mode only supports a single predict file.')
+ flags.DEFINE_bool(
+ 'do_lower_case', True,
+ 'Whether to lower case the input text. Should be True for uncased '
+ 'models and False for cased models.')
+ flags.DEFINE_float(
+ 'null_score_diff_threshold', 0.0,
+ 'If null_score - best_non_null is greater than the threshold, '
+ 'predict null. This is only used for SQuAD v2.')
+ flags.DEFINE_bool(
+ 'verbose_logging', False,
+ 'If true, all of the warnings related to data processing will be '
+ 'printed. A number of warnings are expected for a normal SQuAD '
+ 'evaluation.')
+ flags.DEFINE_integer('predict_batch_size', 8,
+ 'Total batch size for prediction.')
+ flags.DEFINE_integer(
+ 'n_best_size', 20,
+ 'The total number of n-best predictions to generate in the '
+ 'nbest_predictions.json output file.')
+ flags.DEFINE_integer(
+ 'max_answer_length', 30,
+ 'The maximum length of an answer that can be generated. This is needed '
+ 'because the start and end predictions are not conditioned on one '
+ 'another.')
+
+ common_flags.define_common_bert_flags()
+
+
+FLAGS = flags.FLAGS
+
+
+def squad_loss_fn(start_positions, end_positions, start_logits, end_logits):
+ """Returns sparse categorical crossentropy for start/end logits."""
+ start_loss = tf_keras.losses.sparse_categorical_crossentropy(
+ start_positions, start_logits, from_logits=True)
+ end_loss = tf_keras.losses.sparse_categorical_crossentropy(
+ end_positions, end_logits, from_logits=True)
+
+ total_loss = (tf.reduce_mean(start_loss) + tf.reduce_mean(end_loss)) / 2
+ return total_loss
+
+
+def get_loss_fn():
+ """Gets a loss function for squad task."""
+
+ def _loss_fn(labels, model_outputs):
+ start_positions = labels['start_positions']
+ end_positions = labels['end_positions']
+ start_logits, end_logits = model_outputs
+ return squad_loss_fn(start_positions, end_positions, start_logits,
+ end_logits)
+
+ return _loss_fn
+
+
+RawResult = collections.namedtuple('RawResult',
+ ['unique_id', 'start_logits', 'end_logits'])
+
+
+def get_raw_results(predictions):
+ """Converts multi-replica predictions to RawResult."""
+ for unique_ids, start_logits, end_logits in zip(predictions['unique_ids'],
+ predictions['start_logits'],
+ predictions['end_logits']):
+ for values in zip(unique_ids.numpy(), start_logits.numpy(),
+ end_logits.numpy()):
+ yield RawResult(
+ unique_id=values[0],
+ start_logits=values[1].tolist(),
+ end_logits=values[2].tolist())
+
+
+def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
+ is_training):
+ """Gets a closure to create a dataset.."""
+
+ def _dataset_fn(ctx=None):
+ """Returns tf.data.Dataset for distributed BERT pretraining."""
+ batch_size = ctx.get_per_replica_batch_size(
+ global_batch_size) if ctx else global_batch_size
+ dataset = input_pipeline.create_squad_dataset(
+ input_file_pattern,
+ max_seq_length,
+ batch_size,
+ is_training=is_training,
+ input_pipeline_context=ctx)
+ return dataset
+
+ return _dataset_fn
+
+
+def get_squad_model_to_predict(strategy, bert_config, checkpoint_path,
+ input_meta_data):
+ """Gets a squad model to make predictions."""
+ with strategy.scope():
+ # Prediction always uses float32, even if training uses mixed precision.
+ tf_keras.mixed_precision.set_global_policy('float32')
+ squad_model, _ = bert_models.squad_model(
+ bert_config,
+ input_meta_data['max_seq_length'],
+ hub_module_url=FLAGS.hub_module_url)
+
+ if checkpoint_path is None:
+ checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)
+ logging.info('Restoring checkpoints from %s', checkpoint_path)
+ checkpoint = tf.train.Checkpoint(model=squad_model)
+ checkpoint.restore(checkpoint_path).expect_partial()
+ return squad_model
+
+
+def predict_squad_customized(strategy, input_meta_data, predict_tfrecord_path,
+ num_steps, squad_model):
+ """Make predictions using a Bert-based squad model."""
+ predict_dataset_fn = get_dataset_fn(
+ predict_tfrecord_path,
+ input_meta_data['max_seq_length'],
+ FLAGS.predict_batch_size,
+ is_training=False)
+ predict_iterator = iter(
+ strategy.distribute_datasets_from_function(predict_dataset_fn))
+
+ @tf.function
+ def predict_step(iterator):
+ """Predicts on distributed devices."""
+
+ def _replicated_step(inputs):
+ """Replicated prediction calculation."""
+ x, _ = inputs
+ unique_ids = x.pop('unique_ids')
+ start_logits, end_logits = squad_model(x, training=False)
+ return dict(
+ unique_ids=unique_ids,
+ start_logits=start_logits,
+ end_logits=end_logits)
+
+ outputs = strategy.run(_replicated_step, args=(next(iterator),))
+ return tf.nest.map_structure(strategy.experimental_local_results, outputs)
+
+ all_results = []
+ for _ in range(num_steps):
+ predictions = predict_step(predict_iterator)
+ for result in get_raw_results(predictions):
+ all_results.append(result)
+ if len(all_results) % 100 == 0:
+ logging.info('Made predictions for %d records.', len(all_results))
+ return all_results
+
+
+def train_squad(strategy,
+ input_meta_data,
+ bert_config,
+ custom_callbacks=None,
+ run_eagerly=False,
+ init_checkpoint=None,
+ sub_model_export_name=None):
+ """Run bert squad training."""
+ if strategy:
+ logging.info('Training using customized training loop with distribution'
+ ' strategy.')
+ # Enables XLA in Session Config. Should not be set for TPU.
+ keras_utils.set_session_config(FLAGS.enable_xla)
+ performance.set_mixed_precision_policy(common_flags.dtype())
+
+ epochs = FLAGS.num_train_epochs
+ num_train_examples = input_meta_data['train_data_size']
+ max_seq_length = input_meta_data['max_seq_length']
+ steps_per_epoch = int(num_train_examples / FLAGS.train_batch_size)
+ warmup_steps = int(epochs * num_train_examples * 0.1 / FLAGS.train_batch_size)
+ train_input_fn = get_dataset_fn(
+ FLAGS.train_data_path,
+ max_seq_length,
+ FLAGS.train_batch_size,
+ is_training=True)
+
+ def _get_squad_model():
+ """Get Squad model and optimizer."""
+ squad_model, core_model = bert_models.squad_model(
+ bert_config,
+ max_seq_length,
+ hub_module_url=FLAGS.hub_module_url,
+ hub_module_trainable=FLAGS.hub_module_trainable)
+ optimizer = optimization.create_optimizer(FLAGS.learning_rate,
+ steps_per_epoch * epochs,
+ warmup_steps, FLAGS.end_lr,
+ FLAGS.optimizer_type)
+
+ squad_model.optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=common_flags.use_float16())
+ return squad_model, core_model
+
+ # Only when explicit_allreduce = True, post_allreduce_callbacks and
+ # allreduce_bytes_per_pack will take effect. optimizer.apply_gradients() no
+ # longer implicitly allreduce gradients, users manually allreduce gradient and
+ # pass the allreduced grads_and_vars to apply_gradients().
+ # With explicit_allreduce = True, clip_by_global_norm is moved to after
+ # allreduce.
+ model_training_utils.run_customized_training_loop(
+ strategy=strategy,
+ model_fn=_get_squad_model,
+ loss_fn=get_loss_fn(),
+ model_dir=FLAGS.model_dir,
+ steps_per_epoch=steps_per_epoch,
+ steps_per_loop=FLAGS.steps_per_loop,
+ epochs=epochs,
+ train_input_fn=train_input_fn,
+ init_checkpoint=init_checkpoint or FLAGS.init_checkpoint,
+ sub_model_export_name=sub_model_export_name,
+ run_eagerly=run_eagerly,
+ custom_callbacks=custom_callbacks,
+ explicit_allreduce=FLAGS.explicit_allreduce,
+ pre_allreduce_callbacks=[
+ model_training_utils.clip_by_global_norm_callback
+ ],
+ allreduce_bytes_per_pack=FLAGS.allreduce_bytes_per_pack)
+
+
+def prediction_output_squad(strategy, input_meta_data, tokenizer, squad_lib,
+ predict_file, squad_model):
+ """Makes predictions for a squad dataset."""
+ doc_stride = input_meta_data['doc_stride']
+ max_query_length = input_meta_data['max_query_length']
+ # Whether data should be in Ver 2.0 format.
+ version_2_with_negative = input_meta_data.get('version_2_with_negative',
+ False)
+ eval_examples = squad_lib.read_squad_examples(
+ input_file=predict_file,
+ is_training=False,
+ version_2_with_negative=version_2_with_negative)
+
+ eval_writer = squad_lib.FeatureWriter(
+ filename=os.path.join(FLAGS.model_dir, 'eval.tf_record'),
+ is_training=False)
+ eval_features = []
+
+ def _append_feature(feature, is_padding):
+ if not is_padding:
+ eval_features.append(feature)
+ eval_writer.process_feature(feature)
+
+ # TPU requires a fixed batch size for all batches, therefore the number
+ # of examples must be a multiple of the batch size, or else examples
+ # will get dropped. So we pad with fake examples which are ignored
+ # later on.
+ kwargs = dict(
+ examples=eval_examples,
+ tokenizer=tokenizer,
+ max_seq_length=input_meta_data['max_seq_length'],
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=False,
+ output_fn=_append_feature,
+ batch_size=FLAGS.predict_batch_size)
+
+ # squad_lib_sp requires one more argument 'do_lower_case'.
+ if squad_lib == squad_lib_sp:
+ kwargs['do_lower_case'] = FLAGS.do_lower_case
+ dataset_size = squad_lib.convert_examples_to_features(**kwargs)
+ eval_writer.close()
+
+ logging.info('***** Running predictions *****')
+ logging.info(' Num orig examples = %d', len(eval_examples))
+ logging.info(' Num split examples = %d', len(eval_features))
+ logging.info(' Batch size = %d', FLAGS.predict_batch_size)
+
+ num_steps = int(dataset_size / FLAGS.predict_batch_size)
+ all_results = predict_squad_customized(strategy, input_meta_data,
+ eval_writer.filename, num_steps,
+ squad_model)
+
+ all_predictions, all_nbest_json, scores_diff_json = (
+ squad_lib.postprocess_output(
+ eval_examples,
+ eval_features,
+ all_results,
+ FLAGS.n_best_size,
+ FLAGS.max_answer_length,
+ FLAGS.do_lower_case,
+ version_2_with_negative=version_2_with_negative,
+ null_score_diff_threshold=FLAGS.null_score_diff_threshold,
+ verbose=FLAGS.verbose_logging))
+
+ return all_predictions, all_nbest_json, scores_diff_json
+
+
+def dump_to_files(all_predictions,
+ all_nbest_json,
+ scores_diff_json,
+ squad_lib,
+ version_2_with_negative,
+ file_prefix=''):
+ """Save output to json files."""
+ output_prediction_file = os.path.join(FLAGS.model_dir,
+ '%spredictions.json' % file_prefix)
+ output_nbest_file = os.path.join(FLAGS.model_dir,
+ '%snbest_predictions.json' % file_prefix)
+ output_null_log_odds_file = os.path.join(FLAGS.model_dir, file_prefix,
+ '%snull_odds.json' % file_prefix)
+ logging.info('Writing predictions to: %s', (output_prediction_file))
+ logging.info('Writing nbest to: %s', (output_nbest_file))
+
+ squad_lib.write_to_json_files(all_predictions, output_prediction_file)
+ squad_lib.write_to_json_files(all_nbest_json, output_nbest_file)
+ if version_2_with_negative:
+ squad_lib.write_to_json_files(scores_diff_json, output_null_log_odds_file)
+
+
+def _get_matched_files(input_path):
+ """Returns all files that matches the input_path."""
+ input_patterns = input_path.strip().split(',')
+ all_matched_files = []
+ for input_pattern in input_patterns:
+ input_pattern = input_pattern.strip()
+ if not input_pattern:
+ continue
+ matched_files = tf.io.gfile.glob(input_pattern)
+ if not matched_files:
+ raise ValueError('%s does not match any files.' % input_pattern)
+ else:
+ all_matched_files.extend(matched_files)
+ return sorted(all_matched_files)
+
+
+def predict_squad(strategy,
+ input_meta_data,
+ tokenizer,
+ bert_config,
+ squad_lib,
+ init_checkpoint=None):
+ """Get prediction results and evaluate them to hard drive."""
+ if init_checkpoint is None:
+ init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
+
+ all_predict_files = _get_matched_files(FLAGS.predict_file)
+ squad_model = get_squad_model_to_predict(strategy, bert_config,
+ init_checkpoint, input_meta_data)
+ for idx, predict_file in enumerate(all_predict_files):
+ all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
+ strategy, input_meta_data, tokenizer, squad_lib, predict_file,
+ squad_model)
+ if len(all_predict_files) == 1:
+ file_prefix = ''
+ else:
+ # if predict_file is /path/xquad.ar.json, the `file_prefix` may be
+ # "xquad.ar-0-"
+ file_prefix = '%s-' % os.path.splitext(
+ os.path.basename(all_predict_files[idx]))[0]
+ dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
+ input_meta_data.get('version_2_with_negative', False),
+ file_prefix)
+
+
+def eval_squad(strategy,
+ input_meta_data,
+ tokenizer,
+ bert_config,
+ squad_lib,
+ init_checkpoint=None):
+ """Get prediction results and evaluate them against ground truth."""
+ if init_checkpoint is None:
+ init_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
+
+ all_predict_files = _get_matched_files(FLAGS.predict_file)
+ if len(all_predict_files) != 1:
+ raise ValueError('`eval_squad` only supports one predict file, '
+ 'but got %s' % all_predict_files)
+
+ squad_model = get_squad_model_to_predict(strategy, bert_config,
+ init_checkpoint, input_meta_data)
+ all_predictions, all_nbest_json, scores_diff_json = prediction_output_squad(
+ strategy, input_meta_data, tokenizer, squad_lib, all_predict_files[0],
+ squad_model)
+ dump_to_files(all_predictions, all_nbest_json, scores_diff_json, squad_lib,
+ input_meta_data.get('version_2_with_negative', False))
+
+ with tf.io.gfile.GFile(FLAGS.predict_file, 'r') as reader:
+ dataset_json = json.load(reader)
+ pred_dataset = dataset_json['data']
+ if input_meta_data.get('version_2_with_negative', False):
+ eval_metrics = squad_evaluate_v2_0.evaluate(pred_dataset, all_predictions,
+ scores_diff_json)
+ else:
+ eval_metrics = squad_evaluate_v1_1.evaluate(pred_dataset, all_predictions)
+ return eval_metrics
+
+
+def export_squad(model_export_path, input_meta_data, bert_config):
+ """Exports a trained model as a `SavedModel` for inference.
+
+ Args:
+ model_export_path: a string specifying the path to the SavedModel directory.
+ input_meta_data: dictionary containing meta data about input and model.
+ bert_config: Bert configuration file to define core bert layers.
+
+ Raises:
+ Export path is not specified, got an empty string or None.
+ """
+ if not model_export_path:
+ raise ValueError('Export path is not specified: %s' % model_export_path)
+ # Export uses float32 for now, even if training uses mixed precision.
+ tf_keras.mixed_precision.set_global_policy('float32')
+ squad_model, _ = bert_models.squad_model(bert_config,
+ input_meta_data['max_seq_length'])
+ model_saving_utils.export_bert_model(
+ model_export_path, model=squad_model, checkpoint_dir=FLAGS.model_dir)
diff --git a/modeling/official/legacy/bert/serving.py b/modeling/official/legacy/bert/serving.py
new file mode 100644
index 0000000000000000000000000000000000000000..60893e21383039544ae6abc674f4c03fda13fb39
--- /dev/null
+++ b/modeling/official/legacy/bert/serving.py
@@ -0,0 +1,133 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Examples of SavedModel export for tf-serving."""
+
+from absl import app
+from absl import flags
+import tensorflow as tf, tf_keras
+
+from official.legacy.bert import bert_models
+from official.legacy.bert import configs
+
+flags.DEFINE_integer(
+ "sequence_length", None, "Sequence length to parse the tf.Example. If "
+ "sequence_length > 0, add a signature for serialized "
+ "tf.Example and define the parsing specification by the "
+ "sequence_length.")
+flags.DEFINE_string("bert_config_file", None,
+ "Bert configuration file to define core bert layers.")
+flags.DEFINE_string("model_checkpoint_path", None,
+ "File path to TF model checkpoint.")
+flags.DEFINE_string("export_path", None,
+ "Destination folder to export the serving SavedModel.")
+
+FLAGS = flags.FLAGS
+
+
+class BertServing(tf_keras.Model):
+ """Bert transformer encoder model for serving."""
+
+ def __init__(self, bert_config, name_to_features=None, name="serving_model"):
+ super(BertServing, self).__init__(name=name)
+ self.encoder = bert_models.get_transformer_encoder(
+ bert_config, sequence_length=None)
+ self.name_to_features = name_to_features
+
+ def call(self, inputs):
+ input_word_ids = inputs["input_ids"]
+ input_mask = inputs["input_mask"]
+ input_type_ids = inputs["segment_ids"]
+
+ encoder_outputs, _ = self.encoder(
+ [input_word_ids, input_mask, input_type_ids])
+ return encoder_outputs
+
+ def serve_body(self, input_ids, input_mask=None, segment_ids=None):
+ if segment_ids is None:
+ # Requires CLS token is the first token of inputs.
+ segment_ids = tf.zeros_like(input_ids)
+ if input_mask is None:
+ # The mask has 1 for real tokens and 0 for padding tokens.
+ input_mask = tf.where(
+ tf.equal(input_ids, 0), tf.zeros_like(input_ids),
+ tf.ones_like(input_ids))
+
+ inputs = dict(
+ input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)
+ return self.call(inputs)
+
+ @tf.function
+ def serve(self, input_ids, input_mask=None, segment_ids=None):
+ outputs = self.serve_body(input_ids, input_mask, segment_ids)
+ # Returns a dictionary to control SignatureDef output signature.
+ return {"outputs": outputs[-1]}
+
+ @tf.function
+ def serve_examples(self, inputs):
+ features = tf.io.parse_example(inputs, self.name_to_features)
+ for key in list(features.keys()):
+ t = features[key]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ features[key] = t
+ return self.serve(
+ features["input_ids"],
+ input_mask=features["input_mask"] if "input_mask" in features else None,
+ segment_ids=features["segment_ids"]
+ if "segment_ids" in features else None)
+
+ @classmethod
+ def export(cls, model, export_dir):
+ if not isinstance(model, cls):
+ raise ValueError("Invalid model instance: %s, it should be a %s" %
+ (model, cls))
+
+ signatures = {
+ "serving_default":
+ model.serve.get_concrete_function(
+ input_ids=tf.TensorSpec(
+ shape=[None, None], dtype=tf.int32, name="inputs")),
+ }
+ if model.name_to_features:
+ signatures[
+ "serving_examples"] = model.serve_examples.get_concrete_function(
+ tf.TensorSpec(shape=[None], dtype=tf.string, name="examples"))
+ tf.saved_model.save(model, export_dir=export_dir, signatures=signatures)
+
+
+def main(_):
+ sequence_length = FLAGS.sequence_length
+ if sequence_length is not None and sequence_length > 0:
+ name_to_features = {
+ "input_ids": tf.io.FixedLenFeature([sequence_length], tf.int64),
+ "input_mask": tf.io.FixedLenFeature([sequence_length], tf.int64),
+ "segment_ids": tf.io.FixedLenFeature([sequence_length], tf.int64),
+ }
+ else:
+ name_to_features = None
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
+ serving_model = BertServing(
+ bert_config=bert_config, name_to_features=name_to_features)
+ checkpoint = tf.train.Checkpoint(model=serving_model.encoder)
+ checkpoint.restore(FLAGS.model_checkpoint_path
+ ).assert_existing_objects_matched().run_restore_ops()
+ BertServing.export(serving_model, FLAGS.export_path)
+
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("bert_config_file")
+ flags.mark_flag_as_required("model_checkpoint_path")
+ flags.mark_flag_as_required("export_path")
+ app.run(main)
diff --git a/modeling/official/legacy/detection/README.md b/modeling/official/legacy/detection/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6999579271df36c27144a0828c3202e326be27ec
--- /dev/null
+++ b/modeling/official/legacy/detection/README.md
@@ -0,0 +1,429 @@
+# Object Detection Models on TensorFlow 2
+
+**WARNING**: This repository will be deprecated and replaced by the solid
+implementations inside vision/beta/.
+
+## Prerequsite
+To get started, download the code from TensorFlow models GitHub repository or
+use the pre-installed Google Cloud VM.
+
+```bash
+git clone https://github.com/tensorflow/models.git
+```
+
+Next, make sure to use TensorFlow 2.1+ on Google Cloud. Also here are
+a few package you need to install to get started:
+
+```bash
+sudo apt-get install -y python-tk && \
+pip3 install -r ~/models/official/requirements.txt
+```
+
+## Train RetinaNet on TPU
+
+### Train a vanilla ResNet-50 based RetinaNet.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+RESNET_CHECKPOINT=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu="${TPU_NAME?}" \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --params_override="{ type: retinanet, train: { checkpoint: { path: ${RESNET_CHECKPOINT?}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
+```
+
+The pre-trained ResNet-50 checkpoint can be downloaded [here](https://storage.cloud.google.com/cloud-tpu-checkpoints/model-garden-vision/detection/resnet50-2018-02-07.tar.gz).
+
+Note: The ResNet implementation under
+[detection/](https://github.com/tensorflow/models/tree/master/official/legacy/detection)
+is currently different from the one under
+[classification/](https://github.com/tensorflow/models/tree/master/official/vision/image_classification),
+so the checkpoints are not compatible.
+We will unify the implementation soon.
+
+
+### Train a SpineNet-49 based RetinaNet.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu="${TPU_NAME?}" \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --params_override="{ type: retinanet, architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
+```
+
+
+### Train a custom RetinaNet using the config file.
+
+First, create a YAML config file, e.g. *my_retinanet.yaml*. This file specifies
+the parameters to be overridden, which should at least include the following
+fields.
+
+```YAML
+# my_retinanet.yaml
+type: 'retinanet'
+train:
+ train_file_pattern:
+eval:
+ eval_file_pattern:
+ val_json_file:
+```
+
+Once the YAML config file is created, you can launch the training using the
+following command.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu="${TPU_NAME?}" \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --config_file="my_retinanet.yaml"
+```
+
+## Train RetinaNet on GPU
+
+Training on GPU is similar to that on TPU. The major change is the strategy
+type (use "[mirrored](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)" for multiple GPU and
+"[one_device](https://www.tensorflow.org/api_docs/python/tf/distribute/OneDeviceStrategy)" for single GPU).
+
+Multi-GPUs example (assuming there are 8GPU connected to the host):
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=mirrored \
+ --num_gpus=8 \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --config_file="my_retinanet.yaml"
+```
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --config_file="my_retinanet.yaml"
+```
+
+An example with inline configuration (YAML or JSON format):
+
+```
+python3 ~/models/official/legacy/detection/main.py \
+ --model_dir= \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --mode=train \
+ --params_override="eval:
+ eval_file_pattern:
+ batch_size: 8
+ val_json_file:
+predict:
+ predict_batch_size: 8
+architecture:
+ use_bfloat16: False
+train:
+ total_steps: 1
+ batch_size: 8
+ train_file_pattern:
+use_tpu: False
+"
+```
+
+---
+
+## Train Mask R-CNN on TPU
+
+### Train a vanilla ResNet-50 based Mask R-CNN.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+RESNET_CHECKPOINT=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --params_override="{train: { checkpoint: { path: ${RESNET_CHECKPOINT}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN} }, eval: { val_json_file: ${VAL_JSON_FILE}, eval_file_pattern: ${EVAL_FILE_PATTERN} } }"
+```
+
+The pre-trained ResNet-50 checkpoint can be downloaded [here](https://storage.cloud.google.com/cloud-tpu-checkpoints/model-garden-vision/detection/resnet50-2018-02-07.tar.gz).
+
+Note: The ResNet implementation under
+[detection/](https://github.com/tensorflow/models/tree/master/official/legacy/detection)
+is currently different from the one under
+[classification/](https://github.com/tensorflow/models/tree/master/official/vision/image_classification),
+so the checkpoints are not compatible.
+We will unify the implementation soon.
+
+
+### Train a SpineNet-49 based Mask R-CNN.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu="${TPU_NAME?}" \
+ --model_dir="${MODEL_DIR?}" \
+ --mode=train \
+ --model=mask_rcnn \
+ --params_override="{architecture: {backbone: spinenet, multilevel_features: identity}, spinenet: {model_id: 49}, train_file_pattern: ${TRAIN_FILE_PATTERN?} }, eval: { val_json_file: ${VAL_JSON_FILE?}, eval_file_pattern: ${EVAL_FILE_PATTERN?} } }"
+```
+
+
+### Train a custom Mask R-CNN using the config file.
+
+First, create a YAML config file, e.g. *my_maskrcnn.yaml*.
+This file specifies the parameters to be overridden,
+which should at least include the following fields.
+
+```YAML
+# my_maskrcnn.yaml
+train:
+ train_file_pattern:
+eval:
+ eval_file_pattern:
+ val_json_file:
+```
+
+Once the YAML config file is created, you can launch the training using the
+following command.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --config_file="my_maskrcnn.yaml"
+```
+
+## Train Mask R-CNN on GPU
+
+Training on GPU is similar to that on TPU. The major change is the strategy type
+(use
+"[mirrored](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)"
+for multiple GPU and
+"[one_device](https://www.tensorflow.org/api_docs/python/tf/distribute/OneDeviceStrategy)"
+for single GPU).
+
+Multi-GPUs example (assuming there are 8GPU connected to the host):
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=mirrored \
+ --num_gpus=8 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --config_file="my_maskrcnn.yaml"
+```
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=mask_rcnn \
+ --config_file="my_maskrcnn.yaml"
+```
+
+An example with inline configuration (YAML or JSON format):
+
+```
+python3 ~/models/official/legacy/detection/main.py \
+ --model_dir= \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --mode=train \
+ --model=mask_rcnn \
+ --params_override="eval:
+ eval_file_pattern:
+ batch_size: 8
+ val_json_file:
+predict:
+ predict_batch_size: 8
+architecture:
+ use_bfloat16: False
+train:
+ total_steps: 1000
+ batch_size: 8
+ train_file_pattern:
+use_tpu: False
+"
+```
+
+## Train ShapeMask on TPU
+
+### Train a ResNet-50 based ShapeMask.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+RESNET_CHECKPOINT=""
+TRAIN_FILE_PATTERN=""
+EVAL_FILE_PATTERN=""
+VAL_JSON_FILE=""
+SHAPE_PRIOR_PATH=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --params_override="{train: { checkpoint: { path: ${RESNET_CHECKPOINT}, prefix: resnet50/ }, train_file_pattern: ${TRAIN_FILE_PATTERN} }, eval: { val_json_file: ${VAL_JSON_FILE}, eval_file_pattern: ${EVAL_FILE_PATTERN} } shapemask_head: {use_category_for_mask: true, shape_prior_path: ${SHAPE_PRIOR_PATH}} }"
+```
+
+The pre-trained ResNet-50 checkpoint can be downloaded [here](https://storage.cloud.google.com/cloud-tpu-checkpoints/model-garden-vision/detection/resnet50-2018-02-07.tar.gz).
+
+The shape priors can be downloaded [here]
+(https://storage.googleapis.com/cloud-tpu-checkpoints/shapemask/kmeans_class_priors_91x20x32x32.npy)
+
+
+### Train a custom ShapeMask using the config file.
+
+First, create a YAML config file, e.g. *my_shapemask.yaml*.
+This file specifies the parameters to be overridden:
+
+```YAML
+# my_shapemask.yaml
+train:
+ train_file_pattern:
+ total_steps:
+ batch_size:
+eval:
+ eval_file_pattern:
+ val_json_file:
+ batch_size:
+shapemask_head:
+ shape_prior_path:
+```
+
+Once the YAML config file is created, you can launch the training using the
+following command.
+
+```bash
+TPU_NAME=""
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --config_file="my_shapemask.yaml"
+```
+
+## Train ShapeMask on GPU
+
+Training on GPU is similar to that on TPU. The major change is the strategy type
+(use
+"[mirrored](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)"
+for multiple GPU and
+"[one_device](https://www.tensorflow.org/api_docs/python/tf/distribute/OneDeviceStrategy)"
+for single GPU).
+
+Multi-GPUs example (assuming there are 8GPU connected to the host):
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=mirrored \
+ --num_gpus=8 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --config_file="my_shapemask.yaml"
+```
+
+A single GPU example
+
+```bash
+MODEL_DIR=""
+python3 ~/models/official/legacy/detection/main.py \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --model_dir=${MODEL_DIR} \
+ --mode=train \
+ --model=shapemask \
+ --config_file="my_shapemask.yaml"
+```
+
+
+An example with inline configuration (YAML or JSON format):
+
+```
+python3 ~/models/official/legacy/detection/main.py \
+ --model_dir= \
+ --strategy_type=one_device \
+ --num_gpus=1 \
+ --mode=train \
+ --model=shapemask \
+ --params_override="eval:
+ eval_file_pattern:
+ batch_size: 8
+ val_json_file:
+train:
+ total_steps: 1000
+ batch_size: 8
+ train_file_pattern:
+use_tpu: False
+"
+```
+
+
+### Run the evaluation (after training)
+
+```
+python3 /usr/share/models/official/legacy/detection/main.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --model_dir=${MODEL_DIR} \
+ --mode=eval \
+ --model=shapemask \
+ --params_override="{eval: { val_json_file: ${VAL_JSON_FILE}, eval_file_pattern: ${EVAL_FILE_PATTERN}, eval_samples: 5000 } }"
+```
+
+`MODEL_DIR` needs to point to the trained path of ShapeMask model.
+Change `strategy_type=mirrored` and `num_gpus=1` to run on a GPU.
+
+Note: The JSON groundtruth file is useful for [COCO dataset](http://cocodataset.org/#home) and can be
+downloaded from the [COCO website](http://cocodataset.org/#download). For custom dataset, it is unncessary because the groundtruth can be included in the TFRecord files.
+
+## References
+
+1. [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002).
+ Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, and Piotr Dollár. IEEE
+ International Conference on Computer Vision (ICCV), 2017.
diff --git a/modeling/official/legacy/detection/__init__.py b/modeling/official/legacy/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/configs/__init__.py b/modeling/official/legacy/detection/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/configs/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/configs/base_config.py b/modeling/official/legacy/detection/configs/base_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1d1f903b06a149b7c0706b3ce5cca33e1209659
--- /dev/null
+++ b/modeling/official/legacy/detection/configs/base_config.py
@@ -0,0 +1,140 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Base config template."""
+
+
+BACKBONES = [
+ 'resnet',
+ 'spinenet',
+]
+
+MULTILEVEL_FEATURES = [
+ 'fpn',
+ 'identity',
+]
+
+# pylint: disable=line-too-long
+# For ResNet, this freezes the variables of the first conv1 and conv2_x
+# layers [1], which leads to higher training speed and slightly better testing
+# accuracy. The intuition is that the low-level architecture (e.g., ResNet-50)
+# is able to capture low-level features such as edges; therefore, it does not
+# need to be fine-tuned for the detection task.
+# Note that we need to trailing `/` to avoid the incorrect match.
+# [1]: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py#L198
+RESNET_FROZEN_VAR_PREFIX = r'(resnet\d+)\/(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
+REGULARIZATION_VAR_REGEX = r'.*(kernel|weight):0$'
+
+BASE_CFG = {
+ 'model_dir': '',
+ 'use_tpu': True,
+ 'strategy_type': 'tpu',
+ 'isolate_session_state': False,
+ 'train': {
+ 'iterations_per_loop': 100,
+ 'batch_size': 64,
+ 'total_steps': 22500,
+ 'num_cores_per_replica': None,
+ 'input_partition_dims': None,
+ 'optimizer': {
+ 'type': 'momentum',
+ 'momentum': 0.9,
+ 'nesterov': True, # `False` is better for TPU v3-128.
+ },
+ 'learning_rate': {
+ 'type': 'step',
+ 'warmup_learning_rate': 0.0067,
+ 'warmup_steps': 500,
+ 'init_learning_rate': 0.08,
+ 'learning_rate_levels': [0.008, 0.0008],
+ 'learning_rate_steps': [15000, 20000],
+ },
+ 'checkpoint': {
+ 'path': '',
+ 'prefix': '',
+ },
+ # One can use 'RESNET_FROZEN_VAR_PREFIX' to speed up ResNet training
+ # when loading from the checkpoint.
+ 'frozen_variable_prefix': '',
+ 'train_file_pattern': '',
+ 'train_dataset_type': 'tfrecord',
+ # TODO(b/142174042): Support transpose_input option.
+ 'transpose_input': False,
+ 'regularization_variable_regex': REGULARIZATION_VAR_REGEX,
+ 'l2_weight_decay': 0.0001,
+ 'gradient_clip_norm': 0.0,
+ 'input_sharding': False,
+ },
+ 'eval': {
+ 'input_sharding': True,
+ 'batch_size': 8,
+ 'eval_samples': 5000,
+ 'min_eval_interval': 180,
+ 'eval_timeout': None,
+ 'num_steps_per_eval': 1000,
+ 'type': 'box',
+ 'use_json_file': True,
+ 'val_json_file': '',
+ 'eval_file_pattern': '',
+ 'eval_dataset_type': 'tfrecord',
+ # When visualizing images, set evaluation batch size to 40 to avoid
+ # potential OOM.
+ 'num_images_to_visualize': 0,
+ },
+ 'predict': {
+ 'batch_size': 8,
+ },
+ 'architecture': {
+ 'backbone': 'resnet',
+ 'min_level': 3,
+ 'max_level': 7,
+ 'multilevel_features': 'fpn',
+ 'use_bfloat16': True,
+ # Note that `num_classes` is the total number of classes including
+ # one background classes whose index is 0.
+ 'num_classes': 91,
+ },
+ 'anchor': {
+ 'num_scales': 3,
+ 'aspect_ratios': [1.0, 2.0, 0.5],
+ 'anchor_size': 4.0,
+ },
+ 'norm_activation': {
+ 'activation': 'relu',
+ 'batch_norm_momentum': 0.997,
+ 'batch_norm_epsilon': 1e-4,
+ 'batch_norm_trainable': True,
+ 'use_sync_bn': False,
+ },
+ 'resnet': {
+ 'resnet_depth': 50,
+ },
+ 'spinenet': {
+ 'model_id': '49',
+ },
+ 'fpn': {
+ 'fpn_feat_dims': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': True,
+ },
+ 'postprocess': {
+ 'use_batched_nms': False,
+ 'max_total_size': 100,
+ 'nms_iou_threshold': 0.5,
+ 'score_threshold': 0.05,
+ 'pre_nms_num_boxes': 5000,
+ },
+ 'enable_summary': False,
+}
+# pylint: enable=line-too-long
diff --git a/modeling/official/legacy/detection/configs/factory.py b/modeling/official/legacy/detection/configs/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7ea676085ef63b1276abf18972b16f2ea14aedb
--- /dev/null
+++ b/modeling/official/legacy/detection/configs/factory.py
@@ -0,0 +1,41 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Factory to provide model configs."""
+
+from official.legacy.detection.configs import maskrcnn_config
+from official.legacy.detection.configs import olnmask_config
+from official.legacy.detection.configs import retinanet_config
+from official.legacy.detection.configs import shapemask_config
+from official.modeling.hyperparams import params_dict
+
+
+def config_generator(model):
+ """Model function generator."""
+ if model == 'retinanet':
+ default_config = retinanet_config.RETINANET_CFG
+ restrictions = retinanet_config.RETINANET_RESTRICTIONS
+ elif model == 'mask_rcnn':
+ default_config = maskrcnn_config.MASKRCNN_CFG
+ restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS
+ elif model == 'olnmask':
+ default_config = olnmask_config.OLNMASK_CFG
+ restrictions = olnmask_config.OLNMASK_RESTRICTIONS
+ elif model == 'shapemask':
+ default_config = shapemask_config.SHAPEMASK_CFG
+ restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS
+ else:
+ raise ValueError('Model %s is not supported.' % model)
+
+ return params_dict.ParamsDict(default_config, restrictions)
diff --git a/modeling/official/legacy/detection/configs/maskrcnn_config.py b/modeling/official/legacy/detection/configs/maskrcnn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..da565c6fdd873f3f2209f2aab42859d2ab208e85
--- /dev/null
+++ b/modeling/official/legacy/detection/configs/maskrcnn_config.py
@@ -0,0 +1,115 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Config template to train Mask R-CNN."""
+
+from official.legacy.detection.configs import base_config
+from official.modeling.hyperparams import params_dict
+
+
+# pylint: disable=line-too-long
+MASKRCNN_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
+MASKRCNN_CFG.override({
+ 'type': 'mask_rcnn',
+ 'eval': {
+ 'type': 'box_and_mask',
+ 'num_images_to_visualize': 0,
+ },
+ 'architecture': {
+ 'parser': 'maskrcnn_parser',
+ 'min_level': 2,
+ 'max_level': 6,
+ 'include_mask': True,
+ 'mask_target_size': 28,
+ },
+ 'maskrcnn_parser': {
+ 'output_size': [1024, 1024],
+ 'num_channels': 3,
+ 'rpn_match_threshold': 0.7,
+ 'rpn_unmatched_threshold': 0.3,
+ 'rpn_batch_size_per_im': 256,
+ 'rpn_fg_fraction': 0.5,
+ 'aug_rand_hflip': True,
+ 'aug_scale_min': 1.0,
+ 'aug_scale_max': 1.0,
+ 'skip_crowd_during_training': True,
+ 'max_num_instances': 100,
+ 'mask_crop_size': 112,
+ },
+ 'anchor': {
+ 'num_scales': 1,
+ 'anchor_size': 8,
+ },
+ 'rpn_head': {
+ 'num_convs': 2,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': False,
+ },
+ 'frcnn_head': {
+ 'num_convs': 0,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'num_fcs': 2,
+ 'fc_dims': 1024,
+ 'use_batch_norm': False,
+ },
+ 'mrcnn_head': {
+ 'num_convs': 4,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': False,
+ },
+ 'rpn_score_loss': {
+ 'rpn_batch_size_per_im': 256,
+ },
+ 'rpn_box_loss': {
+ 'huber_loss_delta': 1.0 / 9.0,
+ },
+ 'frcnn_box_loss': {
+ 'huber_loss_delta': 1.0,
+ },
+ 'roi_proposal': {
+ 'rpn_pre_nms_top_k': 2000,
+ 'rpn_post_nms_top_k': 1000,
+ 'rpn_nms_threshold': 0.7,
+ 'rpn_score_threshold': 0.0,
+ 'rpn_min_size_threshold': 0.0,
+ 'test_rpn_pre_nms_top_k': 1000,
+ 'test_rpn_post_nms_top_k': 1000,
+ 'test_rpn_nms_threshold': 0.7,
+ 'test_rpn_score_threshold': 0.0,
+ 'test_rpn_min_size_threshold': 0.0,
+ 'use_batched_nms': False,
+ },
+ 'roi_sampling': {
+ 'num_samples_per_image': 512,
+ 'fg_fraction': 0.25,
+ 'fg_iou_thresh': 0.5,
+ 'bg_iou_thresh_hi': 0.5,
+ 'bg_iou_thresh_lo': 0.0,
+ 'mix_gt_boxes': True,
+ },
+ 'mask_sampling': {
+ 'num_mask_samples_per_image': 128, # Typically = `num_samples_per_image` * `fg_fraction`.
+ },
+ 'postprocess': {
+ 'pre_nms_num_boxes': 1000,
+ },
+}, is_strict=False)
+
+
+MASKRCNN_RESTRICTIONS = [
+]
+# pylint: enable=line-too-long
diff --git a/modeling/official/legacy/detection/configs/olnmask_config.py b/modeling/official/legacy/detection/configs/olnmask_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c93e7a261347b19d85cb8de8a90478842e79f47
--- /dev/null
+++ b/modeling/official/legacy/detection/configs/olnmask_config.py
@@ -0,0 +1,143 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Config template to train Object Localization Network (OLN)."""
+
+from official.legacy.detection.configs import base_config
+from official.modeling.hyperparams import params_dict
+
+
+# pylint: disable=line-too-long
+OLNMASK_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
+OLNMASK_CFG.override({
+ 'type': 'olnmask',
+ 'eval': {
+ 'type': 'oln_xclass_box',
+ 'use_category': False,
+ 'seen_class': 'voc',
+ 'num_images_to_visualize': 0,
+ },
+ 'architecture': {
+ 'parser': 'olnmask_parser',
+ 'min_level': 2,
+ 'max_level': 6,
+ 'include_rpn_class': False,
+ 'include_frcnn_class': False,
+ 'include_frcnn_box': True,
+ 'include_mask': False,
+ 'mask_target_size': 28,
+ 'num_classes': 2,
+ },
+ 'olnmask_parser': {
+ 'output_size': [640, 640],
+ 'num_channels': 3,
+ 'rpn_match_threshold': 0.7,
+ 'rpn_unmatched_threshold': 0.3,
+ 'rpn_batch_size_per_im': 256,
+ 'rpn_fg_fraction': 0.5,
+ 'aug_rand_hflip': True,
+ 'aug_scale_min': 0.5,
+ 'aug_scale_max': 2.0,
+ 'skip_crowd_during_training': True,
+ 'max_num_instances': 100,
+ 'mask_crop_size': 112,
+ # centerness targets.
+ 'has_centerness': True,
+ 'rpn_center_match_iou_threshold': 0.3,
+ 'rpn_center_unmatched_iou_threshold': 0.1,
+ 'rpn_num_center_samples_per_im': 256,
+ # class manipulation.
+ 'class_agnostic': True,
+ 'train_class': 'voc',
+ },
+ 'anchor': {
+ 'num_scales': 1,
+ 'aspect_ratios': [1.0],
+ 'anchor_size': 8,
+ },
+ 'rpn_head': {
+ 'num_convs': 2,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': False,
+ # RPN-Centerness learning {
+ 'has_centerness': True, # }
+ },
+ 'frcnn_head': {
+ 'num_convs': 0,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'num_fcs': 2,
+ 'fc_dims': 1024,
+ 'use_batch_norm': False,
+ 'has_scoring': True,
+ },
+ 'mrcnn_head': {
+ 'num_convs': 4,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': False,
+ 'has_scoring': False,
+ },
+ 'rpn_score_loss': {
+ 'rpn_batch_size_per_im': 256,
+ },
+ 'rpn_box_loss': {
+ 'huber_loss_delta': 1.0 / 9.0,
+ },
+ 'frcnn_box_loss': {
+ 'huber_loss_delta': 1.0,
+ },
+ 'frcnn_box_score_loss': {
+ 'ignore_threshold': 0.3,
+ },
+ 'roi_proposal': {
+ 'rpn_pre_nms_top_k': 2000,
+ 'rpn_post_nms_top_k': 2000,
+ 'rpn_nms_threshold': 0.7,
+ 'rpn_score_threshold': 0.0,
+ 'rpn_min_size_threshold': 0.0,
+ 'test_rpn_pre_nms_top_k': 2000,
+ 'test_rpn_post_nms_top_k': 2000,
+ 'test_rpn_nms_threshold': 0.7,
+ 'test_rpn_score_threshold': 0.0,
+ 'test_rpn_min_size_threshold': 0.0,
+ 'use_batched_nms': False,
+ },
+ 'roi_sampling': {
+ 'num_samples_per_image': 512,
+ 'fg_fraction': 0.25,
+ 'fg_iou_thresh': 0.5,
+ 'bg_iou_thresh_hi': 0.5,
+ 'bg_iou_thresh_lo': 0.0,
+ 'mix_gt_boxes': True,
+ },
+ 'mask_sampling': {
+ 'num_mask_samples_per_image': 128, # Typically = `num_samples_per_image` * `fg_fraction`.
+ },
+ 'postprocess': {
+ 'use_batched_nms': False,
+ 'max_total_size': 100,
+ 'nms_iou_threshold': 0.5,
+ 'score_threshold': 0.00,
+ 'pre_nms_num_boxes': 2000,
+ },
+}, is_strict=False)
+
+
+OLNMASK_RESTRICTIONS = [
+ # 'anchor.aspect_ratios == [1.0]',
+ # 'anchor.scales == 1',
+]
+# pylint: enable=line-too-long
diff --git a/modeling/official/legacy/detection/configs/retinanet_config.py b/modeling/official/legacy/detection/configs/retinanet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c443f1787f7490dd1e5e14e38482379a4a5a614
--- /dev/null
+++ b/modeling/official/legacy/detection/configs/retinanet_config.py
@@ -0,0 +1,58 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Config template to train Retinanet."""
+
+from official.legacy.detection.configs import base_config
+from official.modeling.hyperparams import params_dict
+
+
+# pylint: disable=line-too-long
+RETINANET_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
+RETINANET_CFG.override({
+ 'type': 'retinanet',
+ 'architecture': {
+ 'parser': 'retinanet_parser',
+ },
+ 'retinanet_parser': {
+ 'output_size': [640, 640],
+ 'num_channels': 3,
+ 'match_threshold': 0.5,
+ 'unmatched_threshold': 0.5,
+ 'aug_rand_hflip': True,
+ 'aug_scale_min': 1.0,
+ 'aug_scale_max': 1.0,
+ 'use_autoaugment': False,
+ 'autoaugment_policy_name': 'v0',
+ 'skip_crowd_during_training': True,
+ 'max_num_instances': 100,
+ },
+ 'retinanet_head': {
+ 'num_convs': 4,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ },
+ 'retinanet_loss': {
+ 'focal_loss_alpha': 0.25,
+ 'focal_loss_gamma': 1.5,
+ 'huber_loss_delta': 0.1,
+ 'box_loss_weight': 50,
+ },
+ 'enable_summary': True,
+}, is_strict=False)
+
+RETINANET_RESTRICTIONS = [
+]
+
+# pylint: enable=line-too-long
diff --git a/modeling/official/legacy/detection/configs/shapemask_config.py b/modeling/official/legacy/detection/configs/shapemask_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac4dd243d931477a454a3ce7a2ecdbeb40662f4
--- /dev/null
+++ b/modeling/official/legacy/detection/configs/shapemask_config.py
@@ -0,0 +1,97 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Config to train shapemask on COCO."""
+
+from official.legacy.detection.configs import base_config
+from official.modeling.hyperparams import params_dict
+
+SHAPEMASK_RESNET_FROZEN_VAR_PREFIX = r'(conv2d(|_([1-9]|10))|batch_normalization(|_([1-9]|10)))\/'
+
+SHAPEMASK_CFG = params_dict.ParamsDict(base_config.BASE_CFG)
+SHAPEMASK_CFG.override({
+ 'type': 'shapemask',
+ 'architecture': {
+ 'parser': 'shapemask_parser',
+ 'backbone': 'resnet',
+ 'multilevel_features': 'fpn',
+ 'outer_box_scale': 1.25,
+ },
+ 'train': {
+ 'total_steps': 45000,
+ 'learning_rate': {
+ 'learning_rate_steps': [30000, 40000],
+ },
+ 'frozen_variable_prefix': SHAPEMASK_RESNET_FROZEN_VAR_PREFIX,
+ 'regularization_variable_regex': None,
+ },
+ 'eval': {
+ 'type': 'shapemask_box_and_mask',
+ 'mask_eval_class': 'all', # 'all', 'voc', or 'nonvoc'.
+ },
+ 'shapemask_parser': {
+ 'output_size': [640, 640],
+ 'num_channels': 3,
+ 'match_threshold': 0.5,
+ 'unmatched_threshold': 0.5,
+ 'aug_rand_hflip': True,
+ 'aug_scale_min': 0.8,
+ 'aug_scale_max': 1.2,
+ 'skip_crowd_during_training': True,
+ 'max_num_instances': 100,
+ # Shapemask specific parameters
+ 'mask_train_class': 'all', # 'all', 'voc', or 'nonvoc'.
+ 'use_category': True,
+ 'outer_box_scale': 1.25,
+ 'num_sampled_masks': 8,
+ 'mask_crop_size': 32,
+ 'mask_min_level': 3,
+ 'mask_max_level': 5,
+ 'box_jitter_scale': 0.025,
+ 'upsample_factor': 4,
+ },
+ 'retinanet_head': {
+ 'num_convs': 4,
+ 'num_filters': 256,
+ 'use_separable_conv': False,
+ 'use_batch_norm': True,
+ },
+ 'shapemask_head': {
+ 'num_downsample_channels': 128,
+ 'mask_crop_size': 32,
+ 'use_category_for_mask': True,
+ 'num_convs': 4,
+ 'upsample_factor': 4,
+ 'shape_prior_path': '',
+ },
+ 'retinanet_loss': {
+ 'focal_loss_alpha': 0.4,
+ 'focal_loss_gamma': 1.5,
+ 'huber_loss_delta': 0.15,
+ 'box_loss_weight': 50,
+ },
+ 'shapemask_loss': {
+ 'shape_prior_loss_weight': 0.1,
+ 'coarse_mask_loss_weight': 1.0,
+ 'fine_mask_loss_weight': 1.0,
+ },
+}, is_strict=False)
+
+SHAPEMASK_RESTRICTIONS = [
+ 'shapemask_head.mask_crop_size == shapemask_parser.mask_crop_size',
+ 'shapemask_head.upsample_factor == shapemask_parser.upsample_factor',
+ 'shapemask_parser.outer_box_scale == architecture.outer_box_scale',
+]
+
+# pylint: enable=line-too-long
diff --git a/modeling/official/legacy/detection/dataloader/__init__.py b/modeling/official/legacy/detection/dataloader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/dataloader/anchor.py b/modeling/official/legacy/detection/dataloader/anchor.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb36ad53a0a4ef7598973d648053425b238630ae
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/anchor.py
@@ -0,0 +1,458 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Anchor box and labeler definition."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import tensorflow as tf, tf_keras
+from official.legacy.detection.utils import box_utils
+from official.vision.ops import iou_similarity
+from official.vision.utils.object_detection import argmax_matcher
+from official.vision.utils.object_detection import balanced_positive_negative_sampler
+from official.vision.utils.object_detection import box_list
+from official.vision.utils.object_detection import faster_rcnn_box_coder
+from official.vision.utils.object_detection import target_assigner
+
+
+class Anchor(object):
+ """Anchor class for anchor-based object detectors."""
+
+ def __init__(self, min_level, max_level, num_scales, aspect_ratios,
+ anchor_size, image_size):
+ """Constructs multiscale anchors.
+
+ Args:
+ min_level: integer number of minimum level of the output feature pyramid.
+ max_level: integer number of maximum level of the output feature pyramid.
+ num_scales: integer number representing intermediate scales added on each
+ level. For instances, num_scales=2 adds one additional intermediate
+ anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: list of float numbers representing the aspect ratio anchors
+ added on each level. The number indicates the ratio of width to height.
+ For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each
+ scale level.
+ anchor_size: float number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ image_size: a list of integer numbers or Tensors representing [height,
+ width] of the input image size.The image_size should be divisible by the
+ largest feature stride 2^max_level.
+ """
+ self.min_level = min_level
+ self.max_level = max_level
+ self.num_scales = num_scales
+ self.aspect_ratios = aspect_ratios
+ self.anchor_size = anchor_size
+ self.image_size = image_size
+ self.boxes = self._generate_boxes()
+
+ def _generate_boxes(self):
+ """Generates multiscale anchor boxes.
+
+ Returns:
+ a Tensor of shape [N, 4], represneting anchor boxes of all levels
+ concatenated together.
+ """
+ boxes_all = []
+ for level in range(self.min_level, self.max_level + 1):
+ boxes_l = []
+ for scale in range(self.num_scales):
+ for aspect_ratio in self.aspect_ratios:
+ stride = 2**level
+ intermediate_scale = 2**(scale / float(self.num_scales))
+ base_anchor_size = self.anchor_size * stride * intermediate_scale
+ aspect_x = aspect_ratio**0.5
+ aspect_y = aspect_ratio**-0.5
+ half_anchor_size_x = base_anchor_size * aspect_x / 2.0
+ half_anchor_size_y = base_anchor_size * aspect_y / 2.0
+ x = tf.range(stride / 2, self.image_size[1], stride)
+ y = tf.range(stride / 2, self.image_size[0], stride)
+ xv, yv = tf.meshgrid(x, y)
+ xv = tf.cast(tf.reshape(xv, [-1]), dtype=tf.float32)
+ yv = tf.cast(tf.reshape(yv, [-1]), dtype=tf.float32)
+ # Tensor shape Nx4.
+ boxes = tf.stack([
+ yv - half_anchor_size_y, xv - half_anchor_size_x,
+ yv + half_anchor_size_y, xv + half_anchor_size_x
+ ],
+ axis=1)
+ boxes_l.append(boxes)
+ # Concat anchors on the same level to tensor shape NxAx4.
+ boxes_l = tf.stack(boxes_l, axis=1)
+ boxes_l = tf.reshape(boxes_l, [-1, 4])
+ boxes_all.append(boxes_l)
+ return tf.concat(boxes_all, axis=0)
+
+ def unpack_labels(self, labels):
+ """Unpacks an array of labels into multiscales labels."""
+ unpacked_labels = collections.OrderedDict()
+ count = 0
+ for level in range(self.min_level, self.max_level + 1):
+ feat_size_y = tf.cast(self.image_size[0] / 2**level, tf.int32)
+ feat_size_x = tf.cast(self.image_size[1] / 2**level, tf.int32)
+ steps = feat_size_y * feat_size_x * self.anchors_per_location
+ unpacked_labels[level] = tf.reshape(labels[count:count + steps],
+ [feat_size_y, feat_size_x, -1])
+ count += steps
+ return unpacked_labels
+
+ @property
+ def anchors_per_location(self):
+ return self.num_scales * len(self.aspect_ratios)
+
+ @property
+ def multilevel_boxes(self):
+ return self.unpack_labels(self.boxes)
+
+
+class AnchorLabeler(object):
+ """Labeler for dense object detector."""
+
+ def __init__(self, anchor, match_threshold=0.5, unmatched_threshold=0.5):
+ """Constructs anchor labeler to assign labels to anchors.
+
+ Args:
+ anchor: an instance of class Anchors.
+ match_threshold: a float number between 0 and 1 representing the
+ lower-bound threshold to assign positive labels for anchors. An anchor
+ with a score over the threshold is labeled positive.
+ unmatched_threshold: a float number between 0 and 1 representing the
+ upper-bound threshold to assign negative labels for anchors. An anchor
+ with a score below the threshold is labeled negative.
+ """
+ similarity_calc = iou_similarity.IouSimilarity()
+ matcher = argmax_matcher.ArgMaxMatcher(
+ match_threshold,
+ unmatched_threshold=unmatched_threshold,
+ negatives_lower_than_unmatched=True,
+ force_match_for_each_row=True)
+ box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
+
+ self._target_assigner = target_assigner.TargetAssigner(
+ similarity_calc, matcher, box_coder)
+ self._anchor = anchor
+ self._match_threshold = match_threshold
+ self._unmatched_threshold = unmatched_threshold
+
+ def label_anchors(self, gt_boxes, gt_labels):
+ """Labels anchors with ground truth inputs.
+
+ Args:
+ gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
+ For each row, it stores [y0, x0, y1, x1] for four corners of a box.
+ gt_labels: A integer tensor with shape [N, 1] representing groundtruth
+ classes.
+
+ Returns:
+ cls_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ box_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors_per_location * 4]. The height_l
+ and width_l represent the dimension of bounding box regression output at
+ l-th level.
+ num_positives: scalar tensor storing number of positives in an image.
+ """
+ gt_box_list = box_list.BoxList(gt_boxes)
+ anchor_box_list = box_list.BoxList(self._anchor.boxes)
+
+ # The cls_weights, box_weights are not used.
+ cls_targets, _, box_targets, _, matches = self._target_assigner.assign(
+ anchor_box_list, gt_box_list, gt_labels)
+
+ # Labels definition in matches.match_results:
+ # (1) match_results[i]>=0, meaning that column i is matched with row
+ # match_results[i].
+ # (2) match_results[i]=-1, meaning that column i is not matched.
+ # (3) match_results[i]=-2, meaning that column i is ignored.
+ match_results = tf.expand_dims(matches.match_results, axis=1)
+ cls_targets = tf.cast(cls_targets, tf.int32)
+ cls_targets = tf.where(
+ tf.equal(match_results, -1), -tf.ones_like(cls_targets), cls_targets)
+ cls_targets = tf.where(
+ tf.equal(match_results, -2), -2 * tf.ones_like(cls_targets),
+ cls_targets)
+
+ # Unpacks labels into multi-level representations.
+ cls_targets_dict = self._anchor.unpack_labels(cls_targets)
+ box_targets_dict = self._anchor.unpack_labels(box_targets)
+ num_positives = tf.reduce_sum(
+ input_tensor=tf.cast(tf.greater(matches.match_results, -1), tf.float32))
+
+ return cls_targets_dict, box_targets_dict, num_positives
+
+
+class RpnAnchorLabeler(AnchorLabeler):
+ """Labeler for Region Proposal Network."""
+
+ def __init__(self,
+ anchor,
+ match_threshold=0.7,
+ unmatched_threshold=0.3,
+ rpn_batch_size_per_im=256,
+ rpn_fg_fraction=0.5):
+ AnchorLabeler.__init__(
+ self, anchor, match_threshold=0.7, unmatched_threshold=0.3)
+ self._rpn_batch_size_per_im = rpn_batch_size_per_im
+ self._rpn_fg_fraction = rpn_fg_fraction
+
+ def _get_rpn_samples(self, match_results):
+ """Computes anchor labels.
+
+ This function performs subsampling for foreground (fg) and background (bg)
+ anchors.
+ Args:
+ match_results: A integer tensor with shape [N] representing the matching
+ results of anchors. (1) match_results[i]>=0, meaning that column i is
+ matched with row match_results[i]. (2) match_results[i]=-1, meaning that
+ column i is not matched. (3) match_results[i]=-2, meaning that column i
+ is ignored.
+
+ Returns:
+ score_targets: a integer tensor with the a shape of [N].
+ (1) score_targets[i]=1, the anchor is a positive sample.
+ (2) score_targets[i]=0, negative. (3) score_targets[i]=-1, the anchor is
+ don't care (ignore).
+ """
+ sampler = (
+ balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
+ positive_fraction=self._rpn_fg_fraction, is_static=False))
+ # indicator includes both positive and negative labels.
+ # labels includes only positives labels.
+ # positives = indicator & labels.
+ # negatives = indicator & !labels.
+ # ignore = !indicator.
+ indicator = tf.greater(match_results, -2)
+ labels = tf.greater(match_results, -1)
+
+ samples = sampler.subsample(indicator, self._rpn_batch_size_per_im, labels)
+ positive_labels = tf.where(
+ tf.logical_and(samples, labels),
+ tf.constant(2, dtype=tf.int32, shape=match_results.shape),
+ tf.constant(0, dtype=tf.int32, shape=match_results.shape))
+ negative_labels = tf.where(
+ tf.logical_and(samples, tf.logical_not(labels)),
+ tf.constant(1, dtype=tf.int32, shape=match_results.shape),
+ tf.constant(0, dtype=tf.int32, shape=match_results.shape))
+ ignore_labels = tf.fill(match_results.shape, -1)
+
+ return (ignore_labels + positive_labels + negative_labels, positive_labels,
+ negative_labels)
+
+ def label_anchors(self, gt_boxes, gt_labels):
+ """Labels anchors with ground truth inputs.
+
+ Args:
+ gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
+ For each row, it stores [y0, x0, y1, x1] for four corners of a box.
+ gt_labels: A integer tensor with shape [N, 1] representing groundtruth
+ classes.
+
+ Returns:
+ score_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors]. The height_l and width_l
+ represent the dimension of class logits at l-th level.
+ box_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ """
+ gt_box_list = box_list.BoxList(gt_boxes)
+ anchor_box_list = box_list.BoxList(self._anchor.boxes)
+
+ # cls_targets, cls_weights, box_weights are not used.
+ _, _, box_targets, _, matches = self._target_assigner.assign(
+ anchor_box_list, gt_box_list, gt_labels)
+
+ # score_targets contains the subsampled positive and negative anchors.
+ score_targets, _, _ = self._get_rpn_samples(matches.match_results)
+
+ # Unpacks labels.
+ score_targets_dict = self._anchor.unpack_labels(score_targets)
+ box_targets_dict = self._anchor.unpack_labels(box_targets)
+
+ return score_targets_dict, box_targets_dict
+
+
+class OlnAnchorLabeler(RpnAnchorLabeler):
+ """Labeler for Region Proposal Network."""
+
+ def __init__(self,
+ anchor,
+ match_threshold=0.7,
+ unmatched_threshold=0.3,
+ rpn_batch_size_per_im=256,
+ rpn_fg_fraction=0.5,
+ has_centerness=False,
+ center_match_iou_threshold=0.3,
+ center_unmatched_iou_threshold=0.1,
+ num_center_samples_per_im=256):
+ """Constructs rpn anchor labeler to assign labels and centerness to anchors.
+
+ Args:
+ anchor: an instance of class Anchors.
+ match_threshold: a float number between 0 and 1 representing the
+ lower-bound threshold to assign positive labels for anchors. An anchor
+ with a score over the threshold is labeled positive.
+ unmatched_threshold: a float number between 0 and 1 representing the
+ upper-bound threshold to assign negative labels for anchors. An anchor
+ with a score below the threshold is labeled negative.
+ rpn_batch_size_per_im: number of anchors that are sampled per image.
+ rpn_fg_fraction:
+ has_centerness: whether to include centerness target creation. An anchor
+ is paired with one centerness score.
+ center_match_iou_threshold: a float number between 0 and 1 representing
+ the lower-bound threshold to sample foreground anchors for centerness
+ regression. An anchor with a score over the threshold is sampled as
+ foreground sample for centerness regression. We sample mostly from the
+ foreground region (255 out of 256 samples). That is, we sample 255 vs 1
+ (foreground vs background) anchor points to learn centerness regression.
+ center_unmatched_iou_threshold: a float number between 0 and 1
+ representing the lower-bound threshold to sample background anchors for
+ centerness regression. An anchor with a score over the threshold is
+ sampled as foreground sample for centerness regression. We sample very
+ sparsely from the background region (1 out of 256 samples). That is, we
+ sample 255 vs 1 (foreground vs background) anchor points to learn
+ centerness regression.
+ num_center_samples_per_im: number of anchor points per image that are
+ sampled as centerness targets.
+ """
+ super(OlnAnchorLabeler, self).__init__(
+ anchor, match_threshold=match_threshold,
+ unmatched_threshold=unmatched_threshold,
+ rpn_batch_size_per_im=rpn_batch_size_per_im,
+ rpn_fg_fraction=rpn_fg_fraction)
+ similarity_calc = iou_similarity.IouSimilarity()
+ matcher = argmax_matcher.ArgMaxMatcher(
+ match_threshold,
+ unmatched_threshold=unmatched_threshold,
+ negatives_lower_than_unmatched=True,
+ force_match_for_each_row=True)
+ box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder()
+ if has_centerness:
+ center_matcher = argmax_matcher.ArgMaxMatcher(
+ center_match_iou_threshold,
+ unmatched_threshold=center_match_iou_threshold,
+ negatives_lower_than_unmatched=True,
+ force_match_for_each_row=True,)
+ else:
+ center_matcher = None
+
+ self._target_assigner = target_assigner.OlnTargetAssigner(
+ similarity_calc, matcher, box_coder,
+ center_matcher=center_matcher)
+ self._num_center_samples_per_im = num_center_samples_per_im
+ self._center_unmatched_iou_threshold = center_unmatched_iou_threshold
+ self._rpn_batch_size_per_im = rpn_batch_size_per_im
+ self._rpn_fg_fraction = rpn_fg_fraction
+
+ def label_anchors_lrtb(self, gt_boxes, gt_labels):
+ """Labels anchors with ground truth inputs.
+
+ Args:
+ gt_boxes: A float tensor with shape [N, 4] representing groundtruth boxes.
+ For each row, it stores [y0, x0, y1, x1] for four corners of a box.
+ gt_labels: A integer tensor with shape [N, 1] representing groundtruth
+ classes.
+
+ Returns:
+ score_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors]. The height_l and width_l
+ represent the dimension of class logits at l-th level.
+ box_targets_dict: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, num_anchors * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ lrtb_targets_dict: Same strucure to box_target_dict, except the regression
+ targets are converted from xyhw to lrtb format. Ordered dictionary with
+ keys [min_level, min_level+1, ..., max_level]. The values are tensor
+ with shape [height_l, width_l, num_anchors * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ center_targets_dict: Same structure to score_tragets_dict, except the
+ scores are centerness values ranging from 0 to 1. Ordered dictionary
+ with keys [min_level, min_level+1, ..., max_level]. The values are
+ tensor with shape [height_l, width_l, num_anchors]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ """
+ gt_box_list = box_list.BoxList(gt_boxes)
+ anchor_box_list = box_list.BoxList(self._anchor.boxes)
+
+ # cls_targets, cls_weights, box_weights are not used.
+ (_, _, box_targets, _, matches,
+ matched_gt_box_list, matched_anchors_mask,
+ center_matched_gt_box_list, center_matched_anchors_mask,
+ matched_ious) = self._target_assigner.assign(
+ anchor_box_list, gt_box_list, gt_labels)
+ # Box lrtb_targets.
+ lrtb_targets, _ = box_utils.encode_boxes_lrtb(
+ matched_gt_box_list.data['boxes'],
+ anchor_box_list.data['boxes'],
+ weights=[1.0, 1.0, 1.0, 1.0])
+ lrtb_sanity = tf.logical_and(
+ tf.greater(tf.reduce_min(lrtb_targets, -1), 0.),
+ matched_anchors_mask)
+ # To broadcast lrtb_sanity to the same shape as lrtb_targets.
+ lrtb_sanity = tf.tile(tf.expand_dims(lrtb_sanity, 1),
+ [1, tf.shape(lrtb_targets)[1]])
+ lrtb_targets = tf.where(lrtb_sanity,
+ lrtb_targets,
+ tf.zeros_like(lrtb_targets))
+ # RPN anchor-gtbox iou values.
+ iou_targets = tf.where(tf.greater(matched_ious, 0.0),
+ matched_ious,
+ tf.zeros_like(matched_ious))
+ # Centerness_targets.
+ _, center_targets = box_utils.encode_boxes_lrtb(
+ center_matched_gt_box_list.data['boxes'],
+ anchor_box_list.data['boxes'],
+ weights=[1.0, 1.0, 1.0, 1.0])
+ # Positive-negative centerness sampler.
+ num_center_samples_per_im = self._num_center_samples_per_im
+ center_pos_neg_sampler = (
+ balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
+ positive_fraction=(1.- 1./num_center_samples_per_im),
+ is_static=False))
+ center_pos_neg_indicator = tf.logical_or(
+ center_matched_anchors_mask,
+ tf.less(iou_targets, self._center_unmatched_iou_threshold))
+ center_pos_labels = center_matched_anchors_mask
+ center_samples = center_pos_neg_sampler.subsample(
+ center_pos_neg_indicator, num_center_samples_per_im, center_pos_labels)
+ is_valid = center_samples
+ center_targets = tf.where(is_valid,
+ center_targets,
+ (-1) * tf.ones_like(center_targets))
+
+ # score_targets contains the subsampled positive and negative anchors.
+ score_targets, _, _ = self._get_rpn_samples(matches.match_results)
+
+ # Unpacks labels.
+ score_targets_dict = self._anchor.unpack_labels(score_targets)
+ box_targets_dict = self._anchor.unpack_labels(box_targets)
+ lrtb_targets_dict = self._anchor.unpack_labels(lrtb_targets)
+ center_targets_dict = self._anchor.unpack_labels(center_targets)
+
+ return (score_targets_dict, box_targets_dict,
+ lrtb_targets_dict, center_targets_dict)
diff --git a/modeling/official/legacy/detection/dataloader/factory.py b/modeling/official/legacy/detection/dataloader/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ca7519895180f79ec7f691f9ef4854337ba9922
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/factory.py
@@ -0,0 +1,136 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model architecture factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from official.legacy.detection.dataloader import maskrcnn_parser
+from official.legacy.detection.dataloader import olnmask_parser
+from official.legacy.detection.dataloader import retinanet_parser
+from official.legacy.detection.dataloader import shapemask_parser
+
+
+def parser_generator(params, mode):
+ """Generator function for various dataset parser."""
+ if params.architecture.parser == 'retinanet_parser':
+ anchor_params = params.anchor
+ parser_params = params.retinanet_parser
+ parser_fn = retinanet_parser.Parser(
+ output_size=parser_params.output_size,
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ num_scales=anchor_params.num_scales,
+ aspect_ratios=anchor_params.aspect_ratios,
+ anchor_size=anchor_params.anchor_size,
+ match_threshold=parser_params.match_threshold,
+ unmatched_threshold=parser_params.unmatched_threshold,
+ aug_rand_hflip=parser_params.aug_rand_hflip,
+ aug_scale_min=parser_params.aug_scale_min,
+ aug_scale_max=parser_params.aug_scale_max,
+ use_autoaugment=parser_params.use_autoaugment,
+ autoaugment_policy_name=parser_params.autoaugment_policy_name,
+ skip_crowd_during_training=parser_params.skip_crowd_during_training,
+ max_num_instances=parser_params.max_num_instances,
+ use_bfloat16=params.architecture.use_bfloat16,
+ mode=mode)
+ elif params.architecture.parser == 'maskrcnn_parser':
+ anchor_params = params.anchor
+ parser_params = params.maskrcnn_parser
+ parser_fn = maskrcnn_parser.Parser(
+ output_size=parser_params.output_size,
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ num_scales=anchor_params.num_scales,
+ aspect_ratios=anchor_params.aspect_ratios,
+ anchor_size=anchor_params.anchor_size,
+ rpn_match_threshold=parser_params.rpn_match_threshold,
+ rpn_unmatched_threshold=parser_params.rpn_unmatched_threshold,
+ rpn_batch_size_per_im=parser_params.rpn_batch_size_per_im,
+ rpn_fg_fraction=parser_params.rpn_fg_fraction,
+ aug_rand_hflip=parser_params.aug_rand_hflip,
+ aug_scale_min=parser_params.aug_scale_min,
+ aug_scale_max=parser_params.aug_scale_max,
+ skip_crowd_during_training=parser_params.skip_crowd_during_training,
+ max_num_instances=parser_params.max_num_instances,
+ include_mask=params.architecture.include_mask,
+ mask_crop_size=parser_params.mask_crop_size,
+ use_bfloat16=params.architecture.use_bfloat16,
+ mode=mode)
+ elif params.architecture.parser == 'olnmask_parser':
+ anchor_params = params.anchor
+ parser_params = params.olnmask_parser
+ parser_fn = olnmask_parser.Parser(
+ output_size=parser_params.output_size,
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ num_scales=anchor_params.num_scales,
+ aspect_ratios=anchor_params.aspect_ratios,
+ anchor_size=anchor_params.anchor_size,
+ rpn_match_threshold=parser_params.rpn_match_threshold,
+ rpn_unmatched_threshold=parser_params.rpn_unmatched_threshold,
+ rpn_batch_size_per_im=parser_params.rpn_batch_size_per_im,
+ rpn_fg_fraction=parser_params.rpn_fg_fraction,
+ aug_rand_hflip=parser_params.aug_rand_hflip,
+ aug_scale_min=parser_params.aug_scale_min,
+ aug_scale_max=parser_params.aug_scale_max,
+ skip_crowd_during_training=parser_params.skip_crowd_during_training,
+ max_num_instances=parser_params.max_num_instances,
+ include_mask=params.architecture.include_mask,
+ mask_crop_size=parser_params.mask_crop_size,
+ use_bfloat16=params.architecture.use_bfloat16,
+ mode=mode,
+ has_centerness=parser_params.has_centerness,
+ rpn_center_match_iou_threshold=(
+ parser_params.rpn_center_match_iou_threshold),
+ rpn_center_unmatched_iou_threshold=(
+ parser_params.rpn_center_unmatched_iou_threshold),
+ rpn_num_center_samples_per_im=(
+ parser_params.rpn_num_center_samples_per_im),
+ class_agnostic=parser_params.class_agnostic,
+ train_class=parser_params.train_class,)
+ elif params.architecture.parser == 'shapemask_parser':
+ anchor_params = params.anchor
+ parser_params = params.shapemask_parser
+ parser_fn = shapemask_parser.Parser(
+ output_size=parser_params.output_size,
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ num_scales=anchor_params.num_scales,
+ aspect_ratios=anchor_params.aspect_ratios,
+ anchor_size=anchor_params.anchor_size,
+ use_category=parser_params.use_category,
+ outer_box_scale=parser_params.outer_box_scale,
+ box_jitter_scale=parser_params.box_jitter_scale,
+ num_sampled_masks=parser_params.num_sampled_masks,
+ mask_crop_size=parser_params.mask_crop_size,
+ mask_min_level=parser_params.mask_min_level,
+ mask_max_level=parser_params.mask_max_level,
+ upsample_factor=parser_params.upsample_factor,
+ match_threshold=parser_params.match_threshold,
+ unmatched_threshold=parser_params.unmatched_threshold,
+ aug_rand_hflip=parser_params.aug_rand_hflip,
+ aug_scale_min=parser_params.aug_scale_min,
+ aug_scale_max=parser_params.aug_scale_max,
+ skip_crowd_during_training=parser_params.skip_crowd_during_training,
+ max_num_instances=parser_params.max_num_instances,
+ use_bfloat16=params.architecture.use_bfloat16,
+ mask_train_class=parser_params.mask_train_class,
+ mode=mode)
+ else:
+ raise ValueError('Parser %s is not supported.' % params.architecture.parser)
+
+ return parser_fn
diff --git a/modeling/official/legacy/detection/dataloader/input_reader.py b/modeling/official/legacy/detection/dataloader/input_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c114f8a80675b2908a1b6b6566ed8b7d31268aa9
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/input_reader.py
@@ -0,0 +1,105 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data loader and input processing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from typing import Optional, Text
+import tensorflow as tf, tf_keras
+from official.legacy.detection.dataloader import factory
+from official.legacy.detection.dataloader import mode_keys as ModeKeys
+from official.modeling.hyperparams import params_dict
+
+
+class InputFn(object):
+ """Input function that creates dataset from files."""
+
+ def __init__(self,
+ file_pattern: Text,
+ params: params_dict.ParamsDict,
+ mode: Text,
+ batch_size: int,
+ num_examples: Optional[int] = -1):
+ """Initialize.
+
+ Args:
+ file_pattern: the file pattern for the data example (TFRecords).
+ params: the parameter object for constructing example parser and model.
+ mode: ModeKeys.TRAIN or ModeKeys.Eval
+ batch_size: the data batch size.
+ num_examples: If positive, only takes this number of examples and raise
+ tf.errors.OutOfRangeError after that. If non-positive, it will be
+ ignored.
+ """
+ assert file_pattern is not None
+ assert mode is not None
+ assert batch_size is not None
+ self._file_pattern = file_pattern
+ self._mode = mode
+ self._is_training = (mode == ModeKeys.TRAIN)
+ self._batch_size = batch_size
+ self._num_examples = num_examples
+ self._parser_fn = factory.parser_generator(params, mode)
+ self._dataset_fn = tf.data.TFRecordDataset
+
+ self._input_sharding = (not self._is_training)
+ try:
+ if self._is_training:
+ self._input_sharding = params.train.input_sharding
+ else:
+ self._input_sharding = params.eval.input_sharding
+ except AttributeError:
+ pass
+
+ def __call__(self, ctx=None, batch_size: int = None):
+ """Provides tf.data.Dataset object.
+
+ Args:
+ ctx: context object.
+ batch_size: expected batch size input data.
+
+ Returns:
+ tf.data.Dataset object.
+ """
+ if not batch_size:
+ batch_size = self._batch_size
+ assert batch_size is not None
+ dataset = tf.data.Dataset.list_files(
+ self._file_pattern, shuffle=self._is_training)
+
+ if self._input_sharding and ctx and ctx.num_input_pipelines > 1:
+ dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
+ dataset = dataset.cache()
+
+ if self._is_training:
+ dataset = dataset.repeat()
+
+ dataset = dataset.interleave(
+ map_func=self._dataset_fn,
+ cycle_length=32,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if self._is_training:
+ dataset = dataset.shuffle(1000)
+ if self._num_examples > 0:
+ dataset = dataset.take(self._num_examples)
+
+ # Parses the fetched records to input tensors for model function.
+ dataset = dataset.map(
+ self._parser_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ return dataset
diff --git a/modeling/official/legacy/detection/dataloader/maskrcnn_parser.py b/modeling/official/legacy/detection/dataloader/maskrcnn_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d37e3f86435020523b433bce17be930caa2f95fe
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/maskrcnn_parser.py
@@ -0,0 +1,381 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data parser and processing for Mask R-CNN."""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import anchor
+from official.legacy.detection.dataloader import mode_keys as ModeKeys
+from official.legacy.detection.dataloader import tf_example_decoder
+from official.legacy.detection.utils import box_utils
+from official.legacy.detection.utils import dataloader_utils
+from official.legacy.detection.utils import input_utils
+
+
+class Parser(object):
+ """Parser to parse an image and its annotations into a dictionary of tensors."""
+
+ def __init__(self,
+ output_size,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ rpn_match_threshold=0.7,
+ rpn_unmatched_threshold=0.3,
+ rpn_batch_size_per_im=256,
+ rpn_fg_fraction=0.5,
+ aug_rand_hflip=False,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ skip_crowd_during_training=True,
+ max_num_instances=100,
+ include_mask=False,
+ mask_crop_size=112,
+ use_bfloat16=True,
+ mode=None):
+ """Initializes parameters for parsing annotations in the dataset.
+
+ Args:
+ output_size: `Tensor` or `list` for [height, width] of output image. The
+ output_size should be divided by the largest feature stride 2^max_level.
+ min_level: `int` number of minimum level of the output feature pyramid.
+ max_level: `int` number of maximum level of the output feature pyramid.
+ num_scales: `int` number representing intermediate scales added
+ on each level. For instances, num_scales=2 adds one additional
+ intermediate anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: `list` of float numbers representing the aspect raito
+ anchors added on each level. The number indicates the ratio of width to
+ height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
+ on each scale level.
+ anchor_size: `float` number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ rpn_match_threshold:
+ rpn_unmatched_threshold:
+ rpn_batch_size_per_im:
+ rpn_fg_fraction:
+ aug_rand_hflip: `bool`, if True, augment training with random
+ horizontal flip.
+ aug_scale_min: `float`, the minimum scale applied to `output_size` for
+ data augmentation during training.
+ aug_scale_max: `float`, the maximum scale applied to `output_size` for
+ data augmentation during training.
+ skip_crowd_during_training: `bool`, if True, skip annotations labeled with
+ `is_crowd` equals to 1.
+ max_num_instances: `int` number of maximum number of instances in an
+ image. The groundtruth data will be padded to `max_num_instances`.
+ include_mask: a bool to indicate whether parse mask groundtruth.
+ mask_crop_size: the size which groundtruth mask is cropped to.
+ use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
+ mode: a ModeKeys. Specifies if this is training, evaluation, prediction
+ or prediction with groundtruths in the outputs.
+ """
+ self._mode = mode
+ self._max_num_instances = max_num_instances
+ self._skip_crowd_during_training = skip_crowd_during_training
+ self._is_training = (mode == ModeKeys.TRAIN)
+
+ self._example_decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=include_mask)
+
+ # Anchor.
+ self._output_size = output_size
+ self._min_level = min_level
+ self._max_level = max_level
+ self._num_scales = num_scales
+ self._aspect_ratios = aspect_ratios
+ self._anchor_size = anchor_size
+
+ # Target assigning.
+ self._rpn_match_threshold = rpn_match_threshold
+ self._rpn_unmatched_threshold = rpn_unmatched_threshold
+ self._rpn_batch_size_per_im = rpn_batch_size_per_im
+ self._rpn_fg_fraction = rpn_fg_fraction
+
+ # Data augmentation.
+ self._aug_rand_hflip = aug_rand_hflip
+ self._aug_scale_min = aug_scale_min
+ self._aug_scale_max = aug_scale_max
+
+ # Mask.
+ self._include_mask = include_mask
+ self._mask_crop_size = mask_crop_size
+
+ # Device.
+ self._use_bfloat16 = use_bfloat16
+
+ # Data is parsed depending on the model Modekey.
+ if mode == ModeKeys.TRAIN:
+ self._parse_fn = self._parse_train_data
+ elif mode == ModeKeys.EVAL:
+ self._parse_fn = self._parse_eval_data
+ elif mode == ModeKeys.PREDICT or mode == ModeKeys.PREDICT_WITH_GT:
+ self._parse_fn = self._parse_predict_data
+ else:
+ raise ValueError('mode is not defined.')
+
+ def __call__(self, value):
+ """Parses data to an image and associated training labels.
+
+ Args:
+ value: a string tensor holding a serialized tf.Example proto.
+
+ Returns:
+ image, labels: if mode == ModeKeys.TRAIN. see _parse_train_data.
+ {'images': image, 'labels': labels}: if mode == ModeKeys.PREDICT
+ or ModeKeys.PREDICT_WITH_GT.
+ """
+ with tf.name_scope('parser'):
+ data = self._example_decoder.decode(value)
+ return self._parse_fn(data)
+
+ def _parse_train_data(self, data):
+ """Parses data for training.
+
+ Args:
+ data: the decoded tensor dictionary from TfExampleDecoder.
+
+ Returns:
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ labels: a dictionary of tensors used for training. The following describes
+ {key: value} pairs in the dictionary.
+ image_info: a 2D `Tensor` that encodes the information of the image and
+ the applied preprocessing. It is in the format of
+ [[original_height, original_width], [scaled_height, scaled_width],
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each level.
+ rpn_score_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ rpn_box_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ gt_boxes: Groundtruth bounding box annotations. The box is represented
+ in [y1, x1, y2, x2] format. The coordinates are w.r.t the scaled
+ image that is fed to the network. The tennsor is padded with -1 to
+ the fixed dimension [self._max_num_instances, 4].
+ gt_classes: Groundtruth classes annotations. The tennsor is padded
+ with -1 to the fixed dimension [self._max_num_instances].
+ gt_masks: groundtrugh masks cropped by the bounding box and
+ resized to a fixed size determined by mask_crop_size.
+ """
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ if self._include_mask:
+ masks = data['groundtruth_instance_masks']
+
+ is_crowds = data['groundtruth_is_crowd']
+ # Skips annotations with `is_crowd` = True.
+ if self._skip_crowd_during_training and self._is_training:
+ num_groundtruths = tf.shape(classes)[0]
+ with tf.control_dependencies([num_groundtruths, is_crowds]):
+ indices = tf.cond(
+ tf.greater(tf.size(is_crowds), 0),
+ lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
+ lambda: tf.cast(tf.range(num_groundtruths), tf.int64))
+ classes = tf.gather(classes, indices)
+ boxes = tf.gather(boxes, indices)
+ if self._include_mask:
+ masks = tf.gather(masks, indices)
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Flips image randomly during training.
+ if self._aug_rand_hflip:
+ if self._include_mask:
+ image, boxes, masks = input_utils.random_horizontal_flip(
+ image, boxes, masks)
+ else:
+ image, boxes = input_utils.random_horizontal_flip(
+ image, boxes)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ # Now the coordinates of boxes are w.r.t. the original image.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=self._aug_scale_min,
+ aug_scale_max=self._aug_scale_max)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # Resizes and crops boxes.
+ # Now the coordinates of boxes are w.r.t the scaled image.
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+ if self._include_mask:
+ masks = tf.gather(masks, indices)
+ # Transfer boxes to the original image space and do normalization.
+ cropped_boxes = boxes + tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
+ cropped_boxes /= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
+ cropped_boxes = box_utils.normalize_boxes(cropped_boxes, image_shape)
+ num_masks = tf.shape(masks)[0]
+ masks = tf.image.crop_and_resize(
+ tf.expand_dims(masks, axis=-1),
+ cropped_boxes,
+ box_indices=tf.range(num_masks, dtype=tf.int32),
+ crop_size=[self._mask_crop_size, self._mask_crop_size],
+ method='bilinear')
+ masks = tf.squeeze(masks, axis=-1)
+
+ # Assigns anchor targets.
+ # Note that after the target assignment, box targets are absolute pixel
+ # offsets w.r.t. the scaled image.
+ input_anchor = anchor.Anchor(
+ self._min_level,
+ self._max_level,
+ self._num_scales,
+ self._aspect_ratios,
+ self._anchor_size,
+ (image_height, image_width))
+ anchor_labeler = anchor.RpnAnchorLabeler(
+ input_anchor,
+ self._rpn_match_threshold,
+ self._rpn_unmatched_threshold,
+ self._rpn_batch_size_per_im,
+ self._rpn_fg_fraction)
+ rpn_score_targets, rpn_box_targets = anchor_labeler.label_anchors(
+ boxes, tf.cast(tf.expand_dims(classes, axis=-1), dtype=tf.float32))
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ }
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'image_info': image_info,
+ 'rpn_score_targets': rpn_score_targets,
+ 'rpn_box_targets': rpn_box_targets,
+ }
+ inputs['gt_boxes'] = input_utils.pad_to_fixed_size(boxes,
+ self._max_num_instances,
+ -1)
+ inputs['gt_classes'] = input_utils.pad_to_fixed_size(
+ classes, self._max_num_instances, -1)
+ if self._include_mask:
+ inputs['gt_masks'] = input_utils.pad_to_fixed_size(
+ masks, self._max_num_instances, -1)
+
+ return inputs, labels
+
+ def _parse_eval_data(self, data):
+ """Parses data for evaluation."""
+ raise NotImplementedError('Not implemented!')
+
+ def _parse_predict_data(self, data):
+ """Parses data for prediction.
+
+ Args:
+ data: the decoded tensor dictionary from TfExampleDecoder.
+
+ Returns:
+ A dictionary of {'images': image, 'labels': labels} where
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ labels: a dictionary of tensors used for training. The following
+ describes {key: value} pairs in the dictionary.
+ source_ids: Source image id. Default value -1 if the source id is
+ empty in the groundtruth annotation.
+ image_info: a 2D `Tensor` that encodes the information of the image
+ and the applied preprocessing. It is in the format of
+ [[original_height, original_width], [scaled_height, scaled_width],
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each
+ level.
+ """
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Compute Anchor boxes.
+ _ = anchor.Anchor(self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size,
+ (image_height, image_width))
+
+ labels = {
+ 'image_info': image_info,
+ }
+
+ if self._mode == ModeKeys.PREDICT_WITH_GT:
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(
+ data['groundtruth_boxes'], image_shape)
+ groundtruths = {
+ 'source_id': data['source_id'],
+ 'height': data['height'],
+ 'width': data['width'],
+ 'num_detections': tf.shape(data['groundtruth_classes']),
+ 'boxes': boxes,
+ 'classes': data['groundtruth_classes'],
+ 'areas': data['groundtruth_area'],
+ 'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = dataloader_utils.process_source_id(
+ groundtruths['source_id'])
+ groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
+ groundtruths, self._max_num_instances)
+ # TODO(yeqing): Remove the `groundtrtuh` layer key (no longer needed).
+ labels['groundtruths'] = groundtruths
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ }
+
+ return inputs, labels
diff --git a/modeling/official/legacy/detection/dataloader/mode_keys.py b/modeling/official/legacy/detection/dataloader/mode_keys.py
new file mode 100644
index 0000000000000000000000000000000000000000..35d3c1a294024498b2bf2f652289b79cec4af7c1
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/mode_keys.py
@@ -0,0 +1,33 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Standard names for input dataloader modes.
+
+The following standard keys are defined:
+
+* `TRAIN`: training mode.
+* `EVAL`: evaluation mode.
+* `PREDICT`: prediction mode.
+* `PREDICT_WITH_GT`: prediction mode with groundtruths in returned variables.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+TRAIN = 'train'
+EVAL = 'eval'
+PREDICT = 'predict'
+PREDICT_WITH_GT = 'predict_with_gt'
diff --git a/modeling/official/legacy/detection/dataloader/olnmask_parser.py b/modeling/official/legacy/detection/dataloader/olnmask_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b8e9bcebcb62fb0d35291ecd88743a6b993a6d1
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/olnmask_parser.py
@@ -0,0 +1,327 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data parser and processing for Mask R-CNN."""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import anchor
+from official.legacy.detection.dataloader.maskrcnn_parser import Parser as MaskrcnnParser
+from official.legacy.detection.utils import box_utils
+from official.legacy.detection.utils import class_utils
+from official.legacy.detection.utils import input_utils
+
+
+class Parser(MaskrcnnParser):
+ """Parser to parse an image and its annotations into a dictionary of tensors."""
+
+ def __init__(self,
+ output_size,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ rpn_match_threshold=0.7,
+ rpn_unmatched_threshold=0.3,
+ rpn_batch_size_per_im=256,
+ rpn_fg_fraction=0.5,
+ aug_rand_hflip=False,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ skip_crowd_during_training=True,
+ max_num_instances=100,
+ include_mask=False,
+ mask_crop_size=112,
+ use_bfloat16=True,
+ mode=None,
+ # for centerness learning.
+ has_centerness=False,
+ rpn_center_match_iou_threshold=0.3,
+ rpn_center_unmatched_iou_threshold=0.1,
+ rpn_num_center_samples_per_im=256,
+ # for class manipulation.
+ class_agnostic=False,
+ train_class='all',
+ ):
+ """Initializes parameters for parsing annotations in the dataset.
+
+ Args:
+ output_size: `Tensor` or `list` for [height, width] of output image. The
+ output_size should be divided by the largest feature stride 2^max_level.
+ min_level: `int` number of minimum level of the output feature pyramid.
+ max_level: `int` number of maximum level of the output feature pyramid.
+ num_scales: `int` number representing intermediate scales added
+ on each level. For instances, num_scales=2 adds one additional
+ intermediate anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: `list` of float numbers representing the aspect raito
+ anchors added on each level. The number indicates the ratio of width to
+ height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
+ on each scale level.
+ anchor_size: `float` number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ rpn_match_threshold:
+ rpn_unmatched_threshold:
+ rpn_batch_size_per_im:
+ rpn_fg_fraction:
+ aug_rand_hflip: `bool`, if True, augment training with random
+ horizontal flip.
+ aug_scale_min: `float`, the minimum scale applied to `output_size` for
+ data augmentation during training.
+ aug_scale_max: `float`, the maximum scale applied to `output_size` for
+ data augmentation during training.
+ skip_crowd_during_training: `bool`, if True, skip annotations labeled with
+ `is_crowd` equals to 1.
+ max_num_instances: `int` number of maximum number of instances in an
+ image. The groundtruth data will be padded to `max_num_instances`.
+ include_mask: a bool to indicate whether parse mask groundtruth.
+ mask_crop_size: the size which groundtruth mask is cropped to.
+ use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
+ mode: a ModeKeys. Specifies if this is training, evaluation, prediction
+ or prediction with groundtruths in the outputs.
+ has_centerness: whether to create centerness targets
+ rpn_center_match_iou_threshold: iou threshold for valid centerness samples
+ ,set to 0.3 by default.
+ rpn_center_unmatched_iou_threshold: iou threshold for invalid centerness
+ samples, set to 0.1 by default.
+ rpn_num_center_samples_per_im: number of centerness samples per image,
+ 256 by default.
+ class_agnostic: whether to merge class ids into one foreground(=1) class,
+ False by default.
+ train_class: 'all' or 'voc' or 'nonvoc', 'all' by default.
+ """
+ super(Parser, self).__init__(
+ output_size=output_size,
+ min_level=min_level,
+ max_level=max_level,
+ num_scales=num_scales,
+ aspect_ratios=aspect_ratios,
+ anchor_size=anchor_size,
+ rpn_match_threshold=rpn_match_threshold,
+ rpn_unmatched_threshold=rpn_unmatched_threshold,
+ rpn_batch_size_per_im=rpn_batch_size_per_im,
+ rpn_fg_fraction=rpn_fg_fraction,
+ aug_rand_hflip=aug_rand_hflip,
+ aug_scale_min=aug_scale_min,
+ aug_scale_max=aug_scale_max,
+ skip_crowd_during_training=skip_crowd_during_training,
+ max_num_instances=max_num_instances,
+ include_mask=include_mask,
+ mask_crop_size=mask_crop_size,
+ use_bfloat16=use_bfloat16,
+ mode=mode,)
+
+ # Centerness target assigning.
+ self._has_centerness = has_centerness
+ self._rpn_center_match_iou_threshold = rpn_center_match_iou_threshold
+ self._rpn_center_unmatched_iou_threshold = (
+ rpn_center_unmatched_iou_threshold)
+ self._rpn_num_center_samples_per_im = rpn_num_center_samples_per_im
+
+ # Class manipulation.
+ self._class_agnostic = class_agnostic
+ self._train_class = train_class
+
+ def _parse_train_data(self, data):
+ """Parses data for training.
+
+ Args:
+ data: the decoded tensor dictionary from TfExampleDecoder.
+
+ Returns:
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ labels: a dictionary of tensors used for training. The following describes
+ {key: value} pairs in the dictionary.
+ image_info: a 2D `Tensor` that encodes the information of the image and
+ the applied preprocessing. It is in the format of
+ [[original_height, original_width], [scaled_height, scaled_width],
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each level.
+ rpn_score_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ rpn_box_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ gt_boxes: Groundtruth bounding box annotations. The box is represented
+ in [y1, x1, y2, x2] format. The coordinates are w.r.t the scaled
+ image that is fed to the network. The tennsor is padded with -1 to
+ the fixed dimension [self._max_num_instances, 4].
+ gt_classes: Groundtruth classes annotations. The tennsor is padded
+ with -1 to the fixed dimension [self._max_num_instances].
+ gt_masks: groundtrugh masks cropped by the bounding box and
+ resized to a fixed size determined by mask_crop_size.
+ """
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ if self._include_mask:
+ masks = data['groundtruth_instance_masks']
+
+ is_crowds = data['groundtruth_is_crowd']
+ # Skips annotations with `is_crowd` = True.
+ if self._skip_crowd_during_training and self._is_training:
+ num_groundtruths = tf.shape(classes)[0]
+ with tf.control_dependencies([num_groundtruths, is_crowds]):
+ indices = tf.cond(
+ tf.greater(tf.size(is_crowds), 0),
+ lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
+ lambda: tf.cast(tf.range(num_groundtruths), tf.int64))
+ classes = tf.gather(classes, indices)
+ boxes = tf.gather(boxes, indices)
+ if self._include_mask:
+ masks = tf.gather(masks, indices)
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Flips image randomly during training.
+ if self._aug_rand_hflip:
+ if self._include_mask:
+ image, boxes, masks = input_utils.random_horizontal_flip(
+ image, boxes, masks)
+ else:
+ image, boxes = input_utils.random_horizontal_flip(
+ image, boxes)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ # Now the coordinates of boxes are w.r.t. the original image.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(
+ self._output_size, 2 ** self._max_level),
+ aug_scale_min=self._aug_scale_min,
+ aug_scale_max=self._aug_scale_max)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # Resizes and crops boxes.
+ # Now the coordinates of boxes are w.r.t the scaled image.
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+ if self._include_mask:
+ masks = tf.gather(masks, indices)
+ # Transfer boxes to the original image space and do normalization.
+ cropped_boxes = boxes + tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
+ cropped_boxes /= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
+ cropped_boxes = box_utils.normalize_boxes(cropped_boxes, image_shape)
+ num_masks = tf.shape(masks)[0]
+ masks = tf.image.crop_and_resize(
+ tf.expand_dims(masks, axis=-1),
+ cropped_boxes,
+ box_indices=tf.range(num_masks, dtype=tf.int32),
+ crop_size=[self._mask_crop_size, self._mask_crop_size],
+ method='bilinear')
+ masks = tf.squeeze(masks, axis=-1)
+
+ # Class manipulation.
+ # Filter out novel split classes from training.
+ if self._train_class != 'all':
+ valid_classes = tf.cast(
+ class_utils.coco_split_class_ids(self._train_class),
+ dtype=classes.dtype)
+ match = tf.reduce_any(tf.equal(
+ tf.expand_dims(valid_classes, 1),
+ tf.expand_dims(classes, 0)), 0)
+ # kill novel split classes and boxes.
+ boxes = tf.gather(boxes, tf.where(match)[:, 0])
+ classes = tf.gather(classes, tf.where(match)[:, 0])
+ if self._include_mask:
+ masks = tf.gather(masks, tf.where(match)[:, 0])
+
+ # Assigns anchor targets.
+ # Note that after the target assignment, box targets are absolute pixel
+ # offsets w.r.t. the scaled image.
+ input_anchor = anchor.Anchor(
+ self._min_level,
+ self._max_level,
+ self._num_scales,
+ self._aspect_ratios,
+ self._anchor_size,
+ (image_height, image_width))
+ anchor_labeler = anchor.OlnAnchorLabeler(
+ input_anchor,
+ self._rpn_match_threshold,
+ self._rpn_unmatched_threshold,
+ self._rpn_batch_size_per_im,
+ self._rpn_fg_fraction,
+ # for centerness target.
+ self._has_centerness,
+ self._rpn_center_match_iou_threshold,
+ self._rpn_center_unmatched_iou_threshold,
+ self._rpn_num_center_samples_per_im,)
+
+ if self._has_centerness:
+ rpn_score_targets, _, rpn_lrtb_targets, rpn_center_targets = (
+ anchor_labeler.label_anchors_lrtb(
+ gt_boxes=boxes,
+ gt_labels=tf.cast(
+ tf.expand_dims(classes, axis=-1), dtype=tf.float32)))
+ else:
+ rpn_score_targets, rpn_box_targets = anchor_labeler.label_anchors(
+ boxes, tf.cast(tf.expand_dims(classes, axis=-1), dtype=tf.float32))
+ # For base rpn, dummy placeholder for centerness target.
+ rpn_center_targets = rpn_score_targets.copy()
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ }
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'image_info': image_info,
+ 'rpn_score_targets': rpn_score_targets,
+ 'rpn_box_targets': (rpn_lrtb_targets if self._has_centerness
+ else rpn_box_targets),
+ 'rpn_center_targets': rpn_center_targets,
+ }
+ # If class_agnostic, convert to binary classes.
+ if self._class_agnostic:
+ classes = tf.where(tf.greater(classes, 0),
+ tf.ones_like(classes),
+ tf.zeros_like(classes))
+
+ inputs['gt_boxes'] = input_utils.pad_to_fixed_size(boxes,
+ self._max_num_instances,
+ -1)
+ inputs['gt_classes'] = input_utils.pad_to_fixed_size(
+ classes, self._max_num_instances, -1)
+ if self._include_mask:
+ inputs['gt_masks'] = input_utils.pad_to_fixed_size(
+ masks, self._max_num_instances, -1)
+
+ return inputs, labels
diff --git a/modeling/official/legacy/detection/dataloader/retinanet_parser.py b/modeling/official/legacy/detection/dataloader/retinanet_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..adcd53a5b72a9acc882c3a09703e697f6133973b
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/retinanet_parser.py
@@ -0,0 +1,425 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data parser and processing.
+
+Parse image and ground truths in a dataset to training targets and package them
+into (image, labels) tuple for RetinaNet.
+
+T.-Y. Lin, P. Goyal, R. Girshick, K. He, and P. Dollar
+Focal Loss for Dense Object Detection. arXiv:1708.02002
+"""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import anchor
+from official.legacy.detection.dataloader import mode_keys as ModeKeys
+from official.legacy.detection.dataloader import tf_example_decoder
+from official.legacy.detection.utils import box_utils
+from official.legacy.detection.utils import input_utils
+
+
+def process_source_id(source_id):
+ """Processes source_id to the right format."""
+ if source_id.dtype == tf.string:
+ source_id = tf.cast(tf.strings.to_number(source_id), tf.int32)
+ with tf.control_dependencies([source_id]):
+ source_id = tf.cond(
+ pred=tf.equal(tf.size(input=source_id), 0),
+ true_fn=lambda: tf.cast(tf.constant(-1), tf.int32),
+ false_fn=lambda: tf.identity(source_id))
+ return source_id
+
+
+def pad_groundtruths_to_fixed_size(gt, n):
+ """Pads the first dimension of groundtruths labels to the fixed size."""
+ gt['boxes'] = input_utils.pad_to_fixed_size(gt['boxes'], n, -1)
+ gt['is_crowds'] = input_utils.pad_to_fixed_size(gt['is_crowds'], n, 0)
+ gt['areas'] = input_utils.pad_to_fixed_size(gt['areas'], n, -1)
+ gt['classes'] = input_utils.pad_to_fixed_size(gt['classes'], n, -1)
+ return gt
+
+
+class Parser(object):
+ """Parser to parse an image and its annotations into a dictionary of tensors."""
+
+ def __init__(self,
+ output_size,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ match_threshold=0.5,
+ unmatched_threshold=0.5,
+ aug_rand_hflip=False,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ use_autoaugment=False,
+ autoaugment_policy_name='v0',
+ skip_crowd_during_training=True,
+ max_num_instances=100,
+ use_bfloat16=True,
+ mode=None):
+ """Initializes parameters for parsing annotations in the dataset.
+
+ Args:
+ output_size: `Tensor` or `list` for [height, width] of output image. The
+ output_size should be divided by the largest feature stride 2^max_level.
+ min_level: `int` number of minimum level of the output feature pyramid.
+ max_level: `int` number of maximum level of the output feature pyramid.
+ num_scales: `int` number representing intermediate scales added on each
+ level. For instances, num_scales=2 adds one additional intermediate
+ anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: `list` of float numbers representing the aspect raito
+ anchors added on each level. The number indicates the ratio of width to
+ height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
+ on each scale level.
+ anchor_size: `float` number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ match_threshold: `float` number between 0 and 1 representing the
+ lower-bound threshold to assign positive labels for anchors. An anchor
+ with a score over the threshold is labeled positive.
+ unmatched_threshold: `float` number between 0 and 1 representing the
+ upper-bound threshold to assign negative labels for anchors. An anchor
+ with a score below the threshold is labeled negative.
+ aug_rand_hflip: `bool`, if True, augment training with random horizontal
+ flip.
+ aug_scale_min: `float`, the minimum scale applied to `output_size` for
+ data augmentation during training.
+ aug_scale_max: `float`, the maximum scale applied to `output_size` for
+ data augmentation during training.
+ use_autoaugment: `bool`, if True, use the AutoAugment augmentation policy
+ during training.
+ autoaugment_policy_name: `string` that specifies the name of the
+ AutoAugment policy that will be used during training.
+ skip_crowd_during_training: `bool`, if True, skip annotations labeled with
+ `is_crowd` equals to 1.
+ max_num_instances: `int` number of maximum number of instances in an
+ image. The groundtruth data will be padded to `max_num_instances`.
+ use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
+ mode: a ModeKeys. Specifies if this is training, evaluation, prediction or
+ prediction with groundtruths in the outputs.
+ """
+ self._mode = mode
+ self._max_num_instances = max_num_instances
+ self._skip_crowd_during_training = skip_crowd_during_training
+ self._is_training = (mode == ModeKeys.TRAIN)
+
+ self._example_decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=False)
+
+ # Anchor.
+ self._output_size = output_size
+ self._min_level = min_level
+ self._max_level = max_level
+ self._num_scales = num_scales
+ self._aspect_ratios = aspect_ratios
+ self._anchor_size = anchor_size
+ self._match_threshold = match_threshold
+ self._unmatched_threshold = unmatched_threshold
+
+ # Data augmentation.
+ self._aug_rand_hflip = aug_rand_hflip
+ self._aug_scale_min = aug_scale_min
+ self._aug_scale_max = aug_scale_max
+
+ # Data Augmentation with AutoAugment.
+ self._use_autoaugment = use_autoaugment
+ self._autoaugment_policy_name = autoaugment_policy_name
+
+ # Device.
+ self._use_bfloat16 = use_bfloat16
+
+ # Data is parsed depending on the model Modekey.
+ if mode == ModeKeys.TRAIN:
+ self._parse_fn = self._parse_train_data
+ elif mode == ModeKeys.EVAL:
+ self._parse_fn = self._parse_eval_data
+ elif mode == ModeKeys.PREDICT or mode == ModeKeys.PREDICT_WITH_GT:
+ self._parse_fn = self._parse_predict_data
+ else:
+ raise ValueError('mode is not defined.')
+
+ def __call__(self, value):
+ """Parses data to an image and associated training labels.
+
+ Args:
+ value: a string tensor holding a serialized tf.Example proto.
+
+ Returns:
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ labels:
+ cls_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ box_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ num_positives: number of positive anchors in the image.
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each level.
+ image_info: a 2D `Tensor` that encodes the information of the image and
+ the applied preprocessing. It is in the format of
+ [[original_height, original_width], [scaled_height, scaled_width],
+ [y_scale, x_scale], [y_offset, x_offset]].
+ groundtruths:
+ source_id: source image id. Default value -1 if the source id is empty
+ in the groundtruth annotation.
+ boxes: groundtruth bounding box annotations. The box is represented in
+ [y1, x1, y2, x2] format. The tennsor is padded with -1 to the fixed
+ dimension [self._max_num_instances, 4].
+ classes: groundtruth classes annotations. The tennsor is padded with
+ -1 to the fixed dimension [self._max_num_instances].
+ areas: groundtruth areas annotations. The tennsor is padded with -1
+ to the fixed dimension [self._max_num_instances].
+ is_crowds: groundtruth annotations to indicate if an annotation
+ represents a group of instances by value {0, 1}. The tennsor is
+ padded with 0 to the fixed dimension [self._max_num_instances].
+ """
+ with tf.name_scope('parser'):
+ data = self._example_decoder.decode(value)
+ return self._parse_fn(data)
+
+ def _parse_train_data(self, data):
+ """Parses data for training and evaluation."""
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ is_crowds = data['groundtruth_is_crowd']
+ # Skips annotations with `is_crowd` = True.
+ if self._skip_crowd_during_training and self._is_training:
+ num_groundtrtuhs = tf.shape(input=classes)[0]
+ with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
+ indices = tf.cond(
+ pred=tf.greater(tf.size(input=is_crowds), 0),
+ true_fn=lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
+ false_fn=lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
+ classes = tf.gather(classes, indices)
+ boxes = tf.gather(boxes, indices)
+
+ # Gets original image and its size.
+ image = data['image']
+
+ image_shape = tf.shape(input=image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Flips image randomly during training.
+ if self._aug_rand_hflip:
+ image, boxes = input_utils.random_horizontal_flip(image, boxes)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(self._output_size,
+ 2**self._max_level),
+ aug_scale_min=self._aug_scale_min,
+ aug_scale_max=self._aug_scale_max)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # Resizes and crops boxes.
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
+ image_info[1, :], offset)
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(self._min_level, self._max_level,
+ self._num_scales, self._aspect_ratios,
+ self._anchor_size, (image_height, image_width))
+ anchor_labeler = anchor.AnchorLabeler(input_anchor, self._match_threshold,
+ self._unmatched_threshold)
+ (cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
+ boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'num_positives': num_positives,
+ 'image_info': image_info,
+ }
+ return image, labels
+
+ def _parse_eval_data(self, data):
+ """Parses data for training and evaluation."""
+ groundtruths = {}
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(input=image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(self._output_size,
+ 2**self._max_level),
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # Resizes and crops boxes.
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
+ image_info[1, :], offset)
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(self._min_level, self._max_level,
+ self._num_scales, self._aspect_ratios,
+ self._anchor_size, (image_height, image_width))
+ anchor_labeler = anchor.AnchorLabeler(input_anchor, self._match_threshold,
+ self._unmatched_threshold)
+ (cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
+ boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Sets up groundtruth data for evaluation.
+ groundtruths = {
+ 'source_id':
+ data['source_id'],
+ 'num_groundtrtuhs':
+ tf.shape(data['groundtruth_classes']),
+ 'image_info':
+ image_info,
+ 'boxes':
+ box_utils.denormalize_boxes(data['groundtruth_boxes'], image_shape),
+ 'classes':
+ data['groundtruth_classes'],
+ 'areas':
+ data['groundtruth_area'],
+ 'is_crowds':
+ tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = process_source_id(groundtruths['source_id'])
+ groundtruths = pad_groundtruths_to_fixed_size(groundtruths,
+ self._max_num_instances)
+
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'num_positives': num_positives,
+ 'image_info': image_info,
+ 'groundtruths': groundtruths,
+ }
+ return image, labels
+
+ def _parse_predict_data(self, data):
+ """Parses data for prediction."""
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(input=image)[0:2]
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ padded_size=input_utils.compute_padded_size(self._output_size,
+ 2**self._max_level),
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_height, image_width, _ = image.get_shape().as_list()
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ # Compute Anchor boxes.
+ input_anchor = anchor.Anchor(self._min_level, self._max_level,
+ self._num_scales, self._aspect_ratios,
+ self._anchor_size, (image_height, image_width))
+
+ labels = {
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'image_info': image_info,
+ }
+ # If mode is PREDICT_WITH_GT, returns groundtruths and training targets
+ # in labels.
+ if self._mode == ModeKeys.PREDICT_WITH_GT:
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(data['groundtruth_boxes'],
+ image_shape)
+ groundtruths = {
+ 'source_id': data['source_id'],
+ 'num_detections': tf.shape(data['groundtruth_classes']),
+ 'boxes': boxes,
+ 'classes': data['groundtruth_classes'],
+ 'areas': data['groundtruth_area'],
+ 'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = process_source_id(groundtruths['source_id'])
+ groundtruths = pad_groundtruths_to_fixed_size(groundtruths,
+ self._max_num_instances)
+ labels['groundtruths'] = groundtruths
+
+ # Computes training objective for evaluation loss.
+ classes = data['groundtruth_classes']
+
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+ boxes = input_utils.resize_and_crop_boxes(boxes, image_scale,
+ image_info[1, :], offset)
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+
+ # Assigns anchors.
+ anchor_labeler = anchor.AnchorLabeler(input_anchor, self._match_threshold,
+ self._unmatched_threshold)
+ (cls_targets, box_targets, num_positives) = anchor_labeler.label_anchors(
+ boxes, tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+ labels['cls_targets'] = cls_targets
+ labels['box_targets'] = box_targets
+ labels['num_positives'] = num_positives
+ return image, labels
diff --git a/modeling/official/legacy/detection/dataloader/shapemask_parser.py b/modeling/official/legacy/detection/dataloader/shapemask_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e1e82777fddfa4ea11ed5ff903bb54eb09ec385
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/shapemask_parser.py
@@ -0,0 +1,521 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Data parser and processing.
+
+Parse image and ground truths in a dataset to training targets and package them
+into (image, labels) tuple for ShapeMask.
+
+Weicheng Kuo, Anelia Angelova, Jitendra Malik, Tsung-Yi Lin
+ShapeMask: Learning to Segment Novel Objects by Refining Shape Priors.
+arXiv:1904.03239.
+"""
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import anchor
+from official.legacy.detection.dataloader import mode_keys as ModeKeys
+from official.legacy.detection.dataloader import tf_example_decoder
+from official.legacy.detection.utils import box_utils
+from official.legacy.detection.utils import class_utils
+from official.legacy.detection.utils import dataloader_utils
+from official.legacy.detection.utils import input_utils
+
+
+def pad_to_size(input_tensor, size):
+ """Pads data with zeros to a given length at the first dimension if needed.
+
+ Args:
+ input_tensor: `Tensor` with any dimension.
+ size: `int` number for the first dimension of output Tensor.
+
+ Returns:
+ `Tensor` with the first dimension padded to `size` if the first diemsion
+ is less than `size`, otherwise no padding.
+ """
+ input_shape = tf.shape(input_tensor)
+ padding_shape = []
+
+ # Computes the padding length on the first dimension.
+ padding_length = tf.maximum(0, size - tf.shape(input_tensor)[0])
+ assert_length = tf.Assert(
+ tf.greater_equal(padding_length, 0), [padding_length])
+ with tf.control_dependencies([assert_length]):
+ padding_shape.append(padding_length)
+
+ # Copies shapes of the rest of input shape dimensions.
+ for i in range(1, len(input_shape)):
+ padding_shape.append(tf.shape(input=input_tensor)[i])
+
+ # Pads input tensor to the fixed first dimension.
+ paddings = tf.cast(tf.zeros(padding_shape), input_tensor.dtype)
+ padded_tensor = tf.concat([input_tensor, paddings], axis=0)
+ return padded_tensor
+
+
+class Parser(object):
+ """ShapeMask Parser to parse an image and its annotations into a dictionary of tensors."""
+
+ def __init__(self,
+ output_size,
+ min_level,
+ max_level,
+ num_scales,
+ aspect_ratios,
+ anchor_size,
+ use_category=True,
+ outer_box_scale=1.0,
+ box_jitter_scale=0.025,
+ num_sampled_masks=8,
+ mask_crop_size=32,
+ mask_min_level=3,
+ mask_max_level=5,
+ upsample_factor=4,
+ match_threshold=0.5,
+ unmatched_threshold=0.5,
+ aug_rand_hflip=False,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ skip_crowd_during_training=True,
+ max_num_instances=100,
+ use_bfloat16=True,
+ mask_train_class='all',
+ mode=None):
+ """Initializes parameters for parsing annotations in the dataset.
+
+ Args:
+ output_size: `Tensor` or `list` for [height, width] of output image. The
+ output_size should be divided by the largest feature stride 2^max_level.
+ min_level: `int` number of minimum level of the output feature pyramid.
+ max_level: `int` number of maximum level of the output feature pyramid.
+ num_scales: `int` number representing intermediate scales added
+ on each level. For instances, num_scales=2 adds one additional
+ intermediate anchor scales [2^0, 2^0.5] on each level.
+ aspect_ratios: `list` of float numbers representing the aspect raito
+ anchors added on each level. The number indicates the ratio of width to
+ height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
+ on each scale level.
+ anchor_size: `float` number representing the scale of size of the base
+ anchor to the feature stride 2^level.
+ use_category: if `False`, treat all object in all classes in one
+ foreground category.
+ outer_box_scale: `float` number in a range of [1.0, inf) representing
+ the scale from object box to outer box. The mask branch predicts
+ instance mask enclosed in outer box.
+ box_jitter_scale: `float` number representing the noise magnitude to
+ jitter the training groundtruth boxes for mask branch.
+ num_sampled_masks: `int` number of sampled masks for training.
+ mask_crop_size: `list` for [height, width] of output training masks.
+ mask_min_level: `int` number indicating the minimum feature level to
+ obtain instance features.
+ mask_max_level: `int` number indicating the maximum feature level to
+ obtain instance features.
+ upsample_factor: `int` factor of upsampling the fine mask predictions.
+ match_threshold: `float` number between 0 and 1 representing the
+ lower-bound threshold to assign positive labels for anchors. An anchor
+ with a score over the threshold is labeled positive.
+ unmatched_threshold: `float` number between 0 and 1 representing the
+ upper-bound threshold to assign negative labels for anchors. An anchor
+ with a score below the threshold is labeled negative.
+ aug_rand_hflip: `bool`, if True, augment training with random
+ horizontal flip.
+ aug_scale_min: `float`, the minimum scale applied to `output_size` for
+ data augmentation during training.
+ aug_scale_max: `float`, the maximum scale applied to `output_size` for
+ data augmentation during training.
+ skip_crowd_during_training: `bool`, if True, skip annotations labeled with
+ `is_crowd` equals to 1.
+ max_num_instances: `int` number of maximum number of instances in an
+ image. The groundtruth data will be padded to `max_num_instances`.
+ use_bfloat16: `bool`, if True, cast output image to tf.bfloat16.
+ mask_train_class: a string of experiment mode: `all`, `voc` or `nonvoc`.
+ mode: a ModeKeys. Specifies if this is training, evaluation, prediction
+ or prediction with groundtruths in the outputs.
+ """
+ self._mode = mode
+ self._mask_train_class = mask_train_class
+ self._max_num_instances = max_num_instances
+ self._skip_crowd_during_training = skip_crowd_during_training
+ self._is_training = (mode == ModeKeys.TRAIN)
+
+ self._example_decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=True)
+
+ # Anchor.
+ self._output_size = output_size
+ self._min_level = min_level
+ self._max_level = max_level
+ self._num_scales = num_scales
+ self._aspect_ratios = aspect_ratios
+ self._anchor_size = anchor_size
+ self._match_threshold = match_threshold
+ self._unmatched_threshold = unmatched_threshold
+
+ # Data augmentation.
+ self._aug_rand_hflip = aug_rand_hflip
+ self._aug_scale_min = aug_scale_min
+ self._aug_scale_max = aug_scale_max
+
+ # Device.
+ self._use_bfloat16 = use_bfloat16
+
+ # ShapeMask specific.
+ # Control of which category to use.
+ self._use_category = use_category
+ self._num_sampled_masks = num_sampled_masks
+ self._mask_crop_size = mask_crop_size
+ self._mask_min_level = mask_min_level
+ self._mask_max_level = mask_max_level
+ self._outer_box_scale = outer_box_scale
+ self._box_jitter_scale = box_jitter_scale
+ self._up_sample_factor = upsample_factor
+
+ # Data is parsed depending on the model Modekey.
+ if mode == ModeKeys.TRAIN:
+ self._parse_fn = self._parse_train_data
+ elif mode == ModeKeys.EVAL:
+ self._parse_fn = self._parse_eval_data
+ elif mode == ModeKeys.PREDICT or mode == ModeKeys.PREDICT_WITH_GT:
+ self._parse_fn = self._parse_predict_data
+ else:
+ raise ValueError('mode is not defined.')
+
+ def __call__(self, value):
+ """Parses data to an image and associated training labels.
+
+ Args:
+ value: a string tensor holding a serialized tf.Example proto.
+
+ Returns:
+ inputs:
+ image: image tensor that is preproessed to have normalized value and
+ dimension [output_size[0], output_size[1], 3]
+ mask_boxes: sampled boxes that tightly enclose the training masks. The
+ box is represented in [y1, x1, y2, x2] format. The tensor is sampled
+ to the fixed dimension [self._num_sampled_masks, 4].
+ mask_outer_boxes: loose box that enclose sampled tight box. The
+ box is represented in [y1, x1, y2, x2] format. The tensor is sampled
+ to the fixed dimension [self._num_sampled_masks, 4].
+ mask_classes: the class ids of sampled training masks. The tensor has
+ shape [self._num_sampled_masks].
+ labels:
+ cls_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location]. The height_l and
+ width_l represent the dimension of class logits at l-th level.
+ box_targets: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, anchors_per_location * 4]. The height_l and
+ width_l represent the dimension of bounding box regression output at
+ l-th level.
+ num_positives: number of positive anchors in the image.
+ anchor_boxes: ordered dictionary with keys
+ [min_level, min_level+1, ..., max_level]. The values are tensor with
+ shape [height_l, width_l, 4] representing anchor boxes at each level.
+ image_scale: 2D float `Tensor` representing scale factors that apply
+ to [height, width] of input image.
+ mask_targets: training binary mask targets. The tensor has shape
+ [self._num_sampled_masks, self._mask_crop_size, self._mask_crop_size].
+ mask_is_valid: the binary tensor to indicate if the sampled masks are
+ valide. The sampled masks are invalid when no mask annotations are
+ included in the image. The tensor has shape [1].
+ groundtruths:
+ source_id: source image id. Default value -1 if the source id is empty
+ in the groundtruth annotation.
+ boxes: groundtruth bounding box annotations. The box is represented in
+ [y1, x1, y2, x2] format. The tensor is padded with -1 to the fixed
+ dimension [self._max_num_instances, 4].
+ classes: groundtruth classes annotations. The tensor is padded with
+ -1 to the fixed dimension [self._max_num_instances].
+ areas: groundtruth areas annotations. The tensor is padded with -1
+ to the fixed dimension [self._max_num_instances].
+ is_crowds: groundtruth annotations to indicate if an annotation
+ represents a group of instances by value {0, 1}. The tensor is
+ padded with 0 to the fixed dimension [self._max_num_instances].
+ """
+ with tf.name_scope('parser'):
+ data = self._example_decoder.decode(value)
+ return self._parse_fn(data)
+
+ def _parse_train_data(self, data):
+ """Parse data for ShapeMask training."""
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ masks = data['groundtruth_instance_masks']
+ is_crowds = data['groundtruth_is_crowd']
+ # Skips annotations with `is_crowd` = True.
+ if self._skip_crowd_during_training and self._is_training:
+ num_groundtrtuhs = tf.shape(classes)[0]
+ with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
+ indices = tf.cond(
+ tf.greater(tf.size(is_crowds), 0),
+ lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
+ lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
+ classes = tf.gather(classes, indices)
+ boxes = tf.gather(boxes, indices)
+ masks = tf.gather(masks, indices)
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # If not using category, makes all categories with id = 0.
+ if not self._use_category:
+ classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Flips image randomly during training.
+ if self._aug_rand_hflip:
+ image, boxes, masks = input_utils.random_horizontal_flip(
+ image, boxes, masks)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ self._output_size,
+ aug_scale_min=self._aug_scale_min,
+ aug_scale_max=self._aug_scale_max)
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+
+ # Resizes and crops boxes and masks.
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+ masks = tf.gather(masks, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(
+ self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size, self._output_size)
+ anchor_labeler = anchor.AnchorLabeler(
+ input_anchor, self._match_threshold, self._unmatched_threshold)
+ (cls_targets,
+ box_targets,
+ num_positives) = anchor_labeler.label_anchors(
+ boxes,
+ tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+
+ # Sample groundtruth masks/boxes/classes for mask branch.
+ num_masks = tf.shape(masks)[0]
+ mask_shape = tf.shape(masks)[1:3]
+
+ # Pad sampled boxes/masks/classes to a constant batch size.
+ padded_boxes = pad_to_size(boxes, self._num_sampled_masks)
+ padded_classes = pad_to_size(classes, self._num_sampled_masks)
+ padded_masks = pad_to_size(masks, self._num_sampled_masks)
+
+ # Randomly sample groundtruth masks for mask branch training. For the image
+ # without groundtruth masks, it will sample the dummy padded tensors.
+ rand_indices = tf.random.shuffle(
+ tf.range(tf.maximum(num_masks, self._num_sampled_masks)))
+ rand_indices = tf.math.mod(rand_indices, tf.maximum(num_masks, 1))
+ rand_indices = rand_indices[0:self._num_sampled_masks]
+ rand_indices = tf.reshape(rand_indices, [self._num_sampled_masks])
+
+ sampled_boxes = tf.gather(padded_boxes, rand_indices)
+ sampled_classes = tf.gather(padded_classes, rand_indices)
+ sampled_masks = tf.gather(padded_masks, rand_indices)
+ # Jitter the sampled boxes to mimic the noisy detections.
+ sampled_boxes = box_utils.jitter_boxes(
+ sampled_boxes, noise_scale=self._box_jitter_scale)
+ sampled_boxes = box_utils.clip_boxes(sampled_boxes, self._output_size)
+ # Compute mask targets in feature crop. A feature crop fully contains a
+ # sampled box.
+ mask_outer_boxes = box_utils.compute_outer_boxes(
+ sampled_boxes, tf.shape(image)[0:2], scale=self._outer_box_scale)
+ mask_outer_boxes = box_utils.clip_boxes(mask_outer_boxes, self._output_size)
+ # Compensate the offset of mask_outer_boxes to map it back to original image
+ # scale.
+ mask_outer_boxes_ori = mask_outer_boxes
+ mask_outer_boxes_ori += tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
+ mask_outer_boxes_ori /= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
+ norm_mask_outer_boxes_ori = box_utils.normalize_boxes(
+ mask_outer_boxes_ori, mask_shape)
+
+ # Set sampled_masks shape to [batch_size, height, width, 1].
+ sampled_masks = tf.cast(tf.expand_dims(sampled_masks, axis=-1), tf.float32)
+ mask_targets = tf.image.crop_and_resize(
+ sampled_masks,
+ norm_mask_outer_boxes_ori,
+ box_indices=tf.range(self._num_sampled_masks),
+ crop_size=[self._mask_crop_size, self._mask_crop_size],
+ method='bilinear',
+ extrapolation_value=0,
+ name='train_mask_targets')
+ mask_targets = tf.where(tf.greater_equal(mask_targets, 0.5),
+ tf.ones_like(mask_targets),
+ tf.zeros_like(mask_targets))
+ mask_targets = tf.squeeze(mask_targets, axis=-1)
+ if self._up_sample_factor > 1:
+ fine_mask_targets = tf.image.crop_and_resize(
+ sampled_masks,
+ norm_mask_outer_boxes_ori,
+ box_indices=tf.range(self._num_sampled_masks),
+ crop_size=[
+ self._mask_crop_size * self._up_sample_factor,
+ self._mask_crop_size * self._up_sample_factor
+ ],
+ method='bilinear',
+ extrapolation_value=0,
+ name='train_mask_targets')
+ fine_mask_targets = tf.where(
+ tf.greater_equal(fine_mask_targets, 0.5),
+ tf.ones_like(fine_mask_targets), tf.zeros_like(fine_mask_targets))
+ fine_mask_targets = tf.squeeze(fine_mask_targets, axis=-1)
+ else:
+ fine_mask_targets = mask_targets
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ valid_image = tf.cast(tf.not_equal(num_masks, 0), tf.int32)
+ if self._mask_train_class == 'all':
+ mask_is_valid = valid_image * tf.ones_like(sampled_classes, tf.int32)
+ else:
+ # Get the intersection of sampled classes with training splits.
+ mask_valid_classes = tf.cast(
+ tf.expand_dims(
+ class_utils.coco_split_class_ids(self._mask_train_class), 1),
+ sampled_classes.dtype)
+ match = tf.reduce_any(
+ tf.equal(tf.expand_dims(sampled_classes, 0), mask_valid_classes), 0)
+ mask_is_valid = valid_image * tf.cast(match, tf.int32)
+
+ # Packs labels for model_fn outputs.
+ labels = {
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'num_positives': num_positives,
+ 'image_info': image_info,
+ # For ShapeMask.
+ 'mask_targets': mask_targets,
+ 'fine_mask_targets': fine_mask_targets,
+ 'mask_is_valid': mask_is_valid,
+ }
+
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ 'mask_boxes': sampled_boxes,
+ 'mask_outer_boxes': mask_outer_boxes,
+ 'mask_classes': sampled_classes,
+ }
+ return inputs, labels
+
+ def _parse_predict_data(self, data):
+ """Parse data for ShapeMask training."""
+ classes = data['groundtruth_classes']
+ boxes = data['groundtruth_boxes']
+ masks = data['groundtruth_instance_masks']
+
+ # Gets original image and its size.
+ image = data['image']
+ image_shape = tf.shape(image)[0:2]
+
+ # If not using category, makes all categories with id = 0.
+ if not self._use_category:
+ classes = tf.cast(tf.greater(classes, 0), dtype=tf.float32)
+
+ # Normalizes image with mean and std pixel values.
+ image = input_utils.normalize_image(image)
+
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ boxes = box_utils.denormalize_boxes(boxes, image_shape)
+
+ # Resizes and crops image.
+ image, image_info = input_utils.resize_and_crop_image(
+ image,
+ self._output_size,
+ self._output_size,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0)
+ image_scale = image_info[2, :]
+ offset = image_info[3, :]
+
+ # Resizes and crops boxes and masks.
+ boxes = input_utils.resize_and_crop_boxes(
+ boxes, image_scale, image_info[1, :], offset)
+ masks = input_utils.resize_and_crop_masks(
+ tf.expand_dims(masks, axis=-1), image_scale, self._output_size, offset)
+
+ # Filters out ground truth boxes that are all zeros.
+ indices = box_utils.get_non_empty_box_indices(boxes)
+ boxes = tf.gather(boxes, indices)
+ classes = tf.gather(classes, indices)
+
+ # Assigns anchors.
+ input_anchor = anchor.Anchor(
+ self._min_level, self._max_level, self._num_scales,
+ self._aspect_ratios, self._anchor_size, self._output_size)
+ anchor_labeler = anchor.AnchorLabeler(
+ input_anchor, self._match_threshold, self._unmatched_threshold)
+
+ # If bfloat16 is used, casts input image to tf.bfloat16.
+ if self._use_bfloat16:
+ image = tf.cast(image, dtype=tf.bfloat16)
+
+ labels = {
+ 'anchor_boxes': input_anchor.multilevel_boxes,
+ 'image_info': image_info,
+ }
+ if self._mode == ModeKeys.PREDICT_WITH_GT:
+ # Converts boxes from normalized coordinates to pixel coordinates.
+ groundtruths = {
+ 'source_id': data['source_id'],
+ 'height': data['height'],
+ 'width': data['width'],
+ 'num_detections': tf.shape(data['groundtruth_classes']),
+ 'boxes': box_utils.denormalize_boxes(
+ data['groundtruth_boxes'], image_shape),
+ 'classes': data['groundtruth_classes'],
+ # 'masks': tf.squeeze(masks, axis=-1),
+ 'areas': data['groundtruth_area'],
+ 'is_crowds': tf.cast(data['groundtruth_is_crowd'], tf.int32),
+ }
+ groundtruths['source_id'] = dataloader_utils.process_source_id(
+ groundtruths['source_id'])
+ groundtruths = dataloader_utils.pad_groundtruths_to_fixed_size(
+ groundtruths, self._max_num_instances)
+ # Computes training labels.
+ (cls_targets,
+ box_targets,
+ num_positives) = anchor_labeler.label_anchors(
+ boxes,
+ tf.cast(tf.expand_dims(classes, axis=1), tf.float32))
+ # Packs labels for model_fn outputs.
+ labels.update({
+ 'cls_targets': cls_targets,
+ 'box_targets': box_targets,
+ 'num_positives': num_positives,
+ 'groundtruths': groundtruths,
+ })
+
+ inputs = {
+ 'image': image,
+ 'image_info': image_info,
+ }
+
+ return inputs, labels
diff --git a/modeling/official/legacy/detection/dataloader/tf_example_decoder.py b/modeling/official/legacy/detection/dataloader/tf_example_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d0dbbb6c53764e66a1c797312a0e559db5cd527
--- /dev/null
+++ b/modeling/official/legacy/detection/dataloader/tf_example_decoder.py
@@ -0,0 +1,156 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""Tensorflow Example proto decoder for object detection.
+
+A decoder to decode string tensors containing serialized tensorflow.Example
+protos for object detection.
+"""
+import tensorflow as tf, tf_keras
+
+
+class TfExampleDecoder(object):
+ """Tensorflow Example proto decoder."""
+
+ def __init__(self, include_mask=False):
+ self._include_mask = include_mask
+ self._keys_to_features = {
+ 'image/encoded':
+ tf.io.FixedLenFeature((), tf.string),
+ 'image/source_id':
+ tf.io.FixedLenFeature((), tf.string),
+ 'image/height':
+ tf.io.FixedLenFeature((), tf.int64),
+ 'image/width':
+ tf.io.FixedLenFeature((), tf.int64),
+ 'image/object/bbox/xmin':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/bbox/xmax':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/bbox/ymin':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/bbox/ymax':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/class/label':
+ tf.io.VarLenFeature(tf.int64),
+ 'image/object/area':
+ tf.io.VarLenFeature(tf.float32),
+ 'image/object/is_crowd':
+ tf.io.VarLenFeature(tf.int64),
+ }
+ if include_mask:
+ self._keys_to_features.update({
+ 'image/object/mask':
+ tf.io.VarLenFeature(tf.string),
+ })
+
+ def _decode_image(self, parsed_tensors):
+ """Decodes the image and set its static shape."""
+ image = tf.io.decode_image(parsed_tensors['image/encoded'], channels=3)
+ image.set_shape([None, None, 3])
+ return image
+
+ def _decode_boxes(self, parsed_tensors):
+ """Concat box coordinates in the format of [ymin, xmin, ymax, xmax]."""
+ xmin = parsed_tensors['image/object/bbox/xmin']
+ xmax = parsed_tensors['image/object/bbox/xmax']
+ ymin = parsed_tensors['image/object/bbox/ymin']
+ ymax = parsed_tensors['image/object/bbox/ymax']
+ return tf.stack([ymin, xmin, ymax, xmax], axis=-1)
+
+ def _decode_masks(self, parsed_tensors):
+ """Decode a set of PNG masks to the tf.float32 tensors."""
+ def _decode_png_mask(png_bytes):
+ mask = tf.squeeze(
+ tf.io.decode_png(png_bytes, channels=1, dtype=tf.uint8), axis=-1)
+ mask = tf.cast(mask, dtype=tf.float32)
+ mask.set_shape([None, None])
+ return mask
+
+ height = parsed_tensors['image/height']
+ width = parsed_tensors['image/width']
+ masks = parsed_tensors['image/object/mask']
+ return tf.cond(
+ pred=tf.greater(tf.size(input=masks), 0),
+ true_fn=lambda: tf.map_fn(_decode_png_mask, masks, dtype=tf.float32),
+ false_fn=lambda: tf.zeros([0, height, width], dtype=tf.float32))
+
+ def _decode_areas(self, parsed_tensors):
+ xmin = parsed_tensors['image/object/bbox/xmin']
+ xmax = parsed_tensors['image/object/bbox/xmax']
+ ymin = parsed_tensors['image/object/bbox/ymin']
+ ymax = parsed_tensors['image/object/bbox/ymax']
+ return tf.cond(
+ tf.greater(tf.shape(parsed_tensors['image/object/area'])[0], 0),
+ lambda: parsed_tensors['image/object/area'],
+ lambda: (xmax - xmin) * (ymax - ymin))
+
+ def decode(self, serialized_example):
+ """Decode the serialized example.
+
+ Args:
+ serialized_example: a single serialized tf.Example string.
+
+ Returns:
+ decoded_tensors: a dictionary of tensors with the following fields:
+ - image: a uint8 tensor of shape [None, None, 3].
+ - source_id: a string scalar tensor.
+ - height: an integer scalar tensor.
+ - width: an integer scalar tensor.
+ - groundtruth_classes: a int64 tensor of shape [None].
+ - groundtruth_is_crowd: a bool tensor of shape [None].
+ - groundtruth_area: a float32 tensor of shape [None].
+ - groundtruth_boxes: a float32 tensor of shape [None, 4].
+ - groundtruth_instance_masks: a float32 tensor of shape
+ [None, None, None].
+ - groundtruth_instance_masks_png: a string tensor of shape [None].
+ """
+ parsed_tensors = tf.io.parse_single_example(
+ serialized=serialized_example, features=self._keys_to_features)
+ for k in parsed_tensors:
+ if isinstance(parsed_tensors[k], tf.SparseTensor):
+ if parsed_tensors[k].dtype == tf.string:
+ parsed_tensors[k] = tf.sparse.to_dense(
+ parsed_tensors[k], default_value='')
+ else:
+ parsed_tensors[k] = tf.sparse.to_dense(
+ parsed_tensors[k], default_value=0)
+
+ image = self._decode_image(parsed_tensors)
+ boxes = self._decode_boxes(parsed_tensors)
+ areas = self._decode_areas(parsed_tensors)
+ is_crowds = tf.cond(
+ tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0),
+ lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool),
+ lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool)) # pylint: disable=line-too-long
+ if self._include_mask:
+ masks = self._decode_masks(parsed_tensors)
+
+ decoded_tensors = {
+ 'image': image,
+ 'source_id': parsed_tensors['image/source_id'],
+ 'height': parsed_tensors['image/height'],
+ 'width': parsed_tensors['image/width'],
+ 'groundtruth_classes': parsed_tensors['image/object/class/label'],
+ 'groundtruth_is_crowd': is_crowds,
+ 'groundtruth_area': areas,
+ 'groundtruth_boxes': boxes,
+ }
+ if self._include_mask:
+ decoded_tensors.update({
+ 'groundtruth_instance_masks': masks,
+ 'groundtruth_instance_masks_png': parsed_tensors['image/object/mask'],
+ })
+ return decoded_tensors
diff --git a/modeling/official/legacy/detection/evaluation/__init__.py b/modeling/official/legacy/detection/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/evaluation/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/evaluation/coco_evaluator.py b/modeling/official/legacy/detection/evaluation/coco_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..83f3013585d4989c2f3eb352db6807c0f8140d4a
--- /dev/null
+++ b/modeling/official/legacy/detection/evaluation/coco_evaluator.py
@@ -0,0 +1,847 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The COCO-style evaluator.
+
+The following snippet demonstrates the use of interfaces:
+
+ evaluator = COCOEvaluator(...)
+ for _ in range(num_evals):
+ for _ in range(num_batches_per_eval):
+ predictions, groundtruth = predictor.predict(...) # pop a batch.
+ evaluator.update(predictions, groundtruths) # aggregate internal stats.
+ evaluator.evaluate() # finish one full eval.
+
+See also: https://github.com/cocodataset/cocoapi/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import atexit
+import copy
+import tempfile
+
+from absl import logging
+import numpy as np
+from pycocotools import cocoeval
+import six
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.evaluation import coco_utils
+from official.legacy.detection.utils import class_utils
+
+
+class OlnCOCOevalWrapper(cocoeval.COCOeval):
+ """COCOeval wrapper class.
+
+ Rewritten based on cocoapi: (pycocotools/cocoeval.py)
+
+ This class wraps COCOEVAL API object, which provides the following additional
+ functionalities:
+ 1. summarze 'all', 'seen', and 'novel' split output print-out, e.g., AR at
+ different K proposals, AR and AP resutls for 'seen' and 'novel' class
+ splits.
+ """
+
+ def __init__(self, coco_gt, coco_dt, iou_type='box'):
+ super(OlnCOCOevalWrapper, self).__init__(
+ cocoGt=coco_gt, cocoDt=coco_dt, iouType=iou_type)
+
+ def summarize(self):
+ """Compute and display summary metrics for evaluation results.
+
+ Delta to the standard cocoapi function:
+ More Averate Recall metrics are produced with different top-K proposals.
+ Note this functin can *only* be applied on the default parameter
+ setting.
+ Raises:
+ Exception: Please run accumulate() first.
+ """
+
+ def _summarize(ap=1, iou_thr=None, area_rng='all', max_dets=100):
+ p = self.params
+ i_str = (' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = '
+ '{:0.3f}')
+ title_str = 'Average Precision' if ap == 1 else 'Average Recall'
+ type_str = '(AP)' if ap == 1 else '(AR)'
+ iou_str = '{:0.2f}:{:0.2f}'.format(
+ p.iouThrs[0],
+ p.iouThrs[-1]) if iou_thr is None else '{:0.2f}'.format(iou_thr)
+
+ aind = [i for i, a_rng in enumerate(p.areaRngLbl) if a_rng == area_rng]
+ mind = [i for i, m_det in enumerate(p.maxDets) if m_det == max_dets]
+ if ap == 1:
+ # dimension of precision: [TxRxKxAxM]
+ s = self.eval['precision']
+ # IoU
+ if iou_thr is not None:
+ t = np.where(iou_thr == p.iouThrs)[0]
+ s = s[t]
+ s = s[:, :, :, aind, mind]
+ else:
+ # dimension of recall: [TxKxAxM]
+ s = self.eval['recall']
+ if iou_thr is not None:
+ t = np.where(iou_thr == p.iouThrs)[0]
+ s = s[t]
+ s = s[:, :, aind, mind]
+
+ if not (s[s > -1]).any():
+ mean_s = -1
+ else:
+ mean_s = np.mean(s[s > -1])
+ print(
+ i_str.format(title_str, type_str, iou_str, area_rng, max_dets,
+ mean_s))
+ return mean_s
+
+ def _summarize_dets():
+ stats = np.zeros((14,))
+ stats[0] = _summarize(1)
+ stats[1] = _summarize(
+ 1,
+ iou_thr=.5,
+ )
+ stats[2] = _summarize(
+ 1,
+ iou_thr=.75,
+ )
+ stats[3] = _summarize(
+ 1,
+ area_rng='small',
+ )
+ stats[4] = _summarize(
+ 1,
+ area_rng='medium',
+ )
+ stats[5] = _summarize(
+ 1,
+ area_rng='large',
+ )
+
+ stats[6] = _summarize(0, max_dets=self.params.maxDets[0]) # 10
+ stats[7] = _summarize(0, max_dets=self.params.maxDets[1]) # 20
+ stats[8] = _summarize(0, max_dets=self.params.maxDets[2]) # 50
+ stats[9] = _summarize(0, max_dets=self.params.maxDets[3]) # 100
+ stats[10] = _summarize(0, max_dets=self.params.maxDets[4]) # 200
+
+ stats[11] = _summarize(0, area_rng='small', max_dets=10)
+ stats[12] = _summarize(0, area_rng='medium', max_dets=10)
+ stats[13] = _summarize(0, area_rng='large', max_dets=10)
+ return stats
+
+ if not self.eval:
+ raise Exception('Please run accumulate() first')
+ summarize = _summarize_dets
+ self.stats = summarize()
+
+
+class OlnCOCOevalXclassWrapper(OlnCOCOevalWrapper):
+ """COCOeval wrapper class.
+
+ Rewritten based on cocoapi: (pycocotools/cocoeval.py)
+ Delta to the standard cocoapi:
+ Detections that hit the 'seen' class objects are ignored in top-K proposals.
+
+ This class wraps COCOEVAL API object, which provides the following additional
+ functionalities:
+ 1. Include ignore-class split (e.g., 'voc' or 'nonvoc').
+ 2. Do not count (or ignore) box proposals hitting ignore-class when
+ evaluating Average Recall at top-K proposals.
+ """
+
+ def __init__(self, coco_gt, coco_dt, iou_type='box'):
+ super(OlnCOCOevalXclassWrapper, self).__init__(
+ coco_gt=coco_gt, coco_dt=coco_dt, iou_type=iou_type)
+
+ def evaluateImg(self, img_id, cat_id, a_rng, max_det):
+ p = self.params
+ if p.useCats:
+ gt = self._gts[img_id, cat_id]
+ dt = self._dts[img_id, cat_id]
+ else:
+ gt, dt = [], []
+ for c_id in p.catIds:
+ gt.extend(self._gts[img_id, c_id])
+ dt.extend(self._dts[img_id, c_id])
+
+ if not gt and not dt:
+ return None
+
+ for g in gt:
+ if g['ignore'] or (g['area'] < a_rng[0] or g['area'] > a_rng[1]):
+ g['_ignore'] = 1
+ else:
+ g['_ignore'] = 0
+ # Class manipulation: ignore the 'ignored_split'.
+ if 'ignored_split' in g and g['ignored_split'] == 1:
+ g['_ignore'] = 1
+
+ # sort dt highest score first, sort gt ignore last
+ gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
+ gt = [gt[i] for i in gtind]
+ dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
+ dt = [dt[i] for i in dtind[0:max_det]]
+ iscrowd = [int(o['iscrowd']) for o in gt]
+ # load computed ious
+ # ious = self.ious[img_id, cat_id][:, gtind] if len(
+ # self.ious[img_id, cat_id]) > 0 else self.ious[img_id, cat_id]
+ if self.ious[img_id, cat_id].any():
+ ious = self.ious[img_id, cat_id][:, gtind]
+ else:
+ ious = self.ious[img_id, cat_id]
+
+ tt = len(p.iouThrs)
+ gg = len(gt)
+ dd = len(dt)
+ gtm = np.zeros((tt, gg))
+ dtm = np.zeros((tt, dd))
+ gt_ig = np.array([g['_ignore'] for g in gt])
+ dt_ig = np.zeros((tt, dd))
+ # indicator of whether the gt object class is of ignored_split or not.
+ gt_ig_split = np.array([g['ignored_split'] for g in gt])
+ dt_ig_split = np.zeros((dd))
+
+ if ious.any():
+ for tind, t in enumerate(p.iouThrs):
+ for dind, d in enumerate(dt):
+ # information about best match so far (m=-1 -> unmatched)
+ iou = min([t, 1 - 1e-10])
+ m = -1
+ for gind, g in enumerate(gt):
+ # if this gt already matched, and not a crowd, continue
+ if gtm[tind, gind] > 0 and not iscrowd[gind]:
+ continue
+ # if dt matched to reg gt, and on ignore gt, stop
+ if m > -1 and gt_ig[m] == 0 and gt_ig[gind] == 1:
+ break
+ # continue to next gt unless better match made
+ if ious[dind, gind] < iou:
+ continue
+ # if match successful and best so far, store appropriately
+ iou = ious[dind, gind]
+ m = gind
+ # if match made store id of match for both dt and gt
+ if m == -1:
+ continue
+ dt_ig[tind, dind] = gt_ig[m]
+ dtm[tind, dind] = gt[m]['id']
+ gtm[tind, m] = d['id']
+
+ # Activate to ignore the seen-class detections.
+ if tind == 0: # Register just only once: tind > 0 is also fine.
+ dt_ig_split[dind] = gt_ig_split[m]
+
+ # set unmatched detections outside of area range to ignore
+ a = np.array([d['area'] < a_rng[0] or d['area'] > a_rng[1] for d in dt
+ ]).reshape((1, len(dt)))
+ dt_ig = np.logical_or(dt_ig, np.logical_and(dtm == 0, np.repeat(a, tt, 0)))
+
+ # Activate to ignore the seen-class detections.
+ # Take only eval_split (eg, nonvoc) and ignore seen_split (eg, voc).
+ if dt_ig_split.sum() > 0:
+ dtm = dtm[:, dt_ig_split == 0]
+ dt_ig = dt_ig[:, dt_ig_split == 0]
+ len_dt = min(max_det, len(dt))
+ dt = [dt[i] for i in range(len_dt) if dt_ig_split[i] == 0]
+
+ # store results for given image and category
+ return {
+ 'image_id': img_id,
+ 'category_id': cat_id,
+ 'aRng': a_rng,
+ 'maxDet': max_det,
+ 'dtIds': [d['id'] for d in dt],
+ 'gtIds': [g['id'] for g in gt],
+ 'dtMatches': dtm,
+ 'gtMatches': gtm,
+ 'dtScores': [d['score'] for d in dt],
+ 'gtIgnore': gt_ig,
+ 'dtIgnore': dt_ig,
+ }
+
+
+class MetricWrapper(object):
+ """Metric Wrapper of the COCO evaluator."""
+ # This is only a wrapper for COCO metric and works on for numpy array. So it
+ # doesn't inherit from tf_keras.layers.Layer or tf_keras.metrics.Metric.
+
+ def __init__(self, evaluator):
+ self._evaluator = evaluator
+
+ def update_state(self, y_true, y_pred):
+ """Update internal states."""
+ labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
+ outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
+ groundtruths = {}
+ predictions = {}
+ for key, val in outputs.items():
+ if isinstance(val, tuple):
+ val = np.concatenate(val)
+ predictions[key] = val
+ for key, val in labels.items():
+ if isinstance(val, tuple):
+ val = np.concatenate(val)
+ groundtruths[key] = val
+ self._evaluator.update(predictions, groundtruths)
+
+ def result(self):
+ return self._evaluator.evaluate()
+
+ def reset_states(self):
+ return self._evaluator.reset()
+
+
+class COCOEvaluator(object):
+ """COCO evaluation metric class."""
+
+ def __init__(self, annotation_file, include_mask, need_rescale_bboxes=True):
+ """Constructs COCO evaluation class.
+
+ The class provides the interface to metrics_fn in TPUEstimator. The
+ _update_op() takes detections from each image and push them to
+ self.detections. The _evaluate() loads a JSON file in COCO annotation format
+ as the groundtruths and runs COCO evaluation.
+
+ Args:
+ annotation_file: a JSON file that stores annotations of the eval dataset.
+ If `annotation_file` is None, groundtruth annotations will be loaded
+ from the dataloader.
+ include_mask: a boolean to indicate whether or not to include the mask
+ eval.
+ need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
+ to absolute values (`image_info` is needed in this case).
+ """
+ if annotation_file:
+ if annotation_file.startswith('gs://'):
+ _, local_val_json = tempfile.mkstemp(suffix='.json')
+ tf.io.gfile.remove(local_val_json)
+
+ tf.io.gfile.copy(annotation_file, local_val_json)
+ atexit.register(tf.io.gfile.remove, local_val_json)
+ else:
+ local_val_json = annotation_file
+ self._coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if include_mask else 'box'),
+ annotation_file=local_val_json)
+ self._annotation_file = annotation_file
+ self._include_mask = include_mask
+ self._metric_names = [
+ 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'ARmax1', 'ARmax10',
+ 'ARmax100', 'ARs', 'ARm', 'ARl'
+ ]
+ self._required_prediction_fields = [
+ 'source_id', 'num_detections', 'detection_classes', 'detection_scores',
+ 'detection_boxes'
+ ]
+ self._need_rescale_bboxes = need_rescale_bboxes
+ if self._need_rescale_bboxes:
+ self._required_prediction_fields.append('image_info')
+ self._required_groundtruth_fields = [
+ 'source_id', 'height', 'width', 'classes', 'boxes'
+ ]
+ if self._include_mask:
+ mask_metric_names = ['mask_' + x for x in self._metric_names]
+ self._metric_names.extend(mask_metric_names)
+ self._required_prediction_fields.extend(['detection_masks'])
+ self._required_groundtruth_fields.extend(['masks'])
+
+ self.reset()
+
+ def reset(self):
+ """Resets internal states for a fresh run."""
+ self._predictions = {}
+ if not self._annotation_file:
+ self._groundtruths = {}
+
+ def evaluate(self):
+ """Evaluates with detections from all images with COCO API.
+
+ Returns:
+ coco_metric: float numpy array with shape [24] representing the
+ coco-style evaluation metrics (box and mask).
+ """
+ if not self._annotation_file:
+ logging.info('Thre is no annotation_file in COCOEvaluator.')
+ gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(
+ self._groundtruths)
+ coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if self._include_mask else 'box'),
+ gt_dataset=gt_dataset)
+ else:
+ logging.info('Using annotation file: %s', self._annotation_file)
+ coco_gt = self._coco_gt
+ coco_predictions = coco_utils.convert_predictions_to_coco_annotations(
+ self._predictions)
+ coco_dt = coco_gt.loadRes(predictions=coco_predictions)
+ image_ids = [ann['image_id'] for ann in coco_predictions]
+
+ coco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='bbox')
+ coco_eval.params.imgIds = image_ids
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ coco_metrics = coco_eval.stats
+
+ if self._include_mask:
+ mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm')
+ mcoco_eval.params.imgIds = image_ids
+ mcoco_eval.evaluate()
+ mcoco_eval.accumulate()
+ mcoco_eval.summarize()
+ mask_coco_metrics = mcoco_eval.stats
+
+ if self._include_mask:
+ metrics = np.hstack((coco_metrics, mask_coco_metrics))
+ else:
+ metrics = coco_metrics
+
+ # Cleans up the internal variables in order for a fresh eval next time.
+ self.reset()
+
+ metrics_dict = {}
+ for i, name in enumerate(self._metric_names):
+ metrics_dict[name] = metrics[i].astype(np.float32)
+ return metrics_dict
+
+ def _process_predictions(self, predictions):
+ image_scale = np.tile(predictions['image_info'][:, 2:3, :], (1, 1, 2))
+ predictions['detection_boxes'] = (
+ predictions['detection_boxes'].astype(np.float32))
+ predictions['detection_boxes'] /= image_scale
+ if 'detection_outer_boxes' in predictions:
+ predictions['detection_outer_boxes'] = (
+ predictions['detection_outer_boxes'].astype(np.float32))
+ predictions['detection_outer_boxes'] /= image_scale
+
+ def update(self, predictions, groundtruths=None):
+ """Update and aggregate detection results and groundtruth data.
+
+ Args:
+ predictions: a dictionary of numpy arrays including the fields below. See
+ different parsers under `../dataloader` for more details.
+ Required fields:
+ - source_id: a numpy array of int or string of shape [batch_size].
+ - image_info [if `need_rescale_bboxes` is True]: a numpy array of
+ float of shape [batch_size, 4, 2].
+ - num_detections: a numpy array of int of shape [batch_size].
+ - detection_boxes: a numpy array of float of shape [batch_size, K, 4].
+ - detection_classes: a numpy array of int of shape [batch_size, K].
+ - detection_scores: a numpy array of float of shape [batch_size, K].
+ Optional fields:
+ - detection_masks: a numpy array of float of shape [batch_size, K,
+ mask_height, mask_width].
+ groundtruths: a dictionary of numpy arrays including the fields below. See
+ also different parsers under `../dataloader` for more details.
+ Required fields:
+ - source_id: a numpy array of int or string of shape [batch_size].
+ - height: a numpy array of int of shape [batch_size].
+ - width: a numpy array of int of shape [batch_size].
+ - num_detections: a numpy array of int of shape [batch_size].
+ - boxes: a numpy array of float of shape [batch_size, K, 4].
+ - classes: a numpy array of int of shape [batch_size, K].
+ Optional fields:
+ - is_crowds: a numpy array of int of shape [batch_size, K]. If the
+ field is absent, it is assumed that this instance is not crowd.
+ - areas: a numy array of float of shape [batch_size, K]. If the field
+ is absent, the area is calculated using either boxes or masks
+ depending on which one is available.
+ - masks: a numpy array of float of shape [batch_size, K, mask_height,
+ mask_width],
+
+ Raises:
+ ValueError: if the required prediction or groundtruth fields are not
+ present in the incoming `predictions` or `groundtruths`.
+ """
+ for k in self._required_prediction_fields:
+ if k not in predictions:
+ raise ValueError(
+ 'Missing the required key `{}` in predictions!'.format(k))
+ if self._need_rescale_bboxes:
+ self._process_predictions(predictions)
+ for k, v in six.iteritems(predictions):
+ if k not in self._predictions:
+ self._predictions[k] = [v]
+ else:
+ self._predictions[k].append(v)
+
+ if not self._annotation_file:
+ assert groundtruths
+ for k in self._required_groundtruth_fields:
+ if k not in groundtruths:
+ raise ValueError(
+ 'Missing the required key `{}` in groundtruths!'.format(k))
+ for k, v in six.iteritems(groundtruths):
+ if k not in self._groundtruths:
+ self._groundtruths[k] = [v]
+ else:
+ self._groundtruths[k].append(v)
+
+
+class OlnXclassEvaluator(COCOEvaluator):
+ """COCO evaluation metric class."""
+
+ def __init__(self, annotation_file, include_mask, need_rescale_bboxes=True,
+ use_category=True, seen_class='all'):
+ """Constructs COCO evaluation class.
+
+ The class provides the interface to metrics_fn in TPUEstimator. The
+ _update_op() takes detections from each image and push them to
+ self.detections. The _evaluate() loads a JSON file in COCO annotation format
+ as the groundtruths and runs COCO evaluation.
+
+ Args:
+ annotation_file: a JSON file that stores annotations of the eval dataset.
+ If `annotation_file` is None, groundtruth annotations will be loaded
+ from the dataloader.
+ include_mask: a boolean to indicate whether or not to include the mask
+ eval.
+ need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
+ to absolute values (`image_info` is needed in this case).
+ use_category: if `False`, treat all object in all classes in one
+ foreground category.
+ seen_class: 'all' or 'voc' or 'nonvoc'
+ """
+ super(OlnXclassEvaluator, self).__init__(
+ annotation_file=annotation_file,
+ include_mask=include_mask,
+ need_rescale_bboxes=need_rescale_bboxes)
+ self._use_category = use_category
+ self._seen_class = seen_class
+ self._seen_class_ids = class_utils.coco_split_class_ids(seen_class)
+ self._metric_names = [
+ 'AP', 'AP50', 'AP75',
+ 'APs', 'APm', 'APl',
+ 'ARmax10', 'ARmax20', 'ARmax50', 'ARmax100', 'ARmax200',
+ 'ARmax10s', 'ARmax10m', 'ARmax10l'
+ ]
+ if self._seen_class != 'all':
+ self._metric_names.extend([
+ 'AP_seen', 'AP50_seen', 'AP75_seen',
+ 'APs_seen', 'APm_seen', 'APl_seen',
+ 'ARmax10_seen', 'ARmax20_seen', 'ARmax50_seen',
+ 'ARmax100_seen', 'ARmax200_seen',
+ 'ARmax10s_seen', 'ARmax10m_seen', 'ARmax10l_seen',
+
+ 'AP_novel', 'AP50_novel', 'AP75_novel',
+ 'APs_novel', 'APm_novel', 'APl_novel',
+ 'ARmax10_novel', 'ARmax20_novel', 'ARmax50_novel',
+ 'ARmax100_novel', 'ARmax200_novel',
+ 'ARmax10s_novel', 'ARmax10m_novel', 'ARmax10l_novel',
+ ])
+ if self._include_mask:
+ mask_metric_names = ['mask_' + x for x in self._metric_names]
+ self._metric_names.extend(mask_metric_names)
+ self._required_prediction_fields.extend(['detection_masks'])
+ self._required_groundtruth_fields.extend(['masks'])
+
+ self.reset()
+
+ def evaluate(self):
+ """Evaluates with detections from all images with COCO API.
+
+ Returns:
+ coco_metric: float numpy array with shape [24] representing the
+ coco-style evaluation metrics (box and mask).
+ """
+ if not self._annotation_file:
+ logging.info('Thre is no annotation_file in COCOEvaluator.')
+ gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(
+ self._groundtruths)
+ coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if self._include_mask else 'box'),
+ gt_dataset=gt_dataset)
+ else:
+ logging.info('Using annotation file: %s', self._annotation_file)
+ coco_gt = self._coco_gt
+
+ coco_predictions = coco_utils.convert_predictions_to_coco_annotations(
+ self._predictions)
+ coco_dt = coco_gt.loadRes(predictions=coco_predictions)
+ image_ids = [ann['image_id'] for ann in coco_predictions]
+ # Class manipulation: 'all' split samples -> ignored_split = 0.
+ for idx, ann in enumerate(coco_gt.dataset['annotations']):
+ coco_gt.dataset['annotations'][idx]['ignored_split'] = 0
+ coco_eval = cocoeval.OlnCOCOevalXclassWrapper(
+ coco_gt, coco_dt, iou_type='bbox')
+ coco_eval.params.maxDets = [10, 20, 50, 100, 200]
+ coco_eval.params.imgIds = image_ids
+ coco_eval.params.useCats = 0 if not self._use_category else 1
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ coco_metrics = coco_eval.stats
+
+ if self._include_mask:
+ mcoco_eval = cocoeval.OlnCOCOevalXclassWrapper(
+ coco_gt, coco_dt, iou_type='segm')
+ mcoco_eval.params.maxDets = [10, 20, 50, 100, 200]
+ mcoco_eval.params.imgIds = image_ids
+ mcoco_eval.params.useCats = 0 if not self._use_category else 1
+ mcoco_eval.evaluate()
+ mcoco_eval.accumulate()
+ mcoco_eval.summarize()
+ mask_coco_metrics = mcoco_eval.stats
+
+ if self._include_mask:
+ metrics = np.hstack((coco_metrics, mask_coco_metrics))
+ else:
+ metrics = coco_metrics
+
+ if self._seen_class != 'all':
+ # for seen class eval, samples of novel_class are ignored.
+ coco_gt_seen = copy.deepcopy(coco_gt)
+ for idx, ann in enumerate(coco_gt.dataset['annotations']):
+ if ann['category_id'] in self._seen_class_ids:
+ coco_gt_seen.dataset['annotations'][idx]['ignored_split'] = 0
+ else:
+ coco_gt_seen.dataset['annotations'][idx]['ignored_split'] = 1
+ coco_eval_seen = cocoeval.OlnCOCOevalXclassWrapper(
+ coco_gt_seen, coco_dt, iou_type='bbox')
+ coco_eval_seen.params.maxDets = [10, 20, 50, 100, 200]
+ coco_eval_seen.params.imgIds = image_ids
+ coco_eval_seen.params.useCats = 0 if not self._use_category else 1
+ coco_eval_seen.evaluate()
+ coco_eval_seen.accumulate()
+ coco_eval_seen.summarize()
+ coco_metrics_seen = coco_eval_seen.stats
+ if self._include_mask:
+ mcoco_eval_seen = cocoeval.OlnCOCOevalXclassWrapper(
+ coco_gt_seen, coco_dt, iou_type='segm')
+ mcoco_eval_seen.params.maxDets = [10, 20, 50, 100, 200]
+ mcoco_eval_seen.params.imgIds = image_ids
+ mcoco_eval_seen.params.useCats = 0 if not self._use_category else 1
+ mcoco_eval_seen.evaluate()
+ mcoco_eval_seen.accumulate()
+ mcoco_eval_seen.summarize()
+ mask_coco_metrics_seen = mcoco_eval_seen.stats
+
+ # for novel class eval, samples of seen_class are ignored.
+ coco_gt_novel = copy.deepcopy(coco_gt)
+ for idx, ann in enumerate(coco_gt.dataset['annotations']):
+ if ann['category_id'] in self._seen_class_ids:
+ coco_gt_novel.dataset['annotations'][idx]['ignored_split'] = 1
+ else:
+ coco_gt_novel.dataset['annotations'][idx]['ignored_split'] = 0
+ coco_eval_novel = cocoeval.OlnCOCOevalXclassWrapper(
+ coco_gt_novel, coco_dt, iou_type='bbox')
+ coco_eval_novel.params.maxDets = [10, 20, 50, 100, 200]
+ coco_eval_novel.params.imgIds = image_ids
+ coco_eval_novel.params.useCats = 0 if not self._use_category else 1
+ coco_eval_novel.evaluate()
+ coco_eval_novel.accumulate()
+ coco_eval_novel.summarize()
+ coco_metrics_novel = coco_eval_novel.stats
+ if self._include_mask:
+ mcoco_eval_novel = cocoeval.OlnCOCOevalXclassWrapper(
+ coco_gt_novel, coco_dt, iou_type='segm')
+ mcoco_eval_novel.params.maxDets = [10, 20, 50, 100, 200]
+ mcoco_eval_novel.params.imgIds = image_ids
+ mcoco_eval_novel.params.useCats = 0 if not self._use_category else 1
+ mcoco_eval_novel.evaluate()
+ mcoco_eval_novel.accumulate()
+ mcoco_eval_novel.summarize()
+ mask_coco_metrics_novel = mcoco_eval_novel.stats
+
+ # Combine all splits.
+ if self._include_mask:
+ metrics = np.hstack((
+ coco_metrics, coco_metrics_seen, coco_metrics_novel,
+ mask_coco_metrics, mask_coco_metrics_seen, mask_coco_metrics_novel))
+ else:
+ metrics = np.hstack((
+ coco_metrics, coco_metrics_seen, coco_metrics_novel))
+
+ # Cleans up the internal variables in order for a fresh eval next time.
+ self.reset()
+
+ metrics_dict = {}
+ for i, name in enumerate(self._metric_names):
+ metrics_dict[name] = metrics[i].astype(np.float32)
+ return metrics_dict
+
+
+class OlnXdataEvaluator(OlnXclassEvaluator):
+ """COCO evaluation metric class."""
+
+ def __init__(self, annotation_file, include_mask, need_rescale_bboxes=True,
+ use_category=True, seen_class='all'):
+ """Constructs COCO evaluation class.
+
+ The class provides the interface to metrics_fn in TPUEstimator. The
+ _update_op() takes detections from each image and push them to
+ self.detections. The _evaluate() loads a JSON file in COCO annotation format
+ as the groundtruths and runs COCO evaluation.
+
+ Args:
+ annotation_file: a JSON file that stores annotations of the eval dataset.
+ If `annotation_file` is None, groundtruth annotations will be loaded
+ from the dataloader.
+ include_mask: a boolean to indicate whether or not to include the mask
+ eval.
+ need_rescale_bboxes: If true bboxes in `predictions` will be rescaled back
+ to absolute values (`image_info` is needed in this case).
+ use_category: if `False`, treat all object in all classes in one
+ foreground category.
+ seen_class: 'all' or 'voc' or 'nonvoc'
+ """
+ super(OlnXdataEvaluator, self).__init__(
+ annotation_file=annotation_file,
+ include_mask=include_mask,
+ need_rescale_bboxes=need_rescale_bboxes,
+ use_category=False,
+ seen_class='all')
+
+ def evaluate(self):
+ """Evaluates with detections from all images with COCO API.
+
+ Returns:
+ coco_metric: float numpy array with shape [24] representing the
+ coco-style evaluation metrics (box and mask).
+ """
+ if not self._annotation_file:
+ logging.info('Thre is no annotation_file in COCOEvaluator.')
+ gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(
+ self._groundtruths)
+ coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if self._include_mask else 'box'),
+ gt_dataset=gt_dataset)
+ else:
+ logging.info('Using annotation file: %s', self._annotation_file)
+ coco_gt = self._coco_gt
+ coco_predictions = coco_utils.convert_predictions_to_coco_annotations(
+ self._predictions)
+ coco_dt = coco_gt.loadRes(predictions=coco_predictions)
+ image_ids = [ann['image_id'] for ann in coco_predictions]
+ # Class manipulation: 'all' split samples -> ignored_split = 0.
+ for idx, _ in enumerate(coco_gt.dataset['annotations']):
+ coco_gt.dataset['annotations'][idx]['ignored_split'] = 0
+ coco_eval = cocoeval.OlnCOCOevalWrapper(coco_gt, coco_dt, iou_type='bbox')
+ coco_eval.params.maxDets = [10, 20, 50, 100, 200]
+ coco_eval.params.imgIds = image_ids
+ coco_eval.params.useCats = 0 if not self._use_category else 1
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ coco_metrics = coco_eval.stats
+
+ if self._include_mask:
+ mcoco_eval = cocoeval.OlnCOCOevalWrapper(coco_gt, coco_dt,
+ iou_type='segm')
+ mcoco_eval.params.maxDets = [10, 20, 50, 100, 200]
+ mcoco_eval.params.imgIds = image_ids
+ mcoco_eval.params.useCats = 0 if not self._use_category else 1
+ mcoco_eval.evaluate()
+ mcoco_eval.accumulate()
+ mcoco_eval.summarize()
+ mask_coco_metrics = mcoco_eval.stats
+
+ if self._include_mask:
+ metrics = np.hstack((coco_metrics, mask_coco_metrics))
+ else:
+ metrics = coco_metrics
+
+ # Cleans up the internal variables in order for a fresh eval next time.
+ self.reset()
+
+ metrics_dict = {}
+ for i, name in enumerate(self._metric_names):
+ metrics_dict[name] = metrics[i].astype(np.float32)
+ return metrics_dict
+
+
+class ShapeMaskCOCOEvaluator(COCOEvaluator):
+ """COCO evaluation metric class for ShapeMask."""
+
+ def __init__(self, mask_eval_class, **kwargs):
+ """Constructs COCO evaluation class.
+
+ The class provides the interface to metrics_fn in TPUEstimator. The
+ _update_op() takes detections from each image and push them to
+ self.detections. The _evaluate() loads a JSON file in COCO annotation format
+ as the groundtruths and runs COCO evaluation.
+
+ Args:
+ mask_eval_class: the set of classes for mask evaluation.
+ **kwargs: other keyword arguments passed to the parent class initializer.
+ """
+ super(ShapeMaskCOCOEvaluator, self).__init__(**kwargs)
+ self._mask_eval_class = mask_eval_class
+ self._eval_categories = class_utils.coco_split_class_ids(mask_eval_class)
+ if mask_eval_class != 'all':
+ self._metric_names = [
+ x.replace('mask', 'novel_mask') for x in self._metric_names
+ ]
+
+ def evaluate(self):
+ """Evaluates with detections from all images with COCO API.
+
+ Returns:
+ coco_metric: float numpy array with shape [24] representing the
+ coco-style evaluation metrics (box and mask).
+ """
+ if not self._annotation_file:
+ gt_dataset = coco_utils.convert_groundtruths_to_coco_dataset(
+ self._groundtruths)
+ coco_gt = coco_utils.COCOWrapper(
+ eval_type=('mask' if self._include_mask else 'box'),
+ gt_dataset=gt_dataset)
+ else:
+ coco_gt = self._coco_gt
+ coco_predictions = coco_utils.convert_predictions_to_coco_annotations(
+ self._predictions)
+ coco_dt = coco_gt.loadRes(predictions=coco_predictions)
+ image_ids = [ann['image_id'] for ann in coco_predictions]
+
+ coco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='bbox')
+ coco_eval.params.imgIds = image_ids
+ coco_eval.evaluate()
+ coco_eval.accumulate()
+ coco_eval.summarize()
+ coco_metrics = coco_eval.stats
+
+ if self._include_mask:
+ mcoco_eval = cocoeval.COCOeval(coco_gt, coco_dt, iouType='segm')
+ mcoco_eval.params.imgIds = image_ids
+ mcoco_eval.evaluate()
+ mcoco_eval.accumulate()
+ mcoco_eval.summarize()
+ if self._mask_eval_class == 'all':
+ metrics = np.hstack((coco_metrics, mcoco_eval.stats))
+ else:
+ mask_coco_metrics = mcoco_eval.category_stats
+ val_catg_idx = np.isin(mcoco_eval.params.catIds, self._eval_categories)
+ # Gather the valid evaluation of the eval categories.
+ if np.any(val_catg_idx):
+ mean_val_metrics = []
+ for mid in range(len(self._metric_names) // 2):
+ mean_val_metrics.append(
+ np.nanmean(mask_coco_metrics[mid][val_catg_idx]))
+
+ mean_val_metrics = np.array(mean_val_metrics)
+ else:
+ mean_val_metrics = np.zeros(len(self._metric_names) // 2)
+ metrics = np.hstack((coco_metrics, mean_val_metrics))
+ else:
+ metrics = coco_metrics
+
+ # Cleans up the internal variables in order for a fresh eval next time.
+ self.reset()
+
+ metrics_dict = {}
+ for i, name in enumerate(self._metric_names):
+ metrics_dict[name] = metrics[i].astype(np.float32)
+ return metrics_dict
diff --git a/modeling/official/legacy/detection/evaluation/coco_utils.py b/modeling/official/legacy/detection/evaluation/coco_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25814a3a518701e91ed07e1edcbf66c9985ae9d
--- /dev/null
+++ b/modeling/official/legacy/detection/evaluation/coco_utils.py
@@ -0,0 +1,372 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Util functions related to pycocotools and COCO eval."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import json
+
+from absl import logging
+import numpy as np
+from PIL import Image
+from pycocotools import coco
+from pycocotools import mask as mask_api
+import six
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import tf_example_decoder
+from official.legacy.detection.utils import box_utils
+from official.legacy.detection.utils import mask_utils
+
+
+class COCOWrapper(coco.COCO):
+ """COCO wrapper class.
+
+ This class wraps COCO API object, which provides the following additional
+ functionalities:
+ 1. Support string type image id.
+ 2. Support loading the groundtruth dataset using the external annotation
+ dictionary.
+ 3. Support loading the prediction results using the external annotation
+ dictionary.
+ """
+
+ def __init__(self, eval_type='box', annotation_file=None, gt_dataset=None):
+ """Instantiates a COCO-style API object.
+
+ Args:
+ eval_type: either 'box' or 'mask'.
+ annotation_file: a JSON file that stores annotations of the eval dataset.
+ This is required if `gt_dataset` is not provided.
+ gt_dataset: the groundtruth eval datatset in COCO API format.
+ """
+ if ((annotation_file and gt_dataset) or
+ ((not annotation_file) and (not gt_dataset))):
+ raise ValueError('One and only one of `annotation_file` and `gt_dataset` '
+ 'needs to be specified.')
+
+ if eval_type not in ['box', 'mask']:
+ raise ValueError('The `eval_type` can only be either `box` or `mask`.')
+
+ coco.COCO.__init__(self, annotation_file=annotation_file)
+ self._eval_type = eval_type
+ if gt_dataset:
+ self.dataset = gt_dataset
+ self.createIndex()
+
+ def loadRes(self, predictions):
+ """Loads result file and return a result api object.
+
+ Args:
+ predictions: a list of dictionary each representing an annotation in COCO
+ format. The required fields are `image_id`, `category_id`, `score`,
+ `bbox`, `segmentation`.
+
+ Returns:
+ res: result COCO api object.
+
+ Raises:
+ ValueError: if the set of image id from predctions is not the subset of
+ the set of image id of the groundtruth dataset.
+ """
+ res = coco.COCO()
+ res.dataset['images'] = copy.deepcopy(self.dataset['images'])
+ res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
+
+ image_ids = [ann['image_id'] for ann in predictions]
+ if set(image_ids) != (set(image_ids) & set(self.getImgIds())):
+ raise ValueError('Results do not correspond to the current dataset!')
+ for ann in predictions:
+ x1, x2, y1, y2 = [ann['bbox'][0], ann['bbox'][0] + ann['bbox'][2],
+ ann['bbox'][1], ann['bbox'][1] + ann['bbox'][3]]
+ if self._eval_type == 'box':
+ ann['area'] = ann['bbox'][2] * ann['bbox'][3]
+ ann['segmentation'] = [
+ [x1, y1, x1, y2, x2, y2, x2, y1]]
+ elif self._eval_type == 'mask':
+ ann['area'] = mask_api.area(ann['segmentation'])
+
+ res.dataset['annotations'] = copy.deepcopy(predictions)
+ res.createIndex()
+ return res
+
+
+def convert_predictions_to_coco_annotations(predictions):
+ """Converts a batch of predictions to annotations in COCO format.
+
+ Args:
+ predictions: a dictionary of lists of numpy arrays including the following
+ fields. K below denotes the maximum number of instances per image.
+ Required fields:
+ - source_id: a list of numpy arrays of int or string of shape
+ [batch_size].
+ - num_detections: a list of numpy arrays of int of shape [batch_size].
+ - detection_boxes: a list of numpy arrays of float of shape
+ [batch_size, K, 4], where coordinates are in the original image
+ space (not the scaled image space).
+ - detection_classes: a list of numpy arrays of int of shape
+ [batch_size, K].
+ - detection_scores: a list of numpy arrays of float of shape
+ [batch_size, K].
+ Optional fields:
+ - detection_masks: a list of numpy arrays of float of shape
+ [batch_size, K, mask_height, mask_width].
+
+ Returns:
+ coco_predictions: prediction in COCO annotation format.
+ """
+ coco_predictions = []
+ num_batches = len(predictions['source_id'])
+ batch_size = predictions['source_id'][0].shape[0]
+ max_num_detections = predictions['detection_classes'][0].shape[1]
+ use_outer_box = 'detection_outer_boxes' in predictions
+ for i in range(num_batches):
+ predictions['detection_boxes'][i] = box_utils.yxyx_to_xywh(
+ predictions['detection_boxes'][i])
+ if use_outer_box:
+ predictions['detection_outer_boxes'][i] = box_utils.yxyx_to_xywh(
+ predictions['detection_outer_boxes'][i])
+ mask_boxes = predictions['detection_outer_boxes']
+ else:
+ mask_boxes = predictions['detection_boxes']
+
+ for j in range(batch_size):
+ if 'detection_masks' in predictions:
+ image_masks = mask_utils.paste_instance_masks(
+ predictions['detection_masks'][i][j],
+ mask_boxes[i][j],
+ int(predictions['image_info'][i][j, 0, 0]),
+ int(predictions['image_info'][i][j, 0, 1]))
+ binary_masks = (image_masks > 0.0).astype(np.uint8)
+ encoded_masks = [
+ mask_api.encode(np.asfortranarray(binary_mask))
+ for binary_mask in list(binary_masks)]
+ for k in range(max_num_detections):
+ ann = {}
+ ann['image_id'] = predictions['source_id'][i][j]
+ ann['category_id'] = predictions['detection_classes'][i][j, k]
+ ann['bbox'] = predictions['detection_boxes'][i][j, k]
+ ann['score'] = predictions['detection_scores'][i][j, k]
+ if 'detection_masks' in predictions:
+ ann['segmentation'] = encoded_masks[k]
+ coco_predictions.append(ann)
+
+ for i, ann in enumerate(coco_predictions):
+ ann['id'] = i + 1
+
+ return coco_predictions
+
+
+def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
+ """Converts groundtruths to the dataset in COCO format.
+
+ Args:
+ groundtruths: a dictionary of numpy arrays including the fields below.
+ Note that each element in the list represent the number for a single
+ example without batch dimension. K below denotes the actual number of
+ instances for each image.
+ Required fields:
+ - source_id: a list of numpy arrays of int or string of shape
+ [batch_size].
+ - height: a list of numpy arrays of int of shape [batch_size].
+ - width: a list of numpy arrays of int of shape [batch_size].
+ - num_detections: a list of numpy arrays of int of shape [batch_size].
+ - boxes: a list of numpy arrays of float of shape [batch_size, K, 4],
+ where coordinates are in the original image space (not the
+ normalized coordinates).
+ - classes: a list of numpy arrays of int of shape [batch_size, K].
+ Optional fields:
+ - is_crowds: a list of numpy arrays of int of shape [batch_size, K]. If
+ th field is absent, it is assumed that this instance is not crowd.
+ - areas: a list of numy arrays of float of shape [batch_size, K]. If the
+ field is absent, the area is calculated using either boxes or
+ masks depending on which one is available.
+ - masks: a list of numpy arrays of string of shape [batch_size, K],
+ label_map: (optional) a dictionary that defines items from the category id
+ to the category name. If `None`, collect the category mappping from the
+ `groundtruths`.
+
+ Returns:
+ coco_groundtruths: the groundtruth dataset in COCO format.
+ """
+ source_ids = np.concatenate(groundtruths['source_id'], axis=0)
+ heights = np.concatenate(groundtruths['height'], axis=0)
+ widths = np.concatenate(groundtruths['width'], axis=0)
+ gt_images = [{'id': int(i), 'height': int(h), 'width': int(w)} for i, h, w
+ in zip(source_ids, heights, widths)]
+
+ gt_annotations = []
+ num_batches = len(groundtruths['source_id'])
+ batch_size = groundtruths['source_id'][0].shape[0]
+ for i in range(num_batches):
+ for j in range(batch_size):
+ num_instances = groundtruths['num_detections'][i][j]
+ for k in range(num_instances):
+ ann = {}
+ ann['image_id'] = int(groundtruths['source_id'][i][j])
+ if 'is_crowds' in groundtruths:
+ ann['iscrowd'] = int(groundtruths['is_crowds'][i][j, k])
+ else:
+ ann['iscrowd'] = 0
+ ann['category_id'] = int(groundtruths['classes'][i][j, k])
+ boxes = groundtruths['boxes'][i]
+ ann['bbox'] = [
+ float(boxes[j, k, 1]),
+ float(boxes[j, k, 0]),
+ float(boxes[j, k, 3] - boxes[j, k, 1]),
+ float(boxes[j, k, 2] - boxes[j, k, 0])]
+ if 'areas' in groundtruths:
+ ann['area'] = float(groundtruths['areas'][i][j, k])
+ else:
+ ann['area'] = float(
+ (boxes[j, k, 3] - boxes[j, k, 1]) *
+ (boxes[j, k, 2] - boxes[j, k, 0]))
+ if 'masks' in groundtruths:
+ mask = Image.open(six.BytesIO(groundtruths['masks'][i][j, k]))
+ np_mask = np.array(mask, dtype=np.uint8)
+ np_mask[np_mask > 0] = 255
+ encoded_mask = mask_api.encode(np.asfortranarray(np_mask))
+ ann['segmentation'] = encoded_mask
+ if 'areas' not in groundtruths:
+ ann['area'] = mask_api.area(encoded_mask)
+ gt_annotations.append(ann)
+
+ for i, ann in enumerate(gt_annotations):
+ ann['id'] = i + 1
+
+ if label_map:
+ gt_categories = [{'id': i, 'name': label_map[i]} for i in label_map]
+ else:
+ category_ids = [gt['category_id'] for gt in gt_annotations]
+ gt_categories = [{'id': i} for i in set(category_ids)]
+
+ gt_dataset = {
+ 'images': gt_images,
+ 'categories': gt_categories,
+ 'annotations': copy.deepcopy(gt_annotations),
+ }
+ return gt_dataset
+
+
+class COCOGroundtruthGenerator(object):
+ """Generates the groundtruth annotations from a single example."""
+
+ def __init__(self, file_pattern, num_examples, include_mask):
+ self._file_pattern = file_pattern
+ self._num_examples = num_examples
+ self._include_mask = include_mask
+ self._dataset_fn = tf.data.TFRecordDataset
+
+ def _parse_single_example(self, example):
+ """Parses a single serialized tf.Example proto.
+
+ Args:
+ example: a serialized tf.Example proto string.
+
+ Returns:
+ A dictionary of groundtruth with the following fields:
+ source_id: a scalar tensor of int64 representing the image source_id.
+ height: a scalar tensor of int64 representing the image height.
+ width: a scalar tensor of int64 representing the image width.
+ boxes: a float tensor of shape [K, 4], representing the groundtruth
+ boxes in absolute coordinates with respect to the original image size.
+ classes: a int64 tensor of shape [K], representing the class labels of
+ each instances.
+ is_crowds: a bool tensor of shape [K], indicating whether the instance
+ is crowd.
+ areas: a float tensor of shape [K], indicating the area of each
+ instance.
+ masks: a string tensor of shape [K], containing the bytes of the png
+ mask of each instance.
+ """
+ decoder = tf_example_decoder.TfExampleDecoder(
+ include_mask=self._include_mask)
+ decoded_tensors = decoder.decode(example)
+
+ image = decoded_tensors['image']
+ image_size = tf.shape(image)[0:2]
+ boxes = box_utils.denormalize_boxes(
+ decoded_tensors['groundtruth_boxes'], image_size)
+ groundtruths = {
+ 'source_id': tf.string_to_number(
+ decoded_tensors['source_id'], out_type=tf.int64),
+ 'height': decoded_tensors['height'],
+ 'width': decoded_tensors['width'],
+ 'num_detections': tf.shape(decoded_tensors['groundtruth_classes'])[0],
+ 'boxes': boxes,
+ 'classes': decoded_tensors['groundtruth_classes'],
+ 'is_crowds': decoded_tensors['groundtruth_is_crowd'],
+ 'areas': decoded_tensors['groundtruth_area'],
+ }
+ if self._include_mask:
+ groundtruths.update({
+ 'masks': decoded_tensors['groundtruth_instance_masks_png'],
+ })
+ return groundtruths
+
+ def _build_pipeline(self):
+ """Builds data pipeline to generate groundtruth annotations."""
+ dataset = tf.data.Dataset.list_files(self._file_pattern, shuffle=False)
+ dataset = dataset.apply(
+ tf.data.experimental.parallel_interleave(
+ lambda filename: self._dataset_fn(filename).prefetch(1),
+ cycle_length=32,
+ sloppy=False))
+ dataset = dataset.map(self._parse_single_example, num_parallel_calls=64)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(1, drop_remainder=False)
+ return dataset
+
+ def __call__(self):
+ with tf.Graph().as_default():
+ dataset = self._build_pipeline()
+ groundtruth = dataset.make_one_shot_iterator().get_next()
+
+ with tf.Session() as sess:
+ for _ in range(self._num_examples):
+ groundtruth_result = sess.run(groundtruth)
+ yield groundtruth_result
+
+
+def scan_and_generator_annotation_file(file_pattern,
+ num_samples,
+ include_mask,
+ annotation_file):
+ """Scans and generate the COCO-style annotation JSON file given a dataset."""
+ groundtruth_generator = COCOGroundtruthGenerator(
+ file_pattern, num_samples, include_mask)
+ generate_annotation_file(groundtruth_generator, annotation_file)
+
+
+def generate_annotation_file(groundtruth_generator,
+ annotation_file):
+ """Generates COCO-style annotation JSON file given a groundtruth generator."""
+ groundtruths = {}
+ logging.info('Loading groundtruth annotations from dataset to memory...')
+ for groundtruth in groundtruth_generator():
+ for k, v in six.iteritems(groundtruth):
+ if k not in groundtruths:
+ groundtruths[k] = [v]
+ else:
+ groundtruths[k].append(v)
+ gt_dataset = convert_groundtruths_to_coco_dataset(groundtruths)
+
+ logging.info('Saving groundtruth annotations to the JSON file...')
+ with tf.io.gfile.GFile(annotation_file, 'w') as f:
+ f.write(json.dumps(gt_dataset))
+ logging.info('Done saving the JSON file...')
diff --git a/modeling/official/legacy/detection/evaluation/factory.py b/modeling/official/legacy/detection/evaluation/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aa51b885126b356c2b9c97a827474099e115618
--- /dev/null
+++ b/modeling/official/legacy/detection/evaluation/factory.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Evaluator factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from official.legacy.detection.evaluation import coco_evaluator
+
+
+def evaluator_generator(params):
+ """Generator function for various evaluators."""
+ if params.type == 'box':
+ evaluator = coco_evaluator.COCOEvaluator(
+ annotation_file=params.val_json_file, include_mask=False)
+ elif params.type == 'box_and_mask':
+ evaluator = coco_evaluator.COCOEvaluator(
+ annotation_file=params.val_json_file, include_mask=True)
+ elif params.type == 'oln_xclass_box':
+ evaluator = coco_evaluator.OlnXclassEvaluator(
+ annotation_file=params.val_json_file, include_mask=False,
+ use_category=False, seen_class=params.seen_class,)
+ elif params.type == 'oln_xclass_box_and_mask':
+ evaluator = coco_evaluator.OlnXclassEvaluator(
+ annotation_file=params.val_json_file, include_mask=True,
+ use_category=False, seen_class=params.seen_class,)
+ elif params.type == 'oln_xdata_box':
+ evaluator = coco_evaluator.OlnXdataEvaluator(
+ annotation_file=params.val_json_file, include_mask=False,
+ use_category=False, seen_class='all',)
+ elif params.type == 'shapemask_box_and_mask':
+ evaluator = coco_evaluator.ShapeMaskCOCOEvaluator(
+ mask_eval_class=params.mask_eval_class,
+ annotation_file=params.val_json_file, include_mask=True)
+
+ else:
+ raise ValueError('Evaluator %s is not supported.' % params.type)
+
+ return coco_evaluator.MetricWrapper(evaluator)
diff --git a/modeling/official/legacy/detection/executor/__init__.py b/modeling/official/legacy/detection/executor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/executor/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/executor/detection_executor.py b/modeling/official/legacy/detection/executor/detection_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..506a2947213df3ba149d60b697f3e676f2d32f1f
--- /dev/null
+++ b/modeling/official/legacy/detection/executor/detection_executor.py
@@ -0,0 +1,159 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""An executor class for running model on TensorFlow 2.0."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+
+import tensorflow as tf, tf_keras
+from official.legacy.detection.executor import distributed_executor as executor
+from official.vision.utils.object_detection import visualization_utils
+
+
+class DetectionDistributedExecutor(executor.DistributedExecutor):
+ """Detection specific customer training loop executor.
+
+ Subclasses the DistributedExecutor and adds support for numpy based metrics.
+ """
+
+ def __init__(self,
+ predict_post_process_fn=None,
+ trainable_variables_filter=None,
+ **kwargs):
+ super(DetectionDistributedExecutor, self).__init__(**kwargs)
+ if predict_post_process_fn:
+ assert callable(predict_post_process_fn)
+ if trainable_variables_filter:
+ assert callable(trainable_variables_filter)
+ self._predict_post_process_fn = predict_post_process_fn
+ self._trainable_variables_filter = trainable_variables_filter
+ self.eval_steps = tf.Variable(
+ 0,
+ trainable=False,
+ dtype=tf.int32,
+ synchronization=tf.VariableSynchronization.ON_READ,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+ shape=[])
+
+ def _create_replicated_step(self,
+ strategy,
+ model,
+ loss_fn,
+ optimizer,
+ metric=None):
+ trainable_variables = model.trainable_variables
+ if self._trainable_variables_filter:
+ trainable_variables = self._trainable_variables_filter(
+ trainable_variables)
+ logging.info('Filter trainable variables from %d to %d',
+ len(model.trainable_variables), len(trainable_variables))
+ update_state_fn = lambda labels, outputs: None
+ if isinstance(metric, tf_keras.metrics.Metric):
+ update_state_fn = metric.update_state
+ else:
+ logging.error('Detection: train metric is not an instance of '
+ 'tf_keras.metrics.Metric.')
+
+ def _replicated_step(inputs):
+ """Replicated training step."""
+ inputs, labels = inputs
+
+ with tf.GradientTape() as tape:
+ outputs = model(inputs, training=True)
+ all_losses = loss_fn(labels, outputs)
+ losses = {}
+ for k, v in all_losses.items():
+ losses[k] = tf.reduce_mean(v)
+ per_replica_loss = losses['total_loss'] / strategy.num_replicas_in_sync
+ update_state_fn(labels, outputs)
+
+ grads = tape.gradient(per_replica_loss, trainable_variables)
+ clipped_grads, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
+ optimizer.apply_gradients(zip(clipped_grads, trainable_variables))
+ return losses
+
+ return _replicated_step
+
+ def _create_test_step(self, strategy, model, metric):
+ """Creates a distributed test step."""
+
+ @tf.function
+ def test_step(iterator, eval_steps):
+ """Calculates evaluation metrics on distributed devices."""
+
+ def _test_step_fn(inputs, eval_steps):
+ """Replicated accuracy calculation."""
+ inputs, labels = inputs
+ model_outputs = model(inputs, training=False)
+ if self._predict_post_process_fn:
+ labels, prediction_outputs = self._predict_post_process_fn(
+ labels, model_outputs)
+ num_remaining_visualizations = (
+ self._params.eval.num_images_to_visualize - eval_steps)
+ # If there are remaining number of visualizations that needs to be
+ # done, add next batch outputs for visualization.
+ #
+ # TODO(hongjunchoi): Once dynamic slicing is supported on TPU, only
+ # write correct slice of outputs to summary file.
+ if num_remaining_visualizations > 0:
+ visualization_utils.visualize_images_with_bounding_boxes(
+ inputs, prediction_outputs['detection_boxes'],
+ self.global_train_step, self.eval_summary_writer)
+
+ return labels, prediction_outputs
+
+ labels, outputs = strategy.run(
+ _test_step_fn, args=(
+ next(iterator),
+ eval_steps,
+ ))
+ outputs = tf.nest.map_structure(strategy.experimental_local_results,
+ outputs)
+ labels = tf.nest.map_structure(strategy.experimental_local_results,
+ labels)
+
+ eval_steps.assign_add(self._params.eval.batch_size)
+ return labels, outputs
+
+ return test_step
+
+ def _run_evaluation(self, test_step, current_training_step, metric,
+ test_iterator):
+ """Runs validation steps and aggregate metrics."""
+ self.eval_steps.assign(0)
+ if not test_iterator or not metric:
+ logging.warning(
+ 'Both test_iterator (%s) and metrics (%s) must not be None.',
+ test_iterator, metric)
+ return None
+ logging.info('Running evaluation after step: %s.', current_training_step)
+ while True:
+ try:
+ labels, outputs = test_step(test_iterator, self.eval_steps)
+ if metric:
+ metric.update_state(labels, outputs)
+ except (StopIteration, tf.errors.OutOfRangeError):
+ break
+
+ metric_result = metric.result()
+ if isinstance(metric, tf_keras.metrics.Metric):
+ metric_result = tf.nest.map_structure(lambda x: x.numpy().astype(float),
+ metric_result)
+ logging.info('Step: [%d] Validation metric = %s', current_training_step,
+ metric_result)
+ return metric_result
diff --git a/modeling/official/legacy/detection/executor/distributed_executor.py b/modeling/official/legacy/detection/executor/distributed_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..19fcf07113e61139250ea5b5490dd4100012e2c6
--- /dev/null
+++ b/modeling/official/legacy/detection/executor/distributed_executor.py
@@ -0,0 +1,811 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Custom training loop for running TensorFlow 2.0 models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from typing import Optional, Dict, List, Text, Callable, Union, Iterator, Any
+
+from absl import flags
+from absl import logging
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+# pylint: disable=unused-import,g-import-not-at-top,redefined-outer-name,reimported
+from official.common import distribute_utils
+from official.modeling.hyperparams import params_dict
+from official.utils import hyperparams_flags
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+
+strategy_flags_dict = hyperparams_flags.strategy_flags_dict
+hparam_flags_dict = hyperparams_flags.hparam_flags_dict
+
+
+def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
+ """Saves model to model_dir with provided checkpoint prefix."""
+
+ checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
+ saved_path = checkpoint.save(checkpoint_path)
+ logging.info('Saving model as TF checkpoint: %s', saved_path)
+
+
+def _steps_to_run(current_step, total_steps, steps_per_loop):
+ """Calculates steps to run on device."""
+ if steps_per_loop <= 0:
+ raise ValueError('steps_per_loop should be positive integer.')
+ return min(total_steps - current_step, steps_per_loop)
+
+
+def _no_metric():
+ return None
+
+
+def metrics_as_dict(metric):
+ """Puts input metric(s) into a list.
+
+ Args:
+ metric: metric(s) to be put into the list. `metric` could be an object, a
+ list, or a dict of tf_keras.metrics.Metric or has the `required_method`.
+
+ Returns:
+ A dictionary of valid metrics.
+ """
+ if isinstance(metric, tf_keras.metrics.Metric):
+ metrics = {metric.name: metric}
+ elif isinstance(metric, list):
+ metrics = {m.name: m for m in metric}
+ elif isinstance(metric, dict):
+ metrics = metric
+ elif not metric:
+ return {}
+ else:
+ metrics = {'metric': metric}
+ return metrics
+
+
+def metric_results(metric):
+ """Collects results from the given metric(s)."""
+ metrics = metrics_as_dict(metric)
+ metric_result = {
+ name: m.result().numpy().astype(float) for name, m in metrics.items()
+ }
+ return metric_result
+
+
+def reset_states(metric):
+ """Resets states of the given metric(s)."""
+ metrics = metrics_as_dict(metric)
+ for m in metrics.values():
+ m.reset_states()
+
+
+class SummaryWriter(object):
+ """Simple SummaryWriter for writing dictionary of metrics.
+
+ Attributes:
+ writer: The tf.SummaryWriter.
+ """
+
+ def __init__(self, model_dir: Text, name: Text):
+ """Inits SummaryWriter with paths.
+
+ Args:
+ model_dir: the model folder path.
+ name: the summary subfolder name.
+ """
+ self.writer = tf.summary.create_file_writer(os.path.join(model_dir, name))
+
+ def __call__(self, metrics: Union[Dict[Text, float], float], step: int):
+ """Write metrics to summary with the given writer.
+
+ Args:
+ metrics: a dictionary of metrics values. Prefer dictionary.
+ step: integer. The training step.
+ """
+ if not isinstance(metrics, dict):
+ # Support scalar metric without name.
+ logging.warning('Warning: summary writer prefer metrics as dictionary.')
+ metrics = {'metric': metrics}
+
+ with self.writer.as_default():
+ for k, v in metrics.items():
+ tf.summary.scalar(k, v, step=step)
+ self.writer.flush()
+
+
+class DistributedExecutor(object):
+ """Interface to train and eval models with tf.distribute.Strategy."""
+
+ def __init__(self, strategy, params, model_fn, loss_fn, is_multi_host=False):
+ """Constructor.
+
+ Args:
+ strategy: an instance of tf.distribute.Strategy.
+ params: Model configuration needed to run distribution strategy.
+ model_fn: Keras model function. Signature:
+ (params: ParamsDict) -> tf_keras.models.Model.
+ loss_fn: loss function. Signature:
+ (y_true: Tensor, y_pred: Tensor) -> Tensor
+ is_multi_host: Set to True when using multi hosts for training, like multi
+ worker GPU or TPU pod (slice). Otherwise, False.
+ """
+
+ self._params = params
+ self._model_fn = model_fn
+ self._loss_fn = loss_fn
+ self._strategy = strategy
+ self._checkpoint_name = 'ctl_step_{step}.ckpt'
+ self._is_multi_host = is_multi_host
+ self.train_summary_writer = None
+ self.eval_summary_writer = None
+ self.global_train_step = None
+
+ @property
+ def checkpoint_name(self):
+ """Returns default checkpoint name."""
+ return self._checkpoint_name
+
+ @checkpoint_name.setter
+ def checkpoint_name(self, name):
+ """Sets default summary writer for the current thread."""
+ self._checkpoint_name = name
+
+ def loss_fn(self):
+ return self._loss_fn()
+
+ def model_fn(self, params):
+ return self._model_fn(params)
+
+ def _save_config(self, model_dir):
+ """Save parameters to config files if model_dir is defined."""
+
+ logging.info('Save config to model_dir %s.', model_dir)
+ if model_dir:
+ if not tf.io.gfile.exists(model_dir):
+ tf.io.gfile.makedirs(model_dir)
+ self._params.lock()
+ params_dict.save_params_dict_to_yaml(self._params,
+ model_dir + '/params.yaml')
+ else:
+ logging.warning('model_dir is empty, so skip the save config.')
+
+ def _get_input_iterator(
+ self, input_fn: Callable[..., tf.data.Dataset],
+ strategy: tf.distribute.Strategy) -> Optional[Iterator[Any]]:
+ """Returns distributed dataset iterator.
+
+ Args:
+ input_fn: (params: dict) -> tf.data.Dataset.
+ strategy: an instance of tf.distribute.Strategy.
+
+ Returns:
+ An iterator that yields input tensors.
+ """
+
+ if input_fn is None:
+ return None
+ # When training with multiple TPU workers, datasets needs to be cloned
+ # across workers. Since Dataset instance cannot be cloned in eager mode,
+ # we instead pass callable that returns a dataset.
+ if self._is_multi_host:
+ return iter(strategy.distribute_datasets_from_function(input_fn))
+ else:
+ input_data = input_fn()
+ return iter(strategy.experimental_distribute_dataset(input_data))
+
+ def _create_replicated_step(self,
+ strategy,
+ model,
+ loss_fn,
+ optimizer,
+ metric=None):
+ """Creates a single training step.
+
+ Args:
+ strategy: an instance of tf.distribute.Strategy.
+ model: (Tensor, bool) -> Tensor. model function.
+ loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
+ optimizer: tf_keras.optimizers.Optimizer.
+ metric: tf_keras.metrics.Metric subclass.
+
+ Returns:
+ The training step callable.
+ """
+ metrics = metrics_as_dict(metric)
+
+ def _replicated_step(inputs):
+ """Replicated training step."""
+ inputs, labels = inputs
+
+ with tf.GradientTape() as tape:
+ outputs = model(inputs, training=True)
+ prediction_loss = loss_fn(labels, outputs)
+ loss = tf.reduce_mean(prediction_loss)
+ loss = loss / strategy.num_replicas_in_sync
+ for m in metrics.values():
+ m.update_state(labels, outputs)
+
+ grads = tape.gradient(loss, model.trainable_variables)
+ optimizer.apply_gradients(zip(grads, model.trainable_variables))
+ return loss
+
+ return _replicated_step
+
+ def _create_train_step(self,
+ strategy,
+ model,
+ loss_fn,
+ optimizer,
+ metric=None):
+ """Creates a distributed training step.
+
+ Args:
+ strategy: an instance of tf.distribute.Strategy.
+ model: (Tensor, bool) -> Tensor. model function.
+ loss_fn: (y_true: Tensor, y_pred: Tensor) -> Tensor.
+ optimizer: tf_keras.optimizers.Optimizer.
+ metric: tf_keras.metrics.Metric subclass.
+
+ Returns:
+ The training step callable.
+ """
+ replicated_step = self._create_replicated_step(strategy, model, loss_fn,
+ optimizer, metric)
+
+ @tf.function
+ def train_step(iterator, num_steps):
+ """Performs a distributed training step.
+
+ Args:
+ iterator: an iterator that yields input tensors.
+ num_steps: the number of steps in the loop.
+
+ Returns:
+ The loss tensor.
+ """
+ if not isinstance(num_steps, tf.Tensor):
+ raise ValueError('steps should be an Tensor. Python object may cause '
+ 'retracing.')
+
+ per_replica_losses = strategy.run(replicated_step, args=(next(iterator),))
+ for _ in tf.range(num_steps - 1):
+ per_replica_losses = strategy.run(
+ replicated_step, args=(next(iterator),))
+
+ # For reporting, we returns the mean of losses.
+ losses = tf.nest.map_structure(
+ lambda x: strategy.reduce(tf.distribute.ReduceOp.MEAN, x, axis=None),
+ per_replica_losses)
+ return losses
+
+ return train_step
+
+ def _create_test_step(self, strategy, model, metric):
+ """Creates a distributed test step."""
+ metrics = metrics_as_dict(metric)
+
+ @tf.function
+ def test_step(iterator):
+ """Calculates evaluation metrics on distributed devices."""
+ if not metric:
+ logging.info('Skip test_step because metric is None (%s)', metric)
+ return None, None
+
+ def _test_step_fn(inputs):
+ """Replicated accuracy calculation."""
+ inputs, labels = inputs
+ model_outputs = model(inputs, training=False)
+ for m in metrics.values():
+ m.update_state(labels, model_outputs)
+ return labels, model_outputs
+
+ return strategy.run(_test_step_fn, args=(next(iterator),))
+
+ return test_step
+
+ def train(
+ self,
+ train_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
+ eval_input_fn: Optional[Callable[[params_dict.ParamsDict],
+ tf.data.Dataset]] = None,
+ model_dir: Optional[Text] = None,
+ total_steps: int = 1,
+ iterations_per_loop: int = 1,
+ train_metric_fn: Optional[Callable[[], Any]] = None,
+ eval_metric_fn: Optional[Callable[[], Any]] = None,
+ summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter,
+ init_checkpoint: Optional[Callable[[tf_keras.Model], Any]] = None,
+ custom_callbacks: Optional[List[tf_keras.callbacks.Callback]] = None,
+ continuous_eval: bool = False,
+ save_config: bool = True):
+ """Runs distributed training.
+
+ Args:
+ train_input_fn: (params: dict) -> tf.data.Dataset training data input
+ function.
+ eval_input_fn: (Optional) same type as train_input_fn. If not None, will
+ trigger evaluating metric on eval data. If None, will not run the eval
+ step.
+ model_dir: the folder path for model checkpoints.
+ total_steps: total training steps.
+ iterations_per_loop: train steps per loop. After each loop, this job will
+ update metrics like loss and save checkpoint.
+ train_metric_fn: metric_fn for evaluation in train_step.
+ eval_metric_fn: metric_fn for evaluation in test_step.
+ summary_writer_fn: function to create summary writer.
+ init_checkpoint: function to load checkpoint.
+ custom_callbacks: A list of Keras Callbacks objects to run during
+ training. More specifically, `on_batch_begin()`, `on_batch_end()`,
+ methods are invoked during training.
+ continuous_eval: If `True`, will continously run evaluation on every
+ available checkpoints. If `False`, will do the evaluation once after the
+ final step.
+ save_config: bool. Whether to save params to model_dir.
+
+ Returns:
+ The training loss and eval metrics.
+ """
+ assert train_input_fn is not None
+ if train_metric_fn and not callable(train_metric_fn):
+ raise ValueError('if `train_metric_fn` is specified, '
+ 'train_metric_fn must be a callable.')
+ if eval_metric_fn and not callable(eval_metric_fn):
+ raise ValueError('if `eval_metric_fn` is specified, '
+ 'eval_metric_fn must be a callable.')
+ train_metric_fn = train_metric_fn or _no_metric
+ eval_metric_fn = eval_metric_fn or _no_metric
+
+ if custom_callbacks and iterations_per_loop != 1:
+ logging.warning(
+ 'It is sematically wrong to run callbacks when '
+ 'iterations_per_loop is not one (%s)', iterations_per_loop)
+
+ custom_callbacks = custom_callbacks or []
+
+ def _run_callbacks_on_batch_begin(batch):
+ """Runs custom callbacks at the start of every step."""
+ if not custom_callbacks:
+ return
+ for callback in custom_callbacks:
+ if callback:
+ callback.on_batch_begin(batch)
+
+ def _run_callbacks_on_batch_end(batch):
+ """Runs custom callbacks at the end of every step."""
+ if not custom_callbacks:
+ return
+ for callback in custom_callbacks:
+ if callback:
+ callback.on_batch_end(batch)
+
+ if save_config:
+ self._save_config(model_dir)
+
+ if FLAGS.save_checkpoint_freq:
+ save_freq = FLAGS.save_checkpoint_freq
+ else:
+ save_freq = iterations_per_loop
+
+ params = self._params
+ strategy = self._strategy
+ # To reduce unnecessary send/receive input pipeline operation, we place
+ # input pipeline ops in worker task.
+ train_iterator = self._get_input_iterator(train_input_fn, strategy)
+ train_loss = None
+ train_metric_result = None
+ eval_metric_result = None
+ tf_keras.backend.set_learning_phase(1)
+ with strategy.scope():
+ # To correctly place the model weights on accelerators,
+ # model and optimizer should be created in scope.
+ model = self.model_fn(params.as_dict())
+ if not hasattr(model, 'optimizer'):
+ raise ValueError('User should set optimizer attribute to model '
+ 'inside `model_fn`.')
+ optimizer = model.optimizer
+
+ # Training loop starts here.
+ checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
+ latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
+ initial_step = 0
+ if latest_checkpoint_file:
+ logging.info(
+ 'Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint_file)
+ checkpoint.restore(latest_checkpoint_file)
+ initial_step = optimizer.iterations.numpy()
+ logging.info('Loading from checkpoint file completed. Init step %d',
+ initial_step)
+ elif init_checkpoint:
+ logging.info('Restoring from init checkpoint function')
+ init_checkpoint(model)
+ logging.info('Loading from init checkpoint file completed')
+
+ current_step = optimizer.iterations.numpy()
+ checkpoint_name = self.checkpoint_name
+
+ eval_metric = eval_metric_fn()
+ train_metric = train_metric_fn()
+ train_summary_writer = summary_writer_fn(model_dir, 'eval_train')
+ self.train_summary_writer = train_summary_writer.writer
+
+ test_summary_writer = summary_writer_fn(model_dir, 'eval_test')
+ self.eval_summary_writer = test_summary_writer.writer
+
+ # Use training summary writer in TimeHistory if it's in use
+ for cb in custom_callbacks:
+ if isinstance(cb, keras_utils.TimeHistory):
+ cb.summary_writer = self.train_summary_writer
+
+ # Continue training loop.
+ train_step = self._create_train_step(
+ strategy=strategy,
+ model=model,
+ loss_fn=self.loss_fn(),
+ optimizer=optimizer,
+ metric=train_metric)
+ test_step = None
+ if eval_input_fn and eval_metric:
+ self.global_train_step = model.optimizer.iterations
+ test_step = self._create_test_step(strategy, model, metric=eval_metric)
+
+ # Step-0 operations
+ if current_step == 0 and not latest_checkpoint_file:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ if test_step:
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(test_step, current_step,
+ eval_metric, eval_iterator)
+ logging.info('Step: %s evalation metric = %s.', current_step,
+ eval_metric_result)
+ test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
+ reset_states(eval_metric)
+
+ logging.info('Training started')
+ last_save_checkpoint_step = current_step
+ while current_step < total_steps:
+
+ num_steps = _steps_to_run(current_step, total_steps, iterations_per_loop)
+ _run_callbacks_on_batch_begin(current_step)
+ train_loss = train_step(train_iterator,
+ tf.convert_to_tensor(num_steps, dtype=tf.int32))
+ current_step += num_steps
+
+ train_loss = tf.nest.map_structure(lambda x: x.numpy().astype(float),
+ train_loss)
+
+ _run_callbacks_on_batch_end(current_step - 1)
+ if not isinstance(train_loss, dict):
+ train_loss = {'total_loss': train_loss}
+ if np.isnan(train_loss['total_loss']):
+ raise ValueError('total loss is NaN.')
+
+ if train_metric:
+ train_metric_result = metric_results(train_metric)
+ train_metric_result.update(train_loss)
+ else:
+ train_metric_result = train_loss
+ if callable(optimizer.lr):
+ train_metric_result.update(
+ {'learning_rate': optimizer.lr(current_step).numpy()})
+ else:
+ train_metric_result.update({'learning_rate': optimizer.lr.numpy()})
+ logging.info('Train Step: %d/%d / loss = %s / training metric = %s',
+ current_step, total_steps, train_loss, train_metric_result)
+
+ train_summary_writer(
+ metrics=train_metric_result, step=optimizer.iterations)
+
+ # Saves model checkpoints and run validation steps at every
+ # iterations_per_loop steps.
+ # To avoid repeated model saving, we do not save after the last
+ # step of training.
+ if save_freq > 0 and current_step < total_steps and (
+ current_step - last_save_checkpoint_step) >= save_freq:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ last_save_checkpoint_step = current_step
+
+ if continuous_eval and current_step < total_steps and test_step:
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(test_step, current_step,
+ eval_metric, eval_iterator)
+ logging.info('Step: %s evalation metric = %s.', current_step,
+ eval_metric_result)
+ test_summary_writer(
+ metrics=eval_metric_result, step=optimizer.iterations)
+
+ # Re-initialize evaluation metric, except the last step.
+ if eval_metric and current_step < total_steps:
+ reset_states(eval_metric)
+ if train_metric and current_step < total_steps:
+ reset_states(train_metric)
+
+ # Reaches the end of training and saves the last checkpoint.
+ if last_save_checkpoint_step < total_steps:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+
+ if test_step:
+ logging.info('Running final evaluation after training is complete.')
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(test_step, current_step,
+ eval_metric, eval_iterator)
+ logging.info('Final evaluation metric = %s.', eval_metric_result)
+ test_summary_writer(metrics=eval_metric_result, step=optimizer.iterations)
+
+ self.train_summary_writer.close()
+ self.eval_summary_writer.close()
+
+ return train_metric_result, eval_metric_result
+
+ def _run_evaluation(self, test_step, current_training_step, metric,
+ test_iterator):
+ """Runs validation steps and aggregate metrics."""
+ if not test_iterator or not metric:
+ logging.warning(
+ 'Both test_iterator (%s) and metrics (%s) must not be None.',
+ test_iterator, metric)
+ return None
+ logging.info('Running evaluation after step: %s.', current_training_step)
+ eval_step = 0
+ while True:
+ try:
+ with tf.experimental.async_scope():
+ test_step(test_iterator)
+ eval_step += 1
+ except (StopIteration, tf.errors.OutOfRangeError):
+ tf.experimental.async_clear_error()
+ break
+
+ metric_result = metric_results(metric)
+ logging.info('Total eval steps: [%d]', eval_step)
+ logging.info('At training step: [%r] Validation metric = %r',
+ current_training_step, metric_result)
+ return metric_result
+
+ def evaluate_from_model_dir(
+ self,
+ model_dir: Text,
+ eval_input_fn: Callable[[params_dict.ParamsDict], tf.data.Dataset],
+ eval_metric_fn: Callable[[], Any],
+ total_steps: int = -1,
+ eval_timeout: Optional[int] = None,
+ min_eval_interval: int = 180,
+ summary_writer_fn: Callable[[Text, Text], SummaryWriter] = SummaryWriter):
+ """Runs distributed evaluation on model folder.
+
+ Args:
+ model_dir: the folder for storing model checkpoints.
+ eval_input_fn: (Optional) same type as train_input_fn. If not None, will
+ trigger evaluting metric on eval data. If None, will not run eval step.
+ eval_metric_fn: metric_fn for evaluation in test_step.
+ total_steps: total training steps. If the current step reaches the
+ total_steps, the evaluation loop will stop.
+ eval_timeout: The maximum number of seconds to wait between checkpoints.
+ If left as None, then the process will wait indefinitely. Used by
+ tf.train.checkpoints_iterator.
+ min_eval_interval: The minimum number of seconds between yielding
+ checkpoints. Used by tf.train.checkpoints_iterator.
+ summary_writer_fn: function to create summary writer.
+
+ Returns:
+ Eval metrics dictionary of the last checkpoint.
+ """
+
+ if not model_dir:
+ raise ValueError('model_dir must be set.')
+
+ def terminate_eval():
+ tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
+ eval_timeout)
+ return True
+
+ summary_writer = summary_writer_fn(model_dir, 'eval')
+ self.eval_summary_writer = summary_writer.writer
+
+ # Read checkpoints from the given model directory
+ # until `eval_timeout` seconds elapses.
+ for checkpoint_path in tf.train.checkpoints_iterator(
+ model_dir,
+ min_interval_secs=min_eval_interval,
+ timeout=eval_timeout,
+ timeout_fn=terminate_eval):
+ eval_metric_result, current_step = self.evaluate_checkpoint(
+ checkpoint_path=checkpoint_path,
+ eval_input_fn=eval_input_fn,
+ eval_metric_fn=eval_metric_fn,
+ summary_writer=summary_writer)
+ if total_steps > 0 and current_step >= total_steps:
+ logging.info('Evaluation finished after training step %d', current_step)
+ break
+ return eval_metric_result
+
+ def evaluate_checkpoint(self,
+ checkpoint_path: Text,
+ eval_input_fn: Callable[[params_dict.ParamsDict],
+ tf.data.Dataset],
+ eval_metric_fn: Callable[[], Any],
+ summary_writer: Optional[SummaryWriter] = None):
+ """Runs distributed evaluation on the one checkpoint.
+
+ Args:
+ checkpoint_path: the checkpoint to evaluate.
+ eval_input_fn: (Optional) same type as train_input_fn. If not None, will
+ trigger evaluting metric on eval data. If None, will not run eval step.
+ eval_metric_fn: metric_fn for evaluation in test_step.
+ summary_writer: function to create summary writer.
+
+ Returns:
+ Eval metrics dictionary of the last checkpoint.
+ """
+ if not callable(eval_metric_fn):
+ raise ValueError('if `eval_metric_fn` is specified, '
+ 'eval_metric_fn must be a callable.')
+
+ old_phase = tf_keras.backend.learning_phase()
+ tf_keras.backend.set_learning_phase(0)
+ params = self._params
+ strategy = self._strategy
+ # To reduce unnecessary send/receive input pipeline operation, we place
+ # input pipeline ops in worker task.
+ with strategy.scope():
+
+ # To correctly place the model weights on accelerators,
+ # model and optimizer should be created in scope.
+ model = self.model_fn(params.as_dict())
+ checkpoint = tf.train.Checkpoint(model=model)
+
+ eval_metric = eval_metric_fn()
+ assert eval_metric, 'eval_metric does not exist'
+ test_step = self._create_test_step(strategy, model, metric=eval_metric)
+
+ logging.info('Starting to evaluate.')
+ if not checkpoint_path:
+ raise ValueError('checkpoint path is empty')
+ reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
+ if reader.has_tensor('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE'):
+ # Legacy keras optimizer iteration.
+ current_step = reader.get_tensor(
+ 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE')
+ else:
+ # New keras optimizer iteration.
+ current_step = reader.get_tensor(
+ 'optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE')
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', checkpoint_path)
+ status = checkpoint.restore(checkpoint_path)
+ status.expect_partial().assert_existing_objects_matched()
+
+ self.global_train_step = model.optimizer.iterations
+ eval_iterator = self._get_input_iterator(eval_input_fn, strategy)
+ eval_metric_result = self._run_evaluation(test_step, current_step,
+ eval_metric, eval_iterator)
+ logging.info('Step: %s evalation metric = %s.', current_step,
+ eval_metric_result)
+ summary_writer(metrics=eval_metric_result, step=current_step)
+ reset_states(eval_metric)
+
+ tf_keras.backend.set_learning_phase(old_phase)
+ return eval_metric_result, current_step
+
+ def predict(self):
+ return NotImplementedError('Unimplmented function.')
+
+
+class ExecutorBuilder(object):
+ """Builder of DistributedExecutor.
+
+ Example 1: Builds an executor with supported Strategy.
+ builder = ExecutorBuilder(
+ strategy_type='tpu',
+ strategy_config={'tpu': '/bns/xxx'})
+ dist_executor = builder.build_executor(
+ params=params,
+ model_fn=my_model_fn,
+ loss_fn=my_loss_fn,
+ metric_fn=my_metric_fn)
+
+ Example 2: Builds an executor with customized Strategy.
+ builder = ExecutorBuilder()
+ builder.strategy =
+ dist_executor = builder.build_executor(
+ params=params,
+ model_fn=my_model_fn,
+ loss_fn=my_loss_fn,
+ metric_fn=my_metric_fn)
+
+ Example 3: Builds a customized executor with customized Strategy.
+ class MyDistributedExecutor(DistributedExecutor):
+ # implementation ...
+
+ builder = ExecutorBuilder()
+ builder.strategy =
+ dist_executor = builder.build_executor(
+ class_ctor=MyDistributedExecutor,
+ params=params,
+ model_fn=my_model_fn,
+ loss_fn=my_loss_fn,
+ metric_fn=my_metric_fn)
+ """
+
+ def __init__(self, strategy_type=None, strategy_config=None):
+ _ = distribute_utils.configure_cluster(strategy_config.worker_hosts,
+ strategy_config.task_index)
+ """Constructor.
+
+ Args:
+ strategy_type: string. One of 'tpu', 'mirrored', 'multi_worker_mirrored'.
+ If None, the user is responsible to set the strategy before calling
+ build_executor(...).
+ strategy_config: necessary config for constructing the proper Strategy.
+ Check strategy_flags_dict() for examples of the structure.
+ """
+ self._strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=strategy_type,
+ num_gpus=strategy_config.num_gpus,
+ all_reduce_alg=strategy_config.all_reduce_alg,
+ num_packs=strategy_config.num_packs,
+ tpu_address=strategy_config.tpu)
+
+ @property
+ def strategy(self):
+ """Returns default checkpoint name."""
+ return self._strategy
+
+ @strategy.setter
+ def strategy(self, new_strategy):
+ """Sets default summary writer for the current thread."""
+ self._strategy = new_strategy
+
+ def build_executor(self,
+ class_ctor=DistributedExecutor,
+ params=None,
+ model_fn=None,
+ loss_fn=None,
+ **kwargs):
+ """Creates an executor according to strategy type.
+
+ See doc string of the DistributedExecutor.__init__ for more information of
+ the
+ input arguments.
+
+ Args:
+ class_ctor: A constructor of executor (default: DistributedExecutor).
+ params: ParamsDict, all the model parameters and runtime parameters.
+ model_fn: Keras model function.
+ loss_fn: loss function.
+ **kwargs: other arguments to the executor constructor.
+
+ Returns:
+ An instance of DistributedExecutor or its subclass.
+ """
+ if self._strategy is None:
+ raise ValueError('`strategy` should not be None. You need to specify '
+ '`strategy_type` in the builder contructor or directly '
+ 'set the `strategy` property of the builder.')
+ return class_ctor(
+ strategy=self._strategy,
+ params=params,
+ model_fn=model_fn,
+ loss_fn=loss_fn,
+ **kwargs)
diff --git a/modeling/official/legacy/detection/main.py b/modeling/official/legacy/detection/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e0274ed2dd0d6f0e5f6a002c6baf1ad78569c48
--- /dev/null
+++ b/modeling/official/legacy/detection/main.py
@@ -0,0 +1,264 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Main function to train various object detection models."""
+
+import functools
+import pprint
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.common import distribute_utils
+from official.legacy.detection.configs import factory as config_factory
+from official.legacy.detection.dataloader import input_reader
+from official.legacy.detection.dataloader import mode_keys as ModeKeys
+from official.legacy.detection.executor import distributed_executor as executor
+from official.legacy.detection.executor.detection_executor import DetectionDistributedExecutor
+from official.legacy.detection.modeling import factory as model_factory
+from official.modeling.hyperparams import params_dict
+from official.utils import hyperparams_flags
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+
+hyperparams_flags.initialize_common_flags()
+flags_core.define_log_steps()
+
+flags.DEFINE_bool('enable_xla', default=False, help='Enable XLA for GPU')
+
+flags.DEFINE_string(
+ 'mode',
+ default='train',
+ help='Mode to run: `train`, `eval` or `eval_once`.')
+
+flags.DEFINE_string(
+ 'model', default='retinanet',
+ help='Model to run: `retinanet`, `mask_rcnn` or `shapemask`.')
+
+flags.DEFINE_string('training_file_pattern', None,
+ 'Location of the train data.')
+
+flags.DEFINE_string('eval_file_pattern', None, 'Location of ther eval data')
+
+flags.DEFINE_string(
+ 'checkpoint_path', None,
+ 'The checkpoint path to eval. Only used in eval_once mode.')
+
+FLAGS = flags.FLAGS
+
+
+def run_executor(params,
+ mode,
+ checkpoint_path=None,
+ train_input_fn=None,
+ eval_input_fn=None,
+ callbacks=None,
+ prebuilt_strategy=None):
+ """Runs the object detection model on distribution strategy defined by the user."""
+
+ if params.architecture.use_bfloat16:
+ tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
+
+ model_builder = model_factory.model_generator(params)
+
+ if prebuilt_strategy is not None:
+ strategy = prebuilt_strategy
+ else:
+ strategy_config = params.strategy_config
+ distribute_utils.configure_cluster(strategy_config.worker_hosts,
+ strategy_config.task_index)
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=params.strategy_type,
+ num_gpus=strategy_config.num_gpus,
+ all_reduce_alg=strategy_config.all_reduce_alg,
+ num_packs=strategy_config.num_packs,
+ tpu_address=strategy_config.tpu)
+
+ num_workers = int(strategy.num_replicas_in_sync + 7) // 8
+ is_multi_host = (int(num_workers) >= 2)
+
+ if mode == 'train':
+
+ def _model_fn(params):
+ return model_builder.build_model(params, mode=ModeKeys.TRAIN)
+
+ logging.info(
+ 'Train num_replicas_in_sync %d num_workers %d is_multi_host %s',
+ strategy.num_replicas_in_sync, num_workers, is_multi_host)
+
+ dist_executor = DetectionDistributedExecutor(
+ strategy=strategy,
+ params=params,
+ model_fn=_model_fn,
+ loss_fn=model_builder.build_loss_fn,
+ is_multi_host=is_multi_host,
+ predict_post_process_fn=model_builder.post_processing,
+ trainable_variables_filter=model_builder
+ .make_filter_trainable_variables_fn())
+
+ if is_multi_host:
+ train_input_fn = functools.partial(
+ train_input_fn,
+ batch_size=params.train.batch_size // strategy.num_replicas_in_sync)
+
+ return dist_executor.train(
+ train_input_fn=train_input_fn,
+ model_dir=params.model_dir,
+ iterations_per_loop=params.train.iterations_per_loop,
+ total_steps=params.train.total_steps,
+ init_checkpoint=model_builder.make_restore_checkpoint_fn(),
+ custom_callbacks=callbacks,
+ save_config=True)
+ elif mode == 'eval' or mode == 'eval_once':
+
+ def _model_fn(params):
+ return model_builder.build_model(params, mode=ModeKeys.PREDICT_WITH_GT)
+
+ logging.info('Eval num_replicas_in_sync %d num_workers %d is_multi_host %s',
+ strategy.num_replicas_in_sync, num_workers, is_multi_host)
+
+ if is_multi_host:
+ eval_input_fn = functools.partial(
+ eval_input_fn,
+ batch_size=params.eval.batch_size // strategy.num_replicas_in_sync)
+
+ dist_executor = DetectionDistributedExecutor(
+ strategy=strategy,
+ params=params,
+ model_fn=_model_fn,
+ loss_fn=model_builder.build_loss_fn,
+ is_multi_host=is_multi_host,
+ predict_post_process_fn=model_builder.post_processing,
+ trainable_variables_filter=model_builder
+ .make_filter_trainable_variables_fn())
+
+ if mode == 'eval':
+ results = dist_executor.evaluate_from_model_dir(
+ model_dir=params.model_dir,
+ eval_input_fn=eval_input_fn,
+ eval_metric_fn=model_builder.eval_metrics,
+ eval_timeout=params.eval.eval_timeout,
+ min_eval_interval=params.eval.min_eval_interval,
+ total_steps=params.train.total_steps)
+ else:
+ # Run evaluation once for a single checkpoint.
+ if not checkpoint_path:
+ raise ValueError('checkpoint_path cannot be empty.')
+ if tf.io.gfile.isdir(checkpoint_path):
+ checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
+ summary_writer = executor.SummaryWriter(params.model_dir, 'eval')
+ results, _ = dist_executor.evaluate_checkpoint(
+ checkpoint_path=checkpoint_path,
+ eval_input_fn=eval_input_fn,
+ eval_metric_fn=model_builder.eval_metrics,
+ summary_writer=summary_writer)
+ for k, v in results.items():
+ logging.info('Final eval metric %s: %f', k, v)
+ return results
+ else:
+ raise ValueError('Mode not found: %s.' % mode)
+
+
+def run(callbacks=None):
+ """Runs the experiment."""
+ keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)
+
+ params = config_factory.config_generator(FLAGS.model)
+
+ params = params_dict.override_params_dict(
+ params, FLAGS.config_file, is_strict=True)
+
+ params = params_dict.override_params_dict(
+ params, FLAGS.params_override, is_strict=True)
+ params.override(
+ {
+ 'strategy_type': FLAGS.strategy_type,
+ 'model_dir': FLAGS.model_dir,
+ 'strategy_config': executor.strategy_flags_dict(),
+ },
+ is_strict=False)
+
+ # Make sure use_tpu and strategy_type are in sync.
+ params.use_tpu = (params.strategy_type == 'tpu')
+
+ if not params.use_tpu:
+ params.override({
+ 'architecture': {
+ 'use_bfloat16': False,
+ },
+ 'norm_activation': {
+ 'use_sync_bn': False,
+ },
+ }, is_strict=True)
+
+ params.validate()
+ params.lock()
+ pp = pprint.PrettyPrinter()
+ params_str = pp.pformat(params.as_dict())
+ logging.info('Model Parameters: %s', params_str)
+
+ train_input_fn = None
+ eval_input_fn = None
+ training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
+ eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
+ if not training_file_pattern and not eval_file_pattern:
+ raise ValueError('Must provide at least one of training_file_pattern and '
+ 'eval_file_pattern.')
+
+ if training_file_pattern:
+ # Use global batch size for single host.
+ train_input_fn = input_reader.InputFn(
+ file_pattern=training_file_pattern,
+ params=params,
+ mode=input_reader.ModeKeys.TRAIN,
+ batch_size=params.train.batch_size)
+
+ if eval_file_pattern:
+ eval_input_fn = input_reader.InputFn(
+ file_pattern=eval_file_pattern,
+ params=params,
+ mode=input_reader.ModeKeys.PREDICT_WITH_GT,
+ batch_size=params.eval.batch_size,
+ num_examples=params.eval.eval_samples)
+
+ if callbacks is None:
+ callbacks = []
+
+ if FLAGS.log_steps:
+ callbacks.append(
+ keras_utils.TimeHistory(
+ batch_size=params.train.batch_size,
+ log_steps=FLAGS.log_steps,
+ ))
+
+ return run_executor(
+ params,
+ FLAGS.mode,
+ checkpoint_path=FLAGS.checkpoint_path,
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ callbacks=callbacks)
+
+
+def main(argv):
+ del argv # Unused.
+
+ run()
+
+
+if __name__ == '__main__':
+ tf.config.set_soft_device_placement(True)
+ app.run(main)
diff --git a/modeling/official/legacy/detection/modeling/__init__.py b/modeling/official/legacy/detection/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/modeling/architecture/__init__.py b/modeling/official/legacy/detection/modeling/architecture/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/modeling/architecture/factory.py b/modeling/official/legacy/detection/modeling/architecture/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..744dc050795d891274555a5a8b7c7d67d8ae7b4d
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/factory.py
@@ -0,0 +1,217 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model architecture factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from official.legacy.detection.modeling.architecture import fpn
+from official.legacy.detection.modeling.architecture import heads
+from official.legacy.detection.modeling.architecture import identity
+from official.legacy.detection.modeling.architecture import nn_ops
+from official.legacy.detection.modeling.architecture import resnet
+from official.legacy.detection.modeling.architecture import spinenet
+
+
+def norm_activation_generator(params):
+ return nn_ops.norm_activation_builder(
+ momentum=params.batch_norm_momentum,
+ epsilon=params.batch_norm_epsilon,
+ trainable=params.batch_norm_trainable,
+ activation=params.activation)
+
+
+def backbone_generator(params):
+ """Generator function for various backbone models."""
+ if params.architecture.backbone == 'resnet':
+ resnet_params = params.resnet
+ backbone_fn = resnet.Resnet(
+ resnet_depth=resnet_params.resnet_depth,
+ activation=params.norm_activation.activation,
+ norm_activation=norm_activation_generator(
+ params.norm_activation))
+ elif params.architecture.backbone == 'spinenet':
+ spinenet_params = params.spinenet
+ backbone_fn = spinenet.SpineNetBuilder(model_id=spinenet_params.model_id)
+ else:
+ raise ValueError('Backbone model `{}` is not supported.'
+ .format(params.architecture.backbone))
+
+ return backbone_fn
+
+
+def multilevel_features_generator(params):
+ """Generator function for various FPN models."""
+ if params.architecture.multilevel_features == 'fpn':
+ fpn_params = params.fpn
+ fpn_fn = fpn.Fpn(
+ min_level=params.architecture.min_level,
+ max_level=params.architecture.max_level,
+ fpn_feat_dims=fpn_params.fpn_feat_dims,
+ use_separable_conv=fpn_params.use_separable_conv,
+ activation=params.norm_activation.activation,
+ use_batch_norm=fpn_params.use_batch_norm,
+ norm_activation=norm_activation_generator(
+ params.norm_activation))
+ elif params.architecture.multilevel_features == 'identity':
+ fpn_fn = identity.Identity()
+ else:
+ raise ValueError('The multi-level feature model `{}` is not supported.'
+ .format(params.architecture.multilevel_features))
+ return fpn_fn
+
+
+def retinanet_head_generator(params):
+ """Generator function for RetinaNet head architecture."""
+ head_params = params.retinanet_head
+ anchors_per_location = params.anchor.num_scales * len(
+ params.anchor.aspect_ratios)
+ return heads.RetinanetHead(
+ params.architecture.min_level,
+ params.architecture.max_level,
+ params.architecture.num_classes,
+ anchors_per_location,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def rpn_head_generator(params):
+ """Generator function for RPN head architecture."""
+ head_params = params.rpn_head
+ anchors_per_location = params.anchor.num_scales * len(
+ params.anchor.aspect_ratios)
+ return heads.RpnHead(
+ params.architecture.min_level,
+ params.architecture.max_level,
+ anchors_per_location,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def oln_rpn_head_generator(params):
+ """Generator function for OLN-proposal (OLN-RPN) head architecture."""
+ head_params = params.rpn_head
+ anchors_per_location = params.anchor.num_scales * len(
+ params.anchor.aspect_ratios)
+ return heads.OlnRpnHead(
+ params.architecture.min_level,
+ params.architecture.max_level,
+ anchors_per_location,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def fast_rcnn_head_generator(params):
+ """Generator function for Fast R-CNN head architecture."""
+ head_params = params.frcnn_head
+ return heads.FastrcnnHead(
+ params.architecture.num_classes,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ head_params.num_fcs,
+ head_params.fc_dims,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def oln_box_score_head_generator(params):
+ """Generator function for Scoring Fast R-CNN head architecture."""
+ head_params = params.frcnn_head
+ return heads.OlnBoxScoreHead(
+ params.architecture.num_classes,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ head_params.num_fcs,
+ head_params.fc_dims,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def mask_rcnn_head_generator(params):
+ """Generator function for Mask R-CNN head architecture."""
+ head_params = params.mrcnn_head
+ return heads.MaskrcnnHead(
+ params.architecture.num_classes,
+ params.architecture.mask_target_size,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def oln_mask_score_head_generator(params):
+ """Generator function for Scoring Mask R-CNN head architecture."""
+ head_params = params.mrcnn_head
+ return heads.OlnMaskScoreHead(
+ params.architecture.num_classes,
+ params.architecture.mask_target_size,
+ head_params.num_convs,
+ head_params.num_filters,
+ head_params.use_separable_conv,
+ params.norm_activation.activation,
+ head_params.use_batch_norm,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def shapeprior_head_generator(params):
+ """Generator function for shape prior head architecture."""
+ head_params = params.shapemask_head
+ return heads.ShapemaskPriorHead(
+ params.architecture.num_classes,
+ head_params.num_downsample_channels,
+ head_params.mask_crop_size,
+ head_params.use_category_for_mask,
+ head_params.shape_prior_path)
+
+
+def coarsemask_head_generator(params):
+ """Generator function for ShapeMask coarse mask head architecture."""
+ head_params = params.shapemask_head
+ return heads.ShapemaskCoarsemaskHead(
+ params.architecture.num_classes,
+ head_params.num_downsample_channels,
+ head_params.mask_crop_size,
+ head_params.use_category_for_mask,
+ head_params.num_convs,
+ norm_activation=norm_activation_generator(params.norm_activation))
+
+
+def finemask_head_generator(params):
+ """Generator function for Shapemask fine mask head architecture."""
+ head_params = params.shapemask_head
+ return heads.ShapemaskFinemaskHead(
+ params.architecture.num_classes,
+ head_params.num_downsample_channels,
+ head_params.mask_crop_size,
+ head_params.use_category_for_mask,
+ head_params.num_convs,
+ head_params.upsample_factor)
diff --git a/modeling/official/legacy/detection/modeling/architecture/fpn.py b/modeling/official/legacy/detection/modeling/architecture/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d14626b2717a732c1cdebae62b7ac8746a287d7
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/fpn.py
@@ -0,0 +1,151 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Feature Pyramid Networks.
+
+Feature Pyramid Networks were proposed in:
+[1] Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan,
+ , and Serge Belongie
+ Feature Pyramid Networks for Object Detection. CVPR 2017.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.modeling.architecture import nn_ops
+from official.legacy.detection.ops import spatial_transform_ops
+
+
+class Fpn(object):
+ """Feature pyramid networks."""
+
+ def __init__(self,
+ min_level=3,
+ max_level=7,
+ fpn_feat_dims=256,
+ use_separable_conv=False,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(
+ activation='relu')):
+ """FPN initialization function.
+
+ Args:
+ min_level: `int` minimum level in FPN output feature maps.
+ max_level: `int` maximum level in FPN output feature maps.
+ fpn_feat_dims: `int` number of filters in FPN layers.
+ use_separable_conv: `bool`, if True use separable convolution for
+ convolution in FPN layers.
+ activation: the activation function.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer
+ followed by an optional activation layer.
+ """
+ self._min_level = min_level
+ self._max_level = max_level
+ self._fpn_feat_dims = fpn_feat_dims
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.SeparableConv2D, depth_multiplier=1)
+ else:
+ self._conv2d_op = tf_keras.layers.Conv2D
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+ self._norm_activation = norm_activation
+
+ self._norm_activations = {}
+ self._lateral_conv2d_op = {}
+ self._post_hoc_conv2d_op = {}
+ self._coarse_conv2d_op = {}
+ for level in range(self._min_level, self._max_level + 1):
+ if self._use_batch_norm:
+ self._norm_activations[level] = norm_activation(
+ use_activation=False, name='p%d-bn' % level)
+ self._lateral_conv2d_op[level] = self._conv2d_op(
+ filters=self._fpn_feat_dims,
+ kernel_size=(1, 1),
+ padding='same',
+ name='l%d' % level)
+ self._post_hoc_conv2d_op[level] = self._conv2d_op(
+ filters=self._fpn_feat_dims,
+ strides=(1, 1),
+ kernel_size=(3, 3),
+ padding='same',
+ name='post_hoc_d%d' % level)
+ self._coarse_conv2d_op[level] = self._conv2d_op(
+ filters=self._fpn_feat_dims,
+ strides=(2, 2),
+ kernel_size=(3, 3),
+ padding='same',
+ name='p%d' % level)
+
+ def __call__(self, multilevel_features, is_training=None):
+ """Returns the FPN features for a given multilevel features.
+
+ Args:
+ multilevel_features: a `dict` containing `int` keys for continuous feature
+ levels, e.g., [2, 3, 4, 5]. The values are corresponding features with
+ shape [batch_size, height_l, width_l, num_filters].
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ a `dict` containing `int` keys for continuous feature levels
+ [min_level, min_level + 1, ..., max_level]. The values are corresponding
+ FPN features with shape [batch_size, height_l, width_l, fpn_feat_dims].
+ """
+ input_levels = list(multilevel_features.keys())
+ if min(input_levels) > self._min_level:
+ raise ValueError(
+ 'The minimum backbone level %d should be '%(min(input_levels)) +
+ 'less or equal to FPN minimum level %d.:'%(self._min_level))
+ backbone_max_level = min(max(input_levels), self._max_level)
+ with tf.name_scope('fpn'):
+ # Adds lateral connections.
+ feats_lateral = {}
+ for level in range(self._min_level, backbone_max_level + 1):
+ feats_lateral[level] = self._lateral_conv2d_op[level](
+ multilevel_features[level])
+
+ # Adds top-down path.
+ feats = {backbone_max_level: feats_lateral[backbone_max_level]}
+ for level in range(backbone_max_level - 1, self._min_level - 1, -1):
+ feats[level] = spatial_transform_ops.nearest_upsampling(
+ feats[level + 1], 2) + feats_lateral[level]
+
+ # Adds post-hoc 3x3 convolution kernel.
+ for level in range(self._min_level, backbone_max_level + 1):
+ feats[level] = self._post_hoc_conv2d_op[level](feats[level])
+
+ # Adds coarser FPN levels introduced for RetinaNet.
+ for level in range(backbone_max_level + 1, self._max_level + 1):
+ feats_in = feats[level - 1]
+ if level > backbone_max_level + 1:
+ feats_in = self._activation_op(feats_in)
+ feats[level] = self._coarse_conv2d_op[level](feats_in)
+ if self._use_batch_norm:
+ # Adds batch_norm layer.
+ for level in range(self._min_level, self._max_level + 1):
+ feats[level] = self._norm_activations[level](
+ feats[level], is_training=is_training)
+ return feats
diff --git a/modeling/official/legacy/detection/modeling/architecture/heads.py b/modeling/official/legacy/detection/modeling/architecture/heads.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf410b1bd3a9cc0ccff90ea2727f633056c3918a
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/heads.py
@@ -0,0 +1,1273 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Classes to build various prediction heads in all supported models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.modeling.architecture import nn_ops
+from official.legacy.detection.ops import spatial_transform_ops
+
+
+class RpnHead(tf_keras.layers.Layer):
+ """Region Proposal Network head."""
+
+ def __init__(
+ self,
+ min_level,
+ max_level,
+ anchors_per_location,
+ num_convs=2,
+ num_filters=256,
+ use_separable_conv=False,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(activation='relu')):
+ """Initialize params to build Region Proposal Network head.
+
+ Args:
+ min_level: `int` number of minimum feature level.
+ max_level: `int` number of maximum feature level.
+ anchors_per_location: `int` number of number of anchors per pixel
+ location.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the prediction.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ """
+ super().__init__(autocast=False)
+
+ self._min_level = min_level
+ self._max_level = max_level
+ self._anchors_per_location = anchors_per_location
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.Conv2D,
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
+ bias_initializer=tf.zeros_initializer())
+
+ self._rpn_conv = self._conv2d_op(
+ num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ activation=(None if self._use_batch_norm else self._activation_op),
+ padding='same',
+ name='rpn')
+ self._rpn_class_conv = self._conv2d_op(
+ anchors_per_location,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='rpn-class')
+ self._rpn_box_conv = self._conv2d_op(
+ 4 * anchors_per_location,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='rpn-box')
+
+ self._norm_activations = {}
+ if self._use_batch_norm:
+ for level in range(self._min_level, self._max_level + 1):
+ self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
+ level)
+
+ def _shared_rpn_heads(self, features, anchors_per_location, level,
+ is_training):
+ """Shared RPN heads."""
+ features = self._rpn_conv(features)
+ if self._use_batch_norm:
+ # The batch normalization layers are not shared between levels.
+ features = self._norm_activations[level](
+ features, is_training=is_training)
+ # Proposal classification scores
+ scores = self._rpn_class_conv(features)
+ # Proposal bbox regression deltas
+ bboxes = self._rpn_box_conv(features)
+
+ return scores, bboxes
+
+ def call(self, features, is_training=None):
+
+ scores_outputs = {}
+ box_outputs = {}
+
+ with tf.name_scope('rpn_head'):
+ for level in range(self._min_level, self._max_level + 1):
+ scores_output, box_output = self._shared_rpn_heads(
+ features[level], self._anchors_per_location, level, is_training)
+ scores_outputs[level] = scores_output
+ box_outputs[level] = box_output
+ return scores_outputs, box_outputs
+
+
+class OlnRpnHead(tf_keras.layers.Layer):
+ """Region Proposal Network for Object Localization Network (OLN)."""
+
+ def __init__(
+ self,
+ min_level,
+ max_level,
+ anchors_per_location,
+ num_convs=2,
+ num_filters=256,
+ use_separable_conv=False,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(activation='relu')):
+ """Initialize params to build Region Proposal Network head.
+
+ Args:
+ min_level: `int` number of minimum feature level.
+ max_level: `int` number of maximum feature level.
+ anchors_per_location: `int` number of number of anchors per pixel
+ location.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the prediction.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ """
+ self._min_level = min_level
+ self._max_level = max_level
+ self._anchors_per_location = anchors_per_location
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.Conv2D,
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
+ bias_initializer=tf.zeros_initializer())
+
+ self._rpn_conv = self._conv2d_op(
+ num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ activation=(None if self._use_batch_norm else self._activation_op),
+ padding='same',
+ name='rpn')
+ self._rpn_class_conv = self._conv2d_op(
+ anchors_per_location,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='rpn-class')
+ self._rpn_box_conv = self._conv2d_op(
+ 4 * anchors_per_location,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='rpn-box-lrtb')
+ self._rpn_center_conv = self._conv2d_op(
+ anchors_per_location,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='rpn-centerness')
+
+ self._norm_activations = {}
+ if self._use_batch_norm:
+ for level in range(self._min_level, self._max_level + 1):
+ self._norm_activations[level] = norm_activation(name='rpn-l%d-bn' %
+ level)
+
+ def _shared_rpn_heads(self, features, anchors_per_location, level,
+ is_training):
+ """Shared RPN heads."""
+ features = self._rpn_conv(features)
+ if self._use_batch_norm:
+ # The batch normalization layers are not shared between levels.
+ features = self._norm_activations[level](
+ features, is_training=is_training)
+ # Feature L2 normalization for training stability
+ features = tf.math.l2_normalize(
+ features,
+ axis=-1,
+ name='rpn-norm',)
+ # Proposal classification scores
+ scores = self._rpn_class_conv(features)
+ # Proposal bbox regression deltas
+ bboxes = self._rpn_box_conv(features)
+ # Proposal centerness scores
+ centers = self._rpn_center_conv(features)
+
+ return scores, bboxes, centers
+
+ def __call__(self, features, is_training=None):
+
+ scores_outputs = {}
+ box_outputs = {}
+ center_outputs = {}
+
+ with tf.name_scope('rpn_head'):
+ for level in range(self._min_level, self._max_level + 1):
+ scores_output, box_output, center_output = self._shared_rpn_heads(
+ features[level], self._anchors_per_location, level, is_training)
+ scores_outputs[level] = scores_output
+ box_outputs[level] = box_output
+ center_outputs[level] = center_output
+ return scores_outputs, box_outputs, center_outputs
+
+
+class FastrcnnHead(tf_keras.layers.Layer):
+ """Fast R-CNN box head."""
+
+ def __init__(
+ self,
+ num_classes,
+ num_convs=0,
+ num_filters=256,
+ use_separable_conv=False,
+ num_fcs=2,
+ fc_dims=1024,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(activation='relu')):
+ """Initialize params to build Fast R-CNN box head.
+
+ Args:
+ num_classes: a integer for the number of classes.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the FC layers.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ num_fcs: `int` number that represents the number of FC layers before the
+ predictions.
+ fc_dims: `int` number that represents the number of dimension of the FC
+ layers.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ """
+ super(FastrcnnHead, self).__init__(autocast=False)
+
+ self._num_classes = num_classes
+
+ self._num_convs = num_convs
+ self._num_filters = num_filters
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.Conv2D,
+ kernel_initializer=tf_keras.initializers.VarianceScaling(
+ scale=2, mode='fan_out', distribution='untruncated_normal'),
+ bias_initializer=tf.zeros_initializer())
+
+ self._num_fcs = num_fcs
+ self._fc_dims = fc_dims
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+ self._norm_activation = norm_activation
+
+ self._conv_ops = []
+ self._conv_bn_ops = []
+ for i in range(self._num_convs):
+ self._conv_ops.append(
+ self._conv2d_op(
+ self._num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding='same',
+ dilation_rate=(1, 1),
+ activation=(None
+ if self._use_batch_norm else self._activation_op),
+ name='conv_{}'.format(i)))
+ if self._use_batch_norm:
+ self._conv_bn_ops.append(self._norm_activation())
+
+ self._fc_ops = []
+ self._fc_bn_ops = []
+ for i in range(self._num_fcs):
+ self._fc_ops.append(
+ tf_keras.layers.Dense(
+ units=self._fc_dims,
+ activation=(None
+ if self._use_batch_norm else self._activation_op),
+ name='fc{}'.format(i)))
+ if self._use_batch_norm:
+ self._fc_bn_ops.append(self._norm_activation(fused=False))
+
+ self._class_predict = tf_keras.layers.Dense(
+ self._num_classes,
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
+ bias_initializer=tf.zeros_initializer(),
+ name='class-predict')
+ self._box_predict = tf_keras.layers.Dense(
+ self._num_classes * 4,
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.001),
+ bias_initializer=tf.zeros_initializer(),
+ name='box-predict')
+
+ def call(self, roi_features, is_training=None):
+ """Box and class branches for the Mask-RCNN model.
+
+ Args:
+ roi_features: A ROI feature tensor of shape [batch_size, num_rois,
+ height_l, width_l, num_filters].
+ is_training: `boolean`, if True if model is in training mode.
+
+ Returns:
+ class_outputs: a tensor with a shape of
+ [batch_size, num_rois, num_classes], representing the class predictions.
+ box_outputs: a tensor with a shape of
+ [batch_size, num_rois, num_classes * 4], representing the box
+ predictions.
+ """
+
+ with tf.name_scope(
+ 'fast_rcnn_head'):
+ # reshape inputs beofre FC.
+ _, num_rois, height, width, filters = roi_features.get_shape().as_list()
+
+ net = tf.reshape(roi_features, [-1, height, width, filters])
+ for i in range(self._num_convs):
+ net = self._conv_ops[i](net)
+ if self._use_batch_norm:
+ net = self._conv_bn_ops[i](net, is_training=is_training)
+
+ filters = self._num_filters if self._num_convs > 0 else filters
+ net = tf.reshape(net, [-1, num_rois, height * width * filters])
+
+ for i in range(self._num_fcs):
+ net = self._fc_ops[i](net)
+ if self._use_batch_norm:
+ net = self._fc_bn_ops[i](net, is_training=is_training)
+
+ class_outputs = self._class_predict(net)
+ box_outputs = self._box_predict(net)
+ return class_outputs, box_outputs
+
+
+class OlnBoxScoreHead(tf_keras.layers.Layer):
+ """Box head of Object Localization Network (OLN)."""
+
+ def __init__(
+ self,
+ num_classes,
+ num_convs=0,
+ num_filters=256,
+ use_separable_conv=False,
+ num_fcs=2,
+ fc_dims=1024,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(activation='relu')):
+ """Initialize params to build OLN box head.
+
+ Args:
+ num_classes: a integer for the number of classes.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the FC layers.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ num_fcs: `int` number that represents the number of FC layers before the
+ predictions.
+ fc_dims: `int` number that represents the number of dimension of the FC
+ layers.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ """
+ self._num_classes = num_classes
+
+ self._num_convs = num_convs
+ self._num_filters = num_filters
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.Conv2D,
+ kernel_initializer=tf_keras.initializers.VarianceScaling(
+ scale=2, mode='fan_out', distribution='untruncated_normal'),
+ bias_initializer=tf.zeros_initializer())
+
+ self._num_fcs = num_fcs
+ self._fc_dims = fc_dims
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+ self._norm_activation = norm_activation
+
+ self._conv_ops = []
+ self._conv_bn_ops = []
+ for i in range(self._num_convs):
+ self._conv_ops.append(
+ self._conv2d_op(
+ self._num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding='same',
+ dilation_rate=(1, 1),
+ activation=(None
+ if self._use_batch_norm else self._activation_op),
+ name='conv_{}'.format(i)))
+ if self._use_batch_norm:
+ self._conv_bn_ops.append(self._norm_activation())
+
+ self._fc_ops = []
+ self._fc_bn_ops = []
+ for i in range(self._num_fcs):
+ self._fc_ops.append(
+ tf_keras.layers.Dense(
+ units=self._fc_dims,
+ activation=(None
+ if self._use_batch_norm else self._activation_op),
+ name='fc{}'.format(i)))
+ if self._use_batch_norm:
+ self._fc_bn_ops.append(self._norm_activation(fused=False))
+
+ self._class_predict = tf_keras.layers.Dense(
+ self._num_classes,
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
+ bias_initializer=tf.zeros_initializer(),
+ name='class-predict')
+ self._box_predict = tf_keras.layers.Dense(
+ self._num_classes * 4,
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.001),
+ bias_initializer=tf.zeros_initializer(),
+ name='box-predict')
+ self._score_predict = tf_keras.layers.Dense(
+ 1,
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
+ bias_initializer=tf.zeros_initializer(),
+ name='score-predict')
+
+ def __call__(self, roi_features, is_training=None):
+ """Box and class branches for the Mask-RCNN model.
+
+ Args:
+ roi_features: A ROI feature tensor of shape [batch_size, num_rois,
+ height_l, width_l, num_filters].
+ is_training: `boolean`, if True if model is in training mode.
+
+ Returns:
+ class_outputs: a tensor with a shape of
+ [batch_size, num_rois, num_classes], representing the class predictions.
+ box_outputs: a tensor with a shape of
+ [batch_size, num_rois, num_classes * 4], representing the box
+ predictions.
+ """
+
+ with tf.name_scope('fast_rcnn_head'):
+ # reshape inputs beofre FC.
+ _, num_rois, height, width, filters = roi_features.get_shape().as_list()
+
+ net = tf.reshape(roi_features, [-1, height, width, filters])
+ for i in range(self._num_convs):
+ net = self._conv_ops[i](net)
+ if self._use_batch_norm:
+ net = self._conv_bn_ops[i](net, is_training=is_training)
+
+ filters = self._num_filters if self._num_convs > 0 else filters
+ net = tf.reshape(net, [-1, num_rois, height * width * filters])
+
+ for i in range(self._num_fcs):
+ net = self._fc_ops[i](net)
+ if self._use_batch_norm:
+ net = self._fc_bn_ops[i](net, is_training=is_training)
+
+ class_outputs = self._class_predict(net)
+ box_outputs = self._box_predict(net)
+ score_outputs = self._score_predict(net)
+ return class_outputs, box_outputs, score_outputs
+
+
+class MaskrcnnHead(tf_keras.layers.Layer):
+ """Mask R-CNN head."""
+
+ def __init__(
+ self,
+ num_classes,
+ mask_target_size,
+ num_convs=4,
+ num_filters=256,
+ use_separable_conv=False,
+ activation='relu',
+ use_batch_norm=True,
+ norm_activation=nn_ops.norm_activation_builder(activation='relu')):
+ """Initialize params to build Fast R-CNN head.
+
+ Args:
+ num_classes: a integer for the number of classes.
+ mask_target_size: a integer that is the resolution of masks.
+ num_convs: `int` number that represents the number of the intermediate
+ conv layers before the prediction.
+ num_filters: `int` number that represents the number of filters of the
+ intermediate conv layers.
+ use_separable_conv: `bool`, indicating whether the separable conv layers
+ is used.
+ activation: activation function. Support 'relu' and 'swish'.
+ use_batch_norm: 'bool', indicating whether batchnorm layers are added.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ """
+ super(MaskrcnnHead, self).__init__(autocast=False)
+ self._num_classes = num_classes
+ self._mask_target_size = mask_target_size
+
+ self._num_convs = num_convs
+ self._num_filters = num_filters
+ if use_separable_conv:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.SeparableConv2D,
+ depth_multiplier=1,
+ bias_initializer=tf.zeros_initializer())
+ else:
+ self._conv2d_op = functools.partial(
+ tf_keras.layers.Conv2D,
+ kernel_initializer=tf_keras.initializers.VarianceScaling(
+ scale=2, mode='fan_out', distribution='untruncated_normal'),
+ bias_initializer=tf.zeros_initializer())
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._use_batch_norm = use_batch_norm
+ self._norm_activation = norm_activation
+ self._conv2d_ops = []
+ for i in range(self._num_convs):
+ self._conv2d_ops.append(
+ self._conv2d_op(
+ self._num_filters,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ padding='same',
+ dilation_rate=(1, 1),
+ activation=(None
+ if self._use_batch_norm else self._activation_op),
+ name='mask-conv-l%d' % i))
+ self._mask_conv_transpose = tf_keras.layers.Conv2DTranspose(
+ self._num_filters,
+ kernel_size=(2, 2),
+ strides=(2, 2),
+ padding='valid',
+ activation=(None if self._use_batch_norm else self._activation_op),
+ kernel_initializer=tf_keras.initializers.VarianceScaling(
+ scale=2, mode='fan_out', distribution='untruncated_normal'),
+ bias_initializer=tf.zeros_initializer(),
+ name='conv5-mask')
+
+ with tf.name_scope('mask_head'):
+ self._mask_conv2d_op = self._conv2d_op(
+ self._num_classes,
+ kernel_size=(1, 1),
+ strides=(1, 1),
+ padding='valid',
+ name='mask_fcn_logits')
+
+ def call(self, roi_features, class_indices, is_training=None):
+ """Mask branch for the Mask-RCNN model.
+
+ Args:
+ roi_features: A ROI feature tensor of shape [batch_size, num_rois,
+ height_l, width_l, num_filters].
+ class_indices: a Tensor of shape [batch_size, num_rois], indicating which
+ class the ROI is.
+ is_training: `boolean`, if True if model is in training mode.
+
+ Returns:
+ mask_outputs: a tensor with a shape of
+ [batch_size, num_masks, mask_height, mask_width, num_classes],
+ representing the mask predictions.
+ fg_gather_indices: a tensor with a shape of [batch_size, num_masks, 2],
+ representing the fg mask targets.
+ Raises:
+ ValueError: If boxes is not a rank-3 tensor or the last dimension of
+ boxes is not 4.
+ """
+
+ with tf.name_scope('mask_head'):
+ _, num_rois, height, width, filters = roi_features.get_shape().as_list()
+ net = tf.reshape(roi_features, [-1, height, width, filters])
+
+ for i in range(self._num_convs):
+ net = self._conv2d_ops[i](net)
+ if self._use_batch_norm:
+ net = self._norm_activation()(net, is_training=is_training)
+
+ net = self._mask_conv_transpose(net)
+ if self._use_batch_norm:
+ net = self._norm_activation()(net, is_training=is_training)
+
+ mask_outputs = self._mask_conv2d_op(net)
+ mask_outputs = tf.reshape(mask_outputs, [
+ -1, num_rois, self._mask_target_size, self._mask_target_size,
+ self._num_classes
+ ])
+
+ with tf.name_scope('masks_post_processing'):
+ mask_outputs = tf.gather(
+ mask_outputs,
+ tf.cast(class_indices, tf.int32),
+ axis=-1,
+ batch_dims=2,
+ )
+ return mask_outputs
+
+
+class RetinanetHead(object):
+ """RetinaNet head."""
+
+ def __init__(
+ self,
+ min_level,
+ max_level,
+ num_classes,
+ anchors_per_location,
+ num_convs=4,
+ num_filters=256,
+ use_separable_conv=False,
+ norm_activation=nn_ops.norm_activation_builder(activation='relu')):
+ """Initialize params to build RetinaNet head.
+
+ Args:
+ min_level: `int` number of minimum feature level.
+ max_level: `int` number of maximum feature level.
+ num_classes: `int` number of classification categories.
+ anchors_per_location: `int` number of anchors per pixel location.
+ num_convs: `int` number of stacked convolution before the last prediction
+ layer.
+ num_filters: `int` number of filters used in the head architecture.
+ use_separable_conv: `bool` to indicate whether to use separable
+ convoluation.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ """
+ self._min_level = min_level
+ self._max_level = max_level
+
+ self._num_classes = num_classes
+ self._anchors_per_location = anchors_per_location
+
+ self._num_convs = num_convs
+ self._num_filters = num_filters
+ self._use_separable_conv = use_separable_conv
+ with tf.name_scope('class_net') as scope_name:
+ self._class_name_scope = tf.name_scope(scope_name)
+ with tf.name_scope('box_net') as scope_name:
+ self._box_name_scope = tf.name_scope(scope_name)
+ self._build_class_net_layers(norm_activation)
+ self._build_box_net_layers(norm_activation)
+
+ def _class_net_batch_norm_name(self, i, level):
+ return 'class-%d-%d' % (i, level)
+
+ def _box_net_batch_norm_name(self, i, level):
+ return 'box-%d-%d' % (i, level)
+
+ def _build_class_net_layers(self, norm_activation):
+ """Build re-usable layers for class prediction network."""
+ if self._use_separable_conv:
+ self._class_predict = tf_keras.layers.SeparableConv2D(
+ self._num_classes * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ padding='same',
+ name='class-predict')
+ else:
+ self._class_predict = tf_keras.layers.Conv2D(
+ self._num_classes * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=1e-5),
+ padding='same',
+ name='class-predict')
+ self._class_conv = []
+ self._class_norm_activation = {}
+ for i in range(self._num_convs):
+ if self._use_separable_conv:
+ self._class_conv.append(
+ tf_keras.layers.SeparableConv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ activation=None,
+ padding='same',
+ name='class-' + str(i)))
+ else:
+ self._class_conv.append(
+ tf_keras.layers.Conv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf_keras.initializers.RandomNormal(
+ stddev=0.01),
+ activation=None,
+ padding='same',
+ name='class-' + str(i)))
+ for level in range(self._min_level, self._max_level + 1):
+ name = self._class_net_batch_norm_name(i, level)
+ self._class_norm_activation[name] = norm_activation(name=name)
+
+ def _build_box_net_layers(self, norm_activation):
+ """Build re-usable layers for box prediction network."""
+ if self._use_separable_conv:
+ self._box_predict = tf_keras.layers.SeparableConv2D(
+ 4 * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ padding='same',
+ name='box-predict')
+ else:
+ self._box_predict = tf_keras.layers.Conv2D(
+ 4 * self._anchors_per_location,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=1e-5),
+ padding='same',
+ name='box-predict')
+ self._box_conv = []
+ self._box_norm_activation = {}
+ for i in range(self._num_convs):
+ if self._use_separable_conv:
+ self._box_conv.append(
+ tf_keras.layers.SeparableConv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ activation=None,
+ bias_initializer=tf.zeros_initializer(),
+ padding='same',
+ name='box-' + str(i)))
+ else:
+ self._box_conv.append(
+ tf_keras.layers.Conv2D(
+ self._num_filters,
+ kernel_size=(3, 3),
+ activation=None,
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf_keras.initializers.RandomNormal(
+ stddev=0.01),
+ padding='same',
+ name='box-' + str(i)))
+ for level in range(self._min_level, self._max_level + 1):
+ name = self._box_net_batch_norm_name(i, level)
+ self._box_norm_activation[name] = norm_activation(name=name)
+
+ def __call__(self, fpn_features, is_training=None):
+ """Returns outputs of RetinaNet head."""
+ class_outputs = {}
+ box_outputs = {}
+ with tf.name_scope('retinanet_head'):
+ for level in range(self._min_level, self._max_level + 1):
+ features = fpn_features[level]
+
+ class_outputs[level] = self.class_net(
+ features, level, is_training=is_training)
+ box_outputs[level] = self.box_net(
+ features, level, is_training=is_training)
+ return class_outputs, box_outputs
+
+ def class_net(self, features, level, is_training):
+ """Class prediction network for RetinaNet."""
+ with self._class_name_scope:
+ for i in range(self._num_convs):
+ features = self._class_conv[i](features)
+ # The convolution layers in the class net are shared among all levels,
+ # but each level has its batch normlization to capture the statistical
+ # difference among different levels.
+ name = self._class_net_batch_norm_name(i, level)
+ features = self._class_norm_activation[name](
+ features, is_training=is_training)
+
+ classes = self._class_predict(features)
+ return classes
+
+ def box_net(self, features, level, is_training=None):
+ """Box regression network for RetinaNet."""
+ with self._box_name_scope:
+ for i in range(self._num_convs):
+ features = self._box_conv[i](features)
+ # The convolution layers in the box net are shared among all levels, but
+ # each level has its batch normlization to capture the statistical
+ # difference among different levels.
+ name = self._box_net_batch_norm_name(i, level)
+ features = self._box_norm_activation[name](
+ features, is_training=is_training)
+
+ boxes = self._box_predict(features)
+ return boxes
+
+
+# TODO(yeqing): Refactor this class when it is ready for var_scope reuse.
+class ShapemaskPriorHead(object):
+ """ShapeMask Prior head."""
+
+ def __init__(self, num_classes, num_downsample_channels, mask_crop_size,
+ use_category_for_mask, shape_prior_path):
+ """Initialize params to build RetinaNet head.
+
+ Args:
+ num_classes: Number of output classes.
+ num_downsample_channels: number of channels in mask branch.
+ mask_crop_size: feature crop size.
+ use_category_for_mask: use class information in mask branch.
+ shape_prior_path: the path to load shape priors.
+ """
+ self._mask_num_classes = num_classes if use_category_for_mask else 1
+ self._num_downsample_channels = num_downsample_channels
+ self._mask_crop_size = mask_crop_size
+ self._shape_prior_path = shape_prior_path
+ self._use_category_for_mask = use_category_for_mask
+
+ self._shape_prior_fc = tf_keras.layers.Dense(
+ self._num_downsample_channels, name='shape-prior-fc')
+
+ def __call__(self, fpn_features, boxes, outer_boxes, classes, is_training):
+ """Generate the detection priors from the box detections and FPN features.
+
+ This corresponds to the Fig. 4 of the ShapeMask paper at
+ https://arxiv.org/pdf/1904.03239.pdf
+
+ Args:
+ fpn_features: a dictionary of FPN features.
+ boxes: a float tensor of shape [batch_size, num_instances, 4] representing
+ the tight gt boxes from dataloader/detection.
+ outer_boxes: a float tensor of shape [batch_size, num_instances, 4]
+ representing the loose gt boxes from dataloader/detection.
+ classes: a int Tensor of shape [batch_size, num_instances] of instance
+ classes.
+ is_training: training mode or not.
+
+ Returns:
+ instance_features: a float Tensor of shape [batch_size * num_instances,
+ mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
+ instance feature crop.
+ detection_priors: A float Tensor of shape [batch_size * num_instances,
+ mask_size, mask_size, 1].
+ """
+ with tf.name_scope('prior_mask'):
+ batch_size, num_instances, _ = boxes.get_shape().as_list()
+ outer_boxes = tf.cast(outer_boxes, tf.float32)
+ boxes = tf.cast(boxes, tf.float32)
+ instance_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, outer_boxes, output_size=self._mask_crop_size)
+ instance_features = self._shape_prior_fc(instance_features)
+
+ shape_priors = self._get_priors()
+
+ # Get uniform priors for each outer box.
+ uniform_priors = tf.ones([
+ batch_size, num_instances, self._mask_crop_size, self._mask_crop_size
+ ])
+ uniform_priors = spatial_transform_ops.crop_mask_in_target_box(
+ uniform_priors, boxes, outer_boxes, self._mask_crop_size)
+
+ # Classify shape priors using uniform priors + instance features.
+ prior_distribution = self._classify_shape_priors(
+ tf.cast(instance_features, tf.float32), uniform_priors, classes)
+
+ instance_priors = tf.gather(shape_priors, classes)
+ instance_priors *= tf.expand_dims(
+ tf.expand_dims(tf.cast(prior_distribution, tf.float32), axis=-1),
+ axis=-1)
+ instance_priors = tf.reduce_sum(instance_priors, axis=2)
+ detection_priors = spatial_transform_ops.crop_mask_in_target_box(
+ instance_priors, boxes, outer_boxes, self._mask_crop_size)
+
+ return instance_features, detection_priors
+
+ def _get_priors(self):
+ """Load shape priors from file."""
+ # loads class specific or agnostic shape priors
+ if self._shape_prior_path:
+ # Priors are loaded into shape [mask_num_classes, num_clusters, 32, 32].
+ priors = np.load(tf.io.gfile.GFile(self._shape_prior_path, 'rb'))
+ priors = tf.convert_to_tensor(priors, dtype=tf.float32)
+ self._num_clusters = priors.get_shape().as_list()[1]
+ else:
+ # If prior path does not exist, do not use priors, i.e., pirors equal to
+ # uniform empty 32x32 patch.
+ self._num_clusters = 1
+ priors = tf.zeros([
+ self._mask_num_classes, self._num_clusters, self._mask_crop_size,
+ self._mask_crop_size
+ ])
+ return priors
+
+ def _classify_shape_priors(self, features, uniform_priors, classes):
+ """Classify the uniform prior by predicting the shape modes.
+
+ Classify the object crop features into K modes of the clusters for each
+ category.
+
+ Args:
+ features: A float Tensor of shape [batch_size, num_instances, mask_size,
+ mask_size, num_channels].
+ uniform_priors: A float Tensor of shape [batch_size, num_instances,
+ mask_size, mask_size] representing the uniform detection priors.
+ classes: A int Tensor of shape [batch_size, num_instances] of detection
+ class ids.
+
+ Returns:
+ prior_distribution: A float Tensor of shape
+ [batch_size, num_instances, num_clusters] representing the classifier
+ output probability over all possible shapes.
+ """
+
+ batch_size, num_instances, _, _, _ = features.get_shape().as_list()
+ features *= tf.expand_dims(uniform_priors, axis=-1)
+ # Reduce spatial dimension of features. The features have shape
+ # [batch_size, num_instances, num_channels].
+ features = tf.reduce_mean(features, axis=(2, 3))
+ logits = tf_keras.layers.Dense(
+ self._mask_num_classes * self._num_clusters,
+ kernel_initializer=tf.random_normal_initializer(stddev=0.01),
+ name='classify-shape-prior-fc')(features)
+ logits = tf.reshape(
+ logits,
+ [batch_size, num_instances, self._mask_num_classes, self._num_clusters])
+ if self._use_category_for_mask:
+ logits = tf.gather(logits, tf.expand_dims(classes, axis=-1), batch_dims=2)
+ logits = tf.squeeze(logits, axis=2)
+ else:
+ logits = logits[:, :, 0, :]
+
+ distribution = tf.nn.softmax(logits, name='shape_prior_weights')
+ return distribution
+
+
+class ShapemaskCoarsemaskHead(object):
+ """ShapemaskCoarsemaskHead head."""
+
+ def __init__(self,
+ num_classes,
+ num_downsample_channels,
+ mask_crop_size,
+ use_category_for_mask,
+ num_convs,
+ norm_activation=nn_ops.norm_activation_builder()):
+ """Initialize params to build ShapeMask coarse and fine prediction head.
+
+ Args:
+ num_classes: `int` number of mask classification categories.
+ num_downsample_channels: `int` number of filters at mask head.
+ mask_crop_size: feature crop size.
+ use_category_for_mask: use class information in mask branch.
+ num_convs: `int` number of stacked convolution before the last prediction
+ layer.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ """
+ self._mask_num_classes = num_classes if use_category_for_mask else 1
+ self._use_category_for_mask = use_category_for_mask
+ self._num_downsample_channels = num_downsample_channels
+ self._mask_crop_size = mask_crop_size
+ self._num_convs = num_convs
+ self._norm_activation = norm_activation
+
+ self._coarse_mask_fc = tf_keras.layers.Dense(
+ self._num_downsample_channels, name='coarse-mask-fc')
+
+ self._class_conv = []
+ self._class_norm_activation = []
+
+ for i in range(self._num_convs):
+ self._class_conv.append(
+ tf_keras.layers.Conv2D(
+ self._num_downsample_channels,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf_keras.initializers.RandomNormal(
+ stddev=0.01),
+ padding='same',
+ name='coarse-mask-class-%d' % i))
+
+ self._class_norm_activation.append(
+ norm_activation(name='coarse-mask-class-%d-bn' % i))
+
+ self._class_predict = tf_keras.layers.Conv2D(
+ self._mask_num_classes,
+ kernel_size=(1, 1),
+ # Focal loss bias initialization to have foreground 0.01 probability.
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
+ padding='same',
+ name='coarse-mask-class-predict')
+
+ def __call__(self, features, detection_priors, classes, is_training):
+ """Generate instance masks from FPN features and detection priors.
+
+ This corresponds to the Fig. 5-6 of the ShapeMask paper at
+ https://arxiv.org/pdf/1904.03239.pdf
+
+ Args:
+ features: a float Tensor of shape [batch_size, num_instances,
+ mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
+ instance feature crop.
+ detection_priors: a float Tensor of shape [batch_size, num_instances,
+ mask_crop_size, mask_crop_size, 1]. This is the detection prior for the
+ instance.
+ classes: a int Tensor of shape [batch_size, num_instances] of instance
+ classes.
+ is_training: a bool indicating whether in training mode.
+
+ Returns:
+ mask_outputs: instance mask prediction as a float Tensor of shape
+ [batch_size, num_instances, mask_size, mask_size].
+ """
+ with tf.name_scope('coarse_mask'):
+ # Transform detection priors to have the same dimension as features.
+ detection_priors = tf.expand_dims(detection_priors, axis=-1)
+ detection_priors = self._coarse_mask_fc(detection_priors)
+
+ features += detection_priors
+ mask_logits = self.decoder_net(features, is_training)
+ # Gather the logits with right input class.
+ if self._use_category_for_mask:
+ mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
+ mask_logits = tf.gather(
+ mask_logits, tf.expand_dims(classes, -1), batch_dims=2)
+ mask_logits = tf.squeeze(mask_logits, axis=2)
+ else:
+ mask_logits = mask_logits[..., 0]
+
+ return mask_logits
+
+ def decoder_net(self, features, is_training=False):
+ """Coarse mask decoder network architecture.
+
+ Args:
+ features: A tensor of size [batch, height_in, width_in, channels_in].
+ is_training: Whether batch_norm layers are in training mode.
+
+ Returns:
+ images: A feature tensor of size [batch, output_size, output_size,
+ num_channels]
+ """
+ (batch_size, num_instances, height, width,
+ num_channels) = features.get_shape().as_list()
+ features = tf.reshape(
+ features, [batch_size * num_instances, height, width, num_channels])
+ for i in range(self._num_convs):
+ features = self._class_conv[i](features)
+ features = self._class_norm_activation[i](
+ features, is_training=is_training)
+
+ mask_logits = self._class_predict(features)
+ mask_logits = tf.reshape(
+ mask_logits,
+ [batch_size, num_instances, height, width, self._mask_num_classes])
+ return mask_logits
+
+
+class ShapemaskFinemaskHead(object):
+ """ShapemaskFinemaskHead head."""
+
+ def __init__(self,
+ num_classes,
+ num_downsample_channels,
+ mask_crop_size,
+ use_category_for_mask,
+ num_convs,
+ upsample_factor,
+ norm_activation=nn_ops.norm_activation_builder()):
+ """Initialize params to build ShapeMask coarse and fine prediction head.
+
+ Args:
+ num_classes: `int` number of mask classification categories.
+ num_downsample_channels: `int` number of filters at mask head.
+ mask_crop_size: feature crop size.
+ use_category_for_mask: use class information in mask branch.
+ num_convs: `int` number of stacked convolution before the last prediction
+ layer.
+ upsample_factor: `int` number of fine mask upsampling factor.
+ norm_activation: an operation that includes a batch normalization layer
+ followed by a relu layer(optional).
+ """
+ self._use_category_for_mask = use_category_for_mask
+ self._mask_num_classes = num_classes if use_category_for_mask else 1
+ self._num_downsample_channels = num_downsample_channels
+ self._mask_crop_size = mask_crop_size
+ self._num_convs = num_convs
+ self.up_sample_factor = upsample_factor
+
+ self._fine_mask_fc = tf_keras.layers.Dense(
+ self._num_downsample_channels, name='fine-mask-fc')
+
+ self._upsample_conv = tf_keras.layers.Conv2DTranspose(
+ self._num_downsample_channels,
+ (self.up_sample_factor, self.up_sample_factor),
+ (self.up_sample_factor, self.up_sample_factor),
+ name='fine-mask-conv2d-tran')
+
+ self._fine_class_conv = []
+ self._fine_class_bn = []
+ for i in range(self._num_convs):
+ self._fine_class_conv.append(
+ tf_keras.layers.Conv2D(
+ self._num_downsample_channels,
+ kernel_size=(3, 3),
+ bias_initializer=tf.zeros_initializer(),
+ kernel_initializer=tf_keras.initializers.RandomNormal(
+ stddev=0.01),
+ activation=None,
+ padding='same',
+ name='fine-mask-class-%d' % i))
+ self._fine_class_bn.append(
+ norm_activation(name='fine-mask-class-%d-bn' % i))
+
+ self._class_predict_conv = tf_keras.layers.Conv2D(
+ self._mask_num_classes,
+ kernel_size=(1, 1),
+ # Focal loss bias initialization to have foreground 0.01 probability.
+ bias_initializer=tf.constant_initializer(-np.log((1 - 0.01) / 0.01)),
+ kernel_initializer=tf_keras.initializers.RandomNormal(stddev=0.01),
+ padding='same',
+ name='fine-mask-class-predict')
+
+ def __call__(self, features, mask_logits, classes, is_training):
+ """Generate instance masks from FPN features and detection priors.
+
+ This corresponds to the Fig. 5-6 of the ShapeMask paper at
+ https://arxiv.org/pdf/1904.03239.pdf
+
+ Args:
+ features: a float Tensor of shape [batch_size, num_instances,
+ mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
+ instance feature crop.
+ mask_logits: a float Tensor of shape [batch_size, num_instances,
+ mask_crop_size, mask_crop_size] indicating predicted mask logits.
+ classes: a int Tensor of shape [batch_size, num_instances] of instance
+ classes.
+ is_training: a bool indicating whether in training mode.
+
+ Returns:
+ mask_outputs: instance mask prediction as a float Tensor of shape
+ [batch_size, num_instances, mask_size, mask_size].
+ """
+ # Extract the foreground mean features
+ # with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
+ with tf.name_scope('fine_mask'):
+ mask_probs = tf.nn.sigmoid(mask_logits)
+ # Compute instance embedding for hard average.
+ binary_mask = tf.cast(tf.greater(mask_probs, 0.5), features.dtype)
+ instance_embedding = tf.reduce_sum(
+ features * tf.expand_dims(binary_mask, axis=-1), axis=(2, 3))
+ instance_embedding /= tf.expand_dims(
+ tf.reduce_sum(binary_mask, axis=(2, 3)) + 1e-20, axis=-1)
+ # Take the difference between crop features and mean instance features.
+ features -= tf.expand_dims(
+ tf.expand_dims(instance_embedding, axis=2), axis=2)
+
+ features += self._fine_mask_fc(tf.expand_dims(mask_probs, axis=-1))
+
+ # Decoder to generate upsampled segmentation mask.
+ mask_logits = self.decoder_net(features, is_training)
+ if self._use_category_for_mask:
+ mask_logits = tf.transpose(mask_logits, [0, 1, 4, 2, 3])
+ mask_logits = tf.gather(
+ mask_logits, tf.expand_dims(classes, -1), batch_dims=2)
+ mask_logits = tf.squeeze(mask_logits, axis=2)
+ else:
+ mask_logits = mask_logits[..., 0]
+
+ return mask_logits
+
+ def decoder_net(self, features, is_training=False):
+ """Fine mask decoder network architecture.
+
+ Args:
+ features: A tensor of size [batch, height_in, width_in, channels_in].
+ is_training: Whether batch_norm layers are in training mode.
+
+ Returns:
+ images: A feature tensor of size [batch, output_size, output_size,
+ num_channels], where output size is self._gt_upsample_scale times
+ that of input.
+ """
+ (batch_size, num_instances, height, width,
+ num_channels) = features.get_shape().as_list()
+ features = tf.reshape(
+ features, [batch_size * num_instances, height, width, num_channels])
+ for i in range(self._num_convs):
+ features = self._fine_class_conv[i](features)
+ features = self._fine_class_bn[i](features, is_training=is_training)
+
+ if self.up_sample_factor > 1:
+ features = self._upsample_conv(features)
+
+ # Predict per-class instance masks.
+ mask_logits = self._class_predict_conv(features)
+
+ mask_logits = tf.reshape(mask_logits, [
+ batch_size, num_instances, height * self.up_sample_factor,
+ width * self.up_sample_factor, self._mask_num_classes
+ ])
+ return mask_logits
diff --git a/modeling/official/legacy/detection/modeling/architecture/identity.py b/modeling/official/legacy/detection/modeling/architecture/identity.py
new file mode 100644
index 0000000000000000000000000000000000000000..28a2d8d2a94ce8d4a92d072b57c0649df59321e1
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/identity.py
@@ -0,0 +1,28 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Identity Fn that forwards the input features."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class Identity(object):
+ """Identity function that forwards the input features."""
+
+ def __call__(self, features, is_training=False):
+ """Only forwards the input features."""
+ return features
+
diff --git a/modeling/official/legacy/detection/modeling/architecture/nn_blocks.py b/modeling/official/legacy/detection/modeling/architecture/nn_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c0cd87d09c95327c34731e0d1054ed24c620407
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/nn_blocks.py
@@ -0,0 +1,316 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains common building blocks for neural networks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.modeling import tf_utils
+
+
+class ResidualBlock(tf_keras.layers.Layer):
+ """A residual block."""
+
+ def __init__(self,
+ filters,
+ strides,
+ use_projection=False,
+ kernel_initializer='VarianceScaling',
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activation='relu',
+ use_sync_bn=False,
+ norm_momentum=0.99,
+ norm_epsilon=0.001,
+ **kwargs):
+ """A residual block with BN after convolutions.
+
+ Args:
+ filters: `int` number of filters for the first two convolutions. Note that
+ the third and final convolution will use 4 times as many filters.
+ strides: `int` block stride. If greater than 1, this block will ultimately
+ downsample the input.
+ use_projection: `bool` for whether this block should use a projection
+ shortcut (versus the default identity shortcut). This is usually `True`
+ for the first block of a block group, which may change the number of
+ filters and the resolution.
+ kernel_initializer: kernel_initializer for convolutional layers.
+ kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D.
+ Default to None.
+ bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d.
+ Default to None.
+ activation: `str` name of the activation function.
+ use_sync_bn: if True, use synchronized batch normalization.
+ norm_momentum: `float` normalization omentum for the moving average.
+ norm_epsilon: `float` small float added to variance to avoid dividing by
+ zero.
+ **kwargs: keyword arguments to be passed.
+ """
+ super(ResidualBlock, self).__init__(**kwargs)
+
+ self._filters = filters
+ self._strides = strides
+ self._use_projection = use_projection
+ self._use_sync_bn = use_sync_bn
+ self._activation = activation
+ self._kernel_initializer = kernel_initializer
+ self._norm_momentum = norm_momentum
+ self._norm_epsilon = norm_epsilon
+ self._kernel_regularizer = kernel_regularizer
+ self._bias_regularizer = bias_regularizer
+
+ if use_sync_bn:
+ self._norm = tf_keras.layers.experimental.SyncBatchNormalization
+ else:
+ self._norm = tf_keras.layers.BatchNormalization
+ if tf_keras.backend.image_data_format() == 'channels_last':
+ self._bn_axis = -1
+ else:
+ self._bn_axis = 1
+ self._activation_fn = tf_utils.get_activation(activation)
+
+ def build(self, input_shape):
+ if self._use_projection:
+ self._shortcut = tf_keras.layers.Conv2D(
+ filters=self._filters,
+ kernel_size=1,
+ strides=self._strides,
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)
+ self._norm0 = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)
+
+ self._conv1 = tf_keras.layers.Conv2D(
+ filters=self._filters,
+ kernel_size=3,
+ strides=self._strides,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)
+ self._norm1 = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)
+
+ self._conv2 = tf_keras.layers.Conv2D(
+ filters=self._filters,
+ kernel_size=3,
+ strides=1,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)
+ self._norm2 = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)
+
+ super(ResidualBlock, self).build(input_shape)
+
+ def get_config(self):
+ config = {
+ 'filters': self._filters,
+ 'strides': self._strides,
+ 'use_projection': self._use_projection,
+ 'kernel_initializer': self._kernel_initializer,
+ 'kernel_regularizer': self._kernel_regularizer,
+ 'bias_regularizer': self._bias_regularizer,
+ 'activation': self._activation,
+ 'use_sync_bn': self._use_sync_bn,
+ 'norm_momentum': self._norm_momentum,
+ 'norm_epsilon': self._norm_epsilon
+ }
+
+ base_config = super(ResidualBlock, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs):
+ shortcut = inputs
+ if self._use_projection:
+ shortcut = self._shortcut(shortcut)
+ shortcut = self._norm0(shortcut)
+
+ x = self._conv1(inputs)
+ x = self._norm1(x)
+ x = self._activation_fn(x)
+
+ x = self._conv2(x)
+ x = self._norm2(x)
+
+ return self._activation_fn(x + shortcut)
+
+
+class BottleneckBlock(tf_keras.layers.Layer):
+ """A standard bottleneck block."""
+
+ def __init__(self,
+ filters,
+ strides,
+ use_projection=False,
+ kernel_initializer='VarianceScaling',
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activation='relu',
+ use_sync_bn=False,
+ norm_momentum=0.99,
+ norm_epsilon=0.001,
+ **kwargs):
+ """A standard bottleneck block with BN after convolutions.
+
+ Args:
+ filters: `int` number of filters for the first two convolutions. Note that
+ the third and final convolution will use 4 times as many filters.
+ strides: `int` block stride. If greater than 1, this block will ultimately
+ downsample the input.
+ use_projection: `bool` for whether this block should use a projection
+ shortcut (versus the default identity shortcut). This is usually `True`
+ for the first block of a block group, which may change the number of
+ filters and the resolution.
+ kernel_initializer: kernel_initializer for convolutional layers.
+ kernel_regularizer: tf_keras.regularizers.Regularizer object for Conv2D.
+ Default to None.
+ bias_regularizer: tf_keras.regularizers.Regularizer object for Conv2d.
+ Default to None.
+ activation: `str` name of the activation function.
+ use_sync_bn: if True, use synchronized batch normalization.
+ norm_momentum: `float` normalization omentum for the moving average.
+ norm_epsilon: `float` small float added to variance to avoid dividing by
+ zero.
+ **kwargs: keyword arguments to be passed.
+ """
+ super(BottleneckBlock, self).__init__(**kwargs)
+
+ self._filters = filters
+ self._strides = strides
+ self._use_projection = use_projection
+ self._use_sync_bn = use_sync_bn
+ self._activation = activation
+ self._kernel_initializer = kernel_initializer
+ self._norm_momentum = norm_momentum
+ self._norm_epsilon = norm_epsilon
+ self._kernel_regularizer = kernel_regularizer
+ self._bias_regularizer = bias_regularizer
+ if use_sync_bn:
+ self._norm = tf_keras.layers.experimental.SyncBatchNormalization
+ else:
+ self._norm = tf_keras.layers.BatchNormalization
+ if tf_keras.backend.image_data_format() == 'channels_last':
+ self._bn_axis = -1
+ else:
+ self._bn_axis = 1
+ self._activation_fn = tf_utils.get_activation(activation)
+
+ def build(self, input_shape):
+ if self._use_projection:
+ self._shortcut = tf_keras.layers.Conv2D(
+ filters=self._filters * 4,
+ kernel_size=1,
+ strides=self._strides,
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)
+ self._norm0 = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)
+
+ self._conv1 = tf_keras.layers.Conv2D(
+ filters=self._filters,
+ kernel_size=1,
+ strides=1,
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)
+ self._norm1 = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)
+
+ self._conv2 = tf_keras.layers.Conv2D(
+ filters=self._filters,
+ kernel_size=3,
+ strides=self._strides,
+ padding='same',
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)
+ self._norm2 = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)
+
+ self._conv3 = tf_keras.layers.Conv2D(
+ filters=self._filters * 4,
+ kernel_size=1,
+ strides=1,
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)
+ self._norm3 = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)
+
+ super(BottleneckBlock, self).build(input_shape)
+
+ def get_config(self):
+ config = {
+ 'filters': self._filters,
+ 'strides': self._strides,
+ 'use_projection': self._use_projection,
+ 'kernel_initializer': self._kernel_initializer,
+ 'kernel_regularizer': self._kernel_regularizer,
+ 'bias_regularizer': self._bias_regularizer,
+ 'activation': self._activation,
+ 'use_sync_bn': self._use_sync_bn,
+ 'norm_momentum': self._norm_momentum,
+ 'norm_epsilon': self._norm_epsilon
+ }
+
+ base_config = super(BottleneckBlock, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def call(self, inputs):
+ shortcut = inputs
+ if self._use_projection:
+ shortcut = self._shortcut(shortcut)
+ shortcut = self._norm0(shortcut)
+
+ x = self._conv1(inputs)
+ x = self._norm1(x)
+ x = self._activation_fn(x)
+
+ x = self._conv2(x)
+ x = self._norm2(x)
+ x = self._activation_fn(x)
+
+ x = self._conv3(x)
+ x = self._norm3(x)
+
+ return self._activation_fn(x + shortcut)
diff --git a/modeling/official/legacy/detection/modeling/architecture/nn_ops.py b/modeling/official/legacy/detection/modeling/architecture/nn_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb4890bf4d0bd9925d449280ce0e24ba78e617ad
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/nn_ops.py
@@ -0,0 +1,109 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Neural network operations commonly shared by the architectures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf, tf_keras
+
+
+class NormActivation(tf_keras.layers.Layer):
+ """Combined Normalization and Activation layers."""
+
+ def __init__(self,
+ momentum=0.997,
+ epsilon=1e-4,
+ trainable=True,
+ init_zero=False,
+ use_activation=True,
+ activation='relu',
+ fused=True,
+ name=None):
+ """A class to construct layers for a batch normalization followed by a ReLU.
+
+ Args:
+ momentum: momentum for the moving average.
+ epsilon: small float added to variance to avoid dividing by zero.
+ trainable: `bool`, if True also add variables to the graph collection
+ GraphKeys.TRAINABLE_VARIABLES. If False, freeze batch normalization
+ layer.
+ init_zero: `bool` if True, initializes scale parameter of batch
+ normalization with 0. If False, initialize it with 1.
+ use_activation: `bool`, whether to add the optional activation layer after
+ the batch normalization layer.
+ activation: 'string', the type of the activation layer. Currently support
+ `relu` and `swish`.
+ fused: `bool` fused option in batch normalziation.
+ name: `str` name for the operation.
+ """
+ super(NormActivation, self).__init__(trainable=trainable)
+ if init_zero:
+ gamma_initializer = tf_keras.initializers.Zeros()
+ else:
+ gamma_initializer = tf_keras.initializers.Ones()
+ self._normalization_op = tf_keras.layers.BatchNormalization(
+ momentum=momentum,
+ epsilon=epsilon,
+ center=True,
+ scale=True,
+ trainable=trainable,
+ fused=fused,
+ gamma_initializer=gamma_initializer,
+ name=name)
+ self._use_activation = use_activation
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+
+ def __call__(self, inputs, is_training=None):
+ """Builds the normalization layer followed by an optional activation layer.
+
+ Args:
+ inputs: `Tensor` of shape `[batch, channels, ...]`.
+ is_training: `boolean`, if True if model is in training mode.
+
+ Returns:
+ A normalized `Tensor` with the same `data_format`.
+ """
+ # We will need to keep training=None by default, so that it can be inherit
+ # from keras.Model.training
+ if is_training and self.trainable:
+ is_training = True
+ inputs = self._normalization_op(inputs, training=is_training)
+
+ if self._use_activation:
+ inputs = self._activation_op(inputs)
+ return inputs
+
+
+def norm_activation_builder(momentum=0.997,
+ epsilon=1e-4,
+ trainable=True,
+ activation='relu',
+ **kwargs):
+ return functools.partial(
+ NormActivation,
+ momentum=momentum,
+ epsilon=epsilon,
+ trainable=trainable,
+ activation=activation,
+ **kwargs)
diff --git a/modeling/official/legacy/detection/modeling/architecture/resnet.py b/modeling/official/legacy/detection/modeling/architecture/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4938b99ff9daad340e72311d161b01fc3550cc27
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/resnet.py
@@ -0,0 +1,352 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains definitions for the post-activation form of Residual Networks.
+
+Residual networks (ResNets) were proposed in:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+from official.legacy.detection.modeling.architecture import nn_ops
+
+
+# TODO(b/140112644): Refactor the code with Keras style, i.e. build and call.
+class Resnet(object):
+ """Class to build ResNet family model."""
+
+ def __init__(
+ self,
+ resnet_depth,
+ activation='relu',
+ norm_activation=nn_ops.norm_activation_builder(activation='relu'),
+ data_format='channels_last'):
+ """ResNet initialization function.
+
+ Args:
+ resnet_depth: `int` depth of ResNet backbone model.
+ activation: the activation function.
+ norm_activation: an operation that includes a normalization layer followed
+ by an optional activation layer.
+ data_format: `str` either "channels_first" for `[batch, channels, height,
+ width]` or "channels_last for `[batch, height, width, channels]`.
+ """
+ self._resnet_depth = resnet_depth
+ if activation == 'relu':
+ self._activation_op = tf.nn.relu
+ elif activation == 'swish':
+ self._activation_op = tf.nn.swish
+ else:
+ raise ValueError('Unsupported activation `{}`.'.format(activation))
+ self._norm_activation = norm_activation
+ self._data_format = data_format
+
+ model_params = {
+ 10: {
+ 'block': self.residual_block,
+ 'layers': [1, 1, 1, 1]
+ },
+ 18: {
+ 'block': self.residual_block,
+ 'layers': [2, 2, 2, 2]
+ },
+ 34: {
+ 'block': self.residual_block,
+ 'layers': [3, 4, 6, 3]
+ },
+ 50: {
+ 'block': self.bottleneck_block,
+ 'layers': [3, 4, 6, 3]
+ },
+ 101: {
+ 'block': self.bottleneck_block,
+ 'layers': [3, 4, 23, 3]
+ },
+ 152: {
+ 'block': self.bottleneck_block,
+ 'layers': [3, 8, 36, 3]
+ },
+ 200: {
+ 'block': self.bottleneck_block,
+ 'layers': [3, 24, 36, 3]
+ }
+ }
+
+ if resnet_depth not in model_params:
+ valid_resnet_depths = ', '.join(
+ [str(depth) for depth in sorted(model_params.keys())])
+ raise ValueError(
+ 'The resnet_depth should be in [%s]. Not a valid resnet_depth:' %
+ (valid_resnet_depths), self._resnet_depth)
+ params = model_params[resnet_depth]
+ self._resnet_fn = self.resnet_v1_generator(params['block'],
+ params['layers'])
+
+ def __call__(self, inputs, is_training=None):
+ """Returns the ResNet model for a given size and number of output classes.
+
+ Args:
+ inputs: a `Tesnor` with shape [batch_size, height, width, 3] representing
+ a batch of images.
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ a `dict` containing `int` keys for continuous feature levels [2, 3, 4, 5].
+ The values are corresponding feature hierarchy in ResNet with shape
+ [batch_size, height_l, width_l, num_filters].
+ """
+ with tf.name_scope('resnet%s' % self._resnet_depth):
+ return self._resnet_fn(inputs, is_training)
+
+ def fixed_padding(self, inputs, kernel_size):
+ """Pads the input along the spatial dimensions independently of input size.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]` or `[batch,
+ height, width, channels]` depending on `data_format`.
+ kernel_size: `int` kernel size to be used for `conv2d` or max_pool2d`
+ operations. Should be a positive integer.
+
+ Returns:
+ A padded `Tensor` of the same `data_format` with size either intact
+ (if `kernel_size == 1`) or padded (if `kernel_size > 1`).
+ """
+ pad_total = kernel_size - 1
+ pad_beg = pad_total // 2
+ pad_end = pad_total - pad_beg
+ if self._data_format == 'channels_first':
+ padded_inputs = tf.pad(
+ tensor=inputs,
+ paddings=[[0, 0], [0, 0], [pad_beg, pad_end], [pad_beg, pad_end]])
+ else:
+ padded_inputs = tf.pad(
+ tensor=inputs,
+ paddings=[[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
+
+ return padded_inputs
+
+ def conv2d_fixed_padding(self, inputs, filters, kernel_size, strides):
+ """Strided 2-D convolution with explicit padding.
+
+ The padding is consistent and is based only on `kernel_size`, not on the
+ dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone).
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height_in, width_in]`.
+ filters: `int` number of filters in the convolution.
+ kernel_size: `int` size of the kernel to be used in the convolution.
+ strides: `int` strides of the convolution.
+
+ Returns:
+ A `Tensor` of shape `[batch, filters, height_out, width_out]`.
+ """
+ if strides > 1:
+ inputs = self.fixed_padding(inputs, kernel_size)
+
+ return tf_keras.layers.Conv2D(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=('SAME' if strides == 1 else 'VALID'),
+ use_bias=False,
+ kernel_initializer=tf.initializers.VarianceScaling(),
+ data_format=self._data_format)(
+ inputs=inputs)
+
+ def residual_block(self,
+ inputs,
+ filters,
+ strides,
+ use_projection=False,
+ is_training=None):
+ """Standard building block for residual networks with BN after convolutions.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]`.
+ filters: `int` number of filters for the first two convolutions. Note that
+ the third and final convolution will use 4 times as many filters.
+ strides: `int` block stride. If greater than 1, this block will ultimately
+ downsample the input.
+ use_projection: `bool` for whether this block should use a projection
+ shortcut (versus the default identity shortcut). This is usually `True`
+ for the first block of a block group, which may change the number of
+ filters and the resolution.
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ The output `Tensor` of the block.
+ """
+ shortcut = inputs
+ if use_projection:
+ # Projection shortcut in first layer to match filters and strides
+ shortcut = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=1, strides=strides)
+ shortcut = self._norm_activation(use_activation=False)(
+ shortcut, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=3, strides=strides)
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=3, strides=1)
+ inputs = self._norm_activation(
+ use_activation=False, init_zero=True)(
+ inputs, is_training=is_training)
+
+ return self._activation_op(inputs + shortcut)
+
+ def bottleneck_block(self,
+ inputs,
+ filters,
+ strides,
+ use_projection=False,
+ is_training=None):
+ """Bottleneck block variant for residual networks with BN after convolutions.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]`.
+ filters: `int` number of filters for the first two convolutions. Note that
+ the third and final convolution will use 4 times as many filters.
+ strides: `int` block stride. If greater than 1, this block will ultimately
+ downsample the input.
+ use_projection: `bool` for whether this block should use a projection
+ shortcut (versus the default identity shortcut). This is usually `True`
+ for the first block of a block group, which may change the number of
+ filters and the resolution.
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ The output `Tensor` of the block.
+ """
+ shortcut = inputs
+ if use_projection:
+ # Projection shortcut only in first block within a group. Bottleneck
+ # blocks end with 4 times the number of filters.
+ filters_out = 4 * filters
+ shortcut = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters_out, kernel_size=1, strides=strides)
+ shortcut = self._norm_activation(use_activation=False)(
+ shortcut, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=1, strides=1)
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=filters, kernel_size=3, strides=strides)
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=4 * filters, kernel_size=1, strides=1)
+ inputs = self._norm_activation(
+ use_activation=False, init_zero=True)(
+ inputs, is_training=is_training)
+
+ return self._activation_op(inputs + shortcut)
+
+ def block_group(self, inputs, filters, block_fn, blocks, strides, name,
+ is_training):
+ """Creates one group of blocks for the ResNet model.
+
+ Args:
+ inputs: `Tensor` of size `[batch, channels, height, width]`.
+ filters: `int` number of filters for the first convolution of the layer.
+ block_fn: `function` for the block to use within the model
+ blocks: `int` number of blocks contained in the layer.
+ strides: `int` stride to use for the first convolution of the layer. If
+ greater than 1, this layer will downsample the input.
+ name: `str`name for the Tensor output of the block layer.
+ is_training: `bool` if True, the model is in training mode.
+
+ Returns:
+ The output `Tensor` of the block layer.
+ """
+ # Only the first block per block_group uses projection shortcut and strides.
+ inputs = block_fn(
+ inputs, filters, strides, use_projection=True, is_training=is_training)
+
+ for _ in range(1, blocks):
+ inputs = block_fn(inputs, filters, 1, is_training=is_training)
+
+ return tf.identity(inputs, name)
+
+ def resnet_v1_generator(self, block_fn, layers):
+ """Generator for ResNet v1 models.
+
+ Args:
+ block_fn: `function` for the block to use within the model. Either
+ `residual_block` or `bottleneck_block`.
+ layers: list of 4 `int`s denoting the number of blocks to include in each
+ of the 4 block groups. Each group consists of blocks that take inputs of
+ the same resolution.
+
+ Returns:
+ Model `function` that takes in `inputs` and `is_training` and returns the
+ output `Tensor` of the ResNet model.
+ """
+
+ def model(inputs, is_training=None):
+ """Creation of the model graph."""
+ inputs = self.conv2d_fixed_padding(
+ inputs=inputs, filters=64, kernel_size=7, strides=2)
+ inputs = tf.identity(inputs, 'initial_conv')
+ inputs = self._norm_activation()(inputs, is_training=is_training)
+
+ inputs = tf_keras.layers.MaxPool2D(
+ pool_size=3, strides=2, padding='SAME',
+ data_format=self._data_format)(
+ inputs)
+ inputs = tf.identity(inputs, 'initial_max_pool')
+
+ c2 = self.block_group(
+ inputs=inputs,
+ filters=64,
+ block_fn=block_fn,
+ blocks=layers[0],
+ strides=1,
+ name='block_group1',
+ is_training=is_training)
+ c3 = self.block_group(
+ inputs=c2,
+ filters=128,
+ block_fn=block_fn,
+ blocks=layers[1],
+ strides=2,
+ name='block_group2',
+ is_training=is_training)
+ c4 = self.block_group(
+ inputs=c3,
+ filters=256,
+ block_fn=block_fn,
+ blocks=layers[2],
+ strides=2,
+ name='block_group3',
+ is_training=is_training)
+ c5 = self.block_group(
+ inputs=c4,
+ filters=512,
+ block_fn=block_fn,
+ blocks=layers[3],
+ strides=2,
+ name='block_group4',
+ is_training=is_training)
+ return {2: c2, 3: c3, 4: c4, 5: c5}
+
+ return model
diff --git a/modeling/official/legacy/detection/modeling/architecture/spinenet.py b/modeling/official/legacy/detection/modeling/architecture/spinenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b04481e21ee76c40c04598206a8174328ee0f9a3
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/architecture/spinenet.py
@@ -0,0 +1,504 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ==============================================================================
+"""Implementation of SpineNet model.
+
+X. Du, T-Y. Lin, P. Jin, G. Ghiasi, M. Tan, Y. Cui, Q. V. Le, X. Song
+SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization
+https://arxiv.org/abs/1912.05027
+"""
+import math
+
+from absl import logging
+import tensorflow as tf, tf_keras
+from official.legacy.detection.modeling.architecture import nn_blocks
+from official.modeling import tf_utils
+
+layers = tf_keras.layers
+
+FILTER_SIZE_MAP = {
+ 1: 32,
+ 2: 64,
+ 3: 128,
+ 4: 256,
+ 5: 256,
+ 6: 256,
+ 7: 256,
+}
+
+# The fixed SpineNet architecture discovered by NAS.
+# Each element represents a specification of a building block:
+# (block_level, block_fn, (input_offset0, input_offset1), is_output).
+SPINENET_BLOCK_SPECS = [
+ (2, 'bottleneck', (0, 1), False),
+ (4, 'residual', (0, 1), False),
+ (3, 'bottleneck', (2, 3), False),
+ (4, 'bottleneck', (2, 4), False),
+ (6, 'residual', (3, 5), False),
+ (4, 'bottleneck', (3, 5), False),
+ (5, 'residual', (6, 7), False),
+ (7, 'residual', (6, 8), False),
+ (5, 'bottleneck', (8, 9), False),
+ (5, 'bottleneck', (8, 10), False),
+ (4, 'bottleneck', (5, 10), True),
+ (3, 'bottleneck', (4, 10), True),
+ (5, 'bottleneck', (7, 12), True),
+ (7, 'bottleneck', (5, 14), True),
+ (6, 'bottleneck', (12, 14), True),
+]
+
+SCALING_MAP = {
+ '49S': {
+ 'endpoints_num_filters': 128,
+ 'filter_size_scale': 0.65,
+ 'resample_alpha': 0.5,
+ 'block_repeats': 1,
+ },
+ '49': {
+ 'endpoints_num_filters': 256,
+ 'filter_size_scale': 1.0,
+ 'resample_alpha': 0.5,
+ 'block_repeats': 1,
+ },
+ '96': {
+ 'endpoints_num_filters': 256,
+ 'filter_size_scale': 1.0,
+ 'resample_alpha': 0.5,
+ 'block_repeats': 2,
+ },
+ '143': {
+ 'endpoints_num_filters': 256,
+ 'filter_size_scale': 1.0,
+ 'resample_alpha': 1.0,
+ 'block_repeats': 3,
+ },
+ '190': {
+ 'endpoints_num_filters': 512,
+ 'filter_size_scale': 1.3,
+ 'resample_alpha': 1.0,
+ 'block_repeats': 4,
+ },
+}
+
+
+class BlockSpec(object):
+ """A container class that specifies the block configuration for SpineNet."""
+
+ def __init__(self, level, block_fn, input_offsets, is_output):
+ self.level = level
+ self.block_fn = block_fn
+ self.input_offsets = input_offsets
+ self.is_output = is_output
+
+
+def build_block_specs(block_specs=None):
+ """Builds the list of BlockSpec objects for SpineNet."""
+ if not block_specs:
+ block_specs = SPINENET_BLOCK_SPECS
+ logging.info('Building SpineNet block specs: %s', block_specs)
+ return [BlockSpec(*b) for b in block_specs]
+
+
+class SpineNet(tf_keras.Model):
+ """Class to build SpineNet models."""
+
+ def __init__(self,
+ input_specs=tf_keras.layers.InputSpec(shape=[None, 640, 640, 3]),
+ min_level=3,
+ max_level=7,
+ block_specs=None,
+ endpoints_num_filters=256,
+ resample_alpha=0.5,
+ block_repeats=1,
+ filter_size_scale=1.0,
+ kernel_initializer='VarianceScaling',
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activation='relu',
+ use_sync_bn=False,
+ norm_momentum=0.99,
+ norm_epsilon=0.001,
+ **kwargs):
+ """SpineNet model."""
+ self._min_level = min_level
+ self._max_level = max_level
+ self._block_specs = (
+ build_block_specs() if block_specs is None else block_specs
+ )
+ self._endpoints_num_filters = endpoints_num_filters
+ self._resample_alpha = resample_alpha
+ self._block_repeats = block_repeats
+ self._filter_size_scale = filter_size_scale
+ self._kernel_initializer = kernel_initializer
+ self._kernel_regularizer = kernel_regularizer
+ self._bias_regularizer = bias_regularizer
+ self._use_sync_bn = use_sync_bn
+ self._norm_momentum = norm_momentum
+ self._norm_epsilon = norm_epsilon
+ if activation == 'relu':
+ self._activation = tf.nn.relu
+ elif activation == 'swish':
+ self._activation = tf.nn.swish
+ else:
+ raise ValueError('Activation {} not implemented.'.format(activation))
+ self._init_block_fn = 'bottleneck'
+ self._num_init_blocks = 2
+
+ if use_sync_bn:
+ self._norm = layers.experimental.SyncBatchNormalization
+ else:
+ self._norm = layers.BatchNormalization
+
+ if tf_keras.backend.image_data_format() == 'channels_last':
+ self._bn_axis = -1
+ else:
+ self._bn_axis = 1
+
+ # Build SpineNet.
+ inputs = tf_keras.Input(shape=input_specs.shape[1:])
+
+ net = self._build_stem(inputs=inputs)
+ net = self._build_scale_permuted_network(
+ net=net, input_width=input_specs.shape[1])
+ net = self._build_endpoints(net=net)
+
+ super(SpineNet, self).__init__(inputs=inputs, outputs=net)
+
+ def _block_group(self,
+ inputs,
+ filters,
+ strides,
+ block_fn_cand,
+ block_repeats=1,
+ name='block_group'):
+ """Creates one group of blocks for the SpineNet model."""
+ block_fn_candidates = {
+ 'bottleneck': nn_blocks.BottleneckBlock,
+ 'residual': nn_blocks.ResidualBlock,
+ }
+ block_fn = block_fn_candidates[block_fn_cand]
+ _, _, _, num_filters = inputs.get_shape().as_list()
+
+ if block_fn_cand == 'bottleneck':
+ use_projection = not (num_filters == (filters * 4) and strides == 1)
+ else:
+ use_projection = not (num_filters == filters and strides == 1)
+
+ x = block_fn(
+ filters=filters,
+ strides=strides,
+ use_projection=use_projection,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activation=self._activation,
+ use_sync_bn=self._use_sync_bn,
+ norm_momentum=self._norm_momentum,
+ norm_epsilon=self._norm_epsilon)(
+ inputs)
+ for _ in range(1, block_repeats):
+ x = block_fn(
+ filters=filters,
+ strides=1,
+ use_projection=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activation=self._activation,
+ use_sync_bn=self._use_sync_bn,
+ norm_momentum=self._norm_momentum,
+ norm_epsilon=self._norm_epsilon)(
+ x)
+ return tf.identity(x, name=name)
+
+ def _build_stem(self, inputs):
+ """Build SpineNet stem."""
+ x = layers.Conv2D(
+ filters=64,
+ kernel_size=7,
+ strides=2,
+ use_bias=False,
+ padding='same',
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)(
+ inputs)
+ x = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)(
+ x)
+ x = tf_utils.get_activation(self._activation)(x)
+ x = layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
+
+ net = []
+ # Build the initial level 2 blocks.
+ for i in range(self._num_init_blocks):
+ x = self._block_group(
+ inputs=x,
+ filters=int(FILTER_SIZE_MAP[2] * self._filter_size_scale),
+ strides=1,
+ block_fn_cand=self._init_block_fn,
+ block_repeats=self._block_repeats,
+ name='stem_block_{}'.format(i + 1))
+ net.append(x)
+ return net
+
+ def _build_scale_permuted_network(self,
+ net,
+ input_width,
+ weighted_fusion=False):
+ """Build scale-permuted network."""
+ net_sizes = [int(math.ceil(input_width / 2**2))] * len(net)
+ net_block_fns = [self._init_block_fn] * len(net)
+ num_outgoing_connections = [0] * len(net)
+
+ endpoints = {}
+ for i, block_spec in enumerate(self._block_specs):
+ # Find out specs for the target block.
+ target_width = int(math.ceil(input_width / 2**block_spec.level))
+ target_num_filters = int(FILTER_SIZE_MAP[block_spec.level] *
+ self._filter_size_scale)
+ target_block_fn = block_spec.block_fn
+
+ # Resample then merge input0 and input1.
+ parents = []
+ input0 = block_spec.input_offsets[0]
+ input1 = block_spec.input_offsets[1]
+
+ x0 = self._resample_with_alpha(
+ inputs=net[input0],
+ input_width=net_sizes[input0],
+ input_block_fn=net_block_fns[input0],
+ target_width=target_width,
+ target_num_filters=target_num_filters,
+ target_block_fn=target_block_fn,
+ alpha=self._resample_alpha)
+ parents.append(x0)
+ num_outgoing_connections[input0] += 1
+
+ x1 = self._resample_with_alpha(
+ inputs=net[input1],
+ input_width=net_sizes[input1],
+ input_block_fn=net_block_fns[input1],
+ target_width=target_width,
+ target_num_filters=target_num_filters,
+ target_block_fn=target_block_fn,
+ alpha=self._resample_alpha)
+ parents.append(x1)
+ num_outgoing_connections[input1] += 1
+
+ # Merge 0 outdegree blocks to the output block.
+ if block_spec.is_output:
+ for j, (j_feat,
+ j_connections) in enumerate(zip(net, num_outgoing_connections)):
+ if j_connections == 0 and (j_feat.shape[2] == target_width and
+ j_feat.shape[3] == x0.shape[3]):
+ parents.append(j_feat)
+ num_outgoing_connections[j] += 1
+
+ # pylint: disable=g-direct-tensorflow-import
+ if weighted_fusion:
+ dtype = parents[0].dtype
+ parent_weights = [
+ tf.nn.relu(tf.cast(tf.Variable(1.0, name='block{}_fusion{}'.format(
+ i, j)), dtype=dtype)) for j in range(len(parents))]
+ weights_sum = tf.add_n(parent_weights)
+ parents = [
+ parents[i] * parent_weights[i] / (weights_sum + 0.0001)
+ for i in range(len(parents))
+ ]
+
+ # Fuse all parent nodes then build a new block.
+ x = tf_utils.get_activation(self._activation)(tf.add_n(parents))
+ x = self._block_group(
+ inputs=x,
+ filters=target_num_filters,
+ strides=1,
+ block_fn_cand=target_block_fn,
+ block_repeats=self._block_repeats,
+ name='scale_permuted_block_{}'.format(i + 1))
+
+ net.append(x)
+ net_sizes.append(target_width)
+ net_block_fns.append(target_block_fn)
+ num_outgoing_connections.append(0)
+
+ # Save output feats.
+ if block_spec.is_output:
+ if block_spec.level in endpoints:
+ raise ValueError('Duplicate feats found for output level {}.'.format(
+ block_spec.level))
+ if (block_spec.level < self._min_level or
+ block_spec.level > self._max_level):
+ raise ValueError('Output level is out of range [{}, {}]'.format(
+ self._min_level, self._max_level))
+ endpoints[block_spec.level] = x
+
+ return endpoints
+
+ def _build_endpoints(self, net):
+ """Match filter size for endpoints before sharing conv layers."""
+ endpoints = {}
+ for level in range(self._min_level, self._max_level + 1):
+ x = layers.Conv2D(
+ filters=self._endpoints_num_filters,
+ kernel_size=1,
+ strides=1,
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)(
+ net[level])
+ x = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)(
+ x)
+ x = tf_utils.get_activation(self._activation)(x)
+ endpoints[level] = x
+ return endpoints
+
+ def _resample_with_alpha(self,
+ inputs,
+ input_width,
+ input_block_fn,
+ target_width,
+ target_num_filters,
+ target_block_fn,
+ alpha=0.5):
+ """Match resolution and feature dimension."""
+ _, _, _, input_num_filters = inputs.get_shape().as_list()
+ if input_block_fn == 'bottleneck':
+ input_num_filters /= 4
+ new_num_filters = int(input_num_filters * alpha)
+
+ x = layers.Conv2D(
+ filters=new_num_filters,
+ kernel_size=1,
+ strides=1,
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)(
+ inputs)
+ x = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)(
+ x)
+ x = tf_utils.get_activation(self._activation)(x)
+
+ # Spatial resampling.
+ if input_width > target_width:
+ x = layers.Conv2D(
+ filters=new_num_filters,
+ kernel_size=3,
+ strides=2,
+ padding='SAME',
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)(
+ x)
+ x = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)(
+ x)
+ x = tf_utils.get_activation(self._activation)(x)
+ input_width /= 2
+ while input_width > target_width:
+ x = layers.MaxPool2D(pool_size=3, strides=2, padding='SAME')(x)
+ input_width /= 2
+ elif input_width < target_width:
+ scale = target_width // input_width
+ x = layers.UpSampling2D(size=(scale, scale))(x)
+
+ # Last 1x1 conv to match filter size.
+ if target_block_fn == 'bottleneck':
+ target_num_filters *= 4
+ x = layers.Conv2D(
+ filters=target_num_filters,
+ kernel_size=1,
+ strides=1,
+ use_bias=False,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer)(
+ x)
+ x = self._norm(
+ axis=self._bn_axis,
+ momentum=self._norm_momentum,
+ epsilon=self._norm_epsilon)(
+ x)
+
+ return x
+
+
+class SpineNetBuilder(object):
+ """SpineNet builder."""
+
+ def __init__(self,
+ model_id,
+ input_specs=tf_keras.layers.InputSpec(shape=[None, 640, 640, 3]),
+ min_level=3,
+ max_level=7,
+ block_specs=None,
+ kernel_initializer='VarianceScaling',
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activation='relu',
+ use_sync_bn=False,
+ norm_momentum=0.99,
+ norm_epsilon=0.001):
+ if model_id not in SCALING_MAP:
+ raise ValueError(
+ 'SpineNet {} is not a valid architecture.'.format(model_id))
+ scaling_params = SCALING_MAP[model_id]
+ self._input_specs = input_specs
+ self._min_level = min_level
+ self._max_level = max_level
+ self._block_specs = block_specs or build_block_specs()
+ self._endpoints_num_filters = scaling_params['endpoints_num_filters']
+ self._resample_alpha = scaling_params['resample_alpha']
+ self._block_repeats = scaling_params['block_repeats']
+ self._filter_size_scale = scaling_params['filter_size_scale']
+ self._kernel_initializer = kernel_initializer
+ self._kernel_regularizer = kernel_regularizer
+ self._bias_regularizer = bias_regularizer
+ self._activation = activation
+ self._use_sync_bn = use_sync_bn
+ self._norm_momentum = norm_momentum
+ self._norm_epsilon = norm_epsilon
+
+ def __call__(self, inputs, is_training=None):
+ model = SpineNet(
+ input_specs=self._input_specs,
+ min_level=self._min_level,
+ max_level=self._max_level,
+ block_specs=self._block_specs,
+ endpoints_num_filters=self._endpoints_num_filters,
+ resample_alpha=self._resample_alpha,
+ block_repeats=self._block_repeats,
+ filter_size_scale=self._filter_size_scale,
+ kernel_initializer=self._kernel_initializer,
+ kernel_regularizer=self._kernel_regularizer,
+ bias_regularizer=self._bias_regularizer,
+ activation=self._activation,
+ use_sync_bn=self._use_sync_bn,
+ norm_momentum=self._norm_momentum,
+ norm_epsilon=self._norm_epsilon)
+ return model(inputs)
diff --git a/modeling/official/legacy/detection/modeling/base_model.py b/modeling/official/legacy/detection/modeling/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d92bc4518bb8429cd659941adfde4747bec0581
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/base_model.py
@@ -0,0 +1,135 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Base Model definition."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import re
+
+import tensorflow as tf, tf_keras
+from official.legacy.detection.modeling import checkpoint_utils
+from official.legacy.detection.modeling import learning_rates
+from official.legacy.detection.modeling import optimizers
+
+
+def _make_filter_trainable_variables_fn(frozen_variable_prefix):
+ """Creates a function for filtering trainable varialbes."""
+
+ def _filter_trainable_variables(variables):
+ """Filters trainable varialbes.
+
+ Args:
+ variables: a list of tf.Variable to be filtered.
+
+ Returns:
+ filtered_variables: a list of tf.Variable filtered out the frozen ones.
+ """
+ # frozen_variable_prefix: a regex string specifing the prefix pattern of
+ # the frozen variables' names.
+ filtered_variables = [
+ v for v in variables if not frozen_variable_prefix or
+ not re.match(frozen_variable_prefix, v.name)
+ ]
+ return filtered_variables
+
+ return _filter_trainable_variables
+
+
+class Model(object):
+ """Base class for model function."""
+
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, params):
+ self._use_bfloat16 = params.architecture.use_bfloat16
+
+ if params.architecture.use_bfloat16:
+ tf.compat.v2.keras.mixed_precision.set_global_policy('mixed_bfloat16')
+
+ # Optimization.
+ self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer)
+ self._learning_rate = learning_rates.learning_rate_generator(
+ params.train.total_steps, params.train.learning_rate)
+
+ self._frozen_variable_prefix = params.train.frozen_variable_prefix
+ self._regularization_var_regex = params.train.regularization_variable_regex
+ self._l2_weight_decay = params.train.l2_weight_decay
+
+ # Checkpoint restoration.
+ self._checkpoint = params.train.checkpoint.as_dict()
+
+ # Summary.
+ self._enable_summary = params.enable_summary
+ self._model_dir = params.model_dir
+
+ @abc.abstractmethod
+ def build_outputs(self, inputs, mode):
+ """Build the graph of the forward path."""
+ pass
+
+ @abc.abstractmethod
+ def build_model(self, params, mode):
+ """Build the model object."""
+ pass
+
+ @abc.abstractmethod
+ def build_loss_fn(self):
+ """Build the model object."""
+ pass
+
+ def post_processing(self, labels, outputs):
+ """Post-processing function."""
+ return labels, outputs
+
+ def model_outputs(self, inputs, mode):
+ """Build the model outputs."""
+ return self.build_outputs(inputs, mode)
+
+ def build_optimizer(self):
+ """Returns train_op to optimize total loss."""
+ # Sets up the optimizer.
+ return self._optimizer_fn(self._learning_rate)
+
+ def make_filter_trainable_variables_fn(self):
+ """Creates a function for filtering trainable varialbes."""
+ return _make_filter_trainable_variables_fn(self._frozen_variable_prefix)
+
+ def weight_decay_loss(self, trainable_variables):
+ reg_variables = [
+ v for v in trainable_variables
+ if self._regularization_var_regex is None or
+ re.match(self._regularization_var_regex, v.name)
+ ]
+
+ return self._l2_weight_decay * tf.add_n(
+ [tf.nn.l2_loss(v) for v in reg_variables])
+
+ def make_restore_checkpoint_fn(self):
+ """Returns scaffold function to restore parameters from v1 checkpoint."""
+ if 'skip_checkpoint_variables' in self._checkpoint:
+ skip_regex = self._checkpoint['skip_checkpoint_variables']
+ else:
+ skip_regex = None
+ return checkpoint_utils.make_restore_checkpoint_fn(
+ self._checkpoint['path'],
+ prefix=self._checkpoint['prefix'],
+ skip_regex=skip_regex)
+
+ def eval_metrics(self):
+ """Returns tuple of metric function and its inputs for evaluation."""
+ raise NotImplementedError('Unimplemented eval_metrics')
diff --git a/modeling/official/legacy/detection/modeling/checkpoint_utils.py b/modeling/official/legacy/detection/modeling/checkpoint_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e50a9a289f4b09de55fd675f4da680dae66b44e
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/checkpoint_utils.py
@@ -0,0 +1,142 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Util functions for loading checkpoints.
+
+Especially for loading Tensorflow 1.x
+checkpoint to Tensorflow 2.x (keras) model.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from absl import logging
+
+import tensorflow as tf, tf_keras
+
+
+def _build_assignment_map(keras_model,
+ prefix='',
+ skip_variables_regex=None,
+ var_to_shape_map=None):
+ """Builds the variable assignment map.
+
+ Compute an assignment mapping for loading older checkpoints into a Keras
+ model. Variable names are remapped from the original TPUEstimator model to
+ the new Keras name.
+
+ Args:
+ keras_model: tf_keras.Model object to provide variables to assign.
+ prefix: prefix in the variable name to be remove for alignment with names in
+ the checkpoint.
+ skip_variables_regex: regular expression to math the names of variables that
+ do not need to be assign.
+ var_to_shape_map: variable name to shape mapping from the checkpoint.
+
+ Returns:
+ The variable assignment map.
+ """
+ assignment_map = {}
+
+ checkpoint_names = []
+ if var_to_shape_map:
+ # pylint: disable=g-long-lambda
+ checkpoint_names = list(
+ filter(
+ lambda x: not x.endswith('Momentum') and not x.endswith(
+ 'global_step'), var_to_shape_map.keys()))
+ # pylint: enable=g-long-lambda
+
+ logging.info('Number of variables in the checkpoint %d',
+ len(checkpoint_names))
+
+ for var in keras_model.variables:
+ var_name = var.name
+
+ if skip_variables_regex and re.match(skip_variables_regex, var_name):
+ continue
+ # Trim the index of the variable.
+ if ':' in var_name:
+ var_name = var_name[:var_name.rindex(':')]
+ if var_name.startswith(prefix):
+ var_name = var_name[len(prefix):]
+
+ if not var_to_shape_map:
+ assignment_map[var_name] = var
+ continue
+
+ # Match name with variables in the checkpoint.
+ # pylint: disable=cell-var-from-loop
+ match_names = list(filter(lambda x: x.endswith(var_name), checkpoint_names))
+ # pylint: enable=cell-var-from-loop
+ try:
+ if match_names:
+ assert len(match_names) == 1, 'more then on matches for {}: {}'.format(
+ var_name, match_names)
+ checkpoint_names.remove(match_names[0])
+ assignment_map[match_names[0]] = var
+ else:
+ logging.info('Error not found var name: %s', var_name)
+ except Exception as e:
+ logging.info('Error removing the match_name: %s', match_names)
+ logging.info('Exception: %s', e)
+ raise
+ logging.info('Found matching variable in checkpoint: %d', len(assignment_map))
+ return assignment_map
+
+
+def _get_checkpoint_map(checkpoint_path):
+ reader = tf.train.load_checkpoint(checkpoint_path)
+ return reader.get_variable_to_shape_map()
+
+
+def make_restore_checkpoint_fn(checkpoint_path, prefix='', skip_regex=None):
+ """Returns scaffold function to restore parameters from v1 checkpoint.
+
+ Args:
+ checkpoint_path: path of the checkpoint folder or file.
+ Example 1: '/path/to/model_dir/'
+ Example 2: '/path/to/model.ckpt-22500'
+ prefix: prefix in the variable name to be remove for alignment with names in
+ the checkpoint.
+ skip_regex: regular expression to math the names of variables that do not
+ need to be assign.
+
+ Returns:
+ Callable[tf.kears.Model] -> void. Fn to load v1 checkpoint to keras model.
+ """
+
+ def _restore_checkpoint_fn(keras_model):
+ """Loads pretrained model through scaffold function."""
+ if not checkpoint_path:
+ logging.info('checkpoint_path is empty')
+ return
+ var_prefix = prefix
+ if prefix and not prefix.endswith('/'):
+ var_prefix += '/'
+ var_to_shape_map = _get_checkpoint_map(checkpoint_path)
+ assert var_to_shape_map, 'var_to_shape_map should not be empty'
+ vars_to_load = _build_assignment_map(
+ keras_model,
+ prefix=var_prefix,
+ skip_variables_regex=skip_regex,
+ var_to_shape_map=var_to_shape_map)
+ if not vars_to_load:
+ raise ValueError('Variables to load is empty.')
+ tf.compat.v1.train.init_from_checkpoint(checkpoint_path, vars_to_load)
+
+ return _restore_checkpoint_fn
diff --git a/modeling/official/legacy/detection/modeling/factory.py b/modeling/official/legacy/detection/modeling/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b82451c3bfc24ebcb4af2c69382d69ce569b364
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/factory.py
@@ -0,0 +1,37 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Factory to build detection model."""
+
+
+from official.legacy.detection.modeling import maskrcnn_model
+from official.legacy.detection.modeling import olnmask_model
+from official.legacy.detection.modeling import retinanet_model
+from official.legacy.detection.modeling import shapemask_model
+
+
+def model_generator(params):
+ """Model function generator."""
+ if params.type == 'retinanet':
+ model_fn = retinanet_model.RetinanetModel(params)
+ elif params.type == 'mask_rcnn':
+ model_fn = maskrcnn_model.MaskrcnnModel(params)
+ elif params.type == 'olnmask':
+ model_fn = olnmask_model.OlnMaskModel(params)
+ elif params.type == 'shapemask':
+ model_fn = shapemask_model.ShapeMaskModel(params)
+ else:
+ raise ValueError('Model %s is not supported.'% params.type)
+
+ return model_fn
diff --git a/modeling/official/legacy/detection/modeling/learning_rates.py b/modeling/official/legacy/detection/modeling/learning_rates.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef242b89cb4b9b055bb2a5913d85c8128c9b7e5
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/learning_rates.py
@@ -0,0 +1,98 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Learning rate schedule."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf, tf_keras
+from official.modeling.hyperparams import params_dict
+
+
+class StepLearningRateWithLinearWarmup(
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Class to generate learning rate tensor."""
+
+ def __init__(self, total_steps, params):
+ """Creates the step learning rate tensor with linear warmup."""
+ super(StepLearningRateWithLinearWarmup, self).__init__()
+ self._total_steps = total_steps
+ assert isinstance(params, (dict, params_dict.ParamsDict))
+ if isinstance(params, dict):
+ params = params_dict.ParamsDict(params)
+ self._params = params
+
+ def __call__(self, global_step):
+ warmup_lr = self._params.warmup_learning_rate
+ warmup_steps = self._params.warmup_steps
+ init_lr = self._params.init_learning_rate
+ lr_levels = self._params.learning_rate_levels
+ lr_steps = self._params.learning_rate_steps
+ linear_warmup = (
+ warmup_lr + tf.cast(global_step, dtype=tf.float32) / warmup_steps *
+ (init_lr - warmup_lr))
+ learning_rate = tf.where(global_step < warmup_steps, linear_warmup, init_lr)
+
+ for next_learning_rate, start_step in zip(lr_levels, lr_steps):
+ learning_rate = tf.where(global_step >= start_step, next_learning_rate,
+ learning_rate)
+ return learning_rate
+
+ def get_config(self):
+ return {'_params': self._params.as_dict()}
+
+
+class CosineLearningRateWithLinearWarmup(
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Class to generate learning rate tensor."""
+
+ def __init__(self, total_steps, params):
+ """Creates the cosine learning rate tensor with linear warmup."""
+ super(CosineLearningRateWithLinearWarmup, self).__init__()
+ self._total_steps = total_steps
+ assert isinstance(params, (dict, params_dict.ParamsDict))
+ if isinstance(params, dict):
+ params = params_dict.ParamsDict(params)
+ self._params = params
+
+ def __call__(self, global_step):
+ global_step = tf.cast(global_step, dtype=tf.float32)
+ warmup_lr = self._params.warmup_learning_rate
+ warmup_steps = self._params.warmup_steps
+ init_lr = self._params.init_learning_rate
+ total_steps = self._total_steps
+ linear_warmup = (
+ warmup_lr + global_step / warmup_steps * (init_lr - warmup_lr))
+ cosine_learning_rate = (
+ init_lr * (tf.cos(np.pi * (global_step - warmup_steps) /
+ (total_steps - warmup_steps)) + 1.0) / 2.0)
+ learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
+ cosine_learning_rate)
+ return learning_rate
+
+ def get_config(self):
+ return {'_params': self._params.as_dict()}
+
+
+def learning_rate_generator(total_steps, params):
+ """The learning rate function generator."""
+ if params.type == 'step':
+ return StepLearningRateWithLinearWarmup(total_steps, params)
+ elif params.type == 'cosine':
+ return CosineLearningRateWithLinearWarmup(total_steps, params)
+ else:
+ raise ValueError('Unsupported learning rate type: {}.'.format(params.type))
diff --git a/modeling/official/legacy/detection/modeling/losses.py b/modeling/official/legacy/detection/modeling/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..110cd795178882e26cbb536117de8a49a3a89955
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/losses.py
@@ -0,0 +1,725 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Losses used for detection models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+
+def focal_loss(logits, targets, alpha, gamma, normalizer):
+ """Compute the focal loss between `logits` and the golden `target` values.
+
+ Focal loss = -(1-pt)^gamma * log(pt)
+ where pt is the probability of being classified to the true class.
+
+ Args:
+ logits: A float32 tensor of size
+ [batch, height_in, width_in, num_predictions].
+ targets: A float32 tensor of size
+ [batch, height_in, width_in, num_predictions].
+ alpha: A float32 scalar multiplying alpha to the loss from positive examples
+ and (1-alpha) to the loss from negative examples.
+ gamma: A float32 scalar modulating loss from hard and easy examples.
+ normalizer: A float32 scalar normalizes the total loss from all examples.
+
+ Returns:
+ loss: A float32 Tensor of size [batch, height_in, width_in, num_predictions]
+ representing normalized loss on the prediction map.
+ """
+ with tf.name_scope('focal_loss'):
+ positive_label_mask = tf.math.equal(targets, 1.0)
+ cross_entropy = (
+ tf.nn.sigmoid_cross_entropy_with_logits(labels=targets, logits=logits))
+ # Below are comments/derivations for computing modulator.
+ # For brevity, let x = logits, z = targets, r = gamma, and p_t = sigmod(x)
+ # for positive samples and 1 - sigmoid(x) for negative examples.
+ #
+ # The modulator, defined as (1 - P_t)^r, is a critical part in focal loss
+ # computation. For r > 0, it puts more weights on hard examples, and less
+ # weights on easier ones. However if it is directly computed as (1 - P_t)^r,
+ # its back-propagation is not stable when r < 1. The implementation here
+ # resolves the issue.
+ #
+ # For positive samples (labels being 1),
+ # (1 - p_t)^r
+ # = (1 - sigmoid(x))^r
+ # = (1 - (1 / (1 + exp(-x))))^r
+ # = (exp(-x) / (1 + exp(-x)))^r
+ # = exp(log((exp(-x) / (1 + exp(-x)))^r))
+ # = exp(r * log(exp(-x)) - r * log(1 + exp(-x)))
+ # = exp(- r * x - r * log(1 + exp(-x)))
+ #
+ # For negative samples (labels being 0),
+ # (1 - p_t)^r
+ # = (sigmoid(x))^r
+ # = (1 / (1 + exp(-x)))^r
+ # = exp(log((1 / (1 + exp(-x)))^r))
+ # = exp(-r * log(1 + exp(-x)))
+ #
+ # Therefore one unified form for positive (z = 1) and negative (z = 0)
+ # samples is:
+ # (1 - p_t)^r = exp(-r * z * x - r * log(1 + exp(-x))).
+ neg_logits = -1.0 * logits
+ modulator = tf.math.exp(gamma * targets * neg_logits -
+ gamma * tf.math.log1p(tf.math.exp(neg_logits)))
+ loss = modulator * cross_entropy
+ weighted_loss = tf.where(positive_label_mask, alpha * loss,
+ (1.0 - alpha) * loss)
+ weighted_loss /= normalizer
+ return weighted_loss
+
+
+class RpnScoreLoss(object):
+ """Region Proposal Network score loss function."""
+
+ def __init__(self, params):
+ self._rpn_batch_size_per_im = params.rpn_batch_size_per_im
+ self._binary_crossentropy = tf_keras.losses.BinaryCrossentropy(
+ reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, score_outputs, labels):
+ """Computes total RPN detection loss.
+
+ Computes total RPN detection loss including box and score from all levels.
+
+ Args:
+ score_outputs: an OrderDict with keys representing levels and values
+ representing scores in [batch_size, height, width, num_anchors].
+ labels: the dictionary that returned from dataloader that includes
+ groundturth targets.
+
+ Returns:
+ rpn_score_loss: a scalar tensor representing total score loss.
+ """
+ with tf.name_scope('rpn_loss'):
+ levels = sorted(score_outputs.keys())
+
+ score_losses = []
+ for level in levels:
+ score_losses.append(
+ self._rpn_score_loss(
+ score_outputs[level],
+ labels[level],
+ normalizer=tf.cast(
+ tf.shape(score_outputs[level])[0] *
+ self._rpn_batch_size_per_im, dtype=tf.float32)))
+
+ # Sums per level losses to total loss.
+ return tf.math.add_n(score_losses)
+
+ def _rpn_score_loss(self, score_outputs, score_targets, normalizer=1.0):
+ """Computes score loss."""
+ # score_targets has three values:
+ # (1) score_targets[i]=1, the anchor is a positive sample.
+ # (2) score_targets[i]=0, negative.
+ # (3) score_targets[i]=-1, the anchor is don't care (ignore).
+ with tf.name_scope('rpn_score_loss'):
+ mask = tf.math.logical_or(tf.math.equal(score_targets, 1),
+ tf.math.equal(score_targets, 0))
+
+ score_targets = tf.math.maximum(score_targets,
+ tf.zeros_like(score_targets))
+
+ score_targets = tf.expand_dims(score_targets, axis=-1)
+ score_outputs = tf.expand_dims(score_outputs, axis=-1)
+ score_loss = self._binary_crossentropy(
+ score_targets, score_outputs, sample_weight=mask)
+
+ score_loss /= normalizer
+ return score_loss
+
+
+class RpnBoxLoss(object):
+ """Region Proposal Network box regression loss function."""
+
+ def __init__(self, params):
+ logging.info('RpnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
+ # The delta is typically around the mean value of regression target.
+ # for instances, the regression targets of 512x512 input with 6 anchors on
+ # P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
+ self._huber_loss = tf_keras.losses.Huber(
+ delta=params.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM)
+
+ def __call__(self, box_outputs, labels):
+ """Computes total RPN detection loss.
+
+ Computes total RPN detection loss including box and score from all levels.
+
+ Args:
+ box_outputs: an OrderDict with keys representing levels and values
+ representing box regression targets in
+ [batch_size, height, width, num_anchors * 4].
+ labels: the dictionary that returned from dataloader that includes
+ groundturth targets.
+
+ Returns:
+ rpn_box_loss: a scalar tensor representing total box regression loss.
+ """
+ with tf.name_scope('rpn_loss'):
+ levels = sorted(box_outputs.keys())
+
+ box_losses = []
+ for level in levels:
+ box_losses.append(self._rpn_box_loss(box_outputs[level], labels[level]))
+
+ # Sum per level losses to total loss.
+ return tf.add_n(box_losses)
+
+ def _rpn_box_loss(self, box_outputs, box_targets, normalizer=1.0):
+ """Computes box regression loss."""
+ with tf.name_scope('rpn_box_loss'):
+ mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
+ box_targets = tf.expand_dims(box_targets, axis=-1)
+ box_outputs = tf.expand_dims(box_outputs, axis=-1)
+ box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
+ # The loss is normalized by the sum of non-zero weights and additional
+ # normalizer provided by the function caller. Using + 0.01 here to avoid
+ # division by zero.
+ box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
+ return box_loss
+
+
+class OlnRpnCenterLoss(object):
+ """Object Localization Network RPN centerness regression loss function."""
+
+ def __init__(self):
+ self._l1_loss = tf_keras.losses.MeanAbsoluteError(
+ reduction=tf_keras.losses.Reduction.SUM)
+
+ def __call__(self, center_outputs, labels):
+ """Computes total RPN centerness regression loss.
+
+ Computes total RPN centerness score regression loss from all levels.
+
+ Args:
+ center_outputs: an OrderDict with keys representing levels and values
+ representing anchor centerness regression targets in
+ [batch_size, height, width, num_anchors * 4].
+ labels: the dictionary that returned from dataloader that includes
+ groundturth targets.
+
+ Returns:
+ rpn_center_loss: a scalar tensor representing total centerness regression
+ loss.
+ """
+ with tf.name_scope('rpn_loss'):
+ # Normalizer.
+ levels = sorted(center_outputs.keys())
+ num_valid = 0
+ # 00, neg=0, ign=-1.
+ mask_ = tf.cast(tf.logical_and(
+ tf.greater(center_targets[level][..., 0], 0.0),
+ tf.greater(tf.reduce_min(labels[level], -1), 0.0)), tf.float32)
+ normalizer += tf.reduce_sum(mask_)
+ normalizer += 1e-8
+ # iou_loss over multi levels.
+ iou_losses = []
+ for level in levels:
+ iou_losses.append(
+ self._rpn_iou_loss(
+ box_outputs[level], labels[level],
+ center_weight=center_targets[level][..., 0],
+ normalizer=normalizer))
+ # Sum per level losses to total loss.
+ return tf.add_n(iou_losses)
+
+ def _rpn_iou_loss(self, box_outputs, box_targets,
+ center_weight=None, normalizer=1.0):
+ """Computes box regression loss."""
+ # for instances, the regression targets of 512x512 input with 6 anchors on
+ # P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
+ with tf.name_scope('rpn_iou_loss'):
+ mask = tf.logical_and(
+ tf.greater(center_weight, 0.0),
+ tf.greater(tf.reduce_min(box_targets, -1), 0.0))
+
+ pred_left = box_outputs[..., 0]
+ pred_right = box_outputs[..., 1]
+ pred_top = box_outputs[..., 2]
+ pred_bottom = box_outputs[..., 3]
+
+ gt_left = box_targets[..., 0]
+ gt_right = box_targets[..., 1]
+ gt_top = box_targets[..., 2]
+ gt_bottom = box_targets[..., 3]
+
+ inter_width = (tf.minimum(pred_left, gt_left) +
+ tf.minimum(pred_right, gt_right))
+ inter_height = (tf.minimum(pred_top, gt_top) +
+ tf.minimum(pred_bottom, gt_bottom))
+ inter_area = inter_width * inter_height
+ union_area = ((pred_left + pred_right) * (pred_top + pred_bottom) +
+ (gt_left + gt_right) * (gt_top + gt_bottom) -
+ inter_area)
+ iou = inter_area / (union_area + 1e-8)
+ mask_ = tf.cast(mask, tf.float32)
+ iou = tf.clip_by_value(iou, clip_value_min=1e-8, clip_value_max=1.0)
+ neg_log_iou = -tf.math.log(iou)
+ iou_loss = tf.reduce_sum(neg_log_iou * mask_)
+ iou_loss /= normalizer
+ return iou_loss
+
+
+class FastrcnnClassLoss(object):
+ """Fast R-CNN classification loss function."""
+
+ def __init__(self):
+ self._categorical_crossentropy = tf_keras.losses.CategoricalCrossentropy(
+ reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, class_outputs, class_targets):
+ """Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
+
+ This function implements the classification loss of the Fast-RCNN.
+
+ The classification loss is softmax on all RoIs.
+ Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
+
+ Args:
+ class_outputs: a float tensor representing the class prediction for each box
+ with a shape of [batch_size, num_boxes, num_classes].
+ class_targets: a float tensor representing the class label for each box
+ with a shape of [batch_size, num_boxes].
+
+ Returns:
+ a scalar tensor representing total class loss.
+ """
+ with tf.name_scope('fast_rcnn_loss'):
+ batch_size, num_boxes, num_classes = class_outputs.get_shape().as_list()
+ class_targets = tf.cast(class_targets, dtype=tf.int32)
+ class_targets_one_hot = tf.one_hot(class_targets, num_classes)
+ return self._fast_rcnn_class_loss(class_outputs, class_targets_one_hot,
+ normalizer=batch_size * num_boxes / 2.0)
+
+ def _fast_rcnn_class_loss(self, class_outputs, class_targets_one_hot,
+ normalizer):
+ """Computes classification loss."""
+ with tf.name_scope('fast_rcnn_class_loss'):
+ class_loss = self._categorical_crossentropy(class_targets_one_hot,
+ class_outputs)
+
+ class_loss /= normalizer
+ return class_loss
+
+
+class FastrcnnBoxLoss(object):
+ """Fast R-CNN box regression loss function."""
+
+ def __init__(self, params):
+ logging.info('FastrcnnBoxLoss huber_loss_delta %s', params.huber_loss_delta)
+ # The delta is typically around the mean value of regression target.
+ # for instances, the regression targets of 512x512 input with 6 anchors on
+ # P2-P6 pyramid is about [0.1, 0.1, 0.2, 0.2].
+ self._huber_loss = tf_keras.losses.Huber(
+ delta=params.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM)
+
+ def __call__(self, box_outputs, class_targets, box_targets):
+ """Computes the box loss (Fast-RCNN branch) of Mask-RCNN.
+
+ This function implements the box regression loss of the Fast-RCNN. As the
+ `box_outputs` produces `num_classes` boxes for each RoI, the reference model
+ expands `box_targets` to match the shape of `box_outputs` and selects only
+ the target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/fast_rcnn.py) # pylint: disable=line-too-long
+ Instead, this function selects the `box_outputs` by the `class_targets` so
+ that it doesn't expand `box_targets`.
+
+ The box loss is smooth L1-loss on only positive samples of RoIs.
+ Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
+
+ Args:
+ box_outputs: a float tensor representing the box prediction for each box
+ with a shape of [batch_size, num_boxes, num_classes * 4].
+ class_targets: a float tensor representing the class label for each box
+ with a shape of [batch_size, num_boxes].
+ box_targets: a float tensor representing the box label for each box
+ with a shape of [batch_size, num_boxes, 4].
+
+ Returns:
+ box_loss: a scalar tensor representing total box regression loss.
+ """
+ with tf.name_scope('fast_rcnn_loss'):
+ class_targets = tf.cast(class_targets, dtype=tf.int32)
+
+ # Selects the box from `box_outputs` based on `class_targets`, with which
+ # the box has the maximum overlap.
+ (batch_size, num_rois,
+ num_class_specific_boxes) = box_outputs.get_shape().as_list()
+ num_classes = num_class_specific_boxes // 4
+ box_outputs = tf.reshape(box_outputs,
+ [batch_size, num_rois, num_classes, 4])
+
+ box_indices = tf.reshape(
+ class_targets + tf.tile(
+ tf.expand_dims(
+ tf.range(batch_size) * num_rois * num_classes, 1),
+ [1, num_rois]) + tf.tile(
+ tf.expand_dims(tf.range(num_rois) * num_classes, 0),
+ [batch_size, 1]), [-1])
+
+ box_outputs = tf.matmul(
+ tf.one_hot(
+ box_indices,
+ batch_size * num_rois * num_classes,
+ dtype=box_outputs.dtype), tf.reshape(box_outputs, [-1, 4]))
+ box_outputs = tf.reshape(box_outputs, [batch_size, -1, 4])
+
+ return self._fast_rcnn_box_loss(box_outputs, box_targets, class_targets)
+
+ def _fast_rcnn_box_loss(self, box_outputs, box_targets, class_targets,
+ normalizer=1.0):
+ """Computes box regression loss."""
+ with tf.name_scope('fast_rcnn_box_loss'):
+ mask = tf.tile(tf.expand_dims(tf.greater(class_targets, 0), axis=2),
+ [1, 1, 4])
+ mask = tf.cast(mask, dtype=tf.float32)
+ box_targets = tf.expand_dims(box_targets, axis=-1)
+ box_outputs = tf.expand_dims(box_outputs, axis=-1)
+ box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
+ # The loss is normalized by the number of ones in mask,
+ # additianal normalizer provided by the user and using 0.01 here to avoid
+ # division by 0.
+ box_loss /= normalizer * (tf.reduce_sum(mask) + 0.01)
+ return box_loss
+
+
+class OlnBoxScoreLoss(object):
+ """Object Localization Network Box-Iou scoring function."""
+
+ def __init__(self, params):
+ self._ignore_threshold = params.ignore_threshold
+ self._l1_loss = tf_keras.losses.MeanAbsoluteError(
+ reduction=tf_keras.losses.Reduction.SUM)
+
+ def __call__(self, score_outputs, score_targets):
+ """Computes the class loss (Fast-RCNN branch) of Mask-RCNN.
+
+ This function implements the classification loss of the Fast-RCNN.
+
+ The classification loss is softmax on all RoIs.
+ Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/modeling/fast_rcnn_heads.py # pylint: disable=line-too-long
+
+ Args:
+ score_outputs: a float tensor representing the class prediction for each box
+ with a shape of [batch_size, num_boxes, num_classes].
+ score_targets: a float tensor representing the class label for each box
+ with a shape of [batch_size, num_boxes].
+
+ Returns:
+ a scalar tensor representing total score loss.
+ """
+ with tf.name_scope('fast_rcnn_loss'):
+ score_outputs = tf.squeeze(score_outputs, -1)
+
+ mask = tf.greater(score_targets, self._ignore_threshold)
+ num_valid = tf.reduce_sum(tf.cast(mask, tf.float32))
+ score_targets = tf.maximum(score_targets, tf.zeros_like(score_targets))
+ score_outputs = tf.sigmoid(score_outputs)
+ score_targets = tf.expand_dims(score_targets, -1)
+ score_outputs = tf.expand_dims(score_outputs, -1)
+ mask = tf.cast(mask, dtype=tf.float32)
+ score_loss = self._l1_loss(score_targets, score_outputs,
+ sample_weight=mask)
+ score_loss /= (num_valid + 1e-10)
+ return score_loss
+
+
+class MaskrcnnLoss(object):
+ """Mask R-CNN instance segmentation mask loss function."""
+
+ def __init__(self):
+ self._binary_crossentropy = tf_keras.losses.BinaryCrossentropy(
+ reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, mask_outputs, mask_targets, select_class_targets):
+ """Computes the mask loss of Mask-RCNN.
+
+ This function implements the mask loss of Mask-RCNN. As the `mask_outputs`
+ produces `num_classes` masks for each RoI, the reference model expands
+ `mask_targets` to match the shape of `mask_outputs` and selects only the
+ target that the RoI has a maximum overlap. (Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/mask_rcnn.py) # pylint: disable=line-too-long
+ Instead, this implementation selects the `mask_outputs` by the `class_targets`
+ so that it doesn't expand `mask_targets`. Note that the selection logic is
+ done in the post-processing of mask_rcnn_fn in mask_rcnn_architecture.py.
+
+ Args:
+ mask_outputs: a float tensor representing the prediction for each mask,
+ with a shape of
+ [batch_size, num_masks, mask_height, mask_width].
+ mask_targets: a float tensor representing the binary mask of ground truth
+ labels for each mask with a shape of
+ [batch_size, num_masks, mask_height, mask_width].
+ select_class_targets: a tensor with a shape of [batch_size, num_masks],
+ representing the foreground mask targets.
+
+ Returns:
+ mask_loss: a float tensor representing total mask loss.
+ """
+ with tf.name_scope('mask_rcnn_loss'):
+ (batch_size, num_masks, mask_height,
+ mask_width) = mask_outputs.get_shape().as_list()
+
+ weights = tf.tile(
+ tf.reshape(tf.greater(select_class_targets, 0),
+ [batch_size, num_masks, 1, 1]),
+ [1, 1, mask_height, mask_width])
+ weights = tf.cast(weights, dtype=tf.float32)
+
+ mask_targets = tf.expand_dims(mask_targets, axis=-1)
+ mask_outputs = tf.expand_dims(mask_outputs, axis=-1)
+ mask_loss = self._binary_crossentropy(mask_targets, mask_outputs,
+ sample_weight=weights)
+
+ # The loss is normalized by the number of 1's in weights and
+ # + 0.01 is used to avoid division by zero.
+ return mask_loss / (tf.reduce_sum(weights) + 0.01)
+
+
+class RetinanetClassLoss(object):
+ """RetinaNet class loss."""
+
+ def __init__(self, params, num_classes):
+ self._num_classes = num_classes
+ self._focal_loss_alpha = params.focal_loss_alpha
+ self._focal_loss_gamma = params.focal_loss_gamma
+
+ def __call__(self, cls_outputs, labels, num_positives):
+ """Computes total detection loss.
+
+ Computes total detection loss including box and class loss from all levels.
+
+ Args:
+ cls_outputs: an OrderDict with keys representing levels and values
+ representing logits in [batch_size, height, width,
+ num_anchors * num_classes].
+ labels: the dictionary that returned from dataloader that includes
+ class groundturth targets.
+ num_positives: number of positive examples in the minibatch.
+
+ Returns:
+ an integar tensor representing total class loss.
+ """
+ # Sums all positives in a batch for normalization and avoids zero
+ # num_positives_sum, which would lead to inf loss during training
+ num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
+
+ cls_losses = []
+ for level in cls_outputs.keys():
+ cls_losses.append(self.class_loss(
+ cls_outputs[level], labels[level], num_positives_sum))
+ # Sums per level losses to total loss.
+ return tf.add_n(cls_losses)
+
+ def class_loss(self, cls_outputs, cls_targets, num_positives,
+ ignore_label=-2):
+ """Computes RetinaNet classification loss."""
+ # Onehot encoding for classification labels.
+ cls_targets_one_hot = tf.one_hot(cls_targets, self._num_classes)
+ bs, height, width, _, _ = cls_targets_one_hot.get_shape().as_list()
+ cls_targets_one_hot = tf.reshape(cls_targets_one_hot,
+ [bs, height, width, -1])
+ loss = focal_loss(tf.cast(cls_outputs, dtype=tf.float32),
+ tf.cast(cls_targets_one_hot, dtype=tf.float32),
+ self._focal_loss_alpha,
+ self._focal_loss_gamma,
+ num_positives)
+
+ ignore_loss = tf.where(
+ tf.equal(cls_targets, ignore_label),
+ tf.zeros_like(cls_targets, dtype=tf.float32),
+ tf.ones_like(cls_targets, dtype=tf.float32),
+ )
+ ignore_loss = tf.expand_dims(ignore_loss, -1)
+ ignore_loss = tf.tile(ignore_loss, [1, 1, 1, 1, self._num_classes])
+ ignore_loss = tf.reshape(ignore_loss, tf.shape(input=loss))
+ return tf.reduce_sum(input_tensor=ignore_loss * loss)
+
+
+class RetinanetBoxLoss(object):
+ """RetinaNet box loss."""
+
+ def __init__(self, params):
+ self._huber_loss = tf_keras.losses.Huber(
+ delta=params.huber_loss_delta, reduction=tf_keras.losses.Reduction.SUM)
+
+ def __call__(self, box_outputs, labels, num_positives):
+ """Computes box detection loss.
+
+ Computes total detection loss including box and class loss from all levels.
+
+ Args:
+ box_outputs: an OrderDict with keys representing levels and values
+ representing box regression targets in [batch_size, height, width,
+ num_anchors * 4].
+ labels: the dictionary that returned from dataloader that includes
+ box groundturth targets.
+ num_positives: number of positive examples in the minibatch.
+
+ Returns:
+ an integer tensor representing total box regression loss.
+ """
+ # Sums all positives in a batch for normalization and avoids zero
+ # num_positives_sum, which would lead to inf loss during training
+ num_positives_sum = tf.reduce_sum(input_tensor=num_positives) + 1.0
+
+ box_losses = []
+ for level in box_outputs.keys():
+ box_targets_l = labels[level]
+ box_losses.append(
+ self.box_loss(box_outputs[level], box_targets_l, num_positives_sum))
+ # Sums per level losses to total loss.
+ return tf.add_n(box_losses)
+
+ def box_loss(self, box_outputs, box_targets, num_positives):
+ """Computes RetinaNet box regression loss."""
+ # The delta is typically around the mean value of regression target.
+ # for instances, the regression targets of 512x512 input with 6 anchors on
+ # P3-P7 pyramid is about [0.1, 0.1, 0.2, 0.2].
+ normalizer = num_positives * 4.0
+ mask = tf.cast(tf.not_equal(box_targets, 0.0), dtype=tf.float32)
+ box_targets = tf.expand_dims(box_targets, axis=-1)
+ box_outputs = tf.expand_dims(box_outputs, axis=-1)
+ box_loss = self._huber_loss(box_targets, box_outputs, sample_weight=mask)
+ box_loss /= normalizer
+ return box_loss
+
+
+class ShapemaskMseLoss(object):
+ """ShapeMask mask Mean Squared Error loss function wrapper."""
+
+ def __call__(self, probs, labels, valid_mask):
+ """Compute instance segmentation loss.
+
+ Args:
+ probs: A Tensor of shape [batch_size * num_points, height, width,
+ num_classes]. The logits are not necessarily between 0 and 1.
+ labels: A float32/float16 Tensor of shape [batch_size, num_instances,
+ mask_size, mask_size], where mask_size =
+ mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
+ for coarse masks and shape priors.
+ valid_mask: a binary mask indicating valid training masks.
+
+ Returns:
+ loss: an float tensor representing total mask classification loss.
+ """
+ with tf.name_scope('shapemask_prior_loss'):
+ batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
+ diff = (tf.cast(labels, dtype=tf.float32) -
+ tf.cast(probs, dtype=tf.float32))
+ diff *= tf.cast(
+ tf.reshape(valid_mask, [batch_size, num_instances, 1, 1]),
+ tf.float32)
+ # Adding 0.001 in the denominator to avoid division by zero.
+ loss = tf.nn.l2_loss(diff) / (tf.reduce_sum(labels) + 0.001)
+ return loss
+
+
+class ShapemaskLoss(object):
+ """ShapeMask mask loss function wrapper."""
+
+ def __init__(self):
+ self._binary_crossentropy = tf_keras.losses.BinaryCrossentropy(
+ reduction=tf_keras.losses.Reduction.SUM, from_logits=True)
+
+ def __call__(self, logits, labels, valid_mask):
+ """ShapeMask mask cross entropy loss function wrapper.
+
+ Args:
+ logits: A Tensor of shape [batch_size * num_instances, height, width,
+ num_classes]. The logits are not necessarily between 0 and 1.
+ labels: A float16/float32 Tensor of shape [batch_size, num_instances,
+ mask_size, mask_size], where mask_size =
+ mask_crop_size * gt_upsample_scale for fine mask, or mask_crop_size
+ for coarse masks and shape priors.
+ valid_mask: a binary mask of shape [batch_size, num_instances]
+ indicating valid training masks.
+ Returns:
+ loss: an float tensor representing total mask classification loss.
+ """
+ with tf.name_scope('shapemask_loss'):
+ batch_size, num_instances = valid_mask.get_shape().as_list()[:2]
+ labels = tf.cast(labels, tf.float32)
+ logits = tf.cast(logits, tf.float32)
+ loss = self._binary_crossentropy(labels, logits)
+ loss *= tf.cast(tf.reshape(
+ valid_mask, [batch_size, num_instances, 1, 1]), loss.dtype)
+ # Adding 0.001 in the denominator to avoid division by zero.
+ loss = tf.reduce_sum(loss) / (tf.reduce_sum(labels) + 0.001)
+ return loss
diff --git a/modeling/official/legacy/detection/modeling/maskrcnn_model.py b/modeling/official/legacy/detection/modeling/maskrcnn_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..79ae56967735087bcc0b49db95dc8438beb917ca
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/maskrcnn_model.py
@@ -0,0 +1,338 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model defination for the Mask R-CNN Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import anchor
+from official.legacy.detection.dataloader import mode_keys
+from official.legacy.detection.evaluation import factory as eval_factory
+from official.legacy.detection.modeling import base_model
+from official.legacy.detection.modeling import losses
+from official.legacy.detection.modeling.architecture import factory
+from official.legacy.detection.ops import postprocess_ops
+from official.legacy.detection.ops import roi_ops
+from official.legacy.detection.ops import spatial_transform_ops
+from official.legacy.detection.ops import target_ops
+from official.legacy.detection.utils import box_utils
+
+
+class MaskrcnnModel(base_model.Model):
+ """Mask R-CNN model function."""
+
+ def __init__(self, params):
+ super(MaskrcnnModel, self).__init__(params)
+
+ # For eval metrics.
+ self._params = params
+ self._keras_model = None
+
+ self._include_mask = params.architecture.include_mask
+
+ # Architecture generators.
+ self._backbone_fn = factory.backbone_generator(params)
+ self._fpn_fn = factory.multilevel_features_generator(params)
+ self._rpn_head_fn = factory.rpn_head_generator(params)
+ self._generate_rois_fn = roi_ops.ROIGenerator(params.roi_proposal)
+ self._sample_rois_fn = target_ops.ROISampler(params.roi_sampling)
+ self._sample_masks_fn = target_ops.MaskSampler(
+ params.architecture.mask_target_size,
+ params.mask_sampling.num_mask_samples_per_image)
+
+ self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
+ if self._include_mask:
+ self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
+
+ # Loss function.
+ self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
+ self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
+ self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
+ self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
+ if self._include_mask:
+ self._mask_loss_fn = losses.MaskrcnnLoss()
+
+ self._generate_detections_fn = postprocess_ops.GenericDetectionGenerator(
+ params.postprocess)
+
+ self._transpose_input = params.train.transpose_input
+ assert not self._transpose_input, 'Transpose input is not supportted.'
+
+ def build_outputs(self, inputs, mode):
+ is_training = mode == mode_keys.TRAIN
+ model_outputs = {}
+
+ image = inputs['image']
+ _, image_height, image_width, _ = image.get_shape().as_list()
+ backbone_features = self._backbone_fn(image, is_training)
+ fpn_features = self._fpn_fn(backbone_features, is_training)
+
+ rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
+ fpn_features, is_training)
+ model_outputs.update({
+ 'rpn_score_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ rpn_score_outputs),
+ 'rpn_box_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ rpn_box_outputs),
+ })
+ input_anchor = anchor.Anchor(self._params.architecture.min_level,
+ self._params.architecture.max_level,
+ self._params.anchor.num_scales,
+ self._params.anchor.aspect_ratios,
+ self._params.anchor.anchor_size,
+ (image_height, image_width))
+ rpn_rois, _ = self._generate_rois_fn(rpn_box_outputs, rpn_score_outputs,
+ input_anchor.multilevel_boxes,
+ inputs['image_info'][:, 1, :],
+ is_training)
+ if is_training:
+ rpn_rois = tf.stop_gradient(rpn_rois)
+
+ # Sample proposals.
+ rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices = (
+ self._sample_rois_fn(rpn_rois, inputs['gt_boxes'],
+ inputs['gt_classes']))
+
+ # Create bounding box training targets.
+ box_targets = box_utils.encode_boxes(
+ matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
+ # If the target is background, the box target is set to all 0s.
+ box_targets = tf.where(
+ tf.tile(
+ tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
+ [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
+ model_outputs.update({
+ 'class_targets': matched_gt_classes,
+ 'box_targets': box_targets,
+ })
+
+ roi_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, rpn_rois, output_size=7)
+
+ class_outputs, box_outputs = self._frcnn_head_fn(roi_features, is_training)
+
+ model_outputs.update({
+ 'class_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ class_outputs),
+ 'box_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ box_outputs),
+ })
+
+ # Add this output to train to make the checkpoint loadable in predict mode.
+ # If we skip it in train mode, the heads will be out-of-order and checkpoint
+ # loading will fail.
+ boxes, scores, classes, valid_detections = self._generate_detections_fn(
+ box_outputs, class_outputs, rpn_rois, inputs['image_info'][:, 1:2, :])
+ model_outputs.update({
+ 'num_detections': valid_detections,
+ 'detection_boxes': boxes,
+ 'detection_classes': classes,
+ 'detection_scores': scores,
+ })
+
+ if not self._include_mask:
+ return model_outputs
+
+ if is_training:
+ rpn_rois, classes, mask_targets = self._sample_masks_fn(
+ rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
+ inputs['gt_masks'])
+ mask_targets = tf.stop_gradient(mask_targets)
+
+ classes = tf.cast(classes, dtype=tf.int32)
+
+ model_outputs.update({
+ 'mask_targets': mask_targets,
+ 'sampled_class_targets': classes,
+ })
+ else:
+ rpn_rois = boxes
+ classes = tf.cast(classes, dtype=tf.int32)
+
+ mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, rpn_rois, output_size=14)
+
+ mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)
+
+ if is_training:
+ model_outputs.update({
+ 'mask_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ mask_outputs),
+ })
+ else:
+ model_outputs.update({'detection_masks': tf.nn.sigmoid(mask_outputs)})
+
+ return model_outputs
+
+ def build_loss_fn(self):
+ if self._keras_model is None:
+ raise ValueError('build_loss_fn() must be called after build_model().')
+
+ filter_fn = self.make_filter_trainable_variables_fn()
+ trainable_variables = filter_fn(self._keras_model.trainable_variables)
+
+ def _total_loss_fn(labels, outputs):
+ rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'],
+ labels['rpn_score_targets'])
+ rpn_box_loss = self._rpn_box_loss_fn(outputs['rpn_box_outputs'],
+ labels['rpn_box_targets'])
+
+ frcnn_class_loss = self._frcnn_class_loss_fn(outputs['class_outputs'],
+ outputs['class_targets'])
+ frcnn_box_loss = self._frcnn_box_loss_fn(outputs['box_outputs'],
+ outputs['class_targets'],
+ outputs['box_targets'])
+
+ if self._include_mask:
+ mask_loss = self._mask_loss_fn(outputs['mask_outputs'],
+ outputs['mask_targets'],
+ outputs['sampled_class_targets'])
+ else:
+ mask_loss = 0.0
+
+ model_loss = (
+ rpn_score_loss + rpn_box_loss + frcnn_class_loss + frcnn_box_loss +
+ mask_loss)
+
+ l2_regularization_loss = self.weight_decay_loss(trainable_variables)
+ total_loss = model_loss + l2_regularization_loss
+ return {
+ 'total_loss': total_loss,
+ 'loss': total_loss,
+ 'fast_rcnn_class_loss': frcnn_class_loss,
+ 'fast_rcnn_box_loss': frcnn_box_loss,
+ 'mask_loss': mask_loss,
+ 'model_loss': model_loss,
+ 'l2_regularization_loss': l2_regularization_loss,
+ 'rpn_score_loss': rpn_score_loss,
+ 'rpn_box_loss': rpn_box_loss,
+ }
+
+ return _total_loss_fn
+
+ def build_input_layers(self, params, mode):
+ is_training = mode == mode_keys.TRAIN
+ input_shape = (
+ params.maskrcnn_parser.output_size +
+ [params.maskrcnn_parser.num_channels])
+ if is_training:
+ batch_size = params.train.batch_size
+ input_layer = {
+ 'image':
+ tf_keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf_keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info',
+ ),
+ 'gt_boxes':
+ tf_keras.layers.Input(
+ shape=[params.maskrcnn_parser.max_num_instances, 4],
+ batch_size=batch_size,
+ name='gt_boxes'),
+ 'gt_classes':
+ tf_keras.layers.Input(
+ shape=[params.maskrcnn_parser.max_num_instances],
+ batch_size=batch_size,
+ name='gt_classes',
+ dtype=tf.int64),
+ }
+ if self._include_mask:
+ input_layer['gt_masks'] = tf_keras.layers.Input(
+ shape=[
+ params.maskrcnn_parser.max_num_instances,
+ params.maskrcnn_parser.mask_crop_size,
+ params.maskrcnn_parser.mask_crop_size
+ ],
+ batch_size=batch_size,
+ name='gt_masks')
+ else:
+ batch_size = params.eval.batch_size
+ input_layer = {
+ 'image':
+ tf_keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf_keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info',
+ ),
+ }
+ return input_layer
+
+ def build_model(self, params, mode):
+ if self._keras_model is None:
+ input_layers = self.build_input_layers(self._params, mode)
+ outputs = self.model_outputs(input_layers, mode)
+
+ model = tf_keras.models.Model(
+ inputs=input_layers, outputs=outputs, name='maskrcnn')
+ assert model is not None, 'Fail to build tf_keras.Model.'
+ model.optimizer = self.build_optimizer()
+ self._keras_model = model
+
+ return self._keras_model
+
+ def post_processing(self, labels, outputs):
+ required_output_fields = ['class_outputs', 'box_outputs']
+ for field in required_output_fields:
+ if field not in outputs:
+ raise ValueError('"%s" is missing in outputs, requried %s found %s' %
+ (field, required_output_fields, outputs.keys()))
+ predictions = {
+ 'image_info': labels['image_info'],
+ 'num_detections': outputs['num_detections'],
+ 'detection_boxes': outputs['detection_boxes'],
+ 'detection_classes': outputs['detection_classes'],
+ 'detection_scores': outputs['detection_scores'],
+ }
+ if self._include_mask:
+ predictions.update({
+ 'detection_masks': outputs['detection_masks'],
+ })
+
+ if 'groundtruths' in labels:
+ predictions['source_id'] = labels['groundtruths']['source_id']
+ predictions['gt_source_id'] = labels['groundtruths']['source_id']
+ predictions['gt_height'] = labels['groundtruths']['height']
+ predictions['gt_width'] = labels['groundtruths']['width']
+ predictions['gt_image_info'] = labels['image_info']
+ predictions['gt_num_detections'] = (
+ labels['groundtruths']['num_detections'])
+ predictions['gt_boxes'] = labels['groundtruths']['boxes']
+ predictions['gt_classes'] = labels['groundtruths']['classes']
+ predictions['gt_areas'] = labels['groundtruths']['areas']
+ predictions['gt_is_crowds'] = labels['groundtruths']['is_crowds']
+ return labels, predictions
+
+ def eval_metrics(self):
+ return eval_factory.evaluator_generator(self._params.eval)
diff --git a/modeling/official/legacy/detection/modeling/olnmask_model.py b/modeling/official/legacy/detection/modeling/olnmask_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c362289f65bb2177f62dded5d712247e3827307
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/olnmask_model.py
@@ -0,0 +1,432 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model defination for the Object Localization Network (OLN) Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import anchor
+from official.legacy.detection.dataloader import mode_keys
+from official.legacy.detection.modeling import losses
+from official.legacy.detection.modeling.architecture import factory
+from official.legacy.detection.modeling.maskrcnn_model import MaskrcnnModel
+from official.legacy.detection.ops import postprocess_ops
+from official.legacy.detection.ops import roi_ops
+from official.legacy.detection.ops import spatial_transform_ops
+from official.legacy.detection.ops import target_ops
+from official.legacy.detection.utils import box_utils
+
+
+class OlnMaskModel(MaskrcnnModel):
+ """OLN-Mask model function."""
+
+ def __init__(self, params):
+ super(OlnMaskModel, self).__init__(params)
+
+ self._params = params
+
+ # Different heads and layers.
+ self._include_rpn_class = params.architecture.include_rpn_class
+ self._include_mask = params.architecture.include_mask
+ self._include_frcnn_class = params.architecture.include_frcnn_class
+ self._include_frcnn_box = params.architecture.include_frcnn_box
+ self._include_centerness = params.rpn_head.has_centerness
+ self._include_box_score = (params.frcnn_head.has_scoring and
+ params.architecture.include_frcnn_box)
+ self._include_mask_score = (params.mrcnn_head.has_scoring and
+ params.architecture.include_mask)
+
+ # Architecture generators.
+ self._backbone_fn = factory.backbone_generator(params)
+ self._fpn_fn = factory.multilevel_features_generator(params)
+ self._rpn_head_fn = factory.rpn_head_generator(params)
+ if self._include_centerness:
+ self._rpn_head_fn = factory.oln_rpn_head_generator(params)
+ else:
+ self._rpn_head_fn = factory.rpn_head_generator(params)
+ self._generate_rois_fn = roi_ops.OlnROIGenerator(params.roi_proposal)
+ self._sample_rois_fn = target_ops.ROIScoreSampler(params.roi_sampling)
+ self._sample_masks_fn = target_ops.MaskSampler(
+ params.architecture.mask_target_size,
+ params.mask_sampling.num_mask_samples_per_image)
+
+ if self._include_box_score:
+ self._frcnn_head_fn = factory.oln_box_score_head_generator(params)
+ else:
+ self._frcnn_head_fn = factory.fast_rcnn_head_generator(params)
+
+ if self._include_mask:
+ if self._include_mask_score:
+ self._mrcnn_head_fn = factory.oln_mask_score_head_generator(params)
+ else:
+ self._mrcnn_head_fn = factory.mask_rcnn_head_generator(params)
+
+ # Loss function.
+ self._rpn_score_loss_fn = losses.RpnScoreLoss(params.rpn_score_loss)
+ self._rpn_box_loss_fn = losses.RpnBoxLoss(params.rpn_box_loss)
+ if self._include_centerness:
+ self._rpn_iou_loss_fn = losses.OlnRpnIoULoss()
+ self._rpn_center_loss_fn = losses.OlnRpnCenterLoss()
+ self._frcnn_class_loss_fn = losses.FastrcnnClassLoss()
+ self._frcnn_box_loss_fn = losses.FastrcnnBoxLoss(params.frcnn_box_loss)
+ if self._include_box_score:
+ self._frcnn_box_score_loss_fn = losses.OlnBoxScoreLoss(
+ params.frcnn_box_score_loss)
+ if self._include_mask:
+ self._mask_loss_fn = losses.MaskrcnnLoss()
+
+ self._generate_detections_fn = postprocess_ops.OlnDetectionGenerator(
+ params.postprocess)
+
+ self._transpose_input = params.train.transpose_input
+ assert not self._transpose_input, 'Transpose input is not supportted.'
+
+ def build_outputs(self, inputs, mode):
+ is_training = mode == mode_keys.TRAIN
+ model_outputs = {}
+
+ image = inputs['image']
+ _, image_height, image_width, _ = image.get_shape().as_list()
+ backbone_features = self._backbone_fn(image, is_training)
+ fpn_features = self._fpn_fn(backbone_features, is_training)
+
+ # rpn_centerness.
+ if self._include_centerness:
+ rpn_score_outputs, rpn_box_outputs, rpn_center_outputs = (
+ self._rpn_head_fn(fpn_features, is_training))
+ model_outputs.update({
+ 'rpn_center_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ rpn_center_outputs),
+ })
+ object_scores = rpn_center_outputs
+ else:
+ rpn_score_outputs, rpn_box_outputs = self._rpn_head_fn(
+ fpn_features, is_training)
+ object_scores = None
+ model_outputs.update({
+ 'rpn_score_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ rpn_score_outputs),
+ 'rpn_box_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ rpn_box_outputs),
+ })
+ input_anchor = anchor.Anchor(self._params.architecture.min_level,
+ self._params.architecture.max_level,
+ self._params.anchor.num_scales,
+ self._params.anchor.aspect_ratios,
+ self._params.anchor.anchor_size,
+ (image_height, image_width))
+ rpn_rois, rpn_roi_scores = self._generate_rois_fn(
+ rpn_box_outputs,
+ rpn_score_outputs,
+ input_anchor.multilevel_boxes,
+ inputs['image_info'][:, 1, :],
+ is_training,
+ is_box_lrtb=self._include_centerness,
+ object_scores=object_scores,
+ )
+ if (not self._include_frcnn_class and
+ not self._include_frcnn_box and
+ not self._include_mask):
+ # if not is_training:
+ # For direct RPN detection,
+ # use dummy box_outputs = (dy,dx,dh,dw = 0,0,0,0)
+ box_outputs = tf.zeros_like(rpn_rois)
+ box_outputs = tf.concat([box_outputs, box_outputs], -1)
+ boxes, scores, classes, valid_detections = self._generate_detections_fn(
+ box_outputs, rpn_roi_scores, rpn_rois,
+ inputs['image_info'][:, 1:2, :],
+ is_single_fg_score=True, # if no_background, no softmax is applied.
+ keep_nms=True)
+ model_outputs.update({
+ 'num_detections': valid_detections,
+ 'detection_boxes': boxes,
+ 'detection_classes': classes,
+ 'detection_scores': scores,
+ })
+ return model_outputs
+
+ # ---- OLN-Proposal finishes here. ----
+
+ if is_training:
+ rpn_rois = tf.stop_gradient(rpn_rois)
+ rpn_roi_scores = tf.stop_gradient(rpn_roi_scores)
+
+ # Sample proposals.
+ (rpn_rois, rpn_roi_scores, matched_gt_boxes, matched_gt_classes,
+ matched_gt_indices) = (
+ self._sample_rois_fn(rpn_rois, rpn_roi_scores, inputs['gt_boxes'],
+ inputs['gt_classes']))
+ # Create bounding box training targets.
+ box_targets = box_utils.encode_boxes(
+ matched_gt_boxes, rpn_rois, weights=[10.0, 10.0, 5.0, 5.0])
+ # If the target is background, the box target is set to all 0s.
+ box_targets = tf.where(
+ tf.tile(
+ tf.expand_dims(tf.equal(matched_gt_classes, 0), axis=-1),
+ [1, 1, 4]), tf.zeros_like(box_targets), box_targets)
+ model_outputs.update({
+ 'class_targets': matched_gt_classes,
+ 'box_targets': box_targets,
+ })
+ # Create Box-IoU targets. {
+ box_ious = box_utils.bbox_overlap(
+ rpn_rois, inputs['gt_boxes'])
+ matched_box_ious = tf.reduce_max(box_ious, 2)
+ model_outputs.update({
+ 'box_iou_targets': matched_box_ious,}) # }
+
+ roi_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, rpn_rois, output_size=7)
+
+ if not self._include_box_score:
+ class_outputs, box_outputs = self._frcnn_head_fn(
+ roi_features, is_training)
+ else:
+ class_outputs, box_outputs, score_outputs = self._frcnn_head_fn(
+ roi_features, is_training)
+ model_outputs.update({
+ 'box_score_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ score_outputs),})
+ model_outputs.update({
+ 'class_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ class_outputs),
+ 'box_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ box_outputs),
+ })
+
+ # Add this output to train to make the checkpoint loadable in predict mode.
+ # If we skip it in train mode, the heads will be out-of-order and checkpoint
+ # loading will fail.
+ if not self._include_frcnn_box:
+ box_outputs = tf.zeros_like(box_outputs) # dummy zeros.
+
+ if self._include_box_score:
+ score_outputs = tf.cast(tf.squeeze(score_outputs, -1),
+ rpn_roi_scores.dtype)
+
+ # box-score = (rpn-centerness * box-iou)^(1/2)
+ # TR: rpn_roi_scores: b,1000, score_outputs: b,512
+ # TS: rpn_roi_scores: b,1000, score_outputs: b,1000
+ box_scores = tf.pow(
+ rpn_roi_scores * tf.sigmoid(score_outputs), 1/2.)
+
+ if not self._include_frcnn_class:
+ boxes, scores, classes, valid_detections = self._generate_detections_fn(
+ box_outputs,
+ box_scores,
+ rpn_rois,
+ inputs['image_info'][:, 1:2, :],
+ is_single_fg_score=True,
+ keep_nms=True,)
+ else:
+ boxes, scores, classes, valid_detections = self._generate_detections_fn(
+ box_outputs, class_outputs, rpn_rois,
+ inputs['image_info'][:, 1:2, :],
+ keep_nms=True,)
+ model_outputs.update({
+ 'num_detections': valid_detections,
+ 'detection_boxes': boxes,
+ 'detection_classes': classes,
+ 'detection_scores': scores,
+ })
+
+ # ---- OLN-Box finishes here. ----
+
+ if not self._include_mask:
+ return model_outputs
+
+ if is_training:
+ rpn_rois, classes, mask_targets = self._sample_masks_fn(
+ rpn_rois, matched_gt_boxes, matched_gt_classes, matched_gt_indices,
+ inputs['gt_masks'])
+ mask_targets = tf.stop_gradient(mask_targets)
+
+ classes = tf.cast(classes, dtype=tf.int32)
+
+ model_outputs.update({
+ 'mask_targets': mask_targets,
+ 'sampled_class_targets': classes,
+ })
+ else:
+ rpn_rois = boxes
+ classes = tf.cast(classes, dtype=tf.int32)
+
+ mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
+ fpn_features, rpn_rois, output_size=14)
+
+ mask_outputs = self._mrcnn_head_fn(mask_roi_features, classes, is_training)
+
+ if is_training:
+ model_outputs.update({
+ 'mask_outputs':
+ tf.nest.map_structure(lambda x: tf.cast(x, tf.float32),
+ mask_outputs),
+ })
+ else:
+ model_outputs.update({'detection_masks': tf.nn.sigmoid(mask_outputs)})
+
+ return model_outputs
+
+ def build_loss_fn(self):
+ if self._keras_model is None:
+ raise ValueError('build_loss_fn() must be called after build_model().')
+
+ filter_fn = self.make_filter_trainable_variables_fn()
+ trainable_variables = filter_fn(self._keras_model.trainable_variables)
+
+ def _total_loss_fn(labels, outputs):
+ if self._include_rpn_class:
+ rpn_score_loss = self._rpn_score_loss_fn(outputs['rpn_score_outputs'],
+ labels['rpn_score_targets'])
+ else:
+ rpn_score_loss = 0.0
+ if self._include_centerness:
+ rpn_center_loss = self._rpn_center_loss_fn(
+ outputs['rpn_center_outputs'], labels['rpn_center_targets'])
+ rpn_box_loss = self._rpn_iou_loss_fn(
+ outputs['rpn_box_outputs'], labels['rpn_box_targets'],
+ labels['rpn_center_targets'])
+ else:
+ rpn_center_loss = 0.0
+ rpn_box_loss = self._rpn_box_loss_fn(
+ outputs['rpn_box_outputs'], labels['rpn_box_targets'])
+
+ if self._include_frcnn_class:
+ frcnn_class_loss = self._frcnn_class_loss_fn(
+ outputs['class_outputs'], outputs['class_targets'])
+ else:
+ frcnn_class_loss = 0.0
+ if self._include_frcnn_box:
+ frcnn_box_loss = self._frcnn_box_loss_fn(
+ outputs['box_outputs'], outputs['class_targets'],
+ outputs['box_targets'])
+ else:
+ frcnn_box_loss = 0.0
+ if self._include_box_score:
+ box_score_loss = self._frcnn_box_score_loss_fn(
+ outputs['box_score_outputs'], outputs['box_iou_targets'])
+ else:
+ box_score_loss = 0.0
+
+ if self._include_mask:
+ mask_loss = self._mask_loss_fn(outputs['mask_outputs'],
+ outputs['mask_targets'],
+ outputs['sampled_class_targets'])
+ else:
+ mask_loss = 0.0
+
+ model_loss = (
+ rpn_score_loss + rpn_box_loss + rpn_center_loss +
+ frcnn_class_loss + frcnn_box_loss + box_score_loss +
+ mask_loss)
+
+ l2_regularization_loss = self.weight_decay_loss(trainable_variables)
+ total_loss = model_loss + l2_regularization_loss
+ return {
+ 'total_loss': total_loss,
+ 'loss': total_loss,
+ 'fast_rcnn_class_loss': frcnn_class_loss,
+ 'fast_rcnn_box_loss': frcnn_box_loss,
+ 'fast_rcnn_box_score_loss': box_score_loss,
+ 'mask_loss': mask_loss,
+ 'model_loss': model_loss,
+ 'l2_regularization_loss': l2_regularization_loss,
+ 'rpn_score_loss': rpn_score_loss,
+ 'rpn_box_loss': rpn_box_loss,
+ 'rpn_center_loss': rpn_center_loss,
+ }
+
+ return _total_loss_fn
+
+ def build_input_layers(self, params, mode):
+ is_training = mode == mode_keys.TRAIN
+ input_shape = (
+ params.olnmask_parser.output_size +
+ [params.olnmask_parser.num_channels])
+ if is_training:
+ batch_size = params.train.batch_size
+ input_layer = {
+ 'image':
+ tf_keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf_keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info',
+ ),
+ 'gt_boxes':
+ tf_keras.layers.Input(
+ shape=[params.olnmask_parser.max_num_instances, 4],
+ batch_size=batch_size,
+ name='gt_boxes'),
+ 'gt_classes':
+ tf_keras.layers.Input(
+ shape=[params.olnmask_parser.max_num_instances],
+ batch_size=batch_size,
+ name='gt_classes',
+ dtype=tf.int64),
+ }
+ if self._include_mask:
+ input_layer['gt_masks'] = tf_keras.layers.Input(
+ shape=[
+ params.olnmask_parser.max_num_instances,
+ params.olnmask_parser.mask_crop_size,
+ params.olnmask_parser.mask_crop_size
+ ],
+ batch_size=batch_size,
+ name='gt_masks')
+ else:
+ batch_size = params.eval.batch_size
+ input_layer = {
+ 'image':
+ tf_keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf_keras.layers.Input(
+ shape=[4, 2],
+ batch_size=batch_size,
+ name='image_info',
+ ),
+ }
+ return input_layer
+
+ def build_model(self, params, mode):
+ if self._keras_model is None:
+ input_layers = self.build_input_layers(self._params, mode)
+ outputs = self.model_outputs(input_layers, mode)
+
+ model = tf_keras.models.Model(
+ inputs=input_layers, outputs=outputs, name='olnmask')
+ assert model is not None, 'Fail to build tf_keras.Model.'
+ model.optimizer = self.build_optimizer()
+ self._keras_model = model
+
+ return self._keras_model
diff --git a/modeling/official/legacy/detection/modeling/optimizers.py b/modeling/official/legacy/detection/modeling/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad3258bd3fbb61ad97c7254eed5d93eac8e1b401
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/optimizers.py
@@ -0,0 +1,49 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Optimizers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf, tf_keras
+
+
+class OptimizerFactory(object):
+ """Class to generate optimizer function."""
+
+ def __init__(self, params):
+ """Creates optimized based on the specified flags."""
+ if params.type == 'momentum':
+ self._optimizer = functools.partial(
+ tf_keras.optimizers.SGD,
+ momentum=params.momentum,
+ nesterov=params.nesterov)
+ elif params.type == 'adam':
+ self._optimizer = tf_keras.optimizers.Adam
+ elif params.type == 'adadelta':
+ self._optimizer = tf_keras.optimizers.Adadelta
+ elif params.type == 'adagrad':
+ self._optimizer = tf_keras.optimizers.Adagrad
+ elif params.type == 'rmsprop':
+ self._optimizer = functools.partial(
+ tf_keras.optimizers.RMSprop, momentum=params.momentum)
+ else:
+ raise ValueError('Unsupported optimizer type `{}`.'.format(params.type))
+
+ def __call__(self, learning_rate):
+ return self._optimizer(learning_rate=learning_rate)
diff --git a/modeling/official/legacy/detection/modeling/retinanet_model.py b/modeling/official/legacy/detection/modeling/retinanet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..249dbd5938d9375c71f51a14b9884bbf541943ff
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/retinanet_model.py
@@ -0,0 +1,165 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model defination for the RetinaNet Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import mode_keys
+from official.legacy.detection.evaluation import factory as eval_factory
+from official.legacy.detection.modeling import base_model
+from official.legacy.detection.modeling import losses
+from official.legacy.detection.modeling.architecture import factory
+from official.legacy.detection.ops import postprocess_ops
+
+
+class RetinanetModel(base_model.Model):
+ """RetinaNet model function."""
+
+ def __init__(self, params):
+ super(RetinanetModel, self).__init__(params)
+
+ # For eval metrics.
+ self._params = params
+
+ # Architecture generators.
+ self._backbone_fn = factory.backbone_generator(params)
+ self._fpn_fn = factory.multilevel_features_generator(params)
+ self._head_fn = factory.retinanet_head_generator(params)
+
+ # Loss function.
+ self._cls_loss_fn = losses.RetinanetClassLoss(
+ params.retinanet_loss, params.architecture.num_classes)
+ self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
+ self._box_loss_weight = params.retinanet_loss.box_loss_weight
+ self._keras_model = None
+
+ # Predict function.
+ self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
+ params.architecture.min_level, params.architecture.max_level,
+ params.postprocess)
+
+ self._transpose_input = params.train.transpose_input
+ assert not self._transpose_input, 'Transpose input is not supported.'
+ # Input layer.
+ self._input_layer = tf_keras.layers.Input(
+ shape=(None, None, params.retinanet_parser.num_channels),
+ name='',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
+
+ def build_outputs(self, inputs, mode):
+ # If the input image is transposed (from NHWC to HWCN), we need to revert it
+ # back to the original shape before it's used in the computation.
+ if self._transpose_input:
+ inputs = tf.transpose(inputs, [3, 0, 1, 2])
+
+ backbone_features = self._backbone_fn(
+ inputs, is_training=(mode == mode_keys.TRAIN))
+ fpn_features = self._fpn_fn(
+ backbone_features, is_training=(mode == mode_keys.TRAIN))
+ cls_outputs, box_outputs = self._head_fn(
+ fpn_features, is_training=(mode == mode_keys.TRAIN))
+
+ if self._use_bfloat16:
+ levels = cls_outputs.keys()
+ for level in levels:
+ cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
+ box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
+
+ model_outputs = {
+ 'cls_outputs': cls_outputs,
+ 'box_outputs': box_outputs,
+ }
+ return model_outputs
+
+ def build_loss_fn(self):
+ if self._keras_model is None:
+ raise ValueError('build_loss_fn() must be called after build_model().')
+
+ filter_fn = self.make_filter_trainable_variables_fn()
+ trainable_variables = filter_fn(self._keras_model.trainable_variables)
+
+ def _total_loss_fn(labels, outputs):
+ cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
+ labels['cls_targets'],
+ labels['num_positives'])
+ box_loss = self._box_loss_fn(outputs['box_outputs'],
+ labels['box_targets'],
+ labels['num_positives'])
+ model_loss = cls_loss + self._box_loss_weight * box_loss
+ l2_regularization_loss = self.weight_decay_loss(trainable_variables)
+ total_loss = model_loss + l2_regularization_loss
+ return {
+ 'total_loss': total_loss,
+ 'cls_loss': cls_loss,
+ 'box_loss': box_loss,
+ 'model_loss': model_loss,
+ 'l2_regularization_loss': l2_regularization_loss,
+ }
+
+ return _total_loss_fn
+
+ def build_model(self, params, mode=None):
+ if self._keras_model is None:
+ outputs = self.model_outputs(self._input_layer, mode)
+
+ model = tf_keras.models.Model(
+ inputs=self._input_layer, outputs=outputs, name='retinanet')
+ assert model is not None, 'Fail to build tf_keras.Model.'
+ model.optimizer = self.build_optimizer()
+ self._keras_model = model
+
+ return self._keras_model
+
+ def post_processing(self, labels, outputs):
+ # TODO(yeqing): Moves the output related part into build_outputs.
+ required_output_fields = ['cls_outputs', 'box_outputs']
+ for field in required_output_fields:
+ if field not in outputs:
+ raise ValueError('"%s" is missing in outputs, requried %s found %s' %
+ (field, required_output_fields, outputs.keys()))
+ required_label_fields = ['image_info', 'groundtruths']
+ for field in required_label_fields:
+ if field not in labels:
+ raise ValueError('"%s" is missing in outputs, requried %s found %s' %
+ (field, required_label_fields, labels.keys()))
+ boxes, scores, classes, valid_detections = self._generate_detections_fn(
+ outputs['box_outputs'], outputs['cls_outputs'], labels['anchor_boxes'],
+ labels['image_info'][:, 1:2, :])
+ # Discards the old output tensors to save memory. The `cls_outputs` and
+ # `box_outputs` are pretty big and could potentiall lead to memory issue.
+ outputs = {
+ 'source_id': labels['groundtruths']['source_id'],
+ 'image_info': labels['image_info'],
+ 'num_detections': valid_detections,
+ 'detection_boxes': boxes,
+ 'detection_classes': classes,
+ 'detection_scores': scores,
+ }
+
+ if 'groundtruths' in labels:
+ labels['source_id'] = labels['groundtruths']['source_id']
+ labels['boxes'] = labels['groundtruths']['boxes']
+ labels['classes'] = labels['groundtruths']['classes']
+ labels['areas'] = labels['groundtruths']['areas']
+ labels['is_crowds'] = labels['groundtruths']['is_crowds']
+
+ return labels, outputs
+
+ def eval_metrics(self):
+ return eval_factory.evaluator_generator(self._params.eval)
diff --git a/modeling/official/legacy/detection/modeling/shapemask_model.py b/modeling/official/legacy/detection/modeling/shapemask_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..e698a5a59417a3cd15575e2eef9b086423a5f8c4
--- /dev/null
+++ b/modeling/official/legacy/detection/modeling/shapemask_model.py
@@ -0,0 +1,304 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model definition for the ShapeMask Model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.dataloader import anchor
+from official.legacy.detection.dataloader import mode_keys
+from official.legacy.detection.evaluation import factory as eval_factory
+from official.legacy.detection.modeling import base_model
+from official.legacy.detection.modeling import losses
+from official.legacy.detection.modeling.architecture import factory
+from official.legacy.detection.ops import postprocess_ops
+from official.legacy.detection.utils import box_utils
+
+
+class ShapeMaskModel(base_model.Model):
+ """ShapeMask model function."""
+
+ def __init__(self, params):
+ super(ShapeMaskModel, self).__init__(params)
+
+ self._params = params
+ self._keras_model = None
+
+ # Architecture generators.
+ self._backbone_fn = factory.backbone_generator(params)
+ self._fpn_fn = factory.multilevel_features_generator(params)
+ self._retinanet_head_fn = factory.retinanet_head_generator(params)
+ self._shape_prior_head_fn = factory.shapeprior_head_generator(params)
+ self._coarse_mask_fn = factory.coarsemask_head_generator(params)
+ self._fine_mask_fn = factory.finemask_head_generator(params)
+
+ # Loss functions.
+ self._cls_loss_fn = losses.RetinanetClassLoss(
+ params.retinanet_loss, params.architecture.num_classes)
+ self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss)
+ self._box_loss_weight = params.retinanet_loss.box_loss_weight
+
+ # Mask loss function.
+ self._shapemask_prior_loss_fn = losses.ShapemaskMseLoss()
+ self._shapemask_loss_fn = losses.ShapemaskLoss()
+ self._shape_prior_loss_weight = (
+ params.shapemask_loss.shape_prior_loss_weight)
+ self._coarse_mask_loss_weight = (
+ params.shapemask_loss.coarse_mask_loss_weight)
+ self._fine_mask_loss_weight = (params.shapemask_loss.fine_mask_loss_weight)
+
+ # Predict function.
+ self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator(
+ params.architecture.min_level, params.architecture.max_level,
+ params.postprocess)
+
+ def build_outputs(self, inputs, mode):
+ is_training = mode == mode_keys.TRAIN
+ images = inputs['image']
+
+ if 'anchor_boxes' in inputs:
+ anchor_boxes = inputs['anchor_boxes']
+ else:
+ anchor_boxes = anchor.Anchor(
+ self._params.architecture.min_level,
+ self._params.architecture.max_level, self._params.anchor.num_scales,
+ self._params.anchor.aspect_ratios, self._params.anchor.anchor_size,
+ images.get_shape().as_list()[1:3]).multilevel_boxes
+
+ batch_size = tf.shape(images)[0]
+ for level in anchor_boxes:
+ anchor_boxes[level] = tf.tile(
+ tf.expand_dims(anchor_boxes[level], 0), [batch_size, 1, 1, 1])
+
+ backbone_features = self._backbone_fn(images, is_training=is_training)
+ fpn_features = self._fpn_fn(backbone_features, is_training=is_training)
+ cls_outputs, box_outputs = self._retinanet_head_fn(
+ fpn_features, is_training=is_training)
+
+ valid_boxes, valid_scores, valid_classes, valid_detections = (
+ self._generate_detections_fn(box_outputs, cls_outputs, anchor_boxes,
+ inputs['image_info'][:, 1:2, :]))
+
+ image_size = images.get_shape().as_list()[1:3]
+ valid_outer_boxes = box_utils.compute_outer_boxes(
+ tf.reshape(valid_boxes, [-1, 4]),
+ image_size,
+ scale=self._params.shapemask_parser.outer_box_scale)
+ valid_outer_boxes = tf.reshape(valid_outer_boxes, tf.shape(valid_boxes))
+
+ # Wrapping if else code paths into a layer to make the checkpoint loadable
+ # in prediction mode.
+ class SampledBoxesLayer(tf_keras.layers.Layer):
+ """ShapeMask model function."""
+
+ def call(self, inputs, val_boxes, val_classes, val_outer_boxes, training):
+ if training:
+ boxes = inputs['mask_boxes']
+ outer_boxes = inputs['mask_outer_boxes']
+ classes = inputs['mask_classes']
+ else:
+ boxes = val_boxes
+ classes = val_classes
+ outer_boxes = val_outer_boxes
+ return boxes, classes, outer_boxes
+
+ boxes, classes, outer_boxes = SampledBoxesLayer()(
+ inputs,
+ valid_boxes,
+ valid_classes,
+ valid_outer_boxes,
+ training=is_training)
+
+ instance_features, prior_masks = self._shape_prior_head_fn(
+ fpn_features, boxes, outer_boxes, classes, is_training)
+ coarse_mask_logits = self._coarse_mask_fn(instance_features, prior_masks,
+ classes, is_training)
+ fine_mask_logits = self._fine_mask_fn(instance_features, coarse_mask_logits,
+ classes, is_training)
+
+ model_outputs = {
+ 'cls_outputs': cls_outputs,
+ 'box_outputs': box_outputs,
+ 'fine_mask_logits': fine_mask_logits,
+ 'coarse_mask_logits': coarse_mask_logits,
+ 'prior_masks': prior_masks,
+ }
+
+ if not is_training:
+ model_outputs.update({
+ 'num_detections': valid_detections,
+ 'detection_boxes': valid_boxes,
+ 'detection_outer_boxes': valid_outer_boxes,
+ 'detection_masks': fine_mask_logits,
+ 'detection_classes': valid_classes,
+ 'detection_scores': valid_scores,
+ })
+
+ return model_outputs
+
+ def build_loss_fn(self):
+ if self._keras_model is None:
+ raise ValueError('build_loss_fn() must be called after build_model().')
+
+ filter_fn = self.make_filter_trainable_variables_fn()
+ trainable_variables = filter_fn(self._keras_model.trainable_variables)
+
+ def _total_loss_fn(labels, outputs):
+ cls_loss = self._cls_loss_fn(outputs['cls_outputs'],
+ labels['cls_targets'],
+ labels['num_positives'])
+ box_loss = self._box_loss_fn(outputs['box_outputs'],
+ labels['box_targets'],
+ labels['num_positives'])
+
+ # Adds Shapemask model losses.
+ shape_prior_loss = self._shapemask_prior_loss_fn(outputs['prior_masks'],
+ labels['mask_targets'],
+ labels['mask_is_valid'])
+ coarse_mask_loss = self._shapemask_loss_fn(outputs['coarse_mask_logits'],
+ labels['mask_targets'],
+ labels['mask_is_valid'])
+ fine_mask_loss = self._shapemask_loss_fn(outputs['fine_mask_logits'],
+ labels['fine_mask_targets'],
+ labels['mask_is_valid'])
+
+ model_loss = (
+ cls_loss + self._box_loss_weight * box_loss +
+ shape_prior_loss * self._shape_prior_loss_weight +
+ coarse_mask_loss * self._coarse_mask_loss_weight +
+ fine_mask_loss * self._fine_mask_loss_weight)
+
+ l2_regularization_loss = self.weight_decay_loss(trainable_variables)
+ total_loss = model_loss + l2_regularization_loss
+
+ shapemask_losses = {
+ 'total_loss': total_loss,
+ 'loss': total_loss,
+ 'retinanet_cls_loss': cls_loss,
+ 'l2_regularization_loss': l2_regularization_loss,
+ 'retinanet_box_loss': box_loss,
+ 'shapemask_prior_loss': shape_prior_loss,
+ 'shapemask_coarse_mask_loss': coarse_mask_loss,
+ 'shapemask_fine_mask_loss': fine_mask_loss,
+ 'model_loss': model_loss,
+ }
+ return shapemask_losses
+
+ return _total_loss_fn
+
+ def build_input_layers(self, params, mode):
+ is_training = mode == mode_keys.TRAIN
+ input_shape = (
+ params.shapemask_parser.output_size +
+ [params.shapemask_parser.num_channels])
+ if is_training:
+ batch_size = params.train.batch_size
+ input_layer = {
+ 'image':
+ tf_keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf_keras.layers.Input(
+ shape=[4, 2], batch_size=batch_size, name='image_info'),
+ 'mask_classes':
+ tf_keras.layers.Input(
+ shape=[params.shapemask_parser.num_sampled_masks],
+ batch_size=batch_size,
+ name='mask_classes',
+ dtype=tf.int64),
+ 'mask_outer_boxes':
+ tf_keras.layers.Input(
+ shape=[params.shapemask_parser.num_sampled_masks, 4],
+ batch_size=batch_size,
+ name='mask_outer_boxes',
+ dtype=tf.float32),
+ 'mask_boxes':
+ tf_keras.layers.Input(
+ shape=[params.shapemask_parser.num_sampled_masks, 4],
+ batch_size=batch_size,
+ name='mask_boxes',
+ dtype=tf.float32),
+ }
+ else:
+ batch_size = params.eval.batch_size
+ input_layer = {
+ 'image':
+ tf_keras.layers.Input(
+ shape=input_shape,
+ batch_size=batch_size,
+ name='image',
+ dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32),
+ 'image_info':
+ tf_keras.layers.Input(
+ shape=[4, 2], batch_size=batch_size, name='image_info'),
+ }
+ return input_layer
+
+ def build_model(self, params, mode):
+ if self._keras_model is None:
+ input_layers = self.build_input_layers(self._params, mode)
+ outputs = self.model_outputs(input_layers, mode)
+
+ model = tf_keras.models.Model(
+ inputs=input_layers, outputs=outputs, name='shapemask')
+ assert model is not None, 'Fail to build tf_keras.Model.'
+ model.optimizer = self.build_optimizer()
+ self._keras_model = model
+
+ return self._keras_model
+
+ def post_processing(self, labels, outputs):
+ required_output_fields = [
+ 'num_detections', 'detection_boxes', 'detection_classes',
+ 'detection_masks', 'detection_scores'
+ ]
+
+ for field in required_output_fields:
+ if field not in outputs:
+ raise ValueError(
+ '"{}" is missing in outputs, requried {} found {}'.format(
+ field, required_output_fields, outputs.keys()))
+
+ required_label_fields = ['image_info']
+ for field in required_label_fields:
+ if field not in labels:
+ raise ValueError(
+ '"{}" is missing in labels, requried {} found {}'.format(
+ field, required_label_fields, labels.keys()))
+
+ predictions = {
+ 'image_info': labels['image_info'],
+ 'num_detections': outputs['num_detections'],
+ 'detection_boxes': outputs['detection_boxes'],
+ 'detection_outer_boxes': outputs['detection_outer_boxes'],
+ 'detection_classes': outputs['detection_classes'],
+ 'detection_scores': outputs['detection_scores'],
+ 'detection_masks': outputs['detection_masks'],
+ }
+
+ if 'groundtruths' in labels:
+ predictions['source_id'] = labels['groundtruths']['source_id']
+ labels = labels['groundtruths']
+
+ return labels, predictions
+
+ def eval_metrics(self):
+ return eval_factory.evaluator_generator(self._params.eval)
diff --git a/modeling/official/legacy/detection/ops/__init__.py b/modeling/official/legacy/detection/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/ops/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/ops/nms.py b/modeling/official/legacy/detection/ops/nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..551771181b37b9f7786e50b1547516e51e6a3aeb
--- /dev/null
+++ b/modeling/official/legacy/detection/ops/nms.py
@@ -0,0 +1,202 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tensorflow implementation of non max suppression."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.utils import box_utils
+
+NMS_TILE_SIZE = 512
+
+
+def _self_suppression(iou, _, iou_sum):
+ batch_size = tf.shape(iou)[0]
+ can_suppress_others = tf.cast(
+ tf.reshape(tf.reduce_max(iou, 1) <= 0.5, [batch_size, -1, 1]), iou.dtype)
+ iou_suppressed = tf.reshape(
+ tf.cast(tf.reduce_max(can_suppress_others * iou, 1) <= 0.5, iou.dtype),
+ [batch_size, -1, 1]) * iou
+ iou_sum_new = tf.reduce_sum(iou_suppressed, [1, 2])
+ return [
+ iou_suppressed,
+ tf.reduce_any(iou_sum - iou_sum_new > 0.5), iou_sum_new
+ ]
+
+
+def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx):
+ batch_size = tf.shape(boxes)[0]
+ new_slice = tf.slice(boxes, [0, inner_idx * NMS_TILE_SIZE, 0],
+ [batch_size, NMS_TILE_SIZE, 4])
+ iou = box_utils.bbox_overlap(new_slice, box_slice)
+ ret_slice = tf.expand_dims(
+ tf.cast(tf.reduce_all(iou < iou_threshold, [1]), box_slice.dtype),
+ 2) * box_slice
+ return boxes, ret_slice, iou_threshold, inner_idx + 1
+
+
+def _suppression_loop_body(boxes, iou_threshold, output_size, idx):
+ """Process boxes in the range [idx*NMS_TILE_SIZE, (idx+1)*NMS_TILE_SIZE).
+
+ Args:
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
+ iou_threshold: a float representing the threshold for deciding whether boxes
+ overlap too much with respect to IOU.
+ output_size: an int32 tensor of size [batch_size]. Representing the number
+ of selected boxes for each batch.
+ idx: an integer scalar representing induction variable.
+
+ Returns:
+ boxes: updated boxes.
+ iou_threshold: pass down iou_threshold to the next iteration.
+ output_size: the updated output_size.
+ idx: the updated induction variable.
+ """
+ boxes_shape = tf.shape(boxes)
+ num_tiles = boxes_shape[1] // NMS_TILE_SIZE
+ batch_size = boxes_shape[0]
+
+ # Iterates over tiles that can possibly suppress the current tile.
+ box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0],
+ [batch_size, NMS_TILE_SIZE, 4])
+ _, box_slice, _, _ = tf.while_loop(
+ lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx,
+ _cross_suppression, [boxes, box_slice, iou_threshold,
+ tf.constant(0)])
+
+ # Iterates over the current tile to compute self-suppression.
+ iou = box_utils.bbox_overlap(box_slice, box_slice)
+ mask = tf.expand_dims(
+ tf.reshape(tf.range(NMS_TILE_SIZE), [1, -1]) > tf.reshape(
+ tf.range(NMS_TILE_SIZE), [-1, 1]), 0)
+ iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype)
+ suppressed_iou, _, _ = tf.while_loop(
+ lambda _iou, loop_condition, _iou_sum: loop_condition, _self_suppression,
+ [iou, tf.constant(True),
+ tf.reduce_sum(iou, [1, 2])])
+ suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0
+ box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype), 2)
+
+ # Uses box_slice to update the input boxes.
+ mask = tf.reshape(
+ tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype), [1, -1, 1, 1])
+ boxes = tf.tile(tf.expand_dims(
+ box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape(
+ boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask)
+ boxes = tf.reshape(boxes, boxes_shape)
+
+ # Updates output_size.
+ output_size += tf.reduce_sum(
+ tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1])
+ return boxes, iou_threshold, output_size, idx + 1
+
+
+def sorted_non_max_suppression_padded(scores, boxes, max_output_size,
+ iou_threshold):
+ """A wrapper that handles non-maximum suppression.
+
+ Assumption:
+ * The boxes are sorted by scores unless the box is a dot (all coordinates
+ are zero).
+ * Boxes with higher scores can be used to suppress boxes with lower scores.
+
+ The overal design of the algorithm is to handle boxes tile-by-tile:
+
+ boxes = boxes.pad_to_multiply_of(tile_size)
+ num_tiles = len(boxes) // tile_size
+ output_boxes = []
+ for i in range(num_tiles):
+ box_tile = boxes[i*tile_size : (i+1)*tile_size]
+ for j in range(i - 1):
+ suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
+ iou = bbox_overlap(box_tile, suppressing_tile)
+ # if the box is suppressed in iou, clear it to a dot
+ box_tile *= _update_boxes(iou)
+ # Iteratively handle the diagnal tile.
+ iou = _box_overlap(box_tile, box_tile)
+ iou_changed = True
+ while iou_changed:
+ # boxes that are not suppressed by anything else
+ suppressing_boxes = _get_suppressing_boxes(iou)
+ # boxes that are suppressed by suppressing_boxes
+ suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
+ # clear iou to 0 for boxes that are suppressed, as they cannot be used
+ # to suppress other boxes any more
+ new_iou = _clear_iou(iou, suppressed_boxes)
+ iou_changed = (new_iou != iou)
+ iou = new_iou
+ # remaining boxes that can still suppress others, are selected boxes.
+ output_boxes.append(_get_suppressing_boxes(iou))
+ if len(output_boxes) >= max_output_size:
+ break
+
+ Args:
+ scores: a tensor with a shape of [batch_size, anchors].
+ boxes: a tensor with a shape of [batch_size, anchors, 4].
+ max_output_size: a scalar integer `Tensor` representing the maximum number
+ of boxes to be selected by non max suppression.
+ iou_threshold: a float representing the threshold for deciding whether boxes
+ overlap too much with respect to IOU.
+
+ Returns:
+ nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
+ dtype as input scores.
+ nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
+ same dtype as input boxes.
+ """
+ batch_size = tf.shape(boxes)[0]
+ num_boxes = tf.shape(boxes)[1]
+ pad = tf.cast(
+ tf.math.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE),
+ tf.int32) * NMS_TILE_SIZE - num_boxes
+ boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]])
+ scores = tf.pad(
+ tf.cast(scores, tf.float32), [[0, 0], [0, pad]], constant_values=-1)
+ num_boxes += pad
+
+ def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
+ return tf.logical_and(
+ tf.reduce_min(output_size) < max_output_size,
+ idx < num_boxes // NMS_TILE_SIZE)
+
+ selected_boxes, _, output_size, _ = tf.while_loop(
+ _loop_cond, _suppression_loop_body,
+ [boxes, iou_threshold,
+ tf.zeros([batch_size], tf.int32),
+ tf.constant(0)])
+ idx = num_boxes - tf.cast(
+ tf.nn.top_k(
+ tf.cast(tf.reduce_any(selected_boxes > 0, [2]), tf.int32) *
+ tf.expand_dims(tf.range(num_boxes, 0, -1), 0), max_output_size)[0],
+ tf.int32)
+ idx = tf.minimum(idx, num_boxes - 1)
+ idx = tf.reshape(idx + tf.reshape(tf.range(batch_size) * num_boxes, [-1, 1]),
+ [-1])
+ boxes = tf.reshape(
+ tf.gather(tf.reshape(boxes, [-1, 4]), idx),
+ [batch_size, max_output_size, 4])
+ boxes = boxes * tf.cast(
+ tf.reshape(tf.range(max_output_size), [1, -1, 1]) < tf.reshape(
+ output_size, [-1, 1, 1]), boxes.dtype)
+ scores = tf.reshape(
+ tf.gather(tf.reshape(scores, [-1, 1]), idx),
+ [batch_size, max_output_size])
+ scores = scores * tf.cast(
+ tf.reshape(tf.range(max_output_size), [1, -1]) < tf.reshape(
+ output_size, [-1, 1]), scores.dtype)
+ return scores, boxes
diff --git a/modeling/official/legacy/detection/ops/postprocess_ops.py b/modeling/official/legacy/detection/ops/postprocess_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb29d3912950c5ca296836202eae15cfaafd652b
--- /dev/null
+++ b/modeling/official/legacy/detection/ops/postprocess_ops.py
@@ -0,0 +1,497 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Post-processing model outputs to generate detection."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.ops import nms
+from official.legacy.detection.utils import box_utils
+
+
+def generate_detections_factory(params):
+ """Factory to select function to generate detection."""
+ if params.use_batched_nms:
+ func = functools.partial(
+ _generate_detections_batched,
+ max_total_size=params.max_total_size,
+ nms_iou_threshold=params.nms_iou_threshold,
+ score_threshold=params.score_threshold)
+ else:
+ func = functools.partial(
+ _generate_detections,
+ max_total_size=params.max_total_size,
+ nms_iou_threshold=params.nms_iou_threshold,
+ score_threshold=params.score_threshold,
+ pre_nms_num_boxes=params.pre_nms_num_boxes)
+ return func
+
+
+def _select_top_k_scores(scores_in, pre_nms_num_detections):
+ """Select top_k scores and indices for each class.
+
+ Args:
+ scores_in: a Tensor with shape [batch_size, N, num_classes], which stacks
+ class logit outputs on all feature levels. The N is the number of total
+ anchors on all levels. The num_classes is the number of classes predicted
+ by the model.
+ pre_nms_num_detections: Number of candidates before NMS.
+
+ Returns:
+ scores and indices: Tensors with shape [batch_size, pre_nms_num_detections,
+ num_classes].
+ """
+ batch_size, num_anchors, num_class = scores_in.get_shape().as_list()
+ scores_trans = tf.transpose(scores_in, perm=[0, 2, 1])
+ scores_trans = tf.reshape(scores_trans, [-1, num_anchors])
+
+ top_k_scores, top_k_indices = tf.nn.top_k(
+ scores_trans, k=pre_nms_num_detections, sorted=True)
+
+ top_k_scores = tf.reshape(top_k_scores,
+ [batch_size, num_class, pre_nms_num_detections])
+ top_k_indices = tf.reshape(top_k_indices,
+ [batch_size, num_class, pre_nms_num_detections])
+
+ return tf.transpose(top_k_scores,
+ [0, 2, 1]), tf.transpose(top_k_indices, [0, 2, 1])
+
+
+def _generate_detections(boxes,
+ scores,
+ max_total_size=100,
+ nms_iou_threshold=0.3,
+ score_threshold=0.05,
+ pre_nms_num_boxes=5000):
+ """Generate the final detections given the model outputs.
+
+ This uses classes unrolling with while loop based NMS, could be parralled
+ at batch dimension.
+
+ Args:
+ boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
+ N, 1, 4], which box predictions on all feature levels. The N is the number
+ of total anchors on all levels.
+ scores: a tensor with shape [batch_size, N, num_classes], which stacks class
+ probability on all feature levels. The N is the number of total anchors on
+ all levels. The num_classes is the number of classes predicted by the
+ model. Note that the class_outputs here is the raw score.
+ max_total_size: a scalar representing maximum number of boxes retained over
+ all classes.
+ nms_iou_threshold: a float representing the threshold for deciding whether
+ boxes overlap too much with respect to IOU.
+ score_threshold: a float representing the threshold for deciding when to
+ remove boxes based on score.
+ pre_nms_num_boxes: an int number of top candidate detections per class
+ before NMS.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
+ representing top detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [batch_size, max_total_size]
+ representing sorted confidence scores for detected boxes. The values are
+ between [0, 1].
+ nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
+ classes for detected boxes.
+ valid_detections: `int` Tensor of shape [batch_size] only the top
+ `valid_detections` boxes are valid detections.
+ """
+ with tf.name_scope('generate_detections'):
+ nmsed_boxes = []
+ nmsed_classes = []
+ nmsed_scores = []
+ valid_detections = []
+ batch_size, _, num_classes_for_box, _ = boxes.get_shape().as_list()
+ _, total_anchors, num_classes = scores.get_shape().as_list()
+ # Selects top pre_nms_num scores and indices before NMS.
+ scores, indices = _select_top_k_scores(
+ scores, min(total_anchors, pre_nms_num_boxes))
+ for i in range(num_classes):
+ boxes_i = boxes[:, :, min(num_classes_for_box - 1, i), :]
+ scores_i = scores[:, :, i]
+ # Obtains pre_nms_num_boxes before running NMS.
+ boxes_i = tf.gather(boxes_i, indices[:, :, i], batch_dims=1, axis=1)
+
+ # Filter out scores.
+ boxes_i, scores_i = box_utils.filter_boxes_by_scores(
+ boxes_i, scores_i, min_score_threshold=score_threshold)
+
+ (nmsed_scores_i, nmsed_boxes_i) = nms.sorted_non_max_suppression_padded(
+ tf.cast(scores_i, tf.float32),
+ tf.cast(boxes_i, tf.float32),
+ max_total_size,
+ iou_threshold=nms_iou_threshold)
+ nmsed_classes_i = tf.fill([batch_size, max_total_size], i)
+ nmsed_boxes.append(nmsed_boxes_i)
+ nmsed_scores.append(nmsed_scores_i)
+ nmsed_classes.append(nmsed_classes_i)
+ nmsed_boxes = tf.concat(nmsed_boxes, axis=1)
+ nmsed_scores = tf.concat(nmsed_scores, axis=1)
+ nmsed_classes = tf.concat(nmsed_classes, axis=1)
+ nmsed_scores, indices = tf.nn.top_k(
+ nmsed_scores, k=max_total_size, sorted=True)
+ nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
+ nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1)
+ valid_detections = tf.reduce_sum(
+ input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32), axis=1)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+def _generate_detections_per_image(boxes,
+ scores,
+ max_total_size=100,
+ nms_iou_threshold=0.3,
+ score_threshold=0.05,
+ pre_nms_num_boxes=5000):
+ """Generate the final detections per image given the model outputs.
+
+ Args:
+ boxes: a tensor with shape [N, num_classes, 4] or [N, 1, 4], which box
+ predictions on all feature levels. The N is the number of total anchors on
+ all levels.
+ scores: a tensor with shape [N, num_classes], which stacks class probability
+ on all feature levels. The N is the number of total anchors on all levels.
+ The num_classes is the number of classes predicted by the model. Note that
+ the class_outputs here is the raw score.
+ max_total_size: a scalar representing maximum number of boxes retained over
+ all classes.
+ nms_iou_threshold: a float representing the threshold for deciding whether
+ boxes overlap too much with respect to IOU.
+ score_threshold: a float representing the threshold for deciding when to
+ remove boxes based on score.
+ pre_nms_num_boxes: an int number of top candidate detections per class
+ before NMS.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [max_total_size, 4] representing top
+ detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [max_total_size] representing sorted
+ confidence scores for detected boxes. The values are between [0, 1].
+ nms_classes: `int` Tensor of shape [max_total_size] representing classes for
+ detected boxes.
+ valid_detections: `int` Tensor of shape [1] only the top `valid_detections`
+ boxes are valid detections.
+ """
+ nmsed_boxes = []
+ nmsed_scores = []
+ nmsed_classes = []
+ num_classes_for_box = boxes.get_shape().as_list()[1]
+ num_classes = scores.get_shape().as_list()[1]
+ for i in range(num_classes):
+ boxes_i = boxes[:, min(num_classes_for_box - 1, i)]
+ scores_i = scores[:, i]
+
+ # Obtains pre_nms_num_boxes before running NMS.
+ scores_i, indices = tf.nn.top_k(
+ scores_i, k=tf.minimum(tf.shape(input=scores_i)[-1], pre_nms_num_boxes))
+ boxes_i = tf.gather(boxes_i, indices)
+
+ (nmsed_indices_i, nmsed_num_valid_i) = tf.image.non_max_suppression_padded(
+ tf.cast(boxes_i, tf.float32),
+ tf.cast(scores_i, tf.float32),
+ max_total_size,
+ iou_threshold=nms_iou_threshold,
+ score_threshold=score_threshold,
+ pad_to_max_output_size=True,
+ name='nms_detections_' + str(i))
+ nmsed_boxes_i = tf.gather(boxes_i, nmsed_indices_i)
+ nmsed_scores_i = tf.gather(scores_i, nmsed_indices_i)
+ # Sets scores of invalid boxes to -1.
+ nmsed_scores_i = tf.where(
+ tf.less(tf.range(max_total_size), [nmsed_num_valid_i]), nmsed_scores_i,
+ -tf.ones_like(nmsed_scores_i))
+ nmsed_classes_i = tf.fill([max_total_size], i)
+ nmsed_boxes.append(nmsed_boxes_i)
+ nmsed_scores.append(nmsed_scores_i)
+ nmsed_classes.append(nmsed_classes_i)
+
+ # Concats results from all classes and sort them.
+ nmsed_boxes = tf.concat(nmsed_boxes, axis=0)
+ nmsed_scores = tf.concat(nmsed_scores, axis=0)
+ nmsed_classes = tf.concat(nmsed_classes, axis=0)
+ nmsed_scores, indices = tf.nn.top_k(
+ nmsed_scores, k=max_total_size, sorted=True)
+ nmsed_boxes = tf.gather(nmsed_boxes, indices)
+ nmsed_classes = tf.gather(nmsed_classes, indices)
+ valid_detections = tf.reduce_sum(
+ input_tensor=tf.cast(tf.greater(nmsed_scores, -1), tf.int32))
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+def _generate_detections_batched(boxes, scores, max_total_size,
+ nms_iou_threshold, score_threshold):
+ """Generates detected boxes with scores and classes for one-stage detector.
+
+ The function takes output of multi-level ConvNets and anchor boxes and
+ generates detected boxes. Note that this used batched nms, which is not
+ supported on TPU currently.
+
+ Args:
+ boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
+ N, 1, 4], which box predictions on all feature levels. The N is the number
+ of total anchors on all levels.
+ scores: a tensor with shape [batch_size, N, num_classes], which stacks class
+ probability on all feature levels. The N is the number of total anchors on
+ all levels. The num_classes is the number of classes predicted by the
+ model. Note that the class_outputs here is the raw score.
+ max_total_size: a scalar representing maximum number of boxes retained over
+ all classes.
+ nms_iou_threshold: a float representing the threshold for deciding whether
+ boxes overlap too much with respect to IOU.
+ score_threshold: a float representing the threshold for deciding when to
+ remove boxes based on score.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
+ representing top detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [batch_size, max_total_size]
+ representing sorted confidence scores for detected boxes. The values are
+ between [0, 1].
+ nms_classes: `int` Tensor of shape [batch_size, max_total_size] representing
+ classes for detected boxes.
+ valid_detections: `int` Tensor of shape [batch_size] only the top
+ `valid_detections` boxes are valid detections.
+ """
+ with tf.name_scope('generate_detections'):
+ # TODO(tsungyi): Removes normalization/denomalization once the
+ # tf.image.combined_non_max_suppression is coordinate system agnostic.
+ # Normalizes maximum box cooridinates to 1.
+ normalizer = tf.reduce_max(boxes)
+ boxes /= normalizer
+ (nmsed_boxes, nmsed_scores, nmsed_classes,
+ valid_detections) = tf.image.combined_non_max_suppression(
+ boxes,
+ scores,
+ max_output_size_per_class=max_total_size,
+ max_total_size=max_total_size,
+ iou_threshold=nms_iou_threshold,
+ score_threshold=score_threshold,
+ pad_per_class=False,
+ )
+ # De-normalizes box cooridinates.
+ nmsed_boxes *= normalizer
+ nmsed_classes = tf.cast(nmsed_classes, tf.int32)
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+class MultilevelDetectionGenerator(tf_keras.layers.Layer):
+ """Generates detected boxes with scores and classes for one-stage detector."""
+
+ def __init__(self, min_level, max_level, params):
+ self._min_level = min_level
+ self._max_level = max_level
+ self._generate_detections = generate_detections_factory(params)
+ super(MultilevelDetectionGenerator, self).__init__(autocast=False)
+
+ def call(self, box_outputs, class_outputs, anchor_boxes, image_shape):
+ # Collects outputs from all levels into a list.
+ boxes = []
+ scores = []
+ for i in range(self._min_level, self._max_level + 1):
+ box_outputs_i_shape = tf.shape(box_outputs[i])
+ batch_size = box_outputs_i_shape[0]
+ num_anchors_per_locations = box_outputs_i_shape[-1] // 4
+ num_classes = tf.shape(class_outputs[i])[-1] // num_anchors_per_locations
+
+ # Applies score transformation and remove the implicit background class.
+ scores_i = tf.sigmoid(
+ tf.reshape(class_outputs[i], [batch_size, -1, num_classes]))
+ scores_i = tf.slice(scores_i, [0, 0, 1], [-1, -1, -1])
+
+ # Box decoding.
+ # The anchor boxes are shared for all data in a batch.
+ # One stage detector only supports class agnostic box regression.
+ anchor_boxes_i = tf.reshape(anchor_boxes[i], [batch_size, -1, 4])
+ box_outputs_i = tf.reshape(box_outputs[i], [batch_size, -1, 4])
+ boxes_i = box_utils.decode_boxes(box_outputs_i, anchor_boxes_i)
+
+ # Box clipping.
+ boxes_i = box_utils.clip_boxes(boxes_i, image_shape)
+
+ boxes.append(boxes_i)
+ scores.append(scores_i)
+ boxes = tf.concat(boxes, axis=1)
+ scores = tf.concat(scores, axis=1)
+
+ nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
+ self._generate_detections(tf.expand_dims(boxes, axis=2), scores))
+
+ # Adds 1 to offset the background class which has index 0.
+ nmsed_classes += 1
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+class GenericDetectionGenerator(tf_keras.layers.Layer):
+ """Generates the final detected boxes with scores and classes."""
+
+ def __init__(self, params):
+ super(GenericDetectionGenerator, self).__init__(autocast=False)
+ self._generate_detections = generate_detections_factory(params)
+
+ def call(self, box_outputs, class_outputs, anchor_boxes, image_shape):
+ """Generate final detections.
+
+ Args:
+ box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
+ representing the class-specific box coordinates relative to anchors.
+ class_outputs: a tensor of shape of [batch_size, K, num_classes]
+ representing the class logits before applying score activiation.
+ anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
+ corresponding anchor boxes w.r.t `box_outputs`.
+ image_shape: a tensor of shape of [batch_size, 2] storing the image height
+ and width w.r.t. the scaled image, i.e. the same image space as
+ `box_outputs` and `anchor_boxes`.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
+ representing top detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [batch_size, max_total_size]
+ representing sorted confidence scores for detected boxes. The values are
+ between [0, 1].
+ nms_classes: `int` Tensor of shape [batch_size, max_total_size]
+ representing classes for detected boxes.
+ valid_detections: `int` Tensor of shape [batch_size] only the top
+ `valid_detections` boxes are valid detections.
+ """
+ class_outputs = tf.nn.softmax(class_outputs, axis=-1)
+
+ # Removes the background class.
+ class_outputs_shape = tf.shape(class_outputs)
+ batch_size = class_outputs_shape[0]
+ num_locations = class_outputs_shape[1]
+ num_classes = class_outputs_shape[-1]
+ num_detections = num_locations * (num_classes - 1)
+
+ class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
+ box_outputs = tf.reshape(
+ box_outputs,
+ tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
+ box_outputs = tf.slice(box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
+ anchor_boxes = tf.tile(
+ tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
+ box_outputs = tf.reshape(box_outputs,
+ tf.stack([batch_size, num_detections, 4], axis=-1))
+ anchor_boxes = tf.reshape(
+ anchor_boxes, tf.stack([batch_size, num_detections, 4], axis=-1))
+
+ # Box decoding.
+ decoded_boxes = box_utils.decode_boxes(
+ box_outputs, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0])
+
+ # Box clipping
+ decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)
+
+ decoded_boxes = tf.reshape(
+ decoded_boxes,
+ tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))
+
+ nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
+ self._generate_detections(decoded_boxes, class_outputs))
+
+ # Adds 1 to offset the background class which has index 0.
+ nmsed_classes += 1
+
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
+
+
+class OlnDetectionGenerator(GenericDetectionGenerator):
+ """Generates the final detected boxes with scores and classes."""
+
+ def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape,
+ is_single_fg_score=False, keep_nms=True):
+ """Generate final detections for Object Localization Network (OLN).
+
+ Args:
+ box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
+ representing the class-specific box coordinates relative to anchors.
+ class_outputs: a tensor of shape of [batch_size, K, num_classes]
+ representing the class logits before applying score activiation.
+ anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
+ corresponding anchor boxes w.r.t `box_outputs`.
+ image_shape: a tensor of shape of [batch_size, 2] storing the image height
+ and width w.r.t. the scaled image, i.e. the same image space as
+ `box_outputs` and `anchor_boxes`.
+ is_single_fg_score: a Bool indicator of whether class_outputs includes the
+ background scores concatenated or not. By default, class_outputs is a
+ concatenation of both scores for the foreground and background. That is,
+ scores_without_bg=False.
+ keep_nms: a Bool indicator of whether to perform NMS or not.
+
+ Returns:
+ nms_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
+ representing top detected boxes in [y1, x1, y2, x2].
+ nms_scores: `float` Tensor of shape [batch_size, max_total_size]
+ representing sorted confidence scores for detected boxes. The values are
+ between [0, 1].
+ nms_classes: `int` Tensor of shape [batch_size, max_total_size]
+ representing classes for detected boxes.
+ valid_detections: `int` Tensor of shape [batch_size] only the top
+ `valid_detections` boxes are valid detections.
+ """
+ if is_single_fg_score:
+ # Concatenates dummy background scores.
+ dummy_bg_scores = tf.zeros_like(class_outputs)
+ class_outputs = tf.stack([dummy_bg_scores, class_outputs], -1)
+ else:
+ class_outputs = tf.nn.softmax(class_outputs, axis=-1)
+
+ # Removes the background class.
+ class_outputs_shape = tf.shape(class_outputs)
+ batch_size = class_outputs_shape[0]
+ num_locations = class_outputs_shape[1]
+ num_classes = class_outputs_shape[-1]
+ num_detections = num_locations * (num_classes - 1)
+
+ class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
+ box_outputs = tf.reshape(
+ box_outputs,
+ tf.stack([batch_size, num_locations, num_classes, 4], axis=-1))
+ box_outputs = tf.slice(box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
+ anchor_boxes = tf.tile(
+ tf.expand_dims(anchor_boxes, axis=2), [1, 1, num_classes - 1, 1])
+ box_outputs = tf.reshape(box_outputs,
+ tf.stack([batch_size, num_detections, 4], axis=-1))
+ anchor_boxes = tf.reshape(
+ anchor_boxes, tf.stack([batch_size, num_detections, 4], axis=-1))
+
+ # Box decoding. For RPN outputs, box_outputs are all zeros.
+ decoded_boxes = box_utils.decode_boxes(
+ box_outputs, anchor_boxes, weights=[10.0, 10.0, 5.0, 5.0])
+
+ # Box clipping
+ decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)
+
+ decoded_boxes = tf.reshape(
+ decoded_boxes,
+ tf.stack([batch_size, num_locations, num_classes - 1, 4], axis=-1))
+
+ if keep_nms:
+ nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
+ self._generate_detections(decoded_boxes, class_outputs))
+ # Adds 1 to offset the background class which has index 0.
+ nmsed_classes += 1
+ else:
+ nmsed_boxes = decoded_boxes[:, :, 0, :]
+ nmsed_scores = class_outputs[:, :, 0]
+ nmsed_classes = tf.cast(tf.ones_like(nmsed_scores), tf.int32)
+ valid_detections = tf.cast(
+ tf.reduce_sum(tf.ones_like(nmsed_scores), axis=-1), tf.int32)
+
+ return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
diff --git a/modeling/official/legacy/detection/ops/roi_ops.py b/modeling/official/legacy/detection/ops/roi_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..94a192810ef7ac3139a6afa1539d5d77d35d8259
--- /dev/null
+++ b/modeling/official/legacy/detection/ops/roi_ops.py
@@ -0,0 +1,468 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ROI-related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.ops import nms
+from official.legacy.detection.utils import box_utils
+
+
+def multilevel_propose_rois(rpn_boxes,
+ rpn_scores,
+ anchor_boxes,
+ image_shape,
+ rpn_pre_nms_top_k=2000,
+ rpn_post_nms_top_k=1000,
+ rpn_nms_threshold=0.7,
+ rpn_score_threshold=0.0,
+ rpn_min_size_threshold=0.0,
+ decode_boxes=True,
+ clip_boxes=True,
+ use_batched_nms=False,
+ apply_sigmoid_to_score=True):
+ """Proposes RoIs given a group of candidates from different FPN levels.
+
+ The following describes the steps:
+ 1. For each individual level:
+ a. Apply sigmoid transform if specified.
+ b. Decode boxes if specified.
+ c. Clip boxes if specified.
+ d. Filter small boxes and those fall outside image if specified.
+ e. Apply pre-NMS filtering including pre-NMS top k and score thresholding.
+ f. Apply NMS.
+ 2. Aggregate post-NMS boxes from each level.
+ 3. Apply an overall top k to generate the final selected RoIs.
+
+ Args:
+ rpn_boxes: a dict with keys representing FPN levels and values representing
+ box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
+ rpn_scores: a dict with keys representing FPN levels and values representing
+ logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
+ anchor_boxes: a dict with keys representing FPN levels and values
+ representing anchor box tensors of shape [batch_size, feature_h,
+ feature_w, num_anchors * 4].
+ image_shape: a tensor of shape [batch_size, 2] where the last dimension are
+ [height, width] of the scaled image.
+ rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
+ keep before applying NMS. Default: 2000.
+ rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
+ keep after applying NMS. Default: 1000.
+ rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
+ used for NMS. If 0.0, no NMS is applied. Default: 0.7.
+ rpn_score_threshold: a float between 0 and 1 representing the minimal box
+ score to keep before applying NMS. This is often used as a pre-filtering
+ step for better performance. If 0, no filtering is applied. Default: 0.
+ rpn_min_size_threshold: a float representing the minimal box size in each
+ side (w.r.t. the scaled image) to keep before applying NMS. This is often
+ used as a pre-filtering step for better performance. If 0, no filtering is
+ applied. Default: 0.
+ decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
+ using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
+ `anchor_boxes`. Default: True.
+ clip_boxes: a boolean indicating whether boxes are first clipped to the
+ scaled image size before appliying NMS. If False, no clipping is applied
+ and `image_shape` is ignored. Default: True.
+ use_batched_nms: a boolean indicating whether NMS is applied in batch using
+ `tf.image.combined_non_max_suppression`. Currently only available in
+ CPU/GPU. Default: False.
+ apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
+ `rpn_scores` before applying NMS. Default: True.
+
+ Returns:
+ selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
+ representing the box coordinates of the selected proposals w.r.t. the
+ scaled image.
+ selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k, 1],
+ representing the scores of the selected proposals.
+ """
+ with tf.name_scope('multilevel_propose_rois'):
+ rois = []
+ roi_scores = []
+ image_shape = tf.expand_dims(image_shape, axis=1)
+ for level in sorted(rpn_scores.keys()):
+ with tf.name_scope('level_%d' % level):
+ _, feature_h, feature_w, num_anchors_per_location = (
+ rpn_scores[level].get_shape().as_list())
+
+ num_boxes = feature_h * feature_w * num_anchors_per_location
+ this_level_scores = tf.reshape(rpn_scores[level], [-1, num_boxes])
+ this_level_boxes = tf.reshape(rpn_boxes[level], [-1, num_boxes, 4])
+ this_level_anchors = tf.cast(
+ tf.reshape(anchor_boxes[level], [-1, num_boxes, 4]),
+ dtype=this_level_scores.dtype)
+
+ if apply_sigmoid_to_score:
+ this_level_scores = tf.sigmoid(this_level_scores)
+
+ if decode_boxes:
+ this_level_boxes = box_utils.decode_boxes(this_level_boxes,
+ this_level_anchors)
+ if clip_boxes:
+ this_level_boxes = box_utils.clip_boxes(this_level_boxes, image_shape)
+
+ if rpn_min_size_threshold > 0.0:
+ this_level_boxes, this_level_scores = box_utils.filter_boxes(
+ this_level_boxes, this_level_scores, image_shape,
+ rpn_min_size_threshold)
+
+ this_level_pre_nms_top_k = min(num_boxes, rpn_pre_nms_top_k)
+ this_level_post_nms_top_k = min(num_boxes, rpn_post_nms_top_k)
+ if rpn_nms_threshold > 0.0:
+ if use_batched_nms:
+ this_level_rois, this_level_roi_scores, _, _ = (
+ tf.image.combined_non_max_suppression(
+ tf.expand_dims(this_level_boxes, axis=2),
+ tf.expand_dims(this_level_scores, axis=-1),
+ max_output_size_per_class=this_level_pre_nms_top_k,
+ max_total_size=this_level_post_nms_top_k,
+ iou_threshold=rpn_nms_threshold,
+ score_threshold=rpn_score_threshold,
+ pad_per_class=False,
+ clip_boxes=False))
+ else:
+ if rpn_score_threshold > 0.0:
+ this_level_boxes, this_level_scores = (
+ box_utils.filter_boxes_by_scores(this_level_boxes,
+ this_level_scores,
+ rpn_score_threshold))
+ this_level_boxes, this_level_scores = box_utils.top_k_boxes(
+ this_level_boxes, this_level_scores, k=this_level_pre_nms_top_k)
+ this_level_roi_scores, this_level_rois = (
+ nms.sorted_non_max_suppression_padded(
+ this_level_scores,
+ this_level_boxes,
+ max_output_size=this_level_post_nms_top_k,
+ iou_threshold=rpn_nms_threshold))
+ else:
+ this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
+ this_level_rois, this_level_scores, k=this_level_post_nms_top_k)
+
+ rois.append(this_level_rois)
+ roi_scores.append(this_level_roi_scores)
+
+ all_rois = tf.concat(rois, axis=1)
+ all_roi_scores = tf.concat(roi_scores, axis=1)
+
+ with tf.name_scope('top_k_rois'):
+ _, num_valid_rois = all_roi_scores.get_shape().as_list()
+ overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)
+
+ selected_rois, selected_roi_scores = box_utils.top_k_boxes(
+ all_rois, all_roi_scores, k=overall_top_k)
+
+ return selected_rois, selected_roi_scores
+
+
+class ROIGenerator(tf_keras.layers.Layer):
+ """Proposes RoIs for the second stage processing."""
+
+ def __init__(self, params):
+ self._rpn_pre_nms_top_k = params.rpn_pre_nms_top_k
+ self._rpn_post_nms_top_k = params.rpn_post_nms_top_k
+ self._rpn_nms_threshold = params.rpn_nms_threshold
+ self._rpn_score_threshold = params.rpn_score_threshold
+ self._rpn_min_size_threshold = params.rpn_min_size_threshold
+ self._test_rpn_pre_nms_top_k = params.test_rpn_pre_nms_top_k
+ self._test_rpn_post_nms_top_k = params.test_rpn_post_nms_top_k
+ self._test_rpn_nms_threshold = params.test_rpn_nms_threshold
+ self._test_rpn_score_threshold = params.test_rpn_score_threshold
+ self._test_rpn_min_size_threshold = params.test_rpn_min_size_threshold
+ self._use_batched_nms = params.use_batched_nms
+ super(ROIGenerator, self).__init__(autocast=False)
+
+ def call(self, boxes, scores, anchor_boxes, image_shape, is_training):
+ """Generates RoI proposals.
+
+ Args:
+ boxes: a dict with keys representing FPN levels and values representing
+ box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
+ scores: a dict with keys representing FPN levels and values representing
+ logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
+ anchor_boxes: a dict with keys representing FPN levels and values
+ representing anchor box tensors of shape [batch_size, feature_h,
+ feature_w, num_anchors * 4].
+ image_shape: a tensor of shape [batch_size, 2] where the last dimension
+ are [height, width] of the scaled image.
+ is_training: a bool indicating whether it is in training or inference
+ mode.
+
+ Returns:
+ proposed_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
+ representing the box coordinates of the proposed RoIs w.r.t. the
+ scaled image.
+ proposed_roi_scores: a tensor of shape
+ [batch_size, rpn_post_nms_top_k, 1], representing the scores of the
+ proposed RoIs.
+
+ """
+ proposed_rois, proposed_roi_scores = multilevel_propose_rois(
+ boxes,
+ scores,
+ anchor_boxes,
+ image_shape,
+ rpn_pre_nms_top_k=(self._rpn_pre_nms_top_k
+ if is_training else self._test_rpn_pre_nms_top_k),
+ rpn_post_nms_top_k=(self._rpn_post_nms_top_k
+ if is_training else self._test_rpn_post_nms_top_k),
+ rpn_nms_threshold=(self._rpn_nms_threshold
+ if is_training else self._test_rpn_nms_threshold),
+ rpn_score_threshold=(self._rpn_score_threshold if is_training else
+ self._test_rpn_score_threshold),
+ rpn_min_size_threshold=(self._rpn_min_size_threshold if is_training else
+ self._test_rpn_min_size_threshold),
+ decode_boxes=True,
+ clip_boxes=True,
+ use_batched_nms=self._use_batched_nms,
+ apply_sigmoid_to_score=True)
+ return proposed_rois, proposed_roi_scores
+
+
+class OlnROIGenerator(ROIGenerator):
+ """Proposes RoIs for the second stage processing."""
+
+ def __call__(self, boxes, scores, anchor_boxes, image_shape, is_training,
+ is_box_lrtb=False, object_scores=None):
+ """Generates RoI proposals.
+
+ Args:
+ boxes: a dict with keys representing FPN levels and values representing
+ box tenors of shape [batch_size, feature_h, feature_w, num_anchors * 4].
+ scores: a dict with keys representing FPN levels and values representing
+ logit tensors of shape [batch_size, feature_h, feature_w, num_anchors].
+ anchor_boxes: a dict with keys representing FPN levels and values
+ representing anchor box tensors of shape [batch_size, feature_h,
+ feature_w, num_anchors * 4].
+ image_shape: a tensor of shape [batch_size, 2] where the last dimension
+ are [height, width] of the scaled image.
+ is_training: a bool indicating whether it is in training or inference
+ mode.
+ is_box_lrtb: a bool indicating whether boxes are in lrtb (=left,right,top,
+ bottom) format.
+ object_scores: another objectness score (e.g., centerness). In OLN, we use
+ object_scores=centerness as a replacement of the scores at each level.
+ A dict with keys representing FPN levels and values representing logit
+ tensors of shape [batch_size, feature_h, feature_w, num_anchors].
+
+ Returns:
+ proposed_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
+ representing the box coordinates of the proposed RoIs w.r.t. the
+ scaled image.
+ proposed_roi_scores: a tensor of shape
+ [batch_size, rpn_post_nms_top_k, 1], representing the scores of the
+ proposed RoIs.
+
+ """
+ proposed_rois, proposed_roi_scores = self.oln_multilevel_propose_rois(
+ boxes,
+ scores,
+ anchor_boxes,
+ image_shape,
+ rpn_pre_nms_top_k=(self._rpn_pre_nms_top_k
+ if is_training else self._test_rpn_pre_nms_top_k),
+ rpn_post_nms_top_k=(self._rpn_post_nms_top_k
+ if is_training else self._test_rpn_post_nms_top_k),
+ rpn_nms_threshold=(self._rpn_nms_threshold
+ if is_training else self._test_rpn_nms_threshold),
+ rpn_score_threshold=(self._rpn_score_threshold if is_training else
+ self._test_rpn_score_threshold),
+ rpn_min_size_threshold=(self._rpn_min_size_threshold if is_training else
+ self._test_rpn_min_size_threshold),
+ decode_boxes=True,
+ clip_boxes=True,
+ use_batched_nms=self._use_batched_nms,
+ apply_sigmoid_to_score=True,
+ is_box_lrtb=is_box_lrtb,
+ rpn_object_scores=object_scores,)
+ return proposed_rois, proposed_roi_scores
+
+ def oln_multilevel_propose_rois(self,
+ rpn_boxes,
+ rpn_scores,
+ anchor_boxes,
+ image_shape,
+ rpn_pre_nms_top_k=2000,
+ rpn_post_nms_top_k=1000,
+ rpn_nms_threshold=0.7,
+ rpn_score_threshold=0.0,
+ rpn_min_size_threshold=0.0,
+ decode_boxes=True,
+ clip_boxes=True,
+ use_batched_nms=False,
+ apply_sigmoid_to_score=True,
+ is_box_lrtb=False,
+ rpn_object_scores=None,):
+ """Proposes RoIs given a group of candidates from different FPN levels.
+
+ The following describes the steps:
+ 1. For each individual level:
+ a. Adjust scores for each level if specified by rpn_object_scores.
+ b. Apply sigmoid transform if specified.
+ c. Decode boxes (either of xyhw or left-right-top-bottom format) if
+ specified.
+ d. Clip boxes if specified.
+ e. Filter small boxes and those fall outside image if specified.
+ f. Apply pre-NMS filtering including pre-NMS top k and score
+ thresholding.
+ g. Apply NMS.
+ 2. Aggregate post-NMS boxes from each level.
+ 3. Apply an overall top k to generate the final selected RoIs.
+
+ Args:
+ rpn_boxes: a dict with keys representing FPN levels and values
+ representing box tenors of shape [batch_size, feature_h, feature_w,
+ num_anchors * 4].
+ rpn_scores: a dict with keys representing FPN levels and values
+ representing logit tensors of shape [batch_size, feature_h, feature_w,
+ num_anchors].
+ anchor_boxes: a dict with keys representing FPN levels and values
+ representing anchor box tensors of shape [batch_size, feature_h,
+ feature_w, num_anchors * 4].
+ image_shape: a tensor of shape [batch_size, 2] where the last dimension
+ are [height, width] of the scaled image.
+ rpn_pre_nms_top_k: an integer of top scoring RPN proposals *per level* to
+ keep before applying NMS. Default: 2000.
+ rpn_post_nms_top_k: an integer of top scoring RPN proposals *in total* to
+ keep after applying NMS. Default: 1000.
+ rpn_nms_threshold: a float between 0 and 1 representing the IoU threshold
+ used for NMS. If 0.0, no NMS is applied. Default: 0.7.
+ rpn_score_threshold: a float between 0 and 1 representing the minimal box
+ score to keep before applying NMS. This is often used as a pre-filtering
+ step for better performance. If 0, no filtering is applied. Default: 0.
+ rpn_min_size_threshold: a float representing the minimal box size in each
+ side (w.r.t. the scaled image) to keep before applying NMS. This is
+ often used as a pre-filtering step for better performance. If 0, no
+ filtering is applied. Default: 0.
+ decode_boxes: a boolean indicating whether `rpn_boxes` needs to be decoded
+ using `anchor_boxes`. If False, use `rpn_boxes` directly and ignore
+ `anchor_boxes`. Default: True.
+ clip_boxes: a boolean indicating whether boxes are first clipped to the
+ scaled image size before appliying NMS. If False, no clipping is applied
+ and `image_shape` is ignored. Default: True.
+ use_batched_nms: a boolean indicating whether NMS is applied in batch
+ using `tf.image.combined_non_max_suppression`. Currently only available
+ in CPU/GPU. Default: False.
+ apply_sigmoid_to_score: a boolean indicating whether apply sigmoid to
+ `rpn_scores` before applying NMS. Default: True.
+ is_box_lrtb: a bool indicating whether boxes are in lrtb (=left,right,top,
+ bottom) format.
+ rpn_object_scores: a predicted objectness score (e.g., centerness). In
+ OLN, we use object_scores=centerness as a replacement of the scores at
+ each level. A dict with keys representing FPN levels and values
+ representing logit tensors of shape [batch_size, feature_h, feature_w,
+ num_anchors].
+
+ Returns:
+ selected_rois: a tensor of shape [batch_size, rpn_post_nms_top_k, 4],
+ representing the box coordinates of the selected proposals w.r.t. the
+ scaled image.
+ selected_roi_scores: a tensor of shape [batch_size, rpn_post_nms_top_k,
+ 1],representing the scores of the selected proposals.
+ """
+ with tf.name_scope('multilevel_propose_rois'):
+ rois = []
+ roi_scores = []
+ image_shape = tf.expand_dims(image_shape, axis=1)
+ for level in sorted(rpn_scores.keys()):
+ with tf.name_scope('level_%d' % level):
+ _, feature_h, feature_w, num_anchors_per_location = (
+ rpn_scores[level].get_shape().as_list())
+
+ num_boxes = feature_h * feature_w * num_anchors_per_location
+ this_level_scores = tf.reshape(rpn_scores[level], [-1, num_boxes])
+ this_level_boxes = tf.reshape(rpn_boxes[level], [-1, num_boxes, 4])
+ this_level_anchors = tf.cast(
+ tf.reshape(anchor_boxes[level], [-1, num_boxes, 4]),
+ dtype=this_level_scores.dtype)
+
+ if rpn_object_scores:
+ this_level_object_scores = rpn_object_scores[level]
+ this_level_object_scores = tf.reshape(this_level_object_scores,
+ [-1, num_boxes])
+ this_level_object_scores = tf.cast(this_level_object_scores,
+ this_level_scores.dtype)
+ this_level_scores = this_level_object_scores
+
+ if apply_sigmoid_to_score:
+ this_level_scores = tf.sigmoid(this_level_scores)
+
+ if decode_boxes:
+ if is_box_lrtb: # Box in left-right-top-bottom format.
+ this_level_boxes = box_utils.decode_boxes_lrtb(
+ this_level_boxes, this_level_anchors)
+ else: # Box in standard x-y-h-w format.
+ this_level_boxes = box_utils.decode_boxes(
+ this_level_boxes, this_level_anchors)
+
+ if clip_boxes:
+ this_level_boxes = box_utils.clip_boxes(
+ this_level_boxes, image_shape)
+
+ if rpn_min_size_threshold > 0.0:
+ this_level_boxes, this_level_scores = box_utils.filter_boxes(
+ this_level_boxes, this_level_scores, image_shape,
+ rpn_min_size_threshold)
+
+ this_level_pre_nms_top_k = min(num_boxes, rpn_pre_nms_top_k)
+ this_level_post_nms_top_k = min(num_boxes, rpn_post_nms_top_k)
+ if rpn_nms_threshold > 0.0:
+ if use_batched_nms:
+ this_level_rois, this_level_roi_scores, _, _ = (
+ tf.image.combined_non_max_suppression(
+ tf.expand_dims(this_level_boxes, axis=2),
+ tf.expand_dims(this_level_scores, axis=-1),
+ max_output_size_per_class=this_level_pre_nms_top_k,
+ max_total_size=this_level_post_nms_top_k,
+ iou_threshold=rpn_nms_threshold,
+ score_threshold=rpn_score_threshold,
+ pad_per_class=False,
+ clip_boxes=False))
+ else:
+ if rpn_score_threshold > 0.0:
+ this_level_boxes, this_level_scores = (
+ box_utils.filter_boxes_by_scores(this_level_boxes,
+ this_level_scores,
+ rpn_score_threshold))
+ this_level_boxes, this_level_scores = box_utils.top_k_boxes(
+ this_level_boxes, this_level_scores,
+ k=this_level_pre_nms_top_k)
+ this_level_roi_scores, this_level_rois = (
+ nms.sorted_non_max_suppression_padded(
+ this_level_scores,
+ this_level_boxes,
+ max_output_size=this_level_post_nms_top_k,
+ iou_threshold=rpn_nms_threshold))
+ else:
+ this_level_rois, this_level_roi_scores = box_utils.top_k_boxes(
+ this_level_rois, this_level_scores, k=this_level_post_nms_top_k)
+
+ rois.append(this_level_rois)
+ roi_scores.append(this_level_roi_scores)
+
+ all_rois = tf.concat(rois, axis=1)
+ all_roi_scores = tf.concat(roi_scores, axis=1)
+
+ with tf.name_scope('top_k_rois'):
+ _, num_valid_rois = all_roi_scores.get_shape().as_list()
+ overall_top_k = min(num_valid_rois, rpn_post_nms_top_k)
+
+ selected_rois, selected_roi_scores = box_utils.top_k_boxes(
+ all_rois, all_roi_scores, k=overall_top_k)
+
+ return selected_rois, selected_roi_scores
diff --git a/modeling/official/legacy/detection/ops/spatial_transform_ops.py b/modeling/official/legacy/detection/ops/spatial_transform_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..70d3fac496cac41294fb67ea0ccbbe627839a938
--- /dev/null
+++ b/modeling/official/legacy/detection/ops/spatial_transform_ops.py
@@ -0,0 +1,603 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions to performa spatial transformation for Tensor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+_EPSILON = 1e-8
+
+
+def nearest_upsampling(data, scale):
+ """Nearest neighbor upsampling implementation.
+
+ Args:
+ data: A tensor with a shape of [batch, height_in, width_in, channels].
+ scale: An integer multiple to scale resolution of input data.
+
+ Returns:
+ data_up: A tensor with a shape of
+ [batch, height_in*scale, width_in*scale, channels]. Same dtype as input
+ data.
+ """
+ with tf.name_scope('nearest_upsampling'):
+ bs, _, _, c = data.get_shape().as_list()
+ shape = tf.shape(input=data)
+ h = shape[1]
+ w = shape[2]
+ bs = -1 if bs is None else bs
+ # Uses reshape to quickly upsample the input. The nearest pixel is selected
+ # implicitly via broadcasting.
+ data = tf.reshape(data, [bs, h, 1, w, 1, c]) * tf.ones(
+ [1, 1, scale, 1, scale, 1], dtype=data.dtype)
+ return tf.reshape(data, [bs, h * scale, w * scale, c])
+
+
+def feature_bilinear_interpolation(features, kernel_y, kernel_x):
+ """Feature bilinear interpolation.
+
+ The RoIAlign feature f can be computed by bilinear interpolation
+ of four neighboring feature points f0, f1, f2, and f3.
+
+ f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
+ [f10, f11]]
+ f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
+ f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
+ kernel_y = [hy, ly]
+ kernel_x = [hx, lx]
+
+ Args:
+ features: The features are in shape of [batch_size, num_boxes, output_size *
+ 2, output_size * 2, num_filters].
+ kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
+ kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
+
+ Returns:
+ A 5-D tensor representing feature crop of shape
+ [batch_size, num_boxes, output_size, output_size, num_filters].
+
+ """
+ (batch_size, num_boxes, output_size, _,
+ num_filters) = features.get_shape().as_list()
+ output_size = output_size // 2
+ kernel_y = tf.reshape(kernel_y, [batch_size, num_boxes, output_size * 2, 1])
+ kernel_x = tf.reshape(kernel_x, [batch_size, num_boxes, 1, output_size * 2])
+ # Use implicit broadcast to generate the interpolation kernel. The
+ # multiplier `4` is for avg pooling.
+ interpolation_kernel = kernel_y * kernel_x * 4
+
+ # Interpolate the gathered features with computed interpolation kernels.
+ features *= tf.cast(
+ tf.expand_dims(interpolation_kernel, axis=-1), dtype=features.dtype)
+ features = tf.reshape(
+ features,
+ [batch_size * num_boxes, output_size * 2, output_size * 2, num_filters])
+ features = tf.nn.avg_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
+ features = tf.reshape(
+ features, [batch_size, num_boxes, output_size, output_size, num_filters])
+ return features
+
+
+def compute_grid_positions(boxes, boundaries, output_size, sample_offset):
+ """Compute the grid position w.r.t.
+
+ the corresponding feature map.
+
+ Args:
+ boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
+ information of each box w.r.t. the corresponding feature map.
+ boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
+ corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
+ in terms of the number of pixels of the corresponding feature map size.
+ boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
+ the boundary (in (y, x)) of the corresponding feature map for each box.
+ Any resampled grid points that go beyond the bounary will be clipped.
+ output_size: a scalar indicating the output crop size.
+ sample_offset: a float number in [0, 1] indicates the subpixel sample offset
+ from grid point.
+
+ Returns:
+ kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1].
+ kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1].
+ box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2]
+ box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2]
+ """
+ batch_size, num_boxes, _ = boxes.get_shape().as_list()
+ box_grid_x = []
+ box_grid_y = []
+ for i in range(output_size):
+ box_grid_x.append(boxes[:, :, 1] +
+ (i + sample_offset) * boxes[:, :, 3] / output_size)
+ box_grid_y.append(boxes[:, :, 0] +
+ (i + sample_offset) * boxes[:, :, 2] / output_size)
+ box_grid_x = tf.stack(box_grid_x, axis=2)
+ box_grid_y = tf.stack(box_grid_y, axis=2)
+
+ box_grid_y0 = tf.floor(box_grid_y)
+ box_grid_x0 = tf.floor(box_grid_x)
+ box_grid_x0 = tf.maximum(0., box_grid_x0)
+ box_grid_y0 = tf.maximum(0., box_grid_y0)
+
+ box_grid_x0 = tf.minimum(box_grid_x0, tf.expand_dims(boundaries[:, :, 1], -1))
+ box_grid_x1 = tf.minimum(box_grid_x0 + 1,
+ tf.expand_dims(boundaries[:, :, 1], -1))
+ box_grid_y0 = tf.minimum(box_grid_y0, tf.expand_dims(boundaries[:, :, 0], -1))
+ box_grid_y1 = tf.minimum(box_grid_y0 + 1,
+ tf.expand_dims(boundaries[:, :, 0], -1))
+
+ box_gridx0x1 = tf.stack([box_grid_x0, box_grid_x1], axis=-1)
+ box_gridy0y1 = tf.stack([box_grid_y0, box_grid_y1], axis=-1)
+
+ # The RoIAlign feature f can be computed by bilinear interpolation of four
+ # neighboring feature points f0, f1, f2, and f3.
+ # f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
+ # [f10, f11]]
+ # f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11
+ # f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11
+ ly = box_grid_y - box_grid_y0
+ lx = box_grid_x - box_grid_x0
+ hy = 1.0 - ly
+ hx = 1.0 - lx
+ kernel_y = tf.reshape(
+ tf.stack([hy, ly], axis=3), [batch_size, num_boxes, output_size, 2, 1])
+ kernel_x = tf.reshape(
+ tf.stack([hx, lx], axis=3), [batch_size, num_boxes, output_size, 2, 1])
+ return kernel_y, kernel_x, box_gridy0y1, box_gridx0x1
+
+
+def get_grid_one_hot(box_gridy0y1, box_gridx0x1, feature_height, feature_width):
+ """Get grid_one_hot from indices and feature_size."""
+ (batch_size, num_boxes, output_size, _) = box_gridx0x1.get_shape().as_list()
+ y_indices = tf.cast(
+ tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size, 2]),
+ dtype=tf.int32)
+ x_indices = tf.cast(
+ tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size, 2]),
+ dtype=tf.int32)
+
+ # shape is [batch_size, num_boxes, output_size, 2, height]
+ grid_y_one_hot = tf.one_hot(tf.cast(y_indices, tf.int32), feature_height)
+ # shape is [batch_size, num_boxes, output_size, 2, width]
+ grid_x_one_hot = tf.one_hot(tf.cast(x_indices, tf.int32), feature_width)
+
+ return grid_y_one_hot, grid_x_one_hot
+
+
+def selective_crop_and_resize(features,
+ boxes,
+ box_levels,
+ boundaries,
+ output_size=7,
+ sample_offset=0.5,
+ use_einsum_gather=False):
+ """Crop and resize boxes on a set of feature maps.
+
+ Given multiple features maps indexed by different levels, and a set of boxes
+ where each box is mapped to a certain level, it selectively crops and resizes
+ boxes from the corresponding feature maps to generate the box features.
+
+ We follow the ROIAlign technique (see https://arxiv.org/pdf/1703.06870.pdf,
+ figure 3 for reference). Specifically, for each feature map, we select an
+ (output_size, output_size) set of pixels corresponding to the box location,
+ and then use bilinear interpolation to select the feature value for each
+ pixel.
+
+ For performance, we perform the gather and interpolation on all layers as a
+ single operation. In this op the multi-level features are first stacked and
+ gathered into [2*output_size, 2*output_size] feature points. Then bilinear
+ interpolation is performed on the gathered feature points to generate
+ [output_size, output_size] RoIAlign feature map.
+
+ Here is the step-by-step algorithm:
+ 1. The multi-level features are gathered into a
+ [batch_size, num_boxes, output_size*2, output_size*2, num_filters]
+ Tensor. The Tensor contains four neighboring feature points for each
+ vertice in the output grid.
+ 2. Compute the interpolation kernel of shape
+ [batch_size, num_boxes, output_size*2, output_size*2]. The last 2 axis
+ can be seen as stacking 2x2 interpolation kernels for all vertices in the
+ output grid.
+ 3. Element-wise multiply the gathered features and interpolation kernel.
+ Then apply 2x2 average pooling to reduce spatial dimension to
+ output_size.
+
+ Args:
+ features: a 5-D tensor of shape [batch_size, num_levels, max_height,
+ max_width, num_filters] where cropping and resizing are based.
+ boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the
+ information of each box w.r.t. the corresponding feature map.
+ boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left
+ corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float)
+ in terms of the number of pixels of the corresponding feature map size.
+ box_levels: a 3-D tensor of shape [batch_size, num_boxes, 1] representing
+ the 0-based corresponding feature level index of each box.
+ boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing
+ the boundary (in (y, x)) of the corresponding feature map for each box.
+ Any resampled grid points that go beyond the bounary will be clipped.
+ output_size: a scalar indicating the output crop size.
+ sample_offset: a float number in [0, 1] indicates the subpixel sample offset
+ from grid point.
+ use_einsum_gather: use einsum to replace gather or not. Replacing einsum
+ with gather can improve performance when feature size is not large, einsum
+ is friendly with model partition as well. Gather's performance is better
+ when feature size is very large and there are multiple box levels.
+
+ Returns:
+ features_per_box: a 5-D tensor of shape
+ [batch_size, num_boxes, output_size, output_size, num_filters]
+ representing the cropped features.
+ """
+ (batch_size, num_levels, max_feature_height, max_feature_width,
+ num_filters) = features.get_shape().as_list()
+ _, num_boxes, _ = boxes.get_shape().as_list()
+
+ kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = compute_grid_positions(
+ boxes, boundaries, output_size, sample_offset)
+ x_indices = tf.cast(
+ tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+ y_indices = tf.cast(
+ tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+
+ if use_einsum_gather:
+ # Blinear interpolation is done during the last two gathers:
+ # f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T
+ # [f10, f11]]
+ # [[f00, f01],
+ # [f10, f11]] = tf.einsum(tf.einsum(features, y_one_hot), x_one_hot)
+ # where [hy, ly] and [hx, lx] are the bilinear interpolation kernel.
+
+ # shape is [batch_size, boxes, output_size, 2, 1]
+ grid_y_one_hot, grid_x_one_hot = get_grid_one_hot(box_gridy0y1,
+ box_gridx0x1,
+ max_feature_height,
+ max_feature_width)
+
+ # shape is [batch_size, num_boxes, output_size, height]
+ grid_y_weight = tf.reduce_sum(
+ tf.multiply(grid_y_one_hot, kernel_y), axis=-2)
+ # shape is [batch_size, num_boxes, output_size, width]
+ grid_x_weight = tf.reduce_sum(
+ tf.multiply(grid_x_one_hot, kernel_x), axis=-2)
+
+ # Gather for y_axis.
+ # shape is [batch_size, num_boxes, output_size, width, features]
+ features_per_box = tf.einsum('bmhwf,bmoh->bmowf', features,
+ tf.cast(grid_y_weight, features.dtype))
+ # Gather for x_axis.
+ # shape is [batch_size, num_boxes, output_size, output_size, features]
+ features_per_box = tf.einsum('bmhwf,bmow->bmhof', features_per_box,
+ tf.cast(grid_x_weight, features.dtype))
+ else:
+ height_dim_offset = max_feature_width
+ level_dim_offset = max_feature_height * height_dim_offset
+ batch_dim_offset = num_levels * level_dim_offset
+
+ batch_size_offset = tf.tile(
+ tf.reshape(
+ tf.range(batch_size) * batch_dim_offset, [batch_size, 1, 1, 1]),
+ [1, num_boxes, output_size * 2, output_size * 2])
+ box_levels_offset = tf.tile(
+ tf.reshape(box_levels * level_dim_offset,
+ [batch_size, num_boxes, 1, 1]),
+ [1, 1, output_size * 2, output_size * 2])
+ y_indices_offset = tf.tile(
+ tf.reshape(y_indices * height_dim_offset,
+ [batch_size, num_boxes, output_size * 2, 1]),
+ [1, 1, 1, output_size * 2])
+ x_indices_offset = tf.tile(
+ tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
+ [1, 1, output_size * 2, 1])
+
+ indices = tf.reshape(
+ batch_size_offset + box_levels_offset + y_indices_offset +
+ x_indices_offset, [-1])
+
+ features = tf.reshape(features, [-1, num_filters])
+ # TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
+ # performance.
+ features_per_box = tf.reshape(
+ tf.gather(features, indices),
+ [batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
+ features_per_box = feature_bilinear_interpolation(features_per_box,
+ kernel_y, kernel_x)
+
+ return features_per_box
+
+
+def multilevel_crop_and_resize(features, boxes, output_size=7):
+ """Crop and resize on multilevel feature pyramid.
+
+ Generate the (output_size, output_size) set of pixels for each input box
+ by first locating the box into the correct feature level, and then cropping
+ and resizing it using the correspoding feature map of that level.
+
+ Args:
+ features: A dictionary with key as pyramid level and value as features. The
+ features are in shape of [batch_size, height_l, width_l, num_filters].
+ boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row represents
+ a box with [y1, x1, y2, x2] in un-normalized coordinates.
+ output_size: A scalar to indicate the output crop size.
+
+ Returns:
+ A 5-D tensor representing feature crop of shape
+ [batch_size, num_boxes, output_size, output_size, num_filters].
+ """
+
+ with tf.name_scope('multilevel_crop_and_resize'):
+ levels = list(features.keys())
+ min_level = min(levels)
+ max_level = max(levels)
+ batch_size, max_feature_height, max_feature_width, num_filters = (
+ features[min_level].get_shape().as_list())
+ _, num_boxes, _ = boxes.get_shape().as_list()
+
+ # Stack feature pyramid into a features_all of shape
+ # [batch_size, levels, height, width, num_filters].
+ features_all = []
+ feature_heights = []
+ feature_widths = []
+ for level in range(min_level, max_level + 1):
+ shape = features[level].get_shape().as_list()
+ feature_heights.append(shape[1])
+ feature_widths.append(shape[2])
+ # Concat tensor of [batch_size, height_l * width_l, num_filters] for each
+ # levels.
+ features_all.append(
+ tf.reshape(features[level], [batch_size, -1, num_filters]))
+ features_r2 = tf.reshape(tf.concat(features_all, 1), [-1, num_filters])
+
+ # Calculate height_l * width_l for each level.
+ level_dim_sizes = [
+ feature_widths[i] * feature_heights[i]
+ for i in range(len(feature_widths))
+ ]
+ # level_dim_offsets is accumulated sum of level_dim_size.
+ level_dim_offsets = [0]
+ for i in range(len(feature_widths) - 1):
+ level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
+ batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]
+ level_dim_offsets = tf.constant(level_dim_offsets, tf.int32)
+ height_dim_sizes = tf.constant(feature_widths, tf.int32)
+
+ # Assigns boxes to the right level.
+ box_width = boxes[:, :, 3] - boxes[:, :, 1]
+ box_height = boxes[:, :, 2] - boxes[:, :, 0]
+ areas_sqrt = tf.sqrt(box_height * box_width)
+ levels = tf.cast(
+ tf.math.floordiv(
+ tf.math.log(tf.divide(areas_sqrt, 224.0)), tf.math.log(2.0)) + 4.0,
+ dtype=tf.int32)
+ # Maps levels between [min_level, max_level].
+ levels = tf.minimum(max_level, tf.maximum(levels, min_level))
+
+ # Projects box location and sizes to corresponding feature levels.
+ scale_to_level = tf.cast(
+ tf.pow(tf.constant(2.0), tf.cast(levels, tf.float32)),
+ dtype=boxes.dtype)
+ boxes /= tf.expand_dims(scale_to_level, axis=2)
+ box_width /= scale_to_level
+ box_height /= scale_to_level
+ boxes = tf.concat([
+ boxes[:, :, 0:2],
+ tf.expand_dims(box_height, -1),
+ tf.expand_dims(box_width, -1)
+ ],
+ axis=-1)
+
+ # Maps levels to [0, max_level-min_level].
+ levels -= min_level
+ level_strides = tf.pow([[2.0]], tf.cast(levels, tf.float32))
+ boundary = tf.cast(
+ tf.concat([
+ tf.expand_dims(
+ [[tf.cast(max_feature_height, tf.float32)]] / level_strides - 1,
+ axis=-1),
+ tf.expand_dims(
+ [[tf.cast(max_feature_width, tf.float32)]] / level_strides - 1,
+ axis=-1),
+ ],
+ axis=-1), boxes.dtype)
+
+ # Compute grid positions.
+ kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = compute_grid_positions(
+ boxes, boundary, output_size, sample_offset=0.5)
+
+ x_indices = tf.cast(
+ tf.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+ y_indices = tf.cast(
+ tf.reshape(box_gridy0y1, [batch_size, num_boxes, output_size * 2]),
+ dtype=tf.int32)
+
+ batch_size_offset = tf.tile(
+ tf.reshape(
+ tf.range(batch_size) * batch_dim_size, [batch_size, 1, 1, 1]),
+ [1, num_boxes, output_size * 2, output_size * 2])
+ # Get level offset for each box. Each box belongs to one level.
+ levels_offset = tf.tile(
+ tf.reshape(
+ tf.gather(level_dim_offsets, levels),
+ [batch_size, num_boxes, 1, 1]),
+ [1, 1, output_size * 2, output_size * 2])
+ y_indices_offset = tf.tile(
+ tf.reshape(
+ y_indices * tf.expand_dims(tf.gather(height_dim_sizes, levels), -1),
+ [batch_size, num_boxes, output_size * 2, 1]),
+ [1, 1, 1, output_size * 2])
+ x_indices_offset = tf.tile(
+ tf.reshape(x_indices, [batch_size, num_boxes, 1, output_size * 2]),
+ [1, 1, output_size * 2, 1])
+ indices = tf.reshape(
+ batch_size_offset + levels_offset + y_indices_offset + x_indices_offset,
+ [-1])
+
+ # TODO(wangtao): replace tf.gather with tf.gather_nd and try to get similar
+ # performance.
+ features_per_box = tf.reshape(
+ tf.gather(features_r2, indices),
+ [batch_size, num_boxes, output_size * 2, output_size * 2, num_filters])
+
+ # Bilinear interpolation.
+ features_per_box = feature_bilinear_interpolation(features_per_box,
+ kernel_y, kernel_x)
+ return features_per_box
+
+
+def single_level_feature_crop(features, level_boxes, detection_prior_levels,
+ min_mask_level, mask_crop_size):
+ """Crop the FPN features at the appropriate levels for each detection.
+
+
+ Args:
+ features: a float tensor of shape [batch_size, num_levels, max_feature_size,
+ max_feature_size, num_downsample_channels].
+ level_boxes: a float Tensor of the level boxes to crop from. [batch_size,
+ num_instances, 4].
+ detection_prior_levels: an int Tensor of instance assigned level of shape
+ [batch_size, num_instances].
+ min_mask_level: minimum FPN level to crop mask feature from.
+ mask_crop_size: an int of mask crop size.
+
+ Returns:
+ crop_features: a float Tensor of shape [batch_size * num_instances,
+ mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
+ instance feature crop.
+ """
+ (batch_size, num_levels, max_feature_size, _,
+ num_downsample_channels) = features.get_shape().as_list()
+ _, num_of_instances, _ = level_boxes.get_shape().as_list()
+ level_boxes = tf.cast(level_boxes, tf.int32)
+ assert num_of_instances == detection_prior_levels.get_shape().as_list()[1]
+
+ x_start_indices = level_boxes[:, :, 1]
+ y_start_indices = level_boxes[:, :, 0]
+ # generate the full indices (not just the starting index)
+ x_idx_list = []
+ y_idx_list = []
+ for i in range(mask_crop_size):
+ x_idx_list.append(x_start_indices + i)
+ y_idx_list.append(y_start_indices + i)
+
+ x_indices = tf.stack(x_idx_list, axis=2)
+ y_indices = tf.stack(y_idx_list, axis=2)
+ levels = detection_prior_levels - min_mask_level
+ height_dim_size = max_feature_size
+ level_dim_size = max_feature_size * height_dim_size
+ batch_dim_size = num_levels * level_dim_size
+ # TODO(weicheng) change this to gather_nd for better readability.
+ indices = tf.reshape(
+ tf.tile(
+ tf.reshape(
+ tf.range(batch_size) * batch_dim_size, [batch_size, 1, 1, 1]),
+ [1, num_of_instances, mask_crop_size, mask_crop_size]) + tf.tile(
+ tf.reshape(levels * level_dim_size,
+ [batch_size, num_of_instances, 1, 1]),
+ [1, 1, mask_crop_size, mask_crop_size]) + tf.tile(
+ tf.reshape(y_indices * height_dim_size,
+ [batch_size, num_of_instances, mask_crop_size, 1]),
+ [1, 1, 1, mask_crop_size]) +
+ tf.tile(
+ tf.reshape(x_indices,
+ [batch_size, num_of_instances, 1, mask_crop_size]),
+ [1, 1, mask_crop_size, 1]), [-1])
+
+ features_r2 = tf.reshape(features, [-1, num_downsample_channels])
+ crop_features = tf.reshape(
+ tf.gather(features_r2, indices), [
+ batch_size * num_of_instances, mask_crop_size, mask_crop_size,
+ num_downsample_channels
+ ])
+
+ return crop_features
+
+
+def crop_mask_in_target_box(masks,
+ boxes,
+ target_boxes,
+ output_size,
+ sample_offset=0,
+ use_einsum=True):
+ """Crop masks in target boxes.
+
+ Args:
+ masks: A tensor with a shape of [batch_size, num_masks, height, width].
+ boxes: a float tensor representing box cooridnates that tightly enclose
+ masks with a shape of [batch_size, num_masks, 4] in un-normalized
+ coordinates. A box is represented by [ymin, xmin, ymax, xmax].
+ target_boxes: a float tensor representing target box cooridnates for masks
+ with a shape of [batch_size, num_masks, 4] in un-normalized coordinates. A
+ box is represented by [ymin, xmin, ymax, xmax].
+ output_size: A scalar to indicate the output crop size. It currently only
+ supports to output a square shape outputs.
+ sample_offset: a float number in [0, 1] indicates the subpixel sample offset
+ from grid point.
+ use_einsum: Use einsum to replace gather in selective_crop_and_resize.
+
+ Returns:
+ A 4-D tensor representing feature crop of shape
+ [batch_size, num_boxes, output_size, output_size].
+ """
+ with tf.name_scope('crop_mask_in_target_box'):
+ batch_size, num_masks, height, width = masks.get_shape().as_list()
+ masks = tf.reshape(masks, [batch_size * num_masks, height, width, 1])
+ # Pad zeros on the boundary of masks.
+ masks = tf.image.pad_to_bounding_box(masks, 2, 2, height + 4, width + 4)
+ masks = tf.reshape(masks, [batch_size, num_masks, height + 4, width + 4, 1])
+
+ # Projects target box locations and sizes to corresponding cropped
+ # mask coordinates.
+ gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(
+ value=boxes, num_or_size_splits=4, axis=2)
+ bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(
+ value=target_boxes, num_or_size_splits=4, axis=2)
+ y_transform = (bb_y_min - gt_y_min) * height / (gt_y_max - gt_y_min +
+ _EPSILON) + 2
+ x_transform = (bb_x_min - gt_x_min) * height / (gt_x_max - gt_x_min +
+ _EPSILON) + 2
+ h_transform = (bb_y_max - bb_y_min) * width / (
+ gt_y_max - gt_y_min + _EPSILON)
+ w_transform = (bb_x_max - bb_x_min) * width / (
+ gt_x_max - gt_x_min + _EPSILON)
+
+ boundaries = tf.concat([
+ tf.cast(
+ tf.ones_like(y_transform) * ((height + 4) - 1), dtype=tf.float32),
+ tf.cast(
+ tf.ones_like(x_transform) * ((width + 4) - 1), dtype=tf.float32)
+ ],
+ axis=-1)
+
+ # Reshape tensors to have the right shape for selective_crop_and_resize.
+ trasnformed_boxes = tf.concat(
+ [y_transform, x_transform, h_transform, w_transform], -1)
+ levels = tf.tile(
+ tf.reshape(tf.range(num_masks), [1, num_masks]), [batch_size, 1])
+
+ cropped_masks = selective_crop_and_resize(
+ masks,
+ trasnformed_boxes,
+ levels,
+ boundaries,
+ output_size,
+ sample_offset=sample_offset,
+ use_einsum_gather=use_einsum)
+ cropped_masks = tf.squeeze(cropped_masks, axis=-1)
+
+ return cropped_masks
diff --git a/modeling/official/legacy/detection/ops/target_ops.py b/modeling/official/legacy/detection/ops/target_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..689e6c3f65ceeb70cc09667d703f9f61570b0cbb
--- /dev/null
+++ b/modeling/official/legacy/detection/ops/target_ops.py
@@ -0,0 +1,571 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Target and sampling related ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.ops import spatial_transform_ops
+from official.legacy.detection.utils import box_utils
+from official.vision.utils.object_detection import balanced_positive_negative_sampler
+
+
+def box_matching(boxes, gt_boxes, gt_classes):
+ """Match boxes to groundtruth boxes.
+
+ Given the proposal boxes and the groundtruth boxes and classes, perform the
+ groundtruth matching by taking the argmax of the IoU between boxes and
+ groundtruth boxes.
+
+ Args:
+ boxes: a tensor of shape of [batch_size, N, 4] representing the box
+ coordiantes to be matched to groundtruth boxes.
+ gt_boxes: a tensor of shape of [batch_size, MAX_INSTANCES, 4] representing
+ the groundtruth box coordinates. It is padded with -1s to indicate the
+ invalid boxes.
+ gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
+ classes. It is padded with -1s to indicate the invalid classes.
+
+ Returns:
+ matched_gt_boxes: a tensor of shape of [batch_size, N, 4], representing
+ the matched groundtruth box coordinates for each input box. If the box
+ does not overlap with any groundtruth boxes, the matched boxes of it
+ will be set to all 0s.
+ matched_gt_classes: a tensor of shape of [batch_size, N], representing
+ the matched groundtruth classes for each input box. If the box does not
+ overlap with any groundtruth boxes, the matched box classes of it will
+ be set to 0, which corresponds to the background class.
+ matched_gt_indices: a tensor of shape of [batch_size, N], representing
+ the indices of the matched groundtruth boxes in the original gt_boxes
+ tensor. If the box does not overlap with any groundtruth boxes, the
+ index of the matched groundtruth will be set to -1.
+ matched_iou: a tensor of shape of [batch_size, N], representing the IoU
+ between the box and its matched groundtruth box. The matched IoU is the
+ maximum IoU of the box and all the groundtruth boxes.
+ iou: a tensor of shape of [batch_size, N, K], representing the IoU matrix
+ between boxes and the groundtruth boxes. The IoU between a box and the
+ invalid groundtruth boxes whose coordinates are [-1, -1, -1, -1] is -1.
+ """
+ # Compute IoU between boxes and gt_boxes.
+ # iou <- [batch_size, N, K]
+ iou = box_utils.bbox_overlap(boxes, gt_boxes)
+
+ # max_iou <- [batch_size, N]
+ # 0.0 -> no match to gt, or -1.0 match to no gt
+ matched_iou = tf.reduce_max(iou, axis=-1)
+
+ # background_box_mask <- bool, [batch_size, N]
+ background_box_mask = tf.less_equal(matched_iou, 0.0)
+
+ argmax_iou_indices = tf.argmax(iou, axis=-1, output_type=tf.int32)
+
+ argmax_iou_indices_shape = tf.shape(argmax_iou_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(argmax_iou_indices_shape[0]), axis=-1) *
+ tf.ones([1, argmax_iou_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_indices = tf.stack([batch_indices, argmax_iou_indices], axis=-1)
+
+ matched_gt_boxes = tf.gather_nd(gt_boxes, gather_nd_indices)
+ matched_gt_boxes = tf.where(
+ tf.tile(tf.expand_dims(background_box_mask, axis=-1), [1, 1, 4]),
+ tf.zeros_like(matched_gt_boxes, dtype=matched_gt_boxes.dtype),
+ matched_gt_boxes)
+
+ matched_gt_classes = tf.gather_nd(gt_classes, gather_nd_indices)
+ matched_gt_classes = tf.where(background_box_mask,
+ tf.zeros_like(matched_gt_classes),
+ matched_gt_classes)
+
+ matched_gt_indices = tf.where(background_box_mask,
+ -tf.ones_like(argmax_iou_indices),
+ argmax_iou_indices)
+
+ return (matched_gt_boxes, matched_gt_classes, matched_gt_indices, matched_iou,
+ iou)
+
+
+def assign_and_sample_proposals(proposed_boxes,
+ gt_boxes,
+ gt_classes,
+ num_samples_per_image=512,
+ mix_gt_boxes=True,
+ fg_fraction=0.25,
+ fg_iou_thresh=0.5,
+ bg_iou_thresh_hi=0.5,
+ bg_iou_thresh_lo=0.0):
+ """Assigns the proposals with groundtruth classes and performs subsmpling.
+
+ Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
+ following algorithm to generate the final `num_samples_per_image` RoIs.
+ 1. Calculates the IoU between each proposal box and each gt_boxes.
+ 2. Assigns each proposed box with a groundtruth class and box by choosing
+ the largest IoU overlap.
+ 3. Samples `num_samples_per_image` boxes from all proposed boxes, and
+ returns box_targets, class_targets, and RoIs.
+
+ Args:
+ proposed_boxes: a tensor of shape of [batch_size, N, 4]. N is the number of
+ proposals before groundtruth assignment. The last dimension is the box
+ coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax] format.
+ gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4]. The
+ coordinates of gt_boxes are in the pixel coordinates of the scaled image.
+ This tensor might have padding of values -1 indicating the invalid box
+ coordinates.
+ gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
+ tensor might have paddings with values of -1 indicating the invalid
+ classes.
+ num_samples_per_image: a integer represents RoI minibatch size per image.
+ mix_gt_boxes: a bool indicating whether to mix the groundtruth boxes before
+ sampling proposals.
+ fg_fraction: a float represents the target fraction of RoI minibatch that is
+ labeled foreground (i.e., class > 0).
+ fg_iou_thresh: a float represents the IoU overlap threshold for an RoI to be
+ considered foreground (if >= fg_iou_thresh).
+ bg_iou_thresh_hi: a float represents the IoU overlap threshold for an RoI to
+ be considered background (class = 0 if overlap in [LO, HI)).
+ bg_iou_thresh_lo: a float represents the IoU overlap threshold for an RoI to
+ be considered background (class = 0 if overlap in [LO, HI)).
+
+ Returns:
+ sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
+ coordinates of the sampled RoIs, where K is the number of the sampled
+ RoIs, i.e. K = num_samples_per_image.
+ sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
+ box coordinates of the matched groundtruth boxes of the samples RoIs.
+ sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
+ classes of the matched groundtruth boxes of the sampled RoIs.
+ sampled_gt_indices: a tensor of shape of [batch_size, K], storing the
+ indices of the sampled groudntruth boxes in the original `gt_boxes`
+ tensor, i.e. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
+ """
+
+ with tf.name_scope('sample_proposals'):
+ if mix_gt_boxes:
+ boxes = tf.concat([proposed_boxes, gt_boxes], axis=1)
+ else:
+ boxes = proposed_boxes
+
+ (matched_gt_boxes, matched_gt_classes, matched_gt_indices, matched_iou,
+ _) = box_matching(boxes, gt_boxes, gt_classes)
+
+ positive_match = tf.greater(matched_iou, fg_iou_thresh)
+ negative_match = tf.logical_and(
+ tf.greater_equal(matched_iou, bg_iou_thresh_lo),
+ tf.less(matched_iou, bg_iou_thresh_hi))
+ ignored_match = tf.less(matched_iou, 0.0)
+
+ # re-assign negatively matched boxes to the background class.
+ matched_gt_classes = tf.where(negative_match,
+ tf.zeros_like(matched_gt_classes),
+ matched_gt_classes)
+ matched_gt_indices = tf.where(negative_match,
+ tf.zeros_like(matched_gt_indices),
+ matched_gt_indices)
+
+ sample_candidates = tf.logical_and(
+ tf.logical_or(positive_match, negative_match),
+ tf.logical_not(ignored_match))
+
+ sampler = (
+ balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
+ positive_fraction=fg_fraction, is_static=True))
+
+ batch_size, _ = sample_candidates.get_shape().as_list()
+ sampled_indicators = []
+ for i in range(batch_size):
+ sampled_indicator = sampler.subsample(sample_candidates[i],
+ num_samples_per_image,
+ positive_match[i])
+ sampled_indicators.append(sampled_indicator)
+ sampled_indicators = tf.stack(sampled_indicators)
+ _, sampled_indices = tf.nn.top_k(
+ tf.cast(sampled_indicators, dtype=tf.int32),
+ k=num_samples_per_image,
+ sorted=True)
+
+ sampled_indices_shape = tf.shape(sampled_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(sampled_indices_shape[0]), axis=-1) *
+ tf.ones([1, sampled_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_indices = tf.stack([batch_indices, sampled_indices], axis=-1)
+
+ sampled_rois = tf.gather_nd(boxes, gather_nd_indices)
+ sampled_gt_boxes = tf.gather_nd(matched_gt_boxes, gather_nd_indices)
+ sampled_gt_classes = tf.gather_nd(matched_gt_classes, gather_nd_indices)
+ sampled_gt_indices = tf.gather_nd(matched_gt_indices, gather_nd_indices)
+
+ return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
+ sampled_gt_indices)
+
+
+def sample_and_crop_foreground_masks(candidate_rois,
+ candidate_gt_boxes,
+ candidate_gt_classes,
+ candidate_gt_indices,
+ gt_masks,
+ num_mask_samples_per_image=128,
+ mask_target_size=28):
+ """Samples and creates cropped foreground masks for training.
+
+ Args:
+ candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the
+ number of candidate RoIs to be considered for mask sampling. It includes
+ both positive and negative RoIs. The `num_mask_samples_per_image` positive
+ RoIs will be sampled to create mask training targets.
+ candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the
+ corresponding groundtruth boxes to the `candidate_rois`.
+ candidate_gt_classes: a tensor of shape of [batch_size, N], storing the
+ corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor
+ corresponds to the background class, i.e. negative RoIs.
+ candidate_gt_indices: a tensor of shape [batch_size, N], storing the
+ corresponding groundtruth instance indices to the `candidate_gt_boxes`,
+ i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and
+ gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is
+ the superset of candidate_gt_boxes.
+ gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width]
+ containing all the groundtruth masks which sample masks are drawn from.
+ num_mask_samples_per_image: an integer which specifies the number of masks
+ to sample.
+ mask_target_size: an integer which specifies the final cropped mask size
+ after sampling. The output masks are resized w.r.t the sampled RoIs.
+
+ Returns:
+ foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI
+ that corresponds to the sampled foreground masks, where
+ K = num_mask_samples_per_image.
+ foreground_classes: a tensor of shape of [batch_size, K] storing the classes
+ corresponding to the sampled foreground masks.
+ cropoped_foreground_masks: a tensor of shape of
+ [batch_size, K, mask_target_size, mask_target_size] storing the cropped
+ foreground masks used for training.
+ """
+ with tf.name_scope('sample_and_crop_foreground_masks'):
+ _, fg_instance_indices = tf.nn.top_k(
+ tf.cast(tf.greater(candidate_gt_classes, 0), dtype=tf.int32),
+ k=num_mask_samples_per_image)
+
+ fg_instance_indices_shape = tf.shape(fg_instance_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(fg_instance_indices_shape[0]), axis=-1) *
+ tf.ones([1, fg_instance_indices_shape[-1]], dtype=tf.int32))
+
+ gather_nd_instance_indices = tf.stack([batch_indices, fg_instance_indices],
+ axis=-1)
+ foreground_rois = tf.gather_nd(candidate_rois, gather_nd_instance_indices)
+ foreground_boxes = tf.gather_nd(candidate_gt_boxes,
+ gather_nd_instance_indices)
+ foreground_classes = tf.gather_nd(candidate_gt_classes,
+ gather_nd_instance_indices)
+ foreground_gt_indices = tf.gather_nd(candidate_gt_indices,
+ gather_nd_instance_indices)
+
+ foreground_gt_indices_shape = tf.shape(foreground_gt_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(foreground_gt_indices_shape[0]), axis=-1) *
+ tf.ones([1, foreground_gt_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_gt_indices = tf.stack([batch_indices, foreground_gt_indices],
+ axis=-1)
+ foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices)
+
+ cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box(
+ foreground_masks,
+ foreground_boxes,
+ foreground_rois,
+ mask_target_size,
+ sample_offset=0.5)
+
+ return foreground_rois, foreground_classes, cropped_foreground_masks
+
+
+class ROISampler(tf_keras.layers.Layer):
+ """Samples RoIs and creates training targets."""
+
+ def __init__(self, params):
+ self._num_samples_per_image = params.num_samples_per_image
+ self._fg_fraction = params.fg_fraction
+ self._fg_iou_thresh = params.fg_iou_thresh
+ self._bg_iou_thresh_hi = params.bg_iou_thresh_hi
+ self._bg_iou_thresh_lo = params.bg_iou_thresh_lo
+ self._mix_gt_boxes = params.mix_gt_boxes
+ super(ROISampler, self).__init__(autocast=False)
+
+ def call(self, rois, gt_boxes, gt_classes):
+ """Sample and assign RoIs for training.
+
+ Args:
+ rois: a tensor of shape of [batch_size, N, 4]. N is the number of
+ proposals before groundtruth assignment. The last dimension is the box
+ coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax] format.
+ gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4]. The
+ coordinates of gt_boxes are in the pixel coordinates of the scaled
+ image. This tensor might have padding of values -1 indicating the
+ invalid box coordinates.
+ gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
+ tensor might have paddings with values of -1 indicating the invalid
+ classes.
+
+ Returns:
+ sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
+ coordinates of the sampled RoIs, where K is the number of the sampled
+ RoIs, i.e. K = num_samples_per_image.
+ sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
+ box coordinates of the matched groundtruth boxes of the samples RoIs.
+ sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
+ classes of the matched groundtruth boxes of the sampled RoIs.
+ """
+ sampled_rois, sampled_gt_boxes, sampled_gt_classes, sampled_gt_indices = (
+ assign_and_sample_proposals(
+ rois,
+ gt_boxes,
+ gt_classes,
+ num_samples_per_image=self._num_samples_per_image,
+ mix_gt_boxes=self._mix_gt_boxes,
+ fg_fraction=self._fg_fraction,
+ fg_iou_thresh=self._fg_iou_thresh,
+ bg_iou_thresh_hi=self._bg_iou_thresh_hi,
+ bg_iou_thresh_lo=self._bg_iou_thresh_lo))
+ return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
+ sampled_gt_indices)
+
+
+class ROIScoreSampler(ROISampler):
+ """Samples RoIs, RoI-scores and creates training targets."""
+
+ def __call__(self, rois, roi_scores, gt_boxes, gt_classes):
+ """Sample and assign RoIs for training.
+
+ Args:
+ rois: a tensor of shape of [batch_size, N, 4]. N is the number of
+ proposals before groundtruth assignment. The last dimension is the box
+ coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax] format.
+ roi_scores:
+ gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4]. The
+ coordinates of gt_boxes are in the pixel coordinates of the scaled
+ image. This tensor might have padding of values -1 indicating the
+ invalid box coordinates.
+ gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
+ tensor might have paddings with values of -1 indicating the invalid
+ classes.
+
+ Returns:
+ sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
+ coordinates of the sampled RoIs, where K is the number of the sampled
+ RoIs, i.e. K = num_samples_per_image.
+ sampled_roi_scores:
+ sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
+ box coordinates of the matched groundtruth boxes of the samples RoIs.
+ sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
+ classes of the matched groundtruth boxes of the sampled RoIs.
+ """
+ (sampled_rois, sampled_roi_scores, sampled_gt_boxes, sampled_gt_classes,
+ sampled_gt_indices) = (
+ self.assign_and_sample_proposals_and_scores(
+ rois,
+ roi_scores,
+ gt_boxes,
+ gt_classes,
+ num_samples_per_image=self._num_samples_per_image,
+ mix_gt_boxes=self._mix_gt_boxes,
+ fg_fraction=self._fg_fraction,
+ fg_iou_thresh=self._fg_iou_thresh,
+ bg_iou_thresh_hi=self._bg_iou_thresh_hi,
+ bg_iou_thresh_lo=self._bg_iou_thresh_lo))
+ return (sampled_rois, sampled_roi_scores, sampled_gt_boxes,
+ sampled_gt_classes, sampled_gt_indices)
+
+ def assign_and_sample_proposals_and_scores(self,
+ proposed_boxes,
+ proposed_scores,
+ gt_boxes,
+ gt_classes,
+ num_samples_per_image=512,
+ mix_gt_boxes=True,
+ fg_fraction=0.25,
+ fg_iou_thresh=0.5,
+ bg_iou_thresh_hi=0.5,
+ bg_iou_thresh_lo=0.0):
+ """Assigns the proposals with groundtruth classes and performs subsmpling.
+
+ Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
+ following algorithm to generate the final `num_samples_per_image` RoIs.
+ 1. Calculates the IoU between each proposal box and each gt_boxes.
+ 2. Assigns each proposed box with a groundtruth class and box by choosing
+ the largest IoU overlap.
+ 3. Samples `num_samples_per_image` boxes from all proposed boxes, and
+ returns box_targets, class_targets, and RoIs.
+
+ Args:
+ proposed_boxes: a tensor of shape of [batch_size, N, 4]. N is the number
+ of proposals before groundtruth assignment. The last dimension is the
+ box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
+ format.
+ proposed_scores: a tensor of shape of [batch_size, N]. N is the number of
+ proposals before groundtruth assignment. It is the rpn scores for all
+ proposed boxes which can be either their classification or centerness
+ scores.
+ gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4]. The
+ coordinates of gt_boxes are in the pixel coordinates of the scaled
+ image. This tensor might have padding of values -1 indicating the
+ invalid box coordinates.
+ gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
+ tensor might have paddings with values of -1 indicating the invalid
+ classes.
+ num_samples_per_image: a integer represents RoI minibatch size per image.
+ mix_gt_boxes: a bool indicating whether to mix the groundtruth boxes
+ before sampling proposals.
+ fg_fraction: a float represents the target fraction of RoI minibatch that
+ is labeled foreground (i.e., class > 0).
+ fg_iou_thresh: a float represents the IoU overlap threshold for an RoI to
+ be considered foreground (if >= fg_iou_thresh).
+ bg_iou_thresh_hi: a float represents the IoU overlap threshold for an RoI
+ to be considered background (class = 0 if overlap in [LO, HI)).
+ bg_iou_thresh_lo: a float represents the IoU overlap threshold for an RoI
+ to be considered background (class = 0 if overlap in [LO, HI)).
+
+ Returns:
+ sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
+ coordinates of the sampled RoIs, where K is the number of the sampled
+ RoIs, i.e. K = num_samples_per_image.
+ sampled_scores: a tensor of shape of [batch_size, K], representing the
+ confidence score of the sampled RoIs, where K is the number of the
+ sampled RoIs, i.e. K = num_samples_per_image.
+ sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
+ box coordinates of the matched groundtruth boxes of the samples RoIs.
+ sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
+ classes of the matched groundtruth boxes of the sampled RoIs.
+ sampled_gt_indices: a tensor of shape of [batch_size, K], storing the
+ indices of the sampled groudntruth boxes in the original `gt_boxes`
+ tensor, i.e. gt_boxes[sampled_gt_indices[:, i]] =
+ sampled_gt_boxes[:, i].
+ """
+
+ with tf.name_scope('sample_proposals_and_scores'):
+ if mix_gt_boxes:
+ boxes = tf.concat([proposed_boxes, gt_boxes], axis=1)
+ gt_scores = tf.ones_like(gt_boxes[:, :, 0])
+ scores = tf.concat([proposed_scores, gt_scores], axis=1)
+ else:
+ boxes = proposed_boxes
+ scores = proposed_scores
+
+ (matched_gt_boxes, matched_gt_classes, matched_gt_indices, matched_iou,
+ _) = box_matching(boxes, gt_boxes, gt_classes)
+
+ positive_match = tf.greater(matched_iou, fg_iou_thresh)
+ negative_match = tf.logical_and(
+ tf.greater_equal(matched_iou, bg_iou_thresh_lo),
+ tf.less(matched_iou, bg_iou_thresh_hi))
+ ignored_match = tf.less(matched_iou, 0.0)
+
+ # re-assign negatively matched boxes to the background class.
+ matched_gt_classes = tf.where(negative_match,
+ tf.zeros_like(matched_gt_classes),
+ matched_gt_classes)
+ matched_gt_indices = tf.where(negative_match,
+ tf.zeros_like(matched_gt_indices),
+ matched_gt_indices)
+
+ sample_candidates = tf.logical_and(
+ tf.logical_or(positive_match, negative_match),
+ tf.logical_not(ignored_match))
+
+ sampler = (
+ balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
+ positive_fraction=fg_fraction, is_static=True))
+
+ batch_size, _ = sample_candidates.get_shape().as_list()
+ sampled_indicators = []
+ for i in range(batch_size):
+ sampled_indicator = sampler.subsample(sample_candidates[i],
+ num_samples_per_image,
+ positive_match[i])
+ sampled_indicators.append(sampled_indicator)
+ sampled_indicators = tf.stack(sampled_indicators)
+ _, sampled_indices = tf.nn.top_k(
+ tf.cast(sampled_indicators, dtype=tf.int32),
+ k=num_samples_per_image,
+ sorted=True)
+
+ sampled_indices_shape = tf.shape(sampled_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(sampled_indices_shape[0]), axis=-1) *
+ tf.ones([1, sampled_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_indices = tf.stack([batch_indices, sampled_indices], axis=-1)
+
+ sampled_rois = tf.gather_nd(boxes, gather_nd_indices)
+ sampled_roi_scores = tf.gather_nd(scores, gather_nd_indices)
+ sampled_gt_boxes = tf.gather_nd(matched_gt_boxes, gather_nd_indices)
+ sampled_gt_classes = tf.gather_nd(matched_gt_classes, gather_nd_indices)
+ sampled_gt_indices = tf.gather_nd(matched_gt_indices, gather_nd_indices)
+
+ return (sampled_rois, sampled_roi_scores, sampled_gt_boxes,
+ sampled_gt_classes, sampled_gt_indices)
+
+
+class MaskSampler(tf_keras.layers.Layer):
+ """Samples and creates mask training targets."""
+
+ def __init__(self, mask_target_size, num_mask_samples_per_image):
+ self._mask_target_size = mask_target_size
+ self._num_mask_samples_per_image = num_mask_samples_per_image
+ super(MaskSampler, self).__init__(autocast=False)
+
+ def call(self,
+ candidate_rois,
+ candidate_gt_boxes,
+ candidate_gt_classes,
+ candidate_gt_indices,
+ gt_masks):
+ """Sample and create mask targets for training.
+
+ Args:
+ candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the
+ number of candidate RoIs to be considered for mask sampling. It includes
+ both positive and negative RoIs. The `num_mask_samples_per_image`
+ positive RoIs will be sampled to create mask training targets.
+ candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the
+ corresponding groundtruth boxes to the `candidate_rois`.
+ candidate_gt_classes: a tensor of shape of [batch_size, N], storing the
+ corresponding groundtruth classes to the `candidate_rois`. 0 in the
+ tensor corresponds to the background class, i.e. negative RoIs.
+ candidate_gt_indices: a tensor of shape [batch_size, N], storing the
+ corresponding groundtruth instance indices to the `candidate_gt_boxes`,
+ i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i],
+ where gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >=
+ N, is the superset of candidate_gt_boxes.
+ gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width]
+ containing all the groundtruth masks which sample masks are drawn from.
+ after sampling. The output masks are resized w.r.t the sampled RoIs.
+
+ Returns:
+ foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI
+ that corresponds to the sampled foreground masks, where
+ K = num_mask_samples_per_image.
+ foreground_classes: a tensor of shape of [batch_size, K] storing the
+ classes corresponding to the sampled foreground masks.
+ cropoped_foreground_masks: a tensor of shape of
+ [batch_size, K, mask_target_size, mask_target_size] storing the
+ cropped foreground masks used for training.
+ """
+ foreground_rois, foreground_classes, cropped_foreground_masks = (
+ sample_and_crop_foreground_masks(candidate_rois, candidate_gt_boxes,
+ candidate_gt_classes,
+ candidate_gt_indices, gt_masks,
+ self._num_mask_samples_per_image,
+ self._mask_target_size))
+ return foreground_rois, foreground_classes, cropped_foreground_masks
diff --git a/modeling/official/legacy/detection/utils/__init__.py b/modeling/official/legacy/detection/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/detection/utils/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/detection/utils/box_utils.py b/modeling/official/legacy/detection/utils/box_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..04695e2d2b9bad39fe6e4541b7276721d223114c
--- /dev/null
+++ b/modeling/official/legacy/detection/utils/box_utils.py
@@ -0,0 +1,700 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for bounding box processing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+EPSILON = 1e-8
+BBOX_XFORM_CLIP = np.log(1000. / 16.)
+
+
+def visualize_images_with_bounding_boxes(images, box_outputs, step,
+ summary_writer):
+ """Records subset of evaluation images with bounding boxes."""
+ image_shape = tf.shape(images[0])
+ image_height = tf.cast(image_shape[0], tf.float32)
+ image_width = tf.cast(image_shape[1], tf.float32)
+ normalized_boxes = normalize_boxes(box_outputs, [image_height, image_width])
+
+ bounding_box_color = tf.constant([[1.0, 1.0, 0.0, 1.0]])
+ image_summary = tf.image.draw_bounding_boxes(images, normalized_boxes,
+ bounding_box_color)
+ with summary_writer.as_default():
+ tf.summary.image('bounding_box_summary', image_summary, step=step)
+ summary_writer.flush()
+
+
+def yxyx_to_xywh(boxes):
+ """Converts boxes from ymin, xmin, ymax, xmax to xmin, ymin, width, height.
+
+ Args:
+ boxes: a numpy array whose last dimension is 4 representing the coordinates
+ of boxes in ymin, xmin, ymax, xmax order.
+
+ Returns:
+ boxes: a numpy array whose shape is the same as `boxes` in new format.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ boxes_ymin = boxes[..., 0]
+ boxes_xmin = boxes[..., 1]
+ boxes_width = boxes[..., 3] - boxes[..., 1]
+ boxes_height = boxes[..., 2] - boxes[..., 0]
+ new_boxes = np.stack([boxes_xmin, boxes_ymin, boxes_width, boxes_height],
+ axis=-1)
+
+ return new_boxes
+
+
+def jitter_boxes(boxes, noise_scale=0.025):
+ """Jitter the box coordinates by some noise distribution.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ noise_scale: a python float which specifies the magnitude of noise. The rule
+ of thumb is to set this between (0, 0.1]. The default value is found to
+ mimic the noisy detections best empirically.
+
+ Returns:
+ jittered_boxes: a tensor whose shape is the same as `boxes` representing
+ the jittered boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('jitter_boxes'):
+ bbox_jitters = tf.random.normal(boxes.get_shape(), stddev=noise_scale)
+ ymin = boxes[..., 0:1]
+ xmin = boxes[..., 1:2]
+ ymax = boxes[..., 2:3]
+ xmax = boxes[..., 3:4]
+ width = xmax - xmin
+ height = ymax - ymin
+ new_center_x = (xmin + xmax) / 2.0 + bbox_jitters[..., 0:1] * width
+ new_center_y = (ymin + ymax) / 2.0 + bbox_jitters[..., 1:2] * height
+ new_width = width * tf.math.exp(bbox_jitters[..., 2:3])
+ new_height = height * tf.math.exp(bbox_jitters[..., 3:4])
+ jittered_boxes = tf.concat([
+ new_center_y - new_height * 0.5, new_center_x - new_width * 0.5,
+ new_center_y + new_height * 0.5, new_center_x + new_width * 0.5
+ ],
+ axis=-1)
+
+ return jittered_boxes
+
+
+def normalize_boxes(boxes, image_shape):
+ """Converts boxes to the normalized coordinates.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+
+ Returns:
+ normalized_boxes: a tensor whose shape is the same as `boxes` representing
+ the normalized boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('normalize_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height = image_shape[..., 0:1]
+ width = image_shape[..., 1:2]
+
+ ymin = boxes[..., 0:1] / height
+ xmin = boxes[..., 1:2] / width
+ ymax = boxes[..., 2:3] / height
+ xmax = boxes[..., 3:4] / width
+
+ normalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
+ return normalized_boxes
+
+
+def denormalize_boxes(boxes, image_shape):
+ """Converts boxes normalized by [height, width] to pixel coordinates.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+
+ Returns:
+ denormalized_boxes: a tensor whose shape is the same as `boxes` representing
+ the denormalized boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ with tf.name_scope('denormalize_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height, width = tf.split(image_shape, 2, axis=-1)
+
+ ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=-1)
+ ymin = ymin * height
+ xmin = xmin * width
+ ymax = ymax * height
+ xmax = xmax * width
+
+ denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
+ return denormalized_boxes
+
+
+def clip_boxes(boxes, image_shape):
+ """Clips boxes to image boundaries.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+
+ Returns:
+ clipped_boxes: a tensor whose shape is the same as `boxes` representing the
+ clipped boxes.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('clip_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ max_length = [height - 1.0, width - 1.0, height - 1.0, width - 1.0]
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height, width = tf.unstack(image_shape, axis=-1)
+ max_length = tf.stack(
+ [height - 1.0, width - 1.0, height - 1.0, width - 1.0], axis=-1)
+
+ clipped_boxes = tf.math.maximum(tf.math.minimum(boxes, max_length), 0.0)
+ return clipped_boxes
+
+
+def compute_outer_boxes(boxes, image_shape, scale=1.0):
+ """Compute outer box encloses an object with a margin.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ image_shape: a list of two integers, a two-element vector or a tensor such
+ that all but the last dimensions are `broadcastable` to `boxes`. The last
+ dimension is 2, which represents [height, width].
+ scale: a float number specifying the scale of output outer boxes to input
+ `boxes`.
+
+ Returns:
+ outer_boxes: a tensor whose shape is the same as `boxes` representing the
+ outer boxes.
+ """
+ if scale < 1.0:
+ raise ValueError(
+ 'scale is {}, but outer box scale must be greater than 1.0.'.format(
+ scale))
+ centers_y = (boxes[..., 0] + boxes[..., 2]) / 2.0
+ centers_x = (boxes[..., 1] + boxes[..., 3]) / 2.0
+ box_height = (boxes[..., 2] - boxes[..., 0]) * scale
+ box_width = (boxes[..., 3] - boxes[..., 1]) * scale
+ outer_boxes = tf.stack([
+ centers_y - box_height / 2.0, centers_x - box_width / 2.0,
+ centers_y + box_height / 2.0, centers_x + box_width / 2.0
+ ],
+ axis=1)
+ outer_boxes = clip_boxes(outer_boxes, image_shape)
+ return outer_boxes
+
+
+def encode_boxes(boxes, anchors, weights=None):
+ """Encode boxes to targets.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
+ representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
+ weights: None or a list of four float numbers used to scale coordinates.
+
+ Returns:
+ encoded_boxes: a tensor whose shape is the same as `boxes` representing the
+ encoded box targets.
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[-1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('encode_boxes'):
+ boxes = tf.cast(boxes, dtype=anchors.dtype)
+ ymin = boxes[..., 0:1]
+ xmin = boxes[..., 1:2]
+ ymax = boxes[..., 2:3]
+ xmax = boxes[..., 3:4]
+ box_h = ymax - ymin + 1.0
+ box_w = xmax - xmin + 1.0
+ box_yc = ymin + 0.5 * box_h
+ box_xc = xmin + 0.5 * box_w
+
+ anchor_ymin = anchors[..., 0:1]
+ anchor_xmin = anchors[..., 1:2]
+ anchor_ymax = anchors[..., 2:3]
+ anchor_xmax = anchors[..., 3:4]
+ anchor_h = anchor_ymax - anchor_ymin + 1.0
+ anchor_w = anchor_xmax - anchor_xmin + 1.0
+ anchor_yc = anchor_ymin + 0.5 * anchor_h
+ anchor_xc = anchor_xmin + 0.5 * anchor_w
+
+ encoded_dy = (box_yc - anchor_yc) / anchor_h
+ encoded_dx = (box_xc - anchor_xc) / anchor_w
+ encoded_dh = tf.math.log(box_h / anchor_h)
+ encoded_dw = tf.math.log(box_w / anchor_w)
+ if weights:
+ encoded_dy *= weights[0]
+ encoded_dx *= weights[1]
+ encoded_dh *= weights[2]
+ encoded_dw *= weights[3]
+
+ encoded_boxes = tf.concat([encoded_dy, encoded_dx, encoded_dh, encoded_dw],
+ axis=-1)
+ return encoded_boxes
+
+
+def decode_boxes(encoded_boxes, anchors, weights=None):
+ """Decode boxes.
+
+ Args:
+ encoded_boxes: a tensor whose last dimension is 4 representing the
+ coordinates of encoded boxes in ymin, xmin, ymax, xmax order.
+ anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
+ representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
+ weights: None or a list of four float numbers used to scale coordinates.
+
+ Returns:
+ encoded_boxes: a tensor whose shape is the same as `boxes` representing the
+ decoded box targets.
+ """
+ if encoded_boxes.shape[-1] != 4:
+ raise ValueError('encoded_boxes.shape[-1] is {:d}, but must be 4.'.format(
+ encoded_boxes.shape[-1]))
+
+ with tf.name_scope('decode_boxes'):
+ encoded_boxes = tf.cast(encoded_boxes, dtype=anchors.dtype)
+ dy = encoded_boxes[..., 0:1]
+ dx = encoded_boxes[..., 1:2]
+ dh = encoded_boxes[..., 2:3]
+ dw = encoded_boxes[..., 3:4]
+ if weights:
+ dy /= weights[0]
+ dx /= weights[1]
+ dh /= weights[2]
+ dw /= weights[3]
+ dh = tf.math.minimum(dh, BBOX_XFORM_CLIP)
+ dw = tf.math.minimum(dw, BBOX_XFORM_CLIP)
+
+ anchor_ymin = anchors[..., 0:1]
+ anchor_xmin = anchors[..., 1:2]
+ anchor_ymax = anchors[..., 2:3]
+ anchor_xmax = anchors[..., 3:4]
+ anchor_h = anchor_ymax - anchor_ymin + 1.0
+ anchor_w = anchor_xmax - anchor_xmin + 1.0
+ anchor_yc = anchor_ymin + 0.5 * anchor_h
+ anchor_xc = anchor_xmin + 0.5 * anchor_w
+
+ decoded_boxes_yc = dy * anchor_h + anchor_yc
+ decoded_boxes_xc = dx * anchor_w + anchor_xc
+ decoded_boxes_h = tf.math.exp(dh) * anchor_h
+ decoded_boxes_w = tf.math.exp(dw) * anchor_w
+
+ decoded_boxes_ymin = decoded_boxes_yc - 0.5 * decoded_boxes_h
+ decoded_boxes_xmin = decoded_boxes_xc - 0.5 * decoded_boxes_w
+ decoded_boxes_ymax = decoded_boxes_ymin + decoded_boxes_h - 1.0
+ decoded_boxes_xmax = decoded_boxes_xmin + decoded_boxes_w - 1.0
+
+ decoded_boxes = tf.concat([
+ decoded_boxes_ymin, decoded_boxes_xmin, decoded_boxes_ymax,
+ decoded_boxes_xmax
+ ],
+ axis=-1)
+ return decoded_boxes
+
+
+def encode_boxes_lrtb(boxes, anchors, weights=None):
+ """Encode boxes to targets on lrtb (=left,right,top,bottom) format.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates
+ of boxes in ymin, xmin, ymax, xmax order.
+ anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
+ representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
+ weights: None or a list of four float numbers used to scale coordinates.
+
+ Returns:
+ encoded_boxes_lrtb: a tensor whose shape is the same as `boxes` representing
+ the encoded box targets. The box targets encode the left, right, top,
+ bottom distances from an anchor location to the four borders of the
+ matched groundtruth bounding box.
+ center_targets: centerness targets defined by the left, right, top, and
+ bottom distance targets. The centerness is defined as the deviation of the
+ anchor location from the groundtruth object center. Formally, centerness =
+ sqrt(min(left, right)/max(left, right)*min(top, bottom)/max(top, bottom)).
+
+ Raises:
+ ValueError: If the last dimension of boxes is not 4.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError(
+ 'boxes.shape[-1] is {:d}, but must be 4.'.format(boxes.shape[-1]))
+
+ with tf.name_scope('encode_boxes_lrtb'):
+ boxes = tf.cast(boxes, dtype=anchors.dtype)
+ ymin = boxes[..., 0:1]
+ xmin = boxes[..., 1:2]
+ ymax = boxes[..., 2:3]
+ xmax = boxes[..., 3:4]
+ # box_h = ymax - ymin + 1.0
+ # box_w = xmax - xmin + 1.0
+ box_h = ymax - ymin
+ box_w = xmax - xmin
+
+ anchor_ymin = anchors[..., 0:1]
+ anchor_xmin = anchors[..., 1:2]
+ anchor_ymax = anchors[..., 2:3]
+ anchor_xmax = anchors[..., 3:4]
+ # anchor_h = anchor_ymax - anchor_ymin + 1.0
+ # anchor_w = anchor_xmax - anchor_xmin + 1.0
+ anchor_h = anchor_ymax - anchor_ymin
+ anchor_w = anchor_xmax - anchor_xmin
+ anchor_yc = anchor_ymin + 0.5 * anchor_h
+ anchor_xc = anchor_xmin + 0.5 * anchor_w
+
+ box_h += EPSILON
+ box_w += EPSILON
+ anchor_h += EPSILON
+ anchor_w += EPSILON
+
+ left = (anchor_xc - xmin) / anchor_w
+ right = (xmax - anchor_xc) / anchor_w
+ top = (anchor_yc - ymin) / anchor_h
+ bottom = (ymax - anchor_yc) / anchor_h
+
+ # Create centerness target. {
+ lrtb_targets = tf.concat([left, right, top, bottom], axis=-1)
+ valid_match = tf.greater(tf.reduce_min(lrtb_targets, -1), 0.0)
+
+ # Centerness score.
+ left_right = tf.concat([left, right], axis=-1)
+
+ left_right = tf.where(tf.stack([valid_match, valid_match], -1),
+ left_right, tf.zeros_like(left_right))
+ top_bottom = tf.concat([top, bottom], axis=-1)
+ top_bottom = tf.where(tf.stack([valid_match, valid_match], -1),
+ top_bottom, tf.zeros_like(top_bottom))
+ center_targets = tf.sqrt(
+ (tf.reduce_min(left_right, -1) /
+ (tf.reduce_max(left_right, -1) + EPSILON)) *
+ (tf.reduce_min(top_bottom, -1) /
+ (tf.reduce_max(top_bottom, -1) + EPSILON)))
+ center_targets = tf.where(valid_match,
+ center_targets,
+ tf.zeros_like(center_targets))
+ if weights:
+ left *= weights[0]
+ right *= weights[1]
+ top *= weights[2]
+ bottom *= weights[3]
+
+ encoded_boxes_lrtb = tf.concat(
+ [left, right, top, bottom],
+ axis=-1)
+
+ return encoded_boxes_lrtb, center_targets
+
+
+def decode_boxes_lrtb(encoded_boxes_lrtb, anchors, weights=None):
+ """Decode boxes.
+
+ Args:
+ encoded_boxes_lrtb: a tensor whose last dimension is 4 representing the
+ coordinates of encoded boxes in left, right, top, bottom order.
+ anchors: a tensor whose shape is the same as, or `broadcastable` to `boxes`,
+ representing the coordinates of anchors in ymin, xmin, ymax, xmax order.
+ weights: None or a list of four float numbers used to scale coordinates.
+
+ Returns:
+ decoded_boxes_lrtb: a tensor whose shape is the same as `boxes` representing
+ the decoded box targets in lrtb (=left,right,top,bottom) format. The box
+ decoded box coordinates represent the left, right, top, and bottom
+ distances from an anchor location to the four borders of the matched
+ groundtruth bounding box.
+ """
+ if encoded_boxes_lrtb.shape[-1] != 4:
+ raise ValueError(
+ 'encoded_boxes_lrtb.shape[-1] is {:d}, but must be 4.'
+ .format(encoded_boxes_lrtb.shape[-1]))
+
+ with tf.name_scope('decode_boxes_lrtb'):
+ encoded_boxes_lrtb = tf.cast(encoded_boxes_lrtb, dtype=anchors.dtype)
+ left = encoded_boxes_lrtb[..., 0:1]
+ right = encoded_boxes_lrtb[..., 1:2]
+ top = encoded_boxes_lrtb[..., 2:3]
+ bottom = encoded_boxes_lrtb[..., 3:4]
+ if weights:
+ left /= weights[0]
+ right /= weights[1]
+ top /= weights[2]
+ bottom /= weights[3]
+
+ anchor_ymin = anchors[..., 0:1]
+ anchor_xmin = anchors[..., 1:2]
+ anchor_ymax = anchors[..., 2:3]
+ anchor_xmax = anchors[..., 3:4]
+
+ anchor_h = anchor_ymax - anchor_ymin
+ anchor_w = anchor_xmax - anchor_xmin
+ anchor_yc = anchor_ymin + 0.5 * anchor_h
+ anchor_xc = anchor_xmin + 0.5 * anchor_w
+ anchor_h += EPSILON
+ anchor_w += EPSILON
+
+ decoded_boxes_ymin = anchor_yc - top * anchor_h
+ decoded_boxes_xmin = anchor_xc - left * anchor_w
+ decoded_boxes_ymax = anchor_yc + bottom * anchor_h
+ decoded_boxes_xmax = anchor_xc + right * anchor_w
+
+ decoded_boxes_lrtb = tf.concat(
+ [decoded_boxes_ymin, decoded_boxes_xmin,
+ decoded_boxes_ymax, decoded_boxes_xmax],
+ axis=-1)
+ return decoded_boxes_lrtb
+
+
+def filter_boxes(boxes, scores, image_shape, min_size_threshold):
+ """Filter and remove boxes that are too small or fall outside the image.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ scores: a tensor whose shape is the same as tf.shape(boxes)[:-1]
+ representing the original scores of the boxes.
+ image_shape: a tensor whose shape is the same as, or `broadcastable` to
+ `boxes` except the last dimension, which is 2, representing [height,
+ width] of the scaled image.
+ min_size_threshold: a float representing the minimal box size in each side
+ (w.r.t. the scaled image). Boxes whose sides are smaller than it will be
+ filtered out.
+
+ Returns:
+ filtered_boxes: a tensor whose shape is the same as `boxes` but with
+ the position of the filtered boxes are filled with 0.
+ filtered_scores: a tensor whose shape is the same as 'scores' but with
+ the positinon of the filtered boxes filled with 0.
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('filter_boxes'):
+ if isinstance(image_shape, list) or isinstance(image_shape, tuple):
+ height, width = image_shape
+ else:
+ image_shape = tf.cast(image_shape, dtype=boxes.dtype)
+ height = image_shape[..., 0]
+ width = image_shape[..., 1]
+
+ ymin = boxes[..., 0]
+ xmin = boxes[..., 1]
+ ymax = boxes[..., 2]
+ xmax = boxes[..., 3]
+
+ h = ymax - ymin + 1.0
+ w = xmax - xmin + 1.0
+ yc = ymin + 0.5 * h
+ xc = xmin + 0.5 * w
+
+ min_size = tf.cast(
+ tf.math.maximum(min_size_threshold, 1.0), dtype=boxes.dtype)
+
+ filtered_size_mask = tf.math.logical_and(
+ tf.math.greater(h, min_size), tf.math.greater(w, min_size))
+ filtered_center_mask = tf.logical_and(
+ tf.math.logical_and(tf.math.greater(yc, 0.0), tf.math.less(yc, height)),
+ tf.math.logical_and(tf.math.greater(xc, 0.0), tf.math.less(xc, width)))
+ filtered_mask = tf.math.logical_and(filtered_size_mask,
+ filtered_center_mask)
+
+ filtered_scores = tf.where(filtered_mask, scores, tf.zeros_like(scores))
+ filtered_boxes = tf.cast(
+ tf.expand_dims(filtered_mask, axis=-1), dtype=boxes.dtype) * boxes
+
+ return filtered_boxes, filtered_scores
+
+
+def filter_boxes_by_scores(boxes, scores, min_score_threshold):
+ """Filter and remove boxes whose scores are smaller than the threshold.
+
+ Args:
+ boxes: a tensor whose last dimension is 4 representing the coordinates of
+ boxes in ymin, xmin, ymax, xmax order.
+ scores: a tensor whose shape is the same as tf.shape(boxes)[:-1]
+ representing the original scores of the boxes.
+ min_score_threshold: a float representing the minimal box score threshold.
+ Boxes whose score are smaller than it will be filtered out.
+
+ Returns:
+ filtered_boxes: a tensor whose shape is the same as `boxes` but with
+ the position of the filtered boxes are filled with -1.
+ filtered_scores: a tensor whose shape is the same as 'scores' but with
+ the
+ """
+ if boxes.shape[-1] != 4:
+ raise ValueError('boxes.shape[1] is {:d}, but must be 4.'.format(
+ boxes.shape[-1]))
+
+ with tf.name_scope('filter_boxes_by_scores'):
+ filtered_mask = tf.math.greater(scores, min_score_threshold)
+ filtered_scores = tf.where(filtered_mask, scores, -tf.ones_like(scores))
+ filtered_boxes = tf.cast(
+ tf.expand_dims(filtered_mask, axis=-1), dtype=boxes.dtype) * boxes
+
+ return filtered_boxes, filtered_scores
+
+
+def top_k_boxes(boxes, scores, k):
+ """Sort and select top k boxes according to the scores.
+
+ Args:
+ boxes: a tensor of shape [batch_size, N, 4] representing the coordiante of
+ the boxes. N is the number of boxes per image.
+ scores: a tensor of shsape [batch_size, N] representing the socre of the
+ boxes.
+ k: an integer or a tensor indicating the top k number.
+
+ Returns:
+ selected_boxes: a tensor of shape [batch_size, k, 4] representing the
+ selected top k box coordinates.
+ selected_scores: a tensor of shape [batch_size, k] representing the selected
+ top k box scores.
+ """
+ with tf.name_scope('top_k_boxes'):
+ selected_scores, top_k_indices = tf.nn.top_k(scores, k=k, sorted=True)
+
+ batch_size, _ = scores.get_shape().as_list()
+ if batch_size == 1:
+ selected_boxes = tf.squeeze(
+ tf.gather(boxes, top_k_indices, axis=1), axis=1)
+ else:
+ top_k_indices_shape = tf.shape(top_k_indices)
+ batch_indices = (
+ tf.expand_dims(tf.range(top_k_indices_shape[0]), axis=-1) *
+ tf.ones([1, top_k_indices_shape[-1]], dtype=tf.int32))
+ gather_nd_indices = tf.stack([batch_indices, top_k_indices], axis=-1)
+ selected_boxes = tf.gather_nd(boxes, gather_nd_indices)
+
+ return selected_boxes, selected_scores
+
+
+def bbox_overlap(boxes, gt_boxes):
+ """Calculates the overlap between proposal and ground truth boxes.
+
+ Some `gt_boxes` may have been padded. The returned `iou` tensor for these
+ boxes will be -1.
+
+ Args:
+ boxes: a tensor with a shape of [batch_size, N, 4]. N is the number of
+ proposals before groundtruth assignment (e.g., rpn_post_nms_topn). The
+ last dimension is the pixel coordinates in [ymin, xmin, ymax, xmax] form.
+ gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4]. This
+ tensor might have paddings with a negative value.
+
+ Returns:
+ iou: a tensor with as a shape of [batch_size, N, MAX_NUM_INSTANCES].
+ """
+ with tf.name_scope('bbox_overlap'):
+ bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(
+ value=boxes, num_or_size_splits=4, axis=2)
+ gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(
+ value=gt_boxes, num_or_size_splits=4, axis=2)
+
+ # Calculates the intersection area.
+ i_xmin = tf.math.maximum(bb_x_min, tf.transpose(gt_x_min, [0, 2, 1]))
+ i_xmax = tf.math.minimum(bb_x_max, tf.transpose(gt_x_max, [0, 2, 1]))
+ i_ymin = tf.math.maximum(bb_y_min, tf.transpose(gt_y_min, [0, 2, 1]))
+ i_ymax = tf.math.minimum(bb_y_max, tf.transpose(gt_y_max, [0, 2, 1]))
+ i_area = tf.math.maximum((i_xmax - i_xmin), 0) * tf.math.maximum(
+ (i_ymax - i_ymin), 0)
+
+ # Calculates the union area.
+ bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
+ gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
+ # Adds a small epsilon to avoid divide-by-zero.
+ u_area = bb_area + tf.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8
+
+ # Calculates IoU.
+ iou = i_area / u_area
+
+ # Fills -1 for IoU entries between the padded ground truth boxes.
+ gt_invalid_mask = tf.less(
+ tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
+ padding_mask = tf.logical_or(
+ tf.zeros_like(bb_x_min, dtype=tf.bool),
+ tf.transpose(gt_invalid_mask, [0, 2, 1]))
+ iou = tf.where(padding_mask, -tf.ones_like(iou), iou)
+
+ return iou
+
+
+def get_non_empty_box_indices(boxes):
+ """Get indices for non-empty boxes."""
+ # Selects indices if box height or width is 0.
+ height = boxes[:, 2] - boxes[:, 0]
+ width = boxes[:, 3] - boxes[:, 1]
+ indices = tf.where(
+ tf.logical_and(tf.greater(height, 0), tf.greater(width, 0)))
+ return indices[:, 0]
diff --git a/modeling/official/legacy/detection/utils/class_utils.py b/modeling/official/legacy/detection/utils/class_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeadfdff50f7f9899a7e72cc2d10b5feb53cee04
--- /dev/null
+++ b/modeling/official/legacy/detection/utils/class_utils.py
@@ -0,0 +1,44 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for handling dataset object categories."""
+
+
+def coco_split_class_ids(split_name):
+ """Return the COCO class split ids based on split name and training mode.
+
+ Args:
+ split_name: The name of dataset split.
+
+ Returns:
+ class_ids: a python list of integer.
+ """
+ if split_name == 'all':
+ return []
+
+ elif split_name == 'voc':
+ return [
+ 1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72
+ ]
+
+ elif split_name == 'nonvoc':
+ return [
+ 8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36,
+ 37, 38, 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56,
+ 57, 58, 59, 60, 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84,
+ 85, 86, 87, 88, 89, 90
+ ]
+
+ else:
+ raise ValueError('Invalid split name {}!!!'.format(split_name))
diff --git a/modeling/official/legacy/detection/utils/dataloader_utils.py b/modeling/official/legacy/detection/utils/dataloader_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..06c2659b8309a32c4c593fa77a439f7c1e1e26ed
--- /dev/null
+++ b/modeling/official/legacy/detection/utils/dataloader_utils.py
@@ -0,0 +1,40 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for dataloader."""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.utils import input_utils
+
+
+def process_source_id(source_id):
+ """Processes source_id to the right format."""
+ if source_id.dtype == tf.string:
+ source_id = tf.cast(tf.strings.to_number(source_id), tf.int64)
+ with tf.control_dependencies([source_id]):
+ source_id = tf.cond(
+ pred=tf.equal(tf.size(input=source_id), 0),
+ true_fn=lambda: tf.cast(tf.constant(-1), tf.int64),
+ false_fn=lambda: tf.identity(source_id))
+ return source_id
+
+
+def pad_groundtruths_to_fixed_size(gt, n):
+ """Pads the first dimension of groundtruths labels to the fixed size."""
+ gt['boxes'] = input_utils.pad_to_fixed_size(gt['boxes'], n, -1)
+ gt['is_crowds'] = input_utils.pad_to_fixed_size(gt['is_crowds'], n, 0)
+ gt['areas'] = input_utils.pad_to_fixed_size(gt['areas'], n, -1)
+ gt['classes'] = input_utils.pad_to_fixed_size(gt['classes'], n, -1)
+ return gt
diff --git a/modeling/official/legacy/detection/utils/input_utils.py b/modeling/official/legacy/detection/utils/input_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..67ce114600fc8fd4cfaa939f9c68211a11094209
--- /dev/null
+++ b/modeling/official/legacy/detection/utils/input_utils.py
@@ -0,0 +1,359 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for input processing."""
+
+import math
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.detection.utils import box_utils
+from official.vision.utils.object_detection import preprocessor
+
+
+def pad_to_fixed_size(input_tensor, size, constant_values=0):
+ """Pads data to a fixed length at the first dimension.
+
+ Args:
+ input_tensor: `Tensor` with any dimension.
+ size: `int` number for the first dimension of output Tensor.
+ constant_values: `int` value assigned to the paddings.
+
+ Returns:
+ `Tensor` with the first dimension padded to `size`.
+ """
+ input_shape = input_tensor.get_shape().as_list()
+ padding_shape = []
+
+ # Computes the padding length on the first dimension.
+ padding_length = tf.maximum(0, size - tf.shape(input_tensor)[0])
+ assert_length = tf.Assert(
+ tf.greater_equal(padding_length, 0), [padding_length])
+ with tf.control_dependencies([assert_length]):
+ padding_shape.append(padding_length)
+
+ # Copies shapes of the rest of input shape dimensions.
+ for i in range(1, len(input_shape)):
+ padding_shape.append(tf.shape(input=input_tensor)[i])
+
+ # Pads input tensor to the fixed first dimension.
+ paddings = tf.cast(constant_values * tf.ones(padding_shape),
+ input_tensor.dtype)
+ padded_tensor = tf.concat([input_tensor, paddings], axis=0)
+ output_shape = input_shape
+ output_shape[0] = size
+ padded_tensor.set_shape(output_shape)
+ return padded_tensor
+
+
+def normalize_image(image,
+ offset=(0.485, 0.456, 0.406),
+ scale=(0.229, 0.224, 0.225)):
+ """Normalizes the image to zero mean and unit variance."""
+ image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+ offset = tf.constant(offset)
+ offset = tf.expand_dims(offset, axis=0)
+ offset = tf.expand_dims(offset, axis=0)
+ image -= offset
+
+ scale = tf.constant(scale)
+ scale = tf.expand_dims(scale, axis=0)
+ scale = tf.expand_dims(scale, axis=0)
+ image /= scale
+ return image
+
+
+def compute_padded_size(desired_size, stride):
+ """Compute the padded size given the desired size and the stride.
+
+ The padded size will be the smallest rectangle, such that each dimension is
+ the smallest multiple of the stride which is larger than the desired
+ dimension. For example, if desired_size = (100, 200) and stride = 32,
+ the output padded_size = (128, 224).
+
+ Args:
+ desired_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the target output image size.
+ stride: an integer, the stride of the backbone network.
+
+ Returns:
+ padded_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the padded output image size.
+ """
+ if isinstance(desired_size, list) or isinstance(desired_size, tuple):
+ padded_size = [
+ int(math.ceil(d * 1.0 / stride) * stride) for d in desired_size
+ ]
+ else:
+ padded_size = tf.cast(
+ tf.math.ceil(tf.cast(desired_size, dtype=tf.float32) / stride) * stride,
+ tf.int32)
+ return padded_size
+
+
+def resize_and_crop_image(image,
+ desired_size,
+ padded_size,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ seed=1,
+ method=tf.image.ResizeMethod.BILINEAR):
+ """Resizes the input image to output size.
+
+ Resize and pad images given the desired output size of the image and
+ stride size.
+
+ Here are the preprocessing steps.
+ 1. For a given image, keep its aspect ratio and rescale the image to make it
+ the largest rectangle to be bounded by the rectangle specified by the
+ `desired_size`.
+ 2. Pad the rescaled image to the padded_size.
+
+ Args:
+ image: a `Tensor` of shape [height, width, 3] representing an image.
+ desired_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the desired actual output image size.
+ padded_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the padded output image size. Padding will be applied
+ after scaling the image to the desired_size.
+ aug_scale_min: a `float` with range between [0, 1.0] representing minimum
+ random scale applied to desired_size for training scale jittering.
+ aug_scale_max: a `float` with range between [1.0, inf] representing maximum
+ random scale applied to desired_size for training scale jittering.
+ seed: seed for random scale jittering.
+ method: function to resize input image to scaled image.
+
+ Returns:
+ output_image: `Tensor` of shape [height, width, 3] where [height, width]
+ equals to `output_size`.
+ image_info: a 2D `Tensor` that encodes the information of the image and the
+ applied preprocessing. It is in the format of
+ [[original_height, original_width], [desired_height, desired_width],
+ [y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
+ desireed_width] is the actual scaled image size, and [y_scale, x_scale] is
+ the scaling factory, which is the ratio of
+ scaled dimension / original dimension.
+ """
+ with tf.name_scope('resize_and_crop_image'):
+ image_size = tf.cast(tf.shape(input=image)[0:2], tf.float32)
+
+ random_jittering = (aug_scale_min != 1.0 or aug_scale_max != 1.0)
+
+ if random_jittering:
+ random_scale = tf.random.uniform([],
+ aug_scale_min,
+ aug_scale_max,
+ seed=seed)
+ scaled_size = tf.round(random_scale * desired_size)
+ else:
+ scaled_size = desired_size
+
+ scale = tf.minimum(scaled_size[0] / image_size[0],
+ scaled_size[1] / image_size[1])
+ scaled_size = tf.round(image_size * scale)
+
+ # Computes 2D image_scale.
+ image_scale = scaled_size / image_size
+
+ # Selects non-zero random offset (x, y) if scaled image is larger than
+ # desired_size.
+ if random_jittering:
+ max_offset = scaled_size - desired_size
+ max_offset = tf.where(
+ tf.less(max_offset, 0), tf.zeros_like(max_offset), max_offset)
+ offset = max_offset * tf.random.uniform([
+ 2,
+ ], 0, 1, seed=seed)
+ offset = tf.cast(offset, tf.int32)
+ else:
+ offset = tf.zeros((2,), tf.int32)
+
+ scaled_image = tf.image.resize(
+ image, tf.cast(scaled_size, tf.int32), method=method)
+
+ if random_jittering:
+ scaled_image = scaled_image[offset[0]:offset[0] + desired_size[0],
+ offset[1]:offset[1] + desired_size[1], :]
+
+ output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0,
+ padded_size[0], padded_size[1])
+
+ image_info = tf.stack([
+ image_size,
+ tf.cast(desired_size, dtype=tf.float32), image_scale,
+ tf.cast(offset, tf.float32)
+ ])
+ return output_image, image_info
+
+
+def resize_and_crop_image_v2(image,
+ short_side,
+ long_side,
+ padded_size,
+ aug_scale_min=1.0,
+ aug_scale_max=1.0,
+ seed=1,
+ method=tf.image.ResizeMethod.BILINEAR):
+ """Resizes the input image to output size (Faster R-CNN style).
+
+ Resize and pad images given the specified short / long side length and the
+ stride size.
+
+ Here are the preprocessing steps.
+ 1. For a given image, keep its aspect ratio and first try to rescale the short
+ side of the original image to `short_side`.
+ 2. If the scaled image after 1 has a long side that exceeds `long_side`, keep
+ the aspect ratio and rescal the long side of the image to `long_side`.
+ 2. Pad the rescaled image to the padded_size.
+
+ Args:
+ image: a `Tensor` of shape [height, width, 3] representing an image.
+ short_side: a scalar `Tensor` or `int` representing the desired short side
+ to be rescaled to.
+ long_side: a scalar `Tensor` or `int` representing the desired long side to
+ be rescaled to.
+ padded_size: a `Tensor` or `int` list/tuple of two elements representing
+ [height, width] of the padded output image size. Padding will be applied
+ after scaling the image to the desired_size.
+ aug_scale_min: a `float` with range between [0, 1.0] representing minimum
+ random scale applied to desired_size for training scale jittering.
+ aug_scale_max: a `float` with range between [1.0, inf] representing maximum
+ random scale applied to desired_size for training scale jittering.
+ seed: seed for random scale jittering.
+ method: function to resize input image to scaled image.
+
+ Returns:
+ output_image: `Tensor` of shape [height, width, 3] where [height, width]
+ equals to `output_size`.
+ image_info: a 2D `Tensor` that encodes the information of the image and the
+ applied preprocessing. It is in the format of
+ [[original_height, original_width], [desired_height, desired_width],
+ [y_scale, x_scale], [y_offset, x_offset]], where [desired_height,
+ desired_width] is the actual scaled image size, and [y_scale, x_scale] is
+ the scaling factor, which is the ratio of
+ scaled dimension / original dimension.
+ """
+ with tf.name_scope('resize_and_crop_image_v2'):
+ image_size = tf.cast(tf.shape(image)[0:2], tf.float32)
+
+ scale_using_short_side = (
+ short_side / tf.math.minimum(image_size[0], image_size[1]))
+ scale_using_long_side = (
+ long_side / tf.math.maximum(image_size[0], image_size[1]))
+
+ scaled_size = tf.math.round(image_size * scale_using_short_side)
+ scaled_size = tf.where(
+ tf.math.greater(
+ tf.math.maximum(scaled_size[0], scaled_size[1]), long_side),
+ tf.math.round(image_size * scale_using_long_side), scaled_size)
+ desired_size = scaled_size
+
+ random_jittering = (aug_scale_min != 1.0 or aug_scale_max != 1.0)
+
+ if random_jittering:
+ random_scale = tf.random.uniform([],
+ aug_scale_min,
+ aug_scale_max,
+ seed=seed)
+ scaled_size = tf.math.round(random_scale * scaled_size)
+
+ # Computes 2D image_scale.
+ image_scale = scaled_size / image_size
+
+ # Selects non-zero random offset (x, y) if scaled image is larger than
+ # desired_size.
+ if random_jittering:
+ max_offset = scaled_size - desired_size
+ max_offset = tf.where(
+ tf.math.less(max_offset, 0), tf.zeros_like(max_offset), max_offset)
+ offset = max_offset * tf.random.uniform([
+ 2,
+ ], 0, 1, seed=seed)
+ offset = tf.cast(offset, tf.int32)
+ else:
+ offset = tf.zeros((2,), tf.int32)
+
+ scaled_image = tf.image.resize(
+ image, tf.cast(scaled_size, tf.int32), method=method)
+
+ if random_jittering:
+ scaled_image = scaled_image[offset[0]:offset[0] + desired_size[0],
+ offset[1]:offset[1] + desired_size[1], :]
+
+ output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0,
+ padded_size[0], padded_size[1])
+
+ image_info = tf.stack([
+ image_size,
+ tf.cast(desired_size, dtype=tf.float32), image_scale,
+ tf.cast(offset, tf.float32)
+ ])
+ return output_image, image_info
+
+
+def resize_and_crop_boxes(boxes, image_scale, output_size, offset):
+ """Resizes boxes to output size with scale and offset.
+
+ Args:
+ boxes: `Tensor` of shape [N, 4] representing ground truth boxes.
+ image_scale: 2D float `Tensor` representing scale factors that apply to
+ [height, width] of input image.
+ output_size: 2D `Tensor` or `int` representing [height, width] of target
+ output image size.
+ offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
+ boxes.
+
+ Returns:
+ boxes: `Tensor` of shape [N, 4] representing the scaled boxes.
+ """
+ # Adjusts box coordinates based on image_scale and offset.
+ boxes *= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
+ boxes -= tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
+ # Clips the boxes.
+ boxes = box_utils.clip_boxes(boxes, output_size)
+ return boxes
+
+
+def resize_and_crop_masks(masks, image_scale, output_size, offset):
+ """Resizes boxes to output size with scale and offset.
+
+ Args:
+ masks: `Tensor` of shape [N, H, W, 1] representing ground truth masks.
+ image_scale: 2D float `Tensor` representing scale factors that apply to
+ [height, width] of input image.
+ output_size: 2D `Tensor` or `int` representing [height, width] of target
+ output image size.
+ offset: 2D `Tensor` representing top-left corner [y0, x0] to crop scaled
+ boxes.
+
+ Returns:
+ masks: `Tensor` of shape [N, H, W, 1] representing the scaled masks.
+ """
+ mask_size = tf.shape(input=masks)[1:3]
+ scaled_size = tf.cast(image_scale * tf.cast(mask_size, image_scale.dtype),
+ tf.int32)
+ scaled_masks = tf.image.resize(
+ masks, scaled_size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
+ offset = tf.cast(offset, tf.int32)
+ scaled_masks = scaled_masks[:, offset[0]:offset[0] + output_size[0],
+ offset[1]:offset[1] + output_size[1], :]
+
+ output_masks = tf.image.pad_to_bounding_box(scaled_masks, 0, 0,
+ output_size[0], output_size[1])
+ return output_masks
+
+
+def random_horizontal_flip(image, boxes=None, masks=None):
+ """Randomly flips input image and bounding boxes."""
+ return preprocessor.random_horizontal_flip(image, boxes, masks)
diff --git a/modeling/official/legacy/detection/utils/mask_utils.py b/modeling/official/legacy/detection/utils/mask_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d6ff18d9ee328f97fb2c19dc026fad9d371c6a3
--- /dev/null
+++ b/modeling/official/legacy/detection/utils/mask_utils.py
@@ -0,0 +1,171 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions for segmentations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import cv2
+import numpy as np
+
+
+def paste_instance_masks(masks, detected_boxes, image_height, image_width):
+ """Paste instance masks to generate the image segmentation results.
+
+ Args:
+ masks: a numpy array of shape [N, mask_height, mask_width] representing the
+ instance masks w.r.t. the `detected_boxes`.
+ detected_boxes: a numpy array of shape [N, 4] representing the reference
+ bounding boxes.
+ image_height: an integer representing the height of the image.
+ image_width: an integer representing the width of the image.
+
+ Returns:
+ segms: a numpy array of shape [N, image_height, image_width] representing
+ the instance masks *pasted* on the image canvas.
+ """
+
+ def expand_boxes(boxes, scale):
+ """Expands an array of boxes by a given scale."""
+ # Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/boxes.py#L227 # pylint: disable=line-too-long
+ # The `boxes` in the reference implementation is in [x1, y1, x2, y2] form,
+ # whereas `boxes` here is in [x1, y1, w, h] form
+ w_half = boxes[:, 2] * .5
+ h_half = boxes[:, 3] * .5
+ x_c = boxes[:, 0] + w_half
+ y_c = boxes[:, 1] + h_half
+
+ w_half *= scale
+ h_half *= scale
+
+ boxes_exp = np.zeros(boxes.shape)
+ boxes_exp[:, 0] = x_c - w_half
+ boxes_exp[:, 2] = x_c + w_half
+ boxes_exp[:, 1] = y_c - h_half
+ boxes_exp[:, 3] = y_c + h_half
+
+ return boxes_exp
+
+ # Reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/core/test.py#L812 # pylint: disable=line-too-long
+ # To work around an issue with cv2.resize (it seems to automatically pad
+ # with repeated border values), we manually zero-pad the masks by 1 pixel
+ # prior to resizing back to the original image resolution. This prevents
+ # "top hat" artifacts. We therefore need to expand the reference boxes by an
+ # appropriate factor.
+ _, mask_height, mask_width = masks.shape
+ scale = max((mask_width + 2.0) / mask_width,
+ (mask_height + 2.0) / mask_height)
+
+ ref_boxes = expand_boxes(detected_boxes, scale)
+ ref_boxes = ref_boxes.astype(np.int32)
+ padded_mask = np.zeros((mask_height + 2, mask_width + 2), dtype=np.float32)
+ segms = []
+ for mask_ind, mask in enumerate(masks):
+ im_mask = np.zeros((image_height, image_width), dtype=np.uint8)
+ # Process mask inside bounding boxes.
+ padded_mask[1:-1, 1:-1] = mask[:, :]
+
+ ref_box = ref_boxes[mask_ind, :]
+ w = ref_box[2] - ref_box[0] + 1
+ h = ref_box[3] - ref_box[1] + 1
+ w = np.maximum(w, 1)
+ h = np.maximum(h, 1)
+
+ mask = cv2.resize(padded_mask, (w, h))
+ mask = np.array(mask > 0.5, dtype=np.uint8)
+
+ x_0 = min(max(ref_box[0], 0), image_width)
+ x_1 = min(max(ref_box[2] + 1, 0), image_width)
+ y_0 = min(max(ref_box[1], 0), image_height)
+ y_1 = min(max(ref_box[3] + 1, 0), image_height)
+
+ im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - ref_box[1]):(y_1 - ref_box[1]),
+ (x_0 - ref_box[0]):(x_1 - ref_box[0])]
+ segms.append(im_mask)
+
+ segms = np.array(segms)
+ assert masks.shape[0] == segms.shape[0]
+ return segms
+
+
+def paste_instance_masks_v2(masks, detected_boxes, image_height, image_width):
+ """Paste instance masks to generate the image segmentation (v2).
+
+ Args:
+ masks: a numpy array of shape [N, mask_height, mask_width] representing the
+ instance masks w.r.t. the `detected_boxes`.
+ detected_boxes: a numpy array of shape [N, 4] representing the reference
+ bounding boxes.
+ image_height: an integer representing the height of the image.
+ image_width: an integer representing the width of the image.
+
+ Returns:
+ segms: a numpy array of shape [N, image_height, image_width] representing
+ the instance masks *pasted* on the image canvas.
+ """
+ _, mask_height, mask_width = masks.shape
+
+ segms = []
+ for i, mask in enumerate(masks):
+ box = detected_boxes[i, :]
+ xmin = box[0]
+ ymin = box[1]
+ xmax = xmin + box[2]
+ ymax = ymin + box[3]
+
+ # Sample points of the cropped mask w.r.t. the image grid.
+ # Note that these coordinates may fall beyond the image.
+ # Pixel clipping will happen after warping.
+ xmin_int = int(math.floor(xmin))
+ xmax_int = int(math.ceil(xmax))
+ ymin_int = int(math.floor(ymin))
+ ymax_int = int(math.ceil(ymax))
+
+ alpha = box[2] / (1.0 * mask_width)
+ beta = box[3] / (1.0 * mask_height)
+ # pylint: disable=invalid-name
+ # Transformation from mask pixel indices to image coordinate.
+ M_mask_to_image = np.array([[alpha, 0, xmin], [0, beta, ymin], [0, 0, 1]],
+ dtype=np.float32)
+ # Transformation from image to cropped mask coordinate.
+ M_image_to_crop = np.array(
+ [[1, 0, -xmin_int], [0, 1, -ymin_int], [0, 0, 1]], dtype=np.float32)
+ M = np.dot(M_image_to_crop, M_mask_to_image)
+ # Compensate the half pixel offset that OpenCV has in the
+ # warpPerspective implementation: the top-left pixel is sampled
+ # at (0,0), but we want it to be at (0.5, 0.5).
+ M = np.dot(
+ np.dot(
+ np.array([[1, 0, -0.5], [0, 1, -0.5], [0, 0, 1]], np.float32), M),
+ np.array([[1, 0, 0.5], [0, 1, 0.5], [0, 0, 1]], np.float32))
+ # pylint: enable=invalid-name
+ cropped_mask = cv2.warpPerspective(
+ mask.astype(np.float32), M, (xmax_int - xmin_int, ymax_int - ymin_int))
+ cropped_mask = np.array(cropped_mask > 0.5, dtype=np.uint8)
+
+ img_mask = np.zeros((image_height, image_width))
+ x0 = max(min(xmin_int, image_width), 0)
+ x1 = max(min(xmax_int, image_width), 0)
+ y0 = max(min(ymin_int, image_height), 0)
+ y1 = max(min(ymax_int, image_height), 0)
+ img_mask[y0:y1, x0:x1] = cropped_mask[(y0 - ymin_int):(y1 - ymin_int),
+ (x0 - xmin_int):(x1 - xmin_int)]
+
+ segms.append(img_mask)
+
+ segms = np.array(segms)
+ return segms
diff --git a/modeling/official/legacy/image_classification/README.md b/modeling/official/legacy/image_classification/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6d9231b4848f55a69395d047f3ec33674ac59599
--- /dev/null
+++ b/modeling/official/legacy/image_classification/README.md
@@ -0,0 +1,222 @@
+# Image Classification
+
+**Warning:** the features in the `image_classification/` directory have been
+fully integrated into the [new code base](https://github.com/tensorflow/models/tree/benchmark/official/vision/modeling/backbones).
+
+This folder contains TF 2 model examples for image classification:
+
+* [MNIST](#mnist)
+* [Classifier Trainer](#classifier-trainer), a framework that uses the Keras
+compile/fit methods for image classification models, including:
+ * ResNet
+ * EfficientNet[^1]
+
+[^1]: Currently a work in progress. We cannot match "AutoAugment (AA)" in [the original version](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet).
+For more information about other types of models, please refer to this
+[README file](../../README.md).
+
+## Before you begin
+Please make sure that you have the latest version of TensorFlow
+installed and add the models folder to your Python path.
+
+### ImageNet preparation
+
+#### Using TFDS
+`classifier_trainer.py` supports ImageNet with
+[TensorFlow Datasets (TFDS)](https://www.tensorflow.org/datasets/overview).
+
+Please see the following [example snippet](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/scripts/download_and_prepare.py)
+for more information on how to use TFDS to download and prepare datasets, and
+specifically the [TFDS ImageNet readme](https://github.com/tensorflow/datasets/blob/master/docs/catalog/imagenet2012.md)
+for manual download instructions.
+
+#### Legacy TFRecords
+Download the ImageNet dataset and convert it to TFRecord format.
+The following [script](https://github.com/tensorflow/tpu/blob/master/tools/datasets/imagenet_to_gcs.py)
+and [README](https://github.com/tensorflow/tpu/tree/master/tools/datasets#imagenet_to_gcspy)
+provide a few options.
+
+Note that the legacy ResNet runners, e.g. [resnet/resnet_ctl_imagenet_main.py](resnet/resnet_ctl_imagenet_main.py)
+require TFRecords whereas `classifier_trainer.py` can use both by setting the
+builder to 'records' or 'tfds' in the configurations.
+
+### Running on Cloud TPUs
+
+Note: These models will **not** work with TPUs on Colab.
+
+You can train image classification models on Cloud TPUs using
+[tf.distribute.TPUStrategy](https://www.tensorflow.org/api_docs/python/tf.distribute.TPUStrategy?version=nightly).
+If you are not familiar with Cloud TPUs, it is strongly recommended that you go
+through the
+[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
+create a TPU and GCE VM.
+
+### Running on multiple GPU hosts
+
+You can also train these models on multiple hosts, each with GPUs, using
+[tf.distribute.experimental.MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy).
+
+The easiest way to run multi-host benchmarks is to set the
+[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG)
+appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on
+2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and
+host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker",
+"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the
+available GPUs at each host.
+
+## MNIST
+
+To download the data and run the MNIST sample model locally for the first time,
+run one of the following command:
+
+
+```bash
+python3 mnist_main.py \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --train_epochs=10 \
+ --distribution_strategy=one_device \
+ --num_gpus=$NUM_GPUS \
+ --download
+```
+
+
+To train the model on a Cloud TPU, run the following command:
+
+
+```bash
+python3 mnist_main.py \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --train_epochs=10 \
+ --distribution_strategy=tpu \
+ --download
+```
+
+
+Note: the `--download` flag is only required the first time you run the model.
+
+## Classifier Trainer
+The classifier trainer is a unified framework for running image classification
+models using Keras's compile/fit methods. Experiments should be provided in the
+form of YAML files, some examples are included within the configs/examples
+folder. Please see [configs/examples](./configs/examples) for more example
+configurations.
+
+The provided configuration files use a per replica batch size and is scaled
+by the number of devices. For instance, if `batch size` = 64, then for 1 GPU
+the global batch size would be 64 * 1 = 64. For 8 GPUs, the global batch size
+would be 64 * 8 = 512. Similarly, for a v3-8 TPU, the global batch size would
+be 64 * 8 = 512, and for a v3-32, the global batch size is 64 * 32 = 2048.
+
+### ResNet50
+
+#### On GPU:
+
+
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=resnet \
+ --dataset=imagenet \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/resnet/imagenet/gpu.yaml \
+ --params_override='runtime.num_gpus=$NUM_GPUS'
+```
+
+
+To train on multiple hosts, each with GPUs attached using
+[MultiWorkerMirroredStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/MultiWorkerMirroredStrategy)
+please update `runtime` section in gpu.yaml
+(or override using `--params_override`) with:
+
+
+
+```YAML
+# gpu.yaml
+runtime:
+ distribution_strategy: 'multi_worker_mirrored'
+ worker_hosts: '$HOST1:port,$HOST2:port'
+ num_gpus: $NUM_GPUS
+ task_index: 0
+```
+
+
+By having `task_index: 0` on the first host and `task_index: 1` on the second
+and so on. `$HOST1` and `$HOST2` are the IP addresses of the hosts, and `port`
+can be chosen any free port on the hosts. Only the first host will write
+TensorBoard Summaries and save checkpoints.
+
+#### On TPU:
+
+
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=resnet \
+ --dataset=imagenet \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/resnet/imagenet/tpu.yaml
+```
+
+
+
+### VGG-16
+
+#### On GPU:
+
+
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=vgg \
+ --dataset=imagenet \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/vgg/imagenet/gpu.yaml \
+ --params_override='runtime.num_gpus=$NUM_GPUS'
+```
+
+
+
+### EfficientNet
+**Note: EfficientNet development is a work in progress.**
+#### On GPU:
+
+
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=efficientnet \
+ --dataset=imagenet \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml \
+ --params_override='runtime.num_gpus=$NUM_GPUS'
+```
+
+
+
+#### On TPU:
+
+
+```bash
+python3 classifier_trainer.py \
+ --mode=train_and_eval \
+ --model_type=efficientnet \
+ --dataset=imagenet \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --config_file=configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
+```
+
+
+Note that the number of GPU devices can be overridden in the command line using
+`--params_overrides`. The TPU does not need this override as the device is fixed
+by providing the TPU address or name with the `--tpu` flag.
+
diff --git a/modeling/official/legacy/image_classification/__init__.py b/modeling/official/legacy/image_classification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/image_classification/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/image_classification/augment.py b/modeling/official/legacy/image_classification/augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b677b08209f6f6a0589d4331d65e4b72668efe9
--- /dev/null
+++ b/modeling/official/legacy/image_classification/augment.py
@@ -0,0 +1,1061 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""AutoAugment and RandAugment policies for enhanced image preprocessing.
+
+AutoAugment Reference: https://arxiv.org/abs/1805.09501
+RandAugment Reference: https://arxiv.org/abs/1909.13719
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+from typing import Any, Dict, List, Optional, Text, Tuple
+
+import tensorflow as tf, tf_keras
+
+
+# This signifies the max integer that the controller RNN could predict for the
+# augmentation scheme.
+_MAX_LEVEL = 10.
+
+
+def to_4d(image: tf.Tensor) -> tf.Tensor:
+ """Converts an input Tensor to 4 dimensions.
+
+ 4D image => [N, H, W, C] or [N, C, H, W]
+ 3D image => [1, H, W, C] or [1, C, H, W]
+ 2D image => [1, H, W, 1]
+
+ Args:
+ image: The 2/3/4D input tensor.
+
+ Returns:
+ A 4D image tensor.
+
+ Raises:
+ `TypeError` if `image` is not a 2/3/4D tensor.
+
+ """
+ shape = tf.shape(image)
+ original_rank = tf.rank(image)
+ left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
+ right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
+ new_shape = tf.concat(
+ [
+ tf.ones(shape=left_pad, dtype=tf.int32),
+ shape,
+ tf.ones(shape=right_pad, dtype=tf.int32),
+ ],
+ axis=0,
+ )
+ return tf.reshape(image, new_shape)
+
+
+def from_4d(image: tf.Tensor, ndims: tf.Tensor) -> tf.Tensor:
+ """Converts a 4D image back to `ndims` rank."""
+ shape = tf.shape(image)
+ begin = tf.cast(tf.less_equal(ndims, 3), dtype=tf.int32)
+ end = 4 - tf.cast(tf.equal(ndims, 2), dtype=tf.int32)
+ new_shape = shape[begin:end]
+ return tf.reshape(image, new_shape)
+
+
+def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
+ """Converts translations to a projective transform.
+
+ The translation matrix looks like this:
+ [[1 0 -dx]
+ [0 1 -dy]
+ [0 0 1]]
+
+ Args:
+ translations: The 2-element list representing [dx, dy], or a matrix of
+ 2-element lists representing [dx dy] to translate for each image. The
+ shape must be static.
+
+ Returns:
+ The transformation matrix of shape (num_images, 8).
+
+ Raises:
+ `TypeError` if
+ - the shape of `translations` is not known or
+ - the shape of `translations` is not rank 1 or 2.
+
+ """
+ translations = tf.convert_to_tensor(translations, dtype=tf.float32)
+ if translations.get_shape().ndims is None:
+ raise TypeError('translations rank must be statically known')
+ elif len(translations.get_shape()) == 1:
+ translations = translations[None]
+ elif len(translations.get_shape()) != 2:
+ raise TypeError('translations should have rank 1 or 2.')
+ num_translations = tf.shape(translations)[0]
+
+ return tf.concat(
+ values=[
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 0, None],
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 1, None],
+ tf.zeros((num_translations, 2), tf.dtypes.float32),
+ ],
+ axis=1,
+ )
+
+
+def _convert_angles_to_transform(angles: tf.Tensor, image_width: tf.Tensor,
+ image_height: tf.Tensor) -> tf.Tensor:
+ """Converts an angle or angles to a projective transform.
+
+ Args:
+ angles: A scalar to rotate all images, or a vector to rotate a batch of
+ images. This must be a scalar.
+ image_width: The width of the image(s) to be transformed.
+ image_height: The height of the image(s) to be transformed.
+
+ Returns:
+ A tensor of shape (num_images, 8).
+
+ Raises:
+ `TypeError` if `angles` is not rank 0 or 1.
+
+ """
+ angles = tf.convert_to_tensor(angles, dtype=tf.float32)
+ if len(angles.get_shape()) == 0: # pylint:disable=g-explicit-length-test
+ angles = angles[None]
+ elif len(angles.get_shape()) != 1:
+ raise TypeError('Angles should have a rank 0 or 1.')
+ x_offset = ((image_width - 1) -
+ (tf.math.cos(angles) * (image_width - 1) - tf.math.sin(angles) *
+ (image_height - 1))) / 2.0
+ y_offset = ((image_height - 1) -
+ (tf.math.sin(angles) * (image_width - 1) + tf.math.cos(angles) *
+ (image_height - 1))) / 2.0
+ num_angles = tf.shape(angles)[0]
+ return tf.concat(
+ values=[
+ tf.math.cos(angles)[:, None],
+ -tf.math.sin(angles)[:, None],
+ x_offset[:, None],
+ tf.math.sin(angles)[:, None],
+ tf.math.cos(angles)[:, None],
+ y_offset[:, None],
+ tf.zeros((num_angles, 2), tf.dtypes.float32),
+ ],
+ axis=1,
+ )
+
+
+def apply_transform_to_images(
+ images,
+ transforms,
+ fill_mode='reflect',
+ fill_value=0.0,
+ interpolation='bilinear',
+ output_shape=None,
+ name=None,
+):
+ """Applies the given transform(s) to the image(s).
+
+ Args:
+ images: A tensor of shape `(num_images, num_rows, num_columns,
+ num_channels)` (NHWC). The rank must be statically known (the shape is
+ not `TensorShape(None)`).
+ transforms: Projective transform matrix/matrices. A vector of length 8 or
+ tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1,
+ b2, c0, c1], then it maps the *output* point `(x, y)` to a transformed
+ *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) /
+ k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared
+ to the transform mapping input points to output points. Note that
+ gradients are not backpropagated into transformation parameters.
+ fill_mode: Points outside the boundaries of the input are filled according
+ to the given mode (one of `{"constant", "reflect", "wrap", "nearest"}`).
+ fill_value: a float represents the value to be filled outside the
+ boundaries when `fill_mode="constant"`.
+ interpolation: Interpolation mode. Supported values: `"nearest"`,
+ `"bilinear"`.
+ output_shape: Output dimension after the transform, `[height, width]`. If
+ `None`, output is the same size as input image.
+ name: The name of the op. Fill mode behavior for each valid value is as
+ follows
+ - `"reflect"`: `(d c b a | a b c d | d c b a)` The input is extended by
+ reflecting about the edge of the last pixel.
+ - `"constant"`: `(k k k k | a b c d | k k k k)` The input is extended by
+ filling all values beyond the edge with the same constant value k = 0.
+ - `"wrap"`: `(a b c d | a b c d | a b c d)` The input is extended by
+ wrapping around to the opposite edge.
+ - `"nearest"`: `(a a a a | a b c d | d d d d)` The input is extended by
+ the nearest pixel. Input shape: 4D tensor with shape:
+ `(samples, height, width, channels)`, in `"channels_last"` format.
+ Output shape: 4D tensor with shape: `(samples, height, width, channels)`,
+ in `"channels_last"` format.
+
+ Returns:
+ Image(s) with the same type and shape as `images`, with the given
+ transform(s) applied. Transformed coordinates outside of the input image
+ will be filled with zeros.
+ """
+ with tf.name_scope(name or 'transform'):
+ if output_shape is None:
+ output_shape = tf.shape(images)[1:3]
+ if not tf.executing_eagerly():
+ output_shape_value = tf.get_static_value(output_shape)
+ if output_shape_value is not None:
+ output_shape = output_shape_value
+
+ output_shape = tf.convert_to_tensor(
+ output_shape, tf.int32, name='output_shape'
+ )
+
+ if not output_shape.get_shape().is_compatible_with([2]):
+ raise ValueError(
+ 'output_shape must be a 1-D Tensor of 2 elements: '
+ 'new_height, new_width, instead got '
+ f'output_shape={output_shape}'
+ )
+
+ fill_value = tf.convert_to_tensor(fill_value, tf.float32, name='fill_value')
+
+ return tf.raw_ops.ImageProjectiveTransformV3(
+ images=images,
+ output_shape=output_shape,
+ fill_value=fill_value,
+ transforms=transforms,
+ fill_mode=fill_mode.upper(),
+ interpolation=interpolation.upper(),
+ )
+
+
+def transform(image: tf.Tensor, transforms) -> tf.Tensor:
+ """Prepares input data for `image_ops.transform`."""
+ original_ndims = tf.rank(image)
+ transforms = tf.convert_to_tensor(transforms, dtype=tf.float32)
+ if transforms.shape.rank == 1:
+ transforms = transforms[None]
+ image = to_4d(image)
+ image = apply_transform_to_images(
+ images=image, transforms=transforms, interpolation='nearest'
+ )
+ return from_4d(image, original_ndims)
+
+
+def translate(image: tf.Tensor, translations) -> tf.Tensor:
+ """Translates image(s) by provided vectors.
+
+ Args:
+ image: An image Tensor of type uint8.
+ translations: A vector or matrix representing [dx dy].
+
+ Returns:
+ The translated version of the image.
+
+ """
+ transforms = _convert_translation_to_transform(translations) # pytype: disable=wrong-arg-types # always-use-return-annotations
+ return transform(image, transforms=transforms)
+
+
+def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
+ """Rotates the image by degrees either clockwise or counterclockwise.
+
+ Args:
+ image: An image Tensor of type uint8.
+ degrees: Float, a scalar angle in degrees to rotate all images by. If
+ degrees is positive the image will be rotated clockwise otherwise it will
+ be rotated counterclockwise.
+
+ Returns:
+ The rotated version of image.
+
+ """
+ # Convert from degrees to radians.
+ degrees_to_radians = math.pi / 180.0
+ radians = tf.cast(degrees * degrees_to_radians, tf.float32)
+
+ original_ndims = tf.rank(image)
+ image = to_4d(image)
+
+ image_height = tf.cast(tf.shape(image)[1], tf.float32)
+ image_width = tf.cast(tf.shape(image)[2], tf.float32)
+ transforms = _convert_angles_to_transform(
+ angles=radians, image_width=image_width, image_height=image_height)
+ # In practice, we should randomize the rotation degrees by flipping
+ # it negatively half the time, but that's done on 'degrees' outside
+ # of the function.
+ image = transform(image, transforms=transforms)
+ return from_4d(image, original_ndims)
+
+
+def blend(image1: tf.Tensor, image2: tf.Tensor, factor: float) -> tf.Tensor:
+ """Blend image1 and image2 using 'factor'.
+
+ Factor can be above 0.0. A value of 0.0 means only image1 is used.
+ A value of 1.0 means only image2 is used. A value between 0.0 and
+ 1.0 means we linearly interpolate the pixel values between the two
+ images. A value greater than 1.0 "extrapolates" the difference
+ between the two pixel values, and we clip the results to values
+ between 0 and 255.
+
+ Args:
+ image1: An image Tensor of type uint8.
+ image2: An image Tensor of type uint8.
+ factor: A floating point value above 0.0.
+
+ Returns:
+ A blended image Tensor of type uint8.
+ """
+ if factor == 0.0:
+ return tf.convert_to_tensor(image1)
+ if factor == 1.0:
+ return tf.convert_to_tensor(image2)
+
+ image1 = tf.cast(image1, tf.float32)
+ image2 = tf.cast(image2, tf.float32)
+
+ difference = image2 - image1
+ scaled = factor * difference
+
+ # Do addition in float.
+ temp = tf.cast(image1, tf.float32) + scaled
+
+ # Interpolate
+ if factor > 0.0 and factor < 1.0:
+ # Interpolation means we always stay within 0 and 255.
+ return tf.cast(temp, tf.uint8)
+
+ # Extrapolate:
+ #
+ # We need to clip and then cast.
+ return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)
+
+
+def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
+ """Apply cutout (https://arxiv.org/abs/1708.04552) to image.
+
+ This operation applies a (2*pad_size x 2*pad_size) mask of zeros to
+ a random location within `img`. The pixel values filled in will be of the
+ value `replace`. The located where the mask will be applied is randomly
+ chosen uniformly over the whole image.
+
+ Args:
+ image: An image Tensor of type uint8.
+ pad_size: Specifies how big the zero mask that will be generated is that is
+ applied to the image. The mask will be of size (2*pad_size x 2*pad_size).
+ replace: What pixel value to fill in the image in the area that has the
+ cutout mask applied to it.
+
+ Returns:
+ An image Tensor that is of type uint8.
+ """
+ image_height = tf.shape(image)[0]
+ image_width = tf.shape(image)[1]
+
+ # Sample the center location in the image where the zero mask will be applied.
+ cutout_center_height = tf.random.uniform(
+ shape=[], minval=0, maxval=image_height, dtype=tf.int32)
+
+ cutout_center_width = tf.random.uniform(
+ shape=[], minval=0, maxval=image_width, dtype=tf.int32)
+
+ lower_pad = tf.maximum(0, cutout_center_height - pad_size)
+ upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
+ left_pad = tf.maximum(0, cutout_center_width - pad_size)
+ right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
+
+ cutout_shape = [
+ image_height - (lower_pad + upper_pad),
+ image_width - (left_pad + right_pad)
+ ]
+ padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
+ mask = tf.pad(
+ tf.zeros(cutout_shape, dtype=image.dtype),
+ padding_dims,
+ constant_values=1)
+ mask = tf.expand_dims(mask, -1)
+ mask = tf.tile(mask, [1, 1, 3])
+ image = tf.where(
+ tf.equal(mask, 0),
+ tf.ones_like(image, dtype=image.dtype) * replace, image)
+ return image
+
+
+def solarize(image: tf.Tensor, threshold: int = 128) -> tf.Tensor:
+ # For each pixel in the image, select the pixel
+ # if the value is less than the threshold.
+ # Otherwise, subtract 255 from the pixel.
+ return tf.where(image < threshold, image, 255 - image)
+
+
+def solarize_add(image: tf.Tensor,
+ addition: int = 0,
+ threshold: int = 128) -> tf.Tensor:
+ # For each pixel in the image less than threshold
+ # we add 'addition' amount to it and then clip the
+ # pixel value to be between 0 and 255. The value
+ # of 'addition' is between -128 and 128.
+ added_image = tf.cast(image, tf.int64) + addition
+ added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8)
+ return tf.where(image < threshold, added_image, image)
+
+
+def color(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Equivalent of PIL Color."""
+ degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
+ return blend(degenerate, image, factor)
+
+
+def contrast(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Equivalent of PIL Contrast."""
+ degenerate = tf.image.rgb_to_grayscale(image)
+ # Cast before calling tf.histogram.
+ degenerate = tf.cast(degenerate, tf.int32)
+
+ # Compute the grayscale histogram, then compute the mean pixel value,
+ # and create a constant image size of that value. Use that as the
+ # blending degenerate target of the original image.
+ hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
+ mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
+ degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
+ degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
+ degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
+ return blend(degenerate, image, factor)
+
+
+def brightness(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Equivalent of PIL Brightness."""
+ degenerate = tf.zeros_like(image)
+ return blend(degenerate, image, factor)
+
+
+def posterize(image: tf.Tensor, bits: int) -> tf.Tensor:
+ """Equivalent of PIL Posterize."""
+ shift = 8 - bits
+ return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift)
+
+
+def wrapped_rotate(image: tf.Tensor, degrees: float, replace: int) -> tf.Tensor:
+ """Applies rotation with wrap/unwrap."""
+ image = rotate(wrap(image), degrees=degrees)
+ return unwrap(image, replace)
+
+
+def translate_x(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Translate in X dimension."""
+ image = translate(wrap(image), [-pixels, 0])
+ return unwrap(image, replace)
+
+
+def translate_y(image: tf.Tensor, pixels: int, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Translate in Y dimension."""
+ image = translate(wrap(image), [0, -pixels])
+ return unwrap(image, replace)
+
+
+def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Shearing in X dimension."""
+ # Shear parallel to x axis is a projective transform
+ # with a matrix form of:
+ # [1 level
+ # 0 1].
+ image = transform(
+ image=wrap(image), transforms=[1., level, 0., 0., 1., 0., 0., 0.])
+ return unwrap(image, replace)
+
+
+def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
+ """Equivalent of PIL Shearing in Y dimension."""
+ # Shear parallel to y axis is a projective transform
+ # with a matrix form of:
+ # [1 0
+ # level 1].
+ image = transform(
+ image=wrap(image), transforms=[1., 0., 0., level, 1., 0., 0., 0.])
+ return unwrap(image, replace)
+
+
+def autocontrast(image: tf.Tensor) -> tf.Tensor:
+ """Implements Autocontrast function from PIL using TF ops.
+
+ Args:
+ image: A 3D uint8 tensor.
+
+ Returns:
+ The image after it has had autocontrast applied to it and will be of type
+ uint8.
+ """
+
+ def scale_channel(image: tf.Tensor) -> tf.Tensor:
+ """Scale the 2D image using the autocontrast rule."""
+ # A possibly cheaper version can be done using cumsum/unique_with_counts
+ # over the histogram values, rather than iterating over the entire image.
+ # to compute mins and maxes.
+ lo = tf.cast(tf.reduce_min(image), tf.float32)
+ hi = tf.cast(tf.reduce_max(image), tf.float32)
+
+ # Scale the image, making the lowest value 0 and the highest value 255.
+ def scale_values(im):
+ scale = 255.0 / (hi - lo)
+ offset = -lo * scale
+ im = tf.cast(im, tf.float32) * scale + offset
+ im = tf.clip_by_value(im, 0.0, 255.0)
+ return tf.cast(im, tf.uint8)
+
+ result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
+ return result
+
+ # Assumes RGB for now. Scales each channel independently
+ # and then stacks the result.
+ s1 = scale_channel(image[:, :, 0])
+ s2 = scale_channel(image[:, :, 1])
+ s3 = scale_channel(image[:, :, 2])
+ image = tf.stack([s1, s2, s3], 2)
+ return image
+
+
+def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
+ """Implements Sharpness function from PIL using TF ops."""
+ orig_image = image
+ image = tf.cast(image, tf.float32)
+ # Make image 4D for conv operation.
+ image = tf.expand_dims(image, 0)
+ # SMOOTH PIL Kernel.
+ kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
+ dtype=tf.float32,
+ shape=[3, 3, 1, 1]) / 13.
+ # Tile across channel dimension.
+ kernel = tf.tile(kernel, [1, 1, 3, 1])
+ strides = [1, 1, 1, 1]
+ degenerate = tf.nn.depthwise_conv2d(
+ image, kernel, strides, padding='VALID', dilations=[1, 1])
+ degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
+ degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
+
+ # For the borders of the resulting image, fill in the values of the
+ # original image.
+ mask = tf.ones_like(degenerate)
+ padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
+ padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
+ result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
+
+ # Blend the final result.
+ return blend(result, orig_image, factor)
+
+
+def equalize(image: tf.Tensor) -> tf.Tensor:
+ """Implements Equalize function from PIL using TF ops."""
+
+ def scale_channel(im, c):
+ """Scale the data in the channel to implement equalize."""
+ im = tf.cast(im[:, :, c], tf.int32)
+ # Compute the histogram of the image channel.
+ histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)
+
+ # For the purposes of computing the step, filter out the nonzeros.
+ nonzero = tf.where(tf.not_equal(histo, 0))
+ nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
+ step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255
+
+ def build_lut(histo, step):
+ # Compute the cumulative sum, shifting by step // 2
+ # and then normalization by step.
+ lut = (tf.cumsum(histo) + (step // 2)) // step
+ # Shift lut, prepending with 0.
+ lut = tf.concat([[0], lut[:-1]], 0)
+ # Clip the counts to be in range. This is done
+ # in the C code for image.point.
+ return tf.clip_by_value(lut, 0, 255)
+
+ # If step is zero, return the original image. Otherwise, build
+ # lut from the full histogram and step and then index from it.
+ result = tf.cond(
+ tf.equal(step, 0), lambda: im,
+ lambda: tf.gather(build_lut(histo, step), im))
+
+ return tf.cast(result, tf.uint8)
+
+ # Assumes RGB for now. Scales each channel independently
+ # and then stacks the result.
+ s1 = scale_channel(image, 0)
+ s2 = scale_channel(image, 1)
+ s3 = scale_channel(image, 2)
+ image = tf.stack([s1, s2, s3], 2)
+ return image
+
+
+def invert(image: tf.Tensor) -> tf.Tensor:
+ """Inverts the image pixels."""
+ image = tf.convert_to_tensor(image)
+ return 255 - image
+
+
+def wrap(image: tf.Tensor) -> tf.Tensor:
+ """Returns 'image' with an extra channel set to all 1s."""
+ shape = tf.shape(image)
+ extended_channel = tf.ones([shape[0], shape[1], 1], image.dtype)
+ extended = tf.concat([image, extended_channel], axis=2)
+ return extended
+
+
+def unwrap(image: tf.Tensor, replace: int) -> tf.Tensor:
+ """Unwraps an image produced by wrap.
+
+ Where there is a 0 in the last channel for every spatial position,
+ the rest of the three channels in that spatial dimension are grayed
+ (set to 128). Operations like translate and shear on a wrapped
+ Tensor will leave 0s in empty locations. Some transformations look
+ at the intensity of values to do preprocessing, and we want these
+ empty pixels to assume the 'average' value, rather than pure black.
+
+
+ Args:
+ image: A 3D Image Tensor with 4 channels.
+ replace: A one or three value 1D tensor to fill empty pixels.
+
+ Returns:
+ image: A 3D image Tensor with 3 channels.
+ """
+ image_shape = tf.shape(image)
+ # Flatten the spatial dimensions.
+ flattened_image = tf.reshape(image, [-1, image_shape[2]])
+
+ # Find all pixels where the last channel is zero.
+ alpha_channel = tf.expand_dims(flattened_image[:, 3], axis=-1)
+
+ replace = tf.concat([replace, tf.ones([1], image.dtype)], 0)
+
+ # Where they are zero, fill them in with 'replace'.
+ flattened_image = tf.where(
+ tf.equal(alpha_channel, 0),
+ tf.ones_like(flattened_image, dtype=image.dtype) * replace,
+ flattened_image)
+
+ image = tf.reshape(flattened_image, image_shape)
+ image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
+ return image
+
+
+def _randomly_negate_tensor(tensor):
+ """With 50% prob turn the tensor negative."""
+ should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool)
+ final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor)
+ return final_tensor
+
+
+def _rotate_level_to_arg(level: float):
+ level = (level / _MAX_LEVEL) * 30.
+ level = _randomly_negate_tensor(level)
+ return (level,)
+
+
+def _shrink_level_to_arg(level: float):
+ """Converts level to ratio by which we shrink the image content."""
+ if level == 0:
+ return (1.0,) # if level is zero, do not shrink the image
+ # Maximum shrinking ratio is 2.9.
+ level = 2. / (_MAX_LEVEL / level) + 0.9
+ return (level,)
+
+
+def _enhance_level_to_arg(level: float):
+ return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
+
+
+def _shear_level_to_arg(level: float):
+ level = (level / _MAX_LEVEL) * 0.3
+ # Flip level to negative with 50% chance.
+ level = _randomly_negate_tensor(level)
+ return (level,)
+
+
+def _translate_level_to_arg(level: float, translate_const: float):
+ level = (level / _MAX_LEVEL) * float(translate_const)
+ # Flip level to negative with 50% chance.
+ level = _randomly_negate_tensor(level)
+ return (level,)
+
+
+def _mult_to_arg(level: float, multiplier: float = 1.):
+ return (int((level / _MAX_LEVEL) * multiplier),)
+
+
+def _apply_func_with_prob(func: Any, image: tf.Tensor, args: Any, prob: float):
+ """Apply `func` to image w/ `args` as input with probability `prob`."""
+ assert isinstance(args, tuple)
+
+ # Apply the function with probability `prob`.
+ should_apply_op = tf.cast(
+ tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
+ augmented_image = tf.cond(should_apply_op, lambda: func(image, *args),
+ lambda: image)
+ return augmented_image
+
+
+def select_and_apply_random_policy(policies: Any, image: tf.Tensor):
+ """Select a random policy from `policies` and apply it to `image`."""
+ policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32)
+ # Note that using tf.case instead of tf.conds would result in significantly
+ # larger graphs and would even break export for some larger policies.
+ for (i, policy) in enumerate(policies):
+ image = tf.cond(
+ tf.equal(i, policy_to_select),
+ lambda selected_policy=policy: selected_policy(image),
+ lambda: image)
+ return image
+
+
+NAME_TO_FUNC = {
+ 'AutoContrast': autocontrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Rotate': wrapped_rotate,
+ 'Posterize': posterize,
+ 'Solarize': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'Contrast': contrast,
+ 'Brightness': brightness,
+ 'Sharpness': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x,
+ 'TranslateY': translate_y,
+ 'Cutout': cutout,
+}
+
+# Functions that have a 'replace' parameter
+REPLACE_FUNCS = frozenset({
+ 'Rotate',
+ 'TranslateX',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateY',
+ 'Cutout',
+})
+
+
+def level_to_arg(cutout_const: float, translate_const: float):
+ """Creates a dict mapping image operation names to their arguments."""
+
+ no_arg = lambda level: ()
+ posterize_arg = lambda level: _mult_to_arg(level, 4)
+ solarize_arg = lambda level: _mult_to_arg(level, 256)
+ solarize_add_arg = lambda level: _mult_to_arg(level, 110)
+ cutout_arg = lambda level: _mult_to_arg(level, cutout_const)
+ translate_arg = lambda level: _translate_level_to_arg(level, translate_const)
+
+ args = {
+ 'AutoContrast': no_arg,
+ 'Equalize': no_arg,
+ 'Invert': no_arg,
+ 'Rotate': _rotate_level_to_arg,
+ 'Posterize': posterize_arg,
+ 'Solarize': solarize_arg,
+ 'SolarizeAdd': solarize_add_arg,
+ 'Color': _enhance_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'Cutout': cutout_arg,
+ 'TranslateX': translate_arg,
+ 'TranslateY': translate_arg,
+ }
+ return args
+
+
+def _parse_policy_info(name: Text, prob: float, level: float,
+ replace_value: List[int], cutout_const: float,
+ translate_const: float) -> Tuple[Any, float, Any]:
+ """Return the function that corresponds to `name` and update `level` param."""
+ func = NAME_TO_FUNC[name]
+ args = level_to_arg(cutout_const, translate_const)[name](level)
+
+ if name in REPLACE_FUNCS:
+ # Add in replace arg if it is required for the function that is called.
+ args = tuple(list(args) + [replace_value])
+
+ return func, prob, args
+
+
+class ImageAugment(object):
+ """Image augmentation class for applying image distortions."""
+
+ def distort(self, image: tf.Tensor) -> tf.Tensor:
+ """Given an image tensor, returns a distorted image with the same shape.
+
+ Args:
+ image: `Tensor` of shape [height, width, 3] representing an image.
+
+ Returns:
+ The augmented version of `image`.
+ """
+ raise NotImplementedError()
+
+
+class AutoAugment(ImageAugment):
+ """Applies the AutoAugment policy to images.
+
+ AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
+ """
+
+ def __init__(self,
+ augmentation_name: Text = 'v0',
+ policies: Optional[Dict[Text, Any]] = None,
+ cutout_const: float = 100,
+ translate_const: float = 250):
+ """Applies the AutoAugment policy to images.
+
+ Args:
+ augmentation_name: The name of the AutoAugment policy to use. The
+ available options are `v0` and `test`. `v0` is the policy used for all
+ of the results in the paper and was found to achieve the best results on
+ the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
+ on the COCO dataset that have slight variation in what operations were
+ used during the search procedure along with how many operations are
+ applied in parallel to a single image (2 vs 3).
+ policies: list of lists of tuples in the form `(func, prob, level)`,
+ `func` is a string name of the augmentation function, `prob` is the
+ probability of applying the `func` operation, `level` is the input
+ argument for `func`.
+ cutout_const: multiplier for applying cutout.
+ translate_const: multiplier for applying translation.
+ """
+ super(AutoAugment, self).__init__()
+
+ if policies is None:
+ self.available_policies = {
+ 'v0': self.policy_v0(),
+ 'test': self.policy_test(),
+ 'simple': self.policy_simple(),
+ }
+
+ if augmentation_name not in self.available_policies:
+ raise ValueError(
+ 'Invalid augmentation_name: {}'.format(augmentation_name))
+
+ self.augmentation_name = augmentation_name
+ self.policies = self.available_policies[augmentation_name]
+ self.cutout_const = float(cutout_const)
+ self.translate_const = float(translate_const)
+
+ def distort(self, image: tf.Tensor) -> tf.Tensor:
+ """Applies the AutoAugment policy to `image`.
+
+ AutoAugment is from the paper: https://arxiv.org/abs/1805.09501.
+
+ Args:
+ image: `Tensor` of shape [height, width, 3] representing an image.
+
+ Returns:
+ A version of image that now has data augmentation applied to it based on
+ the `policies` pass into the function.
+ """
+ input_image_type = image.dtype
+
+ if input_image_type != tf.uint8:
+ image = tf.clip_by_value(image, 0.0, 255.0)
+ image = tf.cast(image, dtype=tf.uint8)
+
+ replace_value = [128] * 3
+
+ # func is the string name of the augmentation function, prob is the
+ # probability of applying the operation and level is the parameter
+ # associated with the tf op.
+
+ # tf_policies are functions that take in an image and return an augmented
+ # image.
+ tf_policies = []
+ for policy in self.policies:
+ tf_policy = []
+ # Link string name to the correct python function and make sure the
+ # correct argument is passed into that function.
+ for policy_info in policy:
+ policy_info = list(policy_info) + [
+ replace_value, self.cutout_const, self.translate_const
+ ]
+ tf_policy.append(_parse_policy_info(*policy_info))
+ # Now build the tf policy that will apply the augmentation procedue
+ # on image.
+ def make_final_policy(tf_policy_):
+
+ def final_policy(image_):
+ for func, prob, args in tf_policy_:
+ image_ = _apply_func_with_prob(func, image_, args, prob)
+ return image_
+
+ return final_policy
+
+ tf_policies.append(make_final_policy(tf_policy))
+
+ image = select_and_apply_random_policy(tf_policies, image)
+ image = tf.cast(image, dtype=input_image_type)
+ return image
+
+ @staticmethod
+ def policy_v0():
+ """Autoaugment policy that was used in AutoAugment Paper.
+
+ Each tuple is an augmentation operation of the form
+ (operation, probability, magnitude). Each element in policy is a
+ sub-policy that will be applied sequentially on the image.
+
+ Returns:
+ the policy.
+ """
+
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateY', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ return policy
+
+ @staticmethod
+ def policy_simple():
+ """Same as `policy_v0`, except with custom ops removed."""
+
+ policy = [
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)],
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ ]
+ return policy
+
+ @staticmethod
+ def policy_test():
+ """Autoaugment test policy for debugging."""
+ policy = [
+ [('TranslateX', 1.0, 4), ('Equalize', 1.0, 10)],
+ ]
+ return policy
+
+
+class RandAugment(ImageAugment):
+ """Applies the RandAugment policy to images.
+
+ RandAugment is from the paper https://arxiv.org/abs/1909.13719,
+ """
+
+ def __init__(self,
+ num_layers: int = 2,
+ magnitude: float = 10.,
+ cutout_const: float = 40.,
+ translate_const: float = 100.):
+ """Applies the RandAugment policy to images.
+
+ Args:
+ num_layers: Integer, the number of augmentation transformations to apply
+ sequentially to an image. Represented as (N) in the paper. Usually best
+ values will be in the range [1, 3].
+ magnitude: Integer, shared magnitude across all augmentation operations.
+ Represented as (M) in the paper. Usually best values are in the range
+ [5, 10].
+ cutout_const: multiplier for applying cutout.
+ translate_const: multiplier for applying translation.
+ """
+ super(RandAugment, self).__init__()
+
+ self.num_layers = num_layers
+ self.magnitude = float(magnitude)
+ self.cutout_const = float(cutout_const)
+ self.translate_const = float(translate_const)
+ self.available_ops = [
+ 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize',
+ 'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
+ 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
+ ]
+
+ def distort(self, image: tf.Tensor) -> tf.Tensor:
+ """Applies the RandAugment policy to `image`.
+
+ Args:
+ image: `Tensor` of shape [height, width, 3] representing an image.
+
+ Returns:
+ The augmented version of `image`.
+ """
+ input_image_type = image.dtype
+
+ if input_image_type != tf.uint8:
+ image = tf.clip_by_value(image, 0.0, 255.0)
+ image = tf.cast(image, dtype=tf.uint8)
+
+ replace_value = [128] * 3
+ min_prob, max_prob = 0.2, 0.8
+
+ for _ in range(self.num_layers):
+ op_to_select = tf.random.uniform([],
+ maxval=len(self.available_ops) + 1,
+ dtype=tf.int32)
+
+ branch_fns = []
+ for (i, op_name) in enumerate(self.available_ops):
+ prob = tf.random.uniform([],
+ minval=min_prob,
+ maxval=max_prob,
+ dtype=tf.float32)
+ func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
+ replace_value, self.cutout_const,
+ self.translate_const)
+ branch_fns.append((
+ i,
+ # pylint:disable=g-long-lambda
+ lambda selected_func=func, selected_args=args: selected_func(
+ image, *selected_args)))
+ # pylint:enable=g-long-lambda
+
+ image = tf.switch_case(
+ branch_index=op_to_select,
+ branch_fns=branch_fns,
+ default=lambda: tf.identity(image))
+
+ image = tf.cast(image, dtype=input_image_type)
+ return image
diff --git a/modeling/official/legacy/image_classification/augment_test.py b/modeling/official/legacy/image_classification/augment_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..85cb2a3fdf04de175e6a0e3701d550c280c03f2a
--- /dev/null
+++ b/modeling/official/legacy/image_classification/augment_test.py
@@ -0,0 +1,129 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for autoaugment."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.image_classification import augment
+
+
+def get_dtype_test_cases():
+ return [
+ ('uint8', tf.uint8),
+ ('int32', tf.int32),
+ ('float16', tf.float16),
+ ('float32', tf.float32),
+ ]
+
+
+@parameterized.named_parameters(get_dtype_test_cases())
+class TransformsTest(parameterized.TestCase, tf.test.TestCase):
+ """Basic tests for fundamental transformations."""
+
+ def test_to_from_4d(self, dtype):
+ for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]:
+ original_ndims = len(shape)
+ image = tf.zeros(shape, dtype=dtype)
+ image_4d = augment.to_4d(image)
+ self.assertEqual(4, tf.rank(image_4d))
+ self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims))
+
+ def test_transform(self, dtype):
+ image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
+ self.assertAllEqual(
+ augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
+
+ def test_translate(self, dtype):
+ image = tf.constant(
+ [[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
+ translations = [-1, -1]
+ translated = augment.translate(image=image, translations=translations)
+ expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
+ self.assertAllEqual(translated, expected)
+
+ def test_translate_shapes(self, dtype):
+ translation = [0, 0]
+ for shape in [(3, 3), (5, 5), (224, 224, 3)]:
+ image = tf.zeros(shape, dtype=dtype)
+ self.assertAllEqual(image, augment.translate(image, translation))
+
+ def test_translate_invalid_translation(self, dtype):
+ image = tf.zeros((1, 1), dtype=dtype)
+ invalid_translation = [[[1, 1]]]
+ with self.assertRaisesRegex(TypeError, 'rank 1 or 2'):
+ _ = augment.translate(image, invalid_translation)
+
+ def test_rotate(self, dtype):
+ image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
+ rotation = 90.
+ transformed = augment.rotate(image=image, degrees=rotation)
+ expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
+ self.assertAllEqual(transformed, expected)
+
+ def test_rotate_shapes(self, dtype):
+ degrees = 0.
+ for shape in [(3, 3), (5, 5), (224, 224, 3)]:
+ image = tf.zeros(shape, dtype=dtype)
+ self.assertAllEqual(image, augment.rotate(image, degrees))
+
+
+class AutoaugmentTest(tf.test.TestCase):
+
+ def test_autoaugment(self):
+ """Smoke test to be sure there are no syntax errors."""
+ image = tf.zeros((224, 224, 3), dtype=tf.uint8)
+
+ augmenter = augment.AutoAugment()
+ aug_image = augmenter.distort(image)
+
+ self.assertEqual((224, 224, 3), aug_image.shape)
+
+ def test_randaug(self):
+ """Smoke test to be sure there are no syntax errors."""
+ image = tf.zeros((224, 224, 3), dtype=tf.uint8)
+
+ augmenter = augment.RandAugment()
+ aug_image = augmenter.distort(image)
+
+ self.assertEqual((224, 224, 3), aug_image.shape)
+
+ def test_all_policy_ops(self):
+ """Smoke test to be sure all augmentation functions can execute."""
+
+ prob = 1
+ magnitude = 10
+ replace_value = [128] * 3
+ cutout_const = 100
+ translate_const = 250
+
+ image = tf.ones((224, 224, 3), dtype=tf.uint8)
+
+ for op_name in augment.NAME_TO_FUNC:
+ func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
+ replace_value, cutout_const,
+ translate_const)
+ image = func(image, *args)
+
+ self.assertEqual((224, 224, 3), image.shape)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/image_classification/callbacks.py b/modeling/official/legacy/image_classification/callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0301fcc9beec05750531269dbb8cdb8b563cd61
--- /dev/null
+++ b/modeling/official/legacy/image_classification/callbacks.py
@@ -0,0 +1,255 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common modules for callbacks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+from typing import Any, List, MutableMapping, Optional, Text
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.modeling import optimization
+from official.utils.misc import keras_utils
+
+
+def get_callbacks(
+ model_checkpoint: bool = True,
+ include_tensorboard: bool = True,
+ time_history: bool = True,
+ track_lr: bool = True,
+ write_model_weights: bool = True,
+ apply_moving_average: bool = False,
+ initial_step: int = 0,
+ batch_size: int = 0,
+ log_steps: int = 0,
+ model_dir: Optional[str] = None,
+ backup_and_restore: bool = False) -> List[tf_keras.callbacks.Callback]:
+ """Get all callbacks."""
+ model_dir = model_dir or ''
+ callbacks = []
+ if model_checkpoint:
+ ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
+ callbacks.append(
+ tf_keras.callbacks.ModelCheckpoint(
+ ckpt_full_path, save_weights_only=True, verbose=1))
+ if backup_and_restore:
+ backup_dir = os.path.join(model_dir, 'tmp')
+ callbacks.append(
+ tf_keras.callbacks.experimental.BackupAndRestore(backup_dir))
+ if include_tensorboard:
+ callbacks.append(
+ CustomTensorBoard(
+ log_dir=model_dir,
+ track_lr=track_lr,
+ initial_step=initial_step,
+ write_images=write_model_weights,
+ profile_batch=0))
+ if time_history:
+ callbacks.append(
+ keras_utils.TimeHistory(
+ batch_size,
+ log_steps,
+ logdir=model_dir if include_tensorboard else None))
+ if apply_moving_average:
+ # Save moving average model to a different file so that
+ # we can resume training from a checkpoint
+ ckpt_full_path = os.path.join(model_dir, 'average',
+ 'model.ckpt-{epoch:04d}')
+ callbacks.append(
+ AverageModelCheckpoint(
+ update_weights=False,
+ filepath=ckpt_full_path,
+ save_weights_only=True,
+ verbose=1))
+ callbacks.append(MovingAverageCallback())
+ return callbacks
+
+
+def get_scalar_from_tensor(t: tf.Tensor) -> int:
+ """Utility function to convert a Tensor to a scalar."""
+ t = tf_keras.backend.get_value(t)
+ if callable(t):
+ return t()
+ else:
+ return t
+
+
+class CustomTensorBoard(tf_keras.callbacks.TensorBoard):
+ """A customized TensorBoard callback that tracks additional datapoints.
+
+ Metrics tracked:
+ - Global learning rate
+
+ Attributes:
+ log_dir: the path of the directory where to save the log files to be parsed
+ by TensorBoard.
+ track_lr: `bool`, whether or not to track the global learning rate.
+ initial_step: the initial step, used for preemption recovery.
+ **kwargs: Additional arguments for backwards compatibility. Possible key is
+ `period`.
+ """
+
+ # TODO(b/146499062): track params, flops, log lr, l2 loss,
+ # classification loss
+
+ def __init__(self,
+ log_dir: str,
+ track_lr: bool = False,
+ initial_step: int = 0,
+ **kwargs):
+ super(CustomTensorBoard, self).__init__(log_dir=log_dir, **kwargs)
+ self.step = initial_step
+ self._track_lr = track_lr
+
+ def on_batch_begin(self,
+ epoch: int,
+ logs: Optional[MutableMapping[str, Any]] = None) -> None:
+ self.step += 1
+ if logs is None:
+ logs = {}
+ logs.update(self._calculate_metrics())
+ super(CustomTensorBoard, self).on_batch_begin(epoch, logs)
+
+ def on_epoch_begin(self,
+ epoch: int,
+ logs: Optional[MutableMapping[str, Any]] = None) -> None:
+ if logs is None:
+ logs = {}
+ metrics = self._calculate_metrics()
+ logs.update(metrics)
+ for k, v in metrics.items():
+ logging.info('Current %s: %f', k, v)
+ super(CustomTensorBoard, self).on_epoch_begin(epoch, logs)
+
+ def on_epoch_end(self,
+ epoch: int,
+ logs: Optional[MutableMapping[str, Any]] = None) -> None:
+ if logs is None:
+ logs = {}
+ metrics = self._calculate_metrics()
+ logs.update(metrics)
+ super(CustomTensorBoard, self).on_epoch_end(epoch, logs)
+
+ def _calculate_metrics(self) -> MutableMapping[str, Any]:
+ logs = {}
+ # TODO(b/149030439): disable LR reporting.
+ # if self._track_lr:
+ # logs['learning_rate'] = self._calculate_lr()
+ return logs
+
+ def _calculate_lr(self) -> int:
+ """Calculates the learning rate given the current step."""
+ return get_scalar_from_tensor(
+ self._get_base_optimizer()._decayed_lr(var_dtype=tf.float32)) # pylint:disable=protected-access
+
+ def _get_base_optimizer(self) -> tf_keras.optimizers.Optimizer:
+ """Get the base optimizer used by the current model."""
+
+ optimizer = self.model.optimizer
+
+ # The optimizer might be wrapped by another class, so unwrap it
+ while hasattr(optimizer, '_optimizer'):
+ optimizer = optimizer._optimizer # pylint:disable=protected-access
+
+ return optimizer
+
+
+class MovingAverageCallback(tf_keras.callbacks.Callback):
+ """A Callback to be used with a `ExponentialMovingAverage` optimizer.
+
+ Applies moving average weights to the model during validation time to test
+ and predict on the averaged weights rather than the current model weights.
+ Once training is complete, the model weights will be overwritten with the
+ averaged weights (by default).
+
+ Attributes:
+ overwrite_weights_on_train_end: Whether to overwrite the current model
+ weights with the averaged weights from the moving average optimizer.
+ **kwargs: Any additional callback arguments.
+ """
+
+ def __init__(self, overwrite_weights_on_train_end: bool = False, **kwargs):
+ super(MovingAverageCallback, self).__init__(**kwargs)
+ self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
+
+ def set_model(self, model: tf_keras.Model):
+ super(MovingAverageCallback, self).set_model(model)
+ assert isinstance(self.model.optimizer,
+ optimization.ExponentialMovingAverage)
+ self.model.optimizer.shadow_copy(self.model)
+
+ def on_test_begin(self, logs: Optional[MutableMapping[Text, Any]] = None):
+ self.model.optimizer.swap_weights()
+
+ def on_test_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
+ self.model.optimizer.swap_weights()
+
+ def on_train_end(self, logs: Optional[MutableMapping[Text, Any]] = None):
+ if self.overwrite_weights_on_train_end:
+ self.model.optimizer.assign_average_vars(self.model.variables)
+
+
+class AverageModelCheckpoint(tf_keras.callbacks.ModelCheckpoint):
+ """Saves and, optionally, assigns the averaged weights.
+
+ Taken from tfa.callbacks.AverageModelCheckpoint.
+
+ Attributes:
+ update_weights: If True, assign the moving average weights to the model, and
+ save them. If False, keep the old non-averaged weights, but the saved
+ model uses the average weights. See `tf_keras.callbacks.ModelCheckpoint`
+ for the other args.
+ """
+
+ def __init__(self,
+ update_weights: bool,
+ filepath: str,
+ monitor: str = 'val_loss',
+ verbose: int = 0,
+ save_best_only: bool = False,
+ save_weights_only: bool = False,
+ mode: str = 'auto',
+ save_freq: str = 'epoch',
+ **kwargs):
+ self.update_weights = update_weights
+ super().__init__(filepath, monitor, verbose, save_best_only,
+ save_weights_only, mode, save_freq, **kwargs)
+
+ def set_model(self, model):
+ if not isinstance(model.optimizer, optimization.ExponentialMovingAverage):
+ raise TypeError('AverageModelCheckpoint is only used when training'
+ 'with MovingAverage')
+ return super().set_model(model)
+
+ def _save_model(self, epoch, logs):
+ assert isinstance(self.model.optimizer,
+ optimization.ExponentialMovingAverage)
+
+ if self.update_weights:
+ self.model.optimizer.assign_average_vars(self.model.variables)
+ return super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
+ else:
+ # Note: `model.get_weights()` gives us the weights (non-ref)
+ # whereas `model.variables` returns references to the variables.
+ non_avg_weights = self.model.get_weights()
+ self.model.optimizer.assign_average_vars(self.model.variables)
+ # result is currently None, since `super._save_model` doesn't
+ # return anything, but this may change in the future.
+ result = super()._save_model(epoch, logs) # pytype: disable=attribute-error # typed-keras
+ self.model.set_weights(non_avg_weights)
+ return result
diff --git a/modeling/official/legacy/image_classification/classifier_trainer.py b/modeling/official/legacy/image_classification/classifier_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea432ca150122735afd9329126000f0e6b1231aa
--- /dev/null
+++ b/modeling/official/legacy/image_classification/classifier_trainer.py
@@ -0,0 +1,457 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Runs an Image Classification model."""
+
+import os
+import pprint
+from typing import Any, Mapping, Optional, Text, Tuple
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+from official.common import distribute_utils
+from official.legacy.image_classification import callbacks as custom_callbacks
+from official.legacy.image_classification import dataset_factory
+from official.legacy.image_classification import optimizer_factory
+from official.legacy.image_classification.configs import base_configs
+from official.legacy.image_classification.configs import configs
+from official.legacy.image_classification.efficientnet import efficientnet_model
+from official.legacy.image_classification.resnet import common
+from official.legacy.image_classification.resnet import resnet_model
+from official.legacy.image_classification.vgg import vgg_model
+from official.modeling import hyperparams
+from official.modeling import performance
+from official.utils import hyperparams_flags
+from official.utils.misc import keras_utils
+
+
+def get_models() -> Mapping[str, tf_keras.Model]:
+ """Returns the mapping from model type name to Keras model."""
+ return {
+ 'efficientnet': efficientnet_model.EfficientNet.from_name,
+ 'resnet': resnet_model.resnet50,
+ 'vgg': vgg_model.vgg16,
+ }
+
+
+def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
+ """Returns the mapping from dtype string representations to TF dtypes."""
+ return {
+ 'float32': tf.float32,
+ 'bfloat16': tf.bfloat16,
+ 'float16': tf.float16,
+ 'fp32': tf.float32,
+ 'bf16': tf.bfloat16,
+ }
+
+
+def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
+ """Get a dict of available metrics to track."""
+ if one_hot:
+ return {
+ # (name, metric_fn)
+ 'acc':
+ tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
+ 'accuracy':
+ tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
+ 'top_1':
+ tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
+ 'top_5':
+ tf_keras.metrics.TopKCategoricalAccuracy(
+ k=5, name='top_5_accuracy'),
+ }
+ else:
+ return {
+ # (name, metric_fn)
+ 'acc':
+ tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
+ 'accuracy':
+ tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
+ 'top_1':
+ tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
+ 'top_5':
+ tf_keras.metrics.SparseTopKCategoricalAccuracy(
+ k=5, name='top_5_accuracy'),
+ }
+
+
+def get_image_size_from_model(
+ params: base_configs.ExperimentConfig) -> Optional[int]:
+ """If the given model has a preferred image size, return it."""
+ if params.model_name == 'efficientnet':
+ efficientnet_name = params.model.model_params.model_name
+ if efficientnet_name in efficientnet_model.MODEL_CONFIGS:
+ return efficientnet_model.MODEL_CONFIGS[efficientnet_name].resolution
+ return None
+
+
+def _get_dataset_builders(params: base_configs.ExperimentConfig,
+ strategy: tf.distribute.Strategy,
+ one_hot: bool) -> Tuple[Any, Any]:
+ """Create and return train and validation dataset builders."""
+ if one_hot:
+ logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
+ else:
+ logging.warning('label_smoothing not applied, so datasets will not be one '
+ 'hot encoded.')
+
+ num_devices = strategy.num_replicas_in_sync if strategy else 1
+
+ image_size = get_image_size_from_model(params)
+
+ dataset_configs = [params.train_dataset, params.validation_dataset]
+ builders = []
+
+ for config in dataset_configs:
+ if config is not None and config.has_data:
+ builder = dataset_factory.DatasetBuilder(
+ config,
+ image_size=image_size or config.image_size,
+ num_devices=num_devices,
+ one_hot=one_hot)
+ else:
+ builder = None
+ builders.append(builder)
+
+ return builders
+
+
+def get_loss_scale(params: base_configs.ExperimentConfig,
+ fp16_default: float = 128.) -> float:
+ """Returns the loss scale for initializations."""
+ loss_scale = params.runtime.loss_scale
+ if loss_scale == 'dynamic':
+ return loss_scale
+ elif loss_scale is not None:
+ return float(loss_scale)
+ elif (params.train_dataset.dtype == 'float32' or
+ params.train_dataset.dtype == 'bfloat16'):
+ return 1.
+ else:
+ assert params.train_dataset.dtype == 'float16'
+ return fp16_default
+
+
+def _get_params_from_flags(flags_obj: flags.FlagValues):
+ """Get ParamsDict from flags."""
+ model = flags_obj.model_type.lower()
+ dataset = flags_obj.dataset.lower()
+ params = configs.get_config(model=model, dataset=dataset)
+
+ flags_overrides = {
+ 'model_dir': flags_obj.model_dir,
+ 'mode': flags_obj.mode,
+ 'model': {
+ 'name': model,
+ },
+ 'runtime': {
+ 'run_eagerly': flags_obj.run_eagerly,
+ 'tpu': flags_obj.tpu,
+ },
+ 'train_dataset': {
+ 'data_dir': flags_obj.data_dir,
+ },
+ 'validation_dataset': {
+ 'data_dir': flags_obj.data_dir,
+ },
+ 'train': {
+ 'time_history': {
+ 'log_steps': flags_obj.log_steps,
+ },
+ },
+ }
+
+ overriding_configs = (flags_obj.config_file, flags_obj.params_override,
+ flags_overrides)
+
+ pp = pprint.PrettyPrinter()
+
+ logging.info('Base params: %s', pp.pformat(params.as_dict()))
+
+ for param in overriding_configs:
+ logging.info('Overriding params: %s', param)
+ params = hyperparams.override_params_dict(params, param, is_strict=True)
+
+ params.validate()
+ params.lock()
+
+ logging.info('Final model parameters: %s', pp.pformat(params.as_dict()))
+ return params
+
+
+def resume_from_checkpoint(model: tf_keras.Model, model_dir: str,
+ train_steps: int) -> int:
+ """Resumes from the latest checkpoint, if possible.
+
+ Loads the model weights and optimizer settings from a checkpoint.
+ This function should be used in case of preemption recovery.
+
+ Args:
+ model: The model whose weights should be restored.
+ model_dir: The directory where model weights were saved.
+ train_steps: The number of steps to train.
+
+ Returns:
+ The epoch of the latest checkpoint, or 0 if not restoring.
+
+ """
+ logging.info('Load from checkpoint is enabled.')
+ latest_checkpoint = tf.train.latest_checkpoint(model_dir)
+ logging.info('latest_checkpoint: %s', latest_checkpoint)
+ if not latest_checkpoint:
+ logging.info('No checkpoint detected.')
+ return 0
+
+ logging.info('Checkpoint file %s found and restoring from '
+ 'checkpoint', latest_checkpoint)
+ model.load_weights(latest_checkpoint)
+ initial_epoch = model.optimizer.iterations // train_steps
+ logging.info('Completed loading from checkpoint.')
+ logging.info('Resuming from epoch %d', initial_epoch)
+ return int(initial_epoch)
+
+
+def initialize(params: base_configs.ExperimentConfig,
+ dataset_builder: dataset_factory.DatasetBuilder):
+ """Initializes backend related initializations."""
+ keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
+ performance.set_mixed_precision_policy(dataset_builder.dtype)
+ if tf.config.list_physical_devices('GPU'):
+ data_format = 'channels_first'
+ else:
+ data_format = 'channels_last'
+ tf_keras.backend.set_image_data_format(data_format)
+ if params.runtime.run_eagerly:
+ # Enable eager execution to allow step-by-step debugging
+ tf.config.experimental_run_functions_eagerly(True)
+ if tf.config.list_physical_devices('GPU'):
+ if params.runtime.gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=params.runtime.per_gpu_thread_count,
+ gpu_thread_mode=params.runtime.gpu_thread_mode,
+ num_gpus=params.runtime.num_gpus,
+ datasets_num_private_threads=params.runtime
+ .dataset_num_private_threads) # pylint:disable=line-too-long
+ if params.runtime.batchnorm_spatial_persistent:
+ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
+
+
+def define_classifier_flags():
+ """Defines common flags for image classification."""
+ hyperparams_flags.initialize_common_flags()
+ flags.DEFINE_string(
+ 'data_dir', default=None, help='The location of the input data.')
+ flags.DEFINE_string(
+ 'mode',
+ default=None,
+ help='Mode to run: `train`, `eval`, `train_and_eval` or `export`.')
+ flags.DEFINE_bool(
+ 'run_eagerly',
+ default=None,
+ help='Use eager execution and disable autograph for debugging.')
+ flags.DEFINE_string(
+ 'model_type',
+ default=None,
+ help='The type of the model, e.g. EfficientNet, etc.')
+ flags.DEFINE_string(
+ 'dataset',
+ default=None,
+ help='The name of the dataset, e.g. ImageNet, etc.')
+ flags.DEFINE_integer(
+ 'log_steps',
+ default=100,
+ help='The interval of steps between logging of batch level stats.')
+
+
+def serialize_config(params: base_configs.ExperimentConfig, model_dir: str):
+ """Serializes and saves the experiment config."""
+ params_save_path = os.path.join(model_dir, 'params.yaml')
+ logging.info('Saving experiment configuration to %s', params_save_path)
+ tf.io.gfile.makedirs(model_dir)
+ hyperparams.save_params_dict_to_yaml(params, params_save_path)
+
+
+def train_and_eval(
+ params: base_configs.ExperimentConfig,
+ strategy_override: tf.distribute.Strategy) -> Mapping[str, Any]:
+ """Runs the train and eval path using compile/fit."""
+ logging.info('Running train and eval.')
+
+ distribute_utils.configure_cluster(params.runtime.worker_hosts,
+ params.runtime.task_index)
+
+ # Note: for TPUs, strategy and scope should be created before the dataset
+ strategy = strategy_override or distribute_utils.get_distribution_strategy(
+ distribution_strategy=params.runtime.distribution_strategy,
+ all_reduce_alg=params.runtime.all_reduce_alg,
+ num_gpus=params.runtime.num_gpus,
+ tpu_address=params.runtime.tpu)
+
+ strategy_scope = distribute_utils.get_strategy_scope(strategy)
+
+ logging.info('Detected %d devices.',
+ strategy.num_replicas_in_sync if strategy else 1)
+
+ label_smoothing = params.model.loss.label_smoothing
+ one_hot = label_smoothing and label_smoothing > 0
+
+ builders = _get_dataset_builders(params, strategy, one_hot)
+ datasets = [
+ builder.build(strategy) if builder else None for builder in builders
+ ]
+
+ # Unpack datasets and builders based on train/val/test splits
+ train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
+ train_dataset, validation_dataset = datasets
+
+ train_epochs = params.train.epochs
+ train_steps = params.train.steps or train_builder.num_steps
+ validation_steps = params.evaluation.steps or validation_builder.num_steps
+
+ initialize(params, train_builder)
+
+ logging.info('Global batch size: %d', train_builder.global_batch_size)
+
+ with strategy_scope:
+ model_params = params.model.model_params.as_dict()
+ model = get_models()[params.model.name](**model_params)
+ learning_rate = optimizer_factory.build_learning_rate(
+ params=params.model.learning_rate,
+ batch_size=train_builder.global_batch_size,
+ train_epochs=train_epochs,
+ train_steps=train_steps)
+ optimizer = optimizer_factory.build_optimizer(
+ optimizer_name=params.model.optimizer.name,
+ base_learning_rate=learning_rate,
+ params=params.model.optimizer.as_dict(),
+ model=model)
+ optimizer = performance.configure_optimizer(
+ optimizer,
+ use_float16=train_builder.dtype == 'float16',
+ loss_scale=get_loss_scale(params))
+
+ metrics_map = _get_metrics(one_hot)
+ metrics = [metrics_map[metric] for metric in params.train.metrics]
+ steps_per_loop = train_steps if params.train.set_epoch_loop else 1
+
+ if one_hot:
+ loss_obj = tf_keras.losses.CategoricalCrossentropy(
+ label_smoothing=params.model.loss.label_smoothing)
+ else:
+ loss_obj = tf_keras.losses.SparseCategoricalCrossentropy()
+ model.compile(
+ optimizer=optimizer,
+ loss=loss_obj,
+ metrics=metrics,
+ steps_per_execution=steps_per_loop)
+
+ initial_epoch = 0
+ if params.train.resume_checkpoint:
+ initial_epoch = resume_from_checkpoint(
+ model=model, model_dir=params.model_dir, train_steps=train_steps)
+
+ callbacks = custom_callbacks.get_callbacks(
+ model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
+ include_tensorboard=params.train.callbacks.enable_tensorboard,
+ time_history=params.train.callbacks.enable_time_history,
+ track_lr=params.train.tensorboard.track_lr,
+ write_model_weights=params.train.tensorboard.write_model_weights,
+ initial_step=initial_epoch * train_steps,
+ batch_size=train_builder.global_batch_size,
+ log_steps=params.train.time_history.log_steps,
+ model_dir=params.model_dir,
+ backup_and_restore=params.train.callbacks.enable_backup_and_restore)
+
+ serialize_config(params=params, model_dir=params.model_dir)
+
+ if params.evaluation.skip_eval:
+ validation_kwargs = {}
+ else:
+ validation_kwargs = {
+ 'validation_data': validation_dataset,
+ 'validation_steps': validation_steps,
+ 'validation_freq': params.evaluation.epochs_between_evals,
+ }
+
+ history = model.fit(
+ train_dataset,
+ epochs=train_epochs,
+ steps_per_epoch=train_steps,
+ initial_epoch=initial_epoch,
+ callbacks=callbacks,
+ verbose=2,
+ **validation_kwargs)
+
+ validation_output = None
+ if not params.evaluation.skip_eval:
+ validation_output = model.evaluate(
+ validation_dataset, steps=validation_steps, verbose=2)
+
+ # TODO(dankondratyuk): eval and save final test accuracy
+ stats = common.build_stats(history, validation_output, callbacks)
+ return stats
+
+
+def export(params: base_configs.ExperimentConfig):
+ """Runs the model export functionality."""
+ logging.info('Exporting model.')
+ model_params = params.model.model_params.as_dict()
+ model = get_models()[params.model.name](**model_params)
+ checkpoint = params.export.checkpoint
+ if checkpoint is None:
+ logging.info('No export checkpoint was provided. Using the latest '
+ 'checkpoint from model_dir.')
+ checkpoint = tf.train.latest_checkpoint(params.model_dir)
+
+ model.load_weights(checkpoint)
+ model.save(params.export.destination)
+
+
+def run(flags_obj: flags.FlagValues,
+ strategy_override: tf.distribute.Strategy = None) -> Mapping[str, Any]:
+ """Runs Image Classification model using native Keras APIs.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+ strategy_override: A `tf.distribute.Strategy` object to use for model.
+
+ Returns:
+ Dictionary of training/eval stats
+ """
+ params = _get_params_from_flags(flags_obj)
+ if params.mode == 'train_and_eval':
+ return train_and_eval(params, strategy_override)
+ elif params.mode == 'export_only':
+ export(params)
+ else:
+ raise ValueError('{} is not a valid mode.'.format(params.mode))
+
+
+def main(_):
+ stats = run(flags.FLAGS)
+ if stats:
+ logging.info('Run stats:\n%s', stats)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ define_classifier_flags()
+ flags.mark_flag_as_required('data_dir')
+ flags.mark_flag_as_required('mode')
+ flags.mark_flag_as_required('model_type')
+ flags.mark_flag_as_required('dataset')
+
+ app.run(main)
diff --git a/modeling/official/legacy/image_classification/classifier_trainer_test.py b/modeling/official/legacy/image_classification/classifier_trainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..21a21cee84528e5757c0f1998d1147886d2ea256
--- /dev/null
+++ b/modeling/official/legacy/image_classification/classifier_trainer_test.py
@@ -0,0 +1,238 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for the classifier trainer models."""
+
+import functools
+import json
+
+import os
+import sys
+
+from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Tuple
+
+from absl import flags
+from absl.testing import flagsaver
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.legacy.image_classification import classifier_trainer
+from official.utils.flags import core as flags_core
+
+
+classifier_trainer.define_classifier_flags()
+
+
+def distribution_strategy_combinations() -> Iterable[Tuple[Any, ...]]:
+ """Returns the combinations of end-to-end tests to run."""
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],
+ model=[
+ 'efficientnet',
+ 'resnet',
+ 'vgg',
+ ],
+ dataset=[
+ 'imagenet',
+ ],
+ )
+
+
+def get_params_override(params_override: Mapping[str, Any]) -> str:
+ """Converts params_override dict to string command."""
+ return '--params_override=' + json.dumps(params_override)
+
+
+def basic_params_override(dtype: str = 'float32') -> MutableMapping[str, Any]:
+ """Returns a basic parameter configuration for testing."""
+ return {
+ 'train_dataset': {
+ 'builder': 'synthetic',
+ 'use_per_replica_batch_size': True,
+ 'batch_size': 1,
+ 'image_size': 224,
+ 'dtype': dtype,
+ },
+ 'validation_dataset': {
+ 'builder': 'synthetic',
+ 'batch_size': 1,
+ 'use_per_replica_batch_size': True,
+ 'image_size': 224,
+ 'dtype': dtype,
+ },
+ 'train': {
+ 'steps': 1,
+ 'epochs': 1,
+ 'callbacks': {
+ 'enable_checkpoint_and_export': True,
+ 'enable_tensorboard': False,
+ },
+ },
+ 'evaluation': {
+ 'steps': 1,
+ },
+ }
+
+
+@flagsaver.flagsaver
+def run_end_to_end(main: Callable[[Any], None],
+ extra_flags: Optional[Iterable[str]] = None,
+ model_dir: Optional[str] = None):
+ """Runs the classifier trainer end-to-end."""
+ extra_flags = [] if extra_flags is None else extra_flags
+ args = [sys.argv[0], '--model_dir', model_dir] + extra_flags
+ flags_core.parse_flags(argv=args)
+ main(flags.FLAGS)
+
+
+class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
+ """Unit tests for Keras models."""
+ _tempdir = None
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(ClassifierTest, cls).setUpClass()
+
+ def tearDown(self):
+ super(ClassifierTest, self).tearDown()
+ tf.io.gfile.rmtree(self.get_temp_dir())
+
+ @combinations.generate(distribution_strategy_combinations())
+ def test_end_to_end_train_and_eval(self, distribution, model, dataset):
+ """Test train_and_eval and export for Keras classifier models."""
+ # Some parameters are not defined as flags (e.g. cannot run
+ # classifier_train.py --batch_size=...) by design, so use
+ # "--params_override=..." instead
+ model_dir = self.create_tempdir().full_path
+ base_flags = [
+ '--data_dir=not_used',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ ]
+ train_and_eval_flags = base_flags + [
+ get_params_override(basic_params_override()),
+ '--mode=train_and_eval',
+ ]
+
+ run = functools.partial(
+ classifier_trainer.run, strategy_override=distribution)
+ run_end_to_end(
+ main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ model=[
+ 'efficientnet',
+ 'resnet',
+ 'vgg',
+ ],
+ dataset='imagenet',
+ dtype='float16',
+ ))
+ def test_gpu_train(self, distribution, model, dataset, dtype):
+ """Test train_and_eval and export for Keras classifier models."""
+ # Some parameters are not defined as flags (e.g. cannot run
+ # classifier_train.py --batch_size=...) by design, so use
+ # "--params_override=..." instead
+ model_dir = self.create_tempdir().full_path
+ base_flags = [
+ '--data_dir=not_used',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ ]
+ train_and_eval_flags = base_flags + [
+ get_params_override(basic_params_override(dtype)),
+ '--mode=train_and_eval',
+ ]
+
+ export_params = basic_params_override()
+ export_path = os.path.join(model_dir, 'export')
+ export_params['export'] = {}
+ export_params['export']['destination'] = export_path
+ export_flags = base_flags + [
+ '--mode=export_only',
+ get_params_override(export_params)
+ ]
+
+ run = functools.partial(
+ classifier_trainer.run, strategy_override=distribution)
+ run_end_to_end(
+ main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
+ run_end_to_end(main=run, extra_flags=export_flags, model_dir=model_dir)
+ self.assertTrue(os.path.exists(export_path))
+
+ @combinations.generate(
+ combinations.combine(
+ distribution=[
+ strategy_combinations.cloud_tpu_strategy,
+ ],
+ model=[
+ 'efficientnet',
+ 'resnet',
+ 'vgg',
+ ],
+ dataset='imagenet',
+ dtype='bfloat16',
+ ))
+ def test_tpu_train(self, distribution, model, dataset, dtype):
+ """Test train_and_eval and export for Keras classifier models."""
+ # Some parameters are not defined as flags (e.g. cannot run
+ # classifier_train.py --batch_size=...) by design, so use
+ # "--params_override=..." instead
+ model_dir = self.create_tempdir().full_path
+ base_flags = [
+ '--data_dir=not_used',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ ]
+ train_and_eval_flags = base_flags + [
+ get_params_override(basic_params_override(dtype)),
+ '--mode=train_and_eval',
+ ]
+
+ run = functools.partial(
+ classifier_trainer.run, strategy_override=distribution)
+ run_end_to_end(
+ main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
+
+ @combinations.generate(distribution_strategy_combinations())
+ def test_end_to_end_invalid_mode(self, distribution, model, dataset):
+ """Test the Keras EfficientNet model with `strategy`."""
+ model_dir = self.create_tempdir().full_path
+ extra_flags = [
+ '--data_dir=not_used',
+ '--mode=invalid_mode',
+ '--model_type=' + model,
+ '--dataset=' + dataset,
+ get_params_override(basic_params_override()),
+ ]
+
+ run = functools.partial(
+ classifier_trainer.run, strategy_override=distribution)
+ with self.assertRaises(ValueError):
+ run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/image_classification/classifier_trainer_util_test.py b/modeling/official/legacy/image_classification/classifier_trainer_util_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..09fdd65bf37b5fb2a9cf1a59299d727a8a9f0078
--- /dev/null
+++ b/modeling/official/legacy/image_classification/classifier_trainer_util_test.py
@@ -0,0 +1,165 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Unit tests for the classifier trainer models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import os
+
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from official.legacy.image_classification import classifier_trainer
+from official.legacy.image_classification import dataset_factory
+from official.legacy.image_classification import test_utils
+from official.legacy.image_classification.configs import base_configs
+
+
+def get_trivial_model(num_classes: int) -> tf_keras.Model:
+ """Creates and compiles trivial model for ImageNet dataset."""
+ model = test_utils.trivial_model(num_classes=num_classes)
+ lr = 0.01
+ optimizer = tf_keras.optimizers.SGD(learning_rate=lr)
+ loss_obj = tf_keras.losses.SparseCategoricalCrossentropy()
+ model.compile(optimizer=optimizer, loss=loss_obj, run_eagerly=True)
+ return model
+
+
+def get_trivial_data() -> tf.data.Dataset:
+ """Gets trivial data in the ImageNet size."""
+
+ def generate_data(_) -> tf.data.Dataset:
+ image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
+ label = tf.zeros([1], dtype=tf.int32)
+ return image, label
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.prefetch(buffer_size=1).batch(1)
+ return dataset
+
+
+class UtilTests(parameterized.TestCase, tf.test.TestCase):
+ """Tests for individual utility functions within classifier_trainer.py."""
+
+ @parameterized.named_parameters(
+ ('efficientnet-b0', 'efficientnet', 'efficientnet-b0', 224),
+ ('efficientnet-b1', 'efficientnet', 'efficientnet-b1', 240),
+ ('efficientnet-b2', 'efficientnet', 'efficientnet-b2', 260),
+ ('efficientnet-b3', 'efficientnet', 'efficientnet-b3', 300),
+ ('efficientnet-b4', 'efficientnet', 'efficientnet-b4', 380),
+ ('efficientnet-b5', 'efficientnet', 'efficientnet-b5', 456),
+ ('efficientnet-b6', 'efficientnet', 'efficientnet-b6', 528),
+ ('efficientnet-b7', 'efficientnet', 'efficientnet-b7', 600),
+ ('resnet', 'resnet', '', None),
+ )
+ def test_get_model_size(self, model, model_name, expected):
+ config = base_configs.ExperimentConfig(
+ model_name=model,
+ model=base_configs.ModelConfig(
+ model_params={
+ 'model_name': model_name,
+ },))
+ size = classifier_trainer.get_image_size_from_model(config)
+ self.assertEqual(size, expected)
+
+ @parameterized.named_parameters(
+ ('dynamic', 'dynamic', None, 'dynamic'),
+ ('scalar', 128., None, 128.),
+ ('float32', None, 'float32', 1),
+ ('float16', None, 'float16', 128),
+ )
+ def test_get_loss_scale(self, loss_scale, dtype, expected):
+ config = base_configs.ExperimentConfig(
+ runtime=base_configs.RuntimeConfig(loss_scale=loss_scale),
+ train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
+ ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
+ self.assertEqual(ls, expected)
+
+ @parameterized.named_parameters(('float16', 'float16'),
+ ('bfloat16', 'bfloat16'))
+ def test_initialize(self, dtype):
+ config = base_configs.ExperimentConfig(
+ runtime=base_configs.RuntimeConfig(
+ run_eagerly=False,
+ enable_xla=False,
+ per_gpu_thread_count=1,
+ gpu_thread_mode='gpu_private',
+ num_gpus=1,
+ dataset_num_private_threads=1,
+ ),
+ train_dataset=dataset_factory.DatasetConfig(dtype=dtype),
+ model=base_configs.ModelConfig(),
+ )
+
+ class EmptyClass:
+ pass
+
+ fake_ds_builder = EmptyClass()
+ fake_ds_builder.dtype = dtype
+ fake_ds_builder.config = EmptyClass()
+ classifier_trainer.initialize(config, fake_ds_builder)
+
+ def test_resume_from_checkpoint(self):
+ """Tests functionality for resuming from checkpoint."""
+ # Set the keras policy
+ tf_keras.mixed_precision.set_global_policy('mixed_bfloat16')
+
+ # Get the model, datasets, and compile it.
+ model = get_trivial_model(10)
+
+ # Create the checkpoint
+ model_dir = self.create_tempdir().full_path
+ train_epochs = 1
+ train_steps = 10
+ ds = get_trivial_data()
+ callbacks = [
+ tf_keras.callbacks.ModelCheckpoint(
+ os.path.join(model_dir, 'model.ckpt-{epoch:04d}'),
+ save_weights_only=True)
+ ]
+ model.fit(
+ ds,
+ callbacks=callbacks,
+ epochs=train_epochs,
+ steps_per_epoch=train_steps)
+
+ # Test load from checkpoint
+ clean_model = get_trivial_model(10)
+ weights_before_load = copy.deepcopy(clean_model.get_weights())
+ initial_epoch = classifier_trainer.resume_from_checkpoint(
+ model=clean_model, model_dir=model_dir, train_steps=train_steps)
+ self.assertEqual(initial_epoch, 1)
+ self.assertNotAllClose(weights_before_load, clean_model.get_weights())
+
+ tf.io.gfile.rmtree(model_dir)
+
+ def test_serialize_config(self):
+ """Tests functionality for serializing data."""
+ config = base_configs.ExperimentConfig()
+ model_dir = self.create_tempdir().full_path
+ classifier_trainer.serialize_config(params=config, model_dir=model_dir)
+ saved_params_path = os.path.join(model_dir, 'params.yaml')
+ self.assertTrue(os.path.exists(saved_params_path))
+ tf.io.gfile.rmtree(model_dir)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/image_classification/configs/__init__.py b/modeling/official/legacy/image_classification/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/image_classification/configs/base_configs.py b/modeling/official/legacy/image_classification/configs/base_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..975a1be9de4b49d970f7051a0965abef553fce35
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/base_configs.py
@@ -0,0 +1,262 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Definitions for high level configuration groups.."""
+
+import dataclasses
+from typing import Any, List, Optional
+from official.core import config_definitions
+from official.modeling import hyperparams
+
+RuntimeConfig = config_definitions.RuntimeConfig
+
+
+@dataclasses.dataclass
+class TensorBoardConfig(hyperparams.Config):
+ """Configuration for TensorBoard.
+
+ Attributes:
+ track_lr: Whether or not to track the learning rate in TensorBoard. Defaults
+ to True.
+ write_model_weights: Whether or not to write the model weights as images in
+ TensorBoard. Defaults to False.
+ """
+ track_lr: bool = True
+ write_model_weights: bool = False
+
+
+@dataclasses.dataclass
+class CallbacksConfig(hyperparams.Config):
+ """Configuration for Callbacks.
+
+ Attributes:
+ enable_checkpoint_and_export: Whether or not to enable checkpoints as a
+ Callback. Defaults to True.
+ enable_backup_and_restore: Whether or not to add BackupAndRestore
+ callback. Defaults to True.
+ enable_tensorboard: Whether or not to enable TensorBoard as a Callback.
+ Defaults to True.
+ enable_time_history: Whether or not to enable TimeHistory Callbacks.
+ Defaults to True.
+ """
+ enable_checkpoint_and_export: bool = True
+ enable_backup_and_restore: bool = False
+ enable_tensorboard: bool = True
+ enable_time_history: bool = True
+
+
+@dataclasses.dataclass
+class ExportConfig(hyperparams.Config):
+ """Configuration for exports.
+
+ Attributes:
+ checkpoint: the path to the checkpoint to export.
+ destination: the path to where the checkpoint should be exported.
+ """
+ checkpoint: str = None
+ destination: str = None
+
+
+@dataclasses.dataclass
+class MetricsConfig(hyperparams.Config):
+ """Configuration for Metrics.
+
+ Attributes:
+ accuracy: Whether or not to track accuracy as a Callback. Defaults to None.
+ top_5: Whether or not to track top_5_accuracy as a Callback. Defaults to
+ None.
+ """
+ accuracy: bool = None
+ top_5: bool = None
+
+
+@dataclasses.dataclass
+class TimeHistoryConfig(hyperparams.Config):
+ """Configuration for the TimeHistory callback.
+
+ Attributes:
+ log_steps: Interval of steps between logging of batch level stats.
+ """
+ log_steps: int = None
+
+
+@dataclasses.dataclass
+class TrainConfig(hyperparams.Config):
+ """Configuration for training.
+
+ Attributes:
+ resume_checkpoint: Whether or not to enable load checkpoint loading.
+ Defaults to None.
+ epochs: The number of training epochs to run. Defaults to None.
+ steps: The number of steps to run per epoch. If None, then this will be
+ inferred based on the number of images and batch size. Defaults to None.
+ callbacks: An instance of CallbacksConfig.
+ metrics: An instance of MetricsConfig.
+ tensorboard: An instance of TensorBoardConfig.
+ set_epoch_loop: Whether or not to set `steps_per_execution` to
+ equal the number of training steps in `model.compile`. This reduces the
+ number of callbacks run per epoch which significantly improves end-to-end
+ TPU training time.
+ """
+ resume_checkpoint: bool = None
+ epochs: int = None
+ steps: int = None
+ callbacks: CallbacksConfig = dataclasses.field(
+ default_factory=CallbacksConfig
+ )
+ metrics: MetricsConfig = None
+ tensorboard: TensorBoardConfig = dataclasses.field(
+ default_factory=TensorBoardConfig
+ )
+ time_history: TimeHistoryConfig = dataclasses.field(
+ default_factory=TimeHistoryConfig
+ )
+ set_epoch_loop: bool = False
+
+
+@dataclasses.dataclass
+class EvalConfig(hyperparams.Config):
+ """Configuration for evaluation.
+
+ Attributes:
+ epochs_between_evals: The number of train epochs to run between evaluations.
+ Defaults to None.
+ steps: The number of eval steps to run during evaluation. If None, this will
+ be inferred based on the number of images and batch size. Defaults to
+ None.
+ skip_eval: Whether or not to skip evaluation.
+ """
+ epochs_between_evals: int = None
+ steps: int = None
+ skip_eval: bool = False
+
+
+@dataclasses.dataclass
+class LossConfig(hyperparams.Config):
+ """Configuration for Loss.
+
+ Attributes:
+ name: The name of the loss. Defaults to None.
+ label_smoothing: Whether or not to apply label smoothing to the loss. This
+ only applies to 'categorical_cross_entropy'.
+ """
+ name: str = None
+ label_smoothing: float = None
+
+
+@dataclasses.dataclass
+class OptimizerConfig(hyperparams.Config):
+ """Configuration for Optimizers.
+
+ Attributes:
+ name: The name of the optimizer. Defaults to None.
+ decay: Decay or rho, discounting factor for gradient. Defaults to None.
+ epsilon: Small value used to avoid 0 denominator. Defaults to None.
+ momentum: Plain momentum constant. Defaults to None.
+ nesterov: Whether or not to apply Nesterov momentum. Defaults to None.
+ moving_average_decay: The amount of decay to apply. If 0 or None, then
+ exponential moving average is not used. Defaults to None.
+ lookahead: Whether or not to apply the lookahead optimizer. Defaults to
+ None.
+ beta_1: The exponential decay rate for the 1st moment estimates. Used in the
+ Adam optimizers. Defaults to None.
+ beta_2: The exponential decay rate for the 2nd moment estimates. Used in the
+ Adam optimizers. Defaults to None.
+ epsilon: Small value used to avoid 0 denominator. Defaults to 1e-7.
+ """
+ name: str = None
+ decay: float = None
+ epsilon: float = None
+ momentum: float = None
+ nesterov: bool = None
+ moving_average_decay: Optional[float] = None
+ lookahead: Optional[bool] = None
+ beta_1: float = None
+ beta_2: float = None
+ epsilon: float = None
+
+
+@dataclasses.dataclass
+class LearningRateConfig(hyperparams.Config):
+ """Configuration for learning rates.
+
+ Attributes:
+ name: The name of the learning rate. Defaults to None.
+ initial_lr: The initial learning rate. Defaults to None.
+ decay_epochs: The number of decay epochs. Defaults to None.
+ decay_rate: The rate of decay. Defaults to None.
+ warmup_epochs: The number of warmup epochs. Defaults to None.
+ batch_lr_multiplier: The multiplier to apply to the base learning rate, if
+ necessary. Defaults to None.
+ examples_per_epoch: the number of examples in a single epoch. Defaults to
+ None.
+ boundaries: boundaries used in piecewise constant decay with warmup.
+ multipliers: multipliers used in piecewise constant decay with warmup.
+ scale_by_batch_size: Scale the learning rate by a fraction of the batch
+ size. Set to 0 for no scaling (default).
+ staircase: Apply exponential decay at discrete values instead of continuous.
+ """
+ name: str = None
+ initial_lr: float = None
+ decay_epochs: float = None
+ decay_rate: float = None
+ warmup_epochs: int = None
+ examples_per_epoch: int = None
+ boundaries: List[int] = None
+ multipliers: List[float] = None
+ scale_by_batch_size: float = 0.
+ staircase: bool = None
+
+
+@dataclasses.dataclass
+class ModelConfig(hyperparams.Config):
+ """Configuration for Models.
+
+ Attributes:
+ name: The name of the model. Defaults to None.
+ model_params: The parameters used to create the model. Defaults to None.
+ num_classes: The number of classes in the model. Defaults to None.
+ loss: A `LossConfig` instance. Defaults to None.
+ optimizer: An `OptimizerConfig` instance. Defaults to None.
+ """
+ name: str = None
+ model_params: hyperparams.Config = None
+ num_classes: int = None
+ loss: LossConfig = None
+ optimizer: OptimizerConfig = None
+
+
+@dataclasses.dataclass
+class ExperimentConfig(hyperparams.Config):
+ """Base configuration for an image classification experiment.
+
+ Attributes:
+ model_dir: The directory to use when running an experiment.
+ mode: e.g. 'train_and_eval', 'export'
+ runtime: A `RuntimeConfig` instance.
+ train: A `TrainConfig` instance.
+ evaluation: An `EvalConfig` instance.
+ model: A `ModelConfig` instance.
+ export: An `ExportConfig` instance.
+ """
+ model_dir: str = None
+ model_name: str = None
+ mode: str = None
+ runtime: RuntimeConfig = None
+ train_dataset: Any = None
+ validation_dataset: Any = None
+ train: TrainConfig = None
+ evaluation: EvalConfig = None
+ model: ModelConfig = None
+ export: ExportConfig = None
diff --git a/modeling/official/legacy/image_classification/configs/configs.py b/modeling/official/legacy/image_classification/configs/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..05b6c2eca824113bbe9ce8a31cdd3dc3fe176bbe
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/configs.py
@@ -0,0 +1,191 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration utils for image classification experiments."""
+
+import dataclasses
+
+from official.legacy.image_classification import dataset_factory
+from official.legacy.image_classification.configs import base_configs
+from official.legacy.image_classification.efficientnet import efficientnet_config
+from official.legacy.image_classification.resnet import resnet_config
+from official.legacy.image_classification.vgg import vgg_config
+
+
+@dataclasses.dataclass
+class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
+ """Base configuration to train efficientnet-b0 on ImageNet.
+
+ Attributes:
+ export: An `ExportConfig` instance
+ runtime: A `RuntimeConfig` instance.
+ dataset: A `DatasetConfig` instance.
+ train: A `TrainConfig` instance.
+ evaluation: An `EvalConfig` instance.
+ model: A `ModelConfig` instance.
+ """
+ export: base_configs.ExportConfig = dataclasses.field(
+ default_factory=base_configs.ExportConfig
+ )
+ runtime: base_configs.RuntimeConfig = dataclasses.field(
+ default_factory=base_configs.RuntimeConfig
+ )
+ train_dataset: dataset_factory.DatasetConfig = dataclasses.field(
+ default_factory=lambda: dataset_factory.ImageNetConfig(split='train')
+ )
+ validation_dataset: dataset_factory.DatasetConfig = dataclasses.field(
+ default_factory=lambda: dataset_factory.ImageNetConfig(split='validation')
+ )
+ train: base_configs.TrainConfig = dataclasses.field(
+ default_factory=lambda: base_configs.TrainConfig( # pylint: disable=g-long-lambda
+ resume_checkpoint=True,
+ epochs=500,
+ steps=None,
+ callbacks=base_configs.CallbacksConfig(
+ enable_checkpoint_and_export=True, enable_tensorboard=True
+ ),
+ metrics=['accuracy', 'top_5'],
+ time_history=base_configs.TimeHistoryConfig(log_steps=100),
+ tensorboard=base_configs.TensorBoardConfig(
+ track_lr=True, write_model_weights=False
+ ),
+ set_epoch_loop=False,
+ )
+ )
+ evaluation: base_configs.EvalConfig = dataclasses.field(
+ default_factory=lambda: base_configs.EvalConfig( # pylint: disable=g-long-lambda
+ epochs_between_evals=1, steps=None
+ )
+ )
+ model: base_configs.ModelConfig = dataclasses.field(
+ default_factory=efficientnet_config.EfficientNetModelConfig
+ )
+
+
+@dataclasses.dataclass
+class ResNetImagenetConfig(base_configs.ExperimentConfig):
+ """Base configuration to train resnet-50 on ImageNet."""
+ export: base_configs.ExportConfig = dataclasses.field(
+ default_factory=base_configs.ExportConfig
+ )
+ runtime: base_configs.RuntimeConfig = dataclasses.field(
+ default_factory=base_configs.RuntimeConfig
+ )
+ train_dataset: dataset_factory.DatasetConfig = dataclasses.field(
+ default_factory=lambda: dataset_factory.ImageNetConfig( # pylint: disable=g-long-lambda
+ split='train', one_hot=False, mean_subtract=True, standardize=True
+ )
+ )
+ validation_dataset: dataset_factory.DatasetConfig = dataclasses.field(
+ default_factory=lambda: dataset_factory.ImageNetConfig( # pylint: disable=g-long-lambda
+ split='validation',
+ one_hot=False,
+ mean_subtract=True,
+ standardize=True,
+ )
+ )
+ train: base_configs.TrainConfig = dataclasses.field(
+ default_factory=lambda: base_configs.TrainConfig( # pylint: disable=g-long-lambda
+ resume_checkpoint=True,
+ epochs=90,
+ steps=None,
+ callbacks=base_configs.CallbacksConfig(
+ enable_checkpoint_and_export=True, enable_tensorboard=True
+ ),
+ metrics=['accuracy', 'top_5'],
+ time_history=base_configs.TimeHistoryConfig(log_steps=100),
+ tensorboard=base_configs.TensorBoardConfig(
+ track_lr=True, write_model_weights=False
+ ),
+ set_epoch_loop=False,
+ )
+ )
+ evaluation: base_configs.EvalConfig = dataclasses.field(
+ default_factory=lambda: base_configs.EvalConfig( # pylint: disable=g-long-lambda
+ epochs_between_evals=1, steps=None
+ )
+ )
+ model: base_configs.ModelConfig = dataclasses.field(
+ default_factory=resnet_config.ResNetModelConfig
+ )
+
+
+@dataclasses.dataclass
+class VGGImagenetConfig(base_configs.ExperimentConfig):
+ """Base configuration to train vgg-16 on ImageNet."""
+ export: base_configs.ExportConfig = dataclasses.field(
+ default_factory=base_configs.ExportConfig
+ )
+ runtime: base_configs.RuntimeConfig = dataclasses.field(
+ default_factory=base_configs.RuntimeConfig
+ )
+ train_dataset: dataset_factory.DatasetConfig = dataclasses.field(
+ default_factory=lambda: dataset_factory.ImageNetConfig( # pylint: disable=g-long-lambda
+ split='train', one_hot=False, mean_subtract=True, standardize=True
+ )
+ )
+ validation_dataset: dataset_factory.DatasetConfig = dataclasses.field(
+ default_factory=lambda: dataset_factory.ImageNetConfig( # pylint: disable=g-long-lambda
+ split='validation',
+ one_hot=False,
+ mean_subtract=True,
+ standardize=True,
+ )
+ )
+ train: base_configs.TrainConfig = dataclasses.field(
+ default_factory=lambda: base_configs.TrainConfig( # pylint: disable=g-long-lambda
+ resume_checkpoint=True,
+ epochs=90,
+ steps=None,
+ callbacks=base_configs.CallbacksConfig(
+ enable_checkpoint_and_export=True, enable_tensorboard=True
+ ),
+ metrics=['accuracy', 'top_5'],
+ time_history=base_configs.TimeHistoryConfig(log_steps=100),
+ tensorboard=base_configs.TensorBoardConfig(
+ track_lr=True, write_model_weights=False
+ ),
+ set_epoch_loop=False,
+ )
+ )
+ evaluation: base_configs.EvalConfig = dataclasses.field(
+ default_factory=lambda: base_configs.EvalConfig( # pylint: disable=g-long-lambda
+ epochs_between_evals=1, steps=None
+ )
+ )
+ model: base_configs.ModelConfig = dataclasses.field(
+ default_factory=vgg_config.VGGModelConfig
+ )
+
+
+def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
+ """Given model and dataset names, return the ExperimentConfig."""
+ dataset_model_config_map = {
+ 'imagenet': {
+ 'efficientnet': EfficientNetImageNetConfig(),
+ 'resnet': ResNetImagenetConfig(),
+ 'vgg': VGGImagenetConfig(),
+ }
+ }
+ try:
+ return dataset_model_config_map[dataset][model]
+ except KeyError:
+ if dataset not in dataset_model_config_map:
+ raise KeyError('Invalid dataset received. Received: {}. Supported '
+ 'datasets include: {}'.format(
+ dataset, ', '.join(dataset_model_config_map.keys())))
+ raise KeyError('Invalid model received. Received: {}. Supported models for'
+ '{} include: {}'.format(
+ model, dataset,
+ ', '.join(dataset_model_config_map[dataset].keys())))
diff --git a/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9b8608922faef69a5f7edba6ab86c5d50c1dfa21
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-gpu.yaml
@@ -0,0 +1,52 @@
+# Training configuration for EfficientNet-b0 trained on ImageNet on GPUs.
+# Takes ~32 minutes per epoch for 8 V100s.
+# Reaches ~76.1% within 350 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 32
+ use_per_replica_batch_size: true
+ dtype: 'float32'
+ augmenter:
+ name: 'autoaugment'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 32
+ use_per_replica_batch_size: true
+ dtype: 'float32'
+model:
+ model_params:
+ model_name: 'efficientnet-b0'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'default'
+ dtype: 'float32'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: true
+ epochs: 500
+evaluation:
+ epochs_between_evals: 1
diff --git a/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..066ac220734d3ddb6614bcae8f62c3d3ae41e4f4
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b0-tpu.yaml
@@ -0,0 +1,52 @@
+# Training configuration for EfficientNet-b0 trained on ImageNet on TPUs.
+# Takes ~2 minutes, 50 seconds per epoch for v3-32.
+# Reaches ~76.1% within 350 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'tpu'
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 128
+ use_per_replica_batch_size: true
+ dtype: 'bfloat16'
+ augmenter:
+ name: 'autoaugment'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 128
+ use_per_replica_batch_size: true
+ dtype: 'bfloat16'
+model:
+ model_params:
+ model_name: 'efficientnet-b0'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'tpu'
+ dtype: 'bfloat16'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: true
+ epochs: 500
+ set_epoch_loop: true
+evaluation:
+ epochs_between_evals: 1
diff --git a/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..21f3fe3200928cb543bea7fea75361903e5bf40e
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-gpu.yaml
@@ -0,0 +1,47 @@
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 32
+ use_per_replica_batch_size: true
+ dtype: 'float32'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 32
+ use_per_replica_batch_size: true
+ dtype: 'float32'
+model:
+ model_params:
+ model_name: 'efficientnet-b1'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'default'
+ dtype: 'float32'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: true
+ epochs: 500
+evaluation:
+ epochs_between_evals: 1
diff --git a/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..616ac7c502c73709db04541653023024859a41bd
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/examples/efficientnet/imagenet/efficientnet-b1-tpu.yaml
@@ -0,0 +1,51 @@
+# Training configuration for EfficientNet-b1 trained on ImageNet on TPUs.
+# Takes ~3 minutes, 15 seconds per epoch for v3-32.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'tpu'
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 128
+ use_per_replica_batch_size: true
+ dtype: 'bfloat16'
+ augmenter:
+ name: 'autoaugment'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 128
+ use_per_replica_batch_size: true
+ dtype: 'bfloat16'
+model:
+ model_params:
+ model_name: 'efficientnet-b1'
+ overrides:
+ num_classes: 1000
+ batch_norm: 'tpu'
+ dtype: 'bfloat16'
+ activation: 'swish'
+ optimizer:
+ name: 'rmsprop'
+ momentum: 0.9
+ decay: 0.9
+ moving_average_decay: 0.0
+ lookahead: false
+ learning_rate:
+ name: 'exponential'
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: true
+ epochs: 500
+ set_epoch_loop: true
+evaluation:
+ epochs_between_evals: 1
diff --git a/modeling/official/legacy/image_classification/configs/examples/resnet/imagenet/gpu.yaml b/modeling/official/legacy/image_classification/configs/examples/resnet/imagenet/gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f0c623206c1b43c68792a68450c800f94db7dceb
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/examples/resnet/imagenet/gpu.yaml
@@ -0,0 +1,49 @@
+# Training configuration for ResNet trained on ImageNet on GPUs.
+# Reaches > 76.1% within 90 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+ batchnorm_spatial_persistent: true
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'train'
+ image_size: 224
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 256
+ use_per_replica_batch_size: true
+ dtype: 'float16'
+ mean_subtract: true
+ standardize: true
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'validation'
+ image_size: 224
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 256
+ use_per_replica_batch_size: true
+ dtype: 'float16'
+ mean_subtract: true
+ standardize: true
+model:
+ name: 'resnet'
+ model_params:
+ rescale_inputs: false
+ optimizer:
+ name: 'momentum'
+ momentum: 0.9
+ decay: 0.9
+ epsilon: 0.001
+ loss:
+ label_smoothing: 0.1
+train:
+ resume_checkpoint: true
+ epochs: 90
+evaluation:
+ epochs_between_evals: 1
diff --git a/modeling/official/legacy/image_classification/configs/examples/resnet/imagenet/tpu.yaml b/modeling/official/legacy/image_classification/configs/examples/resnet/imagenet/tpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1f88980755e7259a7d39c1636e75e5092089f186
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/examples/resnet/imagenet/tpu.yaml
@@ -0,0 +1,55 @@
+# Training configuration for ResNet trained on ImageNet on TPUs.
+# Takes ~4 minutes, 30 seconds seconds per epoch for a v3-32.
+# Reaches > 76.1% within 90 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'tpu'
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'train'
+ one_hot: false
+ image_size: 224
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 128
+ use_per_replica_batch_size: true
+ mean_subtract: false
+ standardize: false
+ dtype: 'bfloat16'
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'tfds'
+ split: 'validation'
+ one_hot: false
+ image_size: 224
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 128
+ use_per_replica_batch_size: true
+ mean_subtract: false
+ standardize: false
+ dtype: 'bfloat16'
+model:
+ name: 'resnet'
+ model_params:
+ rescale_inputs: true
+ optimizer:
+ name: 'momentum'
+ momentum: 0.9
+ decay: 0.9
+ epsilon: 0.001
+ moving_average_decay: 0.
+ lookahead: false
+ loss:
+ label_smoothing: 0.1
+train:
+ callbacks:
+ enable_checkpoint_and_export: true
+ resume_checkpoint: true
+ epochs: 90
+ set_epoch_loop: true
+evaluation:
+ epochs_between_evals: 1
diff --git a/modeling/official/legacy/image_classification/configs/examples/vgg16/imagenet/gpu.yaml b/modeling/official/legacy/image_classification/configs/examples/vgg16/imagenet/gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..33c5a4e36a71fd975ddb57e298fa87d65c66555a
--- /dev/null
+++ b/modeling/official/legacy/image_classification/configs/examples/vgg16/imagenet/gpu.yaml
@@ -0,0 +1,46 @@
+# Training configuration for VGG-16 trained on ImageNet on GPUs.
+# Reaches > 72.8% within 90 epochs.
+# Note: This configuration uses a scaled per-replica batch size based on the number of devices.
+runtime:
+ distribution_strategy: 'mirrored'
+ num_gpus: 1
+ batchnorm_spatial_persistent: true
+train_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'train'
+ image_size: 224
+ num_classes: 1000
+ num_examples: 1281167
+ batch_size: 128
+ use_per_replica_batch_size: true
+ dtype: 'float32'
+ mean_subtract: true
+ standardize: true
+validation_dataset:
+ name: 'imagenet2012'
+ data_dir: null
+ builder: 'records'
+ split: 'validation'
+ image_size: 224
+ num_classes: 1000
+ num_examples: 50000
+ batch_size: 128
+ use_per_replica_batch_size: true
+ dtype: 'float32'
+ mean_subtract: true
+ standardize: true
+model:
+ name: 'vgg'
+ optimizer:
+ name: 'momentum'
+ momentum: 0.9
+ epsilon: 0.001
+ loss:
+ label_smoothing: 0.0
+train:
+ resume_checkpoint: true
+ epochs: 90
+evaluation:
+ epochs_between_evals: 1
diff --git a/modeling/official/legacy/image_classification/dataset_factory.py b/modeling/official/legacy/image_classification/dataset_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e414ed6f19cd3cd3e11bd12b7979a4cbd6f8eb41
--- /dev/null
+++ b/modeling/official/legacy/image_classification/dataset_factory.py
@@ -0,0 +1,533 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Dataset utilities for vision tasks using TFDS and tf.data.Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import dataclasses
+import os
+from typing import Any, List, Mapping, Optional, Tuple, Union
+
+from absl import logging
+import tensorflow as tf, tf_keras
+import tensorflow_datasets as tfds
+from official.legacy.image_classification import augment
+from official.legacy.image_classification import preprocessing
+from official.modeling.hyperparams import base_config
+
+AUGMENTERS = {
+ 'autoaugment': augment.AutoAugment,
+ 'randaugment': augment.RandAugment,
+}
+
+
+@dataclasses.dataclass
+class AugmentConfig(base_config.Config):
+ """Configuration for image augmenters.
+
+ Attributes:
+ name: The name of the image augmentation to use. Possible options are None
+ (default), 'autoaugment', or 'randaugment'.
+ params: Any parameters used to initialize the augmenter.
+ """
+ name: Optional[str] = None
+ params: Optional[Mapping[str, Any]] = None
+
+ def build(self) -> augment.ImageAugment:
+ """Build the augmenter using this config."""
+ params = self.params or {}
+ augmenter = AUGMENTERS.get(self.name, None)
+ return augmenter(**params) if augmenter is not None else None
+
+
+@dataclasses.dataclass
+class DatasetConfig(base_config.Config):
+ """The base configuration for building datasets.
+
+ Attributes:
+ name: The name of the Dataset. Usually should correspond to a TFDS dataset.
+ data_dir: The path where the dataset files are stored, if available.
+ filenames: Optional list of strings representing the TFRecord names.
+ builder: The builder type used to load the dataset. Value should be one of
+ 'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
+ (generate dummy synthetic data without reading from files).
+ split: The split of the dataset. Usually 'train', 'validation', or 'test'.
+ image_size: The size of the image in the dataset. This assumes that `width`
+ == `height`. Set to 'infer' to infer the image size from TFDS info. This
+ requires `name` to be a registered dataset in TFDS.
+ num_classes: The number of classes given by the dataset. Set to 'infer' to
+ infer the image size from TFDS info. This requires `name` to be a
+ registered dataset in TFDS.
+ num_channels: The number of channels given by the dataset. Set to 'infer' to
+ infer the image size from TFDS info. This requires `name` to be a
+ registered dataset in TFDS.
+ num_examples: The number of examples given by the dataset. Set to 'infer' to
+ infer the image size from TFDS info. This requires `name` to be a
+ registered dataset in TFDS.
+ batch_size: The base batch size for the dataset.
+ use_per_replica_batch_size: Whether to scale the batch size based on
+ available resources. If set to `True`, the dataset builder will return
+ batch_size multiplied by `num_devices`, the number of device replicas
+ (e.g., the number of GPUs or TPU cores). This setting should be `True` if
+ the strategy argument is passed to `build()` and `num_devices > 1`.
+ num_devices: The number of replica devices to use. This should be set by
+ `strategy.num_replicas_in_sync` when using a distribution strategy.
+ dtype: The desired dtype of the dataset. This will be set during
+ preprocessing.
+ one_hot: Whether to apply one hot encoding. Set to `True` to be able to use
+ label smoothing.
+ augmenter: The augmenter config to use. No augmentation is used by default.
+ download: Whether to download data using TFDS.
+ shuffle_buffer_size: The buffer size used for shuffling training data.
+ file_shuffle_buffer_size: The buffer size used for shuffling raw training
+ files.
+ skip_decoding: Whether to skip image decoding when loading from TFDS.
+ cache: whether to cache to dataset examples. Can be used to avoid re-reading
+ from disk on the second epoch. Requires significant memory overhead.
+ tf_data_service: The URI of a tf.data service to offload preprocessing onto
+ during training. The URI should be in the format "protocol://address",
+ e.g. "grpc://tf-data-service:5050".
+ mean_subtract: whether or not to apply mean subtraction to the dataset.
+ standardize: whether or not to apply standardization to the dataset.
+ """
+ name: Optional[str] = None
+ data_dir: Optional[str] = None
+ filenames: Optional[List[str]] = None
+ builder: str = 'tfds'
+ split: str = 'train'
+ image_size: Union[int, str] = 'infer'
+ num_classes: Union[int, str] = 'infer'
+ num_channels: Union[int, str] = 'infer'
+ num_examples: Union[int, str] = 'infer'
+ batch_size: int = 128
+ use_per_replica_batch_size: bool = True
+ num_devices: int = 1
+ dtype: str = 'float32'
+ one_hot: bool = True
+ augmenter: AugmentConfig = dataclasses.field(default_factory=AugmentConfig)
+ download: bool = False
+ shuffle_buffer_size: int = 10000
+ file_shuffle_buffer_size: int = 1024
+ skip_decoding: bool = True
+ cache: bool = False
+ tf_data_service: Optional[str] = None
+ mean_subtract: bool = False
+ standardize: bool = False
+
+ @property
+ def has_data(self):
+ """Whether this dataset is has any data associated with it."""
+ return self.name or self.data_dir or self.filenames
+
+
+@dataclasses.dataclass
+class ImageNetConfig(DatasetConfig):
+ """The base ImageNet dataset config."""
+ name: str = 'imagenet2012'
+ # Note: for large datasets like ImageNet, using records is faster than tfds
+ builder: str = 'records'
+ image_size: int = 224
+ num_channels: int = 3
+ num_examples: int = 1281167
+ num_classes: int = 1000
+ batch_size: int = 128
+
+
+@dataclasses.dataclass
+class Cifar10Config(DatasetConfig):
+ """The base CIFAR-10 dataset config."""
+ name: str = 'cifar10'
+ image_size: int = 224
+ batch_size: int = 128
+ download: bool = True
+ cache: bool = True
+
+
+class DatasetBuilder:
+ """An object for building datasets.
+
+ Allows building various pipelines fetching examples, preprocessing, etc.
+ Maintains additional state information calculated from the dataset, i.e.,
+ training set split, batch size, and number of steps (batches).
+ """
+
+ def __init__(self, config: DatasetConfig, **overrides: Any):
+ """Initialize the builder from the config."""
+ self.config = config.replace(**overrides)
+ self.builder_info = None
+
+ if self.config.augmenter is not None:
+ logging.info('Using augmentation: %s', self.config.augmenter.name)
+ self.augmenter = self.config.augmenter.build()
+ else:
+ self.augmenter = None
+
+ @property
+ def is_training(self) -> bool:
+ """Whether this is the training set."""
+ return self.config.split == 'train'
+
+ @property
+ def batch_size(self) -> int:
+ """The batch size, multiplied by the number of replicas (if configured)."""
+ if self.config.use_per_replica_batch_size:
+ return self.config.batch_size * self.config.num_devices
+ else:
+ return self.config.batch_size
+
+ @property
+ def global_batch_size(self):
+ """The global batch size across all replicas."""
+ return self.batch_size
+
+ @property
+ def local_batch_size(self):
+ """The base unscaled batch size."""
+ if self.config.use_per_replica_batch_size:
+ return self.config.batch_size
+ else:
+ return self.config.batch_size // self.config.num_devices
+
+ @property
+ def num_steps(self) -> int:
+ """The number of steps (batches) to exhaust this dataset."""
+ # Always divide by the global batch size to get the correct # of steps
+ return self.num_examples // self.global_batch_size
+
+ @property
+ def dtype(self) -> tf.dtypes.DType:
+ """Converts the config's dtype string to a tf dtype.
+
+ Returns:
+ A mapping from string representation of a dtype to the `tf.dtypes.DType`.
+
+ Raises:
+ ValueError if the config's dtype is not supported.
+
+ """
+ dtype_map = {
+ 'float32': tf.float32,
+ 'bfloat16': tf.bfloat16,
+ 'float16': tf.float16,
+ 'fp32': tf.float32,
+ 'bf16': tf.bfloat16,
+ }
+ try:
+ return dtype_map[self.config.dtype]
+ except:
+ raise ValueError('Invalid DType provided. Supported types: {}'.format(
+ dtype_map.keys()))
+
+ @property
+ def image_size(self) -> int:
+ """The size of each image (can be inferred from the dataset)."""
+
+ if self.config.image_size == 'infer':
+ return self.info.features['image'].shape[0]
+ else:
+ return int(self.config.image_size)
+
+ @property
+ def num_channels(self) -> int:
+ """The number of image channels (can be inferred from the dataset)."""
+ if self.config.num_channels == 'infer':
+ return self.info.features['image'].shape[-1]
+ else:
+ return int(self.config.num_channels)
+
+ @property
+ def num_examples(self) -> int:
+ """The number of examples (can be inferred from the dataset)."""
+ if self.config.num_examples == 'infer':
+ return self.info.splits[self.config.split].num_examples
+ else:
+ return int(self.config.num_examples)
+
+ @property
+ def num_classes(self) -> int:
+ """The number of classes (can be inferred from the dataset)."""
+ if self.config.num_classes == 'infer':
+ return self.info.features['label'].num_classes
+ else:
+ return int(self.config.num_classes)
+
+ @property
+ def info(self) -> tfds.core.DatasetInfo:
+ """The TFDS dataset info, if available."""
+ try:
+ if self.builder_info is None:
+ self.builder_info = tfds.builder(self.config.name).info
+ except ConnectionError as e:
+ logging.error('Failed to use TFDS to load info. Please set dataset info '
+ '(image_size, num_channels, num_examples, num_classes) in '
+ 'the dataset config.')
+ raise e
+ return self.builder_info
+
+ def build(
+ self,
+ strategy: Optional[tf.distribute.Strategy] = None) -> tf.data.Dataset:
+ """Construct a dataset end-to-end and return it using an optional strategy.
+
+ Args:
+ strategy: a strategy that, if passed, will distribute the dataset
+ according to that strategy. If passed and `num_devices > 1`,
+ `use_per_replica_batch_size` must be set to `True`.
+
+ Returns:
+ A TensorFlow dataset outputting batched images and labels.
+ """
+ if strategy:
+ if strategy.num_replicas_in_sync != self.config.num_devices:
+ logging.warn(
+ 'Passed a strategy with %d devices, but expected'
+ '%d devices.', strategy.num_replicas_in_sync,
+ self.config.num_devices)
+ dataset = strategy.distribute_datasets_from_function(self._build)
+ else:
+ dataset = self._build()
+
+ return dataset
+
+ def _build(
+ self,
+ input_context: Optional[tf.distribute.InputContext] = None
+ ) -> tf.data.Dataset:
+ """Construct a dataset end-to-end and return it.
+
+ Args:
+ input_context: An optional context provided by `tf.distribute` for
+ cross-replica training.
+
+ Returns:
+ A TensorFlow dataset outputting batched images and labels.
+ """
+ builders = {
+ 'tfds': self.load_tfds,
+ 'records': self.load_records,
+ 'synthetic': self.load_synthetic,
+ }
+
+ builder = builders.get(self.config.builder, None)
+
+ if builder is None:
+ raise ValueError('Unknown builder type {}'.format(self.config.builder))
+
+ self.input_context = input_context
+ dataset = builder()
+ dataset = self.pipeline(dataset)
+
+ return dataset
+
+ def load_tfds(self) -> tf.data.Dataset:
+ """Return a dataset loading files from TFDS."""
+
+ logging.info('Using TFDS to load data.')
+ builder = tfds.builder(self.config.name, data_dir=self.config.data_dir)
+
+ if self.config.download:
+ builder.download_and_prepare()
+
+ decoders = {}
+
+ if self.config.skip_decoding:
+ decoders['image'] = tfds.decode.SkipDecoding()
+
+ read_config = tfds.ReadConfig(
+ interleave_cycle_length=10,
+ interleave_block_length=1,
+ input_context=self.input_context)
+
+ dataset = builder.as_dataset(
+ split=self.config.split,
+ as_supervised=True,
+ shuffle_files=True,
+ decoders=decoders,
+ read_config=read_config)
+
+ return dataset
+
+ def load_records(self) -> tf.data.Dataset:
+ """Return a dataset loading files with TFRecords."""
+ logging.info('Using TFRecords to load data.')
+ if self.config.filenames is None:
+ if self.config.data_dir is None:
+ raise ValueError('Dataset must specify a path for the data files.')
+
+ file_pattern = os.path.join(self.config.data_dir,
+ '{}*'.format(self.config.split))
+ dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False)
+ else:
+ dataset = tf.data.Dataset.from_tensor_slices(self.config.filenames)
+
+ return dataset
+
+ def load_synthetic(self) -> tf.data.Dataset:
+ """Return a dataset generating dummy synthetic data."""
+ logging.info('Generating a synthetic dataset.')
+
+ def generate_data(_):
+ image = tf.zeros([self.image_size, self.image_size, self.num_channels],
+ dtype=self.dtype)
+ label = tf.zeros([1], dtype=tf.int32)
+ return image, label
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+ def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
+ """Build a pipeline fetching, shuffling, and preprocessing the dataset.
+
+ Args:
+ dataset: A `tf.data.Dataset` that loads raw files.
+
+ Returns:
+ A TensorFlow dataset outputting batched images and labels.
+ """
+ if (self.config.builder != 'tfds' and self.input_context and
+ self.input_context.num_input_pipelines > 1):
+ dataset = dataset.shard(self.input_context.num_input_pipelines,
+ self.input_context.input_pipeline_id)
+ logging.info(
+ 'Sharding the dataset: input_pipeline_id=%d '
+ 'num_input_pipelines=%d', self.input_context.num_input_pipelines,
+ self.input_context.input_pipeline_id)
+
+ if self.is_training and self.config.builder == 'records':
+ # Shuffle the input files.
+ dataset.shuffle(buffer_size=self.config.file_shuffle_buffer_size)
+
+ if self.is_training and not self.config.cache:
+ dataset = dataset.repeat()
+
+ if self.config.builder == 'records':
+ # Read the data from disk in parallel
+ dataset = dataset.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=10,
+ block_length=1,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if self.config.cache:
+ dataset = dataset.cache()
+
+ if self.is_training:
+ dataset = dataset.shuffle(self.config.shuffle_buffer_size)
+ dataset = dataset.repeat()
+
+ # Parse, pre-process, and batch the data in parallel
+ if self.config.builder == 'records':
+ preprocess = self.parse_record
+ else:
+ preprocess = self.preprocess
+ dataset = dataset.map(
+ preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if self.input_context and self.config.num_devices > 1:
+ if not self.config.use_per_replica_batch_size:
+ raise ValueError(
+ 'The builder does not support a global batch size with more than '
+ 'one replica. Got {} replicas. Please set a '
+ '`per_replica_batch_size` and enable '
+ '`use_per_replica_batch_size=True`.'.format(
+ self.config.num_devices))
+
+ # The batch size of the dataset will be multiplied by the number of
+ # replicas automatically when strategy.distribute_datasets_from_function
+ # is called, so we use local batch size here.
+ dataset = dataset.batch(
+ self.local_batch_size, drop_remainder=self.is_training)
+ else:
+ dataset = dataset.batch(
+ self.global_batch_size, drop_remainder=self.is_training)
+
+ # Prefetch overlaps in-feed with training
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+
+ if self.config.tf_data_service:
+ if not hasattr(tf.data.experimental, 'service'):
+ raise ValueError('The tf_data_service flag requires Tensorflow version '
+ '>= 2.3.0, but the version is {}'.format(
+ tf.__version__))
+ dataset = dataset.apply(
+ tf.data.experimental.service.distribute(
+ processing_mode='parallel_epochs',
+ service=self.config.tf_data_service,
+ job_name='resnet_train'))
+ dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+
+ return dataset
+
+ def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
+ """Parse an ImageNet record from a serialized string Tensor."""
+ keys_to_features = {
+ 'image/encoded': tf.io.FixedLenFeature((), tf.string, ''),
+ 'image/format': tf.io.FixedLenFeature((), tf.string, 'jpeg'),
+ 'image/class/label': tf.io.FixedLenFeature([], tf.int64, -1),
+ 'image/class/text': tf.io.FixedLenFeature([], tf.string, ''),
+ 'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
+ 'image/object/class/label': tf.io.VarLenFeature(dtype=tf.int64),
+ }
+
+ parsed = tf.io.parse_single_example(record, keys_to_features)
+
+ label = tf.reshape(parsed['image/class/label'], shape=[1])
+
+ # Subtract one so that labels are in [0, 1000)
+ label -= 1
+
+ image_bytes = tf.reshape(parsed['image/encoded'], shape=[])
+ image, label = self.preprocess(image_bytes, label)
+
+ return image, label
+
+ def preprocess(self, image: tf.Tensor,
+ label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
+ """Apply image preprocessing and augmentation to the image and label."""
+ if self.is_training:
+ image = preprocessing.preprocess_for_train(
+ image,
+ image_size=self.image_size,
+ mean_subtract=self.config.mean_subtract,
+ standardize=self.config.standardize,
+ dtype=self.dtype,
+ augmenter=self.augmenter)
+ else:
+ image = preprocessing.preprocess_for_eval(
+ image,
+ image_size=self.image_size,
+ num_channels=self.num_channels,
+ mean_subtract=self.config.mean_subtract,
+ standardize=self.config.standardize,
+ dtype=self.dtype)
+
+ label = tf.cast(label, tf.int32)
+ if self.config.one_hot:
+ label = tf.one_hot(label, self.num_classes)
+ label = tf.reshape(label, [self.num_classes])
+
+ return image, label
+
+ @classmethod
+ def from_params(cls, *args, **kwargs):
+ """Construct a dataset builder from a default config and any overrides."""
+ config = DatasetConfig.from_args(*args, **kwargs)
+ return cls(config)
diff --git a/modeling/official/legacy/image_classification/efficientnet/__init__.py b/modeling/official/legacy/image_classification/efficientnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/image_classification/efficientnet/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/image_classification/efficientnet/common_modules.py b/modeling/official/legacy/image_classification/efficientnet/common_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d07b11c795cb354ba69347131e752963f4b4b2
--- /dev/null
+++ b/modeling/official/legacy/image_classification/efficientnet/common_modules.py
@@ -0,0 +1,116 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common modeling utilities."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from typing import Optional, Text
+import numpy as np
+import tensorflow as tf, tf_keras
+import tensorflow.compat.v1 as tf1
+from tensorflow.python.tpu import tpu_function
+
+
+@tf_keras.utils.register_keras_serializable(package='Vision')
+class TpuBatchNormalization(tf_keras.layers.BatchNormalization):
+ """Cross replica batch normalization."""
+
+ def __init__(self, fused: Optional[bool] = False, **kwargs):
+ if fused in (True, None):
+ raise ValueError('TpuBatchNormalization does not support fused=True.')
+ super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)
+
+ def _cross_replica_average(self, t: tf.Tensor, num_shards_per_group: int):
+ """Calculates the average value of input tensor across TPU replicas."""
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ group_assignment = None
+ if num_shards_per_group > 1:
+ if num_shards % num_shards_per_group != 0:
+ raise ValueError(
+ 'num_shards: %d mod shards_per_group: %d, should be 0' %
+ (num_shards, num_shards_per_group))
+ num_groups = num_shards // num_shards_per_group
+ group_assignment = [[
+ x for x in range(num_shards) if x // num_shards_per_group == y
+ ] for y in range(num_groups)]
+ return tf1.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
+ num_shards_per_group, t.dtype)
+
+ def _moments(self, inputs: tf.Tensor, reduction_axes: int, keep_dims: int):
+ """Compute the mean and variance: it overrides the original _moments."""
+ shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
+ inputs, reduction_axes, keep_dims=keep_dims)
+
+ num_shards = tpu_function.get_tpu_context().number_of_shards or 1
+ if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.
+ num_shards_per_group = 1
+ else:
+ num_shards_per_group = max(8, num_shards // 8)
+ if num_shards_per_group > 1:
+ # Compute variance using: Var[X]= E[X^2] - E[X]^2.
+ shard_square_of_mean = tf.math.square(shard_mean)
+ shard_mean_of_square = shard_variance + shard_square_of_mean
+ group_mean = self._cross_replica_average(shard_mean, num_shards_per_group)
+ group_mean_of_square = self._cross_replica_average(
+ shard_mean_of_square, num_shards_per_group)
+ group_variance = group_mean_of_square - tf.math.square(group_mean)
+ return (group_mean, group_variance)
+ else:
+ return (shard_mean, shard_variance)
+
+
+def get_batch_norm(batch_norm_type: Text) -> tf_keras.layers.BatchNormalization:
+ """A helper to create a batch normalization getter.
+
+ Args:
+ batch_norm_type: The type of batch normalization layer implementation. `tpu`
+ will use `TpuBatchNormalization`.
+
+ Returns:
+ An instance of `tf_keras.layers.BatchNormalization`.
+ """
+ if batch_norm_type == 'tpu':
+ return TpuBatchNormalization
+
+ return tf_keras.layers.BatchNormalization # pytype: disable=bad-return-type # typed-keras
+
+
+def count_params(model, trainable_only=True):
+ """Returns the count of all model parameters, or just trainable ones."""
+ if not trainable_only:
+ return model.count_params()
+ else:
+ return int(
+ np.sum([
+ tf_keras.backend.count_params(p) for p in model.trainable_weights
+ ]))
+
+
+def load_weights(model: tf_keras.Model,
+ model_weights_path: Text,
+ weights_format: Text = 'saved_model'):
+ """Load model weights from the given file path.
+
+ Args:
+ model: the model to load weights into
+ model_weights_path: the path of the model weights
+ weights_format: the model weights format. One of 'saved_model', 'h5', or
+ 'checkpoint'.
+ """
+ if weights_format == 'saved_model':
+ loaded_model = tf_keras.models.load_model(model_weights_path)
+ model.set_weights(loaded_model.get_weights())
+ else:
+ model.load_weights(model_weights_path)
diff --git a/modeling/official/legacy/image_classification/efficientnet/efficientnet_config.py b/modeling/official/legacy/image_classification/efficientnet/efficientnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7526ef0bad494ae44830943656e3558d35de1e
--- /dev/null
+++ b/modeling/official/legacy/image_classification/efficientnet/efficientnet_config.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration definitions for EfficientNet losses, learning rates, and optimizers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import dataclasses
+from official.legacy.image_classification.configs import base_configs
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class EfficientNetModelConfig(base_configs.ModelConfig):
+ """Configuration for the EfficientNet model.
+
+ This configuration will default to settings used for training efficientnet-b0
+ on a v3-8 TPU on ImageNet.
+
+ Attributes:
+ name: The name of the model. Defaults to 'EfficientNet'.
+ num_classes: The number of classes in the model.
+ model_params: A dictionary that represents the parameters of the
+ EfficientNet model. These will be passed in to the "from_name" function.
+ loss: The configuration for loss. Defaults to a categorical cross entropy
+ implementation.
+ optimizer: The configuration for optimizations. Defaults to an RMSProp
+ configuration.
+ learning_rate: The configuration for learning rate. Defaults to an
+ exponential configuration.
+ """
+ name: str = 'EfficientNet'
+ num_classes: int = 1000
+ model_params: base_config.Config = dataclasses.field(
+ default_factory=lambda: {
+ 'model_name': 'efficientnet-b0',
+ 'model_weights_path': '',
+ 'weights_format': 'saved_model',
+ 'overrides': {
+ 'batch_norm': 'default',
+ 'rescale_input': True,
+ 'num_classes': 1000,
+ 'activation': 'swish',
+ 'dtype': 'float32',
+ }
+ })
+ loss: base_configs.LossConfig = dataclasses.field(
+ default_factory=lambda: base_configs.LossConfig( # pylint: disable=g-long-lambda
+ name='categorical_crossentropy', label_smoothing=0.1
+ )
+ )
+ optimizer: base_configs.OptimizerConfig = dataclasses.field(
+ default_factory=lambda: base_configs.OptimizerConfig( # pylint: disable=g-long-lambda
+ name='rmsprop',
+ decay=0.9,
+ epsilon=0.001,
+ momentum=0.9,
+ moving_average_decay=None,
+ )
+ )
+ learning_rate: base_configs.LearningRateConfig = dataclasses.field(
+ default_factory=lambda: base_configs.LearningRateConfig( # pylint: disable=g-long-lambda
+ name='exponential',
+ initial_lr=0.008,
+ decay_epochs=2.4,
+ decay_rate=0.97,
+ warmup_epochs=5,
+ scale_by_batch_size=1.0 / 128.0,
+ staircase=True,
+ )
+ )
diff --git a/modeling/official/legacy/image_classification/efficientnet/efficientnet_model.py b/modeling/official/legacy/image_classification/efficientnet/efficientnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d060df016cd88fb5d162f9a97d3fa4ba797fea9
--- /dev/null
+++ b/modeling/official/legacy/image_classification/efficientnet/efficientnet_model.py
@@ -0,0 +1,495 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Contains definitions for EfficientNet model.
+
+[1] Mingxing Tan, Quoc V. Le
+ EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks.
+ ICML'19, https://arxiv.org/abs/1905.11946
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import dataclasses
+import math
+from typing import Any, Dict, Optional, Text, Tuple
+
+from absl import logging
+import tensorflow as tf, tf_keras
+from official.legacy.image_classification import preprocessing
+from official.legacy.image_classification.efficientnet import common_modules
+from official.modeling import tf_utils
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class BlockConfig(base_config.Config):
+ """Config for a single MB Conv Block."""
+ input_filters: int = 0
+ output_filters: int = 0
+ kernel_size: int = 3
+ num_repeat: int = 1
+ expand_ratio: int = 1
+ strides: Tuple[int, int] = (1, 1)
+ se_ratio: Optional[float] = None
+ id_skip: bool = True
+ fused_conv: bool = False
+ conv_type: str = 'depthwise'
+
+
+@dataclasses.dataclass
+class ModelConfig(base_config.Config):
+ """Default Config for Efficientnet-B0."""
+ width_coefficient: float = 1.0
+ depth_coefficient: float = 1.0
+ resolution: int = 224
+ dropout_rate: float = 0.2
+ blocks: Tuple[BlockConfig, ...] = (
+ # (input_filters, output_filters, kernel_size, num_repeat,
+ # expand_ratio, strides, se_ratio)
+ # pylint: disable=bad-whitespace
+ BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25),
+ BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25),
+ BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25),
+ BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25),
+ BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25),
+ BlockConfig.from_args(112, 192, 5, 4, 6, (2, 2), 0.25),
+ BlockConfig.from_args(192, 320, 3, 1, 6, (1, 1), 0.25),
+ # pylint: enable=bad-whitespace
+ )
+ stem_base_filters: int = 32
+ top_base_filters: int = 1280
+ activation: str = 'simple_swish'
+ batch_norm: str = 'default'
+ bn_momentum: float = 0.99
+ bn_epsilon: float = 1e-3
+ # While the original implementation used a weight decay of 1e-5,
+ # tf.nn.l2_loss divides it by 2, so we halve this to compensate in Keras
+ weight_decay: float = 5e-6
+ drop_connect_rate: float = 0.2
+ depth_divisor: int = 8
+ min_depth: Optional[int] = None
+ use_se: bool = True
+ input_channels: int = 3
+ num_classes: int = 1000
+ model_name: str = 'efficientnet'
+ rescale_input: bool = True
+ data_format: str = 'channels_last'
+ dtype: str = 'float32'
+
+
+MODEL_CONFIGS = {
+ # (width, depth, resolution, dropout)
+ 'efficientnet-b0': ModelConfig.from_args(1.0, 1.0, 224, 0.2),
+ 'efficientnet-b1': ModelConfig.from_args(1.0, 1.1, 240, 0.2),
+ 'efficientnet-b2': ModelConfig.from_args(1.1, 1.2, 260, 0.3),
+ 'efficientnet-b3': ModelConfig.from_args(1.2, 1.4, 300, 0.3),
+ 'efficientnet-b4': ModelConfig.from_args(1.4, 1.8, 380, 0.4),
+ 'efficientnet-b5': ModelConfig.from_args(1.6, 2.2, 456, 0.4),
+ 'efficientnet-b6': ModelConfig.from_args(1.8, 2.6, 528, 0.5),
+ 'efficientnet-b7': ModelConfig.from_args(2.0, 3.1, 600, 0.5),
+ 'efficientnet-b8': ModelConfig.from_args(2.2, 3.6, 672, 0.5),
+ 'efficientnet-l2': ModelConfig.from_args(4.3, 5.3, 800, 0.5),
+}
+
+CONV_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 2.0,
+ 'mode': 'fan_out',
+ # Note: this is a truncated normal distribution
+ 'distribution': 'normal'
+ }
+}
+
+DENSE_KERNEL_INITIALIZER = {
+ 'class_name': 'VarianceScaling',
+ 'config': {
+ 'scale': 1 / 3.0,
+ 'mode': 'fan_out',
+ 'distribution': 'uniform'
+ }
+}
+
+
+def round_filters(filters: int, config: ModelConfig) -> int:
+ """Round number of filters based on width coefficient."""
+ width_coefficient = config.width_coefficient
+ min_depth = config.min_depth
+ divisor = config.depth_divisor
+ orig_filters = filters
+
+ if not width_coefficient:
+ return filters
+
+ filters *= width_coefficient
+ min_depth = min_depth or divisor
+ new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_filters < 0.9 * filters:
+ new_filters += divisor
+ logging.info('round_filter input=%s output=%s', orig_filters, new_filters)
+ return int(new_filters)
+
+
+def round_repeats(repeats: int, depth_coefficient: float) -> int:
+ """Round number of repeats based on depth coefficient."""
+ return int(math.ceil(depth_coefficient * repeats))
+
+
+def conv2d_block(inputs: tf.Tensor,
+ conv_filters: Optional[int],
+ config: ModelConfig,
+ kernel_size: Any = (1, 1),
+ strides: Any = (1, 1),
+ use_batch_norm: bool = True,
+ use_bias: bool = False,
+ activation: Optional[Any] = None,
+ depthwise: bool = False,
+ name: Optional[Text] = None):
+ """A conv2d followed by batch norm and an activation."""
+ batch_norm = common_modules.get_batch_norm(config.batch_norm)
+ bn_momentum = config.bn_momentum
+ bn_epsilon = config.bn_epsilon
+ data_format = tf_keras.backend.image_data_format()
+ weight_decay = config.weight_decay
+
+ name = name or ''
+
+ # Collect args based on what kind of conv2d block is desired
+ init_kwargs = {
+ 'kernel_size': kernel_size,
+ 'strides': strides,
+ 'use_bias': use_bias,
+ 'padding': 'same',
+ 'name': name + '_conv2d',
+ 'kernel_regularizer': tf_keras.regularizers.l2(weight_decay),
+ 'bias_regularizer': tf_keras.regularizers.l2(weight_decay),
+ }
+
+ if depthwise:
+ conv2d = tf_keras.layers.DepthwiseConv2D
+ init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER})
+ else:
+ conv2d = tf_keras.layers.Conv2D
+ init_kwargs.update({
+ 'filters': conv_filters,
+ 'kernel_initializer': CONV_KERNEL_INITIALIZER
+ })
+
+ x = conv2d(**init_kwargs)(inputs)
+
+ if use_batch_norm:
+ bn_axis = 1 if data_format == 'channels_first' else -1
+ x = batch_norm(
+ axis=bn_axis,
+ momentum=bn_momentum,
+ epsilon=bn_epsilon,
+ name=name + '_bn')(
+ x)
+
+ if activation is not None:
+ x = tf_keras.layers.Activation(activation, name=name + '_activation')(x)
+ return x
+
+
+def mb_conv_block(inputs: tf.Tensor,
+ block: BlockConfig,
+ config: ModelConfig,
+ prefix: Optional[Text] = None):
+ """Mobile Inverted Residual Bottleneck.
+
+ Args:
+ inputs: the Keras input to the block
+ block: BlockConfig, arguments to create a Block
+ config: ModelConfig, a set of model parameters
+ prefix: prefix for naming all layers
+
+ Returns:
+ the output of the block
+ """
+ use_se = config.use_se
+ activation = tf_utils.get_activation(config.activation)
+ drop_connect_rate = config.drop_connect_rate
+ data_format = tf_keras.backend.image_data_format()
+ use_depthwise = block.conv_type != 'no_depthwise'
+ prefix = prefix or ''
+
+ filters = block.input_filters * block.expand_ratio
+
+ x = inputs
+
+ if block.fused_conv:
+ # If we use fused mbconv, skip expansion and use regular conv.
+ x = conv2d_block(
+ x,
+ filters,
+ config,
+ kernel_size=block.kernel_size,
+ strides=block.strides,
+ activation=activation,
+ name=prefix + 'fused')
+ else:
+ if block.expand_ratio != 1:
+ # Expansion phase
+ kernel_size = (1, 1) if use_depthwise else (3, 3)
+ x = conv2d_block(
+ x,
+ filters,
+ config,
+ kernel_size=kernel_size,
+ activation=activation,
+ name=prefix + 'expand')
+
+ # Depthwise Convolution
+ if use_depthwise:
+ x = conv2d_block(
+ x,
+ conv_filters=None,
+ config=config,
+ kernel_size=block.kernel_size,
+ strides=block.strides,
+ activation=activation,
+ depthwise=True,
+ name=prefix + 'depthwise')
+
+ # Squeeze and Excitation phase
+ if use_se:
+ assert block.se_ratio is not None
+ assert 0 < block.se_ratio <= 1
+ num_reduced_filters = max(1, int(block.input_filters * block.se_ratio))
+
+ if data_format == 'channels_first':
+ se_shape = (filters, 1, 1)
+ else:
+ se_shape = (1, 1, filters)
+
+ se = tf_keras.layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
+ se = tf_keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')(se)
+
+ se = conv2d_block(
+ se,
+ num_reduced_filters,
+ config,
+ use_bias=True,
+ use_batch_norm=False,
+ activation=activation,
+ name=prefix + 'se_reduce')
+ se = conv2d_block(
+ se,
+ filters,
+ config,
+ use_bias=True,
+ use_batch_norm=False,
+ activation='sigmoid',
+ name=prefix + 'se_expand')
+ x = tf_keras.layers.multiply([x, se], name=prefix + 'se_excite')
+
+ # Output phase
+ x = conv2d_block(
+ x, block.output_filters, config, activation=None, name=prefix + 'project')
+
+ # Add identity so that quantization-aware training can insert quantization
+ # ops correctly.
+ x = tf_keras.layers.Activation(
+ tf_utils.get_activation('identity'), name=prefix + 'id')(
+ x)
+
+ if (block.id_skip and all(s == 1 for s in block.strides) and
+ block.input_filters == block.output_filters):
+ if drop_connect_rate and drop_connect_rate > 0:
+ # Apply dropconnect
+ # The only difference between dropout and dropconnect in TF is scaling by
+ # drop_connect_rate during training. See:
+ # https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
+ x = tf_keras.layers.Dropout(
+ drop_connect_rate, noise_shape=(None, 1, 1, 1), name=prefix + 'drop')(
+ x)
+
+ x = tf_keras.layers.add([x, inputs], name=prefix + 'add')
+
+ return x
+
+
+def efficientnet(image_input: tf_keras.layers.Input, config: ModelConfig): # pytype: disable=invalid-annotation # typed-keras
+ """Creates an EfficientNet graph given the model parameters.
+
+ This function is wrapped by the `EfficientNet` class to make a tf_keras.Model.
+
+ Args:
+ image_input: the input batch of images
+ config: the model config
+
+ Returns:
+ the output of efficientnet
+ """
+ depth_coefficient = config.depth_coefficient
+ blocks = config.blocks
+ stem_base_filters = config.stem_base_filters
+ top_base_filters = config.top_base_filters
+ activation = tf_utils.get_activation(config.activation)
+ dropout_rate = config.dropout_rate
+ drop_connect_rate = config.drop_connect_rate
+ num_classes = config.num_classes
+ input_channels = config.input_channels
+ rescale_input = config.rescale_input
+ data_format = tf_keras.backend.image_data_format()
+ dtype = config.dtype
+ weight_decay = config.weight_decay
+
+ x = image_input
+ if data_format == 'channels_first':
+ # Happens on GPU/TPU if available.
+ x = tf_keras.layers.Permute((3, 1, 2))(x)
+ if rescale_input:
+ x = preprocessing.normalize_images(
+ x, num_channels=input_channels, dtype=dtype, data_format=data_format)
+
+ # Build stem
+ x = conv2d_block(
+ x,
+ round_filters(stem_base_filters, config),
+ config,
+ kernel_size=[3, 3],
+ strides=[2, 2],
+ activation=activation,
+ name='stem')
+
+ # Build blocks
+ num_blocks_total = sum(
+ round_repeats(block.num_repeat, depth_coefficient) for block in blocks)
+ block_num = 0
+
+ for stack_idx, block in enumerate(blocks):
+ assert block.num_repeat > 0
+ # Update block input and output filters based on depth multiplier
+ block = block.replace(
+ input_filters=round_filters(block.input_filters, config),
+ output_filters=round_filters(block.output_filters, config),
+ num_repeat=round_repeats(block.num_repeat, depth_coefficient))
+
+ # The first block needs to take care of stride and filter size increase
+ drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
+ config = config.replace(drop_connect_rate=drop_rate)
+ block_prefix = 'stack_{}/block_0/'.format(stack_idx)
+ x = mb_conv_block(x, block, config, block_prefix)
+ block_num += 1
+ if block.num_repeat > 1:
+ block = block.replace(input_filters=block.output_filters, strides=[1, 1])
+
+ for block_idx in range(block.num_repeat - 1):
+ drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
+ config = config.replace(drop_connect_rate=drop_rate)
+ block_prefix = 'stack_{}/block_{}/'.format(stack_idx, block_idx + 1)
+ x = mb_conv_block(x, block, config, prefix=block_prefix)
+ block_num += 1
+
+ # Build top
+ x = conv2d_block(
+ x,
+ round_filters(top_base_filters, config),
+ config,
+ activation=activation,
+ name='top')
+
+ # Build classifier
+ x = tf_keras.layers.GlobalAveragePooling2D(name='top_pool')(x)
+ if dropout_rate and dropout_rate > 0:
+ x = tf_keras.layers.Dropout(dropout_rate, name='top_dropout')(x)
+ x = tf_keras.layers.Dense(
+ num_classes,
+ kernel_initializer=DENSE_KERNEL_INITIALIZER,
+ kernel_regularizer=tf_keras.regularizers.l2(weight_decay),
+ bias_regularizer=tf_keras.regularizers.l2(weight_decay),
+ name='logits')(
+ x)
+ x = tf_keras.layers.Activation('softmax', name='probs')(x)
+
+ return x
+
+
+class EfficientNet(tf_keras.Model):
+ """Wrapper class for an EfficientNet Keras model.
+
+ Contains helper methods to build, manage, and save metadata about the model.
+ """
+
+ def __init__(self,
+ config: Optional[ModelConfig] = None,
+ overrides: Optional[Dict[Text, Any]] = None):
+ """Create an EfficientNet model.
+
+ Args:
+ config: (optional) the main model parameters to create the model
+ overrides: (optional) a dict containing keys that can override config
+ """
+ overrides = overrides or {}
+ config = config or ModelConfig()
+
+ self.config = config.replace(**overrides)
+
+ input_channels = self.config.input_channels
+ model_name = self.config.model_name
+ input_shape = (None, None, input_channels) # Should handle any size image
+ image_input = tf_keras.layers.Input(shape=input_shape)
+
+ output = efficientnet(image_input, self.config)
+
+ # Cast to float32 in case we have a different model dtype
+ output = tf.cast(output, tf.float32)
+
+ logging.info('Building model %s with params %s', model_name, self.config)
+
+ super(EfficientNet, self).__init__(
+ inputs=image_input, outputs=output, name=model_name)
+
+ @classmethod
+ def from_name(cls,
+ model_name: Text,
+ model_weights_path: Optional[Text] = None,
+ weights_format: Text = 'saved_model',
+ overrides: Optional[Dict[Text, Any]] = None):
+ """Construct an EfficientNet model from a predefined model name.
+
+ E.g., `EfficientNet.from_name('efficientnet-b0')`.
+
+ Args:
+ model_name: the predefined model name
+ model_weights_path: the path to the weights (h5 file or saved model dir)
+ weights_format: the model weights format. One of 'saved_model', 'h5', or
+ 'checkpoint'.
+ overrides: (optional) a dict containing keys that can override config
+
+ Returns:
+ A constructed EfficientNet instance.
+ """
+ model_configs = dict(MODEL_CONFIGS)
+ overrides = dict(overrides) if overrides else {}
+
+ # One can define their own custom models if necessary
+ model_configs.update(overrides.pop('model_config', {}))
+
+ if model_name not in model_configs:
+ raise ValueError('Unknown model name {}'.format(model_name))
+
+ config = model_configs[model_name]
+
+ model = cls(config=config, overrides=overrides)
+
+ if model_weights_path:
+ common_modules.load_weights(
+ model, model_weights_path, weights_format=weights_format)
+
+ return model
diff --git a/modeling/official/legacy/image_classification/efficientnet/tfhub_export.py b/modeling/official/legacy/image_classification/efficientnet/tfhub_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..38f8c6b73787a2d14356e7a0457482fd29755e1f
--- /dev/null
+++ b/modeling/official/legacy/image_classification/efficientnet/tfhub_export.py
@@ -0,0 +1,67 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A script to export TF-Hub SavedModel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import app
+from absl import flags
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.image_classification.efficientnet import efficientnet_model
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("model_name", None, "EfficientNet model name.")
+flags.DEFINE_string("model_path", None, "File path to TF model checkpoint.")
+flags.DEFINE_string("export_path", None,
+ "TF-Hub SavedModel destination path to export.")
+
+
+def export_tfhub(model_path, hub_destination, model_name):
+ """Restores a tf_keras.Model and saves for TF-Hub."""
+ model_configs = dict(efficientnet_model.MODEL_CONFIGS)
+ config = model_configs[model_name]
+
+ image_input = tf_keras.layers.Input(
+ shape=(None, None, 3), name="image_input", dtype=tf.float32)
+ x = image_input * 255.0
+ outputs = efficientnet_model.efficientnet(x, config)
+ hub_model = tf_keras.Model(image_input, outputs)
+ ckpt = tf.train.Checkpoint(model=hub_model)
+ ckpt.restore(model_path).assert_existing_objects_matched()
+ hub_model.save(
+ os.path.join(hub_destination, "classification"), include_optimizer=False)
+
+ feature_vector_output = hub_model.get_layer(name="top_pool").get_output_at(0)
+ hub_model2 = tf_keras.Model(image_input, feature_vector_output)
+ hub_model2.save(
+ os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError("Too many command-line arguments.")
+
+ export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/image_classification/learning_rate.py b/modeling/official/legacy/image_classification/learning_rate.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d5a06fd199f64bf983472012bb8bc579a211463
--- /dev/null
+++ b/modeling/official/legacy/image_classification/learning_rate.py
@@ -0,0 +1,116 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Learning rate utilities for vision tasks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from typing import Any, Mapping, Optional
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+BASE_LEARNING_RATE = 0.1
+
+
+class WarmupDecaySchedule(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """A wrapper for LearningRateSchedule that includes warmup steps."""
+
+ def __init__(self,
+ lr_schedule: tf_keras.optimizers.schedules.LearningRateSchedule,
+ warmup_steps: int,
+ warmup_lr: Optional[float] = None):
+ """Add warmup decay to a learning rate schedule.
+
+ Args:
+ lr_schedule: base learning rate scheduler
+ warmup_steps: number of warmup steps
+ warmup_lr: an optional field for the final warmup learning rate. This
+ should be provided if the base `lr_schedule` does not contain this
+ field.
+ """
+ super(WarmupDecaySchedule, self).__init__()
+ self._lr_schedule = lr_schedule
+ self._warmup_steps = warmup_steps
+ self._warmup_lr = warmup_lr
+
+ def __call__(self, step: int):
+ lr = self._lr_schedule(step)
+ if self._warmup_steps:
+ if self._warmup_lr is not None:
+ initial_learning_rate = tf.convert_to_tensor(
+ self._warmup_lr, name="initial_learning_rate")
+ else:
+ initial_learning_rate = tf.convert_to_tensor(
+ self._lr_schedule.initial_learning_rate,
+ name="initial_learning_rate")
+ dtype = initial_learning_rate.dtype
+ global_step_recomp = tf.cast(step, dtype)
+ warmup_steps = tf.cast(self._warmup_steps, dtype)
+ warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps
+ lr = tf.cond(global_step_recomp < warmup_steps, lambda: warmup_lr,
+ lambda: lr)
+ return lr
+
+ def get_config(self) -> Mapping[str, Any]:
+ config = self._lr_schedule.get_config()
+ config.update({
+ "warmup_steps": self._warmup_steps,
+ "warmup_lr": self._warmup_lr,
+ })
+ return config
+
+
+class CosineDecayWithWarmup(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Class to generate learning rate tensor."""
+
+ def __init__(self, batch_size: int, total_steps: int, warmup_steps: int):
+ """Creates the cosine learning rate tensor with linear warmup.
+
+ Args:
+ batch_size: The training batch size used in the experiment.
+ total_steps: Total training steps.
+ warmup_steps: Steps for the warm up period.
+ """
+ super(CosineDecayWithWarmup, self).__init__()
+ base_lr_batch_size = 256
+ self._total_steps = total_steps
+ self._init_learning_rate = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
+ self._warmup_steps = warmup_steps
+
+ def __call__(self, global_step: int):
+ global_step = tf.cast(global_step, dtype=tf.float32)
+ warmup_steps = self._warmup_steps
+ init_lr = self._init_learning_rate
+ total_steps = self._total_steps
+
+ linear_warmup = global_step / warmup_steps * init_lr
+
+ cosine_learning_rate = init_lr * (tf.cos(np.pi *
+ (global_step - warmup_steps) /
+ (total_steps - warmup_steps)) +
+ 1.0) / 2.0
+
+ learning_rate = tf.where(global_step < warmup_steps, linear_warmup,
+ cosine_learning_rate)
+ return learning_rate
+
+ def get_config(self):
+ return {
+ "total_steps": self._total_steps,
+ "warmup_learning_rate": self._warmup_learning_rate,
+ "warmup_steps": self._warmup_steps,
+ "init_learning_rate": self._init_learning_rate,
+ }
diff --git a/modeling/official/legacy/image_classification/learning_rate_test.py b/modeling/official/legacy/image_classification/learning_rate_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fbb1e11681da13dbfc197b531f69189b2b838fb
--- /dev/null
+++ b/modeling/official/legacy/image_classification/learning_rate_test.py
@@ -0,0 +1,60 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for learning_rate."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.image_classification import learning_rate
+
+
+class LearningRateTests(tf.test.TestCase):
+
+ def test_warmup_decay(self):
+ """Basic computational test for warmup decay."""
+ initial_lr = 0.01
+ decay_steps = 100
+ decay_rate = 0.01
+ warmup_steps = 10
+
+ base_lr = tf_keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate=initial_lr,
+ decay_steps=decay_steps,
+ decay_rate=decay_rate)
+ lr = learning_rate.WarmupDecaySchedule(
+ lr_schedule=base_lr, warmup_steps=warmup_steps)
+
+ for step in range(warmup_steps - 1):
+ config = lr.get_config()
+ self.assertEqual(config['warmup_steps'], warmup_steps)
+ self.assertAllClose(
+ self.evaluate(lr(step)), step / warmup_steps * initial_lr)
+
+ def test_cosine_decay_with_warmup(self):
+ """Basic computational test for cosine decay with warmup."""
+ expected_lrs = [0.0, 0.1, 0.05, 0.0]
+
+ lr = learning_rate.CosineDecayWithWarmup(
+ batch_size=256, total_steps=3, warmup_steps=1)
+
+ for step in [0, 1, 2, 3]:
+ self.assertAllClose(lr(step), expected_lrs[step])
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/image_classification/mnist_main.py b/modeling/official/legacy/image_classification/mnist_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..c62b6ab3f4237f128d63aa44c673343e77ae29e5
--- /dev/null
+++ b/modeling/official/legacy/image_classification/mnist_main.py
@@ -0,0 +1,176 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Runs a simple model on the MNIST dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+import tensorflow_datasets as tfds
+from official.common import distribute_utils
+from official.legacy.image_classification.resnet import common
+from official.utils.flags import core as flags_core
+from official.utils.misc import model_helpers
+
+FLAGS = flags.FLAGS
+
+
+def build_model():
+ """Constructs the ML model used to predict handwritten digits."""
+
+ image = tf_keras.layers.Input(shape=(28, 28, 1))
+
+ y = tf_keras.layers.Conv2D(filters=32,
+ kernel_size=5,
+ padding='same',
+ activation='relu')(image)
+ y = tf_keras.layers.MaxPooling2D(pool_size=(2, 2),
+ strides=(2, 2),
+ padding='same')(y)
+ y = tf_keras.layers.Conv2D(filters=32,
+ kernel_size=5,
+ padding='same',
+ activation='relu')(y)
+ y = tf_keras.layers.MaxPooling2D(pool_size=(2, 2),
+ strides=(2, 2),
+ padding='same')(y)
+ y = tf_keras.layers.Flatten()(y)
+ y = tf_keras.layers.Dense(1024, activation='relu')(y)
+ y = tf_keras.layers.Dropout(0.4)(y)
+
+ probs = tf_keras.layers.Dense(10, activation='softmax')(y)
+
+ model = tf_keras.models.Model(image, probs, name='mnist')
+
+ return model
+
+
+@tfds.decode.make_decoder(output_dtype=tf.float32)
+def decode_image(example, feature):
+ """Convert image to float32 and normalize from [0, 255] to [0.0, 1.0]."""
+ return tf.cast(feature.decode_example(example), dtype=tf.float32) / 255
+
+
+def run(flags_obj, datasets_override=None, strategy_override=None):
+ """Run MNIST model training and eval loop using native Keras APIs.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+ datasets_override: A pair of `tf.data.Dataset` objects to train the model,
+ representing the train and test sets.
+ strategy_override: A `tf.distribute.Strategy` object to use for model.
+
+ Returns:
+ Dictionary of training and eval stats.
+ """
+ # Start TF profiler server.
+ tf.profiler.experimental.server.start(flags_obj.profiler_port)
+
+ strategy = strategy_override or distribute_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=flags_obj.num_gpus,
+ tpu_address=flags_obj.tpu)
+
+ strategy_scope = distribute_utils.get_strategy_scope(strategy)
+
+ mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
+ if flags_obj.download:
+ mnist.download_and_prepare()
+
+ mnist_train, mnist_test = datasets_override or mnist.as_dataset(
+ split=['train', 'test'],
+ decoders={'image': decode_image()}, # pylint: disable=no-value-for-parameter
+ as_supervised=True)
+ train_input_dataset = mnist_train.cache().repeat().shuffle(
+ buffer_size=50000).batch(flags_obj.batch_size)
+ eval_input_dataset = mnist_test.cache().repeat().batch(flags_obj.batch_size)
+
+ with strategy_scope:
+ lr_schedule = tf_keras.optimizers.schedules.ExponentialDecay(
+ 0.05, decay_steps=100000, decay_rate=0.96)
+ optimizer = tf_keras.optimizers.SGD(learning_rate=lr_schedule)
+
+ model = build_model()
+ model.compile(
+ optimizer=optimizer,
+ loss='sparse_categorical_crossentropy',
+ metrics=['sparse_categorical_accuracy'])
+
+ num_train_examples = mnist.info.splits['train'].num_examples
+ train_steps = num_train_examples // flags_obj.batch_size
+ train_epochs = flags_obj.train_epochs
+
+ ckpt_full_path = os.path.join(flags_obj.model_dir, 'model.ckpt-{epoch:04d}')
+ callbacks = [
+ tf_keras.callbacks.ModelCheckpoint(
+ ckpt_full_path, save_weights_only=True),
+ tf_keras.callbacks.TensorBoard(log_dir=flags_obj.model_dir),
+ ]
+
+ num_eval_examples = mnist.info.splits['test'].num_examples
+ num_eval_steps = num_eval_examples // flags_obj.batch_size
+
+ history = model.fit(
+ train_input_dataset,
+ epochs=train_epochs,
+ steps_per_epoch=train_steps,
+ callbacks=callbacks,
+ validation_steps=num_eval_steps,
+ validation_data=eval_input_dataset,
+ validation_freq=flags_obj.epochs_between_evals)
+
+ export_path = os.path.join(flags_obj.model_dir, 'saved_model')
+ model.save(export_path, include_optimizer=False)
+
+ eval_output = model.evaluate(
+ eval_input_dataset, steps=num_eval_steps, verbose=2)
+
+ stats = common.build_stats(history, eval_output, callbacks)
+ return stats
+
+
+def define_mnist_flags():
+ """Define command line flags for MNIST model."""
+ flags_core.define_base(
+ clean=True,
+ num_gpu=True,
+ train_epochs=True,
+ epochs_between_evals=True,
+ distribution_strategy=True)
+ flags_core.define_device()
+ flags_core.define_distribution()
+ flags.DEFINE_bool('download', True,
+ 'Whether to download data to `--data_dir`.')
+ flags.DEFINE_integer('profiler_port', 9012,
+ 'Port to start profiler server on.')
+ FLAGS.set_default('batch_size', 1024)
+
+
+def main(_):
+ model_helpers.apply_clean(FLAGS)
+ stats = run(flags.FLAGS)
+ logging.info('Run stats:\n%s', stats)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ define_mnist_flags()
+ app.run(main)
diff --git a/modeling/official/legacy/image_classification/mnist_test.py b/modeling/official/legacy/image_classification/mnist_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..495d52b3a2bd7d9cb6ecad787b615e7394f3b460
--- /dev/null
+++ b/modeling/official/legacy/image_classification/mnist_test.py
@@ -0,0 +1,89 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test the Keras MNIST model on GPU."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.legacy.image_classification import mnist_main
+from official.utils.testing import integration
+
+
+mnist_main.define_mnist_flags()
+
+
+def eager_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],)
+
+
+class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
+ """Unit tests for sample Keras MNIST model."""
+ _tempdir = None
+
+ @classmethod
+ def setUpClass(cls): # pylint: disable=invalid-name
+ super(KerasMnistTest, cls).setUpClass()
+
+ def tearDown(self):
+ super(KerasMnistTest, self).tearDown()
+ tf.io.gfile.rmtree(self.get_temp_dir())
+
+ @combinations.generate(eager_strategy_combinations())
+ def test_end_to_end(self, distribution):
+ """Test Keras MNIST model with `strategy`."""
+
+ extra_flags = [
+ "-train_epochs",
+ "1",
+ # Let TFDS find the metadata folder automatically
+ "--data_dir="
+ ]
+
+ dummy_data = (
+ tf.ones(shape=(10, 28, 28, 1), dtype=tf.int32),
+ tf.range(10),
+ )
+ datasets = (
+ tf.data.Dataset.from_tensor_slices(dummy_data),
+ tf.data.Dataset.from_tensor_slices(dummy_data),
+ )
+
+ run = functools.partial(
+ mnist_main.run,
+ datasets_override=datasets,
+ strategy_override=distribution)
+
+ integration.run_synthetic(
+ main=run,
+ synth=False,
+ tmp_root=self.create_tempdir().full_path,
+ extra_flags=extra_flags)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/image_classification/optimizer_factory.py b/modeling/official/legacy/image_classification/optimizer_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..e57e398a03d093f4b4b25638640eeddc10b45837
--- /dev/null
+++ b/modeling/official/legacy/image_classification/optimizer_factory.py
@@ -0,0 +1,335 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Optimizer factory for vision tasks."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from typing import Any, Dict, Optional, Text, Union
+
+from absl import logging
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.legacy.image_classification import learning_rate
+from official.legacy.image_classification.configs import base_configs
+from official.modeling import optimization
+from official.modeling.optimization import legacy_adamw
+
+# pylint: disable=protected-access
+
+FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]
+
+
+class Lookahead(tf_keras.optimizers.legacy.Optimizer):
+ """This class allows to extend optimizers with the lookahead mechanism.
+
+ The mechanism is proposed by Michael R. Zhang et.al in the paper [Lookahead
+ Optimizer: k steps forward, 1 step back] (https://arxiv.org/abs/1907.08610v1).
+ The optimizer iteratively updates two sets of weights: the search directions
+ for weights are chosen by the inner optimizer, while the "slow weights" are
+ updated each `k` steps based on the directions of the "fast weights" and the
+ two sets of weights are synchronized. This method improves the learning
+ stability and lowers the variance of its inner optimizer.
+
+ Example of usage:
+
+ ```python
+ opt = tf_keras.optimizers.SGD(learning_rate) opt =
+ tfa.optimizers.Lookahead(opt)
+ ```
+ """
+
+ def __init__(
+ self,
+ optimizer: tf_keras.optimizers.Optimizer,
+ sync_period: int = 6,
+ slow_step_size: FloatTensorLike = 0.5,
+ name: str = 'Lookahead',
+ **kwargs,
+ ):
+ """Wrap optimizer with the lookahead mechanism.
+
+ Args:
+ optimizer: The original optimizer that will be used to compute and apply
+ the gradients.
+ sync_period: An integer. The synchronization period of lookahead. Enable
+ lookahead mechanism by setting it with a positive value.
+ slow_step_size: A floating point value. The ratio for updating the slow
+ weights.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Lookahead".
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
+ `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
+ gradients by value, `decay` is included for backward compatibility to
+ allow time inverse decay of learning rate. `lr` is included for backward
+ compatibility, recommended to use `learning_rate` instead.
+ """
+ super().__init__(name, **kwargs)
+
+ if isinstance(optimizer, str):
+ optimizer = tf_keras.optimizers.get(optimizer)
+ if not isinstance(
+ optimizer,
+ (tf_keras.optimizers.Optimizer, tf_keras.optimizers.legacy.Optimizer),
+ ):
+ raise TypeError(
+ 'optimizer is not an object of tf_keras.optimizers.Optimizer'
+ )
+
+ self._optimizer = optimizer
+ self._set_hyper('sync_period', sync_period)
+ self._set_hyper('slow_step_size', slow_step_size)
+ self._initialized = False
+ self._track_trackable(self._optimizer, 'lh_base_optimizer')
+
+ def _create_slots(self, var_list):
+ self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
+ for var in var_list:
+ self.add_slot(var, 'slow', initializer=var)
+
+ def _create_hypers(self):
+ self._optimizer._create_hypers() # pylint: disable=protected-access
+
+ def _prepare(self, var_list):
+ return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access
+
+ def apply_gradients(
+ self, grads_and_vars, name=None, skip_gradients_aggregation=None, **kwargs
+ ):
+ self._optimizer._iterations = self.iterations # pylint: disable=protected-access
+ return super().apply_gradients(grads_and_vars, name, **kwargs)
+
+ def _look_ahead_op(self, var):
+ var_dtype = var.dtype.base_dtype
+ slow_var = self.get_slot(var, 'slow')
+ local_step = tf.cast(self.iterations + 1, tf.dtypes.int64)
+ sync_period = self._get_hyper('sync_period', tf.dtypes.int64)
+ slow_step_size = self._get_hyper('slow_step_size', var_dtype)
+ step_back = slow_var + slow_step_size * (var - slow_var)
+ sync_cond = tf.equal(
+ tf.math.floordiv(local_step, sync_period) * sync_period, local_step
+ )
+ with tf.control_dependencies([step_back]):
+ slow_update = slow_var.assign(
+ tf.where(sync_cond, step_back, slow_var),
+ use_locking=self._use_locking,
+ )
+ var_update = var.assign(
+ tf.where(sync_cond, step_back, var), use_locking=self._use_locking
+ )
+ return tf.group(slow_update, var_update)
+
+ @property
+ def weights(self):
+ return self._weights + self._optimizer.weights
+
+ def _resource_apply_dense(self, grad, var):
+ train_op = self._optimizer._resource_apply_dense(grad, var) # pylint: disable=protected-access
+ with tf.control_dependencies([train_op]):
+ look_ahead_op = self._look_ahead_op(var)
+ return tf.group(train_op, look_ahead_op)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ train_op = self._optimizer._resource_apply_sparse( # pylint: disable=protected-access
+ grad, var, indices
+ )
+ with tf.control_dependencies([train_op]):
+ look_ahead_op = self._look_ahead_op(var)
+ return tf.group(train_op, look_ahead_op)
+
+ def get_config(self):
+ config = {
+ 'optimizer': tf_keras.optimizers.serialize(self._optimizer),
+ 'sync_period': self._serialize_hyperparameter('sync_period'),
+ 'slow_step_size': self._serialize_hyperparameter('slow_step_size'),
+ }
+ base_config = super().get_config()
+ return {**base_config, **config}
+
+ @property
+ def learning_rate(self):
+ return self._optimizer._get_hyper('learning_rate')
+
+ @learning_rate.setter
+ def learning_rate(self, value):
+ self._optimizer._set_hyper('learning_rate', value)
+
+ @property
+ def lr(self):
+ return self.learning_rate
+
+ @lr.setter
+ def lr(self, lr):
+ self.learning_rate = lr
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ optimizer = tf_keras.optimizers.deserialize(
+ config.pop('optimizer'), custom_objects=custom_objects
+ )
+ return cls(optimizer, **config)
+
+
+def build_optimizer(
+ optimizer_name: Text,
+ base_learning_rate: tf_keras.optimizers.schedules.LearningRateSchedule,
+ params: Dict[Text, Any],
+ model: Optional[tf_keras.Model] = None):
+ """Build the optimizer based on name.
+
+ Args:
+ optimizer_name: String representation of the optimizer name. Examples: sgd,
+ momentum, rmsprop.
+ base_learning_rate: `tf_keras.optimizers.schedules.LearningRateSchedule`
+ base learning rate.
+ params: String -> Any dictionary representing the optimizer params. This
+ should contain optimizer specific parameters such as `base_learning_rate`,
+ `decay`, etc.
+ model: The `tf_keras.Model`. This is used for the shadow copy if using
+ `ExponentialMovingAverage`.
+
+ Returns:
+ A tf_keras.optimizers.legacy.Optimizer.
+
+ Raises:
+ ValueError if the provided optimizer_name is not supported.
+
+ """
+ optimizer_name = optimizer_name.lower()
+ logging.info('Building %s optimizer with params %s', optimizer_name, params)
+
+ if optimizer_name == 'sgd':
+ logging.info('Using SGD optimizer')
+ nesterov = params.get('nesterov', False)
+ optimizer = tf_keras.optimizers.legacy.SGD(
+ learning_rate=base_learning_rate, nesterov=nesterov)
+ elif optimizer_name == 'momentum':
+ logging.info('Using momentum optimizer')
+ nesterov = params.get('nesterov', False)
+ optimizer = tf_keras.optimizers.legacy.SGD(
+ learning_rate=base_learning_rate,
+ momentum=params['momentum'],
+ nesterov=nesterov)
+ elif optimizer_name == 'rmsprop':
+ logging.info('Using RMSProp')
+ rho = params.get('decay', None) or params.get('rho', 0.9)
+ momentum = params.get('momentum', 0.9)
+ epsilon = params.get('epsilon', 1e-07)
+ optimizer = tf_keras.optimizers.legacy.RMSprop(
+ learning_rate=base_learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon)
+ elif optimizer_name == 'adam':
+ logging.info('Using Adam')
+ beta_1 = params.get('beta_1', 0.9)
+ beta_2 = params.get('beta_2', 0.999)
+ epsilon = params.get('epsilon', 1e-07)
+ optimizer = tf_keras.optimizers.legacy.Adam(
+ learning_rate=base_learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon)
+ elif optimizer_name == 'adamw':
+ logging.info('Using AdamW')
+ weight_decay = params.get('weight_decay', 0.01)
+ beta_1 = params.get('beta_1', 0.9)
+ beta_2 = params.get('beta_2', 0.999)
+ epsilon = params.get('epsilon', 1e-07)
+ optimizer = legacy_adamw.AdamWeightDecay(
+ learning_rate=base_learning_rate,
+ weight_decay_rate=weight_decay,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon,
+ )
+ else:
+ raise ValueError('Unknown optimizer %s' % optimizer_name)
+
+ if params.get('lookahead', None):
+ logging.info('Using lookahead optimizer.')
+ optimizer = Lookahead(optimizer)
+
+ # Moving average should be applied last, as it's applied at test time
+ moving_average_decay = params.get('moving_average_decay', 0.)
+ if moving_average_decay is not None and moving_average_decay > 0.:
+ if model is None:
+ raise ValueError(
+ '`model` must be provided if using `ExponentialMovingAverage`.')
+ logging.info('Including moving average decay.')
+ optimizer = optimization.ExponentialMovingAverage(
+ optimizer=optimizer, average_decay=moving_average_decay)
+ optimizer.shadow_copy(model)
+ return optimizer
+
+
+def build_learning_rate(params: base_configs.LearningRateConfig,
+ batch_size: Optional[int] = None,
+ train_epochs: Optional[int] = None,
+ train_steps: Optional[int] = None):
+ """Build the learning rate given the provided configuration."""
+ decay_type = params.name
+ base_lr = params.initial_lr
+ decay_rate = params.decay_rate
+ if params.decay_epochs is not None:
+ decay_steps = params.decay_epochs * train_steps
+ else:
+ decay_steps = 0
+ if params.warmup_epochs is not None:
+ warmup_steps = params.warmup_epochs * train_steps
+ else:
+ warmup_steps = 0
+
+ lr_multiplier = params.scale_by_batch_size
+
+ if lr_multiplier and lr_multiplier > 0:
+ # Scale the learning rate based on the batch size and a multiplier
+ base_lr *= lr_multiplier * batch_size
+ logging.info(
+ 'Scaling the learning rate based on the batch size '
+ 'multiplier. New base_lr: %f', base_lr)
+
+ if decay_type == 'exponential':
+ logging.info(
+ 'Using exponential learning rate with: '
+ 'initial_learning_rate: %f, decay_steps: %d, '
+ 'decay_rate: %f', base_lr, decay_steps, decay_rate)
+ lr = tf_keras.optimizers.schedules.ExponentialDecay(
+ initial_learning_rate=base_lr,
+ decay_steps=decay_steps,
+ decay_rate=decay_rate,
+ staircase=params.staircase)
+ elif decay_type == 'stepwise':
+ steps_per_epoch = params.examples_per_epoch // batch_size
+ boundaries = [boundary * steps_per_epoch for boundary in params.boundaries]
+ multipliers = [batch_size * multiplier for multiplier in params.multipliers]
+ logging.info(
+ 'Using stepwise learning rate. Parameters: '
+ 'boundaries: %s, values: %s', boundaries, multipliers)
+ lr = tf_keras.optimizers.schedules.PiecewiseConstantDecay(
+ boundaries=boundaries, values=multipliers)
+ elif decay_type == 'cosine_with_warmup':
+ lr = learning_rate.CosineDecayWithWarmup(
+ batch_size=batch_size,
+ total_steps=train_epochs * train_steps,
+ warmup_steps=warmup_steps)
+ if warmup_steps > 0:
+ if decay_type not in ['cosine_with_warmup']:
+ logging.info('Applying %d warmup steps to the learning rate',
+ warmup_steps)
+ lr = learning_rate.WarmupDecaySchedule(
+ lr, warmup_steps, warmup_lr=base_lr)
+ return lr
diff --git a/modeling/official/legacy/image_classification/optimizer_factory_test.py b/modeling/official/legacy/image_classification/optimizer_factory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f96d6da99360fad057bd0642c3e58bf743958e04
--- /dev/null
+++ b/modeling/official/legacy/image_classification/optimizer_factory_test.py
@@ -0,0 +1,120 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for optimizer_factory."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+import tensorflow as tf, tf_keras
+from official.legacy.image_classification import optimizer_factory
+from official.legacy.image_classification.configs import base_configs
+
+
+class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
+
+ def build_toy_model(self) -> tf_keras.Model:
+ """Creates a toy `tf.Keras.Model`."""
+ model = tf_keras.Sequential()
+ model.add(tf_keras.layers.Dense(1, input_shape=(1,)))
+ return model
+
+ @parameterized.named_parameters(
+ ('sgd', 'sgd', 0., False), ('momentum', 'momentum', 0., False),
+ ('rmsprop', 'rmsprop', 0., False), ('adam', 'adam', 0., False),
+ ('adamw', 'adamw', 0., False),
+ ('momentum_lookahead', 'momentum', 0., True),
+ ('sgd_ema', 'sgd', 0.999, False),
+ ('momentum_ema', 'momentum', 0.999, False),
+ ('rmsprop_ema', 'rmsprop', 0.999, False))
+ def test_optimizer(self, optimizer_name, moving_average_decay, lookahead):
+ """Smoke test to be sure no syntax errors."""
+ model = self.build_toy_model()
+ params = {
+ 'learning_rate': 0.001,
+ 'rho': 0.09,
+ 'momentum': 0.,
+ 'epsilon': 1e-07,
+ 'moving_average_decay': moving_average_decay,
+ 'lookahead': lookahead,
+ }
+ optimizer = optimizer_factory.build_optimizer(
+ optimizer_name=optimizer_name,
+ base_learning_rate=params['learning_rate'],
+ params=params,
+ model=model)
+ self.assertTrue(
+ issubclass(type(optimizer), tf_keras.optimizers.legacy.Optimizer)
+ )
+
+ def test_unknown_optimizer(self):
+ with self.assertRaises(ValueError):
+ optimizer_factory.build_optimizer(
+ optimizer_name='this_optimizer_does_not_exist',
+ base_learning_rate=None,
+ params=None)
+
+ def test_learning_rate_without_decay_or_warmups(self):
+ params = base_configs.LearningRateConfig(
+ name='exponential',
+ initial_lr=0.01,
+ decay_rate=0.01,
+ decay_epochs=None,
+ warmup_epochs=None,
+ scale_by_batch_size=0.01,
+ examples_per_epoch=1,
+ boundaries=[0],
+ multipliers=[0, 1])
+ batch_size = 1
+ train_steps = 1
+
+ lr = optimizer_factory.build_learning_rate(
+ params=params, batch_size=batch_size, train_steps=train_steps)
+ self.assertTrue(
+ issubclass(
+ type(lr), tf_keras.optimizers.schedules.LearningRateSchedule))
+
+ @parameterized.named_parameters(('exponential', 'exponential'),
+ ('cosine_with_warmup', 'cosine_with_warmup'))
+ def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
+ """Basic smoke test for syntax."""
+ params = base_configs.LearningRateConfig(
+ name=lr_decay_type,
+ initial_lr=0.01,
+ decay_rate=0.01,
+ decay_epochs=1,
+ warmup_epochs=1,
+ scale_by_batch_size=0.01,
+ examples_per_epoch=1,
+ boundaries=[0],
+ multipliers=[0, 1])
+ batch_size = 1
+ train_epochs = 1
+ train_steps = 1
+
+ lr = optimizer_factory.build_learning_rate(
+ params=params,
+ batch_size=batch_size,
+ train_epochs=train_epochs,
+ train_steps=train_steps)
+ self.assertTrue(
+ issubclass(
+ type(lr), tf_keras.optimizers.schedules.LearningRateSchedule))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/image_classification/preprocessing.py b/modeling/official/legacy/image_classification/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc73cc93fa02a8ff21f51227bb1644f125b51067
--- /dev/null
+++ b/modeling/official/legacy/image_classification/preprocessing.py
@@ -0,0 +1,391 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Preprocessing functions for images."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from typing import List, Optional, Text, Tuple
+import tensorflow as tf, tf_keras
+from official.legacy.image_classification import augment
+
+
+# Calculated from the ImageNet training set
+MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
+STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
+
+IMAGE_SIZE = 224
+CROP_PADDING = 32
+
+
+def mean_image_subtraction(
+ image_bytes: tf.Tensor,
+ means: Tuple[float, ...],
+ num_channels: int = 3,
+ dtype: tf.dtypes.DType = tf.float32,
+) -> tf.Tensor:
+ """Subtracts the given means from each image channel.
+
+ For example:
+ means = [123.68, 116.779, 103.939]
+ image_bytes = mean_image_subtraction(image_bytes, means)
+
+ Note that the rank of `image` must be known.
+
+ Args:
+ image_bytes: a tensor of size [height, width, C].
+ means: a C-vector of values to subtract from each channel.
+ num_channels: number of color channels in the image that will be distorted.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ the centered image.
+
+ Raises:
+ ValueError: If the rank of `image` is unknown, if `image` has a rank other
+ than three or if the number of channels in `image` doesn't match the
+ number of values in `means`.
+ """
+ if image_bytes.get_shape().ndims != 3:
+ raise ValueError('Input must be of size [height, width, C>0]')
+
+ if len(means) != num_channels:
+ raise ValueError('len(means) must match the number of channels')
+
+ # We have a 1-D tensor of means; convert to 3-D.
+ # Note(b/130245863): we explicitly call `broadcast` instead of simply
+ # expanding dimensions for better performance.
+ means = tf.broadcast_to(means, tf.shape(image_bytes))
+ if dtype is not None:
+ means = tf.cast(means, dtype=dtype)
+
+ return image_bytes - means
+
+
+def standardize_image(
+ image_bytes: tf.Tensor,
+ stddev: Tuple[float, ...],
+ num_channels: int = 3,
+ dtype: tf.dtypes.DType = tf.float32,
+) -> tf.Tensor:
+ """Divides the given stddev from each image channel.
+
+ For example:
+ stddev = [123.68, 116.779, 103.939]
+ image_bytes = standardize_image(image_bytes, stddev)
+
+ Note that the rank of `image` must be known.
+
+ Args:
+ image_bytes: a tensor of size [height, width, C].
+ stddev: a C-vector of values to divide from each channel.
+ num_channels: number of color channels in the image that will be distorted.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ the centered image.
+
+ Raises:
+ ValueError: If the rank of `image` is unknown, if `image` has a rank other
+ than three or if the number of channels in `image` doesn't match the
+ number of values in `stddev`.
+ """
+ if image_bytes.get_shape().ndims != 3:
+ raise ValueError('Input must be of size [height, width, C>0]')
+
+ if len(stddev) != num_channels:
+ raise ValueError('len(stddev) must match the number of channels')
+
+ # We have a 1-D tensor of stddev; convert to 3-D.
+ # Note(b/130245863): we explicitly call `broadcast` instead of simply
+ # expanding dimensions for better performance.
+ stddev = tf.broadcast_to(stddev, tf.shape(image_bytes))
+ if dtype is not None:
+ stddev = tf.cast(stddev, dtype=dtype)
+
+ return image_bytes / stddev
+
+
+def normalize_images(features: tf.Tensor,
+ mean_rgb: Tuple[float, ...] = MEAN_RGB,
+ stddev_rgb: Tuple[float, ...] = STDDEV_RGB,
+ num_channels: int = 3,
+ dtype: tf.dtypes.DType = tf.float32,
+ data_format: Text = 'channels_last') -> tf.Tensor:
+ """Normalizes the input image channels with the given mean and stddev.
+
+ Args:
+ features: `Tensor` representing decoded images in float format.
+ mean_rgb: the mean of the channels to subtract.
+ stddev_rgb: the stddev of the channels to divide.
+ num_channels: the number of channels in the input image tensor.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+ data_format: the format of the input image tensor
+ ['channels_first', 'channels_last'].
+
+ Returns:
+ A normalized image `Tensor`.
+ """
+ # TODO(allencwang) - figure out how to use mean_image_subtraction and
+ # standardize_image on batches of images and replace the following.
+ if data_format == 'channels_first':
+ stats_shape = [num_channels, 1, 1]
+ else:
+ stats_shape = [1, 1, num_channels]
+
+ if dtype is not None:
+ features = tf.image.convert_image_dtype(features, dtype=dtype)
+
+ if mean_rgb is not None:
+ mean_rgb = tf.constant(mean_rgb,
+ shape=stats_shape,
+ dtype=features.dtype)
+ mean_rgb = tf.broadcast_to(mean_rgb, tf.shape(features))
+ features = features - mean_rgb
+
+ if stddev_rgb is not None:
+ stddev_rgb = tf.constant(stddev_rgb,
+ shape=stats_shape,
+ dtype=features.dtype)
+ stddev_rgb = tf.broadcast_to(stddev_rgb, tf.shape(features))
+ features = features / stddev_rgb
+
+ return features
+
+
+def decode_and_center_crop(image_bytes: tf.Tensor,
+ image_size: int = IMAGE_SIZE,
+ crop_padding: int = CROP_PADDING) -> tf.Tensor:
+ """Crops to center of image with padding then scales image_size.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ image_size: image height/width dimension.
+ crop_padding: the padding size to use when centering the crop.
+
+ Returns:
+ A decoded and cropped image `Tensor`.
+ """
+ decoded = image_bytes.dtype != tf.string
+ shape = (tf.shape(image_bytes) if decoded
+ else tf.image.extract_jpeg_shape(image_bytes))
+ image_height = shape[0]
+ image_width = shape[1]
+
+ padded_center_crop_size = tf.cast(
+ ((image_size / (image_size + crop_padding)) *
+ tf.cast(tf.minimum(image_height, image_width), tf.float32)),
+ tf.int32)
+
+ offset_height = ((image_height - padded_center_crop_size) + 1) // 2
+ offset_width = ((image_width - padded_center_crop_size) + 1) // 2
+ crop_window = tf.stack([offset_height, offset_width,
+ padded_center_crop_size, padded_center_crop_size])
+ if decoded:
+ image = tf.image.crop_to_bounding_box(
+ image_bytes,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=padded_center_crop_size,
+ target_width=padded_center_crop_size)
+ else:
+ image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
+
+ image = resize_image(image_bytes=image,
+ height=image_size,
+ width=image_size)
+
+ return image
+
+
+def decode_crop_and_flip(image_bytes: tf.Tensor) -> tf.Tensor:
+ """Crops an image to a random part of the image, then randomly flips.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+
+ Returns:
+ A decoded and cropped image `Tensor`.
+
+ """
+ decoded = image_bytes.dtype != tf.string
+ bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
+ shape = (tf.shape(image_bytes) if decoded
+ else tf.image.extract_jpeg_shape(image_bytes))
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+ shape,
+ bounding_boxes=bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=[0.75, 1.33],
+ area_range=[0.05, 1.0],
+ max_attempts=100,
+ use_image_if_no_bounding_boxes=True)
+ bbox_begin, bbox_size, _ = sample_distorted_bounding_box
+
+ # Reassemble the bounding box in the format the crop op requires.
+ offset_height, offset_width, _ = tf.unstack(bbox_begin)
+ target_height, target_width, _ = tf.unstack(bbox_size)
+ crop_window = tf.stack([offset_height, offset_width,
+ target_height, target_width])
+ if decoded:
+ cropped = tf.image.crop_to_bounding_box(
+ image_bytes,
+ offset_height=offset_height,
+ offset_width=offset_width,
+ target_height=target_height,
+ target_width=target_width)
+ else:
+ cropped = tf.image.decode_and_crop_jpeg(image_bytes,
+ crop_window,
+ channels=3)
+
+ # Flip to add a little more random distortion in.
+ cropped = tf.image.random_flip_left_right(cropped)
+ return cropped
+
+
+def resize_image(image_bytes: tf.Tensor,
+ height: int = IMAGE_SIZE,
+ width: int = IMAGE_SIZE) -> tf.Tensor:
+ """Resizes an image to a given height and width.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ height: image height dimension.
+ width: image width dimension.
+
+ Returns:
+ A tensor containing the resized image.
+
+ """
+ print(height, width)
+ return tf.compat.v1.image.resize(
+ image_bytes,
+ tf.convert_to_tensor([height, width]),
+ method=tf.image.ResizeMethod.BILINEAR,
+ align_corners=False)
+
+
+def preprocess_for_eval(
+ image_bytes: tf.Tensor,
+ image_size: int = IMAGE_SIZE,
+ num_channels: int = 3,
+ mean_subtract: bool = False,
+ standardize: bool = False,
+ dtype: tf.dtypes.DType = tf.float32
+) -> tf.Tensor:
+ """Preprocesses the given image for evaluation.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ image_size: image height/width dimension.
+ num_channels: number of image input channels.
+ mean_subtract: whether or not to apply mean subtraction.
+ standardize: whether or not to apply standardization.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ images = decode_and_center_crop(image_bytes, image_size)
+ images = tf.reshape(images, [image_size, image_size, num_channels])
+
+ if mean_subtract:
+ images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB)
+ if standardize:
+ images = standardize_image(image_bytes=images, stddev=STDDEV_RGB)
+ if dtype is not None:
+ images = tf.image.convert_image_dtype(images, dtype=dtype)
+
+ return images
+
+
+def load_eval_image(filename: Text, image_size: int = IMAGE_SIZE) -> tf.Tensor:
+ """Reads an image from the filesystem and applies image preprocessing.
+
+ Args:
+ filename: a filename path of an image.
+ image_size: image height/width dimension.
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ image_bytes = tf.io.read_file(filename)
+ image = preprocess_for_eval(image_bytes, image_size)
+
+ return image
+
+
+def build_eval_dataset(filenames: List[Text],
+ labels: Optional[List[int]] = None,
+ image_size: int = IMAGE_SIZE,
+ batch_size: int = 1) -> tf.Tensor:
+ """Builds a tf.data.Dataset from a list of filenames and labels.
+
+ Args:
+ filenames: a list of filename paths of images.
+ labels: a list of labels corresponding to each image.
+ image_size: image height/width dimension.
+ batch_size: the batch size used by the dataset
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ if labels is None:
+ labels = [0] * len(filenames)
+
+ filenames = tf.constant(filenames)
+ labels = tf.constant(labels)
+ dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
+
+ dataset = dataset.map(
+ lambda filename, label: (load_eval_image(filename, image_size), label))
+ dataset = dataset.batch(batch_size)
+
+ return dataset
+
+
+def preprocess_for_train(image_bytes: tf.Tensor,
+ image_size: int = IMAGE_SIZE,
+ augmenter: Optional[augment.ImageAugment] = None,
+ mean_subtract: bool = False,
+ standardize: bool = False,
+ dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor:
+ """Preprocesses the given image for training.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of
+ arbitrary size of dtype tf.uint8.
+ image_size: image height/width dimension.
+ augmenter: the image augmenter to apply.
+ mean_subtract: whether or not to apply mean subtraction.
+ standardize: whether or not to apply standardization.
+ dtype: the dtype to convert the images to. Set to `None` to skip conversion.
+
+ Returns:
+ A preprocessed and normalized image `Tensor`.
+ """
+ images = decode_crop_and_flip(image_bytes=image_bytes)
+ images = resize_image(images, height=image_size, width=image_size)
+ if augmenter is not None:
+ images = augmenter.distort(images)
+ if mean_subtract:
+ images = mean_image_subtraction(image_bytes=images, means=MEAN_RGB)
+ if standardize:
+ images = standardize_image(image_bytes=images, stddev=STDDEV_RGB)
+ if dtype is not None:
+ images = tf.image.convert_image_dtype(images, dtype)
+
+ return images
diff --git a/modeling/official/legacy/image_classification/resnet/README.md b/modeling/official/legacy/image_classification/resnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5064523fbdcd4222c2159bdc1c09b7156800bf54
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/README.md
@@ -0,0 +1,125 @@
+This folder contains a
+[custom training loop (CTL)](#resnet-custom-training-loop) implementation for
+ResNet50.
+
+## Before you begin
+Please refer to the [README](../README.md) in the parent directory for
+information on setup and preparing the data.
+
+## ResNet (custom training loop)
+
+Similar to the [estimator implementation](../../../r1/resnet), the Keras
+implementation has code for the ImageNet dataset. The ImageNet
+version uses a ResNet50 model implemented in
+[`resnet_model.py`](./resnet_model.py).
+
+
+### Pretrained Models
+
+* [ResNet50 Checkpoints](https://storage.googleapis.com/cloud-tpu-checkpoints/resnet/resnet50.tar.gz)
+
+* ResNet50 TFHub: [feature vector](https://tfhub.dev/tensorflow/resnet_50/feature_vector/1)
+and [classification](https://tfhub.dev/tensorflow/resnet_50/classification/1)
+
+Again, if you did not download the data to the default directory, specify the
+location with the `--data_dir` flag:
+
+```bash
+python3 resnet_ctl_imagenet_main.py --data_dir=/path/to/imagenet
+```
+
+There are more flag options you can specify. Here are some examples:
+
+- `--use_synthetic_data`: when set to true, synthetic data, rather than real
+data, are used;
+- `--batch_size`: the batch size used for the model;
+- `--model_dir`: the directory to save the model checkpoint;
+- `--train_epochs`: number of epoches to run for training the model;
+- `--train_steps`: number of steps to run for training the model. We now only
+support a number that is smaller than the number of batches in an epoch.
+- `--skip_eval`: when set to true, evaluation as well as validation during
+training is skipped
+
+For example, this is a typical command line to run with ImageNet data with
+batch size 128 per GPU:
+
+```bash
+python3 -m resnet_ctl_imagenet_main.py \
+ --model_dir=/tmp/model_dir/something \
+ --num_gpus=2 \
+ --batch_size=128 \
+ --train_epochs=90 \
+ --train_steps=10 \
+ --use_synthetic_data=false
+```
+
+See [`common.py`](common.py) for full list of options.
+
+### Using multiple GPUs
+
+You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
+You can read more about them in this
+[guide](https://www.tensorflow.org/guide/distribute_strategy).
+
+In this example, we have made it easier to use is with just a command line flag
+`--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
+and 0 otherwise.
+
+- --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
+- --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
+- --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
+distributed training across the GPUs.
+
+If you wish to run without `tf.distribute.Strategy`, you can do so by setting
+`--distribution_strategy=off`.
+
+### Running on multiple GPU hosts
+
+You can also train these models on multiple hosts, each with GPUs, using
+`tf.distribute.Strategy`.
+
+The easiest way to run multi-host benchmarks is to set the
+[`TF_CONFIG`](https://www.tensorflow.org/guide/distributed_training#TF_CONFIG)
+appropriately at each host. e.g., to run using `MultiWorkerMirroredStrategy` on
+2 hosts, the `cluster` in `TF_CONFIG` should have 2 `host:port` entries, and
+host `i` should have the `task` in `TF_CONFIG` set to `{"type": "worker",
+"index": i}`. `MultiWorkerMirroredStrategy` will automatically use all the
+available GPUs at each host.
+
+### Running on Cloud TPUs
+
+Note: This model will **not** work with TPUs on Colab.
+
+You can train the ResNet CTL model on Cloud TPUs using
+`tf.distribute.TPUStrategy`. If you are not familiar with Cloud TPUs, it is
+strongly recommended that you go through the
+[quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
+create a TPU and GCE VM.
+
+To run ResNet model on a TPU, you must set `--distribution_strategy=tpu` and
+`--tpu=$TPU_NAME`, where `$TPU_NAME` the name of your TPU in the Cloud Console.
+From a GCE VM, you can run the following command to train ResNet for one epoch
+on a v2-8 or v3-8 TPU by setting `TRAIN_EPOCHS` to 1:
+
+```bash
+python3 resnet_ctl_imagenet_main.py \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --batch_size=1024 \
+ --steps_per_loop=500 \
+ --train_epochs=$TRAIN_EPOCHS \
+ --use_synthetic_data=false \
+ --dtype=fp32 \
+ --enable_eager=true \
+ --enable_tensorboard=true \
+ --distribution_strategy=tpu \
+ --log_steps=50 \
+ --single_l2_loss_op=true \
+ --use_tf_function=true
+```
+
+To train the ResNet to convergence, run it for 90 epochs by setting
+`TRAIN_EPOCHS` to 90.
+
+Note: `$MODEL_DIR` and `$DATA_DIR` must be GCS paths.
diff --git a/modeling/official/legacy/image_classification/resnet/__init__.py b/modeling/official/legacy/image_classification/resnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/image_classification/resnet/common.py b/modeling/official/legacy/image_classification/resnet/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..b97721fb556325436449db4ac0d3c2a42f72eb25
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/common.py
@@ -0,0 +1,423 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common util functions and classes used by both keras cifar and imagenet."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import flags
+import tensorflow as tf, tf_keras
+
+import tensorflow_model_optimization as tfmot
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
+TRAIN_TOP_1 = 'training_accuracy_top_1'
+LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
+ (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
+]
+
+
+class PiecewiseConstantDecayWithWarmup(
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Piecewise constant decay with warmup schedule."""
+
+ def __init__(self,
+ batch_size,
+ epoch_size,
+ warmup_epochs,
+ boundaries,
+ multipliers,
+ compute_lr_on_cpu=True,
+ name=None):
+ super(PiecewiseConstantDecayWithWarmup, self).__init__()
+ if len(boundaries) != len(multipliers) - 1:
+ raise ValueError('The length of boundaries must be 1 less than the '
+ 'length of multipliers')
+
+ base_lr_batch_size = 256
+ steps_per_epoch = epoch_size // batch_size
+
+ self.rescaled_lr = BASE_LEARNING_RATE * batch_size / base_lr_batch_size
+ self.step_boundaries = [float(steps_per_epoch) * x for x in boundaries]
+ self.lr_values = [self.rescaled_lr * m for m in multipliers]
+ self.warmup_steps = warmup_epochs * steps_per_epoch
+ self.compute_lr_on_cpu = compute_lr_on_cpu
+ self.name = name
+
+ self.learning_rate_ops_cache = {}
+
+ def __call__(self, step):
+ if tf.executing_eagerly():
+ return self._get_learning_rate(step)
+
+ # In an eager function or graph, the current implementation of optimizer
+ # repeatedly call and thus create ops for the learning rate schedule. To
+ # avoid this, we cache the ops if not executing eagerly.
+ graph = tf.compat.v1.get_default_graph()
+ if graph not in self.learning_rate_ops_cache:
+ if self.compute_lr_on_cpu:
+ with tf.device('/device:CPU:0'):
+ self.learning_rate_ops_cache[graph] = self._get_learning_rate(step)
+ else:
+ self.learning_rate_ops_cache[graph] = self._get_learning_rate(step)
+ return self.learning_rate_ops_cache[graph]
+
+ def _get_learning_rate(self, step):
+ """Compute learning rate at given step."""
+ step = tf.cast(step, dtype=tf.float32)
+ warmup_steps = tf.cast(self.warmup_steps, dtype=tf.float32)
+ with tf.name_scope('PiecewiseConstantDecayWithWarmup'):
+
+ def warmup_lr(step):
+ return self.rescaled_lr * (step / warmup_steps)
+
+ def piecewise_lr(step):
+ return tf.compat.v1.train.piecewise_constant(step, self.step_boundaries,
+ self.lr_values)
+
+ return tf.cond(step < warmup_steps, lambda: warmup_lr(step),
+ lambda: piecewise_lr(step))
+
+ def get_config(self):
+ return {
+ 'rescaled_lr': self.rescaled_lr,
+ 'step_boundaries': self.step_boundaries,
+ 'lr_values': self.lr_values,
+ 'warmup_steps': self.warmup_steps,
+ 'compute_lr_on_cpu': self.compute_lr_on_cpu,
+ 'name': self.name
+ }
+
+
+def get_optimizer(learning_rate=0.1, use_legacy_optimizer=True):
+ """Returns optimizer to use."""
+ # The learning_rate is overwritten at the beginning of each step by callback.
+ if use_legacy_optimizer:
+ return tf_keras.optimizers.legacy.SGD(
+ learning_rate=learning_rate, momentum=0.9)
+ else:
+ return tf_keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
+
+
+def get_callbacks(pruning_method=None,
+ enable_checkpoint_and_export=False,
+ model_dir=None):
+ """Returns common callbacks."""
+ time_callback = keras_utils.TimeHistory(
+ FLAGS.batch_size,
+ FLAGS.log_steps,
+ logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
+ callbacks = [time_callback]
+
+ if FLAGS.enable_tensorboard:
+ tensorboard_callback = tf_keras.callbacks.TensorBoard(
+ log_dir=FLAGS.model_dir, profile_batch=FLAGS.profile_steps)
+ callbacks.append(tensorboard_callback)
+
+ is_pruning_enabled = pruning_method is not None
+ if is_pruning_enabled:
+ callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
+ if model_dir is not None:
+ callbacks.append(
+ tfmot.sparsity.keras.PruningSummaries(
+ log_dir=model_dir, profile_batch=0))
+
+ if enable_checkpoint_and_export:
+ if model_dir is not None:
+ ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
+ callbacks.append(
+ tf_keras.callbacks.ModelCheckpoint(
+ ckpt_full_path, save_weights_only=True))
+ return callbacks
+
+
+def build_stats(history, eval_output, callbacks):
+ """Normalizes and returns dictionary of stats.
+
+ Args:
+ history: Results of the training step. Supports both categorical_accuracy
+ and sparse_categorical_accuracy.
+ eval_output: Output of the eval step. Assumes first value is eval_loss and
+ second value is accuracy_top_1.
+ callbacks: a list of callbacks which might include a time history callback
+ used during keras.fit.
+
+ Returns:
+ Dictionary of normalized results.
+ """
+ stats = {}
+ if eval_output:
+ stats['accuracy_top_1'] = float(eval_output[1])
+ stats['eval_loss'] = float(eval_output[0])
+ if history and history.history:
+ train_hist = history.history
+ # Gets final loss from training.
+ stats['loss'] = float(train_hist['loss'][-1])
+ # Gets top_1 training accuracy.
+ if 'categorical_accuracy' in train_hist:
+ stats[TRAIN_TOP_1] = float(train_hist['categorical_accuracy'][-1])
+ elif 'sparse_categorical_accuracy' in train_hist:
+ stats[TRAIN_TOP_1] = float(train_hist['sparse_categorical_accuracy'][-1])
+ elif 'accuracy' in train_hist:
+ stats[TRAIN_TOP_1] = float(train_hist['accuracy'][-1])
+
+ if not callbacks:
+ return stats
+
+ # Look for the time history callback which was used during keras.fit
+ for callback in callbacks:
+ if isinstance(callback, keras_utils.TimeHistory):
+ timestamp_log = callback.timestamp_log
+ stats['step_timestamp_log'] = timestamp_log
+ stats['train_finish_time'] = callback.train_finish_time
+ if callback.epoch_runtime_log:
+ stats['avg_exp_per_second'] = callback.average_examples_per_second
+
+ return stats
+
+
+def define_keras_flags(model=False,
+ optimizer=False,
+ pretrained_filepath=False):
+ """Define flags for Keras models."""
+ flags_core.define_base(
+ clean=True,
+ num_gpu=True,
+ run_eagerly=True,
+ train_epochs=True,
+ epochs_between_evals=True,
+ distribution_strategy=True)
+ flags_core.define_performance(
+ num_parallel_calls=False,
+ synthetic_data=True,
+ dtype=True,
+ all_reduce_alg=True,
+ num_packs=True,
+ tf_gpu_thread_mode=True,
+ datasets_num_private_threads=True,
+ loss_scale=True,
+ fp16_implementation=True,
+ tf_data_experimental_slack=True,
+ enable_xla=True,
+ training_dataset_cache=True)
+ flags_core.define_image()
+ flags_core.define_benchmark()
+ flags_core.define_distribution()
+ flags.adopt_module_key_flags(flags_core)
+
+ flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
+ flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
+ # TODO(b/135607288): Remove this flag once we understand the root cause of
+ # slowdown when setting the learning phase in Keras backend.
+ flags.DEFINE_boolean(
+ name='set_learning_phase_to_train',
+ default=True,
+ help='If skip eval, also set Keras learning phase to 1 (training).')
+ flags.DEFINE_boolean(
+ name='explicit_gpu_placement',
+ default=False,
+ help='If not using distribution strategy, explicitly set device scope '
+ 'for the Keras training loop.')
+ flags.DEFINE_boolean(
+ name='use_trivial_model',
+ default=False,
+ help='Whether to use a trivial Keras model.')
+ flags.DEFINE_boolean(
+ name='report_accuracy_metrics',
+ default=True,
+ help='Report metrics during training and evaluation.')
+ flags.DEFINE_boolean(
+ name='use_tensor_lr',
+ default=True,
+ help='Use learning rate tensor instead of a callback.')
+ flags.DEFINE_boolean(
+ name='enable_tensorboard',
+ default=False,
+ help='Whether to enable TensorBoard callback.')
+ flags.DEFINE_string(
+ name='profile_steps',
+ default=None,
+ help='Save profiling data to model dir at given range of global steps. The '
+ 'value must be a comma separated pair of positive integers, specifying '
+ 'the first and last step to profile. For example, "--profile_steps=2,4" '
+ 'triggers the profiler to process 3 steps, starting from the 2nd step. '
+ 'Note that profiler has a non-trivial performance overhead, and the '
+ 'output file can be gigantic if profiling many steps.')
+ flags.DEFINE_integer(
+ name='train_steps',
+ default=None,
+ help='The number of steps to run for training. If it is larger than '
+ '# batches per epoch, then use # batches per epoch. This flag will be '
+ 'ignored if train_epochs is set to be larger than 1. ')
+ flags.DEFINE_boolean(
+ name='batchnorm_spatial_persistent',
+ default=True,
+ help='Enable the spacial persistent mode for CuDNN batch norm kernel.')
+ flags.DEFINE_boolean(
+ name='enable_get_next_as_optional',
+ default=False,
+ help='Enable get_next_as_optional behavior in DistributedIterator.')
+ flags.DEFINE_boolean(
+ name='enable_checkpoint_and_export',
+ default=False,
+ help='Whether to enable a checkpoint callback and export the savedmodel.')
+ flags.DEFINE_string(name='tpu', default='', help='TPU address to connect to.')
+ flags.DEFINE_integer(
+ name='steps_per_loop',
+ default=None,
+ help='Number of steps per training loop. Only training step happens '
+ 'inside the loop. Callbacks will not be called inside. Will be capped at '
+ 'steps per epoch.')
+ flags.DEFINE_boolean(
+ name='use_tf_while_loop',
+ default=True,
+ help='Whether to build a tf.while_loop inside the training loop on the '
+ 'host. Setting it to True is critical to have peak performance on '
+ 'TPU.')
+
+ if model:
+ flags.DEFINE_string('model', 'resnet50_v1.5',
+ 'Name of model preset. (mobilenet, resnet50_v1.5)')
+ if optimizer:
+ flags.DEFINE_string(
+ 'optimizer', 'resnet50_default', 'Name of optimizer preset. '
+ '(mobilenet_default, resnet50_default)')
+ # TODO(kimjaehong): Replace as general hyper-params not only for mobilenet.
+ flags.DEFINE_float(
+ 'initial_learning_rate_per_sample', 0.00007,
+ 'Initial value of learning rate per sample for '
+ 'mobilenet_default.')
+ flags.DEFINE_float('lr_decay_factor', 0.94,
+ 'Learning rate decay factor for mobilenet_default.')
+ flags.DEFINE_float('num_epochs_per_decay', 2.5,
+ 'Number of epochs per decay for mobilenet_default.')
+ if pretrained_filepath:
+ flags.DEFINE_string('pretrained_filepath', '', 'Pretrained file path.')
+
+
+def get_synth_data(height, width, num_channels, num_classes, dtype):
+ """Creates a set of synthetic random data.
+
+ Args:
+ height: Integer height that will be used to create a fake image tensor.
+ width: Integer width that will be used to create a fake image tensor.
+ num_channels: Integer depth that will be used to create a fake image tensor.
+ num_classes: Number of classes that should be represented in the fake labels
+ tensor
+ dtype: Data type for features/images.
+
+ Returns:
+ A tuple of tensors representing the inputs and labels.
+
+ """
+ # Synthetic input should be within [0, 255].
+ inputs = tf.random.truncated_normal([height, width, num_channels],
+ dtype=dtype,
+ mean=127,
+ stddev=60,
+ name='synthetic_inputs')
+ labels = tf.random.uniform([1],
+ minval=0,
+ maxval=num_classes - 1,
+ dtype=tf.int32,
+ name='synthetic_labels')
+ return inputs, labels
+
+
+def define_pruning_flags():
+ """Define flags for pruning methods."""
+ flags.DEFINE_string(
+ 'pruning_method', None, 'Pruning method.'
+ 'None (no pruning) or polynomial_decay.')
+ flags.DEFINE_float('pruning_initial_sparsity', 0.0,
+ 'Initial sparsity for pruning.')
+ flags.DEFINE_float('pruning_final_sparsity', 0.5,
+ 'Final sparsity for pruning.')
+ flags.DEFINE_integer('pruning_begin_step', 0, 'Begin step for pruning.')
+ flags.DEFINE_integer('pruning_end_step', 100000, 'End step for pruning.')
+ flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.')
+
+
+def define_clustering_flags():
+ """Define flags for clustering methods."""
+ flags.DEFINE_string('clustering_method', None,
+ 'None (no clustering) or selective_clustering '
+ '(cluster last three Conv2D layers of the model).')
+
+
+def get_synth_input_fn(height,
+ width,
+ num_channels,
+ num_classes,
+ dtype=tf.float32,
+ drop_remainder=True):
+ """Returns an input function that returns a dataset with random data.
+
+ This input_fn returns a data set that iterates over a set of random data and
+ bypasses all preprocessing, e.g. jpeg decode and copy. The host to device
+ copy is still included. This used to find the upper throughput bound when
+ tuning the full input pipeline.
+
+ Args:
+ height: Integer height that will be used to create a fake image tensor.
+ width: Integer width that will be used to create a fake image tensor.
+ num_channels: Integer depth that will be used to create a fake image tensor.
+ num_classes: Number of classes that should be represented in the fake labels
+ tensor
+ dtype: Data type for features/images.
+ drop_remainder: A boolean indicates whether to drop the remainder of the
+ batches. If True, the batch dimension will be static.
+
+ Returns:
+ An input_fn that can be used in place of a real one to return a dataset
+ that can be used for iteration.
+ """
+
+ # pylint: disable=unused-argument
+ def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
+ """Returns dataset filled with random data."""
+ inputs, labels = get_synth_data(
+ height=height,
+ width=width,
+ num_channels=num_channels,
+ num_classes=num_classes,
+ dtype=dtype)
+ # Cast to float32 for Keras model.
+ labels = tf.cast(labels, dtype=tf.float32)
+ data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
+
+ # `drop_remainder` will make dataset produce outputs with known shapes.
+ data = data.batch(batch_size, drop_remainder=drop_remainder)
+ data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+ return data
+
+ return input_fn
+
+
+def set_cudnn_batchnorm_mode():
+ """Set CuDNN batchnorm mode for better performance.
+
+ Note: Spatial Persistent mode may lead to accuracy losses for certain
+ models.
+ """
+ if FLAGS.batchnorm_spatial_persistent:
+ os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
+ else:
+ os.environ.pop('TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT', None)
diff --git a/modeling/official/legacy/image_classification/resnet/imagenet_preprocessing.py b/modeling/official/legacy/image_classification/resnet/imagenet_preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..00ad8e61c7675e697de36f50690b8296b6e7a19b
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/imagenet_preprocessing.py
@@ -0,0 +1,574 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides utilities to preprocess images.
+
+Training images are sampled using the provided bounding boxes, and subsequently
+cropped to the sampled bounding box. Images are additionally flipped randomly,
+then resized to the target output size (without aspect-ratio preservation).
+
+Images used during evaluation are resized (with aspect-ratio preservation) and
+centrally cropped.
+
+All images undergo mean color subtraction.
+
+Note that these steps are colloquially referred to as "ResNet preprocessing,"
+and they differ from "VGG preprocessing," which does not use bounding boxes
+and instead does an aspect-preserving resize followed by random crop during
+training. (These both differ from "Inception preprocessing," which introduces
+color distortion steps.)
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+DEFAULT_IMAGE_SIZE = 224
+NUM_CHANNELS = 3
+NUM_CLASSES = 1001
+
+NUM_IMAGES = {
+ 'train': 1281167,
+ 'validation': 50000,
+}
+
+_NUM_TRAIN_FILES = 1024
+_SHUFFLE_BUFFER = 10000
+
+_R_MEAN = 123.68
+_G_MEAN = 116.78
+_B_MEAN = 103.94
+CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
+
+# The lower bound for the smallest side of the image for aspect-preserving
+# resizing. For example, if an image is 500 x 1000, it will be resized to
+# _RESIZE_MIN x (_RESIZE_MIN * 2).
+_RESIZE_MIN = 256
+
+
+def process_record_dataset(dataset,
+ is_training,
+ batch_size,
+ shuffle_buffer,
+ parse_record_fn,
+ dtype=tf.float32,
+ datasets_num_private_threads=None,
+ drop_remainder=False,
+ tf_data_experimental_slack=False):
+ """Given a Dataset with raw records, return an iterator over the records.
+
+ Args:
+ dataset: A Dataset representing raw records
+ is_training: A boolean denoting whether the input is for training.
+ batch_size: The number of samples per batch.
+ shuffle_buffer: The buffer size to use when shuffling records. A larger
+ value results in better randomness, but smaller values reduce startup time
+ and use less memory.
+ parse_record_fn: A function that takes a raw record and returns the
+ corresponding (image, label) pair.
+ dtype: Data type to use for images/features.
+ datasets_num_private_threads: Number of threads for a private threadpool
+ created for all datasets computation.
+ drop_remainder: A boolean indicates whether to drop the remainder of the
+ batches. If True, the batch dimension will be static.
+ tf_data_experimental_slack: Whether to enable tf.data's `experimental_slack`
+ option.
+
+ Returns:
+ Dataset of (image, label) pairs ready for iteration.
+ """
+ # Defines a specific size thread pool for tf.data operations.
+ if datasets_num_private_threads:
+ options = tf.data.Options()
+ options.experimental_threading.private_threadpool_size = (
+ datasets_num_private_threads)
+ dataset = dataset.with_options(options)
+ logging.info('datasets_num_private_threads: %s',
+ datasets_num_private_threads)
+
+ if is_training:
+ # Shuffles records before repeating to respect epoch boundaries.
+ dataset = dataset.shuffle(buffer_size=shuffle_buffer)
+ # Repeats the dataset for the number of epochs to train.
+ dataset = dataset.repeat()
+
+ # Parses the raw records into images and labels.
+ dataset = dataset.map(
+ lambda value: parse_record_fn(value, is_training, dtype),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
+
+ # Operations between the final prefetch and the get_next call to the iterator
+ # will happen synchronously during run time. We prefetch here again to
+ # background all of the above processing work and keep it out of the
+ # critical training path. Setting buffer_size to tf.data.experimental.AUTOTUNE
+ # allows DistributionStrategies to adjust how many batches to fetch based
+ # on how many devices are present.
+ dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+
+ options = tf.data.Options()
+ options.experimental_slack = tf_data_experimental_slack
+ dataset = dataset.with_options(options)
+
+ return dataset
+
+
+def get_filenames(is_training, data_dir):
+ """Return filenames for dataset."""
+ if is_training:
+ return [
+ os.path.join(data_dir, 'train-%05d-of-01024' % i)
+ for i in range(_NUM_TRAIN_FILES)
+ ]
+ else:
+ return [
+ os.path.join(data_dir, 'validation-%05d-of-00128' % i)
+ for i in range(128)
+ ]
+
+
+def parse_example_proto(example_serialized):
+ """Parses an Example proto containing a training example of an image.
+
+ The output of the build_image_data.py image preprocessing script is a dataset
+ containing serialized Example protocol buffers. Each Example proto contains
+ the following fields (values are included as examples):
+
+ image/height: 462
+ image/width: 581
+ image/colorspace: 'RGB'
+ image/channels: 3
+ image/class/label: 615
+ image/class/synset: 'n03623198'
+ image/class/text: 'knee pad'
+ image/object/bbox/xmin: 0.1
+ image/object/bbox/xmax: 0.9
+ image/object/bbox/ymin: 0.2
+ image/object/bbox/ymax: 0.6
+ image/object/bbox/label: 615
+ image/format: 'JPEG'
+ image/filename: 'ILSVRC2012_val_00041207.JPEG'
+ image/encoded:
+
+ Args:
+ example_serialized: scalar Tensor tf.string containing a serialized Example
+ protocol buffer.
+
+ Returns:
+ image_buffer: Tensor tf.string containing the contents of a JPEG file.
+ label: Tensor tf.int32 containing the label.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as
+ [ymin, xmin, ymax, xmax].
+ """
+ # Dense features in Example proto.
+ feature_map = {
+ 'image/encoded':
+ tf.io.FixedLenFeature([], dtype=tf.string, default_value=''),
+ 'image/class/label':
+ tf.io.FixedLenFeature([], dtype=tf.int64, default_value=-1),
+ 'image/class/text':
+ tf.io.FixedLenFeature([], dtype=tf.string, default_value=''),
+ }
+ sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32)
+ # Sparse features in Example proto.
+ feature_map.update({
+ k: sparse_float32 for k in [
+ 'image/object/bbox/xmin', 'image/object/bbox/ymin',
+ 'image/object/bbox/xmax', 'image/object/bbox/ymax'
+ ]
+ })
+
+ features = tf.io.parse_single_example(
+ serialized=example_serialized, features=feature_map)
+ label = tf.cast(features['image/class/label'], dtype=tf.int32)
+
+ xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
+ ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
+ xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
+ ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
+
+ # Note that we impose an ordering of (y, x) just to make life difficult.
+ bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
+
+ # Force the variable number of bounding boxes into the shape
+ # [1, num_boxes, coords].
+ bbox = tf.expand_dims(bbox, 0)
+ bbox = tf.transpose(a=bbox, perm=[0, 2, 1])
+
+ return features['image/encoded'], label, bbox
+
+
+def parse_record(raw_record, is_training, dtype):
+ """Parses a record containing a training example of an image.
+
+ The input record is parsed into a label and image, and the image is passed
+ through preprocessing steps (cropping, flipping, and so on).
+
+ Args:
+ raw_record: scalar Tensor tf.string containing a serialized Example protocol
+ buffer.
+ is_training: A boolean denoting whether the input is for training.
+ dtype: data type to use for images/features.
+
+ Returns:
+ Tuple with processed image tensor in a channel-last format and
+ one-hot-encoded label tensor.
+ """
+ image_buffer, label, bbox = parse_example_proto(raw_record)
+
+ image = preprocess_image(
+ image_buffer=image_buffer,
+ bbox=bbox,
+ output_height=DEFAULT_IMAGE_SIZE,
+ output_width=DEFAULT_IMAGE_SIZE,
+ num_channels=NUM_CHANNELS,
+ is_training=is_training)
+ image = tf.cast(image, dtype)
+
+ # Subtract one so that labels are in [0, 1000), and cast to float32 for
+ # Keras model.
+ label = tf.cast(
+ tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
+ dtype=tf.float32)
+ return image, label
+
+
+def get_parse_record_fn(use_keras_image_data_format=False):
+ """Get a function for parsing the records, accounting for image format.
+
+ This is useful by handling different types of Keras models. For instance,
+ the current resnet_model.resnet50 input format is always channel-last,
+ whereas the keras_applications mobilenet input format depends on
+ tf_keras.backend.image_data_format(). We should set
+ use_keras_image_data_format=False for the former and True for the latter.
+
+ Args:
+ use_keras_image_data_format: A boolean denoting whether data format is keras
+ backend image data format. If False, the image format is channel-last. If
+ True, the image format matches tf_keras.backend.image_data_format().
+
+ Returns:
+ Function to use for parsing the records.
+ """
+
+ def parse_record_fn(raw_record, is_training, dtype):
+ image, label = parse_record(raw_record, is_training, dtype)
+ if use_keras_image_data_format:
+ if tf_keras.backend.image_data_format() == 'channels_first':
+ image = tf.transpose(image, perm=[2, 0, 1])
+ return image, label
+
+ return parse_record_fn
+
+
+def input_fn(is_training,
+ data_dir,
+ batch_size,
+ dtype=tf.float32,
+ datasets_num_private_threads=None,
+ parse_record_fn=parse_record,
+ input_context=None,
+ drop_remainder=False,
+ tf_data_experimental_slack=False,
+ training_dataset_cache=False,
+ filenames=None):
+ """Input function which provides batches for train or eval.
+
+ Args:
+ is_training: A boolean denoting whether the input is for training.
+ data_dir: The directory containing the input data.
+ batch_size: The number of samples per batch.
+ dtype: Data type to use for images/features
+ datasets_num_private_threads: Number of private threads for tf.data.
+ parse_record_fn: Function to use for parsing the records.
+ input_context: A `tf.distribute.InputContext` object passed in by
+ `tf.distribute.Strategy`.
+ drop_remainder: A boolean indicates whether to drop the remainder of the
+ batches. If True, the batch dimension will be static.
+ tf_data_experimental_slack: Whether to enable tf.data's `experimental_slack`
+ option.
+ training_dataset_cache: Whether to cache the training dataset on workers.
+ Typically used to improve training performance when training data is in
+ remote storage and can fit into worker memory.
+ filenames: Optional field for providing the file names of the TFRecords.
+
+ Returns:
+ A dataset that can be used for iteration.
+ """
+ if filenames is None:
+ filenames = get_filenames(is_training, data_dir)
+ dataset = tf.data.Dataset.from_tensor_slices(filenames)
+
+ if input_context:
+ logging.info(
+ 'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
+ input_context.input_pipeline_id, input_context.num_input_pipelines)
+ dataset = dataset.shard(input_context.num_input_pipelines,
+ input_context.input_pipeline_id)
+
+ if is_training:
+ # Shuffle the input files
+ dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES)
+
+ # Convert to individual records.
+ # cycle_length = 10 means that up to 10 files will be read and deserialized in
+ # parallel. You may want to increase this number if you have a large number of
+ # CPU cores.
+ dataset = dataset.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=10,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ if is_training and training_dataset_cache:
+ # Improve training performance when training data is in remote storage and
+ # can fit into worker memory.
+ dataset = dataset.cache()
+
+ return process_record_dataset(
+ dataset=dataset,
+ is_training=is_training,
+ batch_size=batch_size,
+ shuffle_buffer=_SHUFFLE_BUFFER,
+ parse_record_fn=parse_record_fn,
+ dtype=dtype,
+ datasets_num_private_threads=datasets_num_private_threads,
+ drop_remainder=drop_remainder,
+ tf_data_experimental_slack=tf_data_experimental_slack,
+ )
+
+
+def _decode_crop_and_flip(image_buffer, bbox, num_channels):
+ """Crops the given image to a random part of the image, and randomly flips.
+
+ We use the fused decode_and_crop op, which performs better than the two ops
+ used separately in series, but note that this requires that the image be
+ passed in as an un-decoded string Tensor.
+
+ Args:
+ image_buffer: scalar string Tensor representing the raw JPEG image buffer.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as [ymin,
+ xmin, ymax, xmax].
+ num_channels: Integer depth of the image buffer for decoding.
+
+ Returns:
+ 3-D tensor with cropped image.
+
+ """
+ # A large fraction of image datasets contain a human-annotated bounding box
+ # delineating the region of the image containing the object of interest. We
+ # choose to create a new bounding box for the object which is a randomly
+ # distorted version of the human-annotated bounding box that obeys an
+ # allowed range of aspect ratios, sizes and overlap with the human-annotated
+ # bounding box. If no box is supplied, then we assume the bounding box is
+ # the entire image.
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+ tf.image.extract_jpeg_shape(image_buffer),
+ bounding_boxes=bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=[0.75, 1.33],
+ area_range=[0.05, 1.0],
+ max_attempts=100,
+ use_image_if_no_bounding_boxes=True)
+ bbox_begin, bbox_size, _ = sample_distorted_bounding_box
+
+ # Reassemble the bounding box in the format the crop op requires.
+ offset_y, offset_x, _ = tf.unstack(bbox_begin)
+ target_height, target_width, _ = tf.unstack(bbox_size)
+ crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
+
+ # Use the fused decode and crop op here, which is faster than each in series.
+ cropped = tf.image.decode_and_crop_jpeg(
+ image_buffer, crop_window, channels=num_channels)
+
+ # Flip to add a little more random distortion in.
+ cropped = tf.image.random_flip_left_right(cropped)
+ return cropped
+
+
+def _central_crop(image, crop_height, crop_width):
+ """Performs central crops of the given image list.
+
+ Args:
+ image: a 3-D image tensor
+ crop_height: the height of the image following the crop.
+ crop_width: the width of the image following the crop.
+
+ Returns:
+ 3-D tensor with cropped image.
+ """
+ shape = tf.shape(input=image)
+ height, width = shape[0], shape[1]
+
+ amount_to_be_cropped_h = (height - crop_height)
+ crop_top = amount_to_be_cropped_h // 2
+ amount_to_be_cropped_w = (width - crop_width)
+ crop_left = amount_to_be_cropped_w // 2
+ return tf.slice(image, [crop_top, crop_left, 0],
+ [crop_height, crop_width, -1])
+
+
+def _mean_image_subtraction(image, means, num_channels):
+ """Subtracts the given means from each image channel.
+
+ For example:
+ means = [123.68, 116.779, 103.939]
+ image = _mean_image_subtraction(image, means)
+
+ Note that the rank of `image` must be known.
+
+ Args:
+ image: a tensor of size [height, width, C].
+ means: a C-vector of values to subtract from each channel.
+ num_channels: number of color channels in the image that will be distorted.
+
+ Returns:
+ the centered image.
+
+ Raises:
+ ValueError: If the rank of `image` is unknown, if `image` has a rank other
+ than three or if the number of channels in `image` doesn't match the
+ number of values in `means`.
+ """
+ if image.get_shape().ndims != 3:
+ raise ValueError('Input must be of size [height, width, C>0]')
+
+ if len(means) != num_channels:
+ raise ValueError('len(means) must match the number of channels')
+
+ # We have a 1-D tensor of means; convert to 3-D.
+ # Note(b/130245863): we explicitly call `broadcast` instead of simply
+ # expanding dimensions for better performance.
+ means = tf.broadcast_to(means, tf.shape(image))
+
+ return image - means
+
+
+def _smallest_size_at_least(height, width, resize_min):
+ """Computes new shape with the smallest side equal to `smallest_side`.
+
+ Computes new shape with the smallest side equal to `smallest_side` while
+ preserving the original aspect ratio.
+
+ Args:
+ height: an int32 scalar tensor indicating the current height.
+ width: an int32 scalar tensor indicating the current width.
+ resize_min: A python integer or scalar `Tensor` indicating the size of the
+ smallest side after resize.
+
+ Returns:
+ new_height: an int32 scalar tensor indicating the new height.
+ new_width: an int32 scalar tensor indicating the new width.
+ """
+ resize_min = tf.cast(resize_min, tf.float32)
+
+ # Convert to floats to make subsequent calculations go smoothly.
+ height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
+
+ smaller_dim = tf.minimum(height, width)
+ scale_ratio = resize_min / smaller_dim
+
+ # Convert back to ints to make heights and widths that TF ops will accept.
+ new_height = tf.cast(height * scale_ratio, tf.int32)
+ new_width = tf.cast(width * scale_ratio, tf.int32)
+
+ return new_height, new_width
+
+
+def _aspect_preserving_resize(image, resize_min):
+ """Resize images preserving the original aspect ratio.
+
+ Args:
+ image: A 3-D image `Tensor`.
+ resize_min: A python integer or scalar `Tensor` indicating the size of the
+ smallest side after resize.
+
+ Returns:
+ resized_image: A 3-D tensor containing the resized image.
+ """
+ shape = tf.shape(input=image)
+ height, width = shape[0], shape[1]
+
+ new_height, new_width = _smallest_size_at_least(height, width, resize_min)
+
+ return _resize_image(image, new_height, new_width)
+
+
+def _resize_image(image, height, width):
+ """Simple wrapper around tf.resize_images.
+
+ This is primarily to make sure we use the same `ResizeMethod` and other
+ details each time.
+
+ Args:
+ image: A 3-D image `Tensor`.
+ height: The target height for the resized image.
+ width: The target width for the resized image.
+
+ Returns:
+ resized_image: A 3-D tensor containing the resized image. The first two
+ dimensions have the shape [height, width].
+ """
+ return tf.compat.v1.image.resize(
+ image, [height, width],
+ method=tf.image.ResizeMethod.BILINEAR,
+ align_corners=False)
+
+
+def preprocess_image(image_buffer,
+ bbox,
+ output_height,
+ output_width,
+ num_channels,
+ is_training=False):
+ """Preprocesses the given image.
+
+ Preprocessing includes decoding, cropping, and resizing for both training
+ and eval images. Training preprocessing, however, introduces some random
+ distortion of the image to improve accuracy.
+
+ Args:
+ image_buffer: scalar string Tensor representing the raw JPEG image buffer.
+ bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+ where each coordinate is [0, 1) and the coordinates are arranged as [ymin,
+ xmin, ymax, xmax].
+ output_height: The height of the image after preprocessing.
+ output_width: The width of the image after preprocessing.
+ num_channels: Integer depth of the image buffer for decoding.
+ is_training: `True` if we're preprocessing the image for training and
+ `False` otherwise.
+
+ Returns:
+ A preprocessed image.
+ """
+ if is_training:
+ # For training, we want to randomize some of the distortions.
+ image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
+ image = _resize_image(image, output_height, output_width)
+ else:
+ # For validation, we want to decode, resize, then just crop the middle.
+ image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
+ image = _aspect_preserving_resize(image, _RESIZE_MIN)
+ image = _central_crop(image, output_height, output_width)
+
+ image.set_shape([output_height, output_width, num_channels])
+
+ return _mean_image_subtraction(image, CHANNEL_MEANS, num_channels)
diff --git a/modeling/official/legacy/image_classification/resnet/resnet_config.py b/modeling/official/legacy/image_classification/resnet/resnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e0a48fb9977d48790ad6ed7435b7e7abdd72589
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/resnet_config.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration definitions for ResNet losses, learning rates, and optimizers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import dataclasses
+from official.legacy.image_classification.configs import base_configs
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class ResNetModelConfig(base_configs.ModelConfig):
+ """Configuration for the ResNet model."""
+ name: str = 'ResNet'
+ num_classes: int = 1000
+ model_params: base_config.Config = dataclasses.field(
+ # pylint: disable=g-long-lambda
+ default_factory=lambda: {
+ 'num_classes': 1000,
+ 'batch_size': None,
+ 'use_l2_regularizer': True,
+ 'rescale_inputs': False,
+ })
+ # pylint: enable=g-long-lambda
+ loss: base_configs.LossConfig = dataclasses.field(
+ default_factory=lambda: base_configs.LossConfig( # pylint: disable=g-long-lambda
+ name='sparse_categorical_crossentropy'
+ )
+ )
+ optimizer: base_configs.OptimizerConfig = dataclasses.field(
+ default_factory=lambda: base_configs.OptimizerConfig( # pylint: disable=g-long-lambda
+ name='momentum',
+ decay=0.9,
+ epsilon=0.001,
+ momentum=0.9,
+ moving_average_decay=None,
+ )
+ )
+ learning_rate: base_configs.LearningRateConfig = dataclasses.field(
+ default_factory=lambda: base_configs.LearningRateConfig( # pylint: disable=g-long-lambda
+ name='stepwise',
+ initial_lr=0.1,
+ examples_per_epoch=1281167,
+ boundaries=[30, 60, 80],
+ warmup_epochs=5,
+ scale_by_batch_size=1.0 / 256.0,
+ multipliers=[0.1 / 256, 0.01 / 256, 0.001 / 256, 0.0001 / 256],
+ )
+ )
diff --git a/modeling/official/legacy/image_classification/resnet/resnet_ctl_imagenet_main.py b/modeling/official/legacy/image_classification/resnet/resnet_ctl_imagenet_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0f1696aa0d2dc6216b6c9959ad4c0a4606f3c8a
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/resnet_ctl_imagenet_main.py
@@ -0,0 +1,195 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
+
+import math
+import os
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import orbit
+import tensorflow as tf, tf_keras
+from official.common import distribute_utils
+from official.legacy.image_classification.resnet import common
+from official.legacy.image_classification.resnet import imagenet_preprocessing
+from official.legacy.image_classification.resnet import resnet_runnable
+from official.modeling import performance
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+from official.utils.misc import model_helpers
+
+flags.DEFINE_boolean(name='use_tf_function', default=True,
+ help='Wrap the train and test step inside a '
+ 'tf.function.')
+flags.DEFINE_boolean(name='single_l2_loss_op', default=False,
+ help='Calculate L2_loss on concatenated weights, '
+ 'instead of using Keras per-layer L2 loss.')
+
+
+def build_stats(runnable, time_callback):
+ """Normalizes and returns dictionary of stats.
+
+ Args:
+ runnable: The module containing all the training and evaluation metrics.
+ time_callback: Time tracking callback instance.
+
+ Returns:
+ Dictionary of normalized results.
+ """
+ stats = {}
+
+ if not runnable.flags_obj.skip_eval:
+ stats['eval_loss'] = runnable.test_loss.result().numpy()
+ stats['eval_acc'] = runnable.test_accuracy.result().numpy()
+
+ stats['train_loss'] = runnable.train_loss.result().numpy()
+ stats['train_acc'] = runnable.train_accuracy.result().numpy()
+
+ if time_callback:
+ timestamp_log = time_callback.timestamp_log
+ stats['step_timestamp_log'] = timestamp_log
+ stats['train_finish_time'] = time_callback.train_finish_time
+ if time_callback.epoch_runtime_log:
+ stats['avg_exp_per_second'] = time_callback.average_examples_per_second
+
+ return stats
+
+
+def get_num_train_iterations(flags_obj):
+ """Returns the number of training steps, train and test epochs."""
+ train_steps = (
+ imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
+ train_epochs = flags_obj.train_epochs
+
+ if flags_obj.train_steps:
+ train_steps = min(flags_obj.train_steps, train_steps)
+ train_epochs = 1
+
+ eval_steps = math.ceil(1.0 * imagenet_preprocessing.NUM_IMAGES['validation'] /
+ flags_obj.batch_size)
+
+ return train_steps, train_epochs, eval_steps
+
+
+def run(flags_obj):
+ """Run ResNet ImageNet training and eval loop using custom training loops.
+
+ Args:
+ flags_obj: An object containing parsed flag values.
+
+ Raises:
+ ValueError: If fp16 is passed as it is not currently supported.
+
+ Returns:
+ Dictionary of training and eval stats.
+ """
+ keras_utils.set_session_config()
+ performance.set_mixed_precision_policy(flags_core.get_tf_dtype(flags_obj))
+
+ if tf.config.list_physical_devices('GPU'):
+ if flags_obj.tf_gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=flags_obj.per_gpu_thread_count,
+ gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
+ num_gpus=flags_obj.num_gpus,
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads)
+ common.set_cudnn_batchnorm_mode()
+
+ data_format = flags_obj.data_format
+ if data_format is None:
+ data_format = ('channels_first' if tf.config.list_physical_devices('GPU')
+ else 'channels_last')
+ tf_keras.backend.set_image_data_format(data_format)
+
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=flags_obj.num_gpus,
+ all_reduce_alg=flags_obj.all_reduce_alg,
+ num_packs=flags_obj.num_packs,
+ tpu_address=flags_obj.tpu)
+
+ per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
+ flags_obj)
+ if flags_obj.steps_per_loop is None:
+ steps_per_loop = per_epoch_steps
+ elif flags_obj.steps_per_loop > per_epoch_steps:
+ steps_per_loop = per_epoch_steps
+ logging.warn('Setting steps_per_loop to %d to respect epoch boundary.',
+ steps_per_loop)
+ else:
+ steps_per_loop = flags_obj.steps_per_loop
+
+ logging.info(
+ 'Training %d epochs, each epoch has %d steps, '
+ 'total steps: %d; Eval %d steps', train_epochs, per_epoch_steps,
+ train_epochs * per_epoch_steps, eval_steps)
+
+ time_callback = keras_utils.TimeHistory(
+ flags_obj.batch_size,
+ flags_obj.log_steps,
+ logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
+ with distribute_utils.get_strategy_scope(strategy):
+ runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
+ per_epoch_steps)
+
+ eval_interval = flags_obj.epochs_between_evals * per_epoch_steps
+ checkpoint_interval = (
+ steps_per_loop * 5 if flags_obj.enable_checkpoint_and_export else None)
+ summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None
+
+ checkpoint_manager = tf.train.CheckpointManager(
+ runnable.checkpoint,
+ directory=flags_obj.model_dir,
+ max_to_keep=10,
+ step_counter=runnable.global_step,
+ checkpoint_interval=checkpoint_interval)
+
+ resnet_controller = orbit.Controller(
+ strategy=strategy,
+ trainer=runnable,
+ evaluator=runnable if not flags_obj.skip_eval else None,
+ global_step=runnable.global_step,
+ steps_per_loop=steps_per_loop,
+ checkpoint_manager=checkpoint_manager,
+ summary_interval=summary_interval,
+ summary_dir=flags_obj.model_dir,
+ eval_summary_dir=os.path.join(flags_obj.model_dir, 'eval'))
+
+ time_callback.on_train_begin()
+ if not flags_obj.skip_eval:
+ resnet_controller.train_and_evaluate(
+ train_steps=per_epoch_steps * train_epochs,
+ eval_steps=eval_steps,
+ eval_interval=eval_interval)
+ else:
+ resnet_controller.train(steps=per_epoch_steps * train_epochs)
+ time_callback.on_train_end()
+
+ stats = build_stats(runnable, time_callback)
+ return stats
+
+
+def main(_):
+ model_helpers.apply_clean(flags.FLAGS)
+ stats = run(flags.FLAGS)
+ logging.info('Run stats:\n%s', stats)
+
+
+if __name__ == '__main__':
+ logging.set_verbosity(logging.INFO)
+ common.define_keras_flags()
+ app.run(main)
diff --git a/modeling/official/legacy/image_classification/resnet/resnet_model.py b/modeling/official/legacy/image_classification/resnet/resnet_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ee17ae693db489c5ed27e600695c3f9ff81902f
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/resnet_model.py
@@ -0,0 +1,325 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ResNet50 model for Keras.
+
+Adapted from tf_keras.applications.resnet50.ResNet50().
+This is ResNet model version 1.5.
+
+Related papers/blogs:
+- https://arxiv.org/abs/1512.03385
+- https://arxiv.org/pdf/1603.05027v2.pdf
+- http://torch.ch/blog/2016/02/04/resnets.html
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+from official.legacy.image_classification.resnet import imagenet_preprocessing
+
+layers = tf_keras.layers
+
+
+def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
+ return tf_keras.regularizers.L2(
+ l2_weight_decay) if use_l2_regularizer else None
+
+
+def identity_block(input_tensor,
+ kernel_size,
+ filters,
+ stage,
+ block,
+ use_l2_regularizer=True,
+ batch_norm_decay=0.9,
+ batch_norm_epsilon=1e-5):
+ """The identity block is the block that has no conv layer at shortcut.
+
+ Args:
+ input_tensor: input tensor
+ kernel_size: default 3, the kernel size of middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: 'a','b'..., current block label, used for generating layer names
+ use_l2_regularizer: whether to use L2 regularizer on Conv layer.
+ batch_norm_decay: Moment of batch norm layers.
+ batch_norm_epsilon: Epsilon of batch borm layers.
+
+ Returns:
+ Output tensor for the block.
+ """
+ filters1, filters2, filters3 = filters
+ if tf_keras.backend.image_data_format() == 'channels_last':
+ bn_axis = 3
+ else:
+ bn_axis = 1
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+
+ x = layers.Conv2D(
+ filters1, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2a')(
+ input_tensor)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2a')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters2,
+ kernel_size,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2b')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2b')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters3, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2c')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2c')(
+ x)
+
+ x = layers.add([x, input_tensor])
+ x = layers.Activation('relu')(x)
+ return x
+
+
+def conv_block(input_tensor,
+ kernel_size,
+ filters,
+ stage,
+ block,
+ strides=(2, 2),
+ use_l2_regularizer=True,
+ batch_norm_decay=0.9,
+ batch_norm_epsilon=1e-5):
+ """A block that has a conv layer at shortcut.
+
+ Note that from stage 3,
+ the second conv layer at main path is with strides=(2, 2)
+ And the shortcut should have strides=(2, 2) as well
+
+ Args:
+ input_tensor: input tensor
+ kernel_size: default 3, the kernel size of middle conv layer at main path
+ filters: list of integers, the filters of 3 conv layer at main path
+ stage: integer, current stage label, used for generating layer names
+ block: 'a','b'..., current block label, used for generating layer names
+ strides: Strides for the second conv layer in the block.
+ use_l2_regularizer: whether to use L2 regularizer on Conv layer.
+ batch_norm_decay: Moment of batch norm layers.
+ batch_norm_epsilon: Epsilon of batch borm layers.
+
+ Returns:
+ Output tensor for the block.
+ """
+ filters1, filters2, filters3 = filters
+ if tf_keras.backend.image_data_format() == 'channels_last':
+ bn_axis = 3
+ else:
+ bn_axis = 1
+ conv_name_base = 'res' + str(stage) + block + '_branch'
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
+
+ x = layers.Conv2D(
+ filters1, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2a')(
+ input_tensor)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2a')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters2,
+ kernel_size,
+ strides=strides,
+ padding='same',
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2b')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2b')(
+ x)
+ x = layers.Activation('relu')(x)
+
+ x = layers.Conv2D(
+ filters3, (1, 1),
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '2c')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '2c')(
+ x)
+
+ shortcut = layers.Conv2D(
+ filters3, (1, 1),
+ strides=strides,
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name=conv_name_base + '1')(
+ input_tensor)
+ shortcut = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name=bn_name_base + '1')(
+ shortcut)
+
+ x = layers.add([x, shortcut])
+ x = layers.Activation('relu')(x)
+ return x
+
+
+def resnet50(num_classes,
+ batch_size=None,
+ use_l2_regularizer=True,
+ rescale_inputs=False,
+ batch_norm_decay=0.9,
+ batch_norm_epsilon=1e-5):
+ """Instantiates the ResNet50 architecture.
+
+ Args:
+ num_classes: `int` number of classes for image classification.
+ batch_size: Size of the batches for each step.
+ use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
+ rescale_inputs: whether to rescale inputs from 0 to 1.
+ batch_norm_decay: Moment of batch norm layers.
+ batch_norm_epsilon: Epsilon of batch borm layers.
+
+ Returns:
+ A Keras model instance.
+ """
+ input_shape = (224, 224, 3)
+ img_input = layers.Input(shape=input_shape, batch_size=batch_size)
+ if rescale_inputs:
+ # Hub image modules expect inputs in the range [0, 1]. This rescales these
+ # inputs to the range expected by the trained model.
+ x = layers.Lambda(
+ lambda x: x * 255.0 - tf_keras.backend.constant( # pylint: disable=g-long-lambda
+ imagenet_preprocessing.CHANNEL_MEANS,
+ shape=[1, 1, 3],
+ dtype=x.dtype),
+ name='rescale')(
+ img_input)
+ else:
+ x = img_input
+
+ if tf_keras.backend.image_data_format() == 'channels_first':
+ x = layers.Permute((3, 1, 2))(x)
+ bn_axis = 1
+ else: # channels_last
+ bn_axis = 3
+
+ block_config = dict(
+ use_l2_regularizer=use_l2_regularizer,
+ batch_norm_decay=batch_norm_decay,
+ batch_norm_epsilon=batch_norm_epsilon)
+ x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
+ x = layers.Conv2D(
+ 64, (7, 7),
+ strides=(2, 2),
+ padding='valid',
+ use_bias=False,
+ kernel_initializer='he_normal',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='conv1')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv1')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
+
+ x = conv_block(
+ x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), **block_config)
+ x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', **block_config)
+ x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', **block_config)
+
+ x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', **block_config)
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', **block_config)
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', **block_config)
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', **block_config)
+
+ x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e', **block_config)
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f', **block_config)
+
+ x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', **block_config)
+ x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', **block_config)
+ x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', **block_config)
+
+ x = layers.GlobalAveragePooling2D()(x)
+ x = layers.Dense(
+ num_classes,
+ kernel_initializer=tf.initializers.random_normal(stddev=0.01),
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ bias_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='fc1000')(
+ x)
+
+ # A softmax that is followed by the model loss must be done cannot be done
+ # in float16 due to numeric issues. So we pass dtype=float32.
+ x = layers.Activation('softmax', dtype='float32')(x)
+
+ # Create model.
+ return tf_keras.Model(img_input, x, name='resnet50')
diff --git a/modeling/official/legacy/image_classification/resnet/resnet_runnable.py b/modeling/official/legacy/image_classification/resnet/resnet_runnable.py
new file mode 100644
index 0000000000000000000000000000000000000000..101a663533bbdf10d5b7c72e97aeb91dd5243262
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/resnet_runnable.py
@@ -0,0 +1,210 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Runs a ResNet model on the ImageNet dataset using custom training loops."""
+
+import orbit
+import tensorflow as tf, tf_keras
+from official.legacy.image_classification.resnet import common
+from official.legacy.image_classification.resnet import imagenet_preprocessing
+from official.legacy.image_classification.resnet import resnet_model
+from official.modeling import grad_utils
+from official.modeling import performance
+from official.utils.flags import core as flags_core
+
+
+class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
+ """Implements the training and evaluation APIs for Resnet model."""
+
+ def __init__(self, flags_obj, time_callback, epoch_steps):
+ self.strategy = tf.distribute.get_strategy()
+ self.flags_obj = flags_obj
+ self.dtype = flags_core.get_tf_dtype(flags_obj)
+ self.time_callback = time_callback
+
+ # Input pipeline related
+ batch_size = flags_obj.batch_size
+ if batch_size % self.strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ 'Batch size must be divisible by number of replicas : {}'.format(
+ self.strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ self.batch_size = int(batch_size / self.strategy.num_replicas_in_sync)
+
+ if self.flags_obj.use_synthetic_data:
+ self.input_fn = common.get_synth_input_fn(
+ height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
+ width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
+ num_channels=imagenet_preprocessing.NUM_CHANNELS,
+ num_classes=imagenet_preprocessing.NUM_CLASSES,
+ dtype=self.dtype,
+ drop_remainder=True)
+ else:
+ self.input_fn = imagenet_preprocessing.input_fn
+
+ self.model = resnet_model.resnet50(
+ num_classes=imagenet_preprocessing.NUM_CLASSES,
+ use_l2_regularizer=not flags_obj.single_l2_loss_op)
+
+ lr_schedule = common.PiecewiseConstantDecayWithWarmup(
+ batch_size=flags_obj.batch_size,
+ epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
+ warmup_epochs=common.LR_SCHEDULE[0][1],
+ boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
+ multipliers=list(p[0] for p in common.LR_SCHEDULE),
+ compute_lr_on_cpu=True)
+ self.optimizer = common.get_optimizer(lr_schedule)
+ # Make sure iterations variable is created inside scope.
+ self.global_step = self.optimizer.iterations
+ self.optimizer = performance.configure_optimizer(
+ self.optimizer,
+ use_float16=self.dtype == tf.float16,
+ loss_scale=flags_core.get_loss_scale(flags_obj, default_for_fp16=128))
+
+ self.train_loss = tf_keras.metrics.Mean('train_loss', dtype=tf.float32)
+ self.train_accuracy = tf_keras.metrics.SparseCategoricalAccuracy(
+ 'train_accuracy', dtype=tf.float32)
+ self.test_loss = tf_keras.metrics.Mean('test_loss', dtype=tf.float32)
+ self.test_accuracy = tf_keras.metrics.SparseCategoricalAccuracy(
+ 'test_accuracy', dtype=tf.float32)
+
+ self.checkpoint = tf.train.Checkpoint(
+ model=self.model, optimizer=self.optimizer)
+
+ # Handling epochs.
+ self.epoch_steps = epoch_steps
+ self.epoch_helper = orbit.utils.EpochHelper(epoch_steps, self.global_step)
+ train_dataset = orbit.utils.make_distributed_dataset(
+ self.strategy,
+ self.input_fn,
+ is_training=True,
+ data_dir=self.flags_obj.data_dir,
+ batch_size=self.batch_size,
+ parse_record_fn=imagenet_preprocessing.parse_record,
+ datasets_num_private_threads=self.flags_obj
+ .datasets_num_private_threads,
+ dtype=self.dtype,
+ drop_remainder=True,
+ training_dataset_cache=self.flags_obj.training_dataset_cache)
+ orbit.StandardTrainer.__init__(
+ self,
+ train_dataset,
+ options=orbit.StandardTrainerOptions(
+ use_tf_while_loop=flags_obj.use_tf_while_loop,
+ use_tf_function=flags_obj.use_tf_function))
+ if not flags_obj.skip_eval:
+ eval_dataset = orbit.utils.make_distributed_dataset(
+ self.strategy,
+ self.input_fn,
+ is_training=False,
+ data_dir=self.flags_obj.data_dir,
+ batch_size=self.batch_size,
+ parse_record_fn=imagenet_preprocessing.parse_record,
+ dtype=self.dtype)
+ orbit.StandardEvaluator.__init__(
+ self,
+ eval_dataset,
+ options=orbit.StandardEvaluatorOptions(
+ use_tf_function=flags_obj.use_tf_function))
+
+ def train_loop_begin(self):
+ """See base class."""
+ # Reset all metrics
+ self.train_loss.reset_states()
+ self.train_accuracy.reset_states()
+
+ self._epoch_begin()
+ self.time_callback.on_batch_begin(self.epoch_helper.batch_index)
+
+ def train_step(self, iterator):
+ """See base class."""
+
+ def step_fn(inputs):
+ """Function to run on the device."""
+ images, labels = inputs
+ with tf.GradientTape() as tape:
+ logits = self.model(images, training=True)
+
+ prediction_loss = tf_keras.losses.sparse_categorical_crossentropy(
+ labels, logits)
+ loss = tf.reduce_sum(prediction_loss) * (1.0 /
+ self.flags_obj.batch_size)
+ num_replicas = self.strategy.num_replicas_in_sync
+ l2_weight_decay = 1e-4
+ if self.flags_obj.single_l2_loss_op:
+ l2_loss = l2_weight_decay * 2 * tf.add_n([
+ tf.nn.l2_loss(v)
+ for v in self.model.trainable_variables
+ if 'bn' not in v.name
+ ])
+
+ loss += (l2_loss / num_replicas)
+ else:
+ loss += (tf.reduce_sum(self.model.losses) / num_replicas)
+
+ grad_utils.minimize_using_explicit_allreduce(
+ tape, self.optimizer, loss, self.model.trainable_variables)
+ self.train_loss.update_state(loss)
+ self.train_accuracy.update_state(labels, logits)
+ if self.flags_obj.enable_xla:
+ step_fn = tf.function(step_fn, jit_compile=True)
+ self.strategy.run(step_fn, args=(next(iterator),))
+
+ def train_loop_end(self):
+ """See base class."""
+ metrics = {
+ 'train_loss': self.train_loss.result(),
+ 'train_accuracy': self.train_accuracy.result(),
+ }
+ self.time_callback.on_batch_end(self.epoch_helper.batch_index - 1)
+ self._epoch_end()
+ return metrics
+
+ def eval_begin(self):
+ """See base class."""
+ self.test_loss.reset_states()
+ self.test_accuracy.reset_states()
+
+ def eval_step(self, iterator):
+ """See base class."""
+
+ def step_fn(inputs):
+ """Function to run on the device."""
+ images, labels = inputs
+ logits = self.model(images, training=False)
+ loss = tf_keras.losses.sparse_categorical_crossentropy(labels, logits)
+ loss = tf.reduce_sum(loss) * (1.0 / self.flags_obj.batch_size)
+ self.test_loss.update_state(loss)
+ self.test_accuracy.update_state(labels, logits)
+
+ self.strategy.run(step_fn, args=(next(iterator),))
+
+ def eval_end(self):
+ """See base class."""
+ return {
+ 'test_loss': self.test_loss.result(),
+ 'test_accuracy': self.test_accuracy.result()
+ }
+
+ def _epoch_begin(self):
+ if self.epoch_helper.epoch_begin():
+ self.time_callback.on_epoch_begin(self.epoch_helper.current_epoch)
+
+ def _epoch_end(self):
+ if self.epoch_helper.epoch_end():
+ self.time_callback.on_epoch_end(self.epoch_helper.current_epoch)
diff --git a/modeling/official/legacy/image_classification/resnet/tfhub_export.py b/modeling/official/legacy/image_classification/resnet/tfhub_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c319da358970f2c1e30ca0c74828cacb25ba48f
--- /dev/null
+++ b/modeling/official/legacy/image_classification/resnet/tfhub_export.py
@@ -0,0 +1,66 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A script to export TF-Hub SavedModel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+# Import libraries
+from absl import app
+from absl import flags
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.image_classification.resnet import imagenet_preprocessing
+from official.legacy.image_classification.resnet import resnet_model
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("model_path", None,
+ "File path to TF model checkpoint or H5 file.")
+flags.DEFINE_string("export_path", None,
+ "TF-Hub SavedModel destination path to export.")
+
+
+def export_tfhub(model_path, hub_destination):
+ """Restores a tf_keras.Model and saves for TF-Hub."""
+ model = resnet_model.resnet50(
+ num_classes=imagenet_preprocessing.NUM_CLASSES, rescale_inputs=True)
+ model.load_weights(model_path)
+ model.save(
+ os.path.join(hub_destination, "classification"), include_optimizer=False)
+
+ # Extracts a sub-model to use pooling feature vector as model output.
+ image_input = model.get_layer(index=0).get_output_at(0)
+ feature_vector_output = model.get_layer(name="reduce_mean").get_output_at(0)
+ hub_model = tf_keras.Model(image_input, feature_vector_output)
+
+ # Exports a SavedModel.
+ hub_model.save(
+ os.path.join(hub_destination, "feature-vector"), include_optimizer=False)
+
+
+def main(argv):
+ if len(argv) > 1:
+ raise app.UsageError("Too many command-line arguments.")
+
+ export_tfhub(FLAGS.model_path, FLAGS.export_path)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/image_classification/test_utils.py b/modeling/official/legacy/image_classification/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b3cabd5f09fe054e1100d57d615283b4d26333f
--- /dev/null
+++ b/modeling/official/legacy/image_classification/test_utils.py
@@ -0,0 +1,37 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test utilities for image classification tasks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf, tf_keras
+
+
+def trivial_model(num_classes):
+ """Trivial model for ImageNet dataset."""
+
+ input_shape = (224, 224, 3)
+ img_input = tf_keras.layers.Input(shape=input_shape)
+
+ x = tf_keras.layers.Lambda(
+ lambda x: tf_keras.backend.reshape(x, [-1, 224 * 224 * 3]),
+ name='reshape')(img_input)
+ x = tf_keras.layers.Dense(1, name='fc1')(x)
+ x = tf_keras.layers.Dense(num_classes, name='fc1000')(x)
+ x = tf_keras.layers.Activation('softmax', dtype='float32')(x)
+
+ return tf_keras.models.Model(img_input, x, name='trivial')
diff --git a/modeling/official/legacy/image_classification/vgg/__init__.py b/modeling/official/legacy/image_classification/vgg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f338592c943c69c8ca66bc1f0981a619ea10e27
--- /dev/null
+++ b/modeling/official/legacy/image_classification/vgg/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
diff --git a/modeling/official/legacy/image_classification/vgg/vgg_config.py b/modeling/official/legacy/image_classification/vgg/vgg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..61cd34e2a804e9b5e27a05c80171c64099e25e0c
--- /dev/null
+++ b/modeling/official/legacy/image_classification/vgg/vgg_config.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration definitions for VGG losses, learning rates, and optimizers."""
+
+import dataclasses
+from official.legacy.image_classification.configs import base_configs
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class VGGModelConfig(base_configs.ModelConfig):
+ """Configuration for the VGG model."""
+ name: str = 'VGG'
+ num_classes: int = 1000
+ model_params: base_config.Config = dataclasses.field(default_factory=lambda: { # pylint:disable=g-long-lambda
+ 'num_classes': 1000,
+ 'batch_size': None,
+ 'use_l2_regularizer': True
+ })
+ loss: base_configs.LossConfig = dataclasses.field(
+ default_factory=lambda: base_configs.LossConfig( # pylint: disable=g-long-lambda
+ name='sparse_categorical_crossentropy'
+ )
+ )
+ optimizer: base_configs.OptimizerConfig = dataclasses.field(
+ default_factory=lambda: base_configs.OptimizerConfig( # pylint: disable=g-long-lambda
+ name='momentum',
+ epsilon=0.001,
+ momentum=0.9,
+ moving_average_decay=None,
+ )
+ )
+ learning_rate: base_configs.LearningRateConfig = dataclasses.field(
+ default_factory=lambda: base_configs.LearningRateConfig( # pylint: disable=g-long-lambda
+ name='stepwise',
+ initial_lr=0.01,
+ examples_per_epoch=1281167,
+ boundaries=[30, 60],
+ warmup_epochs=0,
+ scale_by_batch_size=1.0 / 256.0,
+ multipliers=[0.01 / 256, 0.001 / 256, 0.0001 / 256],
+ )
+ )
diff --git a/modeling/official/legacy/image_classification/vgg/vgg_model.py b/modeling/official/legacy/image_classification/vgg/vgg_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f48cf08a1f75a0e90f6da73518a6fa15ca628d68
--- /dev/null
+++ b/modeling/official/legacy/image_classification/vgg/vgg_model.py
@@ -0,0 +1,269 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""VGG16 model for Keras.
+
+Adapted from tf_keras.applications.vgg16.VGG16().
+
+Related papers/blogs:
+- https://arxiv.org/abs/1409.1556
+"""
+
+import tensorflow as tf, tf_keras
+
+layers = tf_keras.layers
+
+
+def _gen_l2_regularizer(use_l2_regularizer=True, l2_weight_decay=1e-4):
+ return tf_keras.regularizers.L2(
+ l2_weight_decay) if use_l2_regularizer else None
+
+
+def vgg16(num_classes,
+ batch_size=None,
+ use_l2_regularizer=True,
+ batch_norm_decay=0.9,
+ batch_norm_epsilon=1e-5):
+ """Instantiates the VGG16 architecture.
+
+ Args:
+ num_classes: `int` number of classes for image classification.
+ batch_size: Size of the batches for each step.
+ use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
+ batch_norm_decay: Moment of batch norm layers.
+ batch_norm_epsilon: Epsilon of batch borm layers.
+
+ Returns:
+ A Keras model instance.
+
+ """
+ input_shape = (224, 224, 3)
+ img_input = layers.Input(shape=input_shape, batch_size=batch_size)
+
+ x = img_input
+
+ if tf_keras.backend.image_data_format() == 'channels_first':
+ x = layers.Permute((3, 1, 2))(x)
+ bn_axis = 1
+ else: # channels_last
+ bn_axis = 3
+ # Block 1
+ x = layers.Conv2D(
+ 64, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block1_conv1')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv1')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 64, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block1_conv2')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv2')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
+
+ # Block 2
+ x = layers.Conv2D(
+ 128, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block2_conv1')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv3')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 128, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block2_conv2')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv4')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
+
+ # Block 3
+ x = layers.Conv2D(
+ 256, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block3_conv1')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv5')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 256, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block3_conv2')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv6')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 256, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block3_conv3')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv7')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
+
+ # Block 4
+ x = layers.Conv2D(
+ 512, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block4_conv1')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv8')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 512, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block4_conv2')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv9')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 512, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block4_conv3')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv10')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
+
+ # Block 5
+ x = layers.Conv2D(
+ 512, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block5_conv1')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv11')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 512, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block5_conv2')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv12')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Conv2D(
+ 512, (3, 3),
+ padding='same',
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='block5_conv3')(
+ x)
+ x = layers.BatchNormalization(
+ axis=bn_axis,
+ momentum=batch_norm_decay,
+ epsilon=batch_norm_epsilon,
+ name='bn_conv13')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
+
+ x = layers.Flatten(name='flatten')(x)
+ x = layers.Dense(
+ 4096,
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='fc1')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Dropout(0.5)(x)
+ x = layers.Dense(
+ 4096,
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='fc2')(
+ x)
+ x = layers.Activation('relu')(x)
+ x = layers.Dropout(0.5)(x)
+ x = layers.Dense(
+ num_classes,
+ kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer),
+ name='fc1000')(
+ x)
+
+ x = layers.Activation('softmax', dtype='float32')(x)
+
+ # Create model.
+ return tf_keras.Model(img_input, x, name='vgg16')
diff --git a/modeling/official/legacy/transformer/README.md b/modeling/official/legacy/transformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1edf4f85963a8386010321bdc6aec6701b89be22
--- /dev/null
+++ b/modeling/official/legacy/transformer/README.md
@@ -0,0 +1,220 @@
+# Transformer Translation Model
+This is an implementation of the Transformer translation model as described in
+the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper. The
+implementation leverages tf.keras and makes sure it is compatible with TF 2.x.
+
+**Warning: the features in the `transformer/` folder have been fully intergrated
+into nlp/modeling.
+Due to its dependencies, we will remove this folder after the model
+garden 2.5 release. The model in `nlp/modeling/models/seq2seq_transformer.py` is
+identical to the model in this folder.**
+
+## Contents
+ * [Contents](#contents)
+ * [Walkthrough](#walkthrough)
+ * [Detailed instructions](#detailed-instructions)
+ * [Environment preparation](#environment-preparation)
+ * [Download and preprocess datasets](#download-and-preprocess-datasets)
+ * [Model training and evaluation](#model-training-and-evaluation)
+ * [Implementation overview](#implementation-overview)
+ * [Model Definition](#model-definition)
+ * [Model Trainer](#model-trainer)
+ * [Test dataset](#test-dataset)
+
+## Walkthrough
+
+Below are the commands for running the Transformer model. See the
+[Detailed instructions](#detailed-instructions) for more details on running the
+model.
+
+```
+# Ensure that PYTHONPATH is correctly defined as described in
+# https://github.com/tensorflow/models/tree/master/official#requirements
+export PYTHONPATH="$PYTHONPATH:/path/to/models"
+
+cd /path/to/models/official/legacy/transformer
+
+# Export variables
+PARAM_SET=big
+DATA_DIR=$HOME/transformer/data
+MODEL_DIR=$HOME/transformer/model_$PARAM_SET
+VOCAB_FILE=$DATA_DIR/vocab.ende.32768
+
+# Download training/evaluation/test datasets
+python3 data_download.py --data_dir=$DATA_DIR
+
+# Train the model for 100000 steps and evaluate every 5000 steps on a single GPU.
+# Each train step, takes 4096 tokens as a batch budget with 64 as sequence
+# maximal length.
+python3 transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
+ --vocab_file=$VOCAB_FILE --param_set=$PARAM_SET \
+ --train_steps=100000 --steps_between_evals=5000 \
+ --batch_size=4096 --max_length=64 \
+ --bleu_source=$DATA_DIR/newstest2014.en \
+ --bleu_ref=$DATA_DIR/newstest2014.de \
+ --num_gpus=1 \
+ --enable_time_history=false
+
+# Run during training in a separate process to get continuous updates,
+# or after training is complete.
+tensorboard --logdir=$MODEL_DIR
+```
+
+## Detailed instructions
+
+
+0. ### Environment preparation
+
+ #### Add models repo to PYTHONPATH
+ Follow the instructions described in the [Requirements](https://github.com/tensorflow/models/tree/master/official#requirements) section to add the models folder to the python path.
+
+ #### Export variables (optional)
+
+ Export the following variables, or modify the values in each of the snippets below:
+
+ ```shell
+ PARAM_SET=big
+ DATA_DIR=$HOME/transformer/data
+ MODEL_DIR=$HOME/transformer/model_$PARAM_SET
+ VOCAB_FILE=$DATA_DIR/vocab.ende.32768
+ ```
+
+1. ### Download and preprocess datasets
+
+ [data_download.py](data_download.py) downloads and preprocesses the training and evaluation WMT datasets. After the data is downloaded and extracted, the training data is used to generate a vocabulary of subtokens. The evaluation and training strings are tokenized, and the resulting data is sharded, shuffled, and saved as TFRecords.
+
+ 1.75GB of compressed data will be downloaded. In total, the raw files (compressed, extracted, and combined files) take up 8.4GB of disk space. The resulting TFRecord and vocabulary files are 722MB. The script takes around 40 minutes to run, with the bulk of the time spent downloading and ~15 minutes spent on preprocessing.
+
+ Command to run:
+ ```
+ python3 data_download.py --data_dir=$DATA_DIR
+ ```
+
+ Arguments:
+ * `--data_dir`: Path where the preprocessed TFRecord data, and vocab file will be saved.
+ * Use the `--help` or `-h` flag to get a full list of possible arguments.
+
+2. ### Model training and evaluation
+
+ [transformer_main.py](transformer_main.py) creates a Transformer keras model,
+ and trains it uses keras model.fit().
+
+ Users need to adjust `batch_size` and `num_gpus` to get good performance
+ running multiple GPUs.
+
+ **Note that:**
+ when using multiple GPUs or TPUs, this is the global batch size for all
+ devices. For example, if the batch size is `4096*4` and there are 4 devices,
+ each device will take 4096 tokens as a batch budget.
+
+ Command to run:
+ ```
+ python3 transformer_main.py --data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
+ --vocab_file=$VOCAB_FILE --param_set=$PARAM_SET
+ ```
+
+ Arguments:
+ * `--data_dir`: This should be set to the same directory given to the `data_download`'s `data_dir` argument.
+ * `--model_dir`: Directory to save Transformer model training checkpoints.
+ * `--vocab_file`: Path to subtoken vocabulary file. If data_download was used, you may find the file in `data_dir`.
+ * `--param_set`: Parameter set to use when creating and training the model. Options are `base` and `big` (default).
+ * `--enable_time_history`: Whether add TimeHistory call. If so, --log_steps must be specified.
+ * `--batch_size`: The number of tokens to consider in a batch. Combining with
+ `--max_length`, they decide how many sequences are used per batch.
+ * Use the `--help` or `-h` flag to get a full list of possible arguments.
+
+ #### Using multiple GPUs
+ You can train these models on multiple GPUs using `tf.distribute.Strategy` API.
+ You can read more about them in this
+ [guide](https://www.tensorflow.org/guide/distribute_strategy).
+
+ In this example, we have made it easier to use is with just a command line flag
+ `--num_gpus`. By default this flag is 1 if TensorFlow is compiled with CUDA,
+ and 0 otherwise.
+
+ - --num_gpus=0: Uses tf.distribute.OneDeviceStrategy with CPU as the device.
+ - --num_gpus=1: Uses tf.distribute.OneDeviceStrategy with GPU as the device.
+ - --num_gpus=2+: Uses tf.distribute.MirroredStrategy to run synchronous
+ distributed training across the GPUs.
+
+ #### Using Cloud TPUs
+
+ You can train the Transformer model on Cloud TPUs using
+ `tf.distribute.TPUStrategy`. If you are not familiar with Cloud TPUs, it is
+ strongly recommended that you go through the
+ [quickstart](https://cloud.google.com/tpu/docs/quickstart) to learn how to
+ create a TPU and GCE VM.
+
+ To run the Transformer model on a TPU, you must set
+ `--distribution_strategy=tpu`, `--tpu=$TPU_NAME`, and `--use_ctl=True` where
+ `$TPU_NAME` the name of your TPU in the Cloud Console.
+
+ An example command to run Transformer on a v2-8 or v3-8 TPU would be:
+
+ ```bash
+ python transformer_main.py \
+ --tpu=$TPU_NAME \
+ --model_dir=$MODEL_DIR \
+ --data_dir=$DATA_DIR \
+ --vocab_file=$DATA_DIR/vocab.ende.32768 \
+ --bleu_source=$DATA_DIR/newstest2014.en \
+ --bleu_ref=$DATA_DIR/newstest2014.end \
+ --batch_size=6144 \
+ --train_steps=2000 \
+ --static_batch=true \
+ --use_ctl=true \
+ --param_set=big \
+ --max_length=64 \
+ --decode_batch_size=32 \
+ --decode_max_length=97 \
+ --padded_decode=true \
+ --distribution_strategy=tpu
+ ```
+ Note: `$MODEL_DIR` and `$DATA_DIR` must be GCS paths.
+
+ #### Customizing training schedule
+
+ By default, the model will train for 10 epochs, and evaluate after every epoch. The training schedule may be defined through the flags:
+
+ * Training with steps:
+ * `--train_steps`: sets the total number of training steps to run.
+ * `--steps_between_evals`: Number of training steps to run between evaluations.
+
+ #### Compute BLEU score during model evaluation
+
+ Use these flags to compute the BLEU when the model evaluates:
+
+ * `--bleu_source`: Path to file containing text to translate.
+ * `--bleu_ref`: Path to file containing the reference translation.
+
+ When running `transformer_main.py`, use the flags: `--bleu_source=$DATA_DIR/newstest2014.en --bleu_ref=$DATA_DIR/newstest2014.de`
+
+ #### Tensorboard
+ Training and evaluation metrics (loss, accuracy, approximate BLEU score, etc.) are logged, and can be displayed in the browser using Tensorboard.
+ ```
+ tensorboard --logdir=$MODEL_DIR
+ ```
+ The values are displayed at [localhost:6006](localhost:6006).
+
+## Implementation overview
+
+A brief look at each component in the code:
+
+### Model Definition
+* [transformer.py](transformer.py): Defines a tf.keras.Model: `Transformer`.
+* [embedding_layer.py](embedding_layer.py): Contains the layer that calculates the embeddings. The embedding weights are also used to calculate the pre-softmax probabilities from the decoder output.
+* [attention_layer.py](attention_layer.py): Defines the multi-headed and self attention layers that are used in the encoder/decoder stacks.
+* [ffn_layer.py](ffn_layer.py): Defines the feedforward network that is used in the encoder/decoder stacks. The network is composed of 2 fully connected layers.
+
+Other files:
+* [beam_search.py](beam_search.py) contains the beam search implementation, which is used during model inference to find high scoring translations.
+
+### Model Trainer
+[transformer_main.py](transformer_main.py) creates an `TransformerTask` to train and evaluate the model using tf.keras.
+
+### Test dataset
+The [newstest2014 files](https://storage.googleapis.com/tf-perf-public/official_transformer/test_data/newstest2014.tgz)
+are extracted from the [NMT Seq2Seq tutorial](https://google.github.io/seq2seq/nmt/#download-data).
+The raw text files are converted from the SGM format of the
+[WMT 2016](http://www.statmt.org/wmt16/translation-task.html) test sets. The
+newstest2014 files are put into the `$DATA_DIR` when executing `data_download.py`
diff --git a/modeling/official/legacy/transformer/__init__.py b/modeling/official/legacy/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/transformer/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/transformer/attention_layer.py b/modeling/official/legacy/transformer/attention_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e11253afe00d2b3f0f576c9586e9970e95d06c
--- /dev/null
+++ b/modeling/official/legacy/transformer/attention_layer.py
@@ -0,0 +1,178 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Implementation of multiheaded attention and self-attention layers."""
+import math
+
+import tensorflow as tf, tf_keras
+
+from official.modeling import tf_utils
+
+
+class Attention(tf_keras.layers.Layer):
+ """Multi-headed attention layer."""
+
+ def __init__(self, hidden_size, num_heads, attention_dropout):
+ """Initialize Attention.
+
+ Args:
+ hidden_size: int, output dim of hidden layer.
+ num_heads: int, number of heads to repeat the same attention structure.
+ attention_dropout: float, dropout rate inside attention for training.
+ """
+ if hidden_size % num_heads:
+ raise ValueError(
+ "Hidden size ({}) must be divisible by the number of heads ({})."
+ .format(hidden_size, num_heads))
+
+ super(Attention, self).__init__()
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.attention_dropout = attention_dropout
+
+ def build(self, input_shape):
+ """Builds the layer."""
+ # Layers for linearly projecting the queries, keys, and values.
+ size_per_head = self.hidden_size // self.num_heads
+
+ def _glorot_initializer(fan_in, fan_out):
+ limit = math.sqrt(6.0 / (fan_in + fan_out))
+ return tf_keras.initializers.RandomUniform(minval=-limit, maxval=limit)
+
+ attention_initializer = _glorot_initializer(input_shape.as_list()[-1],
+ self.hidden_size)
+ self.query_dense_layer = tf_keras.layers.EinsumDense(
+ "BTE,ENH->BTNH",
+ output_shape=(None, self.num_heads, size_per_head),
+ kernel_initializer=tf_utils.clone_initializer(attention_initializer),
+ bias_axes=None,
+ name="query")
+ self.key_dense_layer = tf_keras.layers.EinsumDense(
+ "BTE,ENH->BTNH",
+ output_shape=(None, self.num_heads, size_per_head),
+ kernel_initializer=tf_utils.clone_initializer(attention_initializer),
+ bias_axes=None,
+ name="key")
+ self.value_dense_layer = tf_keras.layers.EinsumDense(
+ "BTE,ENH->BTNH",
+ output_shape=(None, self.num_heads, size_per_head),
+ kernel_initializer=tf_utils.clone_initializer(attention_initializer),
+ bias_axes=None,
+ name="value")
+
+ output_initializer = _glorot_initializer(self.hidden_size, self.hidden_size)
+ self.output_dense_layer = tf_keras.layers.EinsumDense(
+ "BTNH,NHE->BTE",
+ output_shape=(None, self.hidden_size),
+ kernel_initializer=output_initializer,
+ bias_axes=None,
+ name="output_transform")
+ super(Attention, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "hidden_size": self.hidden_size,
+ "num_heads": self.num_heads,
+ "attention_dropout": self.attention_dropout,
+ }
+
+ def call(self,
+ query_input,
+ source_input,
+ bias,
+ training,
+ cache=None,
+ decode_loop_step=None):
+ """Apply attention mechanism to query_input and source_input.
+
+ Args:
+ query_input: A tensor with shape [batch_size, length_query, hidden_size].
+ source_input: A tensor with shape [batch_size, length_source,
+ hidden_size].
+ bias: A tensor with shape [batch_size, 1, length_query, length_source],
+ the attention bias that will be added to the result of the dot product.
+ training: A bool, whether in training mode or not.
+ cache: (Used during prediction) A dictionary with tensors containing
+ results of previous attentions. The dictionary must have the items:
+ {"k": tensor with shape [batch_size, i, heads, dim_per_head],
+ "v": tensor with shape [batch_size, i, heads, dim_per_head]} where
+ i is the current decoded length for non-padded decode, or max
+ sequence length for padded decode.
+ decode_loop_step: An integer, step number of the decoding loop. Used only
+ for autoregressive inference on TPU.
+
+ Returns:
+ Attention layer output with shape [batch_size, length_query, hidden_size]
+ """
+ # Linearly project the query, key and value using different learned
+ # projections. Splitting heads is automatically done during the linear
+ # projections --> [batch_size, length, num_heads, dim_per_head].
+ query = self.query_dense_layer(query_input)
+ key = self.key_dense_layer(source_input)
+ value = self.value_dense_layer(source_input)
+
+ if cache is not None:
+ # Combine cached keys and values with new keys and values.
+ if decode_loop_step is not None:
+ cache_k_shape = cache["k"].shape.as_list()
+ indices = tf.reshape(
+ tf.one_hot(decode_loop_step, cache_k_shape[1], dtype=key.dtype),
+ [1, cache_k_shape[1], 1, 1])
+ key = cache["k"] + key * indices
+ cache_v_shape = cache["v"].shape.as_list()
+ indices = tf.reshape(
+ tf.one_hot(decode_loop_step, cache_v_shape[1], dtype=value.dtype),
+ [1, cache_v_shape[1], 1, 1])
+ value = cache["v"] + value * indices
+ else:
+ key = tf.concat([tf.cast(cache["k"], key.dtype), key], axis=1)
+ value = tf.concat([tf.cast(cache["v"], value.dtype), value], axis=1)
+
+ # Update cache
+ cache["k"] = key
+ cache["v"] = value
+
+ # Scale query to prevent the dot product between query and key from growing
+ # too large.
+ depth = (self.hidden_size // self.num_heads)
+ query *= depth**-0.5
+
+ # Calculate dot product attention
+ logits = tf.einsum("BTNH,BFNH->BNFT", key, query)
+ logits += bias
+ # Note that softmax internally performs math operations using float32
+ # for numeric stability. When training with float16, we keep the input
+ # and output in float16 for better performance.
+ weights = tf.nn.softmax(logits, name="attention_weights")
+ if training:
+ weights = tf.nn.dropout(weights, rate=self.attention_dropout)
+ attention_output = tf.einsum("BNFT,BTNH->BFNH", weights, value)
+
+ # Run the outputs through another linear projection layer. Recombining heads
+ # is automatically done --> [batch_size, length, hidden_size]
+ attention_output = self.output_dense_layer(attention_output)
+ return attention_output
+
+
+class SelfAttention(Attention):
+ """Multiheaded self-attention layer."""
+
+ def call(self,
+ query_input,
+ bias,
+ training,
+ cache=None,
+ decode_loop_step=None):
+ return super(SelfAttention, self).call(query_input, query_input, bias,
+ training, cache, decode_loop_step)
diff --git a/modeling/official/legacy/transformer/beam_search_v1.py b/modeling/official/legacy/transformer/beam_search_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3299fe12a4820494e5249d4ea0e03f766c15798
--- /dev/null
+++ b/modeling/official/legacy/transformer/beam_search_v1.py
@@ -0,0 +1,82 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Beam search to find the translated sequence with the highest probability."""
+
+import tensorflow.compat.v1 as tf
+from official.nlp.modeling.ops import beam_search
+
+_StateKeys = beam_search._StateKeys # pylint: disable=protected-access
+
+
+class SequenceBeamSearch(beam_search.SequenceBeamSearch):
+ """Implementation of beam search loop."""
+
+ def _process_finished_state(self, finished_state):
+ alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
+ alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
+ finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
+ finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
+ finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
+
+ # Account for corner case where there are no finished sequences for a
+ # particular batch item. In that case, return alive sequences for that batch
+ # item.
+ finished_seq = tf.where(
+ tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
+ finished_scores = tf.where(
+ tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
+ return finished_seq, finished_scores
+
+
+def sequence_beam_search(symbols_to_logits_fn,
+ initial_ids,
+ initial_cache,
+ vocab_size,
+ beam_size,
+ alpha,
+ max_decode_length,
+ eos_id,
+ padded_decode=False):
+ """Search for sequence of subtoken ids with the largest probability.
+
+ Args:
+ symbols_to_logits_fn: A function that takes in ids, index, and cache as
+ arguments. The passed in arguments will have shape: ids -> A tensor with
+ shape [batch_size * beam_size, index]. index -> A scalar. cache -> A
+ nested dictionary of tensors [batch_size * beam_size, ...].
+ The function must return a tuple of logits and new cache: logits -> A
+ tensor with shape [batch * beam_size, vocab_size]. new cache -> A nested
+ dictionary with the same shape/structure as the inputted cache.
+ initial_ids: An int32 tensor with shape [batch_size]. Starting ids for each
+ batch item.
+ initial_cache: A dictionary, containing starting decoder variables
+ information.
+ vocab_size: An integer, the size of the vocabulary, used for topk
+ computation.
+ beam_size: An integer, the number of beams.
+ alpha: A float, defining the strength of length normalization.
+ max_decode_length: An integer, the maximum length to decoded a sequence.
+ eos_id: An integer, ID of eos token, used to determine when a sequence has
+ finished.
+ padded_decode: A bool, indicating if max_sequence_length padding is used for
+ beam search.
+
+ Returns:
+ Top decoded sequences [batch_size, beam_size, max_decode_length]
+ sequence scores [batch_size, beam_size]
+ """
+ sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, beam_size, alpha,
+ max_decode_length, eos_id, padded_decode)
+ return sbs.search(initial_ids, initial_cache)
diff --git a/modeling/official/legacy/transformer/compute_bleu.py b/modeling/official/legacy/transformer/compute_bleu.py
new file mode 100644
index 0000000000000000000000000000000000000000..2df04fb6d025fa3746138c81786a854e943048aa
--- /dev/null
+++ b/modeling/official/legacy/transformer/compute_bleu.py
@@ -0,0 +1,148 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to compute official BLEU score.
+
+Source:
+https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
+"""
+
+import re
+import sys
+import unicodedata
+
+from absl import app
+from absl import flags
+from absl import logging
+import six
+from six.moves import range
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer.utils import metrics
+from official.legacy.transformer.utils import tokenizer
+from official.utils.flags import core as flags_core
+
+
+class UnicodeRegex(object):
+ """Ad-hoc hack to recognize all punctuation and symbols."""
+
+ def __init__(self):
+ punctuation = self.property_chars("P")
+ self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
+ self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
+ self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
+
+ def property_chars(self, prefix):
+ return "".join(
+ six.unichr(x)
+ for x in range(sys.maxunicode)
+ if unicodedata.category(six.unichr(x)).startswith(prefix))
+
+
+uregex = UnicodeRegex()
+
+
+def bleu_tokenize(string):
+ r"""Tokenize a string following the official BLEU implementation.
+
+ See https://github.com/moses-smt/mosesdecoder/'
+ 'blob/master/scripts/generic/mteval-v14.pl#L954-L983
+ In our case, the input string is expected to be just one line
+ and no HTML entities de-escaping is needed.
+ So we just tokenize on punctuation and symbols,
+ except when a punctuation is preceded and followed by a digit
+ (e.g. a comma/dot as a thousand/decimal separator).
+
+ Note that a numer (e.g. a year) followed by a dot at the end of sentence
+ is NOT tokenized,
+ i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
+ does not match this case (unless we add a space after each sentence).
+ However, this error is already in the original mteval-v14.pl
+ and we want to be consistent with it.
+
+ Args:
+ string: the input string
+
+ Returns:
+ a list of tokens
+ """
+ string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
+ string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
+ string = uregex.symbol_re.sub(r" \1 ", string)
+ return string.split()
+
+
+def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
+ """Compute BLEU for two files (reference and hypothesis translation)."""
+ ref_lines = tokenizer.native_to_unicode(
+ tf.io.gfile.GFile(ref_filename).read()).strip().splitlines()
+ hyp_lines = tokenizer.native_to_unicode(
+ tf.io.gfile.GFile(hyp_filename).read()).strip().splitlines()
+ return bleu_on_list(ref_lines, hyp_lines, case_sensitive)
+
+
+def bleu_on_list(ref_lines, hyp_lines, case_sensitive=False):
+ """Compute BLEU for two list of strings (reference and hypothesis)."""
+ if len(ref_lines) != len(hyp_lines):
+ raise ValueError(
+ "Reference and translation files have different number of "
+ "lines (%d VS %d). If training only a few steps (100-200), the "
+ "translation may be empty." % (len(ref_lines), len(hyp_lines)))
+ if not case_sensitive:
+ ref_lines = [x.lower() for x in ref_lines]
+ hyp_lines = [x.lower() for x in hyp_lines]
+ ref_tokens = [bleu_tokenize(x) for x in ref_lines]
+ hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
+ return metrics.compute_bleu(ref_tokens, hyp_tokens) * 100
+
+
+def main(unused_argv):
+ if FLAGS.bleu_variant in ("both", "uncased"):
+ score = bleu_wrapper(FLAGS.reference, FLAGS.translation, False)
+ logging.info("Case-insensitive results: %f", score)
+
+ if FLAGS.bleu_variant in ("both", "cased"):
+ score = bleu_wrapper(FLAGS.reference, FLAGS.translation, True)
+ logging.info("Case-sensitive results: %f", score)
+
+
+def define_compute_bleu_flags():
+ """Add flags for computing BLEU score."""
+ flags.DEFINE_string(
+ name="translation",
+ default=None,
+ help=flags_core.help_wrap("File containing translated text."))
+ flags.mark_flag_as_required("translation")
+
+ flags.DEFINE_string(
+ name="reference",
+ default=None,
+ help=flags_core.help_wrap("File containing reference translation."))
+ flags.mark_flag_as_required("reference")
+
+ flags.DEFINE_enum(
+ name="bleu_variant",
+ short_name="bv",
+ default="both",
+ enum_values=["both", "uncased", "cased"],
+ case_sensitive=False,
+ help=flags_core.help_wrap(
+ "Specify one or more BLEU variants to calculate. Variants: \"cased\""
+ ", \"uncased\", or \"both\"."))
+
+
+if __name__ == "__main__":
+ define_compute_bleu_flags()
+ FLAGS = flags.FLAGS
+ app.run(main)
diff --git a/modeling/official/legacy/transformer/compute_bleu_test.py b/modeling/official/legacy/transformer/compute_bleu_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9ae2716a13c47369d5f59527ed22030e03dc2af
--- /dev/null
+++ b/modeling/official/legacy/transformer/compute_bleu_test.py
@@ -0,0 +1,72 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test functions in compute_blue.py."""
+
+import tempfile
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer import compute_bleu
+
+
+class ComputeBleuTest(tf.test.TestCase):
+
+ def _create_temp_file(self, text):
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ with tf.io.gfile.GFile(temp_file.name, "w") as w:
+ w.write(text)
+ return temp_file.name
+
+ def test_bleu_same(self):
+ ref = self._create_temp_file("test 1 two 3\nmore tests!")
+ hyp = self._create_temp_file("test 1 two 3\nmore tests!")
+
+ uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
+ cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
+ self.assertEqual(100, uncased_score)
+ self.assertEqual(100, cased_score)
+
+ def test_bleu_same_different_case(self):
+ ref = self._create_temp_file("Test 1 two 3\nmore tests!")
+ hyp = self._create_temp_file("test 1 two 3\nMore tests!")
+ uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
+ cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
+ self.assertEqual(100, uncased_score)
+ self.assertLess(cased_score, 100)
+
+ def test_bleu_different(self):
+ ref = self._create_temp_file("Testing\nmore tests!")
+ hyp = self._create_temp_file("Dog\nCat")
+ uncased_score = compute_bleu.bleu_wrapper(ref, hyp, False)
+ cased_score = compute_bleu.bleu_wrapper(ref, hyp, True)
+ self.assertLess(uncased_score, 100)
+ self.assertLess(cased_score, 100)
+
+ def test_bleu_tokenize(self):
+ s = "Test0, 1 two, 3"
+ tokenized = compute_bleu.bleu_tokenize(s)
+ self.assertEqual(["Test0", ",", "1", "two", ",", "3"], tokenized)
+
+ def test_bleu_list(self):
+ ref = ["test 1 two 3", "more tests!"]
+ hyp = ["test 1 two 3", "More tests!"]
+ uncased_score = compute_bleu.bleu_on_list(ref, hyp, False)
+ cased_score = compute_bleu.bleu_on_list(ref, hyp, True)
+ self.assertEqual(uncased_score, 100)
+ self.assertLess(cased_score, 100)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/transformer/data_download.py b/modeling/official/legacy/transformer/data_download.py
new file mode 100644
index 0000000000000000000000000000000000000000..3da9359725efabd53b35dfe5e8af36b71ac29c18
--- /dev/null
+++ b/modeling/official/legacy/transformer/data_download.py
@@ -0,0 +1,443 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Download and preprocess WMT17 ende training and evaluation datasets."""
+
+import os
+import random
+import tarfile
+
+# pylint: disable=g-bad-import-order
+
+from absl import app
+from absl import flags
+from absl import logging
+import six
+from six.moves import range
+from six.moves import urllib
+from six.moves import zip
+import tensorflow.compat.v1 as tf
+
+from official.legacy.transformer.utils import tokenizer
+from official.utils.flags import core as flags_core
+# pylint: enable=g-bad-import-order
+
+# Data sources for training/evaluating the transformer translation model.
+# If any of the training sources are changed, then either:
+# 1) use the flag `--search` to find the best min count or
+# 2) update the _TRAIN_DATA_MIN_COUNT constant.
+# min_count is the minimum number of times a token must appear in the data
+# before it is added to the vocabulary. "Best min count" refers to the value
+# that generates a vocabulary set that is closest in size to _TARGET_VOCAB_SIZE.
+_TRAIN_DATA_SOURCES = [
+ {
+ "url": "http://data.statmt.org/wmt17/translation-task/"
+ "training-parallel-nc-v12.tgz",
+ "input": "news-commentary-v12.de-en.en",
+ "target": "news-commentary-v12.de-en.de",
+ },
+ {
+ "url": "http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
+ "input": "commoncrawl.de-en.en",
+ "target": "commoncrawl.de-en.de",
+ },
+ {
+ "url": "http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
+ "input": "europarl-v7.de-en.en",
+ "target": "europarl-v7.de-en.de",
+ },
+]
+# Use pre-defined minimum count to generate subtoken vocabulary.
+_TRAIN_DATA_MIN_COUNT = 6
+
+_EVAL_DATA_SOURCES = [{
+ "url": "http://data.statmt.org/wmt17/translation-task/dev.tgz",
+ "input": "newstest2013.en",
+ "target": "newstest2013.de",
+}]
+
+_TEST_DATA_SOURCES = [{
+ "url": ("https://storage.googleapis.com/cloud-tpu-test-datasets/"
+ "transformer_data/newstest2014.tgz"),
+ "input": "newstest2014.en",
+ "target": "newstest2014.de",
+}]
+
+# Vocabulary constants
+_TARGET_VOCAB_SIZE = 32768 # Number of subtokens in the vocabulary list.
+_TARGET_THRESHOLD = 327 # Accept vocabulary if size is within this threshold
+VOCAB_FILE = "vocab.ende.%d" % _TARGET_VOCAB_SIZE
+
+# Strings to inclue in the generated files.
+_PREFIX = "wmt32k"
+_TRAIN_TAG = "train"
+_EVAL_TAG = "dev" # Following WMT and Tensor2Tensor conventions, in which the
+# evaluation datasets are tagged as "dev" for development.
+
+# Number of files to split train and evaluation data
+_TRAIN_SHARDS = 100
+_EVAL_SHARDS = 1
+
+
+def find_file(path, filename, max_depth=5):
+ """Returns full filepath if the file is in path or a subdirectory."""
+ for root, dirs, files in os.walk(path):
+ if filename in files:
+ return os.path.join(root, filename)
+
+ # Don't search past max_depth
+ depth = root[len(path) + 1:].count(os.sep)
+ if depth > max_depth:
+ del dirs[:] # Clear dirs
+ return None
+
+
+###############################################################################
+# Download and extraction functions
+###############################################################################
+def get_raw_files(raw_dir, data_source):
+ """Return raw files from source.
+
+ Downloads/extracts if needed.
+
+ Args:
+ raw_dir: string directory to store raw files
+ data_source: dictionary with
+ {"url": url of compressed dataset containing input and target files
+ "input": file with data in input language
+ "target": file with data in target language}
+
+ Returns:
+ dictionary with
+ {"inputs": list of files containing data in input language
+ "targets": list of files containing corresponding data in target language
+ }
+ """
+ raw_files = {
+ "inputs": [],
+ "targets": [],
+ } # keys
+ for d in data_source:
+ input_file, target_file = download_and_extract(raw_dir, d["url"],
+ d["input"], d["target"])
+ raw_files["inputs"].append(input_file)
+ raw_files["targets"].append(target_file)
+ return raw_files
+
+
+def download_report_hook(count, block_size, total_size):
+ """Report hook for download progress.
+
+ Args:
+ count: current block number
+ block_size: block size
+ total_size: total size
+ """
+ percent = int(count * block_size * 100 / total_size)
+ print(six.ensure_str("\r%d%%" % percent) + " completed", end="\r")
+
+
+def download_from_url(path, url):
+ """Download content from a url.
+
+ Args:
+ path: string directory where file will be downloaded
+ url: string url
+
+ Returns:
+ Full path to downloaded file
+ """
+ filename = six.ensure_str(url).split("/")[-1]
+ found_file = find_file(path, filename, max_depth=0)
+ if found_file is None:
+ filename = os.path.join(path, filename)
+ logging.info("Downloading from %s to %s.", url, filename)
+ inprogress_filepath = six.ensure_str(filename) + ".incomplete"
+ inprogress_filepath, _ = urllib.request.urlretrieve(
+ url, inprogress_filepath, reporthook=download_report_hook)
+ # Print newline to clear the carriage return from the download progress.
+ print()
+ tf.gfile.Rename(inprogress_filepath, filename)
+ return filename
+ else:
+ logging.info("Already downloaded: %s (at %s).", url, found_file)
+ return found_file
+
+
+def download_and_extract(path, url, input_filename, target_filename):
+ """Extract files from downloaded compressed archive file.
+
+ Args:
+ path: string directory where the files will be downloaded
+ url: url containing the compressed input and target files
+ input_filename: name of file containing data in source language
+ target_filename: name of file containing data in target language
+
+ Returns:
+ Full paths to extracted input and target files.
+
+ Raises:
+ OSError: if the download/extraction fails.
+ """
+ # Check if extracted files already exist in path
+ input_file = find_file(path, input_filename)
+ target_file = find_file(path, target_filename)
+ if input_file and target_file:
+ logging.info("Already downloaded and extracted %s.", url)
+ return input_file, target_file
+
+ # Download archive file if it doesn't already exist.
+ compressed_file = download_from_url(path, url)
+
+ # Extract compressed files
+ logging.info("Extracting %s.", compressed_file)
+ with tarfile.open(compressed_file, "r:gz") as corpus_tar:
+ corpus_tar.extractall(path)
+
+ # Return file paths of the requested files.
+ input_file = find_file(path, input_filename)
+ target_file = find_file(path, target_filename)
+
+ if input_file and target_file:
+ return input_file, target_file
+
+ raise OSError("Download/extraction failed for url %s to path %s" %
+ (url, path))
+
+
+def txt_line_iterator(path):
+ """Iterate through lines of file."""
+ with tf.io.gfile.GFile(path) as f:
+ for line in f:
+ yield line.strip()
+
+
+def compile_files(raw_dir, raw_files, tag):
+ """Compile raw files into a single file for each language.
+
+ Args:
+ raw_dir: Directory containing downloaded raw files.
+ raw_files: Dict containing filenames of input and target data.
+ {"inputs": list of files containing data in input language
+ "targets": list of files containing corresponding data in target language
+ }
+ tag: String to append to the compiled filename.
+
+ Returns:
+ Full path of compiled input and target files.
+ """
+ logging.info("Compiling files with tag %s.", tag)
+ filename = "%s-%s" % (_PREFIX, tag)
+ input_compiled_file = os.path.join(raw_dir,
+ six.ensure_str(filename) + ".lang1")
+ target_compiled_file = os.path.join(raw_dir,
+ six.ensure_str(filename) + ".lang2")
+
+ with tf.io.gfile.GFile(input_compiled_file, mode="w") as input_writer:
+ with tf.io.gfile.GFile(target_compiled_file, mode="w") as target_writer:
+ for i in range(len(raw_files["inputs"])):
+ input_file = raw_files["inputs"][i]
+ target_file = raw_files["targets"][i]
+
+ logging.info("Reading files %s and %s.", input_file, target_file)
+ write_file(input_writer, input_file)
+ write_file(target_writer, target_file)
+ return input_compiled_file, target_compiled_file
+
+
+def write_file(writer, filename):
+ """Write all of lines from file using the writer."""
+ for line in txt_line_iterator(filename):
+ writer.write(line)
+ writer.write("\n")
+
+
+###############################################################################
+# Data preprocessing
+###############################################################################
+def encode_and_save_files(subtokenizer, data_dir, raw_files, tag, total_shards):
+ """Save data from files as encoded Examples in TFrecord format.
+
+ Args:
+ subtokenizer: Subtokenizer object that will be used to encode the strings.
+ data_dir: The directory in which to write the examples
+ raw_files: A tuple of (input, target) data files. Each line in the input and
+ the corresponding line in target file will be saved in a tf.Example.
+ tag: String that will be added onto the file names.
+ total_shards: Number of files to divide the data into.
+
+ Returns:
+ List of all files produced.
+ """
+ # Create a file for each shard.
+ filepaths = [
+ shard_filename(data_dir, tag, n + 1, total_shards)
+ for n in range(total_shards)
+ ]
+
+ if all_exist(filepaths):
+ logging.info("Files with tag %s already exist.", tag)
+ return filepaths
+
+ logging.info("Saving files with tag %s.", tag)
+ input_file = raw_files[0]
+ target_file = raw_files[1]
+
+ # Write examples to each shard in round robin order.
+ tmp_filepaths = [six.ensure_str(fname) + ".incomplete" for fname in filepaths]
+ writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filepaths]
+ counter, shard = 0, 0
+ for counter, (input_line, target_line) in enumerate(
+ zip(txt_line_iterator(input_file), txt_line_iterator(target_file))):
+ if counter > 0 and counter % 100000 == 0:
+ logging.info("\tSaving case %d.", counter)
+ example = dict_to_example({
+ "inputs": subtokenizer.encode(input_line, add_eos=True),
+ "targets": subtokenizer.encode(target_line, add_eos=True)
+ })
+ writers[shard].write(example.SerializeToString())
+ shard = (shard + 1) % total_shards
+ for writer in writers:
+ writer.close()
+
+ for tmp_name, final_name in zip(tmp_filepaths, filepaths):
+ tf.gfile.Rename(tmp_name, final_name)
+
+ logging.info("Saved %d Examples", counter + 1)
+ return filepaths
+
+
+def shard_filename(path, tag, shard_num, total_shards):
+ """Create filename for data shard."""
+ return os.path.join(
+ path, "%s-%s-%.5d-of-%.5d" % (_PREFIX, tag, shard_num, total_shards))
+
+
+def shuffle_records(fname):
+ """Shuffle records in a single file."""
+ logging.info("Shuffling records in file %s", fname)
+
+ # Rename file prior to shuffling
+ tmp_fname = six.ensure_str(fname) + ".unshuffled"
+ tf.gfile.Rename(fname, tmp_fname)
+
+ reader = tf.io.tf_record_iterator(tmp_fname)
+ records = []
+ for record in reader:
+ records.append(record)
+ if len(records) % 100000 == 0:
+ logging.info("\tRead: %d", len(records))
+
+ random.shuffle(records)
+
+ # Write shuffled records to original file name
+ with tf.python_io.TFRecordWriter(fname) as w:
+ for count, record in enumerate(records):
+ w.write(record)
+ if count > 0 and count % 100000 == 0:
+ logging.info("\tWriting record: %d", count)
+
+ tf.gfile.Remove(tmp_fname)
+
+
+def dict_to_example(dictionary):
+ """Converts a dictionary of string->int to a tf.Example."""
+ features = {}
+ for k, v in six.iteritems(dictionary):
+ features[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v))
+ return tf.train.Example(features=tf.train.Features(feature=features))
+
+
+def all_exist(filepaths):
+ """Returns true if all files in the list exist."""
+ for fname in filepaths:
+ if not tf.gfile.Exists(fname):
+ return False
+ return True
+
+
+def make_dir(path):
+ if not tf.gfile.Exists(path):
+ logging.info("Creating directory %s", path)
+ tf.gfile.MakeDirs(path)
+
+
+def main(unused_argv):
+ """Obtain training and evaluation data for the Transformer model."""
+ make_dir(FLAGS.raw_dir)
+ make_dir(FLAGS.data_dir)
+
+ # Download test_data
+ logging.info("Step 1/5: Downloading test data")
+ get_raw_files(FLAGS.data_dir, _TEST_DATA_SOURCES)
+
+ # Get paths of download/extracted training and evaluation files.
+ logging.info("Step 2/5: Downloading data from source")
+ train_files = get_raw_files(FLAGS.raw_dir, _TRAIN_DATA_SOURCES)
+ eval_files = get_raw_files(FLAGS.raw_dir, _EVAL_DATA_SOURCES)
+
+ # Create subtokenizer based on the training files.
+ logging.info("Step 3/5: Creating subtokenizer and building vocabulary")
+ train_files_flat = train_files["inputs"] + train_files["targets"]
+ vocab_file = os.path.join(FLAGS.data_dir, VOCAB_FILE)
+ subtokenizer = tokenizer.Subtokenizer.init_from_files(
+ vocab_file,
+ train_files_flat,
+ _TARGET_VOCAB_SIZE,
+ _TARGET_THRESHOLD,
+ min_count=None if FLAGS.search else _TRAIN_DATA_MIN_COUNT)
+
+ logging.info("Step 4/5: Compiling training and evaluation data")
+ compiled_train_files = compile_files(FLAGS.raw_dir, train_files, _TRAIN_TAG)
+ compiled_eval_files = compile_files(FLAGS.raw_dir, eval_files, _EVAL_TAG)
+
+ # Tokenize and save data as Examples in the TFRecord format.
+ logging.info("Step 5/5: Preprocessing and saving data")
+ train_tfrecord_files = encode_and_save_files(subtokenizer, FLAGS.data_dir,
+ compiled_train_files, _TRAIN_TAG,
+ _TRAIN_SHARDS)
+ encode_and_save_files(subtokenizer, FLAGS.data_dir, compiled_eval_files,
+ _EVAL_TAG, _EVAL_SHARDS)
+
+ for fname in train_tfrecord_files:
+ shuffle_records(fname)
+
+
+def define_data_download_flags():
+ """Add flags specifying data download arguments."""
+ flags.DEFINE_string(
+ name="data_dir",
+ short_name="dd",
+ default="/tmp/translate_ende",
+ help=flags_core.help_wrap(
+ "Directory for where the translate_ende_wmt32k dataset is saved."))
+ flags.DEFINE_string(
+ name="raw_dir",
+ short_name="rd",
+ default="/tmp/translate_ende_raw",
+ help=flags_core.help_wrap(
+ "Path where the raw data will be downloaded and extracted."))
+ flags.DEFINE_bool(
+ name="search",
+ default=False,
+ help=flags_core.help_wrap(
+ "If set, use binary search to find the vocabulary set with size"
+ "closest to the target size (%d)." % _TARGET_VOCAB_SIZE))
+
+
+if __name__ == "__main__":
+ logging.set_verbosity(logging.INFO)
+ define_data_download_flags()
+ FLAGS = flags.FLAGS
+ app.run(main)
diff --git a/modeling/official/legacy/transformer/data_pipeline.py b/modeling/official/legacy/transformer/data_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbcf8694047c167a12b371cef25547e2a85137b6
--- /dev/null
+++ b/modeling/official/legacy/transformer/data_pipeline.py
@@ -0,0 +1,330 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Input pipeline for the transformer model to read, filter, and batch examples.
+
+Two things to note in the pipeline:
+
+1. Batching scheme
+
+ The examples encoded in the TFRecord files contain data in the format:
+ {"inputs": [variable length array of integers],
+ "targets": [variable length array of integers]}
+ Where integers in the arrays refer to tokens in the English and German vocab
+ file (named `vocab.ende.32768`).
+
+ Prior to batching, elements in the dataset are grouped by length (max between
+ "inputs" and "targets" length). Each group is then batched such that:
+ group_batch_size * length <= batch_size.
+
+ Another way to view batch_size is the maximum number of tokens in each batch.
+
+ Once batched, each element in the dataset will have the shape:
+ {"inputs": [group_batch_size, padded_input_length],
+ "targets": [group_batch_size, padded_target_length]}
+ Lengths are padded to the longest "inputs" or "targets" sequence in the batch
+ (padded_input_length and padded_target_length can be different).
+
+ This batching scheme decreases the fraction of padding tokens per training
+ batch, thus improving the training speed significantly.
+
+2. Shuffling
+
+ While training, the dataset is shuffled in two places in the code. The first
+ is the list of training files. Second, while reading records using
+ `parallel_interleave`, the `sloppy` argument is used to generate randomness
+ in the order of the examples.
+"""
+
+import os
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.utils.misc import model_helpers
+
+# Buffer size for reading records from a TFRecord file. Each training file is
+# 7.2 MB, so 8 MB allows an entire file to be kept in memory.
+_READ_RECORD_BUFFER = 8 * 1000 * 1000
+
+# Example grouping constants. Defines length boundaries for each group.
+# These values are the defaults used in Tensor2Tensor.
+_MIN_BOUNDARY = 8
+_BOUNDARY_SCALE = 1.1
+
+
+def _load_records(filename):
+ """Read file and return a dataset of tf.Examples."""
+ return tf.data.TFRecordDataset(filename, buffer_size=_READ_RECORD_BUFFER)
+
+
+def _parse_example(serialized_example):
+ """Return inputs and targets Tensors from a serialized tf.Example."""
+ data_fields = {
+ "inputs": tf.io.VarLenFeature(tf.int64),
+ "targets": tf.io.VarLenFeature(tf.int64)
+ }
+ parsed = tf.io.parse_single_example(serialized_example, data_fields)
+ inputs = tf.sparse.to_dense(parsed["inputs"])
+ targets = tf.sparse.to_dense(parsed["targets"])
+ return inputs, targets
+
+
+def _filter_max_length(example, max_length=256):
+ """Indicates whether the example's length is lower than the maximum length."""
+ return tf.logical_and(
+ tf.size(example[0]) <= max_length,
+ tf.size(example[1]) <= max_length)
+
+
+def _get_example_length(example):
+ """Returns the maximum length between the example inputs and targets."""
+ length = tf.maximum(tf.shape(example[0])[0], tf.shape(example[1])[0])
+ return length
+
+
+def _create_min_max_boundaries(max_length,
+ min_boundary=_MIN_BOUNDARY,
+ boundary_scale=_BOUNDARY_SCALE):
+ """Create min and max boundary lists up to max_length.
+
+ For example, when max_length=24, min_boundary=4 and boundary_scale=2, the
+ returned values will be:
+ buckets_min = [0, 4, 8, 16, 24]
+ buckets_max = [4, 8, 16, 24, 25]
+
+ Args:
+ max_length: The maximum length of example in dataset.
+ min_boundary: Minimum length in boundary.
+ boundary_scale: Amount to scale consecutive boundaries in the list.
+
+ Returns:
+ min and max boundary lists
+
+ """
+ # Create bucket boundaries list by scaling the previous boundary or adding 1
+ # (to ensure increasing boundary sizes).
+ bucket_boundaries = []
+ x = min_boundary
+ while x < max_length:
+ bucket_boundaries.append(x)
+ x = max(x + 1, int(x * boundary_scale))
+
+ # Create min and max boundary lists from the initial list.
+ buckets_min = [0] + bucket_boundaries
+ buckets_max = bucket_boundaries + [max_length + 1]
+ return buckets_min, buckets_max
+
+
+def _batch_examples(dataset, batch_size, max_length):
+ """Group examples by similar lengths, and return batched dataset.
+
+ Each batch of similar-length examples are padded to the same length, and may
+ have different number of elements in each batch, such that:
+ group_batch_size * padded_length <= batch_size.
+
+ This decreases the number of padding tokens per batch, which improves the
+ training speed.
+
+ Args:
+ dataset: Dataset of unbatched examples.
+ batch_size: Max number of tokens per batch of examples.
+ max_length: Max number of tokens in an example input or target sequence.
+
+ Returns:
+ Dataset of batched examples with similar lengths.
+ """
+ # Get min and max boundary lists for each example. These are used to calculate
+ # the `bucket_id`, which is the index at which:
+ # buckets_min[bucket_id] <= len(example) < buckets_max[bucket_id]
+ # Note that using both min and max lists improves the performance.
+ buckets_min, buckets_max = _create_min_max_boundaries(max_length)
+
+ # Create list of batch sizes for each bucket_id, so that
+ # bucket_batch_size[bucket_id] * buckets_max[bucket_id] <= batch_size
+ bucket_batch_sizes = [int(batch_size) // x for x in buckets_max]
+ # bucket_id will be a tensor, so convert this list to a tensor as well.
+ bucket_batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
+
+ def example_to_bucket_id(example_input, example_target):
+ """Return int64 bucket id for this example, calculated based on length."""
+ seq_length = _get_example_length((example_input, example_target))
+
+ # TODO(xunkai): investigate if removing code branching improves performance.
+ conditions_c = tf.logical_and(
+ tf.less_equal(buckets_min, seq_length), tf.less(seq_length,
+ buckets_max))
+ bucket_id = tf.reduce_min(tf.where(conditions_c))
+ return bucket_id
+
+ def window_size_fn(bucket_id):
+ """Return number of examples to be grouped when given a bucket id."""
+ return bucket_batch_sizes[bucket_id]
+
+ def batching_fn(bucket_id, grouped_dataset):
+ """Batch and add padding to a dataset of elements with similar lengths."""
+ bucket_batch_size = window_size_fn(bucket_id)
+
+ # Batch the dataset and add padding so that all input sequences in the
+ # examples have the same length, and all target sequences have the same
+ # lengths as well. Resulting lengths of inputs and targets can differ.
+ return grouped_dataset.padded_batch(bucket_batch_size, ([None], [None]))
+
+ return dataset.apply(
+ tf.data.experimental.group_by_window(
+ key_func=example_to_bucket_id,
+ reduce_func=batching_fn,
+ window_size=None,
+ window_size_func=window_size_fn))
+
+
+def _read_and_batch_from_files(file_pattern,
+ batch_size,
+ max_length,
+ max_io_parallelism,
+ shuffle,
+ repeat,
+ static_batch=False,
+ num_replicas=1,
+ ctx=None):
+ """Create dataset where each item is a dict of "inputs" and "targets".
+
+ Args:
+ file_pattern: String used to match the input TFRecord files.
+ batch_size: Maximum number of tokens per global batch of examples.
+ max_length: Maximum number of tokens per example
+ max_io_parallelism: Max number of cpu cores for parallel input processing.
+ shuffle: If true, randomizes order of elements.
+ repeat: Number of times to repeat the dataset. If None, the dataset is
+ repeated forever.
+ static_batch: Whether the batches in the dataset should have static shapes.
+ If True, the input is batched so that every batch has the shape
+ [batch_size // max_length, max_length]. If False, the input is grouped by
+ length, and batched so that batches may have different
+ shapes [N, M], where: N * M <= batch_size M <= max_length In general, this
+ setting should be False. Dynamic shapes allow the inputs to be grouped
+ so that the number of padding tokens is minimized, and helps model
+ training. In cases where the input shape must be static (e.g. running on
+ TPU), this setting should be set to True.
+ num_replicas: Number of GPUs or other workers. We will generate global
+ batches, and each global batch is equally divisible by number of replicas.
+ Currently it is only effective when static_batch==True. TODO: make it
+ effective when static_batch=False.
+ ctx: Input context.
+
+ Returns:
+ tf.data.Dataset object containing examples loaded from the files.
+ """
+ dataset = tf.data.Dataset.list_files(file_pattern, shuffle=shuffle)
+
+ if ctx and ctx.num_input_pipelines > 1:
+ logging.info("Shard %d of the dataset.", ctx.input_pipeline_id)
+ dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id)
+
+ # Read files and interleave results. When training, the order of the examples
+ # will be non-deterministic.
+ options = tf.data.Options()
+ options.experimental_deterministic = False
+ dataset = dataset.interleave(
+ _load_records,
+ cycle_length=max_io_parallelism,
+ num_parallel_calls=tf.data.experimental.AUTOTUNE).with_options(options)
+
+ # Parse each tf.Example into a dictionary
+ # TODO: Look into prefetch_input_elements for performance optimization. # pylint: disable=g-bad-todo
+ dataset = dataset.map(
+ _parse_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+ # Remove examples where the input or target length exceeds the maximum length,
+ dataset = dataset.filter(lambda x, y: _filter_max_length((x, y), max_length))
+
+ if static_batch:
+ dataset = dataset.padded_batch(
+ # First calculate batch size (token number) per worker, then divide it
+ # into sentences, and finally expand to a global batch. It could prove
+ # the global batch divisble for distribution strategy.
+ int(batch_size // num_replicas // max_length * num_replicas),
+ ([max_length], [max_length]),
+ drop_remainder=True)
+ else:
+ # Group and batch such that each batch has examples of similar length.
+ # TODO(xunkai): _batch_examples might need to do something special for
+ # num_replicas.
+ dataset = _batch_examples(dataset, batch_size, max_length)
+
+ dataset = dataset.repeat(repeat)
+
+ # Prefetch the next element to improve speed of input pipeline.
+ dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
+ return dataset
+
+
+def _generate_synthetic_data(params):
+ """Create synthetic data based on the parameter batch size."""
+ batch_size = int(params["batch_size"] // params["max_length"])
+ length = params["max_length"]
+ dataset = model_helpers.generate_synthetic_data(
+ input_shape=tf.TensorShape([length]),
+ input_value=1,
+ input_dtype=tf.int64,
+ label_shape=tf.TensorShape([length]),
+ label_value=1,
+ label_dtype=tf.int64,
+ )
+ if params["static_batch"]:
+ dataset = dataset.batch(batch_size, drop_remainder=True)
+ else:
+ dataset = dataset.padded_batch(batch_size, ([None], [None]))
+ return dataset
+
+
+def train_input_fn(params, ctx=None):
+ """Load and return dataset of batched examples for use during training."""
+ file_pattern = os.path.join(params["data_dir"] or "", "*train*")
+ if params["use_synthetic_data"]:
+ return _generate_synthetic_data(params)
+ return _read_and_batch_from_files(
+ file_pattern,
+ params["batch_size"],
+ params["max_length"],
+ params["max_io_parallelism"],
+ shuffle=True,
+ repeat=params["repeat_dataset"],
+ static_batch=params["static_batch"],
+ num_replicas=params["num_gpus"],
+ ctx=ctx)
+
+
+def eval_input_fn(params, ctx=None):
+ """Load and return dataset of batched examples for use during evaluation."""
+ file_pattern = os.path.join(params["data_dir"] or "", "*dev*")
+ if params["use_synthetic_data"]:
+ return _generate_synthetic_data(params)
+ return _read_and_batch_from_files(
+ file_pattern,
+ params["batch_size"],
+ params["max_length"],
+ params["max_io_parallelism"],
+ shuffle=False,
+ repeat=1,
+ static_batch=params["static_batch"],
+ num_replicas=params["num_gpus"],
+ ctx=ctx)
+
+
+def map_data_for_transformer_fn(x, y):
+ """Maps data for training, and handles weried behaviors for different vers."""
+ # Will transform input x and targets y into tuple(x, y) as new model inputs.
+ # For TF v2, the 2nd parameter is omitted to make Keras training work.
+ return ((x, y),)
diff --git a/modeling/official/legacy/transformer/embedding_layer.py b/modeling/official/legacy/transformer/embedding_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dbc285552d35b98d99b73d3057cedf6c628bb14
--- /dev/null
+++ b/modeling/official/legacy/transformer/embedding_layer.py
@@ -0,0 +1,102 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Implementation of embedding layer with shared weights."""
+
+import tensorflow as tf, tf_keras
+
+
+class EmbeddingSharedWeights(tf_keras.layers.Layer):
+ """Calculates input embeddings and pre-softmax linear with shared weights."""
+
+ def __init__(self, vocab_size, hidden_size):
+ """Specify characteristic parameters of embedding layer.
+
+ Args:
+ vocab_size: Number of tokens in the embedding. (Typically ~32,000)
+ hidden_size: Dimensionality of the embedding. (Typically 512 or 1024)
+ """
+ super(EmbeddingSharedWeights, self).__init__()
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+
+ def build(self, input_shape):
+ """Build embedding layer."""
+ with tf.name_scope("embedding_and_softmax"):
+ # Create and initialize weights. The random normal initializer was chosen
+ # arbitrarily, and works well.
+ self.shared_weights = self.add_weight(
+ "weights",
+ shape=[self.vocab_size, self.hidden_size],
+ dtype=tf.float32,
+ initializer=tf.random_normal_initializer(
+ mean=0., stddev=self.hidden_size**-0.5))
+ super(EmbeddingSharedWeights, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "vocab_size": self.vocab_size,
+ "hidden_size": self.hidden_size,
+ }
+
+ def call(self, inputs, mode="embedding"):
+ """Get token embeddings of inputs.
+
+ Args:
+ inputs: An int64 tensor with shape [batch_size, length]
+ mode: string, a valid value is one of "embedding" and "linear".
+
+ Returns:
+ outputs: (1) If mode == "embedding", output embedding tensor, float32 with
+ shape [batch_size, length, embedding_size]; (2) mode == "linear", output
+ linear tensor, float32 with shape [batch_size, length, vocab_size].
+ Raises:
+ ValueError: if mode is not valid.
+ """
+ if mode == "embedding":
+ return self._embedding(inputs)
+ elif mode == "linear":
+ return self._linear(inputs)
+ else:
+ raise ValueError("mode {} is not valid.".format(mode))
+
+ def _embedding(self, inputs):
+ """Applies embedding based on inputs tensor."""
+ with tf.name_scope("embedding"):
+ # Create binary mask of size [batch_size, length]
+ embeddings = tf.gather(self.shared_weights, inputs)
+ # mask = tf.cast(tf.not_equal(inputs, 0), embeddings.dtype)
+ # embeddings *= tf.expand_dims(mask, -1)
+ # Scale embedding by the sqrt of the hidden size
+ embeddings *= self.hidden_size**0.5
+
+ return embeddings
+
+ def _linear(self, inputs):
+ """Computes logits by running inputs through a linear layer.
+
+ Args:
+ inputs: A float32 tensor with shape [batch_size, length, hidden_size]
+
+ Returns:
+ float32 tensor with shape [batch_size, length, vocab_size].
+ """
+ with tf.name_scope("presoftmax_linear"):
+ batch_size = tf.shape(inputs)[0]
+ length = tf.shape(inputs)[1]
+
+ x = tf.reshape(inputs, [-1, self.hidden_size])
+ logits = tf.matmul(x, self.shared_weights, transpose_b=True)
+
+ return tf.reshape(logits, [batch_size, length, self.vocab_size])
diff --git a/modeling/official/legacy/transformer/ffn_layer.py b/modeling/official/legacy/transformer/ffn_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..772eb694df6f6f1071ecb663b211260e7cdd7fbe
--- /dev/null
+++ b/modeling/official/legacy/transformer/ffn_layer.py
@@ -0,0 +1,71 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Implementation of fully connected network."""
+
+import tensorflow as tf, tf_keras
+
+
+class FeedForwardNetwork(tf_keras.layers.Layer):
+ """Fully connected feedforward network."""
+
+ def __init__(self, hidden_size, filter_size, relu_dropout):
+ """Initialize FeedForwardNetwork.
+
+ Args:
+ hidden_size: int, output dim of hidden layer.
+ filter_size: int, filter size for the inner (first) dense layer.
+ relu_dropout: float, dropout rate for training.
+ """
+ super(FeedForwardNetwork, self).__init__()
+ self.hidden_size = hidden_size
+ self.filter_size = filter_size
+ self.relu_dropout = relu_dropout
+
+ def build(self, input_shape):
+ self.filter_dense_layer = tf_keras.layers.Dense(
+ self.filter_size,
+ use_bias=True,
+ activation=tf.nn.relu,
+ name="filter_layer")
+ self.output_dense_layer = tf_keras.layers.Dense(
+ self.hidden_size, use_bias=True, name="output_layer")
+ super(FeedForwardNetwork, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "hidden_size": self.hidden_size,
+ "filter_size": self.filter_size,
+ "relu_dropout": self.relu_dropout,
+ }
+
+ def call(self, x, training):
+ """Return outputs of the feedforward network.
+
+ Args:
+ x: tensor with shape [batch_size, length, hidden_size]
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ Output of the feedforward network.
+ tensor with shape [batch_size, length, hidden_size]
+ """
+ # Retrieve dynamically known shapes
+
+ output = self.filter_dense_layer(x)
+ if training:
+ output = tf.nn.dropout(output, rate=self.relu_dropout)
+ output = self.output_dense_layer(output)
+
+ return output
diff --git a/modeling/official/legacy/transformer/metrics.py b/modeling/official/legacy/transformer/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..e70060207c6733d01a53c5ff1dd73a8257798ecd
--- /dev/null
+++ b/modeling/official/legacy/transformer/metrics.py
@@ -0,0 +1,180 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions for calculating loss, accuracy, and other model metrics.
+
+Metrics:
+ - Padded loss, accuracy, and negative log perplexity. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
+ - BLEU approximation. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
+ - ROUGE score. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
+"""
+
+import functools
+
+import tensorflow as tf, tf_keras
+
+
+def _pad_tensors_to_same_length(x, y):
+ """Pad x and y so that the results have the same length (second dimension)."""
+ with tf.name_scope("pad_to_same_length"):
+ x_length = tf.shape(x)[1]
+ y_length = tf.shape(y)[1]
+
+ max_length = tf.maximum(x_length, y_length)
+
+ x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
+ y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
+ return x, y
+
+
+def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
+ """Calculate cross entropy loss while ignoring padding.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch_size, length_labels]
+ smoothing: Label smoothing constant, used to determine the on and off values
+ vocab_size: int size of the vocabulary
+
+ Returns:
+ Returns the cross entropy loss and weight tensors: float32 tensors with
+ shape [batch_size, max(length_logits, length_labels)]
+ """
+ with tf.name_scope("loss"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+
+ # Calculate smoothing cross entropy
+ with tf.name_scope("smoothing_cross_entropy"):
+ confidence = 1.0 - smoothing
+ low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32)
+ soft_targets = tf.one_hot(
+ tf.cast(labels, tf.int32),
+ depth=vocab_size,
+ on_value=confidence,
+ off_value=low_confidence)
+ xentropy = tf.nn.softmax_cross_entropy_with_logits(
+ logits=logits, labels=soft_targets)
+
+ # Calculate the best (lowest) possible value of cross entropy, and
+ # subtract from the cross entropy loss.
+ normalizing_constant = -(
+ confidence * tf.math.log(confidence) +
+ tf.cast(vocab_size - 1, tf.float32) * low_confidence *
+ tf.math.log(low_confidence + 1e-20))
+ xentropy -= normalizing_constant
+
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ return xentropy * weights, weights
+
+
+def padded_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels on non-0s."""
+ with tf.name_scope("padded_accuracy"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ return tf.cast(tf.equal(outputs, padded_labels), tf.float32), weights
+
+
+def padded_accuracy_topk(logits, labels, k):
+ """Percentage of times that top-k predictions matches labels on non-0s."""
+ with tf.name_scope("padded_accuracy_topk"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ effective_k = tf.minimum(k, tf.shape(logits)[-1])
+ _, outputs = tf.nn.top_k(logits, k=effective_k)
+ outputs = tf.cast(outputs, tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ padded_labels = tf.expand_dims(padded_labels, axis=-1)
+ padded_labels += tf.zeros_like(outputs) # Pad to same shape.
+ same = tf.cast(tf.equal(outputs, padded_labels), tf.float32)
+ same_topk = tf.reduce_sum(same, axis=-1)
+ return same_topk, weights
+
+
+def padded_accuracy_top5(logits, labels):
+ return padded_accuracy_topk(logits, labels, 5)
+
+
+def padded_sequence_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels everywhere (non-0)."""
+ with tf.name_scope("padded_sequence_accuracy"):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ not_correct = tf.cast(tf.not_equal(outputs, padded_labels),
+ tf.float32) * weights
+ axis = list(range(1, len(outputs.get_shape())))
+ correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
+ return correct_seq, tf.constant(1.0)
+
+
+def padded_neg_log_perplexity(logits, labels, vocab_size):
+ """Average log-perplexity excluding padding 0s. No smoothing."""
+ num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
+ return -num, den
+
+
+class MetricLayer(tf_keras.layers.Layer):
+ """Custom a layer of metrics for Transformer model."""
+
+ def __init__(self, vocab_size):
+ super(MetricLayer, self).__init__()
+ self.vocab_size = vocab_size
+ self.metric_mean_fns = []
+
+ def build(self, input_shape):
+ """"Builds metric layer."""
+ neg_log_perplexity = functools.partial(
+ padded_neg_log_perplexity, vocab_size=self.vocab_size)
+ self.metric_mean_fns = [
+ (tf_keras.metrics.Mean("accuracy"), padded_accuracy),
+ (tf_keras.metrics.Mean("accuracy_top5"), padded_accuracy_top5),
+ (tf_keras.metrics.Mean("accuracy_per_sequence"),
+ padded_sequence_accuracy),
+ (tf_keras.metrics.Mean("neg_log_perplexity"), neg_log_perplexity),
+ ]
+ super(MetricLayer, self).build(input_shape)
+
+ def get_config(self):
+ return {"vocab_size": self.vocab_size}
+
+ def call(self, inputs):
+ logits, targets = inputs[0], inputs[1]
+ for mean, fn in self.metric_mean_fns:
+ m = mean(*fn(logits, targets))
+ self.add_metric(m)
+ return logits
+
+
+def transformer_loss(logits, labels, smoothing, vocab_size):
+ """Calculates total loss containing cross entropy with padding ignored.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch_size, length_labels]
+ smoothing: Label smoothing constant, used to determine the on and off values
+ vocab_size: int size of the vocabulary
+
+ Returns:
+ A scalar float tensor for loss.
+ """
+ xentropy, weights = padded_cross_entropy_loss(logits, labels, smoothing,
+ vocab_size)
+ return tf.reduce_sum(xentropy) / tf.reduce_sum(weights)
diff --git a/modeling/official/legacy/transformer/misc.py b/modeling/official/legacy/transformer/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..66f2e059d376f4253f1443c1e622fb00bd0080b2
--- /dev/null
+++ b/modeling/official/legacy/transformer/misc.py
@@ -0,0 +1,288 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Misc for Transformer."""
+
+# pylint: disable=g-bad-import-order
+
+from absl import flags
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer import model_params
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+
+FLAGS = flags.FLAGS
+
+PARAMS_MAP = {
+ 'tiny': model_params.TINY_PARAMS,
+ 'base': model_params.BASE_PARAMS,
+ 'big': model_params.BIG_PARAMS,
+}
+
+
+def get_model_params(param_set, num_gpus):
+ """Gets predefined model params."""
+ if num_gpus > 1:
+ if param_set == 'big':
+ return model_params.BIG_MULTI_GPU_PARAMS.copy()
+ elif param_set == 'base':
+ return model_params.BASE_MULTI_GPU_PARAMS.copy()
+ else:
+ raise ValueError('Not valid params: param_set={} num_gpus={}'.format(
+ param_set, num_gpus))
+
+ return PARAMS_MAP[param_set].copy()
+
+
+def define_transformer_flags():
+ """Add flags and flag validators for running transformer_main."""
+ # Add common flags (data_dir, model_dir, etc.).
+ flags_core.define_base(num_gpu=True, distribution_strategy=True)
+ flags_core.define_performance(
+ num_parallel_calls=True,
+ inter_op=False,
+ intra_op=False,
+ synthetic_data=True,
+ max_train_steps=False,
+ dtype=True,
+ loss_scale=True,
+ all_reduce_alg=True,
+ num_packs=True,
+ tf_gpu_thread_mode=True,
+ datasets_num_private_threads=True,
+ enable_xla=True,
+ fp16_implementation=True)
+
+ flags_core.define_benchmark()
+ flags_core.define_device(tpu=True)
+
+ flags.DEFINE_integer(
+ name='train_steps',
+ short_name='ts',
+ default=300000,
+ help=flags_core.help_wrap('The number of steps used to train.'))
+ flags.DEFINE_integer(
+ name='steps_between_evals',
+ short_name='sbe',
+ default=5000,
+ help=flags_core.help_wrap(
+ 'The Number of training steps to run between evaluations. This is '
+ 'used if --train_steps is defined.'))
+ flags.DEFINE_boolean(
+ name='enable_time_history',
+ default=True,
+ help='Whether to enable TimeHistory callback.')
+ flags.DEFINE_boolean(
+ name='enable_tensorboard',
+ default=False,
+ help='Whether to enable Tensorboard callback.')
+ flags.DEFINE_boolean(
+ name='enable_metrics_in_training',
+ default=False,
+ help='Whether to enable metrics during training.')
+ flags.DEFINE_boolean(
+ name='enable_mlir_bridge',
+ default=False,
+ help='Whether to enable the TF to XLA bridge.')
+ # Set flags from the flags_core module as 'key flags' so they're listed when
+ # the '-h' flag is used. Without this line, the flags defined above are
+ # only shown in the full `--helpful` help text.
+ flags.adopt_module_key_flags(flags_core)
+
+ # Add transformer-specific flags
+ flags.DEFINE_enum(
+ name='param_set',
+ short_name='mp',
+ default='big',
+ enum_values=PARAMS_MAP.keys(),
+ help=flags_core.help_wrap(
+ 'Parameter set to use when creating and training the model. The '
+ 'parameters define the input shape (batch size and max length), '
+ 'model configuration (size of embedding, # of hidden layers, etc.), '
+ 'and various other settings. The big parameter set increases the '
+ 'default batch size, embedding/hidden size, and filter size. For a '
+ 'complete list of parameters, please see model/model_params.py.'))
+
+ flags.DEFINE_bool(
+ name='static_batch',
+ short_name='sb',
+ default=False,
+ help=flags_core.help_wrap(
+ 'Whether the batches in the dataset should have static shapes. In '
+ 'general, this setting should be False. Dynamic shapes allow the '
+ 'inputs to be grouped so that the number of padding tokens is '
+ 'minimized, and helps model training. In cases where the input shape '
+ 'must be static (e.g. running on TPU), this setting will be ignored '
+ 'and static batching will always be used.'))
+ flags.DEFINE_integer(
+ name='max_length',
+ short_name='ml',
+ default=256,
+ help=flags_core.help_wrap(
+ 'Max sentence length for Transformer. Default is 256. Note: Usually '
+ 'it is more effective to use a smaller max length if static_batch is '
+ 'enabled, e.g. 64.'))
+
+ # Flags for training with steps (may be used for debugging)
+ flags.DEFINE_integer(
+ name='validation_steps',
+ short_name='vs',
+ default=64,
+ help=flags_core.help_wrap('The number of steps used in validation.'))
+
+ # BLEU score computation
+ flags.DEFINE_string(
+ name='bleu_source',
+ short_name='bls',
+ default=None,
+ help=flags_core.help_wrap(
+ 'Path to source file containing text translate when calculating the '
+ 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
+ ))
+ flags.DEFINE_string(
+ name='bleu_ref',
+ short_name='blr',
+ default=None,
+ help=flags_core.help_wrap(
+ 'Path to source file containing text translate when calculating the '
+ 'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
+ ))
+ flags.DEFINE_string(
+ name='vocab_file',
+ short_name='vf',
+ default=None,
+ help=flags_core.help_wrap(
+ 'Path to subtoken vocabulary file. If data_download.py was used to '
+ 'download and encode the training data, look in the data_dir to find '
+ 'the vocab file.'))
+ flags.DEFINE_string(
+ name='mode',
+ default='train',
+ help=flags_core.help_wrap('mode: train, eval, or predict'))
+ flags.DEFINE_bool(
+ name='use_ctl',
+ default=False,
+ help=flags_core.help_wrap(
+ 'Whether the model runs with custom training loop.'))
+ flags.DEFINE_integer(
+ name='decode_batch_size',
+ default=32,
+ help=flags_core.help_wrap(
+ 'Global batch size used for Transformer autoregressive decoding on '
+ 'TPU.'))
+ flags.DEFINE_integer(
+ name='decode_max_length',
+ default=97,
+ help=flags_core.help_wrap(
+ 'Max sequence length of the decode/eval data. This is used by '
+ 'Transformer autoregressive decoding on TPU to have minimum '
+ 'paddings.'))
+ flags.DEFINE_bool(
+ name='padded_decode',
+ default=False,
+ help=flags_core.help_wrap(
+ 'Whether the autoregressive decoding runs with input data padded to '
+ 'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
+ 'set due the static shape requirement. Although CPU/GPU could also '
+ 'use padded_decode, it has not been tested. In addition, this method '
+ 'will introduce unnecessary overheads which grow quadratically with '
+ 'the max sequence length.'))
+ flags.DEFINE_bool(
+ name='enable_checkpointing',
+ default=True,
+ help=flags_core.help_wrap(
+ 'Whether to do checkpointing during training. When running under '
+ 'benchmark harness, we will avoid checkpointing.'))
+ flags.DEFINE_bool(
+ name='save_weights_only',
+ default=True,
+ help=flags_core.help_wrap(
+ 'Only used when above `enable_checkpointing` is True. '
+ 'If True, then only the model\'s weights will be saved '
+ '(`model.save_weights(filepath)`), else the full model is saved '
+ '(`model.save(filepath)`)'))
+
+ flags_core.set_defaults(
+ data_dir='/tmp/translate_ende',
+ model_dir='/tmp/transformer_model',
+ batch_size=None)
+
+ # pylint: disable=unused-variable
+ @flags.multi_flags_validator(
+ ['bleu_source', 'bleu_ref'],
+ message='Both or neither --bleu_source and --bleu_ref must be defined.')
+ def _check_bleu_files(flags_dict):
+ return (flags_dict['bleu_source'] is None) == (
+ flags_dict['bleu_ref'] is None)
+
+ @flags.multi_flags_validator(
+ ['bleu_source', 'bleu_ref', 'vocab_file'],
+ message='--vocab_file must be defined if --bleu_source and --bleu_ref '
+ 'are defined.')
+ def _check_bleu_vocab_file(flags_dict):
+ if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
+ return flags_dict['vocab_file'] is not None
+ return True
+
+ # pylint: enable=unused-variable
+
+
+def get_callbacks():
+ """Returns common callbacks."""
+ callbacks = []
+ if FLAGS.enable_time_history:
+ time_callback = keras_utils.TimeHistory(
+ FLAGS.batch_size,
+ FLAGS.log_steps,
+ logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
+ callbacks.append(time_callback)
+
+ if FLAGS.enable_tensorboard:
+ tensorboard_callback = tf_keras.callbacks.TensorBoard(
+ log_dir=FLAGS.model_dir)
+ callbacks.append(tensorboard_callback)
+
+ return callbacks
+
+
+def update_stats(history, stats, callbacks):
+ """Normalizes and updates dictionary of stats.
+
+ Args:
+ history: Results of the training step.
+ stats: Dict with pre-existing training stats.
+ callbacks: a list of callbacks which might include a time history callback
+ used during keras.fit.
+ """
+
+ if history and history.history:
+ train_hist = history.history
+ # Gets final loss from training.
+ stats['loss'] = float(train_hist['loss'][-1])
+
+ if not callbacks:
+ return
+
+ # Look for the time history callback which was used during keras.fit
+ for callback in callbacks:
+ if isinstance(callback, keras_utils.TimeHistory):
+ timestamp_log = callback.timestamp_log
+ stats['step_timestamp_log'] = timestamp_log
+ stats['train_finish_time'] = callback.train_finish_time
+ if len(timestamp_log) > 1:
+ stats['avg_exp_per_second'] = (
+ callback.batch_size * callback.log_steps *
+ (len(callback.timestamp_log) - 1) /
+ (timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
diff --git a/modeling/official/legacy/transformer/model_params.py b/modeling/official/legacy/transformer/model_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf1ae3f9d0a16b101262118b7cce1e461ce3f46b
--- /dev/null
+++ b/modeling/official/legacy/transformer/model_params.py
@@ -0,0 +1,96 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines Transformer model parameters."""
+
+import collections
+
+
+BASE_PARAMS = collections.defaultdict(
+ lambda: None, # Set default value to None.
+
+ # Input params
+ default_batch_size=2048, # Maximum number of tokens per batch of examples.
+ default_batch_size_tpu=32768,
+ max_length=256, # Maximum number of tokens per example.
+
+ # Model params
+ initializer_gain=1.0, # Used in trainable variable initialization.
+ vocab_size=33708, # Number of tokens defined in the vocabulary file.
+ hidden_size=512, # Model dimension in the hidden layers.
+ num_hidden_layers=6, # Number of layers in the encoder and decoder stacks.
+ num_heads=8, # Number of heads to use in multi-headed attention.
+ filter_size=2048, # Inner layer dimension in the feedforward network.
+
+ # Dropout values (only used when training)
+ layer_postprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+
+ # Training params
+ label_smoothing=0.1,
+ learning_rate=2.0,
+ learning_rate_decay_rate=1.0,
+ learning_rate_warmup_steps=16000,
+
+ # Optimizer params
+ optimizer_adam_beta1=0.9,
+ optimizer_adam_beta2=0.997,
+ optimizer_adam_epsilon=1e-09,
+
+ # Default prediction params
+ extra_decode_length=50,
+ beam_size=4,
+ alpha=0.6, # used to calculate length normalization in beam search
+
+ # TPU specific parameters
+ use_tpu=False,
+ static_batch=False,
+ allow_ffn_pad=True,
+)
+
+BIG_PARAMS = BASE_PARAMS.copy()
+BIG_PARAMS.update(
+ default_batch_size=4096,
+
+ # default batch size is smaller than for BASE_PARAMS due to memory limits.
+ default_batch_size_tpu=16384,
+
+ hidden_size=1024,
+ filter_size=4096,
+ num_heads=16,
+)
+
+# Parameters for running the model in multi gpu. These should not change the
+# params that modify the model shape (such as the hidden_size or num_heads).
+BASE_MULTI_GPU_PARAMS = BASE_PARAMS.copy()
+BASE_MULTI_GPU_PARAMS.update(
+ learning_rate_warmup_steps=8000
+)
+
+BIG_MULTI_GPU_PARAMS = BIG_PARAMS.copy()
+BIG_MULTI_GPU_PARAMS.update(
+ layer_postprocess_dropout=0.3,
+ learning_rate_warmup_steps=8000
+)
+
+# Parameters for testing the model
+TINY_PARAMS = BASE_PARAMS.copy()
+TINY_PARAMS.update(
+ default_batch_size=1024,
+ default_batch_size_tpu=1024,
+ hidden_size=32,
+ num_heads=4,
+ filter_size=256,
+)
diff --git a/modeling/official/legacy/transformer/model_utils.py b/modeling/official/legacy/transformer/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5867a603783200da56efd5c8ab43a7ac08f457ef
--- /dev/null
+++ b/modeling/official/legacy/transformer/model_utils.py
@@ -0,0 +1,121 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transformer model helper methods."""
+
+import math
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+# Very low numbers to represent -infinity. We do not actually use -Inf, since we
+# want to be able to multiply these values by zero to get zero. (-Inf * 0 = NaN)
+_NEG_INF_FP32 = -1e9
+_NEG_INF_FP16 = np.finfo(np.float16).min
+
+
+def get_position_encoding(length,
+ hidden_size,
+ min_timescale=1.0,
+ max_timescale=1.0e4):
+ """Return positional encoding.
+
+ Calculates the position encoding as a mix of sine and cosine functions with
+ geometrically increasing wavelengths.
+ Defined and formulized in Attention is All You Need, section 3.5.
+
+ Args:
+ length: Sequence length.
+ hidden_size: Size of the
+ min_timescale: Minimum scale that will be applied at each position
+ max_timescale: Maximum scale that will be applied at each position
+
+ Returns:
+ Tensor with shape [length, hidden_size]
+ """
+ # We compute the positional encoding in float32 even if the model uses
+ # float16, as many of the ops used, like log and exp, are numerically unstable
+ # in float16.
+ position = tf.cast(tf.range(length), tf.float32)
+ num_timescales = hidden_size // 2
+ log_timescale_increment = (
+ math.log(float(max_timescale) / float(min_timescale)) /
+ (tf.cast(num_timescales, tf.float32) - 1))
+ inv_timescales = min_timescale * tf.exp(
+ tf.cast(tf.range(num_timescales), tf.float32) * -log_timescale_increment)
+ scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0)
+ signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
+ return signal
+
+
+def get_decoder_self_attention_bias(length, dtype=tf.float32):
+ """Calculate bias for decoder that maintains model's autoregressive property.
+
+ Creates a tensor that masks out locations that correspond to illegal
+ connections, so prediction at position i cannot draw information from future
+ positions.
+
+ Args:
+ length: int length of sequences in batch.
+ dtype: The dtype of the return value.
+
+ Returns:
+ float tensor of shape [1, 1, length, length]
+ """
+ neg_inf = _NEG_INF_FP16 if dtype == tf.float16 else _NEG_INF_FP32
+ with tf.name_scope("decoder_self_attention_bias"):
+ valid_locs = tf.linalg.band_part(
+ tf.ones([length, length], dtype=dtype), -1, 0)
+ valid_locs = tf.reshape(valid_locs, [1, 1, length, length])
+ decoder_bias = neg_inf * (1.0 - valid_locs)
+ return decoder_bias
+
+
+def get_padding(x, padding_value=0, dtype=tf.float32):
+ """Return float tensor representing the padding values in x.
+
+ Args:
+ x: int tensor with any shape
+ padding_value: int which represents padded values in input
+ dtype: The dtype of the return value.
+
+ Returns:
+ float tensor with same shape as x containing values 0 or 1.
+ 0 -> non-padding, 1 -> padding
+ """
+ with tf.name_scope("padding"):
+ return tf.cast(tf.equal(x, padding_value), dtype)
+
+
+def get_padding_bias(x, padding_value=0, dtype=tf.float32):
+ """Calculate bias tensor from padding values in tensor.
+
+ Bias tensor that is added to the pre-softmax multi-headed attention logits,
+ which has shape [batch_size, num_heads, length, length]. The tensor is zero at
+ non-padding locations, and -1e9 (negative infinity) at padding locations.
+
+ Args:
+ x: int tensor with shape [batch_size, length]
+ padding_value: int which represents padded values in input
+ dtype: The dtype of the return value
+
+ Returns:
+ Attention bias tensor of shape [batch_size, 1, 1, length].
+ """
+ with tf.name_scope("attention_bias"):
+ padding = get_padding(x, padding_value, dtype)
+ attention_bias = padding * _NEG_INF_FP32
+ attention_bias = tf.expand_dims(
+ tf.expand_dims(attention_bias, axis=1), axis=1)
+ return attention_bias
diff --git a/modeling/official/legacy/transformer/model_utils_test.py b/modeling/official/legacy/transformer/model_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..195fa8f3599cb16311d9e81692c668815f144c48
--- /dev/null
+++ b/modeling/official/legacy/transformer/model_utils_test.py
@@ -0,0 +1,55 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test Transformer model helper methods."""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer import model_utils
+
+NEG_INF = -1e9
+
+
+class ModelUtilsTest(tf.test.TestCase):
+
+ def test_get_padding(self):
+ x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
+ padding = model_utils.get_padding(x, padding_value=0)
+
+ self.assertAllEqual([[0, 1, 1, 1, 0], [0, 0, 1, 1, 1], [1, 0, 0, 1, 0]],
+ padding)
+
+ def test_get_padding_bias(self):
+ x = tf.constant([[1, 0, 0, 0, 2], [3, 4, 0, 0, 0], [0, 5, 6, 0, 7]])
+ bias = model_utils.get_padding_bias(x)
+ bias_shape = tf.shape(bias)
+ flattened_bias = tf.reshape(bias, [3, 5])
+
+ self.assertAllEqual(
+ [[0, NEG_INF, NEG_INF, NEG_INF, 0], [0, 0, NEG_INF, NEG_INF, NEG_INF],
+ [NEG_INF, 0, 0, NEG_INF, 0]], flattened_bias)
+ self.assertAllEqual([3, 1, 1, 5], bias_shape)
+
+ def test_get_decoder_self_attention_bias(self):
+ length = 5
+ bias = model_utils.get_decoder_self_attention_bias(length)
+
+ self.assertAllEqual(
+ [[[[0, NEG_INF, NEG_INF, NEG_INF, NEG_INF],
+ [0, 0, NEG_INF, NEG_INF, NEG_INF], [0, 0, 0, NEG_INF, NEG_INF],
+ [0, 0, 0, 0, NEG_INF], [0, 0, 0, 0, 0]]]], bias)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/transformer/optimizer.py b/modeling/official/legacy/transformer/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd186e3458e75fd9b19c603421687ec09bf3523f
--- /dev/null
+++ b/modeling/official/legacy/transformer/optimizer.py
@@ -0,0 +1,64 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Optimizer from addons and learning rate scheduler."""
+
+import tensorflow as tf, tf_keras
+
+
+class LearningRateSchedule(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Learning rate schedule."""
+
+ def __init__(self, initial_learning_rate, hidden_size, warmup_steps):
+ """Initialize configuration of the learning rate schedule.
+
+ Args:
+ initial_learning_rate: A float, the initial learning rate.
+ hidden_size: An integer, the model dimension in the hidden layers.
+ warmup_steps: An integer, the number of steps required for linear warmup.
+ """
+ super(LearningRateSchedule, self).__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.hidden_size = hidden_size
+ self.warmup_steps = warmup_steps
+ self.warmup_steps_tensor = tf.cast(warmup_steps, tf.float32)
+
+ def __call__(self, global_step):
+ """Calculate learning rate with linear warmup and rsqrt decay.
+
+ Args:
+ global_step: An integer, the current global step used for learning rate
+ calculation.
+
+ Returns:
+ A float, the learning rate needs to be used for current global step.
+ """
+ with tf.name_scope('learning_rate_schedule'):
+ global_step = tf.cast(global_step, tf.float32)
+ learning_rate = self.initial_learning_rate
+ learning_rate *= (self.hidden_size**-0.5)
+ # Apply linear warmup
+ learning_rate *= tf.minimum(1.0, global_step / self.warmup_steps_tensor)
+ # Apply rsqrt decay
+ learning_rate /= tf.sqrt(
+ tf.maximum(global_step, self.warmup_steps_tensor))
+ return learning_rate
+
+ def get_config(self):
+ """Get the configuration of the learning rate schedule."""
+ return {
+ 'initial_learning_rate': self.initial_learning_rate,
+ 'hidden_size': self.hidden_size,
+ 'warmup_steps': self.warmup_steps,
+ }
diff --git a/modeling/official/legacy/transformer/transformer.py b/modeling/official/legacy/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc38b8989692445e691304dc774f68621aa6aee9
--- /dev/null
+++ b/modeling/official/legacy/transformer/transformer.py
@@ -0,0 +1,550 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines the Transformer model in TF 2.0.
+
+Model paper: https://arxiv.org/pdf/1706.03762.pdf
+Transformer model code source: https://github.com/tensorflow/tensor2tensor
+"""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer import attention_layer
+from official.legacy.transformer import embedding_layer
+from official.legacy.transformer import ffn_layer
+from official.legacy.transformer import metrics
+from official.legacy.transformer import model_utils
+from official.legacy.transformer.utils.tokenizer import EOS_ID
+from official.nlp.modeling.layers import position_embedding
+from official.nlp.modeling.ops import beam_search
+
+# Disable the not-callable lint error, since it claims many objects are not
+# callable when they actually are.
+# pylint: disable=not-callable
+
+
+def create_model(params, is_train):
+ """Creates transformer model."""
+ with tf.name_scope("model"):
+ if is_train:
+ inputs = tf_keras.layers.Input((None,), dtype="int64", name="inputs")
+ targets = tf_keras.layers.Input((None,), dtype="int64", name="targets")
+ internal_model = Transformer(params, name="transformer_v2")
+ logits = internal_model([inputs, targets], training=is_train)
+ vocab_size = params["vocab_size"]
+ label_smoothing = params["label_smoothing"]
+ if params["enable_metrics_in_training"]:
+ logits = metrics.MetricLayer(vocab_size)([logits, targets])
+ logits = tf_keras.layers.Lambda(
+ lambda x: x, name="logits", dtype=tf.float32)(
+ logits)
+ model = tf_keras.Model([inputs, targets], logits)
+ loss = metrics.transformer_loss(logits, targets, label_smoothing,
+ vocab_size)
+ model.add_loss(loss)
+ return model
+
+ else:
+ inputs = tf_keras.layers.Input((None,), dtype="int64", name="inputs")
+ internal_model = Transformer(params, name="transformer_v2")
+ ret = internal_model([inputs], training=is_train)
+ outputs, scores = ret["outputs"], ret["scores"]
+ return tf_keras.Model(inputs, [outputs, scores])
+
+
+class Transformer(tf_keras.Model):
+ """Transformer model with Keras.
+
+ Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf
+
+ The Transformer model consists of an encoder and decoder. The input is an int
+ sequence (or a batch of sequences). The encoder produces a continuous
+ representation, and the decoder uses the encoder output to generate
+ probabilities for the output sequence.
+ """
+
+ def __init__(self, params, name=None):
+ """Initialize layers to build Transformer model.
+
+ Args:
+ params: hyperparameter object defining layer sizes, dropout values, etc.
+ name: name of the model.
+ """
+ super(Transformer, self).__init__(name=name)
+ self.params = params
+ self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights(
+ params["vocab_size"], params["hidden_size"])
+ self.encoder_stack = EncoderStack(params)
+ self.decoder_stack = DecoderStack(params)
+ self.position_embedding = position_embedding.RelativePositionEmbedding(
+ hidden_size=self.params["hidden_size"])
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self, inputs, training):
+ """Calculate target logits or inferred target sequences.
+
+ Args:
+ inputs: input tensor list of size 1 or 2.
+ First item, inputs: int tensor with shape [batch_size, input_length].
+ Second item (optional), targets: None or int tensor with shape
+ [batch_size, target_length].
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ If targets is defined, then return logits for each word in the target
+ sequence. float tensor with shape [batch_size, target_length, vocab_size]
+ If target is none, then generate output sequence one token at a time.
+ returns a dictionary {
+ outputs: int tensor with shape [batch_size, decoded_length]
+ scores: float tensor with shape [batch_size]}
+ Even when float16 is used, the output tensor(s) are always float32.
+
+ Raises:
+ NotImplementedError: If try to use padded decode method on CPU/GPUs.
+ """
+ inputs = inputs if isinstance(inputs, list) else [inputs]
+ if len(inputs) == 2:
+ inputs, targets = inputs[0], inputs[1]
+ else:
+ # Decoding path.
+ inputs, targets = inputs[0], None
+ if self.params["padded_decode"]:
+ if not self.params["num_replicas"]:
+ raise NotImplementedError(
+ "Padded decoding on CPU/GPUs is not supported.")
+ decode_batch_size = int(self.params["decode_batch_size"] /
+ self.params["num_replicas"])
+ inputs.set_shape([decode_batch_size, self.params["decode_max_length"]])
+
+ # Variance scaling is used here because it seems to work in many problems.
+ # Other reasonable initializers may also work just as well.
+ with tf.name_scope("Transformer"):
+ # Calculate attention bias for encoder self-attention and decoder
+ # multi-headed attention layers.
+ attention_bias = model_utils.get_padding_bias(inputs)
+
+ # Run the inputs through the encoder layer to map the symbol
+ # representations to continuous representations.
+ encoder_outputs = self.encode(inputs, attention_bias, training)
+ # Generate output sequence if targets is None, or return logits if target
+ # sequence is known.
+ if targets is None:
+ return self.predict(encoder_outputs, attention_bias, training)
+ else:
+ logits = self.decode(targets, encoder_outputs, attention_bias, training)
+ return logits
+
+ def encode(self, inputs, attention_bias, training):
+ """Generate continuous representation for inputs.
+
+ Args:
+ inputs: int tensor with shape [batch_size, input_length].
+ attention_bias: float tensor with shape [batch_size, 1, 1, input_length].
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ float tensor with shape [batch_size, input_length, hidden_size]
+ """
+ with tf.name_scope("encode"):
+ # Prepare inputs to the layer stack by adding positional encodings and
+ # applying dropout.
+ embedded_inputs = self.embedding_softmax_layer(inputs)
+ embedded_inputs = tf.cast(embedded_inputs, self.params["dtype"])
+ inputs_padding = model_utils.get_padding(inputs)
+ attention_bias = tf.cast(attention_bias, self.params["dtype"])
+
+ with tf.name_scope("add_pos_encoding"):
+ pos_encoding = self.position_embedding(inputs=embedded_inputs)
+ pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
+ encoder_inputs = embedded_inputs + pos_encoding
+
+ if training:
+ encoder_inputs = tf.nn.dropout(
+ encoder_inputs, rate=self.params["layer_postprocess_dropout"])
+
+ return self.encoder_stack(
+ encoder_inputs, attention_bias, inputs_padding, training=training)
+
+ def decode(self, targets, encoder_outputs, attention_bias, training):
+ """Generate logits for each value in the target sequence.
+
+ Args:
+ targets: target values for the output sequence. int tensor with shape
+ [batch_size, target_length]
+ encoder_outputs: continuous representation of input sequence. float tensor
+ with shape [batch_size, input_length, hidden_size]
+ attention_bias: float tensor with shape [batch_size, 1, 1, input_length]
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ float32 tensor with shape [batch_size, target_length, vocab_size]
+ """
+ with tf.name_scope("decode"):
+ # Prepare inputs to decoder layers by shifting targets, adding positional
+ # encoding and applying dropout.
+ with tf.name_scope("shift_targets"):
+ # Shift targets to the right, and remove the last element
+ targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1]
+ decoder_inputs = self.embedding_softmax_layer(targets)
+ decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"])
+ attention_bias = tf.cast(attention_bias, self.params["dtype"])
+ with tf.name_scope("add_pos_encoding"):
+ length = tf.shape(decoder_inputs)[1]
+ pos_encoding = self.position_embedding(decoder_inputs)
+ pos_encoding = tf.cast(pos_encoding, self.params["dtype"])
+ decoder_inputs += pos_encoding
+ if training:
+ decoder_inputs = tf.nn.dropout(
+ decoder_inputs, rate=self.params["layer_postprocess_dropout"])
+
+ # Run values
+ decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
+ length, dtype=self.params["dtype"])
+ outputs = self.decoder_stack(
+ decoder_inputs,
+ encoder_outputs,
+ decoder_self_attention_bias,
+ attention_bias,
+ training=training)
+ logits = self.embedding_softmax_layer(outputs, mode="linear")
+ logits = tf.cast(logits, tf.float32)
+ return logits
+
+ def _get_symbols_to_logits_fn(self, max_decode_length, training):
+ """Returns a decoding function that calculates logits of the next tokens."""
+ timing_signal = self.position_embedding(
+ inputs=None, length=max_decode_length + 1)
+ timing_signal = tf.cast(timing_signal, self.params["dtype"])
+ decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias(
+ max_decode_length, dtype=self.params["dtype"])
+
+ def symbols_to_logits_fn(ids, i, cache):
+ """Generate logits for next potential IDs.
+
+ Args:
+ ids: Current decoded sequences. int tensor with shape [batch_size *
+ beam_size, i + 1].
+ i: Loop index.
+ cache: dictionary of values storing the encoder output, encoder-decoder
+ attention bias, and previous decoder attention values.
+
+ Returns:
+ Tuple of
+ (logits with shape [batch_size * beam_size, vocab_size],
+ updated cache values)
+ """
+ # Set decoder input to the last generated IDs
+ decoder_input = ids[:, -1:]
+
+ # Preprocess decoder input by getting embeddings and adding timing signal.
+ decoder_input = self.embedding_softmax_layer(decoder_input)
+ decoder_input += timing_signal[i]
+ if self.params["padded_decode"]:
+ bias_shape = decoder_self_attention_bias.shape.as_list()
+ self_attention_bias = tf.slice(
+ decoder_self_attention_bias, [0, 0, i, 0],
+ [bias_shape[0], bias_shape[1], 1, bias_shape[3]])
+ else:
+ self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]
+
+ decoder_outputs = self.decoder_stack(
+ decoder_input,
+ cache.get("encoder_outputs"),
+ self_attention_bias,
+ cache.get("encoder_decoder_attention_bias"),
+ training=training,
+ cache=cache,
+ decode_loop_step=i if self.params["padded_decode"] else None)
+ logits = self.embedding_softmax_layer(decoder_outputs, mode="linear")
+ logits = tf.squeeze(logits, axis=[1])
+ return logits, cache
+
+ return symbols_to_logits_fn
+
+ def predict(self, encoder_outputs, encoder_decoder_attention_bias, training):
+ """Return predicted sequence."""
+ encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"])
+ if self.params["padded_decode"]:
+ batch_size = encoder_outputs.shape.as_list()[0]
+ input_length = encoder_outputs.shape.as_list()[1]
+ else:
+ batch_size = tf.shape(encoder_outputs)[0]
+ input_length = tf.shape(encoder_outputs)[1]
+ max_decode_length = input_length + self.params["extra_decode_length"]
+ encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
+ self.params["dtype"])
+
+ symbols_to_logits_fn = self._get_symbols_to_logits_fn(
+ max_decode_length, training)
+
+ # Create initial set of IDs that will be passed into symbols_to_logits_fn.
+ initial_ids = tf.zeros([batch_size], dtype=tf.int32)
+
+ # Create cache storing decoder attention values for each layer.
+ # pylint: disable=g-complex-comprehension
+ init_decode_length = (
+ max_decode_length if self.params["padded_decode"] else 0)
+ num_heads = self.params["num_heads"]
+ dim_per_head = self.params["hidden_size"] // num_heads
+ cache = {
+ "layer_%d" % layer: {
+ "k":
+ tf.zeros(
+ [batch_size, init_decode_length, num_heads, dim_per_head],
+ dtype=self.params["dtype"]),
+ "v":
+ tf.zeros(
+ [batch_size, init_decode_length, num_heads, dim_per_head],
+ dtype=self.params["dtype"])
+ } for layer in range(self.params["num_hidden_layers"])
+ }
+ # pylint: enable=g-complex-comprehension
+
+ # Add encoder output and attention bias to the cache.
+ cache["encoder_outputs"] = encoder_outputs
+ cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
+
+ # Use beam search to find the top beam_size sequences and scores.
+ decoded_ids, scores = beam_search.sequence_beam_search(
+ symbols_to_logits_fn=symbols_to_logits_fn,
+ initial_ids=initial_ids,
+ initial_cache=cache,
+ vocab_size=self.params["vocab_size"],
+ beam_size=self.params["beam_size"],
+ alpha=self.params["alpha"],
+ max_decode_length=max_decode_length,
+ eos_id=EOS_ID,
+ padded_decode=self.params["padded_decode"],
+ dtype=self.params["dtype"])
+
+ # Get the top sequence for each batch element
+ top_decoded_ids = decoded_ids[:, 0, 1:]
+ top_scores = scores[:, 0]
+
+ return {"outputs": top_decoded_ids, "scores": top_scores}
+
+
+class PrePostProcessingWrapper(tf_keras.layers.Layer):
+ """Wrapper class that applies layer pre-processing and post-processing."""
+
+ def __init__(self, layer, params):
+ super(PrePostProcessingWrapper, self).__init__()
+ self.layer = layer
+ self.params = params
+ self.postprocess_dropout = params["layer_postprocess_dropout"]
+
+ def build(self, input_shape):
+ # Create normalization layer
+ self.layer_norm = tf_keras.layers.LayerNormalization(
+ epsilon=1e-6, dtype="float32")
+ super(PrePostProcessingWrapper, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self, x, *args, **kwargs):
+ """Calls wrapped layer with same parameters."""
+ # Preprocessing: apply layer normalization
+ training = kwargs["training"]
+
+ y = self.layer_norm(x)
+
+ # Get layer output
+ y = self.layer(y, *args, **kwargs)
+
+ # Postprocessing: apply dropout and residual connection
+ if training:
+ y = tf.nn.dropout(y, rate=self.postprocess_dropout)
+ return x + y
+
+
+class EncoderStack(tf_keras.layers.Layer):
+ """Transformer encoder stack.
+
+ The encoder stack is made up of N identical layers. Each layer is composed
+ of the sublayers:
+ 1. Self-attention layer
+ 2. Feedforward network (which is 2 fully-connected layers)
+ """
+
+ def __init__(self, params):
+ super(EncoderStack, self).__init__()
+ self.params = params
+ self.layers = []
+
+ def build(self, input_shape):
+ """Builds the encoder stack."""
+ params = self.params
+ for _ in range(params["num_hidden_layers"]):
+ # Create sublayers for each layer.
+ self_attention_layer = attention_layer.SelfAttention(
+ params["hidden_size"], params["num_heads"],
+ params["attention_dropout"])
+ feed_forward_network = ffn_layer.FeedForwardNetwork(
+ params["hidden_size"], params["filter_size"], params["relu_dropout"])
+
+ self.layers.append([
+ PrePostProcessingWrapper(self_attention_layer, params),
+ PrePostProcessingWrapper(feed_forward_network, params)
+ ])
+
+ # Create final layer normalization layer.
+ self.output_normalization = tf_keras.layers.LayerNormalization(
+ epsilon=1e-6, dtype="float32")
+ super(EncoderStack, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self, encoder_inputs, attention_bias, inputs_padding, training):
+ """Return the output of the encoder layer stacks.
+
+ Args:
+ encoder_inputs: tensor with shape [batch_size, input_length, hidden_size]
+ attention_bias: bias for the encoder self-attention layer. [batch_size, 1,
+ 1, input_length]
+ inputs_padding: tensor with shape [batch_size, input_length], inputs with
+ zero paddings.
+ training: boolean, whether in training mode or not.
+
+ Returns:
+ Output of encoder layer stack.
+ float32 tensor with shape [batch_size, input_length, hidden_size]
+ """
+ for n, layer in enumerate(self.layers):
+ # Run inputs through the sublayers.
+ self_attention_layer = layer[0]
+ feed_forward_network = layer[1]
+
+ with tf.name_scope("layer_%d" % n):
+ with tf.name_scope("self_attention"):
+ encoder_inputs = self_attention_layer(
+ encoder_inputs, attention_bias, training=training)
+ with tf.name_scope("ffn"):
+ encoder_inputs = feed_forward_network(
+ encoder_inputs, training=training)
+
+ return self.output_normalization(encoder_inputs)
+
+
+class DecoderStack(tf_keras.layers.Layer):
+ """Transformer decoder stack.
+
+ Like the encoder stack, the decoder stack is made up of N identical layers.
+ Each layer is composed of the sublayers:
+ 1. Self-attention layer
+ 2. Multi-headed attention layer combining encoder outputs with results from
+ the previous self-attention layer.
+ 3. Feedforward network (2 fully-connected layers)
+ """
+
+ def __init__(self, params):
+ super(DecoderStack, self).__init__()
+ self.params = params
+ self.layers = []
+
+ def build(self, input_shape):
+ """Builds the decoder stack."""
+ params = self.params
+ for _ in range(params["num_hidden_layers"]):
+ self_attention_layer = attention_layer.SelfAttention(
+ params["hidden_size"], params["num_heads"],
+ params["attention_dropout"])
+ enc_dec_attention_layer = attention_layer.Attention(
+ params["hidden_size"], params["num_heads"],
+ params["attention_dropout"])
+ feed_forward_network = ffn_layer.FeedForwardNetwork(
+ params["hidden_size"], params["filter_size"], params["relu_dropout"])
+
+ self.layers.append([
+ PrePostProcessingWrapper(self_attention_layer, params),
+ PrePostProcessingWrapper(enc_dec_attention_layer, params),
+ PrePostProcessingWrapper(feed_forward_network, params)
+ ])
+ self.output_normalization = tf_keras.layers.LayerNormalization(
+ epsilon=1e-6, dtype="float32")
+ super(DecoderStack, self).build(input_shape)
+
+ def get_config(self):
+ return {
+ "params": self.params,
+ }
+
+ def call(self,
+ decoder_inputs,
+ encoder_outputs,
+ decoder_self_attention_bias,
+ attention_bias,
+ training,
+ cache=None,
+ decode_loop_step=None):
+ """Return the output of the decoder layer stacks.
+
+ Args:
+ decoder_inputs: A tensor with shape [batch_size, target_length,
+ hidden_size].
+ encoder_outputs: A tensor with shape [batch_size, input_length,
+ hidden_size]
+ decoder_self_attention_bias: A tensor with shape [1, 1, target_len,
+ target_length], the bias for decoder self-attention layer.
+ attention_bias: A tensor with shape [batch_size, 1, 1, input_length], the
+ bias for encoder-decoder attention layer.
+ training: A bool, whether in training mode or not.
+ cache: (Used for fast decoding) A nested dictionary storing previous
+ decoder self-attention values. The items are:
+ {layer_n: {"k": A tensor with shape [batch_size, i, key_channels],
+ "v": A tensor with shape [batch_size, i, value_channels]},
+ ...}
+ decode_loop_step: An integer, the step number of the decoding loop. Used
+ only for autoregressive inference on TPU.
+
+ Returns:
+ Output of decoder layer stack.
+ float32 tensor with shape [batch_size, target_length, hidden_size]
+ """
+ for n, layer in enumerate(self.layers):
+ self_attention_layer = layer[0]
+ enc_dec_attention_layer = layer[1]
+ feed_forward_network = layer[2]
+
+ # Run inputs through the sublayers.
+ layer_name = "layer_%d" % n
+ layer_cache = cache[layer_name] if cache is not None else None
+ with tf.name_scope(layer_name):
+ with tf.name_scope("self_attention"):
+ decoder_inputs = self_attention_layer(
+ decoder_inputs,
+ decoder_self_attention_bias,
+ training=training,
+ cache=layer_cache,
+ decode_loop_step=decode_loop_step)
+ with tf.name_scope("encdec_attention"):
+ decoder_inputs = enc_dec_attention_layer(
+ decoder_inputs,
+ encoder_outputs,
+ attention_bias,
+ training=training)
+ with tf.name_scope("ffn"):
+ decoder_inputs = feed_forward_network(
+ decoder_inputs, training=training)
+
+ return self.output_normalization(decoder_inputs)
diff --git a/modeling/official/legacy/transformer/transformer_forward_test.py b/modeling/official/legacy/transformer/transformer_forward_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..20796a93b62207b713adbf21b8c196275bd12107
--- /dev/null
+++ b/modeling/official/legacy/transformer/transformer_forward_test.py
@@ -0,0 +1,156 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Forward pass test for Transformer model refactoring."""
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer import metrics
+from official.legacy.transformer import model_params
+from official.legacy.transformer import transformer
+from official.nlp.modeling import models
+
+
+def _count_params(layer, trainable_only=True):
+ """Returns the count of all model parameters, or just trainable ones."""
+ if not trainable_only:
+ return layer.count_params()
+ else:
+ return int(
+ np.sum([
+ tf_keras.backend.count_params(p) for p in layer.trainable_weights
+ ]))
+
+
+def _create_model(params, is_train):
+ """Creates transformer model."""
+
+ encdec_kwargs = dict(
+ num_layers=params["num_hidden_layers"],
+ num_attention_heads=params["num_heads"],
+ intermediate_size=params["filter_size"],
+ activation="relu",
+ dropout_rate=params["relu_dropout"],
+ attention_dropout_rate=params["attention_dropout"],
+ use_bias=False,
+ norm_first=True,
+ norm_epsilon=1e-6,
+ intermediate_dropout=params["relu_dropout"])
+ encoder_layer = models.TransformerEncoder(**encdec_kwargs)
+ decoder_layer = models.TransformerDecoder(**encdec_kwargs)
+
+ model_kwargs = dict(
+ vocab_size=params["vocab_size"],
+ embedding_width=params["hidden_size"],
+ dropout_rate=params["layer_postprocess_dropout"],
+ padded_decode=params["padded_decode"],
+ decode_max_length=params["decode_max_length"],
+ dtype=params["dtype"],
+ extra_decode_length=params["extra_decode_length"],
+ beam_size=params["beam_size"],
+ alpha=params["alpha"],
+ encoder_layer=encoder_layer,
+ decoder_layer=decoder_layer,
+ name="transformer_v2")
+
+ if is_train:
+ inputs = tf_keras.layers.Input((None,), dtype="int64", name="inputs")
+ targets = tf_keras.layers.Input((None,), dtype="int64", name="targets")
+ internal_model = models.Seq2SeqTransformer(**model_kwargs)
+ logits = internal_model(
+ dict(inputs=inputs, targets=targets), training=is_train)
+ vocab_size = params["vocab_size"]
+ label_smoothing = params["label_smoothing"]
+ if params["enable_metrics_in_training"]:
+ logits = metrics.MetricLayer(vocab_size)([logits, targets])
+ logits = tf_keras.layers.Lambda(
+ lambda x: x, name="logits", dtype=tf.float32)(
+ logits)
+ model = tf_keras.Model([inputs, targets], logits)
+ loss = metrics.transformer_loss(logits, targets, label_smoothing,
+ vocab_size)
+ model.add_loss(loss)
+ return model
+
+ batch_size = params["decode_batch_size"] if params["padded_decode"] else None
+ inputs = tf_keras.layers.Input((None,),
+ batch_size=batch_size,
+ dtype="int64",
+ name="inputs")
+ internal_model = models.Seq2SeqTransformer(**model_kwargs)
+ ret = internal_model(dict(inputs=inputs), training=is_train)
+ outputs, scores = ret["outputs"], ret["scores"]
+ return tf_keras.Model(inputs, [outputs, scores])
+
+
+class TransformerForwardTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(TransformerForwardTest, self).setUp()
+ self.params = params = model_params.TINY_PARAMS
+ params["batch_size"] = params["default_batch_size"] = 16
+ params["hidden_size"] = 12
+ params["num_hidden_layers"] = 3
+ params["filter_size"] = 14
+ params["num_heads"] = 2
+ params["vocab_size"] = 41
+ params["extra_decode_length"] = 0
+ params["beam_size"] = 3
+ params["dtype"] = tf.float32
+ params["layer_postprocess_dropout"] = 0.0
+ params["attention_dropout"] = 0.0
+ params["relu_dropout"] = 0.0
+
+ def test_forward_pass_train(self):
+ # Set input_len different from target_len
+ inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]])
+ targets = np.asarray([[4, 3, 4, 0], [13, 19, 17, 8], [20, 14, 1, 2],
+ [5, 7, 3, 0]])
+
+ # src_model is the original model before refactored.
+ src_model = transformer.create_model(self.params, True)
+ src_num_weights = _count_params(src_model)
+ src_weights = src_model.get_weights()
+ src_model_output = src_model([inputs, targets], training=True)
+
+ # dest_model is the refactored model.
+ dest_model = _create_model(self.params, True)
+ dest_num_weights = _count_params(dest_model)
+ self.assertEqual(src_num_weights, dest_num_weights)
+ dest_model.set_weights(src_weights)
+ dest_model_output = dest_model([inputs, targets], training=True)
+ self.assertAllEqual(src_model_output, dest_model_output)
+
+ def test_forward_pass_not_train(self):
+ inputs = np.asarray([[5, 2, 1], [7, 5, 0], [1, 4, 0], [7, 5, 11]])
+
+ # src_model is the original model before refactored.
+ src_model = transformer.create_model(self.params, False)
+ src_num_weights = _count_params(src_model)
+ src_weights = src_model.get_weights()
+ src_model_output = src_model([inputs], training=False)
+
+ # dest_model is the refactored model.
+ dest_model = _create_model(self.params, False)
+ dest_num_weights = _count_params(dest_model)
+ self.assertEqual(src_num_weights, dest_num_weights)
+ dest_model.set_weights(src_weights)
+ dest_model_output = dest_model([inputs], training=False)
+ self.assertAllEqual(src_model_output[0], dest_model_output[0])
+ self.assertAllEqual(src_model_output[1], dest_model_output[1])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/transformer/transformer_layers_test.py b/modeling/official/legacy/transformer/transformer_layers_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c07c51d749ddbdbae8d51034bfbf509fc6232283
--- /dev/null
+++ b/modeling/official/legacy/transformer/transformer_layers_test.py
@@ -0,0 +1,125 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for layers in Transformer."""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer import attention_layer
+from official.legacy.transformer import embedding_layer
+from official.legacy.transformer import ffn_layer
+from official.legacy.transformer import metrics
+
+
+class TransformerLayersTest(tf.test.TestCase):
+
+ def test_attention_layer(self):
+ hidden_size = 64
+ num_heads = 4
+ dropout = 0.5
+ dim_per_head = hidden_size // num_heads
+ layer = attention_layer.SelfAttention(hidden_size, num_heads, dropout)
+ self.assertDictEqual(
+ layer.get_config(), {
+ "hidden_size": hidden_size,
+ "num_heads": num_heads,
+ "attention_dropout": dropout,
+ })
+ length = 2
+ x = tf.ones([1, length, hidden_size])
+ bias = tf.ones([1])
+ cache = {
+ "k": tf.zeros([1, 0, num_heads, dim_per_head]),
+ "v": tf.zeros([1, 0, num_heads, dim_per_head]),
+ }
+ y = layer(x, bias, training=True, cache=cache)
+ self.assertEqual(y.shape, (
+ 1,
+ length,
+ 64,
+ ))
+ self.assertEqual(cache["k"].shape, (
+ 1,
+ length,
+ num_heads,
+ dim_per_head,
+ ))
+ self.assertEqual(cache["v"].shape, (
+ 1,
+ length,
+ num_heads,
+ dim_per_head,
+ ))
+
+ def test_embedding_shared_weights(self):
+ vocab_size = 50
+ hidden_size = 64
+ length = 2
+ layer = embedding_layer.EmbeddingSharedWeights(vocab_size, hidden_size)
+ self.assertDictEqual(layer.get_config(), {
+ "vocab_size": 50,
+ "hidden_size": 64,
+ })
+
+ idx = tf.ones([1, length], dtype="int32")
+ y = layer(idx)
+ self.assertEqual(y.shape, (
+ 1,
+ length,
+ hidden_size,
+ ))
+ x = tf.ones([1, length, hidden_size])
+ output = layer(x, "linear")
+ self.assertEqual(output.shape, (
+ 1,
+ length,
+ vocab_size,
+ ))
+
+ def test_feed_forward_network(self):
+ hidden_size = 64
+ filter_size = 32
+ relu_dropout = 0.5
+ layer = ffn_layer.FeedForwardNetwork(hidden_size, filter_size, relu_dropout)
+ self.assertDictEqual(
+ layer.get_config(), {
+ "hidden_size": hidden_size,
+ "filter_size": filter_size,
+ "relu_dropout": relu_dropout,
+ })
+ length = 2
+ x = tf.ones([1, length, hidden_size])
+ y = layer(x, training=True)
+ self.assertEqual(y.shape, (
+ 1,
+ length,
+ hidden_size,
+ ))
+
+ def test_metric_layer(self):
+ vocab_size = 50
+ logits = tf_keras.layers.Input((None, vocab_size),
+ dtype="float32",
+ name="logits")
+ targets = tf_keras.layers.Input((None,), dtype="int64", name="targets")
+ output_logits = metrics.MetricLayer(vocab_size)([logits, targets])
+ self.assertEqual(output_logits.shape.as_list(), [
+ None,
+ None,
+ vocab_size,
+ ])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/transformer/transformer_main.py b/modeling/official/legacy/transformer/transformer_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c3e4a93bd2951889cd95e8a2b724885feaf8f12
--- /dev/null
+++ b/modeling/official/legacy/transformer/transformer_main.py
@@ -0,0 +1,485 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Train and evaluate the Transformer model.
+
+See README for description of setting the training schedule and evaluating the
+BLEU score.
+"""
+
+import os
+import tempfile
+
+# Import libraries
+
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.common import distribute_utils
+from official.legacy.transformer import compute_bleu
+from official.legacy.transformer import data_pipeline
+from official.legacy.transformer import metrics
+from official.legacy.transformer import misc
+from official.legacy.transformer import optimizer
+from official.legacy.transformer import transformer
+from official.legacy.transformer import translate
+from official.legacy.transformer.utils import tokenizer
+from official.modeling import performance
+from official.utils.flags import core as flags_core
+from official.utils.misc import keras_utils
+
+# pylint:disable=logging-format-interpolation
+
+INF = int(1e9)
+BLEU_DIR = "bleu"
+_SINGLE_SAMPLE = 1
+
+
+def translate_and_compute_bleu(model,
+ params,
+ subtokenizer,
+ bleu_source,
+ bleu_ref,
+ distribution_strategy=None):
+ """Translate file and report the cased and uncased bleu scores.
+
+ Args:
+ model: A Keras model, used to generate the translations.
+ params: A dictionary, containing the translation related parameters.
+ subtokenizer: A subtokenizer object, used for encoding and decoding source
+ and translated lines.
+ bleu_source: A file containing source sentences for translation.
+ bleu_ref: A file containing the reference for the translated sentences.
+ distribution_strategy: A platform distribution strategy, used for TPU based
+ translation.
+
+ Returns:
+ uncased_score: A float, the case insensitive BLEU score.
+ cased_score: A float, the case sensitive BLEU score.
+ """
+ # Create temporary file to store translation.
+ tmp = tempfile.NamedTemporaryFile(delete=False)
+ tmp_filename = tmp.name
+
+ translate.translate_file(
+ model,
+ params,
+ subtokenizer,
+ bleu_source,
+ output_file=tmp_filename,
+ print_all_translations=False,
+ distribution_strategy=distribution_strategy)
+
+ # Compute uncased and cased bleu scores.
+ uncased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, False)
+ cased_score = compute_bleu.bleu_wrapper(bleu_ref, tmp_filename, True)
+ os.remove(tmp_filename)
+ return uncased_score, cased_score
+
+
+def evaluate_and_log_bleu(model,
+ params,
+ bleu_source,
+ bleu_ref,
+ vocab_file,
+ distribution_strategy=None):
+ """Calculate and record the BLEU score.
+
+ Args:
+ model: A Keras model, used to generate the translations.
+ params: A dictionary, containing the translation related parameters.
+ bleu_source: A file containing source sentences for translation.
+ bleu_ref: A file containing the reference for the translated sentences.
+ vocab_file: A file containing the vocabulary for translation.
+ distribution_strategy: A platform distribution strategy, used for TPU based
+ translation.
+
+ Returns:
+ uncased_score: A float, the case insensitive BLEU score.
+ cased_score: A float, the case sensitive BLEU score.
+ """
+ subtokenizer = tokenizer.Subtokenizer(vocab_file)
+
+ uncased_score, cased_score = translate_and_compute_bleu(
+ model, params, subtokenizer, bleu_source, bleu_ref, distribution_strategy)
+
+ logging.info("Bleu score (uncased): %s", uncased_score)
+ logging.info("Bleu score (cased): %s", cased_score)
+ return uncased_score, cased_score
+
+
+class TransformerTask(object):
+ """Main entry of Transformer model."""
+
+ def __init__(self, flags_obj):
+ """Init function of TransformerMain.
+
+ Args:
+ flags_obj: Object containing parsed flag values, i.e., FLAGS.
+
+ Raises:
+ ValueError: if not using static batch for input data on TPU.
+ """
+ self.flags_obj = flags_obj
+ self.predict_model = None
+
+ # Add flag-defined parameters to params object
+ num_gpus = flags_core.get_num_gpus(flags_obj)
+ self.params = params = misc.get_model_params(flags_obj.param_set, num_gpus)
+
+ params["num_gpus"] = num_gpus
+ params["use_ctl"] = flags_obj.use_ctl
+ params["data_dir"] = flags_obj.data_dir
+ params["model_dir"] = flags_obj.model_dir
+ params["static_batch"] = flags_obj.static_batch
+ params["max_length"] = flags_obj.max_length
+ params["decode_batch_size"] = flags_obj.decode_batch_size
+ params["decode_max_length"] = flags_obj.decode_max_length
+ params["padded_decode"] = flags_obj.padded_decode
+ params["max_io_parallelism"] = (
+ flags_obj.num_parallel_calls or tf.data.experimental.AUTOTUNE)
+
+ params["use_synthetic_data"] = flags_obj.use_synthetic_data
+ params["batch_size"] = flags_obj.batch_size or params["default_batch_size"]
+ params["repeat_dataset"] = None
+ params["dtype"] = flags_core.get_tf_dtype(flags_obj)
+ params["enable_tensorboard"] = flags_obj.enable_tensorboard
+ params["enable_metrics_in_training"] = flags_obj.enable_metrics_in_training
+ params["steps_between_evals"] = flags_obj.steps_between_evals
+ params["enable_checkpointing"] = flags_obj.enable_checkpointing
+ params["save_weights_only"] = flags_obj.save_weights_only
+
+ self.distribution_strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=flags_obj.distribution_strategy,
+ num_gpus=num_gpus,
+ all_reduce_alg=flags_obj.all_reduce_alg,
+ num_packs=flags_obj.num_packs,
+ tpu_address=flags_obj.tpu or "")
+ if self.use_tpu:
+ params["num_replicas"] = self.distribution_strategy.num_replicas_in_sync
+ else:
+ logging.info("Running transformer with num_gpus = %d", num_gpus)
+
+ if self.distribution_strategy:
+ logging.info("For training, using distribution strategy: %s",
+ self.distribution_strategy)
+ else:
+ logging.info("Not using any distribution strategy.")
+
+ performance.set_mixed_precision_policy(params["dtype"])
+
+ @property
+ def use_tpu(self):
+ if self.distribution_strategy:
+ return isinstance(self.distribution_strategy, tf.distribute.TPUStrategy)
+ return False
+
+ def train(self):
+ """Trains the model."""
+ params = self.params
+ flags_obj = self.flags_obj
+ # Sets config options.
+ keras_utils.set_session_config(enable_xla=flags_obj.enable_xla)
+
+ _ensure_dir(flags_obj.model_dir)
+ with distribute_utils.get_strategy_scope(self.distribution_strategy):
+ model = transformer.create_model(params, is_train=True)
+ opt = self._create_optimizer()
+
+ current_step = 0
+ checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
+ latest_checkpoint = tf.train.latest_checkpoint(flags_obj.model_dir)
+ if latest_checkpoint:
+ checkpoint.restore(latest_checkpoint)
+ logging.info("Loaded checkpoint %s", latest_checkpoint)
+ current_step = opt.iterations.numpy()
+
+ if params["use_ctl"]:
+ train_loss_metric = tf_keras.metrics.Mean(
+ "training_loss", dtype=tf.float32)
+ if params["enable_tensorboard"]:
+ summary_writer = tf.summary.create_file_writer(
+ os.path.join(flags_obj.model_dir, "summary"))
+ else:
+ summary_writer = tf.summary.create_noop_writer()
+ train_metrics = [train_loss_metric]
+ if params["enable_metrics_in_training"]:
+ train_metrics = train_metrics + model.metrics
+ else:
+ model.compile(opt)
+
+ model.summary()
+
+ if self.use_tpu:
+ # Different from experimental_distribute_dataset,
+ # distribute_datasets_from_function requires
+ # per-replica/local batch size.
+ params["batch_size"] /= self.distribution_strategy.num_replicas_in_sync
+ train_ds = (
+ self.distribution_strategy.distribute_datasets_from_function(
+ lambda ctx: data_pipeline.train_input_fn(params, ctx)))
+ else:
+ train_ds = data_pipeline.train_input_fn(params)
+ map_data_fn = data_pipeline.map_data_for_transformer_fn
+ train_ds = train_ds.map(
+ map_data_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ if params["use_ctl"]:
+ train_ds_iterator = iter(train_ds)
+
+ callbacks = self._create_callbacks(flags_obj.model_dir, params)
+
+ # Only TimeHistory callback is supported for CTL
+ if params["use_ctl"]:
+ callbacks = [cb for cb in callbacks
+ if isinstance(cb, keras_utils.TimeHistory)]
+
+ @tf.function
+ def train_steps(iterator, steps):
+ """Training steps function for TPU runs.
+
+ Args:
+ iterator: The input iterator of the training dataset.
+ steps: An integer, the number of training steps.
+
+ Returns:
+ A float, the loss value.
+ """
+
+ def _step_fn(inputs):
+ """Per-replica step function."""
+ inputs, targets = inputs
+ with tf.GradientTape() as tape:
+ logits = model([inputs, targets], training=True)
+ loss = metrics.transformer_loss(logits, targets,
+ params["label_smoothing"],
+ params["vocab_size"])
+ # Scales the loss, which results in using the average loss across all
+ # of the replicas for backprop.
+ scaled_loss = loss / self.distribution_strategy.num_replicas_in_sync
+
+ # De-dupes variables due to keras tracking issues.
+ tvars = list({id(v): v for v in model.trainable_variables}.values())
+ grads = tape.gradient(scaled_loss, tvars)
+ opt.apply_gradients(zip(grads, tvars))
+ # For reporting, the metric takes the mean of losses.
+ train_loss_metric.update_state(loss)
+
+ for _ in tf.range(steps):
+ train_loss_metric.reset_states()
+ self.distribution_strategy.run(
+ _step_fn, args=(next(iterator),))
+
+ cased_score, uncased_score = None, None
+ cased_score_history, uncased_score_history = [], []
+ while current_step < flags_obj.train_steps:
+ remaining_steps = flags_obj.train_steps - current_step
+ train_steps_per_eval = (
+ remaining_steps if remaining_steps < flags_obj.steps_between_evals
+ else flags_obj.steps_between_evals)
+ current_iteration = current_step // flags_obj.steps_between_evals
+
+ logging.info(
+ "Start train iteration at global step:{}".format(current_step))
+ history = None
+ if params["use_ctl"]:
+ if not self.use_tpu:
+ raise NotImplementedError(
+ "Custom training loop on GPUs is not implemented.")
+
+ # Runs training steps.
+ with summary_writer.as_default():
+ for cb in callbacks:
+ cb.on_epoch_begin(current_iteration)
+ cb.on_batch_begin(0)
+
+ train_steps(
+ train_ds_iterator,
+ tf.convert_to_tensor(train_steps_per_eval, dtype=tf.int32))
+ current_step += train_steps_per_eval
+ train_loss = train_loss_metric.result().numpy().astype(float)
+ logging.info("Train Step: %d/%d / loss = %s", current_step,
+ flags_obj.train_steps, train_loss)
+
+ for cb in callbacks:
+ cb.on_batch_end(train_steps_per_eval - 1)
+ cb.on_epoch_end(current_iteration)
+
+ if params["enable_tensorboard"]:
+ for metric_obj in train_metrics:
+ tf.summary.scalar(metric_obj.name, metric_obj.result(),
+ current_step)
+ summary_writer.flush()
+
+ for cb in callbacks:
+ cb.on_train_end()
+
+ if flags_obj.enable_checkpointing:
+ # avoid check-pointing when running for benchmarking.
+ checkpoint_name = checkpoint.save(
+ os.path.join(flags_obj.model_dir,
+ "ctl_step_{}.ckpt".format(current_step)))
+ logging.info("Saved checkpoint to %s", checkpoint_name)
+ else:
+ if self.use_tpu:
+ raise NotImplementedError(
+ "Keras model.fit on TPUs is not implemented.")
+ history = model.fit(
+ train_ds,
+ initial_epoch=current_iteration,
+ epochs=current_iteration + 1,
+ steps_per_epoch=train_steps_per_eval,
+ callbacks=callbacks,
+ # If TimeHistory is enabled, progress bar would be messy. Increase
+ # the verbose level to get rid of it.
+ verbose=(2 if flags_obj.enable_time_history else 1))
+ current_step += train_steps_per_eval
+ logging.info("Train history: {}".format(history.history))
+
+ logging.info("End train iteration at global step:{}".format(current_step))
+
+ if (flags_obj.bleu_source and flags_obj.bleu_ref):
+ uncased_score, cased_score = self.eval()
+ cased_score_history.append([current_iteration + 1, cased_score])
+ uncased_score_history.append([current_iteration + 1, uncased_score])
+
+ stats = ({
+ "loss": train_loss
+ } if history is None else {})
+ misc.update_stats(history, stats, callbacks)
+ if uncased_score and cased_score:
+ stats["bleu_uncased"] = uncased_score
+ stats["bleu_cased"] = cased_score
+ stats["bleu_uncased_history"] = uncased_score_history
+ stats["bleu_cased_history"] = cased_score_history
+ return stats
+
+ def eval(self):
+ """Evaluates the model."""
+ distribution_strategy = self.distribution_strategy if self.use_tpu else None
+
+ # We only want to create the model under DS scope for TPU case.
+ # When 'distribution_strategy' is None, a no-op DummyContextManager will
+ # be used.
+ with distribute_utils.get_strategy_scope(distribution_strategy):
+ if not self.predict_model:
+ self.predict_model = transformer.create_model(self.params, False)
+ self._load_weights_if_possible(
+ self.predict_model,
+ tf.train.latest_checkpoint(self.flags_obj.model_dir))
+ self.predict_model.summary()
+ return evaluate_and_log_bleu(
+ self.predict_model, self.params, self.flags_obj.bleu_source,
+ self.flags_obj.bleu_ref, self.flags_obj.vocab_file,
+ distribution_strategy)
+
+ def predict(self):
+ """Predicts result from the model."""
+ params = self.params
+ flags_obj = self.flags_obj
+
+ with tf.name_scope("model"):
+ model = transformer.create_model(params, is_train=False)
+ self._load_weights_if_possible(
+ model, tf.train.latest_checkpoint(self.flags_obj.model_dir))
+ model.summary()
+ subtokenizer = tokenizer.Subtokenizer(flags_obj.vocab_file)
+
+ ds = data_pipeline.eval_input_fn(params)
+ ds = ds.map(lambda x, y: x).take(_SINGLE_SAMPLE)
+ ret = model.predict(ds)
+ val_outputs, _ = ret
+ length = len(val_outputs)
+ for i in range(length):
+ translate.translate_from_input(val_outputs[i], subtokenizer)
+
+ def _create_callbacks(self, cur_log_dir, params):
+ """Creates a list of callbacks."""
+ callbacks = misc.get_callbacks()
+ if params["enable_checkpointing"]:
+ ckpt_full_path = os.path.join(cur_log_dir, "cp-{epoch:04d}.ckpt")
+ callbacks.append(
+ tf_keras.callbacks.ModelCheckpoint(
+ ckpt_full_path, save_weights_only=params["save_weights_only"]))
+ return callbacks
+
+ def _load_weights_if_possible(self, model, init_weight_path=None):
+ """Loads model weights when it is provided."""
+ if init_weight_path:
+ logging.info("Load weights: {}".format(init_weight_path))
+ if self.use_tpu:
+ checkpoint = tf.train.Checkpoint(
+ model=model, optimizer=self._create_optimizer())
+ checkpoint.restore(init_weight_path)
+ else:
+ model.load_weights(init_weight_path)
+ else:
+ logging.info("Weights not loaded from path:{}".format(init_weight_path))
+
+ def _create_optimizer(self):
+ """Creates optimizer."""
+ params = self.params
+ lr_schedule = optimizer.LearningRateSchedule(
+ params["learning_rate"], params["hidden_size"],
+ params["learning_rate_warmup_steps"])
+ opt = tf_keras.optimizers.Adam(
+ lr_schedule,
+ params["optimizer_adam_beta1"],
+ params["optimizer_adam_beta2"],
+ epsilon=params["optimizer_adam_epsilon"])
+
+ opt = performance.configure_optimizer(
+ opt,
+ use_float16=params["dtype"] == tf.float16,
+ loss_scale=flags_core.get_loss_scale(
+ self.flags_obj, default_for_fp16="dynamic"))
+
+ return opt
+
+
+def _ensure_dir(log_dir):
+ """Makes log dir if not existed."""
+ if not tf.io.gfile.exists(log_dir):
+ tf.io.gfile.makedirs(log_dir)
+
+
+def main(_):
+ flags_obj = flags.FLAGS
+ if flags_obj.enable_mlir_bridge:
+ tf.config.experimental.enable_mlir_bridge()
+ task = TransformerTask(flags_obj)
+
+ # Execute flag override logic for better model performance
+ if flags_obj.tf_gpu_thread_mode:
+ keras_utils.set_gpu_thread_mode_and_count(
+ per_gpu_thread_count=flags_obj.per_gpu_thread_count,
+ gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
+ num_gpus=flags_obj.num_gpus,
+ datasets_num_private_threads=flags_obj.datasets_num_private_threads)
+
+ if flags_obj.mode == "train":
+ task.train()
+ elif flags_obj.mode == "predict":
+ task.predict()
+ elif flags_obj.mode == "eval":
+ task.eval()
+ else:
+ raise ValueError("Invalid mode {}".format(flags_obj.mode))
+
+
+if __name__ == "__main__":
+ logging.set_verbosity(logging.INFO)
+ misc.define_transformer_flags()
+ app.run(main)
diff --git a/modeling/official/legacy/transformer/transformer_main_test.py b/modeling/official/legacy/transformer/transformer_main_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a669eba095e10a623f1b9ae5e98ce5d33ef57f6
--- /dev/null
+++ b/modeling/official/legacy/transformer/transformer_main_test.py
@@ -0,0 +1,193 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test Transformer model."""
+
+import os
+import re
+import sys
+import unittest
+
+from absl import flags
+from absl.testing import flagsaver
+import tensorflow as tf, tf_keras
+from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
+from official.legacy.transformer import misc
+from official.legacy.transformer import transformer_main
+
+FLAGS = flags.FLAGS
+FIXED_TIMESTAMP = 'my_time_stamp'
+WEIGHT_PATTERN = re.compile(r'weights-epoch-.+\.hdf5')
+
+
+def _generate_file(filepath, lines):
+ with open(filepath, 'w') as f:
+ for l in lines:
+ f.write('{}\n'.format(l))
+
+
+class TransformerTaskTest(tf.test.TestCase):
+ local_flags = None
+
+ def setUp(self): # pylint: disable=g-missing-super-call
+ temp_dir = self.get_temp_dir()
+ if TransformerTaskTest.local_flags is None:
+ misc.define_transformer_flags()
+ # Loads flags, array cannot be blank.
+ flags.FLAGS(['foo'])
+ TransformerTaskTest.local_flags = flagsaver.save_flag_values()
+ else:
+ flagsaver.restore_flag_values(TransformerTaskTest.local_flags)
+ FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
+ FLAGS.param_set = 'tiny'
+ FLAGS.use_synthetic_data = True
+ FLAGS.steps_between_evals = 1
+ FLAGS.train_steps = 1
+ FLAGS.validation_steps = 1
+ FLAGS.batch_size = 4
+ FLAGS.max_length = 1
+ FLAGS.num_gpus = 1
+ FLAGS.distribution_strategy = 'off'
+ FLAGS.dtype = 'fp32'
+ self.model_dir = FLAGS.model_dir
+ self.temp_dir = temp_dir
+ self.vocab_file = os.path.join(temp_dir, 'vocab')
+ self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size']
+ self.bleu_source = os.path.join(temp_dir, 'bleu_source')
+ self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
+ self.orig_policy = (
+ tf.compat.v2.keras.mixed_precision.global_policy())
+
+ def tearDown(self): # pylint: disable=g-missing-super-call
+ tf.compat.v2.keras.mixed_precision.set_global_policy(self.orig_policy)
+
+ def _assert_exists(self, filepath):
+ self.assertTrue(os.path.exists(filepath))
+
+ def test_train_no_dist_strat(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ def test_train_save_full_model(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ FLAGS.save_weights_only = False
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ def test_train_static_batch(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ FLAGS.distribution_strategy = 'one_device'
+ if tf.test.is_built_with_cuda():
+ FLAGS.num_gpus = 1
+ else:
+ FLAGS.num_gpus = 0
+ FLAGS.static_batch = True
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_1_gpu_with_dist_strat(self):
+ FLAGS.distribution_strategy = 'one_device'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_fp16(self):
+ FLAGS.distribution_strategy = 'one_device'
+ FLAGS.dtype = 'fp16'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_2_gpu(self):
+ if context.num_gpus() < 2:
+ self.skipTest(
+ '{} GPUs are not available for this test. {} GPUs are available'
+ .format(2, context.num_gpus()))
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.num_gpus = 2
+ FLAGS.param_set = 'base'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_train_2_gpu_fp16(self):
+ if context.num_gpus() < 2:
+ self.skipTest(
+ '{} GPUs are not available for this test. {} GPUs are available'
+ .format(2, context.num_gpus()))
+ FLAGS.distribution_strategy = 'mirrored'
+ FLAGS.num_gpus = 2
+ FLAGS.param_set = 'base'
+ FLAGS.dtype = 'fp16'
+ t = transformer_main.TransformerTask(FLAGS)
+ t.train()
+
+ def _prepare_files_and_flags(self, *extra_flags):
+ # Make log dir.
+ if not os.path.exists(self.temp_dir):
+ os.makedirs(self.temp_dir)
+
+ # Fake vocab, bleu_source and bleu_ref.
+ tokens = [
+ "''", "''", "'_'", "'a'", "'b'", "'c'", "'d'", "'a_'", "'b_'",
+ "'c_'", "'d_'"
+ ]
+ tokens += ["'{}'".format(i) for i in range(self.vocab_size - len(tokens))]
+ _generate_file(self.vocab_file, tokens)
+ _generate_file(self.bleu_source, ['a b', 'c d'])
+ _generate_file(self.bleu_ref, ['a b', 'd c'])
+
+ # Update flags.
+ update_flags = [
+ 'ignored_program_name',
+ '--vocab_file={}'.format(self.vocab_file),
+ '--bleu_source={}'.format(self.bleu_source),
+ '--bleu_ref={}'.format(self.bleu_ref),
+ ]
+ if extra_flags:
+ update_flags.extend(extra_flags)
+ FLAGS(update_flags)
+
+ def test_predict(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ self._prepare_files_and_flags()
+ t = transformer_main.TransformerTask(FLAGS)
+ t.predict()
+
+ @unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
+ def test_predict_fp16(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ self._prepare_files_and_flags('--dtype=fp16')
+ t = transformer_main.TransformerTask(FLAGS)
+ t.predict()
+
+ def test_eval(self):
+ if context.num_gpus() >= 2:
+ self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
+ if 'test_xla' in sys.argv[0]:
+ self.skipTest('TODO(xla): Make this test faster under XLA.')
+ self._prepare_files_and_flags()
+ t = transformer_main.TransformerTask(FLAGS)
+ t.eval()
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/legacy/transformer/transformer_test.py b/modeling/official/legacy/transformer/transformer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..5180735925525c2124f5e4846b39793feff77635
--- /dev/null
+++ b/modeling/official/legacy/transformer/transformer_test.py
@@ -0,0 +1,98 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test Transformer model."""
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer import model_params
+from official.legacy.transformer import transformer
+
+
+class TransformerV2Test(tf.test.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.params = params = model_params.TINY_PARAMS
+ params["batch_size"] = params["default_batch_size"] = 16
+ params["use_synthetic_data"] = True
+ params["hidden_size"] = 12
+ params["num_hidden_layers"] = 2
+ params["filter_size"] = 14
+ params["num_heads"] = 2
+ params["vocab_size"] = 41
+ params["extra_decode_length"] = 2
+ params["beam_size"] = 3
+ params["dtype"] = tf.float32
+
+ def test_create_model_train(self):
+ model = transformer.create_model(self.params, True)
+ inputs, outputs = model.inputs, model.outputs
+ self.assertEqual(len(inputs), 2)
+ self.assertEqual(len(outputs), 1)
+ self.assertEqual(inputs[0].shape.as_list(), [None, None])
+ self.assertEqual(inputs[0].dtype, tf.int64)
+ self.assertEqual(inputs[1].shape.as_list(), [None, None])
+ self.assertEqual(inputs[1].dtype, tf.int64)
+ self.assertEqual(outputs[0].shape.as_list(), [None, None, 41])
+ self.assertEqual(outputs[0].dtype, tf.float32)
+
+ def test_create_model_not_train(self):
+ model = transformer.create_model(self.params, False)
+ inputs, outputs = model.inputs, model.outputs
+ self.assertEqual(len(inputs), 1)
+ self.assertEqual(len(outputs), 2)
+ self.assertEqual(inputs[0].shape.as_list(), [None, None])
+ self.assertEqual(inputs[0].dtype, tf.int64)
+ self.assertEqual(outputs[0].shape.as_list(), [None, None])
+ self.assertEqual(outputs[0].dtype, tf.int32)
+ self.assertEqual(outputs[1].shape.as_list(), [None])
+ self.assertEqual(outputs[1].dtype, tf.float32)
+
+ def test_export(self):
+ model = transformer.Transformer(self.params, name="transformer_v2")
+ export_dir = self.get_temp_dir()
+ batch_size = 5
+ max_length = 6
+
+ class SaveModule(tf.Module):
+
+ def __init__(self, model):
+ super(SaveModule, self).__init__()
+ self.model = model
+
+ @tf.function
+ def serve(self, x):
+ return self.model.call([x], training=False)
+
+ save_module = SaveModule(model)
+ tensor_shape = (None, None)
+ sample_input = tf.zeros((batch_size, max_length), dtype=tf.int64)
+ _ = save_module.serve(sample_input)
+ signatures = dict(
+ serving_default=save_module.serve.get_concrete_function(
+ tf.TensorSpec(shape=tensor_shape, dtype=tf.int64, name="x")))
+ tf.saved_model.save(save_module, export_dir, signatures=signatures)
+ imported = tf.saved_model.load(export_dir)
+ serving_fn = imported.signatures["serving_default"]
+ all_outputs = serving_fn(sample_input)
+ output = all_outputs["outputs"]
+ output_shapes = output.shape.as_list()
+ self.assertEqual(output_shapes[0], batch_size)
+ self.assertEqual(output_shapes[1],
+ max_length + model.params["extra_decode_length"])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/transformer/translate.py b/modeling/official/legacy/transformer/translate.py
new file mode 100644
index 0000000000000000000000000000000000000000..d99c5a731890882fe7f182051aafe4447a0f785b
--- /dev/null
+++ b/modeling/official/legacy/transformer/translate.py
@@ -0,0 +1,190 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Translate text or files using trained transformer model."""
+
+# Import libraries
+from absl import logging
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer.utils import tokenizer
+
+_EXTRA_DECODE_LENGTH = 100
+_BEAM_SIZE = 4
+_ALPHA = 0.6
+
+
+def _get_sorted_inputs(filename):
+ """Read and sort lines from the file sorted by decreasing length.
+
+ Args:
+ filename: String name of file to read inputs from.
+ Returns:
+ Sorted list of inputs, and dictionary mapping original index->sorted index
+ of each element.
+ """
+ with tf.io.gfile.GFile(filename) as f:
+ records = f.read().split("\n")
+ inputs = [record.strip() for record in records]
+ if not inputs[-1]:
+ inputs.pop()
+
+ input_lens = [(i, len(line.split())) for i, line in enumerate(inputs)]
+ sorted_input_lens = sorted(input_lens, key=lambda x: x[1], reverse=True)
+
+ sorted_inputs = [None] * len(sorted_input_lens)
+ sorted_keys = [0] * len(sorted_input_lens)
+ for i, (index, _) in enumerate(sorted_input_lens):
+ sorted_inputs[i] = inputs[index]
+ sorted_keys[index] = i
+ return sorted_inputs, sorted_keys
+
+
+def _encode_and_add_eos(line, subtokenizer):
+ """Encode line with subtokenizer, and add EOS id to the end."""
+ return subtokenizer.encode(line) + [tokenizer.EOS_ID]
+
+
+def _trim_and_decode(ids, subtokenizer):
+ """Trim EOS and PAD tokens from ids, and decode to return a string."""
+ try:
+ index = list(ids).index(tokenizer.EOS_ID)
+ return subtokenizer.decode(ids[:index])
+ except ValueError: # No EOS found in sequence
+ return subtokenizer.decode(ids)
+
+
+def translate_file(model,
+ params,
+ subtokenizer,
+ input_file,
+ output_file=None,
+ print_all_translations=True,
+ distribution_strategy=None):
+ """Translate lines in file, and save to output file if specified.
+
+ Args:
+ model: A Keras model, used to generate the translations.
+ params: A dictionary, containing the translation related parameters.
+ subtokenizer: A subtokenizer object, used for encoding and decoding source
+ and translated lines.
+ input_file: A file containing lines to translate.
+ output_file: A file that stores the generated translations.
+ print_all_translations: A bool. If true, all translations are printed to
+ stdout.
+ distribution_strategy: A distribution strategy, used to perform inference
+ directly with tf.function instead of Keras model.predict().
+
+ Raises:
+ ValueError: if output file is invalid.
+ """
+ batch_size = params["decode_batch_size"]
+
+ # Read and sort inputs by length. Keep dictionary (original index-->new index
+ # in sorted list) to write translations in the original order.
+ sorted_inputs, sorted_keys = _get_sorted_inputs(input_file)
+ total_samples = len(sorted_inputs)
+ num_decode_batches = (total_samples - 1) // batch_size + 1
+
+ def input_generator():
+ """Yield encoded strings from sorted_inputs."""
+ for i in range(num_decode_batches):
+ lines = [
+ sorted_inputs[j + i * batch_size]
+ for j in range(batch_size)
+ if j + i * batch_size < total_samples
+ ]
+ lines = [_encode_and_add_eos(l, subtokenizer) for l in lines]
+ if distribution_strategy:
+ for j in range(batch_size - len(lines)):
+ lines.append([tokenizer.EOS_ID])
+ batch = tf_keras.preprocessing.sequence.pad_sequences(
+ lines,
+ maxlen=params["decode_max_length"],
+ dtype="int32",
+ padding="post")
+ logging.info("Decoding batch %d out of %d.", i, num_decode_batches)
+ yield batch
+
+ @tf.function
+ def predict_step(inputs):
+ """Decoding step function for TPU runs."""
+
+ def _step_fn(inputs):
+ """Per replica step function."""
+ tag = inputs[0]
+ val_inputs = inputs[1]
+ val_outputs, _ = model([val_inputs], training=False)
+ return tag, val_outputs
+
+ return distribution_strategy.run(_step_fn, args=(inputs,))
+
+ translations = []
+ if distribution_strategy:
+ num_replicas = distribution_strategy.num_replicas_in_sync
+ local_batch_size = params["decode_batch_size"] // num_replicas
+ for i, text in enumerate(input_generator()):
+ if distribution_strategy:
+ text = np.reshape(text, [num_replicas, local_batch_size, -1])
+ # Add tag to the input of each replica with the reordering logic after
+ # outputs, to ensure the output order matches the input order.
+ text = tf.constant(text)
+
+ @tf.function
+ def text_as_per_replica():
+ replica_context = tf.distribute.get_replica_context()
+ replica_id = replica_context.replica_id_in_sync_group
+ return replica_id, text[replica_id] # pylint: disable=cell-var-from-loop
+
+ text = distribution_strategy.run(text_as_per_replica)
+ outputs = distribution_strategy.experimental_local_results(
+ predict_step(text))
+ val_outputs = [output for _, output in outputs]
+
+ val_outputs = np.reshape(val_outputs, [params["decode_batch_size"], -1])
+ else:
+ val_outputs, _ = model.predict(text)
+
+ length = len(val_outputs)
+ for j in range(length):
+ if j + i * batch_size < total_samples:
+ translation = _trim_and_decode(val_outputs[j], subtokenizer)
+ translations.append(translation)
+ if print_all_translations:
+ logging.info("Translating:\n\tInput: %s\n\tOutput: %s",
+ sorted_inputs[j + i * batch_size], translation)
+
+ # Write translations in the order they appeared in the original file.
+ if output_file is not None:
+ if tf.io.gfile.isdir(output_file):
+ raise ValueError("File output is a directory, will not save outputs to "
+ "file.")
+ logging.info("Writing to file %s", output_file)
+ with tf.io.gfile.GFile(output_file, "w") as f:
+ for i in sorted_keys:
+ f.write("%s\n" % translations[i])
+
+
+def translate_from_text(model, subtokenizer, txt):
+ encoded_txt = _encode_and_add_eos(txt, subtokenizer)
+ result = model.predict(encoded_txt)
+ outputs = result["outputs"]
+ logging.info("Original: \"%s\"", txt)
+ translate_from_input(outputs, subtokenizer)
+
+
+def translate_from_input(outputs, subtokenizer):
+ translation = _trim_and_decode(outputs, subtokenizer)
+ logging.info("Translation: \"%s\"", translation)
diff --git a/modeling/official/legacy/transformer/utils/__init__.py b/modeling/official/legacy/transformer/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/legacy/transformer/utils/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/legacy/transformer/utils/metrics.py b/modeling/official/legacy/transformer/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..53e252e1e74aed24160e9162f90525862d5f2657
--- /dev/null
+++ b/modeling/official/legacy/transformer/utils/metrics.py
@@ -0,0 +1,491 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions for calculating loss, accuracy, and other model metrics.
+
+Metrics:
+ - Padded loss, accuracy, and negative log perplexity. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/metrics.py
+ - BLEU approximation. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py
+ - ROUGE score. Source:
+ https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import math
+
+import numpy as np
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow.compat.v1 as tf
+
+
+def _pad_tensors_to_same_length(x, y):
+ """Pad x and y so that the results have the same length (second dimension)."""
+ with tf.name_scope("pad_to_same_length"):
+ x_length = tf.shape(x)[1]
+ y_length = tf.shape(y)[1]
+
+ max_length = tf.maximum(x_length, y_length)
+
+ x = tf.pad(x, [[0, 0], [0, max_length - x_length], [0, 0]])
+ y = tf.pad(y, [[0, 0], [0, max_length - y_length]])
+ return x, y
+
+
+def padded_cross_entropy_loss(logits, labels, smoothing, vocab_size):
+ """Calculate cross entropy loss while ignoring padding.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch_size, length_labels]
+ smoothing: Label smoothing constant, used to determine the on and off values
+ vocab_size: int size of the vocabulary
+ Returns:
+ Returns the cross entropy loss and weight tensors: float32 tensors with
+ shape [batch_size, max(length_logits, length_labels)]
+ """
+ with tf.name_scope("loss", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+
+ # Calculate smoothing cross entropy
+ with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
+ confidence = 1.0 - smoothing
+ low_confidence = (1.0 - confidence) / tf.cast(vocab_size - 1, tf.float32)
+ soft_targets = tf.one_hot(
+ tf.cast(labels, tf.int32),
+ depth=vocab_size,
+ on_value=confidence,
+ off_value=low_confidence)
+ xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
+ logits=logits, labels=soft_targets)
+
+ # Calculate the best (lowest) possible value of cross entropy, and
+ # subtract from the cross entropy loss.
+ normalizing_constant = -(
+ confidence * tf.log(confidence) + tf.cast(vocab_size - 1, tf.float32)
+ * low_confidence * tf.log(low_confidence + 1e-20))
+ xentropy -= normalizing_constant
+
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ return xentropy * weights, weights
+
+
+def _convert_to_eval_metric(metric_fn):
+ """Wrap a metric fn that returns scores and weights as an eval metric fn.
+
+ The input metric_fn returns values for the current batch. The wrapper
+ aggregates the return values collected over all of the batches evaluated.
+
+ Args:
+ metric_fn: function that returns scores and weights for the current batch's
+ logits and predicted labels.
+
+ Returns:
+ function that aggregates the scores and weights from metric_fn.
+ """
+ def problem_metric_fn(*args):
+ """Returns an aggregation of the metric_fn's returned values."""
+ (scores, weights) = metric_fn(*args)
+
+ # The tf.metrics.mean function assures correct aggregation.
+ return tf.metrics.mean(scores, weights)
+ return problem_metric_fn
+
+
+def get_eval_metrics(logits, labels, params):
+ """Return dictionary of model evaluation metrics."""
+ metrics = {
+ "accuracy": _convert_to_eval_metric(padded_accuracy)(logits, labels),
+ "accuracy_top5": _convert_to_eval_metric(padded_accuracy_top5)(
+ logits, labels),
+ "accuracy_per_sequence": _convert_to_eval_metric(
+ padded_sequence_accuracy)(logits, labels),
+ "neg_log_perplexity": _convert_to_eval_metric(padded_neg_log_perplexity)(
+ logits, labels, params["vocab_size"]),
+ }
+
+ if not params["use_tpu"]:
+ # TPU does not support tf.py_func
+ metrics.update({
+ "approx_bleu_score": _convert_to_eval_metric(
+ bleu_score)(logits, labels),
+ "rouge_2_fscore": _convert_to_eval_metric(
+ rouge_2_fscore)(logits, labels),
+ "rouge_L_fscore": _convert_to_eval_metric(
+ rouge_l_fscore)(logits, labels),
+ })
+
+ # Prefix each of the metric names with "metrics/". This allows the metric
+ # graphs to display under the "metrics" category in TensorBoard.
+ metrics = {"metrics/%s" % k: v for k, v in six.iteritems(metrics)}
+ return metrics
+
+
+def padded_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels on non-0s."""
+ with tf.variable_scope("padded_accuracy", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ return tf.cast(tf.equal(outputs, padded_labels), tf.float32), weights
+
+
+def padded_accuracy_topk(logits, labels, k):
+ """Percentage of times that top-k predictions matches labels on non-0s."""
+ with tf.variable_scope("padded_accuracy_topk", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ effective_k = tf.minimum(k, tf.shape(logits)[-1])
+ _, outputs = tf.nn.top_k(logits, k=effective_k)
+ outputs = tf.cast(outputs, tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ padded_labels = tf.expand_dims(padded_labels, axis=-1)
+ padded_labels += tf.zeros_like(outputs) # Pad to same shape.
+ same = tf.cast(tf.equal(outputs, padded_labels), tf.float32)
+ same_topk = tf.reduce_sum(same, axis=-1)
+ return same_topk, weights
+
+
+def padded_accuracy_top5(logits, labels):
+ return padded_accuracy_topk(logits, labels, 5)
+
+
+def padded_sequence_accuracy(logits, labels):
+ """Percentage of times that predictions matches labels everywhere (non-0)."""
+ with tf.variable_scope("padded_sequence_accuracy", values=[logits, labels]):
+ logits, labels = _pad_tensors_to_same_length(logits, labels)
+ weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
+ outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ padded_labels = tf.cast(labels, tf.int32)
+ not_correct = (tf.cast(tf.not_equal(outputs, padded_labels), tf.float32) *
+ weights)
+ axis = list(range(1, len(outputs.get_shape())))
+ correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(not_correct, axis=axis))
+ return correct_seq, tf.constant(1.0)
+
+
+def padded_neg_log_perplexity(logits, labels, vocab_size):
+ """Average log-perplexity excluding padding 0s. No smoothing."""
+ num, den = padded_cross_entropy_loss(logits, labels, 0, vocab_size)
+ return -num, den
+
+
+def bleu_score(logits, labels):
+ """Approximate BLEU score computation between labels and predictions.
+
+ An approximate BLEU scoring method since we do not glue word pieces or
+ decode the ids and tokenize the output. By default, we use ngram order of 4
+ and use brevity penalty. Also, this does not have beam search.
+
+ Args:
+ logits: Tensor of size [batch_size, length_logits, vocab_size]
+ labels: Tensor of size [batch-size, length_labels]
+
+ Returns:
+ bleu: int, approx bleu score
+ """
+ predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ # TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
+ bleu = tf.py_func(compute_bleu, (labels, predictions), tf.float32)
+ return bleu, tf.constant(1.0)
+
+
+def _get_ngrams_with_counter(segment, max_order):
+ """Extracts all n-grams up to a given maximum order from an input segment.
+
+ Args:
+ segment: text segment from which n-grams will be extracted.
+ max_order: maximum length in tokens of the n-grams returned by this
+ methods.
+
+ Returns:
+ The Counter containing all n-grams upto max_order in segment
+ with a count of how many times each n-gram occurred.
+ """
+ ngram_counts = collections.Counter()
+ for order in xrange(1, max_order + 1):
+ for i in xrange(0, len(segment) - order + 1):
+ ngram = tuple(segment[i:i + order])
+ ngram_counts[ngram] += 1
+ return ngram_counts
+
+
+def compute_bleu(reference_corpus, translation_corpus, max_order=4,
+ use_bp=True):
+ """Computes BLEU score of translated segments against one or more references.
+
+ Args:
+ reference_corpus: list of references for each translation. Each
+ reference should be tokenized into a list of tokens.
+ translation_corpus: list of translations to score. Each translation
+ should be tokenized into a list of tokens.
+ max_order: Maximum n-gram order to use when computing BLEU score.
+ use_bp: boolean, whether to apply brevity penalty.
+
+ Returns:
+ BLEU score.
+ """
+ reference_length = 0
+ translation_length = 0
+ bp = 1.0
+ geo_mean = 0
+
+ matches_by_order = [0] * max_order
+ possible_matches_by_order = [0] * max_order
+ precisions = []
+
+ for (references, translations) in zip(reference_corpus, translation_corpus):
+ reference_length += len(references)
+ translation_length += len(translations)
+ ref_ngram_counts = _get_ngrams_with_counter(references, max_order)
+ translation_ngram_counts = _get_ngrams_with_counter(translations, max_order)
+
+ overlap = dict((ngram,
+ min(count, translation_ngram_counts[ngram]))
+ for ngram, count in ref_ngram_counts.items())
+
+ for ngram in overlap:
+ matches_by_order[len(ngram) - 1] += overlap[ngram]
+ for ngram in translation_ngram_counts:
+ possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[
+ ngram]
+
+ precisions = [0] * max_order
+ smooth = 1.0
+
+ for i in xrange(0, max_order):
+ if possible_matches_by_order[i] > 0:
+ precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i]
+ if matches_by_order[i] > 0:
+ precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[
+ i]
+ else:
+ smooth *= 2
+ precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
+ else:
+ precisions[i] = 0.0
+
+ if max(precisions) > 0:
+ p_log_sum = sum(math.log(p) for p in precisions if p)
+ geo_mean = math.exp(p_log_sum / max_order)
+
+ if use_bp:
+ ratio = translation_length / reference_length
+ bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
+ bleu = geo_mean * bp
+ return np.float32(bleu)
+
+
+def rouge_2_fscore(logits, labels):
+ """ROUGE-2 F1 score computation between labels and predictions.
+
+ This is an approximate ROUGE scoring method since we do not glue word pieces
+ or decode the ids and tokenize the output.
+
+ Args:
+ logits: tensor, model predictions
+ labels: tensor, gold output.
+
+ Returns:
+ rouge2_fscore: approx rouge-2 f1 score.
+ """
+ predictions = tf.cast(tf.argmax(logits, axis=-1), tf.int32)
+ # TODO: Look into removing use of py_func # pylint: disable=g-bad-todo
+ rouge_2_f_score = tf.py_func(rouge_n, (predictions, labels), tf.float32)
+ return rouge_2_f_score, tf.constant(1.0)
+
+
+def _get_ngrams(n, text):
+ """Calculates n-grams.
+
+ Args:
+ n: which n-grams to calculate
+ text: An array of tokens
+
+ Returns:
+ A set of n-grams
+ """
+ ngram_set = set()
+ text_length = len(text)
+ max_index_ngram_start = text_length - n
+ for i in range(max_index_ngram_start + 1):
+ ngram_set.add(tuple(text[i:i + n]))
+ return ngram_set
+
+
+def rouge_n(eval_sentences, ref_sentences, n=2):
+ """Computes ROUGE-N f1 score of two text collections of sentences.
+
+ Source: https://www.microsoft.com/en-us/research/publication/
+ rouge-a-package-for-automatic-evaluation-of-summaries/
+
+ Args:
+ eval_sentences: Predicted sentences.
+ ref_sentences: Sentences from the reference set
+ n: Size of ngram. Defaults to 2.
+
+ Returns:
+ f1 score for ROUGE-N
+ """
+ f1_scores = []
+ for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
+ eval_ngrams = _get_ngrams(n, eval_sentence)
+ ref_ngrams = _get_ngrams(n, ref_sentence)
+ ref_count = len(ref_ngrams)
+ eval_count = len(eval_ngrams)
+
+ # Count the overlapping ngrams between evaluated and reference
+ overlapping_ngrams = eval_ngrams.intersection(ref_ngrams)
+ overlapping_count = len(overlapping_ngrams)
+
+ # Handle edge case. This isn't mathematically correct, but it's good enough
+ if eval_count == 0:
+ precision = 0.0
+ else:
+ precision = float(overlapping_count) / eval_count
+ if ref_count == 0:
+ recall = 0.0
+ else:
+ recall = float(overlapping_count) / ref_count
+ f1_scores.append(2.0 * ((precision * recall) / (precision + recall + 1e-8)))
+
+ # return overlapping_count / reference_count
+ return np.mean(f1_scores, dtype=np.float32)
+
+
+def rouge_l_fscore(predictions, labels):
+ """ROUGE scores computation between labels and predictions.
+
+ This is an approximate ROUGE scoring method since we do not glue word pieces
+ or decode the ids and tokenize the output.
+
+ Args:
+ predictions: tensor, model predictions
+ labels: tensor, gold output.
+
+ Returns:
+ rouge_l_fscore: approx rouge-l f1 score.
+ """
+ outputs = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
+ rouge_l_f_score = tf.py_func(rouge_l_sentence_level, (outputs, labels),
+ tf.float32)
+ return rouge_l_f_score, tf.constant(1.0)
+
+
+def rouge_l_sentence_level(eval_sentences, ref_sentences):
+ """Computes ROUGE-L (sentence level) of two collections of sentences.
+
+ Source: https://www.microsoft.com/en-us/research/publication/
+ rouge-a-package-for-automatic-evaluation-of-summaries/
+
+ Calculated according to:
+ R_lcs = LCS(X,Y)/m
+ P_lcs = LCS(X,Y)/n
+ F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs)
+
+ where:
+ X = reference summary
+ Y = Candidate summary
+ m = length of reference summary
+ n = length of candidate summary
+
+ Args:
+ eval_sentences: The sentences that have been picked by the summarizer
+ ref_sentences: The sentences from the reference set
+
+ Returns:
+ A float: F_lcs
+ """
+
+ f1_scores = []
+ for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences):
+ m = float(len(ref_sentence))
+ n = float(len(eval_sentence))
+ lcs = _len_lcs(eval_sentence, ref_sentence)
+ f1_scores.append(_f_lcs(lcs, m, n))
+ return np.mean(f1_scores, dtype=np.float32)
+
+
+def _len_lcs(x, y):
+ """Returns the length of the Longest Common Subsequence between two seqs.
+
+ Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
+
+ Args:
+ x: sequence of words
+ y: sequence of words
+
+ Returns
+ integer: Length of LCS between x and y
+ """
+ table = _lcs(x, y)
+ n, m = len(x), len(y)
+ return table[n, m]
+
+
+def _lcs(x, y):
+ """Computes the length of the LCS between two seqs.
+
+ The implementation below uses a DP programming algorithm and runs
+ in O(nm) time where n = len(x) and m = len(y).
+ Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence
+
+ Args:
+ x: collection of words
+ y: collection of words
+
+ Returns:
+ Table of dictionary of coord and len lcs
+ """
+ n, m = len(x), len(y)
+ table = dict()
+ for i in range(n + 1):
+ for j in range(m + 1):
+ if i == 0 or j == 0:
+ table[i, j] = 0
+ elif x[i - 1] == y[j - 1]:
+ table[i, j] = table[i - 1, j - 1] + 1
+ else:
+ table[i, j] = max(table[i - 1, j], table[i, j - 1])
+ return table
+
+
+def _f_lcs(llcs, m, n):
+ """Computes the LCS-based F-measure score.
+
+ Source: http://research.microsoft.com/en-us/um/people/cyl/download/papers/
+ rouge-working-note-v1.3.1.pdf
+
+ Args:
+ llcs: Length of LCS
+ m: number of words in reference summary
+ n: number of words in candidate summary
+
+ Returns:
+ Float. LCS-based F-measure score
+ """
+ r_lcs = llcs / m
+ p_lcs = llcs / n
+ beta = p_lcs / (r_lcs + 1e-12)
+ num = (1 + (beta ** 2)) * r_lcs * p_lcs
+ denom = r_lcs + ((beta ** 2) * p_lcs)
+ f_lcs = num / (denom + 1e-12)
+ return f_lcs
diff --git a/modeling/official/legacy/transformer/utils/tokenizer.py b/modeling/official/legacy/transformer/utils/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfb3bd52dd6c4adce743435e076b23bd4ba2eceb
--- /dev/null
+++ b/modeling/official/legacy/transformer/utils/tokenizer.py
@@ -0,0 +1,660 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Defines Subtokenizer class to encode and decode strings."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+import sys
+import unicodedata
+
+from absl import logging
+
+import numpy as np
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+import tensorflow as tf, tf_keras
+
+# pylint: disable=g-complex-comprehension
+PAD = ""
+PAD_ID = 0
+EOS = ""
+EOS_ID = 1
+RESERVED_TOKENS = [PAD, EOS]
+
+# Set of characters that will be used in the function _escape_token() (see func
+# docstring for more details).
+# This set is added to the alphabet list to ensure that all escaped tokens can
+# be encoded.
+_ESCAPE_CHARS = set(u"\\_u;0123456789")
+# Regex for the function _unescape_token(), the inverse of _escape_token().
+# This is used to find "\u", "\\", and "\###;" substrings in the token.
+_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
+
+_UNDEFINED_UNICODE = u"\u3013"
+
+
+def alphanumeric_char_set():
+ return set(
+ six.unichr(i)
+ for i in xrange(sys.maxunicode)
+ if (unicodedata.category(six.unichr(i)).startswith("L") or
+ unicodedata.category(six.unichr(i)).startswith("N")))
+
+
+# Set contains all letter and number characters.
+_ALPHANUMERIC_CHAR_SET = alphanumeric_char_set()
+
+# min_count is the minimum number of times a subtoken must appear in the data
+# before before it is added to the vocabulary. The value is found using binary
+# search to obtain the target vocabulary size.
+_MIN_MIN_COUNT = 1 # min value to use when binary searching for min_count
+_MAX_MIN_COUNT = 1000 # max value to use when binary searching for min_count
+
+
+class Subtokenizer(object):
+ """Encodes and decodes strings to/from integer IDs."""
+
+ def __init__(self, vocab_file, reserved_tokens=None, master_char_set=None):
+ """Initializes class, creating a vocab file if data_files is provided."""
+ logging.info("Initializing Subtokenizer from file %s.", vocab_file)
+
+ if master_char_set is None:
+ master_char_set = _ALPHANUMERIC_CHAR_SET
+
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ self.subtoken_list = _load_vocab_file(vocab_file, reserved_tokens)
+ self.alphabet = _generate_alphabet_dict(self.subtoken_list)
+ self.subtoken_to_id_dict = _list_to_index_dict(self.subtoken_list)
+
+ self.max_subtoken_length = 0
+ for subtoken in self.subtoken_list:
+ self.max_subtoken_length = max(self.max_subtoken_length, len(subtoken))
+
+ # Create cache to speed up subtokenization
+ self._cache_size = 2**20
+ self._cache = [(None, None)] * self._cache_size
+ self._master_char_set = master_char_set
+
+ @staticmethod
+ def init_from_files(vocab_file,
+ files,
+ target_vocab_size,
+ threshold,
+ min_count=None,
+ file_byte_limit=1e6,
+ reserved_tokens=None,
+ correct_strip=True,
+ master_char_set=None):
+ """Create subtoken vocabulary based on files, and save vocab to file.
+
+ Args:
+ vocab_file: String name of vocab file to store subtoken vocabulary.
+ files: List of file paths that will be used to generate vocabulary.
+ target_vocab_size: target vocabulary size to generate.
+ threshold: int threshold of vocabulary size to accept.
+ min_count: int minimum count to use for generating the vocabulary. The min
+ count is the minimum number of times a subtoken should appear in the
+ files before it is added to the vocabulary. If set to none, this value
+ is found using binary search.
+ file_byte_limit: (Default 1e6) Maximum number of bytes of sample text that
+ will be drawn from the files.
+ reserved_tokens: List of string tokens that are guaranteed to be at the
+ beginning of the subtoken vocabulary list.
+ correct_strip: Whether to convert text to unicode before strip.
+ master_char_set: the char set.
+
+ Returns:
+ Subtokenizer object
+ """
+ if master_char_set is None:
+ master_char_set = _ALPHANUMERIC_CHAR_SET
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ if tf.io.gfile.exists(vocab_file):
+ logging.info("Vocab file already exists (%s)", vocab_file)
+ else:
+ logging.info("Begin steps to create subtoken vocabulary...")
+ token_counts = _count_tokens(files, file_byte_limit, correct_strip,
+ master_char_set)
+ alphabet = _generate_alphabet_dict(token_counts)
+ subtoken_list = _generate_subtokens_with_target_vocab_size(
+ token_counts, alphabet, target_vocab_size, threshold, min_count,
+ reserved_tokens)
+ logging.info("Generated vocabulary with %d subtokens.",
+ len(subtoken_list))
+ _save_vocab_file(vocab_file, subtoken_list)
+ return Subtokenizer(vocab_file, master_char_set=master_char_set)
+
+ def encode(self, raw_string, add_eos=False):
+ """Encodes a string into a list of int subtoken ids."""
+ ret = []
+ tokens = _split_string_to_tokens(
+ native_to_unicode(raw_string), self._master_char_set)
+ for token in tokens:
+ ret.extend(self._token_to_subtoken_ids(token))
+ if add_eos:
+ assert EOS in self.subtoken_list, \
+ "Can't append 'EOS' because it is not in list of known subtokens."
+ ret.append(EOS_ID)
+ return ret
+
+ def _token_to_subtoken_ids(self, token):
+ """Encode a single token into a list of subtoken ids."""
+ cache_location = hash(token) % self._cache_size
+ cache_key, cache_value = self._cache[cache_location]
+ if cache_key == token:
+ return cache_value
+
+ ret = _split_token_to_subtokens(
+ _escape_token(token, self.alphabet), self.subtoken_to_id_dict,
+ self.max_subtoken_length)
+ ret = [self.subtoken_to_id_dict[subtoken_id] for subtoken_id in ret]
+
+ self._cache[cache_location] = (token, ret)
+ return ret
+
+ def decode(self, subtokens):
+ """Converts list of int subtokens ids into a string."""
+ if isinstance(subtokens, np.ndarray):
+ # Note that list(subtokens) converts subtokens to a python list, but the
+ # items remain as np.int32. This converts both the array and its items.
+ subtokens = subtokens.tolist()
+
+ if not subtokens:
+ return ""
+
+ assert isinstance(subtokens, list) and isinstance(subtokens[0], int), (
+ "Subtokens argument passed into decode() must be a list of integers.")
+
+ return _unicode_to_native(
+ _join_tokens_to_string(
+ self._subtoken_ids_to_tokens(subtokens), self._master_char_set))
+
+ def _subtoken_ids_to_tokens(self, subtokens):
+ """Convert list of int subtoken ids to a list of string tokens."""
+ escaped_tokens = "".join([
+ self.subtoken_list[s] for s in subtokens if s < len(self.subtoken_list)
+ ])
+ escaped_tokens = escaped_tokens.split("_")
+
+ # All tokens in the vocabulary list have been escaped (see _escape_token())
+ # so each token must be unescaped when decoding.
+ ret = []
+ for token in escaped_tokens:
+ if token:
+ ret.append(_unescape_token(token))
+ return ret
+
+
+def _save_vocab_file(vocab_file, subtoken_list):
+ """Save subtokens to file."""
+ with tf.io.gfile.GFile(vocab_file, mode="w") as f:
+ for subtoken in subtoken_list:
+ f.write("'%s'\n" % _unicode_to_native(subtoken))
+
+
+def _load_vocab_file(vocab_file, reserved_tokens=None):
+ """Load vocabulary while ensuring reserved tokens are at the top."""
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ subtoken_list = []
+ with tf.io.gfile.GFile(vocab_file, mode="r") as f:
+ for line in f:
+ subtoken = native_to_unicode(line.strip())
+ subtoken = subtoken[1:-1] # Remove surrounding single-quotes
+ if subtoken in reserved_tokens:
+ continue
+ subtoken_list.append(native_to_unicode(subtoken))
+ return reserved_tokens + subtoken_list
+
+
+def native_to_unicode(s):
+ """Convert string to unicode (required in Python 2)."""
+ try: # Python 2
+ return s if isinstance(s, unicode) else s.decode("utf-8")
+ except NameError: # Python 3
+ return s
+
+
+def _unicode_to_native(s):
+ """Convert string from unicode to native format (required in Python 2)."""
+ try: # Python 2
+ return s.encode("utf-8") if isinstance(s, unicode) else s
+ except NameError: # Python 3
+ return s
+
+
+def _split_string_to_tokens(text, master_char_set):
+ """Splits text to a list of string tokens."""
+ if not text:
+ return []
+ ret = []
+ token_start = 0
+ # Classify each character in the input string
+ is_master = [c in master_char_set for c in text]
+ for pos in xrange(1, len(text)):
+ if is_master[pos] != is_master[pos - 1]:
+ token = text[token_start:pos]
+ if token != u" " or token_start == 0:
+ ret.append(token)
+ token_start = pos
+ final_token = text[token_start:]
+ ret.append(final_token)
+ return ret
+
+
+def _join_tokens_to_string(tokens, master_char_set):
+ """Join a list of string tokens into a single string."""
+ token_is_master = [t[0] in master_char_set for t in tokens]
+ ret = []
+ for i, token in enumerate(tokens):
+ if i > 0 and token_is_master[i - 1] and token_is_master[i]:
+ ret.append(u" ")
+ ret.append(token)
+ return "".join(ret)
+
+
+def _escape_token(token, alphabet):
+ r"""Replace characters that aren't in the alphabet and append "_" to token.
+
+ Apply three transformations to the token:
+ 1. Replace underline character "_" with "\u", and backslash "\" with "\\".
+ 2. Replace characters outside of the alphabet with "\###;", where ### is the
+ character's Unicode code point.
+ 3. Appends "_" to mark the end of a token.
+
+ Args:
+ token: unicode string to be escaped
+ alphabet: list of all known characters
+
+ Returns:
+ escaped string
+ """
+ token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
+ ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
+ return u"".join(ret) + "_"
+
+
+def _unescape_token(token):
+ r"""Replaces escaped characters in the token with their unescaped versions.
+
+ Applies inverse transformations as _escape_token():
+ 1. Replace "\u" with "_", and "\\" with "\".
+ 2. Replace "\###;" with the unicode character the ### refers to.
+
+ Args:
+ token: escaped string
+
+ Returns:
+ unescaped string
+ """
+
+ def match(m):
+ r"""Returns replacement string for matched object.
+
+ Matched objects contain one of the strings that matches the regex pattern:
+ r"\\u|\\\\|\\([0-9]+);"
+ The strings can be '\u', '\\', or '\###;' (### is any digit number).
+
+ m.group(0) refers to the entire matched string ('\u', '\\', or '\###;').
+ m.group(1) refers to the first parenthesized subgroup ('###').
+
+ m.group(0) exists for all match objects, while m.group(1) exists only for
+ the string '\###;'.
+
+ This function looks to see if m.group(1) exists. If it doesn't, then the
+ matched string must be '\u' or '\\' . In this case, the corresponding
+ replacement ('_' and '\') are returned. Note that in python, a single
+ backslash is written as '\\', and double backslash as '\\\\'.
+
+ If m.goup(1) exists, then use the integer in m.group(1) to return a
+ unicode character.
+
+ Args:
+ m: match object
+
+ Returns:
+ String to replace matched object with.
+ """
+ # Check if the matched strings are '\u' or '\\'.
+ if m.group(1) is None:
+ return u"_" if m.group(0) == u"\\u" else u"\\"
+
+ # If m.group(1) exists, try and return unicode character.
+ try:
+ return six.unichr(int(m.group(1)))
+ except (ValueError, OverflowError) as _:
+ return _UNDEFINED_UNICODE
+
+ # Use match function to replace escaped substrings in the token.
+ return _UNESCAPE_REGEX.sub(match, token)
+
+
+def _count_tokens(files,
+ file_byte_limit=1e6,
+ correct_strip=True,
+ master_char_set=None):
+ """Return token counts of words in the files.
+
+ Samples file_byte_limit bytes from each file, and counts the words that appear
+ in the samples. The samples are semi-evenly distributed across the file.
+
+ Args:
+ files: List of filepaths
+ file_byte_limit: Max number of bytes that will be read from each file.
+ correct_strip: Whether to convert text to unicode before strip. This affects
+ vocabulary generation for PY2. Sets correct_strip to False in PY2 to
+ reproduce previous common public result. Sets correct_strip to True will
+ let PY2 and PY3 get a consistent vocabulary.
+ master_char_set: the char set.
+
+ Returns:
+ Dictionary mapping tokens to the number of times they appear in the sampled
+ lines from the files.
+ """
+ if master_char_set is None:
+ master_char_set = _ALPHANUMERIC_CHAR_SET
+
+ token_counts = collections.defaultdict(int)
+
+ for filepath in files:
+ with tf.io.gfile.GFile(filepath, mode="r") as reader:
+ file_byte_budget = file_byte_limit
+ counter = 0
+ lines_to_skip = int(reader.size() / (file_byte_budget * 2))
+ for line in reader:
+ if counter < lines_to_skip:
+ counter += 1
+ else:
+ if file_byte_budget < 0:
+ break
+ if correct_strip:
+ line = native_to_unicode(line)
+ line = line.strip()
+ file_byte_budget -= len(line)
+ counter = 0
+
+ # Add words to token counts
+ for token in _split_string_to_tokens(
+ native_to_unicode(line), master_char_set):
+ token_counts[token] += 1
+ return token_counts
+
+
+def _list_to_index_dict(lst):
+ """Create dictionary mapping list items to their indices in the list."""
+ return {item: n for n, item in enumerate(lst)}
+
+
+def _split_token_to_subtokens(token, subtoken_dict, max_subtoken_length):
+ """Splits a token into subtokens defined in the subtoken dict."""
+ ret = []
+ start = 0
+ token_len = len(token)
+ while start < token_len:
+ # Find the longest subtoken, so iterate backwards.
+ for end in xrange(min(token_len, start + max_subtoken_length), start, -1):
+ subtoken = token[start:end]
+ if subtoken in subtoken_dict:
+ ret.append(subtoken)
+ start = end
+ break
+ else: # Did not break
+ # If there is no possible encoding of the escaped token then one of the
+ # characters in the token is not in the alphabet. This should be
+ # impossible and would be indicative of a bug.
+ raise ValueError("Was unable to split token \"%s\" into subtokens." %
+ token)
+ return ret
+
+
+def _generate_subtokens_with_target_vocab_size(token_counts,
+ alphabet,
+ target_size,
+ threshold,
+ min_count=None,
+ reserved_tokens=None):
+ """Generate subtoken vocabulary close to the target size."""
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ if min_count is not None:
+ logging.info("Using min_count=%d to generate vocab with target size %d",
+ min_count, target_size)
+ return _generate_subtokens(
+ token_counts, alphabet, min_count, reserved_tokens=reserved_tokens)
+
+ def bisect(min_val, max_val):
+ """Recursive function to binary search for subtoken vocabulary."""
+ cur_count = (min_val + max_val) // 2
+ logging.info("Binary search: trying min_count=%d (%d %d)", cur_count,
+ min_val, max_val)
+ subtoken_list = _generate_subtokens(
+ token_counts, alphabet, cur_count, reserved_tokens=reserved_tokens)
+
+ val = len(subtoken_list)
+ logging.info("Binary search: min_count=%d resulted in %d tokens", cur_count,
+ val)
+
+ within_threshold = abs(val - target_size) < threshold
+ if within_threshold or min_val >= max_val or cur_count < 2:
+ return subtoken_list
+ if val > target_size:
+ other_subtoken_list = bisect(cur_count + 1, max_val)
+ else:
+ other_subtoken_list = bisect(min_val, cur_count - 1)
+
+ # Return vocabulary dictionary with the closest number of tokens.
+ other_val = len(other_subtoken_list)
+ if abs(other_val - target_size) < abs(val - target_size):
+ return other_subtoken_list
+ return subtoken_list
+
+ logging.info("Finding best min_count to get target size of %d", target_size)
+ return bisect(_MIN_MIN_COUNT, _MAX_MIN_COUNT)
+
+
+def _generate_alphabet_dict(iterable, reserved_tokens=None):
+ """Create set of characters that appear in any element in the iterable."""
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+ alphabet = {c for token in iterable for c in token}
+ alphabet |= {c for token in reserved_tokens for c in token}
+ alphabet |= _ESCAPE_CHARS # Add escape characters to alphabet set.
+ return alphabet
+
+
+def _count_and_gen_subtokens(token_counts, alphabet, subtoken_dict,
+ max_subtoken_length):
+ """Count number of times subtokens appear, and generate new subtokens.
+
+ Args:
+ token_counts: dict mapping tokens to the number of times they appear in the
+ original files.
+ alphabet: list of allowed characters. Used to escape the tokens, which
+ guarantees that all tokens can be split into subtokens.
+ subtoken_dict: dict mapping subtokens to ids.
+ max_subtoken_length: maximum length of subtoken in subtoken_dict.
+
+ Returns:
+ A defaultdict mapping subtokens to the number of times they appear in the
+ tokens. The dict may contain new subtokens.
+ """
+ subtoken_counts = collections.defaultdict(int)
+ for token, count in six.iteritems(token_counts):
+ token = _escape_token(token, alphabet)
+ subtokens = _split_token_to_subtokens(token, subtoken_dict,
+ max_subtoken_length)
+
+ # Generate new subtokens by taking substrings from token.
+ start = 0
+ for subtoken in subtokens:
+ for end in xrange(start + 1, len(token) + 1):
+ new_subtoken = token[start:end]
+ subtoken_counts[new_subtoken] += count
+ start += len(subtoken)
+
+ return subtoken_counts
+
+
+def _filter_and_bucket_subtokens(subtoken_counts, min_count):
+ """Return a bucketed list of subtokens that are filtered by count.
+
+ Args:
+ subtoken_counts: defaultdict mapping subtokens to their counts
+ min_count: int count used to filter subtokens
+
+ Returns:
+ List of subtoken sets, where subtokens in set i have the same length=i.
+ """
+ # Create list of buckets, where subtokens in bucket i have length i.
+ subtoken_buckets = []
+ for subtoken, count in six.iteritems(subtoken_counts):
+ if count < min_count: # Filter out subtokens that don't appear enough
+ continue
+ while len(subtoken_buckets) <= len(subtoken):
+ subtoken_buckets.append(set())
+ subtoken_buckets[len(subtoken)].add(subtoken)
+ return subtoken_buckets
+
+
+def _gen_new_subtoken_list(subtoken_counts,
+ min_count,
+ alphabet,
+ reserved_tokens=None):
+ """Generate candidate subtokens ordered by count, and new max subtoken length.
+
+ Add subtokens to the candiate list in order of length (longest subtokens
+ first). When a subtoken is added, the counts of each of its prefixes are
+ decreased. Prefixes that don't appear much outside the subtoken are not added
+ to the candidate list.
+
+ For example:
+ subtoken being added to candidate list: 'translate'
+ subtoken_counts: {'translate':10, 't':40, 'tr':16, 'tra':12, ...}
+ min_count: 5
+
+ When 'translate' is added, subtoken_counts is updated to:
+ {'translate':0, 't':30, 'tr':6, 'tra': 2, ...}
+
+ The subtoken 'tra' will not be added to the candidate list, because it appears
+ twice (less than min_count) outside of 'translate'.
+
+ Args:
+ subtoken_counts: defaultdict mapping str subtokens to int counts
+ min_count: int minumum count requirement for subtokens
+ alphabet: set of characters. Each character is added to the subtoken list to
+ guarantee that all tokens can be encoded.
+ reserved_tokens: list of tokens that will be added to the beginning of the
+ returned subtoken list.
+
+ Returns:
+ List of candidate subtokens in decreasing count order, and maximum subtoken
+ length
+ """
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ # Create a list of (count, subtoken) for each candidate subtoken.
+ subtoken_candidates = []
+
+ # Use bucketted list to iterate through subtokens in order of length.
+ # subtoken_buckets[i] = set(subtokens), where each subtoken has length i.
+ subtoken_buckets = _filter_and_bucket_subtokens(subtoken_counts, min_count)
+ max_subtoken_length = len(subtoken_buckets) - 1
+
+ # Go through the list in reverse order to consider longer subtokens first.
+ for subtoken_len in xrange(max_subtoken_length, 0, -1):
+ for subtoken in subtoken_buckets[subtoken_len]:
+ count = subtoken_counts[subtoken]
+
+ # Possible if this subtoken is a prefix of another token.
+ if count < min_count:
+ continue
+
+ # Ignore alphabet/reserved tokens, which will be added manually later.
+ if subtoken not in alphabet and subtoken not in reserved_tokens:
+ subtoken_candidates.append((count, subtoken))
+
+ # Decrement count of the subtoken's prefixes (if a longer subtoken is
+ # added, its prefixes lose priority to be added).
+ for end in xrange(1, subtoken_len):
+ subtoken_counts[subtoken[:end]] -= count
+
+ # Add alphabet subtokens (guarantees that all strings are encodable).
+ subtoken_candidates.extend((subtoken_counts.get(a, 0), a) for a in alphabet)
+
+ # Order subtoken candidates by decreasing count.
+ subtoken_list = [t for _, t in sorted(subtoken_candidates, reverse=True)]
+
+ # Add reserved tokens to beginning of the list.
+ subtoken_list = reserved_tokens + subtoken_list
+ return subtoken_list, max_subtoken_length
+
+
+def _generate_subtokens(token_counts,
+ alphabet,
+ min_count,
+ num_iterations=4,
+ reserved_tokens=None):
+ """Create a list of subtokens in decreasing order of frequency.
+
+ Args:
+ token_counts: dict mapping str tokens -> int count
+ alphabet: set of characters
+ min_count: int minimum number of times a subtoken must appear before it is
+ added to the vocabulary.
+ num_iterations: int number of iterations to generate new tokens.
+ reserved_tokens: list of tokens that will be added to the beginning to the
+ returned subtoken list.
+
+ Returns:
+ Sorted list of subtokens (most frequent first)
+ """
+ if reserved_tokens is None:
+ reserved_tokens = RESERVED_TOKENS
+
+ # Use alphabet set to create initial list of subtokens
+ subtoken_list = reserved_tokens + list(alphabet)
+ max_subtoken_length = 1
+
+ # On each iteration, segment all words using the subtokens defined in
+ # subtoken_dict, count how often the resulting subtokens appear, and update
+ # the dictionary with subtokens w/ high enough counts.
+ for i in xrange(num_iterations):
+ logging.info("\tGenerating subtokens: iteration %d", i)
+ # Generate new subtoken->id dictionary using the new subtoken list.
+ subtoken_dict = _list_to_index_dict(subtoken_list)
+
+ # Create dict mapping subtoken->count, with additional subtokens created
+ # from substrings taken from the tokens.
+ subtoken_counts = _count_and_gen_subtokens(token_counts, alphabet,
+ subtoken_dict,
+ max_subtoken_length)
+
+ # Generate new list of subtokens sorted by subtoken count.
+ subtoken_list, max_subtoken_length = _gen_new_subtoken_list(
+ subtoken_counts, min_count, alphabet, reserved_tokens)
+
+ logging.info("\tVocab size: %d", len(subtoken_list))
+ return subtoken_list
diff --git a/modeling/official/legacy/transformer/utils/tokenizer_test.py b/modeling/official/legacy/transformer/utils/tokenizer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2101ec97043e16924bedb1c6894f91199d9d91d6
--- /dev/null
+++ b/modeling/official/legacy/transformer/utils/tokenizer_test.py
@@ -0,0 +1,204 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Test Subtokenizer and string helper methods."""
+
+import collections
+import tempfile
+
+import tensorflow as tf, tf_keras
+
+from official.legacy.transformer.utils import tokenizer
+
+
+class SubtokenizerTest(tf.test.TestCase):
+
+ def _init_subtokenizer(self, vocab_list):
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
+ with tf.io.gfile.GFile(temp_file.name, "w") as w:
+ for subtoken in vocab_list:
+ w.write("'%s'" % subtoken)
+ w.write("\n")
+ return tokenizer.Subtokenizer(temp_file.name, reserved_tokens=[])
+
+ def test_encode(self):
+ vocab_list = ["123_", "test", "ing_"]
+ subtokenizer = self._init_subtokenizer(vocab_list)
+ s = "testing 123"
+ encoded_list = subtokenizer.encode(s)
+ self.assertEqual([1, 2, 0], encoded_list)
+
+ def test_decode(self):
+ vocab_list = ["123_", "test", "ing_"]
+ subtokenizer = self._init_subtokenizer(vocab_list)
+ encoded_list = [1, 2, 0] # testing 123
+ decoded_str = subtokenizer.decode(encoded_list)
+ self.assertEqual("testing 123", decoded_str)
+
+ def test_subtoken_ids_to_tokens(self):
+ vocab_list = ["123_", "test", "ing_"]
+ subtokenizer = self._init_subtokenizer(vocab_list)
+ encoded_list = [1, 2, 0] # testing 123
+ token_list = subtokenizer._subtoken_ids_to_tokens(encoded_list)
+ self.assertEqual([u"testing", u"123"], token_list)
+
+
+class StringHelperTest(tf.test.TestCase):
+
+ def test_split_string_to_tokens(self):
+ text = "test? testing 123."
+
+ tokens = tokenizer._split_string_to_tokens(text,
+ tokenizer._ALPHANUMERIC_CHAR_SET)
+ self.assertEqual(["test", "? ", "testing", "123", "."], tokens)
+
+ def test_join_tokens_to_string(self):
+ tokens = ["test", "? ", "testing", "123", "."]
+
+ s = tokenizer._join_tokens_to_string(tokens,
+ tokenizer._ALPHANUMERIC_CHAR_SET)
+ self.assertEqual("test? testing 123.", s)
+
+ def test_escape_token(self):
+ token = u"abc_\\4"
+ alphabet = set("abc_\\u;")
+
+ escaped_token = tokenizer._escape_token(token, alphabet)
+ self.assertEqual("abc\\u\\\\\\52;_", escaped_token)
+
+ def test_unescape_token(self):
+ escaped_token = u"Underline: \\u, Backslash: \\\\, Unicode: \\52;"
+
+ unescaped_token = tokenizer._unescape_token(escaped_token)
+ self.assertEqual("Underline: _, Backslash: \\, Unicode: 4", unescaped_token)
+
+ def test_list_to_index_dict(self):
+ lst = ["test", "strings"]
+
+ d = tokenizer._list_to_index_dict(lst)
+ self.assertDictEqual({"test": 0, "strings": 1}, d)
+
+ def test_split_token_to_subtokens(self):
+ token = "abc"
+ subtoken_dict = {"a": 0, "b": 1, "c": 2, "ab": 3}
+ max_subtoken_length = 2
+
+ subtokens = tokenizer._split_token_to_subtokens(token, subtoken_dict,
+ max_subtoken_length)
+ self.assertEqual(["ab", "c"], subtokens)
+
+ def test_generate_alphabet_dict(self):
+ s = ["testing", "123"]
+ reserved_tokens = ["???"]
+
+ alphabet = tokenizer._generate_alphabet_dict(s, reserved_tokens)
+ self.assertIn("?", alphabet)
+ self.assertIn("t", alphabet)
+ self.assertIn("e", alphabet)
+ self.assertIn("s", alphabet)
+ self.assertIn("i", alphabet)
+ self.assertIn("n", alphabet)
+ self.assertIn("g", alphabet)
+ self.assertIn("1", alphabet)
+ self.assertIn("2", alphabet)
+ self.assertIn("3", alphabet)
+
+ def test_count_and_gen_subtokens(self):
+ token_counts = {"abc": 5}
+ alphabet = set("abc_")
+ subtoken_dict = {"a": 0, "b": 1, "c": 2, "_": 3}
+ max_subtoken_length = 2
+
+ subtoken_counts = tokenizer._count_and_gen_subtokens(
+ token_counts, alphabet, subtoken_dict, max_subtoken_length)
+
+ self.assertIsInstance(subtoken_counts, collections.defaultdict)
+ self.assertDictEqual(
+ {
+ "a": 5,
+ "b": 5,
+ "c": 5,
+ "_": 5,
+ "ab": 5,
+ "bc": 5,
+ "c_": 5,
+ "abc": 5,
+ "bc_": 5,
+ "abc_": 5
+ }, subtoken_counts)
+
+ def test_filter_and_bucket_subtokens(self):
+ subtoken_counts = collections.defaultdict(int, {
+ "a": 2,
+ "b": 4,
+ "c": 1,
+ "ab": 6,
+ "ac": 3,
+ "abbc": 5
+ })
+ min_count = 3
+
+ subtoken_buckets = tokenizer._filter_and_bucket_subtokens(
+ subtoken_counts, min_count)
+
+ self.assertEqual(len(subtoken_buckets[0]), 0)
+ self.assertEqual(set("b"), subtoken_buckets[1])
+ self.assertEqual(set(["ab", "ac"]), subtoken_buckets[2])
+ self.assertEqual(len(subtoken_buckets[3]), 0)
+ self.assertEqual(set(["abbc"]), subtoken_buckets[4])
+
+ def test_gen_new_subtoken_list(self):
+ subtoken_counts = collections.defaultdict(int, {
+ "translate": 10,
+ "t": 40,
+ "tr": 16,
+ "tra": 12
+ })
+ min_count = 5
+ alphabet = set("translate")
+ reserved_tokens = ["reserved", "tokens"]
+
+ subtoken_list, max_token_length = tokenizer._gen_new_subtoken_list(
+ subtoken_counts, min_count, alphabet, reserved_tokens)
+
+ # Check that "tra" isn"t in the list (its count should be decremented to 2,
+ # so it should not be added to the canddiate list).
+ self.assertNotIn("tra", subtoken_list)
+
+ self.assertIn("tr", subtoken_list)
+ self.assertIn("t", subtoken_list)
+
+ self.assertEqual(len("translate"), max_token_length)
+
+ def test_generate_subtokens(self):
+ token_counts = {"ab": 1, "bc": 3, "abc": 5}
+ alphabet = set("abc_")
+ min_count = 100
+ num_iterations = 1
+ reserved_tokens = ["reserved", "tokens"]
+
+ vocab_list = tokenizer._generate_subtokens(token_counts, alphabet,
+ min_count, num_iterations,
+ reserved_tokens)
+
+ # Check that reserved tokens are at the front of the list
+ self.assertEqual(vocab_list[:2], reserved_tokens)
+
+ # Check that each character in alphabet is in the vocab list
+ for c in alphabet:
+ self.assertIn(c, vocab_list)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/legacy/xlnet/README.md b/modeling/official/legacy/xlnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..546d1128e2d562fbcb4de32894d89b2200249d2b
--- /dev/null
+++ b/modeling/official/legacy/xlnet/README.md
@@ -0,0 +1,236 @@
+# XLNet: Generalized Autoregressive Pretraining for Language Understanding
+
+The academic paper which describes XLNet in detail and provides full results on
+a number of tasks can be found here: https://arxiv.org/abs/1906.08237.
+
+XLNet is a generalized autoregressive BERT-like pretraining language model that
+enables learning bidirectional contexts by maximizing the expected likelihood
+over all permutations of the factorization order. It can learn dependency beyond
+a fixed length without disrupting temporal coherence by using segment-level
+recurrence mechanism and relative positional encoding scheme introduced in
+[Transformer-XL](https://arxiv.org/pdf/1901.02860.pdf). XLNet outperforms BERT
+on 20 NLP benchmark tasks and achieves state-of-the-art results on 18 tasks
+including question answering, natural language inference, sentiment analysis,
+and document ranking.
+
+## Contents
+
+* [Contents](#contents)
+* [Set Up](#set-up)
+* [Process Datasets](#process-datasets)
+* [Fine-tuning with XLNet](#fine-tuning-with-xlnet)
+
+## Set up
+
+To run XLNet on a Cloud TPU, you can first create a `tf-nightly` TPU with the
+[ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
+
+```shell
+ctpu up -name --tf-version=”nightly”
+```
+
+After SSH'ing into the VM (or if you're using an on-prem machine), setup
+continues as follows:
+
+```shell
+export PYTHONPATH="$PYTHONPATH:/path/to/models"
+```
+
+Install `tf-nightly` to get latest updates:
+
+```shell
+pip install tf-nightly-gpu
+```
+
+## Process Datasets
+
+Dataset processing requires a
+[Sentence Piece](https://github.com/google/sentencepiece) model. One can be
+found at the publicly available GCS bucket at:
+`gs://cloud-tpu-checkpoints/xlnet/cased_spiece.model`.
+
+Note that in order to train using Cloud TPUs, data must be stored on a GCS
+bucket.
+
+Setup commands:
+
+```shell
+export SPIECE_DIR=~/cased_spiece/
+export SPIECE_MODEL=${SPIECE_DIR}/cased_spiece.model
+export DATASETS_DIR=gs://some_bucket/datasets
+mkdir -p ${SPIECE_DIR}
+gsutil cp gs://cloud-tpu-checkpoints/xlnet/cased_spiece.model ${SPIECE_DIR}
+```
+
+
+### Pre-training
+
+Pre-training data can be converted into TFRecords using
+[`preprocess_pretrain_data.py`](preprocess_pretrain_data.py). Inputs should
+consist of a plain text file (or a file glob of plain text files) with one
+sentence per line.
+
+To run the script, use the following command:
+
+```shell
+export INPUT_GLOB='path/to/wiki_cased/*.txt'
+
+python3 preprocess_pretrain_data.py --bsz_per_host=32 --num_core_per_host=16
+--seq_len=512 --reuse_len=256 --input_glob='path/to/wiki_cased/*.txt'
+--save_dir=${DATASETS_DIR}/pretrain --bi_data=True --sp_path=${SPIECE_MODEL}
+--mask_alpha=6 --mask_beta=1 --num_predict=85
+```
+
+Note that to make the memory mechanism work correctly, `bsz_per_host` and
+`num_core_per_host` are *strictly specified* when preparing TFRecords. The same
+TPU settings should be used when training.
+
+### Fine-tuning
+
+* Classification
+
+To prepare classification data TFRecords on the IMDB dataset, users can download
+and unpack the [IMDB dataset](https://www.imdb.com/interfaces/) with the
+following command:
+
+```shell
+export IMDB_DIR=~/imdb
+mkdir -p ${IMDB_DIR}
+
+cd ${IMDB_DIR}
+wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
+tar zxvf aclImdb_v1.tar.gz -C ${IMDB_DIR}
+rm aclImdb_v1.tar.gz
+```
+
+Then, the dataset can be converted into TFRecords with the following command:
+
+```shell
+export TASK_NAME=imdb
+
+python3 preprocess_classification_data.py --max_seq_length=512 --spiece_model_file=${SPIECE_MODEL} --output_dir=${DATASETS_DIR}/${TASK_NAME} --data_dir=${IMDB_DIR}/aclImdb --task_name=${TASK_NAME}
+```
+
+Note: To obtain SOTA on the IMDB dataset, using a sequence length of 512 is
+necessary.
+
+* SQUAD
+
+The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
+detailed information about the SQuAD datasets and evaluation.
+
+To download the relevant files, use the following command:
+
+```shell
+export SQUAD_DIR=~/squad
+
+mkdir -p ${SQUAD_DIR} && cd ${SQUAD_DIR}
+wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json
+wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json
+```
+
+Then to process the dataset into TFRecords, run the following commands:
+
+```shell
+python3 preprocess_squad_data.py --spiece_model_file=${SPIECE_MODEL} --train_file=${SQUAD_DIR}/train-v2.0.json --predict_file=${SQUAD_DIR}/dev-v2.0.json --output_dir=${DATASETS_DIR}/squad --uncased=False --max_seq_length=512 --num_proc=1 --proc_id=0
+
+gsutil cp ${SQUAD_DIR}/dev-v2.0.json ${DATASETS_DIR}/squad
+```
+
+## Fine-tuning with XLNet
+
+* Cloud Storage
+
+The unzipped pre-trained model files can be found in the Google Cloud Storage
+folder `gs://cloud-tpu-checkpoints/xlnet/keras_xlnet`. For example:
+
+```shell
+export XLNET_DIR=gs:/cloud-tpu-checkpoints/xlnet/keras_xlnet
+export MODEL_DIR=gs://some_bucket/my_output_dir
+```
+
+### Classification task
+
+This example code fine-tunes `XLNet` on the IMDB dataset. For this task, it
+takes around 11 minutes to get the first 500 steps' results, and takes around 1
+hour to complete on a v3-8. It is expected to obtain an accuracy between 96.15
+and 96.33.
+
+To run on a v3-8 TPU:
+
+```shell
+export TPU_NAME=my-tpu
+
+python3 run_classifier.py \
+--strategy_type=tpu \
+--tpu=${TPU_NAME} \
+--init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
+--model_dir=${MODEL_DIR} \
+--test_data_size=25024 \
+--train_tfrecord_path=${DATASETS_DIR}/imdb/cased_spiece.model.len-512.train.tf_record \
+--test_tfrecord_path=${DATASETS_DIR}/imdb/cased_spiece.model.len-512.dev.eval.tf_record \
+--train_batch_size=32 \
+--seq_len=512 \
+--n_layer=24 \
+--d_model=1024 \
+--d_embed=1024 \
+--n_head=16 \
+--d_head=64 \
+--d_inner=4096 \
+--untie_r=true \
+--n_class=2 \
+--ff_activation=gelu \
+--learning_rate=2e-5 \
+--train_steps=4000 \
+--warmup_steps=500 \
+--iterations=500 \
+--bi_data=false \
+--summary_type=last
+```
+
+### SQuAD 2.0 Task
+
+The Stanford Question Answering Dataset (SQuAD) is a popular question answering
+benchmark dataset. See more in
+[SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
+
+We use `XLNet-LARGE` (cased_L-24_H-1024_A-16) running on a v3-8 as an example to
+run this workflow. It is expected to reach a `best_f1` score of between 88.30
+and 88.80. It should take around 5 minutes to read the pickle file, and then 18
+minutes to get the first 1000 steps' results. It takes around 2 hours to
+complete.
+
+```shell
+export TPU_NAME=my-tpu
+
+python3 run_squad.py \
+ --strategy_type=tpu \
+ --tpu=${TPU_NAME} \
+ --init_checkpoint=${XLNET_DIR}/xlnet_model.ckpt \
+ --model_dir=${MODEL_DIR} \
+ --train_tfrecord_path=${DATASETS_DIR}/squad/squad_cased \
+ --test_tfrecord_path=${DATASETS_DIR}/squad/squad_cased/12048.eval.tf_record \
+ --test_feature_path=${DATASETS_DIR}/squad/spiece.model.slen-512.qlen-64.eval.features.pkl \
+ --predict_dir=${MODEL_DIR} \
+ --predict_file=${DATASETS_DIR}/squad/dev-v2.0.json \
+ --train_batch_size=48 \
+ --seq_len=512 \
+ --reuse_len=256 \
+ --mem_len=0 \
+ --n_layer=24 \
+ --d_model=1024 \
+ --d_embed=1024 \
+ --n_head=16 \
+ --d_head=64 \
+ --d_inner=4096 \
+ --untie_r=true \
+ --ff_activation=gelu \
+ --learning_rate=.00003 \
+ --train_steps=8000 \
+ --warmup_steps=1000 \
+ --iterations=1000 \
+ --bi_data=false \
+ --query_len=64 \
+ --adam_epsilon=.000001 \
+ --lr_layer_decay_rate=0.75
+```
diff --git a/modeling/official/legacy/xlnet/__init__.py b/modeling/official/legacy/xlnet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f338592c943c69c8ca66bc1f0981a619ea10e27
--- /dev/null
+++ b/modeling/official/legacy/xlnet/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
diff --git a/modeling/official/legacy/xlnet/classifier_utils.py b/modeling/official/legacy/xlnet/classifier_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..646a925cb765ec70cf6824a71fe2ea1b8a834b1c
--- /dev/null
+++ b/modeling/official/legacy/xlnet/classifier_utils.py
@@ -0,0 +1,163 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for pre-processing classification data."""
+
+from absl import logging
+
+from official.legacy.xlnet import data_utils
+
+SEG_ID_A = 0
+SEG_ID_B = 1
+
+
+class PaddingInputExample(object):
+ """Fake example so the num input examples is a multiple of the batch size.
+
+ When running eval/predict on the TPU, we need to pad the number of examples
+ to be a multiple of the batch size, because the TPU requires a fixed batch
+ size. The alternative is to drop the last batch, which is bad because it means
+ the entire output data won't be generated.
+ We use this class instead of `None` because treating `None` as padding
+ battches could cause silent errors.
+ """
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ input_ids,
+ input_mask,
+ segment_ids,
+ label_id,
+ is_real_example=True):
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.label_id = label_id
+ self.is_real_example = is_real_example
+
+
+def _truncate_seq_pair(tokens_a, tokens_b, max_length):
+ """Truncates a sequence pair in place to the maximum length."""
+
+ # This is a simple heuristic which will always truncate the longer sequence
+ # one token at a time. This makes more sense than truncating an equal percent
+ # of tokens from each, since if one sequence is very short then each token
+ # that's truncated likely contains more information than a longer sequence.
+ while True:
+ total_length = len(tokens_a) + len(tokens_b)
+ if total_length <= max_length:
+ break
+ if len(tokens_a) > len(tokens_b):
+ tokens_a.pop()
+ else:
+ tokens_b.pop()
+
+
+def convert_single_example(example_index, example, label_list, max_seq_length,
+ tokenize_fn, use_bert_format):
+ """Converts a single `InputExample` into a single `InputFeatures`."""
+
+ if isinstance(example, PaddingInputExample):
+ return InputFeatures(
+ input_ids=[0] * max_seq_length,
+ input_mask=[1] * max_seq_length,
+ segment_ids=[0] * max_seq_length,
+ label_id=0,
+ is_real_example=False)
+
+ if label_list is not None:
+ label_map = {}
+ for (i, label) in enumerate(label_list):
+ label_map[label] = i
+
+ tokens_a = tokenize_fn(example.text_a)
+ tokens_b = None
+ if example.text_b:
+ tokens_b = tokenize_fn(example.text_b)
+
+ if tokens_b:
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
+ # length is less than the specified length.
+ # Account for two [SEP] & one [CLS] with "- 3"
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
+ else:
+ # Account for one [SEP] & one [CLS] with "- 2"
+ if len(tokens_a) > max_seq_length - 2:
+ tokens_a = tokens_a[:max_seq_length - 2]
+
+ tokens = []
+ segment_ids = []
+ for token in tokens_a:
+ tokens.append(token)
+ segment_ids.append(SEG_ID_A)
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(SEG_ID_A)
+
+ if tokens_b:
+ for token in tokens_b:
+ tokens.append(token)
+ segment_ids.append(SEG_ID_B)
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(SEG_ID_B)
+
+ if use_bert_format:
+ tokens.insert(0, data_utils.CLS_ID)
+ segment_ids.insert(0, data_utils.SEG_ID_CLS)
+ else:
+ tokens.append(data_utils.CLS_ID)
+ segment_ids.append(data_utils.SEG_ID_CLS)
+
+ input_ids = tokens
+
+ # The mask has 0 for real tokens and 1 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [0] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ if len(input_ids) < max_seq_length:
+ delta_len = max_seq_length - len(input_ids)
+ if use_bert_format:
+ input_ids = input_ids + [0] * delta_len
+ input_mask = input_mask + [1] * delta_len
+ segment_ids = segment_ids + [data_utils.SEG_ID_PAD] * delta_len
+ else:
+ input_ids = [0] * delta_len + input_ids
+ input_mask = [1] * delta_len + input_mask
+ segment_ids = [data_utils.SEG_ID_PAD] * delta_len + segment_ids
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ if label_list is not None:
+ label_id = label_map[example.label]
+ else:
+ label_id = example.label
+ if example_index < 5:
+ logging.info("*** Example ***")
+ logging.info("guid: %s", (example.guid))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("label: %s (id = %d)", example.label, label_id)
+
+ feature = InputFeatures(
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ label_id=label_id)
+ return feature
diff --git a/modeling/official/legacy/xlnet/common_flags.py b/modeling/official/legacy/xlnet/common_flags.py
new file mode 100644
index 0000000000000000000000000000000000000000..83e3ab483b9ae047daccd451ea145dcea0037f90
--- /dev/null
+++ b/modeling/official/legacy/xlnet/common_flags.py
@@ -0,0 +1,142 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common flags used in XLNet model."""
+
+from absl import flags
+
+flags.DEFINE_string("master", default=None, help="master")
+flags.DEFINE_string(
+ "tpu",
+ default=None,
+ help="The Cloud TPU to use for training. This should be "
+ "either the name used when creating the Cloud TPU, or a "
+ "url like grpc://ip.address.of.tpu:8470.")
+flags.DEFINE_bool(
+ "use_tpu", default=True, help="Use TPUs rather than plain CPUs.")
+flags.DEFINE_string("tpu_topology", "2x2", help="TPU topology.")
+flags.DEFINE_integer(
+ "num_core_per_host", default=8, help="number of cores per host")
+
+flags.DEFINE_string("model_dir", default=None, help="Estimator model_dir.")
+flags.DEFINE_string(
+ "init_checkpoint",
+ default=None,
+ help="Checkpoint path for initializing the model.")
+flags.DEFINE_bool(
+ "init_from_transformerxl",
+ default=False,
+ help="Init from a transformerxl model checkpoint. Otherwise, init from the "
+ "entire model checkpoint.")
+
+# Optimization config
+flags.DEFINE_float("learning_rate", default=1e-4, help="Maximum learning rate.")
+flags.DEFINE_float("clip", default=1.0, help="Gradient clipping value.")
+flags.DEFINE_float("weight_decay_rate", default=0.0, help="Weight decay rate.")
+
+# lr decay
+flags.DEFINE_integer(
+ "warmup_steps", default=0, help="Number of steps for linear lr warmup.")
+flags.DEFINE_float("adam_epsilon", default=1e-8, help="Adam epsilon.")
+flags.DEFINE_float(
+ "lr_layer_decay_rate",
+ default=1.0,
+ help="Top layer: lr[L] = FLAGS.learning_rate."
+ "Lower layers: lr[l-1] = lr[l] * lr_layer_decay_rate.")
+flags.DEFINE_float(
+ "min_lr_ratio", default=0.0, help="Minimum ratio learning rate.")
+
+# Training config
+flags.DEFINE_integer(
+ "train_batch_size",
+ default=16,
+ help="Size of the train batch across all hosts.")
+flags.DEFINE_integer(
+ "train_steps", default=100000, help="Total number of training steps.")
+flags.DEFINE_integer(
+ "iterations", default=1000, help="Number of iterations per repeat loop.")
+
+# Data config
+flags.DEFINE_integer(
+ "seq_len", default=0, help="Sequence length for pretraining.")
+flags.DEFINE_integer(
+ "reuse_len",
+ default=0,
+ help="How many tokens to be reused in the next batch. "
+ "Could be half of `seq_len`.")
+flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
+flags.DEFINE_bool(
+ "bi_data",
+ default=False,
+ help="Use bidirectional data streams, "
+ "i.e., forward & backward.")
+flags.DEFINE_integer("n_token", 32000, help="Vocab size")
+
+# Model config
+flags.DEFINE_integer("mem_len", default=0, help="Number of steps to cache")
+flags.DEFINE_bool("same_length", default=False, help="Same length attention")
+flags.DEFINE_integer("clamp_len", default=-1, help="Clamp length")
+
+flags.DEFINE_integer("n_layer", default=6, help="Number of layers.")
+flags.DEFINE_integer("d_model", default=32, help="Dimension of the model.")
+flags.DEFINE_integer("d_embed", default=32, help="Dimension of the embeddings.")
+flags.DEFINE_integer("n_head", default=4, help="Number of attention heads.")
+flags.DEFINE_integer(
+ "d_head", default=8, help="Dimension of each attention head.")
+flags.DEFINE_integer(
+ "d_inner",
+ default=32,
+ help="Dimension of inner hidden size in positionwise "
+ "feed-forward.")
+flags.DEFINE_float("dropout", default=0.1, help="Dropout rate.")
+flags.DEFINE_float("dropout_att", default=0.1, help="Attention dropout rate.")
+flags.DEFINE_bool("untie_r", default=False, help="Untie r_w_bias and r_r_bias")
+flags.DEFINE_string(
+ "ff_activation",
+ default="relu",
+ help="Activation type used in position-wise feed-forward.")
+flags.DEFINE_string(
+ "strategy_type",
+ default="tpu",
+ help="Activation type used in position-wise feed-forward.")
+flags.DEFINE_bool("use_bfloat16", False, help="Whether to use bfloat16.")
+
+# Parameter initialization
+flags.DEFINE_enum(
+ "init_method",
+ default="normal",
+ enum_values=["normal", "uniform"],
+ help="Initialization method.")
+flags.DEFINE_float(
+ "init_std", default=0.02, help="Initialization std when init is normal.")
+flags.DEFINE_float(
+ "init_range", default=0.1, help="Initialization std when init is uniform.")
+
+flags.DEFINE_integer(
+ "test_data_size", default=12048, help="Number of test data samples.")
+flags.DEFINE_string(
+ "train_tfrecord_path",
+ default=None,
+ help="Path to preprocessed training set tfrecord.")
+flags.DEFINE_string(
+ "test_tfrecord_path",
+ default=None,
+ help="Path to preprocessed test set tfrecord.")
+flags.DEFINE_integer(
+ "test_batch_size",
+ default=16,
+ help="Size of the test batch across all hosts.")
+flags.DEFINE_integer(
+ "save_steps", default=1000, help="Number of steps for saving checkpoint.")
+FLAGS = flags.FLAGS
diff --git a/modeling/official/legacy/xlnet/data_utils.py b/modeling/official/legacy/xlnet/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5258ff98bc96cd3814de994c6713015e9a32b91f
--- /dev/null
+++ b/modeling/official/legacy/xlnet/data_utils.py
@@ -0,0 +1,804 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities used for data preparation."""
+
+import collections
+import json
+import os
+
+from absl import logging
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+special_symbols = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "": 3,
+ "": 4,
+ "": 5,
+ "": 6,
+ "": 7,
+ "": 8,
+}
+
+VOCAB_SIZE = 32000
+UNK_ID = special_symbols[""]
+CLS_ID = special_symbols[""]
+SEP_ID = special_symbols[""]
+MASK_ID = special_symbols[""]
+EOD_ID = special_symbols[""]
+SEG_ID_P = 0
+SEG_ID_Q = 1
+SEG_ID_CLS = 2
+SEG_ID_PAD = 3
+
+OnlineMaskingConfig = collections.namedtuple("OnlineMaskingConfig", [
+ "sample_strategy", "max_num_tokens", "min_num_tokens", "max_num_words",
+ "min_num_words"
+])
+
+
+def file_based_input_fn_builder(input_file, name_to_features, batch_size,
+ is_training):
+ """Creates an `input_fn` closure."""
+
+ logging.info("Input tfrecord file %s", input_file)
+
+ def _decode_record(record, name_to_features):
+ """Decodes a record to a TensorFlow example."""
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def input_fn():
+ """Returns dataset for training/evaluation."""
+ num_threads = 8
+ if isinstance(input_file, str):
+ d = tf.data.TFRecordDataset(input_file)
+ # For training, we want a lot of parallel reading and shuffling.
+ # For eval, we want no shuffling and parallel reading doesn't matter.
+ if is_training:
+ d = d.shuffle(2048)
+ d = d.repeat()
+ else:
+ cycle_length = min(num_threads, len(input_file))
+ d = tf.data.Dataset.from_tensor_slices(input_file)
+ # file level shuffle
+ d = d.shuffle(len(input_file)).repeat()
+
+ d = d.interleave(
+ tf.data.TFRecordDataset,
+ cycle_length=cycle_length)
+
+ if is_training:
+ # sample level shuffle
+ d = d.shuffle(buffer_size=2048)
+ d = d.map(
+ lambda record: _decode_record(record, name_to_features),
+ num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ d = d.batch(batch_size, drop_remainder=is_training)
+
+ # When `input_file` is a path to a single file or a list
+ # containing a single path, disable auto sharding so that
+ # same input file is sent to all workers.
+ if isinstance(input_file, str) or len(input_file) == 1:
+ options = tf.data.Options()
+ options.experimental_distribute.auto_shard_policy = (
+ tf.data.experimental.AutoShardPolicy.OFF)
+ d = d.with_options(options)
+
+ d = d.prefetch(tf.data.experimental.AUTOTUNE)
+ return d
+
+ return input_fn
+
+
+def create_classification_dataset(file_path, seq_length, batch_size,
+ is_training):
+ """Creates input dataset from (tf)records files for pretraining."""
+ name_to_features = {
+ "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
+ "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "label_ids": tf.io.FixedLenFeature([], tf.int64),
+ "is_real_example": tf.io.FixedLenFeature([], tf.int64),
+ }
+
+ input_fn = file_based_input_fn_builder(file_path, name_to_features,
+ batch_size, is_training)
+ dataset = input_fn()
+ return dataset
+
+
+def create_squad_dataset(file_path, seq_length, batch_size, is_training):
+ """Creates input dataset from (tf)records files for pretraining."""
+ name_to_features = {
+ "unique_ids": tf.io.FixedLenFeature([], tf.int64),
+ "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "input_mask": tf.io.FixedLenFeature([seq_length], tf.float32),
+ "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "cls_index": tf.io.FixedLenFeature([], tf.int64),
+ "p_mask": tf.io.FixedLenFeature([seq_length], tf.float32)
+ }
+
+ if is_training:
+ name_to_features["start_positions"] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features["end_positions"] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features["is_impossible"] = tf.io.FixedLenFeature([], tf.float32)
+
+ input_fn = file_based_input_fn_builder(file_path, name_to_features,
+ batch_size, is_training)
+ dataset = input_fn()
+ return dataset
+
+
+def get_input_iterator(input_fn, strategy):
+ """Returns distributed dataset iterator."""
+
+ # When training with TPU pods, datasets needs to be cloned across
+ # workers. Since Dataset instance cannot be cloned in eager mode, we instead
+ # pass callable that returns a dataset.
+ input_data = input_fn()
+ if callable(input_data):
+ iterator = iter(strategy.distribute_datasets_from_function(input_data))
+ else:
+ iterator = iter(strategy.experimental_distribute_dataset(input_data))
+ return iterator
+
+
+def get_classification_input_data(batch_size, seq_len, strategy, is_training,
+ file_path):
+ """Returns input dataset from input file string."""
+
+ # When using TPU pods, we need to clone dataset across
+ # workers and need to pass in function that returns the dataset rather
+ # than passing dataset instance itself.
+ use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
+ if use_dataset_fn:
+ if batch_size % strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of replicas : {}".format(
+ strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ batch_size = int(batch_size / strategy.num_replicas_in_sync)
+
+ def _dataset_fn(ctx=None):
+ del ctx
+
+ train_dataset = create_classification_dataset(
+ file_path=file_path,
+ seq_length=seq_len,
+ batch_size=batch_size,
+ is_training=is_training)
+ return train_dataset
+
+ return _dataset_fn if use_dataset_fn else _dataset_fn()
+
+
+def get_squad_input_data(batch_size, seq_len, q_len, strategy, is_training,
+ file_path):
+ """Returns input dataset from input file string."""
+
+ # When using TPU pods, we need to clone dataset across
+ # workers and need to pass in function that returns the dataset rather
+ # than passing dataset instance itself.
+ use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
+ if use_dataset_fn:
+ if batch_size % strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of replicas : {}".format(
+ strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ batch_size = int(batch_size / strategy.num_replicas_in_sync)
+
+ if is_training:
+ input_glob = os.path.join(
+ file_path,
+ "spiece.model.*.slen-{}.qlen-{}.train.tf_record".format(seq_len, q_len))
+
+ global_input_paths = tf.io.gfile.glob(input_glob)
+ else:
+ global_input_paths = file_path
+
+ def _dataset_fn(ctx=None):
+ del ctx
+
+ train_dataset = create_squad_dataset(
+ file_path=global_input_paths,
+ seq_length=seq_len,
+ batch_size=batch_size,
+ is_training=is_training)
+ return train_dataset
+
+ return _dataset_fn if use_dataset_fn else _dataset_fn()
+
+
+def _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len, num_predict):
+ """Turn beg and end indices into actual mask."""
+ non_func_mask = tf.logical_and(
+ tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID))
+ all_indices = tf.where(non_func_mask, tf.range(tgt_len, dtype=tf.int64),
+ tf.constant(-1, shape=[tgt_len], dtype=tf.int64))
+ candidate_matrix = tf.cast(
+ tf.logical_and(all_indices[None, :] >= beg_indices[:, None],
+ all_indices[None, :] < end_indices[:, None]), tf.float32)
+ cumsum_matrix = tf.reshape(
+ tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, tgt_len])
+ masked_matrix = tf.cast(cumsum_matrix <= num_predict, tf.float32)
+ target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
+ is_masked = tf.cast(target_mask, tf.bool)
+
+ return is_masked, target_mask
+
+
+def _word_span_mask(inputs, tgt_len, num_predict, min_num_words, max_num_words,
+ boundary):
+ """Sample whole word spans as prediction targets."""
+ # Note: 1.2 is the token-to-word ratio
+ mask_alpha = tgt_len / num_predict / 1.2
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
+
+ # Sample span lengths from a zipf distribution
+ span_len_seq = np.arange(min_num_words, max_num_words + 1)
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
+ probs /= np.sum(probs)
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
+
+ # Sample `num_predict` words here: note that this is over sampling
+ span_lens = tf.random.categorical(
+ logits=logits[None],
+ num_samples=num_predict,
+ dtype=tf.int64,
+ )[0] + min_num_words
+
+ # Sample the ratio [0.0, 1.0) of left context lengths
+ span_lens_float = tf.cast(span_lens, tf.float32)
+ left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
+
+ left_ctx_len = round_to_int(left_ctx_len)
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
+
+ beg_indices = (
+ tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
+ end_indices = beg_indices + span_lens
+
+ # Remove out of range indices
+ max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int64)
+ valid_idx_mask = end_indices < max_boundary_index
+ beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
+
+ beg_indices = tf.gather(boundary, beg_indices)
+ end_indices = tf.gather(boundary, end_indices)
+
+ # Shuffle valid indices
+ num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
+ beg_indices = tf.gather(beg_indices, order)
+ end_indices = tf.gather(end_indices, order)
+
+ return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
+ num_predict)
+
+
+def _token_span_mask(inputs, tgt_len, num_predict, min_num_tokens,
+ max_num_tokens):
+ """Sample token spans as prediction targets."""
+ mask_alpha = tgt_len / num_predict
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int64)
+
+ # Sample span lengths from a zipf distribution
+ span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
+
+ probs /= np.sum(probs)
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
+ span_lens = tf.random.categorical(
+ logits=logits[None],
+ num_samples=num_predict,
+ dtype=tf.int64,
+ )[0] + min_num_tokens
+
+ # Sample the ratio [0.0, 1.0) of left context lengths
+ span_lens_float = tf.cast(span_lens, tf.float32)
+ left_ratio = tf.random.uniform(shape=[num_predict], minval=0.0, maxval=1.0)
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
+ left_ctx_len = round_to_int(left_ctx_len)
+
+ # Compute the offset from left start to the right end
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
+
+ # Get the actual begin and end indices
+ beg_indices = (
+ tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
+ end_indices = beg_indices + span_lens
+
+ # Remove out of range indices
+ valid_idx_mask = end_indices < tgt_len
+ beg_indices = tf.boolean_mask(beg_indices, valid_idx_mask)
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
+
+ # Shuffle valid indices
+ num_valid = tf.cast(tf.shape(beg_indices)[0], tf.int64)
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int64))
+ beg_indices = tf.gather(beg_indices, order)
+ end_indices = tf.gather(end_indices, order)
+
+ return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
+ num_predict)
+
+
+def _whole_word_mask(inputs, tgt_len, num_predict, boundary):
+ """Sample whole words as prediction targets."""
+ pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
+ cand_pair_indices = tf.random.shuffle(pair_indices)[:num_predict]
+ beg_indices = cand_pair_indices[:, 0]
+ end_indices = cand_pair_indices[:, 1]
+
+ return _idx_pair_to_mask(beg_indices, end_indices, inputs, tgt_len,
+ num_predict)
+
+
+def _single_token_mask(inputs, tgt_len, num_predict):
+ """Sample individual tokens as prediction targets."""
+ all_indices = tf.range(tgt_len, dtype=tf.int64)
+ non_func_mask = tf.logical_and(
+ tf.not_equal(inputs, SEP_ID), tf.not_equal(inputs, CLS_ID))
+ non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
+
+ masked_pos = tf.random.shuffle(non_func_indices)
+ masked_pos = tf.sort(masked_pos[:num_predict])
+ target_mask = tf.sparse_to_dense(
+ sparse_indices=masked_pos,
+ output_shape=[tgt_len],
+ sparse_values=1.0,
+ default_value=0.0)
+
+ is_masked = tf.cast(target_mask, tf.bool)
+
+ return is_masked, target_mask
+
+
+def _online_sample_masks(inputs,
+ tgt_len,
+ num_predict,
+ online_masking_config,
+ boundary=None):
+ """Sample target positions to predict."""
+ logging.info("Online sample with strategy: `%s`.",
+ online_masking_config.sample_strategy)
+ if online_masking_config.sample_strategy == "single_token":
+ return _single_token_mask(inputs, tgt_len, num_predict)
+ elif online_masking_config.sample_strategy == "whole_word":
+ assert boundary is not None, "whole word sampling requires `boundary`"
+ return _whole_word_mask(inputs, tgt_len, num_predict, boundary)
+ elif online_masking_config.sample_strategy == "token_span":
+ return _token_span_mask(inputs, tgt_len, num_predict,
+ online_masking_config.min_num_tokens,
+ online_masking_config.max_num_tokens)
+ elif online_masking_config.sample_strategy == "word_span":
+ assert boundary is not None, "word span sampling requires `boundary`"
+ return _word_span_mask(inputs, tgt_len, num_predict,
+ online_masking_config.min_num_words,
+ online_masking_config.max_num_words, boundary)
+ else:
+ raise NotImplementedError
+
+
+def create_pretrain_dataset(file_names,
+ bsz_per_core,
+ seq_len,
+ reuse_len,
+ perm_size,
+ leak_ratio,
+ online_masking_config,
+ num_predict=None,
+ input_pipeline_context=None):
+ """Creates pretrain dataset."""
+
+ def parser(record):
+ """Function used to parse tfrecord."""
+
+ record_spec = {
+ "input": tf.io.FixedLenFeature([seq_len], tf.int64),
+ "seg_id": tf.io.FixedLenFeature([seq_len], tf.int64),
+ "label": tf.io.FixedLenFeature([1], tf.int64),
+ }
+
+ if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
+ logging.info("Add `boundary` spec for %s",
+ online_masking_config.sample_strategy)
+ record_spec["boundary"] = tf.io.VarLenFeature(tf.int64)
+
+ # retrieve serialized example
+ example = tf.io.parse_single_example(
+ serialized=record, features=record_spec)
+
+ inputs = example.pop("input")
+ if online_masking_config.sample_strategy in ["whole_word", "word_span"]:
+ boundary = tf.sparse.to_dense(example.pop("boundary"))
+ else:
+ boundary = None
+ is_masked, _ = _online_sample_masks(
+ inputs, seq_len, num_predict, online_masking_config, boundary=boundary)
+
+ if reuse_len > 0:
+ ##### Use memory
+ # permutate the reuse and non-reuse parts separately
+ non_reuse_len = seq_len - reuse_len
+ assert reuse_len % perm_size == 0 and non_reuse_len % perm_size == 0
+
+ # Creates permutation mask and target mask for the first reuse_len tokens.
+ # The tokens in this part are reused from the last sequence.
+ perm_mask_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
+ inputs[:reuse_len], is_masked[:reuse_len], perm_size, reuse_len,
+ leak_ratio)
+
+ # Creates permutation mask and target mask for the rest of tokens in
+ # current example, which are concatentation of two new segments.
+ perm_mask_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
+ inputs[reuse_len:], is_masked[reuse_len:], perm_size, non_reuse_len,
+ leak_ratio)
+
+ perm_mask_0 = tf.concat(
+ [perm_mask_0, tf.ones([reuse_len, non_reuse_len])], axis=1)
+ perm_mask_1 = tf.concat(
+ [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1], axis=1)
+ perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
+ target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
+ input_k = tf.concat([input_k_0, input_k_1], axis=0)
+ input_q = tf.concat([input_q_0, input_q_1], axis=0)
+ else:
+ ##### Do not use memory
+ assert seq_len % perm_size == 0
+ # permutate the entire sequence together
+ perm_mask, target_mask, input_k, input_q = _local_perm(
+ inputs, is_masked, perm_size, seq_len, leak_ratio)
+
+ # reshape back to fixed shape
+ example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
+ example["input_ids"] = tf.reshape(input_k, [seq_len])
+ example["input_q"] = tf.reshape(input_q, [seq_len])
+
+ # Directly use raw inputs as the target
+ target = inputs
+
+ if num_predict is not None:
+ indices = tf.range(seq_len, dtype=tf.int64)
+ bool_target_mask = tf.cast(target_mask, tf.bool)
+ indices = tf.boolean_mask(indices, bool_target_mask)
+
+ ##### extra padding due to CLS/SEP introduced after prepro
+ actual_num_predict = tf.shape(indices)[0]
+ pad_len = num_predict - actual_num_predict
+
+ ##### target_mapping
+ target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
+ paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
+ target_mapping = tf.concat([target_mapping, paddings], axis=0)
+ example["target_mapping"] = tf.reshape(target_mapping,
+ [num_predict, seq_len])
+
+ ##### target
+ target = tf.boolean_mask(target, bool_target_mask)
+ paddings = tf.zeros([pad_len], dtype=target.dtype)
+ target = tf.concat([target, paddings], axis=0)
+ example["target"] = tf.reshape(target, [num_predict])
+
+ ##### target mask
+ target_mask = tf.concat([
+ tf.ones([actual_num_predict], dtype=tf.float32),
+ tf.zeros([pad_len], dtype=tf.float32)
+ ],
+ axis=0)
+ example["target_mask"] = tf.reshape(target_mask, [num_predict])
+ else:
+ example["target"] = tf.reshape(target, [seq_len])
+ example["target_mask"] = tf.reshape(target_mask, [seq_len])
+
+ for key in list(example.keys()):
+ val = example[key]
+ if tf_keras.backend.is_sparse(val):
+ val = tf.sparse.to_dense(val)
+ if val.dtype == tf.int64:
+ val = tf.cast(val, tf.int32)
+
+ example[key] = val
+
+ for k, v in example.items():
+ logging.info("%s: %s", k, v)
+
+ return example
+
+ dataset = parse_files_to_dataset(
+ parser=parser,
+ file_paths=file_names,
+ bsz_per_core=bsz_per_core,
+ sequential=reuse_len > 0,
+ input_pipeline_context=input_pipeline_context)
+
+ return dataset
+
+
+def format_filename(prefix,
+ suffix,
+ bsz_per_host,
+ seq_len,
+ reuse_len=None,
+ uncased=False):
+ """Generates input file name pattern."""
+ if reuse_len is not None and reuse_len > 0:
+ reuse_str = "reuse-{}.".format(reuse_len)
+ bsz_str = "hostbsz-{}.".format(bsz_per_host)
+ else:
+ reuse_str = ""
+ bsz_str = ""
+
+ if not uncased:
+ case_str = ""
+ else:
+ case_str = "uncased."
+
+ file_name = "{}.seq-{}.{}{}{}{}".format(prefix, seq_len, reuse_str, bsz_str,
+ case_str, suffix)
+
+ return file_name
+
+
+def get_pretrain_input_data(batch_size,
+ seq_len,
+ strategy,
+ file_path,
+ reuse_len,
+ perm_size,
+ leak_ratio,
+ num_predict,
+ uncased,
+ online_masking_config,
+ num_hosts=1):
+ """Returns input dataset from input file string."""
+
+ # When using TPU pods, we need to clone dataset across
+ # workers and need to pass in function that returns the dataset rather
+ # than passing dataset instance itself.
+ use_dataset_fn = isinstance(strategy, tf.distribute.TPUStrategy)
+ split = "train"
+ bsz_per_host = int(batch_size / num_hosts)
+ record_glob_base = format_filename(
+ prefix="meta.{}.pass-*".format(split),
+ suffix="json*",
+ bsz_per_host=bsz_per_host,
+ seq_len=seq_len,
+ reuse_len=reuse_len,
+ uncased=uncased)
+
+ def _get_num_batch(info):
+ if "num_batch" in info:
+ return info["num_batch"]
+ elif "num_example" in info:
+ return info["num_example"] / bsz_per_host
+ else:
+ raise ValueError("Do not have sample info.")
+
+ if use_dataset_fn:
+ if batch_size % strategy.num_replicas_in_sync != 0:
+ raise ValueError(
+ "Batch size must be divisible by number of replicas : {}".format(
+ strategy.num_replicas_in_sync))
+
+ # As auto rebatching is not supported in
+ # `distribute_datasets_from_function()` API, which is
+ # required when cloning dataset to multiple workers in eager mode,
+ # we use per-replica batch size.
+ batch_size = int(batch_size / strategy.num_replicas_in_sync)
+
+ record_info = {"num_batch": 0, "filenames": []}
+
+ tfrecord_dirs = file_path.split(",")
+ logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
+
+ for idx, record_dir in enumerate(tfrecord_dirs):
+ record_glob = os.path.join(record_dir, record_glob_base)
+ logging.info("[%d] Record glob: %s", idx, record_glob)
+
+ record_paths = sorted(tf.io.gfile.glob(record_glob))
+ logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
+
+ cur_record_info = {"num_batch": 0, "filenames": []}
+
+ for record_info_path in record_paths:
+ with tf.io.gfile.GFile(record_info_path, "r") as fp:
+ info = json.load(fp)
+ cur_record_info["num_batch"] += int(_get_num_batch(info))
+ cur_record_info["filenames"] += info["filenames"]
+
+ # overwrite directory for `cur_record_info`
+ new_filenames = []
+ for filename in cur_record_info["filenames"]:
+ basename = os.path.basename(filename)
+ new_filename = os.path.join(record_dir, basename)
+ new_filenames.append(new_filename)
+ cur_record_info["filenames"] = new_filenames
+
+ logging.info("[Dir %d] Number of chosen batches: %s", idx,
+ cur_record_info["num_batch"])
+ logging.info("[Dir %d] Number of chosen files: %s", idx,
+ len(cur_record_info["filenames"]))
+ logging.info(cur_record_info["filenames"])
+
+ # add `cur_record_info` to global `record_info`
+ record_info["num_batch"] += cur_record_info["num_batch"]
+ record_info["filenames"] += cur_record_info["filenames"]
+
+ logging.info("Total number of batches: %d", record_info["num_batch"])
+ logging.info("Total number of files: %d", len(record_info["filenames"]))
+ logging.info(record_info["filenames"])
+
+ def _dataset_fn(ctx=None):
+ """Function that can create a pretrain dataset."""
+
+ train_dataset = create_pretrain_dataset(
+ file_names=record_info["filenames"],
+ bsz_per_core=batch_size,
+ seq_len=seq_len,
+ reuse_len=reuse_len,
+ perm_size=perm_size,
+ leak_ratio=leak_ratio,
+ online_masking_config=online_masking_config,
+ num_predict=num_predict,
+ input_pipeline_context=ctx)
+ return train_dataset
+
+ return _dataset_fn if use_dataset_fn else _dataset_fn()
+
+
+def parse_files_to_dataset(parser,
+ file_paths,
+ bsz_per_core,
+ sequential,
+ input_pipeline_context=None):
+ """Creates the dataset given file paths."""
+
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ # Note: we cannot perform sample-level shuffle here because this will violate
+ # the consecutive requirement of data stream.
+
+ if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:
+ dataset = dataset.shard(input_pipeline_context.num_input_pipelines,
+ input_pipeline_context.input_pipeline_id)
+ # file-level shuffle
+ if len(file_paths) > 1:
+ dataset = dataset.shuffle(len(file_paths))
+
+ if sequential:
+ # Note: cannot perform sample-level shuffle here because this will violate
+ # the consecutive requirement of data stream.
+ dataset = tf.data.TFRecordDataset(dataset)
+ else:
+ # `cycle_length` is the number of parallel files that get read.
+ cycle_length = min(8, len(file_paths))
+ logging.info("Interleave %d files", cycle_length)
+
+ dataset = dataset.apply(
+ tf.data.experimental.parallel_interleave(
+ tf.data.TFRecordDataset, cycle_length=cycle_length))
+ buffer_size = 2048
+ logging.info("Perform sample-level shuffle with size %d", buffer_size)
+ dataset = dataset.shuffle(buffer_size=buffer_size)
+
+ dataset = dataset.cache().repeat().map(parser)
+ dataset = dataset.batch(bsz_per_core, drop_remainder=True)
+ dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
+
+ return dataset
+
+
+def _local_perm(inputs, is_masked, perm_size, seq_len, leak_ratio):
+ """Samples a permutation of the factorization order.
+
+ Creates perm_mask and target_mask accordingly.
+
+ Args:
+ inputs: int64 Tensor in shape [seq_len], input ids.
+ is_masked: bool Tensor in shape [seq_len]. True means being selected for
+ partial prediction.
+ perm_size: the length of longest permutation. Could be set to be reuse_len.
+ Should not be larger than reuse_len or there will be data leaks.
+ seq_len: int, sequence length.
+ leak_ratio: float, percent of masked tokens that are leaked.
+
+ Returns:
+ perm_mask: float32 Tensor in shape [seq_len, seq_len] consisted of 0 and 1.
+ If perm_mask[i][j] == 1, it means the ith token (in original order) cannot
+ attend to the jth token
+ (in original order). This case will happen only when the ith token's
+ permutated position <= the jth token's permutated position,
+ and the jth token is masked or is func token. If perm_mask[i][j] == 0, it
+ means the ith token (in original order) can attend to the jth token
+ (in original order). Note that non-masked tokens can be attended by all
+ other tokens, which is different from the description in original paper.
+ target_mask: float32 Tensor in shape [seq_len] consisted of 0 and 1. If
+ target_mask[i] == 1,
+ the ith token needs to be predicted and mask will be used as input. This
+ token will count for loss.
+ If target_mask[i] == 0, token (or [SEP], [CLS]) will be used as input. This
+ token will not count for loss.
+ inputs_k: int64 Tensor in shape [seq_len], input ids.
+ inputs_q: float32 Tensor in shape [seq_len], the same as target_mask.
+
+ """
+
+ # Generate permutation indices
+ index = tf.range(seq_len, dtype=tf.int64)
+ index = tf.transpose(tf.reshape(index, [-1, perm_size]))
+ index = tf.random.shuffle(index)
+ index = tf.reshape(tf.transpose(index), [-1])
+
+ # non-functional tokens
+ non_func_tokens = tf.logical_not(
+ tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))
+ masked_tokens = tf.logical_and(is_masked, non_func_tokens)
+ non_masked_or_func_tokens = tf.logical_not(masked_tokens)
+
+ smallest_index = -2 * tf.ones([seq_len], dtype=tf.int64)
+
+ # Similar to BERT, randomly leak some masked tokens
+ if leak_ratio > 0:
+ leak_tokens = tf.logical_and(
+ masked_tokens,
+ tf.random.uniform([seq_len], maxval=1.0) < leak_ratio)
+ can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
+ else:
+ can_attend_self = non_masked_or_func_tokens
+ to_index = tf.where(can_attend_self, smallest_index, index)
+ from_index = tf.where(can_attend_self, to_index + 1, to_index)
+
+ # For masked tokens, can attend if i > j
+ # For context tokens, always can attend each other
+ can_attend = from_index[:, None] > to_index[None, :]
+
+ # In modeling, 1 indicates cannot attend. Hence, reverse the value here.
+ perm_mask = 1.0 - tf.cast(can_attend, tf.float32)
+
+ # Only masked tokens are included in the loss
+ target_mask = tf.cast(masked_tokens, tf.float32)
+
+ # construct inputs_k
+ inputs_k = inputs
+
+ # construct inputs_q
+ inputs_q = masked_tokens
+
+ return perm_mask, target_mask, inputs_k, inputs_q
diff --git a/modeling/official/legacy/xlnet/optimization.py b/modeling/official/legacy/xlnet/optimization.py
new file mode 100644
index 0000000000000000000000000000000000000000..f731ad6807e158d4a29a5c668c63e4f42633d414
--- /dev/null
+++ b/modeling/official/legacy/xlnet/optimization.py
@@ -0,0 +1,98 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions and classes related to optimization (weight updates)."""
+
+from absl import logging
+import tensorflow as tf, tf_keras
+from official.nlp import optimization
+
+
+class WarmUp(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Applys a warmup schedule on a given learning rate decay schedule."""
+
+ def __init__(self,
+ initial_learning_rate,
+ decay_schedule_fn,
+ warmup_steps,
+ power=1.0,
+ name=None):
+ super(WarmUp, self).__init__()
+ self.initial_learning_rate = initial_learning_rate
+ self.warmup_steps = warmup_steps
+ self.power = power
+ self.decay_schedule_fn = decay_schedule_fn
+ self.name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self.name or "WarmUp") as name:
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ global_step_float = tf.cast(step, tf.float32)
+ warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
+ warmup_percent_done = global_step_float / warmup_steps_float
+ warmup_learning_rate = (
+ self.initial_learning_rate *
+ tf.math.pow(warmup_percent_done, self.power))
+ return tf.cond(
+ global_step_float < warmup_steps_float,
+ lambda: warmup_learning_rate,
+ lambda: self.decay_schedule_fn(step - self.warmup_steps),
+ name=name)
+
+ def get_config(self):
+ return {
+ "initial_learning_rate": self.initial_learning_rate,
+ "decay_schedule_fn": self.decay_schedule_fn,
+ "warmup_steps": self.warmup_steps,
+ "power": self.power,
+ "name": self.name
+ }
+
+
+def create_optimizer(init_lr,
+ num_train_steps,
+ num_warmup_steps,
+ min_lr_ratio=0.0,
+ adam_epsilon=1e-8,
+ weight_decay_rate=0.0):
+ """Creates an optimizer with learning rate schedule."""
+ # Implements linear decay of the learning rate.
+ learning_rate_fn = tf_keras.optimizers.schedules.PolynomialDecay(
+ initial_learning_rate=init_lr,
+ decay_steps=num_train_steps - num_warmup_steps,
+ end_learning_rate=init_lr * min_lr_ratio)
+ if num_warmup_steps:
+ learning_rate_fn = WarmUp(
+ initial_learning_rate=init_lr,
+ decay_schedule_fn=learning_rate_fn,
+ warmup_steps=num_warmup_steps)
+ if weight_decay_rate > 0.0:
+ logging.info(
+ "Using AdamWeightDecay with adam_epsilon=%.9f weight_decay_rate=%.3f",
+ adam_epsilon, weight_decay_rate)
+ optimizer = optimization.AdamWeightDecay(
+ learning_rate=learning_rate_fn,
+ weight_decay_rate=weight_decay_rate,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=adam_epsilon,
+ exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
+ include_in_weight_decay=["r_s_bias", "r_r_bias", "r_w_bias"])
+ else:
+ logging.info("Using Adam with adam_epsilon=%.9f", (adam_epsilon))
+ optimizer = tf_keras.optimizers.legacy.Adam(
+ learning_rate=learning_rate_fn, epsilon=adam_epsilon)
+
+ return optimizer, learning_rate_fn
diff --git a/modeling/official/legacy/xlnet/preprocess_classification_data.py b/modeling/official/legacy/xlnet/preprocess_classification_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3a28f7fef2533de1dc85c6b9b327fae2344599c
--- /dev/null
+++ b/modeling/official/legacy/xlnet/preprocess_classification_data.py
@@ -0,0 +1,455 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Script to pre-process classification data into tfrecords."""
+
+import collections
+import csv
+import os
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import numpy as np
+import tensorflow as tf, tf_keras
+
+import sentencepiece as spm
+from official.legacy.xlnet import classifier_utils
+from official.legacy.xlnet import preprocess_utils
+
+
+flags.DEFINE_bool(
+ "overwrite_data",
+ default=False,
+ help="If False, will use cached data if available.")
+flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.")
+flags.DEFINE_string(
+ "spiece_model_file", default="", help="Sentence Piece model path.")
+flags.DEFINE_string("data_dir", default="", help="Directory for input data.")
+
+# task specific
+flags.DEFINE_string("eval_split", default="dev", help="could be dev or test")
+flags.DEFINE_string("task_name", default=None, help="Task name")
+flags.DEFINE_integer(
+ "eval_batch_size", default=64, help="batch size for evaluation")
+flags.DEFINE_integer("max_seq_length", default=128, help="Max sequence length")
+flags.DEFINE_integer(
+ "num_passes",
+ default=1,
+ help="Num passes for processing training data. "
+ "This is use to batch data without loss for TPUs.")
+flags.DEFINE_bool("uncased", default=False, help="Use uncased.")
+flags.DEFINE_bool(
+ "is_regression", default=False, help="Whether it's a regression task.")
+flags.DEFINE_bool(
+ "use_bert_format",
+ default=False,
+ help="Whether to use BERT format to arrange input data.")
+
+FLAGS = flags.FLAGS
+
+
+class InputExample(object):
+ """A single training/test example for simple sequence classification."""
+
+ def __init__(self, guid, text_a, text_b=None, label=None):
+ """Constructs a InputExample.
+
+ Args:
+ guid: Unique id for the example.
+ text_a: string. The untokenized text of the first sequence. For single
+ sequence tasks, only this sequence must be specified.
+ text_b: (Optional) string. The untokenized text of the second sequence.
+ Only must be specified for sequence pair tasks.
+ label: (Optional) string. The label of the example. This should be
+ specified for train and dev examples, but not for test examples.
+ """
+ self.guid = guid
+ self.text_a = text_a
+ self.text_b = text_b
+ self.label = label
+
+
+class DataProcessor(object):
+ """Base class for data converters for sequence classification data sets."""
+
+ def get_train_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the train set."""
+ raise NotImplementedError()
+
+ def get_dev_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the dev set."""
+ raise NotImplementedError()
+
+ def get_test_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for prediction."""
+ raise NotImplementedError()
+
+ def get_labels(self):
+ """Gets the list of labels for this data set."""
+ raise NotImplementedError()
+
+ @classmethod
+ def _read_tsv(cls, input_file, quotechar=None):
+ """Reads a tab separated value file."""
+ with tf.io.gfile.GFile(input_file, "r") as f:
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
+ lines = []
+ for line in reader:
+ # pylint: disable=g-explicit-length-test
+ if len(line) == 0:
+ continue
+ lines.append(line)
+ return lines
+
+
+class GLUEProcessor(DataProcessor):
+ """GLUEProcessor."""
+
+ def __init__(self):
+ self.train_file = "train.tsv"
+ self.dev_file = "dev.tsv"
+ self.test_file = "test.tsv"
+ self.label_column = None
+ self.text_a_column = None
+ self.text_b_column = None
+ self.contains_header = True
+ self.test_text_a_column = None
+ self.test_text_b_column = None
+ self.test_contains_header = True
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, self.train_file)), "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, self.dev_file)), "dev")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ if self.test_text_a_column is None:
+ self.test_text_a_column = self.text_a_column
+ if self.test_text_b_column is None:
+ self.test_text_b_column = self.text_b_column
+
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, self.test_file)), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0 and self.contains_header and set_type != "test":
+ continue
+ if i == 0 and self.test_contains_header and set_type == "test":
+ continue
+ guid = "%s-%s" % (set_type, i)
+
+ a_column = (
+ self.text_a_column if set_type != "test" else self.test_text_a_column)
+ b_column = (
+ self.text_b_column if set_type != "test" else self.test_text_b_column)
+
+ # there are some incomplete lines in QNLI
+ if len(line) <= a_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_a = line[a_column]
+
+ if b_column is not None:
+ if len(line) <= b_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_b = line[b_column]
+ else:
+ text_b = None
+
+ if set_type == "test":
+ label = self.get_labels()[0]
+ else:
+ if len(line) <= self.label_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ label = line[self.label_column]
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class Yelp5Processor(DataProcessor):
+ """Yelp5Processor."""
+
+ def get_train_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "train.csv"))
+
+ def get_dev_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "test.csv"))
+
+ def get_labels(self):
+ """See base class."""
+ return ["1", "2", "3", "4", "5"]
+
+ def _create_examples(self, input_file):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ with tf.io.gfile.GFile(input_file) as f:
+ reader = csv.reader(f)
+ for i, line in enumerate(reader):
+
+ label = line[0]
+ text_a = line[1].replace('""', '"').replace('\\"', '"')
+ examples.append(
+ InputExample(guid=str(i), text_a=text_a, text_b=None, label=label))
+ return examples
+
+
+class ImdbProcessor(DataProcessor):
+ """ImdbProcessor."""
+
+ def get_labels(self):
+ return ["neg", "pos"]
+
+ def get_train_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "train"))
+
+ def get_dev_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "test"))
+
+ def _create_examples(self, data_dir):
+ """Creates examples."""
+ examples = []
+ for label in ["neg", "pos"]:
+ cur_dir = os.path.join(data_dir, label)
+ for filename in tf.io.gfile.listdir(cur_dir):
+ if not filename.endswith("txt"):
+ continue
+
+ if len(examples) % 1000 == 0:
+ logging.info("Loading dev example %d", len(examples))
+
+ path = os.path.join(cur_dir, filename)
+ with tf.io.gfile.GFile(path) as f:
+ text = f.read().strip().replace(" ", " ")
+ examples.append(
+ InputExample(
+ guid="unused_id", text_a=text, text_b=None, label=label))
+ return examples
+
+
+class MnliMatchedProcessor(GLUEProcessor):
+ """MnliMatchedProcessor."""
+
+ def __init__(self):
+ super(MnliMatchedProcessor, self).__init__()
+ self.dev_file = "dev_matched.tsv"
+ self.test_file = "test_matched.tsv"
+ self.label_column = -1
+ self.text_a_column = 8
+ self.text_b_column = 9
+
+ def get_labels(self):
+ return ["contradiction", "entailment", "neutral"]
+
+
+class MnliMismatchedProcessor(MnliMatchedProcessor):
+
+ def __init__(self):
+ super(MnliMismatchedProcessor, self).__init__()
+ self.dev_file = "dev_mismatched.tsv"
+ self.test_file = "test_mismatched.tsv"
+
+
+class StsbProcessor(GLUEProcessor):
+ """StsbProcessor."""
+
+ def __init__(self):
+ super(StsbProcessor, self).__init__()
+ self.label_column = 9
+ self.text_a_column = 7
+ self.text_b_column = 8
+
+ def get_labels(self):
+ return [0.0]
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ if i == 0 and self.contains_header and set_type != "test":
+ continue
+ if i == 0 and self.test_contains_header and set_type == "test":
+ continue
+ guid = "%s-%s" % (set_type, i)
+
+ a_column = (
+ self.text_a_column if set_type != "test" else self.test_text_a_column)
+ b_column = (
+ self.text_b_column if set_type != "test" else self.test_text_b_column)
+
+ # there are some incomplete lines in QNLI
+ if len(line) <= a_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_a = line[a_column]
+
+ if b_column is not None:
+ if len(line) <= b_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ text_b = line[b_column]
+ else:
+ text_b = None
+
+ if set_type == "test":
+ label = self.get_labels()[0]
+ else:
+ if len(line) <= self.label_column:
+ logging.warning("Incomplete line, ignored.")
+ continue
+ label = float(line[self.label_column])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+
+ return examples
+
+
+def file_based_convert_examples_to_features(examples,
+ label_list,
+ max_seq_length,
+ tokenize_fn,
+ output_file,
+ num_passes=1):
+ """Convert a set of `InputExample`s to a TFRecord file."""
+
+ # do not create duplicated records
+ if tf.io.gfile.exists(output_file) and not FLAGS.overwrite_data:
+ logging.info("Do not overwrite tfrecord %s exists.", output_file)
+ return
+
+ logging.info("Create new tfrecord %s.", output_file)
+
+ writer = tf.io.TFRecordWriter(output_file)
+
+ examples *= num_passes
+
+ for (ex_index, example) in enumerate(examples):
+ if ex_index % 10000 == 0:
+ logging.info("Writing example %d of %d", ex_index, len(examples))
+
+ feature = classifier_utils.convert_single_example(ex_index, example,
+ label_list,
+ max_seq_length,
+ tokenize_fn,
+ FLAGS.use_bert_format)
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ features = collections.OrderedDict()
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_float_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+ if label_list is not None:
+ features["label_ids"] = create_int_feature([feature.label_id])
+ else:
+ features["label_ids"] = create_float_feature([float(feature.label_id)])
+ features["is_real_example"] = create_int_feature(
+ [int(feature.is_real_example)])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+def main(_):
+ logging.set_verbosity(logging.INFO)
+ processors = {
+ "mnli_matched": MnliMatchedProcessor,
+ "mnli_mismatched": MnliMismatchedProcessor,
+ "sts-b": StsbProcessor,
+ "imdb": ImdbProcessor,
+ "yelp5": Yelp5Processor
+ }
+
+ task_name = FLAGS.task_name.lower()
+
+ if task_name not in processors:
+ raise ValueError("Task not found: %s" % (task_name))
+
+ processor = processors[task_name]()
+ label_list = processor.get_labels() if not FLAGS.is_regression else None
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(FLAGS.spiece_model_file)
+
+ def tokenize_fn(text):
+ text = preprocess_utils.preprocess_text(text, lower=FLAGS.uncased)
+ return preprocess_utils.encode_ids(sp, text)
+
+ spm_basename = os.path.basename(FLAGS.spiece_model_file)
+
+ train_file_base = "{}.len-{}.train.tf_record".format(spm_basename,
+ FLAGS.max_seq_length)
+ train_file = os.path.join(FLAGS.output_dir, train_file_base)
+ logging.info("Use tfrecord file %s", train_file)
+
+ train_examples = processor.get_train_examples(FLAGS.data_dir)
+ np.random.shuffle(train_examples)
+ logging.info("Num of train samples: %d", len(train_examples))
+
+ file_based_convert_examples_to_features(train_examples, label_list,
+ FLAGS.max_seq_length, tokenize_fn,
+ train_file, FLAGS.num_passes)
+ if FLAGS.eval_split == "dev":
+ eval_examples = processor.get_dev_examples(FLAGS.data_dir)
+ else:
+ eval_examples = processor.get_test_examples(FLAGS.data_dir)
+
+ logging.info("Num of eval samples: %d", len(eval_examples))
+
+ # TPU requires a fixed batch size for all batches, therefore the number
+ # of examples must be a multiple of the batch size, or else examples
+ # will get dropped. So we pad with fake examples which are ignored
+ # later on. These do NOT count towards the metric (all tf.metrics
+ # support a per-instance weight, and these get a weight of 0.0).
+ #
+ # Modified in XL: We also adopt the same mechanism for GPUs.
+ while len(eval_examples) % FLAGS.eval_batch_size != 0:
+ eval_examples.append(classifier_utils.PaddingInputExample())
+
+ eval_file_base = "{}.len-{}.{}.eval.tf_record".format(spm_basename,
+ FLAGS.max_seq_length,
+ FLAGS.eval_split)
+ eval_file = os.path.join(FLAGS.output_dir, eval_file_base)
+
+ file_based_convert_examples_to_features(eval_examples, label_list,
+ FLAGS.max_seq_length, tokenize_fn,
+ eval_file)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/xlnet/preprocess_pretrain_data.py b/modeling/official/legacy/xlnet/preprocess_pretrain_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..88b875d3fec92198ce75c92e651e290c694097cb
--- /dev/null
+++ b/modeling/official/legacy/xlnet/preprocess_pretrain_data.py
@@ -0,0 +1,1005 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -*- coding: utf-8 -*-
+"""Script to pre-process pre-training data into tfrecords."""
+
+import json
+import os
+import random
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+
+import numpy as np
+
+import tensorflow.compat.v1 as tf
+import sentencepiece as spm
+from official.legacy.xlnet import preprocess_utils
+
+FLAGS = flags.FLAGS
+
+
+special_symbols = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "": 3,
+ "": 4,
+ "": 5,
+ "": 6,
+ "": 7,
+ "": 8,
+}
+
+VOCAB_SIZE = 32000
+UNK_ID = special_symbols[""]
+CLS_ID = special_symbols[""]
+SEP_ID = special_symbols[""]
+MASK_ID = special_symbols[""]
+EOD_ID = special_symbols[""]
+
+
+def _int64_feature(values):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
+
+
+def _float_feature(values):
+ return tf.train.Feature(float_list=tf.train.FloatList(value=values))
+
+
+def format_filename(prefix, bsz_per_host, seq_len, bi_data, suffix,
+ mask_alpha=5, mask_beta=1, reuse_len=None, uncased=False,
+ fixed_num_predict=None):
+ """docs."""
+ if reuse_len is None:
+ reuse_len_str = ""
+ else:
+ reuse_len_str = "reuse-{}.".format(reuse_len)
+ if not uncased:
+ uncased_str = ""
+ else:
+ uncased_str = "uncased."
+ if bi_data:
+ bi_data_str = "bi"
+ else:
+ bi_data_str = "uni"
+ if fixed_num_predict is not None:
+ fnp_str = "fnp-{}.".format(fixed_num_predict)
+ else:
+ fnp_str = ""
+
+ file_name = "{}.bsz-{}.seqlen-{}.{}{}{}.alpha-{}.beta-{}.{}{}".format(
+ prefix, bsz_per_host, seq_len, reuse_len_str, uncased_str, bi_data_str,
+ mask_alpha, mask_beta, fnp_str, suffix)
+
+ return file_name
+
+
+def _create_data(idx, input_paths):
+ """Creates data."""
+ # Load sentence-piece model
+ sp = spm.SentencePieceProcessor()
+ sp.Load(FLAGS.sp_path)
+
+ input_shards = []
+ total_line_cnt = 0
+ for input_path in input_paths:
+ input_data, sent_ids = [], []
+ sent_id, line_cnt = True, 0
+ logging.info("Processing %s", input_path)
+ for line in tf.gfile.Open(input_path):
+ if line_cnt % 100000 == 0:
+ logging.info("Loading line %d", line_cnt)
+ line_cnt += 1
+
+ if not line.strip():
+ if FLAGS.use_eod:
+ sent_id = not sent_id
+ cur_sent = [EOD_ID]
+ else:
+ continue
+ else:
+ if FLAGS.from_raw_text:
+ cur_sent = preprocess_utils.preprocess_text(
+ line.strip(), lower=FLAGS.uncased)
+ cur_sent = preprocess_utils.encode_ids(sp, cur_sent)
+ else:
+ cur_sent = list(map(int, line.strip().split()))
+
+ input_data.extend(cur_sent)
+ sent_ids.extend([sent_id] * len(cur_sent))
+ sent_id = not sent_id
+
+ logging.info("Finish with line %d", line_cnt)
+ if line_cnt == 0:
+ continue
+
+ input_data = np.array(input_data, dtype=np.int64)
+ sent_ids = np.array(sent_ids, dtype=bool)
+
+ total_line_cnt += line_cnt
+ input_shards.append((input_data, sent_ids))
+
+ logging.info("[Task %d] Total number line: %d", idx, total_line_cnt)
+
+ tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
+
+ filenames, num_batch = [], 0
+
+ # Randomly shuffle input shards (with a fixed but distinct random seed)
+ np.random.seed(100 * FLAGS.task + FLAGS.pass_id)
+
+ perm_indices = np.random.permutation(len(input_shards))
+ logging.info("Using perm indices %s for pass %d",
+ perm_indices.tolist(), FLAGS.pass_id)
+
+ input_data_list, sent_ids_list = [], []
+ prev_sent_id = None
+ for perm_idx in perm_indices:
+ input_data, sent_ids = input_shards[perm_idx]
+ # make sure the `send_ids[0] == not prev_sent_id`
+ if prev_sent_id is not None and sent_ids[0] == prev_sent_id:
+ sent_ids = np.logical_not(sent_ids)
+
+ # append to temporary list
+ input_data_list.append(input_data)
+ sent_ids_list.append(sent_ids)
+
+ # update `prev_sent_id`
+ prev_sent_id = sent_ids[-1]
+
+ input_data = np.concatenate(input_data_list)
+ sent_ids = np.concatenate(sent_ids_list)
+
+ file_name, cur_num_batch = create_tfrecords(
+ save_dir=tfrecord_dir,
+ basename="{}-{}-{}".format(FLAGS.split, idx, FLAGS.pass_id),
+ data=[input_data, sent_ids],
+ bsz_per_host=FLAGS.bsz_per_host,
+ seq_len=FLAGS.seq_len,
+ bi_data=FLAGS.bi_data,
+ sp=sp,
+ )
+
+ filenames.append(file_name)
+ num_batch += cur_num_batch
+
+ record_info = {
+ "filenames": filenames,
+ "num_batch": num_batch
+ }
+
+ return record_info
+
+
+def create_data(_):
+ """Creates pretrain data."""
+ # Validate FLAGS
+ assert FLAGS.bsz_per_host % FLAGS.num_core_per_host == 0
+ if not FLAGS.use_tpu:
+ FLAGS.num_core_per_host = 1 # forced to be one
+
+ # Make workdirs
+ if not tf.gfile.Exists(FLAGS.save_dir):
+ tf.gfile.MakeDirs(FLAGS.save_dir)
+
+ tfrecord_dir = os.path.join(FLAGS.save_dir, "tfrecords")
+ if not tf.gfile.Exists(tfrecord_dir):
+ tf.gfile.MakeDirs(tfrecord_dir)
+
+ # Create and dump corpus_info from task 0
+ if FLAGS.task == 0 and FLAGS.pass_id == 0:
+ corpus_info = {
+ "vocab_size": VOCAB_SIZE,
+ "bsz_per_host": FLAGS.bsz_per_host,
+ "num_core_per_host": FLAGS.num_core_per_host,
+ "seq_len": FLAGS.seq_len,
+ "reuse_len": FLAGS.reuse_len,
+ "uncased": FLAGS.uncased,
+ "bi_data": FLAGS.bi_data,
+ "mask_alpha": FLAGS.mask_alpha,
+ "mask_beta": FLAGS.mask_beta,
+ "num_predict": FLAGS.num_predict,
+ "use_eod": FLAGS.use_eod,
+ "sp_path": FLAGS.sp_path,
+ "input_glob": FLAGS.input_glob,
+ }
+ corpus_info_path = os.path.join(FLAGS.save_dir, "corpus_info.json")
+ with tf.gfile.Open(corpus_info_path, "w") as fp:
+ json.dump(corpus_info, fp)
+
+ # Interleavely split the work into FLAGS.num_task splits
+ file_paths = sorted(tf.gfile.Glob(FLAGS.input_glob))
+ logging.info("Use glob: %s", FLAGS.input_glob)
+ logging.info("Find %d files: %s", len(file_paths), file_paths)
+
+ task_file_paths = file_paths[FLAGS.task::FLAGS.num_task]
+ if not task_file_paths:
+ logging.info("Exit: task %d has no file to process.", FLAGS.task)
+ return
+
+ logging.info("Task %d process %d files: %s",
+ FLAGS.task, len(task_file_paths), task_file_paths)
+ record_info = _create_data(FLAGS.task, task_file_paths)
+
+ record_prefix = "record_info-{}-{}-{}".format(
+ FLAGS.split, FLAGS.task, FLAGS.pass_id)
+ record_name = format_filename(
+ prefix=record_prefix,
+ bsz_per_host=FLAGS.bsz_per_host,
+ seq_len=FLAGS.seq_len,
+ mask_alpha=FLAGS.mask_alpha,
+ mask_beta=FLAGS.mask_beta,
+ reuse_len=FLAGS.reuse_len,
+ bi_data=FLAGS.bi_data,
+ suffix="json",
+ uncased=FLAGS.uncased,
+ fixed_num_predict=FLAGS.num_predict)
+ record_info_path = os.path.join(tfrecord_dir, record_name)
+
+ with tf.gfile.Open(record_info_path, "w") as fp:
+ json.dump(record_info, fp)
+
+
+def batchify(data, bsz_per_host, sent_ids=None):
+ """Creates batches."""
+ num_step = len(data) // bsz_per_host
+ data = data[:bsz_per_host * num_step]
+ data = data.reshape(bsz_per_host, num_step)
+ if sent_ids is not None:
+ sent_ids = sent_ids[:bsz_per_host * num_step]
+ sent_ids = sent_ids.reshape(bsz_per_host, num_step)
+
+ if sent_ids is not None:
+ return data, sent_ids
+ return data
+
+
+def _split_a_and_b(data, sent_ids, begin_idx, tot_len, extend_target=False):
+ """Split two segments from `data` starting from the index `begin_idx`."""
+
+ data_len = data.shape[0]
+ if begin_idx + tot_len >= data_len:
+ logging.info("[_split_a_and_b] returns None: "
+ "begin_idx %d + tot_len %d >= data_len %d",
+ begin_idx, tot_len, data_len)
+ return None
+
+ end_idx = begin_idx + 1
+ cut_points = []
+ while end_idx < data_len:
+ if sent_ids[end_idx] != sent_ids[end_idx - 1]:
+ if end_idx - begin_idx >= tot_len: break
+ cut_points.append(end_idx)
+ end_idx += 1
+
+ a_begin = begin_idx
+ if len(cut_points) == 0 or random.random() < 0.5: # pylint:disable=g-explicit-length-test
+ label = 0
+ if len(cut_points) == 0: # pylint:disable=g-explicit-length-test
+ a_end = end_idx
+ else:
+ a_end = random.choice(cut_points)
+
+ b_len = max(1, tot_len - (a_end - a_begin))
+ # (zihangd): `data_len - 1` to account for extend_target
+ b_begin = random.randint(0, data_len - 1 - b_len)
+ b_end = b_begin + b_len
+ while b_begin > 0 and sent_ids[b_begin - 1] == sent_ids[b_begin]:
+ b_begin -= 1
+ # (zihangd): `data_len - 1` to account for extend_target
+ while b_end < data_len - 1 and sent_ids[b_end - 1] == sent_ids[b_end]:
+ b_end += 1
+
+ new_begin = a_end
+ else:
+ label = 1
+ a_end = random.choice(cut_points)
+ b_begin = a_end
+ b_end = end_idx
+
+ new_begin = b_end
+
+ while a_end - a_begin + b_end - b_begin > tot_len:
+ if a_end - a_begin > b_end - b_begin:
+ # delete the right side only for the LM objective
+ a_end -= 1
+ else:
+ b_end -= 1
+
+ ret = [data[a_begin: a_end], data[b_begin: b_end], label, new_begin]
+
+ if extend_target:
+ if a_end >= data_len or b_end >= data_len:
+ logging.info("[_split_a_and_b] returns None: "
+ "a_end %d or b_end %d >= data_len %d",
+ a_end, b_end, data_len)
+ return None
+ a_target = data[a_begin + 1: a_end + 1]
+ b_target = data[b_begin: b_end + 1]
+ ret.extend([a_target, b_target])
+
+ return ret
+
+
+def _is_start_piece(piece):
+ special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
+ if (piece.startswith("▁") or piece.startswith("<")
+ or piece in special_pieces):
+ return True
+ else:
+ return False
+
+
+def _sample_mask(sp, seg, reverse=False, max_gram=5, goal_num_predict=None):
+ """Samples `goal_num_predict` tokens for partial prediction."""
+ seg_len = len(seg)
+ mask = np.array([False] * seg_len, dtype=bool)
+
+ num_predict = 0
+
+ ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
+ pvals = 1. / np.arange(1, max_gram + 1)
+ pvals /= pvals.sum(keepdims=True)
+
+ if reverse:
+ seg = np.flip(seg, 0)
+
+ cur_len = 0
+ while cur_len < seg_len:
+ if goal_num_predict is not None and num_predict >= goal_num_predict: break
+
+ n = np.random.choice(ngrams, p=pvals)
+ if goal_num_predict is not None:
+ n = min(n, goal_num_predict - num_predict)
+ ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
+ l_ctx = np.random.choice(ctx_size)
+ r_ctx = ctx_size - l_ctx
+
+ # Find the start position of a complete token
+ beg = cur_len + l_ctx
+ while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
+ beg += 1
+ if beg >= seg_len:
+ break
+
+ # Find the end position of the n-gram (start pos of the n+1-th gram)
+ end = beg + 1
+ cnt_ngram = 1
+ while end < seg_len:
+ cnt_ngram += 1
+ if cnt_ngram > n:
+ break
+ end += 1
+ if end >= seg_len:
+ break
+
+ # Update
+ mask[beg:end] = True
+ num_predict += end - beg
+
+ cur_len = end + r_ctx
+
+ while goal_num_predict is not None and num_predict < goal_num_predict:
+ i = np.random.randint(seg_len)
+ if not mask[i]:
+ mask[i] = True
+ num_predict += 1
+
+ if reverse:
+ mask = np.flip(mask, 0)
+
+ return mask
+
+
+def _sample_mask_ngram(sp, seg, reverse=False, max_gram=5,
+ goal_num_predict=None):
+ """Sample `goal_num_predict` tokens for partial prediction."""
+
+ seg_len = len(seg)
+ mask = np.array([False] * seg_len, dtype=bool)
+
+ num_predict = 0
+
+ ngrams = np.arange(1, max_gram + 1, dtype=np.int64)
+ pvals = 1. / np.arange(1, max_gram + 1)
+ pvals /= pvals.sum(keepdims=True)
+
+ if reverse:
+ seg = np.flip(seg, 0)
+
+ cur_len = 0
+ while cur_len < seg_len:
+ if goal_num_predict is not None and num_predict >= goal_num_predict: break
+
+ n = np.random.choice(ngrams, p=pvals)
+ if goal_num_predict is not None:
+ n = min(n, goal_num_predict - num_predict)
+ ctx_size = (n * FLAGS.mask_alpha) // FLAGS.mask_beta
+ l_ctx = np.random.choice(ctx_size)
+ r_ctx = ctx_size - l_ctx
+
+ # Find the start position of a complete token
+ beg = cur_len + l_ctx
+ while beg < seg_len and not _is_start_piece(sp.IdToPiece(seg[beg].item())):
+ beg += 1
+ if beg >= seg_len:
+ break
+
+ # Find the end position of the n-gram (start pos of the n+1-th gram)
+ end = beg
+ cnt_ngram = 0
+ while end < seg_len:
+ if _is_start_piece(sp.IdToPiece(seg[end].item())):
+ cnt_ngram += 1
+ if cnt_ngram > n:
+ break
+
+ # select current piece
+ mask[end] = True
+
+ # update the end pointer and increment num_predict
+ end += 1
+ num_predict += 1
+
+ if goal_num_predict is not None and num_predict >= goal_num_predict:
+ break
+
+ cur_len = end + r_ctx
+
+ while goal_num_predict is not None and num_predict < goal_num_predict:
+ i = np.random.randint(seg_len)
+ if not mask[i]:
+ mask[i] = True
+ num_predict += 1
+
+ if reverse:
+ mask = np.flip(mask, 0)
+
+ return mask
+
+
+def create_tfrecords(save_dir, basename, data, bsz_per_host, seq_len,
+ bi_data, sp):
+ """Creates TFRecords."""
+ data, sent_ids = data[0], data[1]
+
+ num_core = FLAGS.num_core_per_host
+ bsz_per_core = bsz_per_host // num_core
+
+ if bi_data:
+ assert bsz_per_host % (2 * FLAGS.num_core_per_host) == 0
+ fwd_data, fwd_sent_ids = batchify(data, bsz_per_host // 2, sent_ids)
+
+ fwd_data = fwd_data.reshape(num_core, 1, bsz_per_core // 2, -1)
+ fwd_sent_ids = fwd_sent_ids.reshape(num_core, 1, bsz_per_core // 2, -1)
+
+ bwd_data = fwd_data[:, :, :, ::-1]
+ bwd_sent_ids = fwd_sent_ids[:, :, :, ::-1]
+
+ data = np.concatenate(
+ [fwd_data, bwd_data], 1).reshape(bsz_per_host, -1)
+ sent_ids = np.concatenate(
+ [fwd_sent_ids, bwd_sent_ids], 1).reshape(bsz_per_host, -1)
+ else:
+ data, sent_ids = batchify(data, bsz_per_host, sent_ids)
+
+ logging.info("Raw data shape %s.", data.shape)
+
+ file_name = format_filename(
+ prefix=basename,
+ bsz_per_host=bsz_per_host,
+ seq_len=seq_len,
+ bi_data=bi_data,
+ suffix="tfrecords",
+ mask_alpha=FLAGS.mask_alpha,
+ mask_beta=FLAGS.mask_beta,
+ reuse_len=FLAGS.reuse_len,
+ uncased=FLAGS.uncased,
+ fixed_num_predict=FLAGS.num_predict
+ )
+ save_path = os.path.join(save_dir, file_name)
+ record_writer = tf.python_io.TFRecordWriter(save_path)
+ logging.info("Start writing %s.", save_path)
+
+ num_batch = 0
+ reuse_len = FLAGS.reuse_len
+
+ # [sep] x 2 + [cls]
+ assert reuse_len < seq_len - 3
+
+ data_len = data.shape[1]
+ sep_array = np.array([SEP_ID], dtype=np.int64)
+ cls_array = np.array([CLS_ID], dtype=np.int64)
+
+ i = 0
+ while i + seq_len <= data_len:
+ if num_batch % 500 == 0:
+ logging.info("Processing batch %d", num_batch)
+
+ all_ok = True
+ features = []
+ for idx in range(bsz_per_host):
+ inp = data[idx, i: i + reuse_len]
+ tgt = data[idx, i + 1: i + reuse_len + 1]
+
+ results = _split_a_and_b(
+ data[idx],
+ sent_ids[idx],
+ begin_idx=i + reuse_len,
+ tot_len=seq_len - reuse_len - 3,
+ extend_target=True)
+ if results is None:
+ logging.info("Break out with seq idx %d", i)
+ all_ok = False
+ break
+
+ # unpack the results
+ (a_data, b_data, label, _, a_target, b_target) = tuple(results)
+
+ # sample ngram spans to predict
+ reverse = bi_data and (idx // (bsz_per_core // 2)) % 2 == 1
+ if FLAGS.num_predict is None:
+ num_predict_0 = num_predict_1 = None
+ else:
+ num_predict_1 = FLAGS.num_predict // 2
+ num_predict_0 = FLAGS.num_predict - num_predict_1
+ mask_0 = _sample_mask(sp, inp, reverse=reverse,
+ goal_num_predict=num_predict_0)
+ mask_1 = _sample_mask(sp, np.concatenate([a_data, sep_array, b_data,
+ sep_array, cls_array]),
+ reverse=reverse, goal_num_predict=num_predict_1)
+
+ # concatenate data
+ cat_data = np.concatenate([inp, a_data, sep_array, b_data,
+ sep_array, cls_array])
+ seg_id = ([0] * (reuse_len + a_data.shape[0]) + [0] +
+ [1] * b_data.shape[0] + [1] + [2])
+ assert cat_data.shape[0] == seq_len
+ assert mask_0.shape[0] == seq_len // 2
+ assert mask_1.shape[0] == seq_len // 2
+
+ # the last two CLS's are not used, just for padding purposes
+ tgt = np.concatenate([tgt, a_target, b_target, cls_array, cls_array])
+ assert tgt.shape[0] == seq_len
+
+ is_masked = np.concatenate([mask_0, mask_1], 0)
+ if FLAGS.num_predict is not None:
+ assert np.sum(is_masked) == FLAGS.num_predict
+
+ feature = {
+ "input": _int64_feature(cat_data),
+ "is_masked": _int64_feature(is_masked),
+ "target": _int64_feature(tgt),
+ "seg_id": _int64_feature(seg_id),
+ "label": _int64_feature([label]),
+ }
+ features.append(feature)
+
+ if all_ok:
+ assert len(features) == bsz_per_host
+ for feature in features:
+ example = tf.train.Example(features=tf.train.Features(feature=feature))
+ record_writer.write(example.SerializeToString())
+ num_batch += 1
+ else:
+ break
+
+ i += reuse_len
+
+ record_writer.close()
+ logging.info("Done writing %s. Num of batches: %d", save_path, num_batch)
+
+ return save_path, num_batch
+
+
+################
+# get_input_fn #
+################
+def _convert_example(example, use_bfloat16):
+ """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16."""
+ for key in list(example.keys()):
+ val = example[key]
+ if tf_keras.backend.is_sparse(val):
+ val = tf.sparse.to_dense(val)
+ if val.dtype == tf.int64:
+ val = tf.cast(val, tf.int32)
+ if use_bfloat16 and val.dtype == tf.float32:
+ val = tf.cast(val, tf.bfloat16)
+
+ example[key] = val
+
+
+def parse_files_to_dataset(parser, file_names, split, num_batch, num_hosts,
+ host_id, num_core_per_host, bsz_per_core):
+ """Parses files to a dataset."""
+ del num_batch
+ # list of file pathes
+ num_files = len(file_names)
+ num_files_per_host = num_files // num_hosts
+ my_start_file_id = host_id * num_files_per_host
+ my_end_file_id = (host_id + 1) * num_files_per_host
+ if host_id == num_hosts - 1:
+ my_end_file_id = num_files
+ file_paths = file_names[my_start_file_id: my_end_file_id]
+ logging.info("Host %d handles %d files", host_id, len(file_paths))
+
+ assert split == "train"
+ dataset = tf.data.Dataset.from_tensor_slices(file_paths)
+
+ # file-level shuffle
+ if len(file_paths) > 1:
+ dataset = dataset.shuffle(len(file_paths))
+
+ # Note: we cannot perform sample-level shuffle here because this will violate
+ # the consecutive requirement of data stream.
+ dataset = tf.data.TFRecordDataset(dataset)
+
+ # Note: since we are doing online preprocessing, the parsed result of
+ # the same input at each time will be different. Thus, cache processed data
+ # is not helpful. It will use a lot of memory and lead to contrainer OOM.
+ # So, change to cache non-parsed raw data instead.
+ dataset = dataset.cache().map(parser).repeat()
+ dataset = dataset.batch(bsz_per_core, drop_remainder=True)
+ dataset = dataset.prefetch(num_core_per_host * bsz_per_core)
+
+ return dataset
+
+
+def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
+ """Samples a permutation of the factorization order, and create a mask.
+
+ Args:
+ inputs: int64 Tensor in shape [seq_len], input ids.
+ targets: int64 Tensor in shape [seq_len], target ids.
+ is_masked: bool Tensor in shape [seq_len]. True means being selected
+ for partial prediction.
+ perm_size: the length of longest permutation. Could be set to be reuse_len.
+ Should not be larger than reuse_len or there will be data leaks.
+ seq_len: int, sequence length.
+
+ Returns:
+ The permutation mask, new targets, target mask, and new inputs.
+
+ """
+
+ # Generate permutation indices
+ index = tf.range(seq_len, dtype=tf.int64)
+ index = tf.transpose(tf.reshape(index, [-1, perm_size]))
+ index = tf.random_shuffle(index)
+ index = tf.reshape(tf.transpose(index), [-1])
+
+ # `perm_mask` and `target_mask`
+ # non-functional tokens
+ non_func_tokens = tf.logical_not(tf.logical_or(
+ tf.equal(inputs, SEP_ID),
+ tf.equal(inputs, CLS_ID)))
+
+ non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
+ masked_or_func_tokens = tf.logical_not(non_mask_tokens)
+
+ # Set the permutation indices of non-masked (& non-funcional) tokens to the
+ # smallest index (-1):
+ # (1) they can be seen by all other positions
+ # (2) they cannot see masked positions, so there won"t be information leak
+ smallest_index = -tf.ones([seq_len], dtype=tf.int64)
+ rev_index = tf.where(non_mask_tokens, smallest_index, index)
+
+ # Create `target_mask`: non-funcional and maksed tokens
+ # 1: use mask as input and have loss
+ # 0: use token (or [SEP], [CLS]) as input and do not have loss
+ target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
+ target_mask = tf.cast(target_tokens, tf.float32)
+
+ # Create `perm_mask`
+ # `target_tokens` cannot see themselves
+ self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)
+
+ # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
+ # 0: can attend if i > j or j is non-masked
+ perm_mask = tf.logical_and(
+ self_rev_index[:, None] <= rev_index[None, :],
+ masked_or_func_tokens)
+ perm_mask = tf.cast(perm_mask, tf.float32)
+
+ # new target: [next token] for LM and [curr token] (self) for PLM
+ new_targets = tf.concat([inputs[0: 1], targets[: -1]],
+ axis=0)
+
+ # construct inputs_k
+ inputs_k = inputs
+
+ # construct inputs_q
+ inputs_q = target_mask
+
+ return perm_mask, new_targets, target_mask, inputs_k, inputs_q
+
+
+def get_dataset(params, num_hosts, num_core_per_host, split, file_names,
+ num_batch, seq_len, reuse_len, perm_size, mask_alpha,
+ mask_beta, use_bfloat16=False, num_predict=None):
+ """Gets the dataset."""
+
+ del mask_alpha
+ del mask_beta
+ bsz_per_core = params["batch_size"]
+ if num_hosts > 1:
+ host_id = params["context"].current_host
+ else:
+ host_id = 0
+
+ #### Function used to parse tfrecord
+ def parser(record):
+ """function used to parse tfrecord."""
+
+ record_spec = {
+ "input": tf.FixedLenFeature([seq_len], tf.int64),
+ "target": tf.FixedLenFeature([seq_len], tf.int64),
+ "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
+ "label": tf.FixedLenFeature([1], tf.int64),
+ "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
+ }
+
+ # retrieve serialized example
+ example = tf.parse_single_example(
+ serialized=record,
+ features=record_spec)
+
+ inputs = example.pop("input")
+ target = example.pop("target")
+ is_masked = tf.cast(example.pop("is_masked"), tf.bool)
+
+ non_reuse_len = seq_len - reuse_len
+ assert perm_size <= reuse_len and perm_size <= non_reuse_len
+
+ perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
+ inputs[:reuse_len],
+ target[:reuse_len],
+ is_masked[:reuse_len],
+ perm_size,
+ reuse_len)
+
+ perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
+ inputs[reuse_len:],
+ target[reuse_len:],
+ is_masked[reuse_len:],
+ perm_size,
+ non_reuse_len)
+
+ perm_mask_0 = tf.concat([perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
+ axis=1)
+ perm_mask_1 = tf.concat([tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
+ axis=1)
+ perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
+ target = tf.concat([target_0, target_1], axis=0)
+ target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
+ input_k = tf.concat([input_k_0, input_k_1], axis=0)
+ input_q = tf.concat([input_q_0, input_q_1], axis=0)
+
+ if num_predict is not None:
+ indices = tf.range(seq_len, dtype=tf.int64)
+ bool_target_mask = tf.cast(target_mask, tf.bool)
+ indices = tf.boolean_mask(indices, bool_target_mask)
+
+ ##### extra padding due to CLS/SEP introduced after prepro
+ actual_num_predict = tf.shape(indices)[0]
+ pad_len = num_predict - actual_num_predict
+
+ ##### target_mapping
+ target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
+ paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
+ target_mapping = tf.concat([target_mapping, paddings], axis=0)
+ example["target_mapping"] = tf.reshape(target_mapping,
+ [num_predict, seq_len])
+
+ ##### target
+ target = tf.boolean_mask(target, bool_target_mask)
+ paddings = tf.zeros([pad_len], dtype=target.dtype)
+ target = tf.concat([target, paddings], axis=0)
+ example["target"] = tf.reshape(target, [num_predict])
+
+ ##### target mask
+ target_mask = tf.concat(
+ [tf.ones([actual_num_predict], dtype=tf.float32),
+ tf.zeros([pad_len], dtype=tf.float32)],
+ axis=0)
+ example["target_mask"] = tf.reshape(target_mask, [num_predict])
+ else:
+ example["target"] = tf.reshape(target, [seq_len])
+ example["target_mask"] = tf.reshape(target_mask, [seq_len])
+
+ # reshape back to fixed shape
+ example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
+ example["input_k"] = tf.reshape(input_k, [seq_len])
+ example["input_q"] = tf.reshape(input_q, [seq_len])
+
+ _convert_example(example, use_bfloat16)
+
+ for k, v in example.items():
+ logging.info("%s: %s", k, v)
+
+ return example
+
+ # Get dataset
+ dataset = parse_files_to_dataset(
+ parser=parser,
+ file_names=file_names,
+ split=split,
+ num_batch=num_batch,
+ num_hosts=num_hosts,
+ host_id=host_id,
+ num_core_per_host=num_core_per_host,
+ bsz_per_core=bsz_per_core)
+
+ return dataset
+
+
+def get_input_fn(
+ tfrecord_dir,
+ split,
+ bsz_per_host,
+ seq_len,
+ reuse_len,
+ bi_data,
+ num_hosts=1,
+ num_core_per_host=1,
+ perm_size=None,
+ mask_alpha=None,
+ mask_beta=None,
+ uncased=False,
+ num_passes=None,
+ use_bfloat16=False,
+ num_predict=None):
+ """Gets the input function."""
+
+ # Merge all record infos into a single one
+ record_glob_base = format_filename(
+ prefix="record_info-{}-*".format(split),
+ bsz_per_host=bsz_per_host,
+ seq_len=seq_len,
+ bi_data=bi_data,
+ suffix="json",
+ mask_alpha=mask_alpha,
+ mask_beta=mask_beta,
+ reuse_len=reuse_len,
+ uncased=uncased,
+ fixed_num_predict=num_predict)
+
+ record_info = {"num_batch": 0, "filenames": []}
+
+ tfrecord_dirs = tfrecord_dir.split(",")
+ logging.info("Use the following tfrecord dirs: %s", tfrecord_dirs)
+
+ for idx, record_dir in enumerate(tfrecord_dirs):
+ record_glob = os.path.join(record_dir, record_glob_base)
+ logging.info("[%d] Record glob: %s", idx, record_glob)
+
+ record_paths = sorted(tf.gfile.Glob(record_glob))
+ logging.info("[%d] Num of record info path: %d", idx, len(record_paths))
+
+ cur_record_info = {"num_batch": 0, "filenames": []}
+
+ for record_info_path in record_paths:
+ if num_passes is not None:
+ record_info_name = os.path.basename(record_info_path)
+ fields = record_info_name.split(".")[0].split("-")
+ pass_id = int(fields[-1])
+ if len(fields) == 5 and pass_id >= num_passes:
+ logging.info("Skip pass %d: %s", pass_id, record_info_name)
+ continue
+
+ with tf.gfile.Open(record_info_path, "r") as fp:
+ info = json.load(fp)
+ if num_passes is not None:
+ eff_num_passes = min(num_passes, len(info["filenames"]))
+ ratio = eff_num_passes / len(info["filenames"])
+ cur_record_info["num_batch"] += int(info["num_batch"] * ratio)
+ cur_record_info["filenames"] += info["filenames"][:eff_num_passes]
+ else:
+ cur_record_info["num_batch"] += info["num_batch"]
+ cur_record_info["filenames"] += info["filenames"]
+
+ # overwrite directory for `cur_record_info`
+ new_filenames = []
+ for filename in cur_record_info["filenames"]:
+ basename = os.path.basename(filename)
+ new_filename = os.path.join(record_dir, basename)
+ new_filenames.append(new_filename)
+ cur_record_info["filenames"] = new_filenames
+
+ logging.info("[Dir %d] Number of chosen batches: %s",
+ idx, cur_record_info["num_batch"])
+ logging.info("[Dir %d] Number of chosen files: %s",
+ idx, len(cur_record_info["filenames"]))
+ logging.info(cur_record_info["filenames"])
+
+ # add `cur_record_info` to global `record_info`
+ record_info["num_batch"] += cur_record_info["num_batch"]
+ record_info["filenames"] += cur_record_info["filenames"]
+
+ logging.info("Total number of batches: %d", record_info["num_batch"])
+ logging.info("Total number of files: %d", len(record_info["filenames"]))
+ logging.info(record_info["filenames"])
+
+ def input_fn(params):
+ """docs."""
+ assert params["batch_size"] * num_core_per_host == bsz_per_host
+
+ dataset = get_dataset(
+ params=params,
+ num_hosts=num_hosts,
+ num_core_per_host=num_core_per_host,
+ split=split,
+ file_names=record_info["filenames"],
+ num_batch=record_info["num_batch"],
+ seq_len=seq_len,
+ reuse_len=reuse_len,
+ perm_size=perm_size,
+ mask_alpha=mask_alpha,
+ mask_beta=mask_beta,
+ use_bfloat16=use_bfloat16,
+ num_predict=num_predict)
+
+ return dataset
+
+ return input_fn, record_info
+
+
+def define_flags():
+ """Defines relevant flags."""
+ flags.DEFINE_bool("use_tpu", True, help="whether to use TPUs")
+ flags.DEFINE_integer("bsz_per_host", 32, help="batch size per host.")
+ flags.DEFINE_integer("num_core_per_host", 8, help="num TPU cores per host.")
+
+ flags.DEFINE_integer("seq_len", 512,
+ help="Sequence length.")
+ flags.DEFINE_integer("reuse_len", 256,
+ help="Number of token that can be reused as memory. "
+ "Could be half of `seq_len`.")
+ flags.DEFINE_bool("uncased", False, help="Use uncased inputs or not.")
+ flags.DEFINE_bool("bi_data", True,
+ help="whether to create bidirectional data")
+ flags.DEFINE_integer("mask_alpha", default=6,
+ help="How many tokens to form a group.")
+ flags.DEFINE_integer("mask_beta", default=1,
+ help="How many tokens to mask within each group.")
+ flags.DEFINE_bool("use_eod", True,
+ help="whether to append EOD at the end of a doc.")
+ flags.DEFINE_bool("from_raw_text", True,
+ help="Whether the input is raw text or encoded ids.")
+ flags.DEFINE_integer("num_predict", default=85,
+ help="Num of tokens to predict.")
+
+ flags.DEFINE_string("input_glob", "data/example/*.txt",
+ help="Input file glob.")
+ flags.DEFINE_string("sp_path", "", help="Path to the sentence piece model.")
+ flags.DEFINE_string("save_dir", "proc_data/example",
+ help="Directory for saving the processed data.")
+ flags.DEFINE_enum("split", "train", ["train", "dev", "test"],
+ help="Save the data as which split.")
+
+ flags.DEFINE_integer("pass_id", 0, help="ID of the current pass."
+ "Different passes sample different negative segment.")
+ flags.DEFINE_integer("num_task", 1, help="Number of total tasks.")
+ flags.DEFINE_integer("task", 0, help="The Task ID. This value is used when "
+ "using multiple workers to identify each worker.")
+
+
+if __name__ == "__main__":
+ define_flags()
+ logging.set_verbosity(logging.INFO)
+ app.run(create_data)
diff --git a/modeling/official/legacy/xlnet/preprocess_squad_data.py b/modeling/official/legacy/xlnet/preprocess_squad_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..867be798efefe6d70ff84f4cd337923a3ccded1a
--- /dev/null
+++ b/modeling/official/legacy/xlnet/preprocess_squad_data.py
@@ -0,0 +1,108 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding=utf-8
+"""Script to pre-process SQUAD data into tfrecords."""
+
+import os
+import random
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+
+import sentencepiece as spm
+from official.legacy.xlnet import squad_utils
+
+flags.DEFINE_integer(
+ "num_proc", default=1, help="Number of preprocessing processes.")
+flags.DEFINE_integer("proc_id", default=0, help="Process id for preprocessing.")
+
+# I/O paths
+flags.DEFINE_string("output_dir", default="", help="Output dir for TF records.")
+flags.DEFINE_string(
+ "spiece_model_file", default="", help="Sentence Piece model path.")
+flags.DEFINE_string("train_file", default="", help="Path of train file.")
+flags.DEFINE_string("predict_file", default="", help="Path of prediction file.")
+
+# Data preprocessing config
+flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length")
+flags.DEFINE_integer("max_query_length", default=64, help="Max query length")
+flags.DEFINE_integer("doc_stride", default=128, help="Doc stride")
+flags.DEFINE_bool("uncased", default=False, help="Use uncased data.")
+flags.DEFINE_bool(
+ "create_train_data", default=True, help="Whether to create training data.")
+flags.DEFINE_bool(
+ "create_eval_data", default=False, help="Whether to create eval data.")
+
+FLAGS = flags.FLAGS
+
+
+def preprocess():
+ """Preprocesses SQUAD data."""
+ sp_model = spm.SentencePieceProcessor()
+ sp_model.Load(FLAGS.spiece_model_file)
+ spm_basename = os.path.basename(FLAGS.spiece_model_file)
+ if FLAGS.create_train_data:
+ train_rec_file = os.path.join(
+ FLAGS.output_dir,
+ "{}.{}.slen-{}.qlen-{}.train.tf_record".format(spm_basename,
+ FLAGS.proc_id,
+ FLAGS.max_seq_length,
+ FLAGS.max_query_length))
+
+ logging.info("Read examples from %s", FLAGS.train_file)
+ train_examples = squad_utils.read_squad_examples(
+ FLAGS.train_file, is_training=True)
+ train_examples = train_examples[FLAGS.proc_id::FLAGS.num_proc]
+
+ # Pre-shuffle the input to avoid having to make a very large shuffle
+ # buffer in the `input_fn`.
+ random.shuffle(train_examples)
+ write_to_logging = "Write to " + train_rec_file
+ logging.info(write_to_logging)
+ train_writer = squad_utils.FeatureWriter(
+ filename=train_rec_file, is_training=True)
+ squad_utils.convert_examples_to_features(
+ examples=train_examples,
+ sp_model=sp_model,
+ max_seq_length=FLAGS.max_seq_length,
+ doc_stride=FLAGS.doc_stride,
+ max_query_length=FLAGS.max_query_length,
+ is_training=True,
+ output_fn=train_writer.process_feature,
+ uncased=FLAGS.uncased)
+ train_writer.close()
+ if FLAGS.create_eval_data:
+ eval_examples = squad_utils.read_squad_examples(
+ FLAGS.predict_file, is_training=False)
+ squad_utils.create_eval_data(spm_basename, sp_model, eval_examples,
+ FLAGS.max_seq_length, FLAGS.max_query_length,
+ FLAGS.doc_stride, FLAGS.uncased,
+ FLAGS.output_dir)
+
+
+def main(_):
+ logging.set_verbosity(logging.INFO)
+
+ if not tf.io.gfile.exists(FLAGS.output_dir):
+ tf.io.gfile.mkdir(FLAGS.output_dir)
+
+ preprocess()
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/xlnet/preprocess_utils.py b/modeling/official/legacy/xlnet/preprocess_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d815f0821d2c8ed963661f3878f9567955cc86e2
--- /dev/null
+++ b/modeling/official/legacy/xlnet/preprocess_utils.py
@@ -0,0 +1,121 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding=utf-8
+"""Utilities for pre-processing."""
+import unicodedata
+
+import six
+
+SPIECE_UNDERLINE = '▁'
+
+
+def printable_text(text):
+ """Returns text encoded in a way suitable for print or `tf.logging`."""
+
+ # These functions want `str` for both Python2 and Python3, but in one case
+ # it's a Unicode string and in the other it's a byte string.
+ if six.PY3:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, bytes):
+ return text.decode('utf-8', 'ignore')
+ else:
+ raise ValueError('Unsupported string type: %s' % (type(text)))
+ elif six.PY2:
+ if isinstance(text, str):
+ return text
+ elif isinstance(text, unicode): # pylint: disable=undefined-variable
+ return text.encode('utf-8')
+ else:
+ raise ValueError('Unsupported string type: %s' % (type(text)))
+ else:
+ raise ValueError('Not running on Python2 or Python 3?')
+
+
+def print_(*args):
+ new_args = []
+ for arg in args:
+ if isinstance(arg, list):
+ s = [printable_text(i) for i in arg]
+ s = ' '.join(s)
+ new_args.append(s)
+ else:
+ new_args.append(printable_text(arg))
+ print(*new_args)
+
+
+def preprocess_text(inputs, lower=False, remove_space=True, keep_accents=False):
+ """Preprocesses texts."""
+ if remove_space:
+ outputs = ' '.join(inputs.strip().split())
+ else:
+ outputs = inputs
+
+ outputs = outputs.replace('``', '"').replace("''", '"')
+
+ if six.PY2 and isinstance(outputs, str):
+ outputs = outputs.decode('utf-8')
+
+ if not keep_accents:
+ outputs = unicodedata.normalize('NFKD', outputs)
+ outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
+ if lower:
+ outputs = outputs.lower()
+
+ return outputs
+
+
+def encode_pieces(sp_model, text, return_unicode=True, sample=False):
+ """Encodes pieces."""
+ # return_unicode is used only for py2
+
+ if six.PY2 and isinstance(text, unicode): # pylint: disable=undefined-variable
+ text = text.encode('utf-8')
+
+ if not sample:
+ pieces = sp_model.EncodeAsPieces(text)
+ else:
+ pieces = sp_model.SampleEncodeAsPieces(text, 64, 0.1)
+ new_pieces = []
+ for piece in pieces:
+ if len(piece) > 1 and piece[-1] == ',' and piece[-2].isdigit():
+ cur_pieces = sp_model.EncodeAsPieces(piece[:-1].replace(
+ SPIECE_UNDERLINE, ''))
+ if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
+ if len(cur_pieces[0]) == 1:
+ cur_pieces = cur_pieces[1:]
+ else:
+ cur_pieces[0] = cur_pieces[0][1:]
+ cur_pieces.append(piece[-1])
+ new_pieces.extend(cur_pieces)
+ else:
+ new_pieces.append(piece)
+
+ # note(zhiliny): convert back to unicode for py2
+ if six.PY2 and return_unicode:
+ ret_pieces = []
+ for piece in new_pieces:
+ if isinstance(piece, str):
+ piece = piece.decode('utf-8')
+ ret_pieces.append(piece)
+ new_pieces = ret_pieces
+
+ return new_pieces
+
+
+def encode_ids(sp_model, text, sample=False):
+ pieces = encode_pieces(sp_model, text, return_unicode=False, sample=sample)
+ ids = [sp_model.PieceToId(piece) for piece in pieces]
+ return ids
diff --git a/modeling/official/legacy/xlnet/run_classifier.py b/modeling/official/legacy/xlnet/run_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9147c9dc8a116effcbdefbe5e5a0496554f78fd
--- /dev/null
+++ b/modeling/official/legacy/xlnet/run_classifier.py
@@ -0,0 +1,187 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""XLNet classification finetuning runner in tf2.0."""
+
+import functools
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+
+import numpy as np
+import tensorflow as tf, tf_keras
+# pylint: disable=unused-import
+from official.common import distribute_utils
+from official.legacy.xlnet import common_flags
+from official.legacy.xlnet import data_utils
+from official.legacy.xlnet import optimization
+from official.legacy.xlnet import training_utils
+from official.legacy.xlnet import xlnet_config
+from official.legacy.xlnet import xlnet_modeling as modeling
+
+flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
+flags.DEFINE_string(
+ "summary_type",
+ default="last",
+ help="Method used to summarize a sequence into a vector.")
+
+FLAGS = flags.FLAGS
+
+
+def get_classificationxlnet_model(model_config,
+ run_config,
+ n_class,
+ summary_type="last"):
+ model = modeling.ClassificationXLNetModel(
+ model_config, run_config, n_class, summary_type, name="model")
+ return model
+
+
+def run_evaluation(strategy,
+ test_input_fn,
+ eval_steps,
+ model,
+ step,
+ eval_summary_writer=None):
+ """Run evaluation for classification task.
+
+ Args:
+ strategy: distribution strategy.
+ test_input_fn: input function for evaluation data.
+ eval_steps: total number of evaluation steps.
+ model: keras model object.
+ step: current train step.
+ eval_summary_writer: summary writer used to record evaluation metrics. As
+ there are fake data samples in validation set, we use mask to get rid of
+ them when calculating the accuracy. For the reason that there will be
+ dynamic-shape tensor, we first collect logits, labels and masks from TPU
+ and calculate the accuracy via numpy locally.
+
+ Returns:
+ A float metric, accuracy.
+ """
+
+ def _test_step_fn(inputs):
+ """Replicated validation step."""
+
+ inputs["mems"] = None
+ _, logits = model(inputs, training=False)
+ return logits, inputs["label_ids"], inputs["is_real_example"]
+
+ @tf.function
+ def _run_evaluation(test_iterator):
+ """Runs validation steps."""
+ logits, labels, masks = strategy.run(
+ _test_step_fn, args=(next(test_iterator),))
+ return logits, labels, masks
+
+ test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
+ correct = 0
+ total = 0
+ for _ in range(eval_steps):
+ logits, labels, masks = _run_evaluation(test_iterator)
+ logits = strategy.experimental_local_results(logits)
+ labels = strategy.experimental_local_results(labels)
+ masks = strategy.experimental_local_results(masks)
+ merged_logits = []
+ merged_labels = []
+ merged_masks = []
+
+ for i in range(strategy.num_replicas_in_sync):
+ merged_logits.append(logits[i].numpy())
+ merged_labels.append(labels[i].numpy())
+ merged_masks.append(masks[i].numpy())
+ merged_logits = np.vstack(np.array(merged_logits))
+ merged_labels = np.hstack(np.array(merged_labels))
+ merged_masks = np.hstack(np.array(merged_masks))
+ real_index = np.where(np.equal(merged_masks, 1))
+ correct += np.sum(
+ np.equal(
+ np.argmax(merged_logits[real_index], axis=-1),
+ merged_labels[real_index]))
+ total += np.shape(real_index)[-1]
+ accuracy = float(correct) / float(total)
+ logging.info("Train step: %d / acc = %d/%d = %f", step, correct, total,
+ accuracy)
+ if eval_summary_writer:
+ with eval_summary_writer.as_default():
+ tf.summary.scalar("eval_acc", float(correct) / float(total), step=step)
+ eval_summary_writer.flush()
+ return accuracy
+
+
+def get_metric_fn():
+ train_acc_metric = tf_keras.metrics.SparseCategoricalAccuracy(
+ "acc", dtype=tf.float32)
+ return train_acc_metric
+
+
+def main(unused_argv):
+ del unused_argv
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.strategy_type,
+ tpu_address=FLAGS.tpu)
+ if strategy:
+ logging.info("***** Number of cores used : %d",
+ strategy.num_replicas_in_sync)
+ train_input_fn = functools.partial(data_utils.get_classification_input_data,
+ FLAGS.train_batch_size, FLAGS.seq_len,
+ strategy, True, FLAGS.train_tfrecord_path)
+ test_input_fn = functools.partial(data_utils.get_classification_input_data,
+ FLAGS.test_batch_size, FLAGS.seq_len,
+ strategy, False, FLAGS.test_tfrecord_path)
+
+ total_training_steps = FLAGS.train_steps
+ steps_per_loop = FLAGS.iterations
+ eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
+ eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
+ eval_steps)
+ optimizer, learning_rate_fn = optimization.create_optimizer(
+ FLAGS.learning_rate,
+ total_training_steps,
+ FLAGS.warmup_steps,
+ adam_epsilon=FLAGS.adam_epsilon)
+ model_config = xlnet_config.XLNetConfig(FLAGS)
+ run_config = xlnet_config.create_run_config(True, False, FLAGS)
+ model_fn = functools.partial(get_classificationxlnet_model, model_config,
+ run_config, FLAGS.n_class, FLAGS.summary_type)
+ input_meta_data = {}
+ input_meta_data["d_model"] = FLAGS.d_model
+ input_meta_data["mem_len"] = FLAGS.mem_len
+ input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
+ strategy.num_replicas_in_sync)
+ input_meta_data["n_layer"] = FLAGS.n_layer
+ input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
+ input_meta_data["n_class"] = FLAGS.n_class
+
+ training_utils.train(
+ strategy=strategy,
+ model_fn=model_fn,
+ input_meta_data=input_meta_data,
+ eval_fn=eval_fn,
+ metric_fn=get_metric_fn,
+ train_input_fn=train_input_fn,
+ init_checkpoint=FLAGS.init_checkpoint,
+ init_from_transformerxl=FLAGS.init_from_transformerxl,
+ total_training_steps=total_training_steps,
+ steps_per_loop=steps_per_loop,
+ optimizer=optimizer,
+ learning_rate_fn=learning_rate_fn,
+ model_dir=FLAGS.model_dir,
+ save_steps=FLAGS.save_steps)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/xlnet/run_pretrain.py b/modeling/official/legacy/xlnet/run_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbb3dce3b3befdea44ebba878b39eb7e03f6e540
--- /dev/null
+++ b/modeling/official/legacy/xlnet/run_pretrain.py
@@ -0,0 +1,146 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""XLNet pretraining runner in tf2.0."""
+
+import functools
+import os
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+# pylint: disable=unused-import
+from official.common import distribute_utils
+from official.legacy.xlnet import common_flags
+from official.legacy.xlnet import data_utils
+from official.legacy.xlnet import optimization
+from official.legacy.xlnet import training_utils
+from official.legacy.xlnet import xlnet_config
+from official.legacy.xlnet import xlnet_modeling as modeling
+
+flags.DEFINE_integer(
+ "num_predict",
+ default=None,
+ help="Number of tokens to predict in partial prediction.")
+
+# FLAGS for pretrain input preprocessing
+flags.DEFINE_integer("perm_size", 0, help="Window size of permutation.")
+flags.DEFINE_float("leak_ratio", default=0.1,
+ help="Percent of masked tokens that are leaked.")
+
+flags.DEFINE_enum("sample_strategy", default="token_span",
+ enum_values=["single_token", "whole_word", "token_span",
+ "word_span"],
+ help="Stragey used to sample prediction targets.")
+flags.DEFINE_integer("max_num_tokens", default=5,
+ help="Maximum number of tokens to sample in a span."
+ "Effective when token_span strategy is used.")
+flags.DEFINE_integer("min_num_tokens", default=1,
+ help="Minimum number of tokens to sample in a span."
+ "Effective when token_span strategy is used.")
+
+flags.DEFINE_integer("max_num_words", default=5,
+ help="Maximum number of whole words to sample in a span."
+ "Effective when word_span strategy is used.")
+flags.DEFINE_integer("min_num_words", default=1,
+ help="Minimum number of whole words to sample in a span."
+ "Effective when word_span strategy is used.")
+FLAGS = flags.FLAGS
+
+
+def get_pretrainxlnet_model(model_config, run_config):
+ return modeling.PretrainingXLNetModel(
+ use_proj=True,
+ xlnet_config=model_config,
+ run_config=run_config,
+ name="model")
+
+
+def main(unused_argv):
+ del unused_argv
+ num_hosts = 1
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.strategy_type,
+ tpu_address=FLAGS.tpu)
+ if FLAGS.strategy_type == "tpu":
+ num_hosts = strategy.extended.num_hosts
+ if strategy:
+ logging.info("***** Number of cores used : %d",
+ strategy.num_replicas_in_sync)
+ logging.info("***** Number of hosts used : %d", num_hosts)
+ online_masking_config = data_utils.OnlineMaskingConfig(
+ sample_strategy=FLAGS.sample_strategy,
+ max_num_tokens=FLAGS.max_num_tokens,
+ min_num_tokens=FLAGS.min_num_tokens,
+ max_num_words=FLAGS.max_num_words,
+ min_num_words=FLAGS.min_num_words)
+
+ train_input_fn = functools.partial(
+ data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
+ strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
+ FLAGS.leak_ratio, FLAGS.num_predict, FLAGS.uncased, online_masking_config,
+ num_hosts)
+
+ total_training_steps = FLAGS.train_steps
+
+ steps_per_loop = FLAGS.iterations
+
+ optimizer, learning_rate_fn = optimization.create_optimizer(
+ init_lr=FLAGS.learning_rate,
+ num_train_steps=total_training_steps,
+ num_warmup_steps=FLAGS.warmup_steps,
+ min_lr_ratio=FLAGS.min_lr_ratio,
+ adam_epsilon=FLAGS.adam_epsilon,
+ weight_decay_rate=FLAGS.weight_decay_rate)
+
+ model_config = xlnet_config.XLNetConfig(FLAGS)
+ run_config = xlnet_config.create_run_config(True, False, FLAGS)
+ input_meta_data = {}
+ input_meta_data["d_model"] = FLAGS.d_model
+ input_meta_data["mem_len"] = FLAGS.mem_len
+ input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
+ strategy.num_replicas_in_sync)
+ input_meta_data["n_layer"] = FLAGS.n_layer
+ input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
+ model_fn = functools.partial(get_pretrainxlnet_model, model_config,
+ run_config)
+
+ model = training_utils.train(
+ strategy=strategy,
+ model_fn=model_fn,
+ input_meta_data=input_meta_data,
+ eval_fn=None,
+ metric_fn=None,
+ train_input_fn=train_input_fn,
+ init_checkpoint=FLAGS.init_checkpoint,
+ init_from_transformerxl=FLAGS.init_from_transformerxl,
+ total_training_steps=total_training_steps,
+ steps_per_loop=steps_per_loop,
+ optimizer=optimizer,
+ learning_rate_fn=learning_rate_fn,
+ model_dir=FLAGS.model_dir,
+ save_steps=FLAGS.save_steps)
+
+ # Export transformer-xl model checkpoint to be used in finetuning.
+ checkpoint = tf.train.Checkpoint(transformer_xl=model.transformerxl_model)
+ saved_path = checkpoint.save(
+ os.path.join(FLAGS.model_dir, "pretrained/transformer_xl.ckpt"))
+ logging.info("Exporting the transformer-xl model as a new TF checkpoint: %s",
+ saved_path)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/xlnet/run_squad.py b/modeling/official/legacy/xlnet/run_squad.py
new file mode 100644
index 0000000000000000000000000000000000000000..78dfd361d3f6a80e8d938e93e54a53087dfc2865
--- /dev/null
+++ b/modeling/official/legacy/xlnet/run_squad.py
@@ -0,0 +1,295 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""XLNet SQUAD finetuning runner in tf2.0."""
+
+import functools
+import json
+import os
+import pickle
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+
+import tensorflow as tf, tf_keras
+# pylint: disable=unused-import
+import sentencepiece as spm
+from official.common import distribute_utils
+from official.legacy.xlnet import common_flags
+from official.legacy.xlnet import data_utils
+from official.legacy.xlnet import optimization
+from official.legacy.xlnet import squad_utils
+from official.legacy.xlnet import training_utils
+from official.legacy.xlnet import xlnet_config
+from official.legacy.xlnet import xlnet_modeling as modeling
+
+flags.DEFINE_string(
+ "test_feature_path", default=None, help="Path to feature of test set.")
+flags.DEFINE_integer("query_len", default=64, help="Max query length.")
+flags.DEFINE_integer("start_n_top", default=5, help="Beam size for span start.")
+flags.DEFINE_integer("end_n_top", default=5, help="Beam size for span end.")
+flags.DEFINE_string(
+ "predict_dir", default=None, help="Path to write predictions.")
+flags.DEFINE_string(
+ "predict_file", default=None, help="Path to json file of test set.")
+flags.DEFINE_integer(
+ "n_best_size", default=5, help="n best size for predictions.")
+flags.DEFINE_integer("max_answer_length", default=64, help="Max answer length.")
+# Data preprocessing config
+flags.DEFINE_string(
+ "spiece_model_file", default=None, help="Sentence Piece model path.")
+flags.DEFINE_integer("max_seq_length", default=512, help="Max sequence length.")
+flags.DEFINE_integer("max_query_length", default=64, help="Max query length.")
+flags.DEFINE_integer("doc_stride", default=128, help="Doc stride.")
+
+FLAGS = flags.FLAGS
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tok_start_to_orig_index,
+ tok_end_to_orig_index,
+ token_is_max_context,
+ input_ids,
+ input_mask,
+ p_mask,
+ segment_ids,
+ paragraph_len,
+ cls_index,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tok_start_to_orig_index = tok_start_to_orig_index
+ self.tok_end_to_orig_index = tok_end_to_orig_index
+ self.token_is_max_context = token_is_max_context
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.p_mask = p_mask
+ self.segment_ids = segment_ids
+ self.paragraph_len = paragraph_len
+ self.cls_index = cls_index
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+
+# pylint: disable=unused-argument
+def run_evaluation(strategy, test_input_fn, eval_examples, eval_features,
+ original_data, eval_steps, input_meta_data, model,
+ current_step, eval_summary_writer):
+ """Run evaluation for SQUAD task.
+
+ Args:
+ strategy: distribution strategy.
+ test_input_fn: input function for evaluation data.
+ eval_examples: tf.Examples of the evaluation set.
+ eval_features: Feature objects of the evaluation set.
+ original_data: The original json data for the evaluation set.
+ eval_steps: total number of evaluation steps.
+ input_meta_data: input meta data.
+ model: keras model object.
+ current_step: current training step.
+ eval_summary_writer: summary writer used to record evaluation metrics.
+
+ Returns:
+ A float metric, F1 score.
+ """
+
+ def _test_step_fn(inputs):
+ """Replicated validation step."""
+
+ inputs["mems"] = None
+ res = model(inputs, training=False)
+ return res, inputs["unique_ids"]
+
+ @tf.function
+ def _run_evaluation(test_iterator):
+ """Runs validation steps."""
+ res, unique_ids = strategy.run(
+ _test_step_fn, args=(next(test_iterator),))
+ return res, unique_ids
+
+ test_iterator = data_utils.get_input_iterator(test_input_fn, strategy)
+ cur_results = []
+ for _ in range(eval_steps):
+ results, unique_ids = _run_evaluation(test_iterator)
+ unique_ids = strategy.experimental_local_results(unique_ids)
+
+ for result_key in results:
+ results[result_key] = (
+ strategy.experimental_local_results(results[result_key]))
+ for core_i in range(strategy.num_replicas_in_sync):
+ bsz = int(input_meta_data["test_batch_size"] /
+ strategy.num_replicas_in_sync)
+ for j in range(bsz):
+ result = {}
+ for result_key in results:
+ result[result_key] = results[result_key][core_i].numpy()[j]
+ result["unique_ids"] = unique_ids[core_i].numpy()[j]
+ # We appended a fake example into dev set to make data size can be
+ # divided by test_batch_size. Ignores this fake example during
+ # evaluation.
+ if result["unique_ids"] == 1000012047:
+ continue
+ unique_id = int(result["unique_ids"])
+
+ start_top_log_probs = ([
+ float(x) for x in result["start_top_log_probs"].flat
+ ])
+ start_top_index = [int(x) for x in result["start_top_index"].flat]
+ end_top_log_probs = ([
+ float(x) for x in result["end_top_log_probs"].flat
+ ])
+ end_top_index = [int(x) for x in result["end_top_index"].flat]
+
+ cls_logits = float(result["cls_logits"].flat[0])
+ cur_results.append(
+ squad_utils.RawResult(
+ unique_id=unique_id,
+ start_top_log_probs=start_top_log_probs,
+ start_top_index=start_top_index,
+ end_top_log_probs=end_top_log_probs,
+ end_top_index=end_top_index,
+ cls_logits=cls_logits))
+ if len(cur_results) % 1000 == 0:
+ logging.info("Processing example: %d", len(cur_results))
+
+ output_prediction_file = os.path.join(input_meta_data["predict_dir"],
+ "predictions.json")
+ output_nbest_file = os.path.join(input_meta_data["predict_dir"],
+ "nbest_predictions.json")
+ output_null_log_odds_file = os.path.join(input_meta_data["predict_dir"],
+ "null_odds.json")
+
+ results = squad_utils.write_predictions(
+ eval_examples, eval_features, cur_results, input_meta_data["n_best_size"],
+ input_meta_data["max_answer_length"], output_prediction_file,
+ output_nbest_file, output_null_log_odds_file, original_data,
+ input_meta_data["start_n_top"], input_meta_data["end_n_top"])
+
+ # Log current results.
+ log_str = "Result | "
+ for key, val in results.items():
+ log_str += "{} {} | ".format(key, val)
+ logging.info(log_str)
+ with eval_summary_writer.as_default():
+ tf.summary.scalar("best_f1", results["best_f1"], step=current_step)
+ tf.summary.scalar("best_exact", results["best_exact"], step=current_step)
+ eval_summary_writer.flush()
+ return results["best_f1"]
+
+
+def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
+ model = modeling.QAXLNetModel(
+ model_config,
+ run_config,
+ start_n_top=start_n_top,
+ end_n_top=end_n_top,
+ name="model")
+ return model
+
+
+def main(unused_argv):
+ del unused_argv
+ strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=FLAGS.strategy_type,
+ tpu_address=FLAGS.tpu)
+ if strategy:
+ logging.info("***** Number of cores used : %d",
+ strategy.num_replicas_in_sync)
+ train_input_fn = functools.partial(data_utils.get_squad_input_data,
+ FLAGS.train_batch_size, FLAGS.seq_len,
+ FLAGS.query_len, strategy, True,
+ FLAGS.train_tfrecord_path)
+
+ test_input_fn = functools.partial(data_utils.get_squad_input_data,
+ FLAGS.test_batch_size, FLAGS.seq_len,
+ FLAGS.query_len, strategy, False,
+ FLAGS.test_tfrecord_path)
+
+ total_training_steps = FLAGS.train_steps
+ steps_per_loop = FLAGS.iterations
+ eval_steps = int(FLAGS.test_data_size / FLAGS.test_batch_size)
+
+ optimizer, learning_rate_fn = optimization.create_optimizer(
+ FLAGS.learning_rate,
+ total_training_steps,
+ FLAGS.warmup_steps,
+ adam_epsilon=FLAGS.adam_epsilon)
+ model_config = xlnet_config.XLNetConfig(FLAGS)
+ run_config = xlnet_config.create_run_config(True, False, FLAGS)
+ input_meta_data = {}
+ input_meta_data["start_n_top"] = FLAGS.start_n_top
+ input_meta_data["end_n_top"] = FLAGS.end_n_top
+ input_meta_data["lr_layer_decay_rate"] = FLAGS.lr_layer_decay_rate
+ input_meta_data["predict_dir"] = FLAGS.predict_dir
+ input_meta_data["n_best_size"] = FLAGS.n_best_size
+ input_meta_data["max_answer_length"] = FLAGS.max_answer_length
+ input_meta_data["test_batch_size"] = FLAGS.test_batch_size
+ input_meta_data["batch_size_per_core"] = int(FLAGS.train_batch_size /
+ strategy.num_replicas_in_sync)
+ input_meta_data["mem_len"] = FLAGS.mem_len
+ model_fn = functools.partial(get_qaxlnet_model, model_config, run_config,
+ FLAGS.start_n_top, FLAGS.end_n_top)
+ eval_examples = squad_utils.read_squad_examples(
+ FLAGS.predict_file, is_training=False)
+ if FLAGS.test_feature_path:
+ logging.info("start reading pickle file...")
+ with tf.io.gfile.GFile(FLAGS.test_feature_path, "rb") as f:
+ eval_features = pickle.load(f)
+ logging.info("finishing reading pickle file...")
+ else:
+ sp_model = spm.SentencePieceProcessor()
+ sp_model.LoadFromSerializedProto(
+ tf.io.gfile.GFile(FLAGS.spiece_model_file, "rb").read())
+ spm_basename = os.path.basename(FLAGS.spiece_model_file)
+ eval_features = squad_utils.create_eval_data(
+ spm_basename, sp_model, eval_examples, FLAGS.max_seq_length,
+ FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.uncased)
+
+ with tf.io.gfile.GFile(FLAGS.predict_file) as f:
+ original_data = json.load(f)["data"]
+ eval_fn = functools.partial(run_evaluation, strategy, test_input_fn,
+ eval_examples, eval_features, original_data,
+ eval_steps, input_meta_data)
+
+ training_utils.train(
+ strategy=strategy,
+ model_fn=model_fn,
+ input_meta_data=input_meta_data,
+ eval_fn=eval_fn,
+ metric_fn=None,
+ train_input_fn=train_input_fn,
+ init_checkpoint=FLAGS.init_checkpoint,
+ init_from_transformerxl=FLAGS.init_from_transformerxl,
+ total_training_steps=total_training_steps,
+ steps_per_loop=steps_per_loop,
+ optimizer=optimizer,
+ learning_rate_fn=learning_rate_fn,
+ model_dir=FLAGS.model_dir,
+ save_steps=FLAGS.save_steps)
+
+
+if __name__ == "__main__":
+ app.run(main)
diff --git a/modeling/official/legacy/xlnet/squad_utils.py b/modeling/official/legacy/xlnet/squad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..98e9af557b334a26b7d1b71d739c9b420fc869d4
--- /dev/null
+++ b/modeling/official/legacy/xlnet/squad_utils.py
@@ -0,0 +1,972 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding=utf-8
+"""Utilities used in SQUAD task."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import gc
+import json
+import math
+import os
+import pickle
+import re
+import string
+
+from absl import logging
+import numpy as np
+import six
+import tensorflow as tf, tf_keras
+
+from official.legacy.xlnet import data_utils
+from official.legacy.xlnet import preprocess_utils
+
+SPIECE_UNDERLINE = u"▁"
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tok_start_to_orig_index,
+ tok_end_to_orig_index,
+ token_is_max_context,
+ input_ids,
+ input_mask,
+ p_mask,
+ segment_ids,
+ paragraph_len,
+ cls_index,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tok_start_to_orig_index = tok_start_to_orig_index
+ self.tok_end_to_orig_index = tok_end_to_orig_index
+ self.token_is_max_context = token_is_max_context
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.p_mask = p_mask
+ self.segment_ids = segment_ids
+ self.paragraph_len = paragraph_len
+ self.cls_index = cls_index
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+
+def make_qid_to_has_ans(dataset):
+ qid_to_has_ans = {}
+ for article in dataset:
+ for p in article["paragraphs"]:
+ for qa in p["qas"]:
+ qid_to_has_ans[qa["id"]] = bool(qa["answers"])
+ return qid_to_has_ans
+
+
+def get_raw_scores(dataset, preds):
+ """Gets exact scores and f1 scores."""
+ exact_scores = {}
+ f1_scores = {}
+ for article in dataset:
+ for p in article["paragraphs"]:
+ for qa in p["qas"]:
+ qid = qa["id"]
+ gold_answers = [
+ a["text"] for a in qa["answers"] if normalize_answer(a["text"])
+ ]
+ if not gold_answers:
+ # For unanswerable questions, only correct answer is empty string
+ gold_answers = [""]
+ if qid not in preds:
+ print("Missing prediction for %s" % qid)
+ continue
+ a_pred = preds[qid]
+ # Take max over all gold answers
+ exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers)
+ f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers)
+ return exact_scores, f1_scores
+
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
+ return re.sub(regex, " ", text)
+
+ def white_space_fix(text):
+ return " ".join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return "".join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+
+def compute_exact(a_gold, a_pred):
+ return int(normalize_answer(a_gold) == normalize_answer(a_pred))
+
+
+def get_tokens(s):
+ if not s:
+ return []
+ return normalize_answer(s).split()
+
+
+def compute_f1(a_gold, a_pred):
+ """Computes f1 score."""
+ gold_toks = get_tokens(a_gold)
+ pred_toks = get_tokens(a_pred)
+ common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
+ num_same = sum(common.values())
+ # pylint: disable=g-explicit-length-test
+ if len(gold_toks) == 0 or len(pred_toks) == 0:
+ # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
+ return int(gold_toks == pred_toks)
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(pred_toks)
+ recall = 1.0 * num_same / len(gold_toks)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
+ """Finds best threshold."""
+ num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
+ cur_score = num_no_ans
+ best_score = cur_score
+ best_thresh = 0.0
+ qid_list = sorted(na_probs, key=lambda k: na_probs[k])
+ for qid in qid_list:
+ if qid not in scores:
+ continue
+ if qid_to_has_ans[qid]:
+ diff = scores[qid]
+ else:
+ if preds[qid]:
+ diff = -1
+ else:
+ diff = 0
+ cur_score += diff
+ if cur_score > best_score:
+ best_score = cur_score
+ best_thresh = na_probs[qid]
+
+ has_ans_score, has_ans_cnt = 0, 0
+ for qid in qid_list:
+ if not qid_to_has_ans[qid]:
+ continue
+ has_ans_cnt += 1
+
+ if qid not in scores:
+ continue
+ has_ans_score += scores[qid]
+
+ return 100.0 * best_score / len(
+ scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
+
+
+def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs,
+ qid_to_has_ans):
+ """Finds all best threshold."""
+ best_exact, exact_thresh, has_ans_exact = find_best_thresh(
+ preds, exact_raw, na_probs, qid_to_has_ans)
+ best_f1, f1_thresh, has_ans_f1 = find_best_thresh(preds, f1_raw, na_probs,
+ qid_to_has_ans)
+ main_eval["best_exact"] = best_exact
+ main_eval["best_exact_thresh"] = exact_thresh
+ main_eval["best_f1"] = best_f1
+ main_eval["best_f1_thresh"] = f1_thresh
+ main_eval["has_ans_exact"] = has_ans_exact
+ main_eval["has_ans_f1"] = has_ans_f1
+
+
+_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction", [
+ "feature_index", "start_index", "end_index", "start_log_prob",
+ "end_log_prob"
+ ])
+
+_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
+RawResult = collections.namedtuple("RawResult", [
+ "unique_id", "start_top_log_probs", "start_top_index", "end_top_log_probs",
+ "end_top_index", "cls_logits"
+])
+
+
+def _compute_softmax(scores):
+ """Computes softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
+
+
+class SquadExample(object):
+ """A single training/test example for simple sequence classification.
+
+ For examples without an answer, the start and end position are -1.
+ """
+
+ def __init__(self,
+ qas_id,
+ question_text,
+ paragraph_text,
+ orig_answer_text=None,
+ start_position=None,
+ is_impossible=False):
+ self.qas_id = qas_id
+ self.question_text = question_text
+ self.paragraph_text = paragraph_text
+ self.orig_answer_text = orig_answer_text
+ self.start_position = start_position
+ self.is_impossible = is_impossible
+
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ s = ""
+ s += "qas_id: %s" % (preprocess_utils.printable_text(self.qas_id))
+ s += ", question_text: %s" % (
+ preprocess_utils.printable_text(self.question_text))
+ s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
+ if self.start_position:
+ s += ", start_position: %d" % (self.start_position)
+ if self.start_position:
+ s += ", is_impossible: %r" % (self.is_impossible)
+ return s
+
+
+def write_predictions(all_examples, all_features, all_results, n_best_size,
+ max_answer_length, output_prediction_file,
+ output_nbest_file, output_null_log_odds_file, orig_data,
+ start_n_top, end_n_top):
+ """Writes final predictions to the json file and log-odds of null if needed."""
+ logging.info("Writing predictions to: %s", (output_prediction_file))
+
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for (example_index, example) in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+
+ for (feature_index, feature) in enumerate(features):
+ result = unique_id_to_result[feature.unique_id]
+
+ cur_null_score = result.cls_logits
+
+ # if we could have irrelevant answers, get the min score of irrelevant
+ score_null = min(score_null, cur_null_score)
+
+ for i in range(start_n_top):
+ for j in range(end_n_top):
+ start_log_prob = result.start_top_log_probs[i]
+ start_index = result.start_top_index[i]
+
+ j_index = i * end_n_top + j
+
+ end_log_prob = result.end_top_log_probs[j_index]
+ end_index = result.end_top_index[j_index]
+
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index >= feature.paragraph_len - 1:
+ continue
+ if end_index >= feature.paragraph_len - 1:
+ continue
+
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index,
+ end_index=end_index,
+ start_log_prob=start_log_prob,
+ end_log_prob=end_log_prob))
+
+ prelim_predictions = sorted(
+ prelim_predictions,
+ key=lambda x: (x.start_log_prob + x.end_log_prob),
+ reverse=True)
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+
+ tok_start_to_orig_index = feature.tok_start_to_orig_index
+ tok_end_to_orig_index = feature.tok_end_to_orig_index
+ start_orig_pos = tok_start_to_orig_index[pred.start_index]
+ end_orig_pos = tok_end_to_orig_index[pred.end_index]
+
+ paragraph_text = example.paragraph_text
+ final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
+
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+
+ nbest.append(
+ _NbestPrediction(
+ text=final_text,
+ start_log_prob=pred.start_log_prob,
+ end_log_prob=pred.end_log_prob))
+
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(
+ _NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_log_prob + entry.end_log_prob)
+ if not best_non_null_entry:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for (i, entry) in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_log_prob"] = entry.start_log_prob
+ output["end_log_prob"] = entry.end_log_prob
+ nbest_json.append(output)
+
+ assert len(nbest_json) >= 1
+ assert best_non_null_entry is not None
+
+ score_diff = score_null
+ scores_diff_json[example.qas_id] = score_diff
+
+ all_predictions[example.qas_id] = best_non_null_entry.text
+
+ all_nbest_json[example.qas_id] = nbest_json
+
+ with tf.io.gfile.GFile(output_prediction_file, "w") as writer:
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
+
+ with tf.io.gfile.GFile(output_nbest_file, "w") as writer:
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
+
+ with tf.io.gfile.GFile(output_null_log_odds_file, "w") as writer:
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
+
+ qid_to_has_ans = make_qid_to_has_ans(orig_data)
+ exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
+ out_eval = {}
+
+ find_all_best_thresh(out_eval, all_predictions, exact_raw, f1_raw,
+ scores_diff_json, qid_to_has_ans)
+
+ return out_eval
+
+
+def read_squad_examples(input_file, is_training):
+ """Reads a SQuAD json file into a list of SquadExample."""
+ with tf.io.gfile.GFile(input_file, "r") as reader:
+ input_data = json.load(reader)["data"]
+
+ examples = []
+ for entry in input_data:
+ for paragraph in entry["paragraphs"]:
+ paragraph_text = paragraph["context"]
+
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question_text = qa["question"]
+ start_position = None
+ orig_answer_text = None
+ is_impossible = False
+
+ if is_training:
+ is_impossible = qa["is_impossible"]
+ if (len(qa["answers"]) != 1) and (not is_impossible):
+ raise ValueError(
+ "For training, each question should have exactly 1 answer.")
+ if not is_impossible:
+ answer = qa["answers"][0]
+ orig_answer_text = answer["text"]
+ start_position = answer["answer_start"]
+ else:
+ start_position = -1
+ orig_answer_text = ""
+
+ example = SquadExample(
+ qas_id=qas_id,
+ question_text=question_text,
+ paragraph_text=paragraph_text,
+ orig_answer_text=orig_answer_text,
+ start_position=start_position,
+ is_impossible=is_impossible)
+ examples.append(example)
+
+ return examples
+
+
+# pylint: disable=invalid-name
+def _convert_index(index, pos, M=None, is_start=True):
+ """Converts index."""
+ if index[pos] is not None:
+ return index[pos]
+ N = len(index)
+ rear = pos
+ while rear < N - 1 and index[rear] is None:
+ rear += 1
+ front = pos
+ while front > 0 and index[front] is None:
+ front -= 1
+ assert index[front] is not None or index[rear] is not None
+ if index[front] is None:
+ if index[rear] >= 1:
+ if is_start:
+ return 0
+ else:
+ return index[rear] - 1
+ return index[rear]
+ if index[rear] is None:
+ if M is not None and index[front] < M - 1:
+ if is_start:
+ return index[front] + 1
+ else:
+ return M - 1
+ return index[front]
+ if is_start:
+ if index[rear] > index[front] + 1:
+ return index[front] + 1
+ else:
+ return index[rear]
+ else:
+ if index[rear] > index[front] + 1:
+ return index[rear] - 1
+ else:
+ return index[front]
+
+
+def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
+ max_query_length, is_training, output_fn,
+ uncased):
+ """Loads a data file into a list of `InputBatch`s."""
+
+ cnt_pos, cnt_neg = 0, 0
+ unique_id = 1000000000
+ max_N, max_M = 1024, 1024
+ f = np.zeros((max_N, max_M), dtype=np.float32)
+
+ for (example_index, example) in enumerate(examples):
+ # pylint: disable=logging-format-interpolation
+ if example_index % 100 == 0:
+ logging.info("Converting {}/{} pos {} neg {}".format(
+ example_index, len(examples), cnt_pos, cnt_neg))
+
+ query_tokens = preprocess_utils.encode_ids(
+ sp_model,
+ preprocess_utils.preprocess_text(example.question_text, lower=uncased))
+
+ if len(query_tokens) > max_query_length:
+ query_tokens = query_tokens[0:max_query_length]
+
+ paragraph_text = example.paragraph_text
+ para_tokens = preprocess_utils.encode_pieces(
+ sp_model,
+ preprocess_utils.preprocess_text(example.paragraph_text, lower=uncased))
+
+ chartok_to_tok_index = []
+ tok_start_to_chartok_index = []
+ tok_end_to_chartok_index = []
+ char_cnt = 0
+ for i, token in enumerate(para_tokens):
+ chartok_to_tok_index.extend([i] * len(token))
+ tok_start_to_chartok_index.append(char_cnt)
+ char_cnt += len(token)
+ tok_end_to_chartok_index.append(char_cnt - 1)
+
+ tok_cat_text = "".join(para_tokens).replace(SPIECE_UNDERLINE, " ")
+ N, M = len(paragraph_text), len(tok_cat_text)
+
+ if N > max_N or M > max_M:
+ max_N = max(N, max_N)
+ max_M = max(M, max_M)
+ f = np.zeros((max_N, max_M), dtype=np.float32)
+ gc.collect()
+
+ g = {}
+
+ # pylint: disable=cell-var-from-loop
+ def _lcs_match(max_dist):
+ """LCS match."""
+ f.fill(0)
+ g.clear()
+
+ ### longest common sub sequence
+ # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
+ for i in range(N):
+
+ # note(zhiliny):
+ # unlike standard LCS, this is specifically optimized for the setting
+ # because the mismatch between sentence pieces and original text will
+ # be small
+ for j in range(i - max_dist, i + max_dist):
+ if j >= M or j < 0:
+ continue
+
+ if i > 0:
+ g[(i, j)] = 0
+ f[i, j] = f[i - 1, j]
+
+ if j > 0 and f[i, j - 1] > f[i, j]:
+ g[(i, j)] = 1
+ f[i, j] = f[i, j - 1]
+
+ f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
+ if (preprocess_utils.preprocess_text(
+ paragraph_text[i], lower=uncased,
+ remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
+ g[(i, j)] = 2
+ f[i, j] = f_prev + 1
+
+ max_dist = abs(N - M) + 5
+ for _ in range(2):
+ _lcs_match(max_dist)
+ if f[N - 1, M - 1] > 0.8 * N:
+ break
+ max_dist *= 2
+
+ orig_to_chartok_index = [None] * N
+ chartok_to_orig_index = [None] * M
+ i, j = N - 1, M - 1
+ while i >= 0 and j >= 0:
+ if (i, j) not in g:
+ break
+ if g[(i, j)] == 2:
+ orig_to_chartok_index[i] = j
+ chartok_to_orig_index[j] = i
+ i, j = i - 1, j - 1
+ elif g[(i, j)] == 1:
+ j = j - 1
+ else:
+ i = i - 1
+
+ if all(
+ v is None for v in orig_to_chartok_index) or f[N - 1, M - 1] < 0.8 * N:
+ print("MISMATCH DETECTED!")
+ continue
+
+ tok_start_to_orig_index = []
+ tok_end_to_orig_index = []
+ for i in range(len(para_tokens)):
+ start_chartok_pos = tok_start_to_chartok_index[i]
+ end_chartok_pos = tok_end_to_chartok_index[i]
+ start_orig_pos = _convert_index(
+ chartok_to_orig_index, start_chartok_pos, N, is_start=True)
+ end_orig_pos = _convert_index(
+ chartok_to_orig_index, end_chartok_pos, N, is_start=False)
+
+ tok_start_to_orig_index.append(start_orig_pos)
+ tok_end_to_orig_index.append(end_orig_pos)
+
+ if not is_training:
+ tok_start_position = tok_end_position = None
+
+ if is_training and example.is_impossible:
+ tok_start_position = -1
+ tok_end_position = -1
+
+ if is_training and not example.is_impossible:
+ start_position = example.start_position
+ end_position = start_position + len(example.orig_answer_text) - 1
+
+ start_chartok_pos = _convert_index(
+ orig_to_chartok_index, start_position, is_start=True)
+ tok_start_position = chartok_to_tok_index[start_chartok_pos]
+
+ end_chartok_pos = _convert_index(
+ orig_to_chartok_index, end_position, is_start=False)
+ tok_end_position = chartok_to_tok_index[end_chartok_pos]
+ assert tok_start_position <= tok_end_position
+
+ def _piece_to_id(x):
+ if six.PY2 and isinstance(x, unicode): # pylint: disable=undefined-variable
+ x = x.encode("utf-8")
+ return sp_model.PieceToId(x)
+
+ all_doc_tokens = list(map(_piece_to_id, para_tokens))
+
+ # The -3 accounts for [CLS], [SEP] and [SEP]
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
+
+ # We can have documents that are longer than the maximum sequence length.
+ # To deal with this we do a sliding window approach, where we take chunks
+ # of the up to our max length with a stride of `doc_stride`.
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
+ "DocSpan", ["start", "length"])
+ doc_spans = []
+ start_offset = 0
+ while start_offset < len(all_doc_tokens):
+ length = len(all_doc_tokens) - start_offset
+ if length > max_tokens_for_doc:
+ length = max_tokens_for_doc
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
+ if start_offset + length == len(all_doc_tokens):
+ break
+ start_offset += min(length, doc_stride)
+
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
+ tokens = []
+ token_is_max_context = {}
+ segment_ids = []
+ p_mask = []
+
+ cur_tok_start_to_orig_index = []
+ cur_tok_end_to_orig_index = []
+
+ for i in range(doc_span.length):
+ split_token_index = doc_span.start + i
+
+ cur_tok_start_to_orig_index.append(
+ tok_start_to_orig_index[split_token_index])
+ cur_tok_end_to_orig_index.append(
+ tok_end_to_orig_index[split_token_index])
+
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
+ split_token_index)
+ token_is_max_context[len(tokens)] = is_max_context
+ tokens.append(all_doc_tokens[split_token_index])
+ segment_ids.append(data_utils.SEG_ID_P)
+ p_mask.append(0)
+
+ paragraph_len = len(tokens)
+
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(data_utils.SEG_ID_P)
+ p_mask.append(1)
+
+ # note(zhiliny): we put P before Q
+ # because during pretraining, B is always shorter than A
+ for token in query_tokens:
+ tokens.append(token)
+ segment_ids.append(data_utils.SEG_ID_Q)
+ p_mask.append(1)
+ tokens.append(data_utils.SEP_ID)
+ segment_ids.append(data_utils.SEG_ID_Q)
+ p_mask.append(1)
+
+ cls_index = len(segment_ids)
+ tokens.append(data_utils.CLS_ID)
+ segment_ids.append(data_utils.SEG_ID_CLS)
+ p_mask.append(0)
+
+ input_ids = tokens
+
+ # The mask has 0 for real tokens and 1 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [0] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(1)
+ segment_ids.append(data_utils.SEG_ID_PAD)
+ p_mask.append(1)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+ assert len(p_mask) == max_seq_length
+
+ span_is_impossible = example.is_impossible
+ start_position = None
+ end_position = None
+ if is_training and not span_is_impossible:
+ # For training, if our document chunk does not contain an annotation
+ # we throw it out, since there is nothing to predict.
+ doc_start = doc_span.start
+ doc_end = doc_span.start + doc_span.length - 1
+ out_of_span = False
+ if not (tok_start_position >= doc_start and
+ tok_end_position <= doc_end):
+ out_of_span = True
+ if out_of_span:
+ # continue
+ start_position = 0
+ end_position = 0
+ span_is_impossible = True
+ else:
+ # note: we put P before Q, so doc_offset should be zero.
+ # doc_offset = len(query_tokens) + 2
+ doc_offset = 0
+ start_position = tok_start_position - doc_start + doc_offset
+ end_position = tok_end_position - doc_start + doc_offset
+
+ if is_training and span_is_impossible:
+ start_position = cls_index
+ end_position = cls_index
+
+ if example_index < 20:
+ logging.info("*** Example ***")
+ logging.info("unique_id: %s", unique_id)
+ logging.info("example_index: %s", example_index)
+ logging.info("doc_span_index: %s", doc_span_index)
+ logging.info("tok_start_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_start_to_orig_index]))
+ logging.info("tok_end_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_end_to_orig_index]))
+ logging.info(
+ "token_is_max_context: %s", " ".join([
+ "%d:%s" % (x, y)
+ for (x, y) in six.iteritems(token_is_max_context)
+ ]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+
+ if is_training and span_is_impossible:
+ logging.info("impossible example span")
+
+ if is_training and not span_is_impossible:
+ pieces = [
+ sp_model.IdToPiece(token)
+ for token in tokens[start_position:(end_position + 1)]
+ ]
+ answer_text = sp_model.DecodePieces(pieces)
+ logging.info("start_position: %d", start_position)
+ logging.info("end_position: %d", end_position)
+ logging.info("answer: %s",
+ preprocess_utils.printable_text(answer_text))
+
+ # With multi processing, the example_index is actually the index
+ # within the current process therefore we use example_index=None to
+ # avoid being used in the future. # The current code does not use
+ # example_index of training data.
+ if is_training:
+ feat_example_index = None
+ else:
+ feat_example_index = example_index
+
+ feature = InputFeatures(
+ unique_id=unique_id,
+ example_index=feat_example_index,
+ doc_span_index=doc_span_index,
+ tok_start_to_orig_index=cur_tok_start_to_orig_index,
+ tok_end_to_orig_index=cur_tok_end_to_orig_index,
+ token_is_max_context=token_is_max_context,
+ input_ids=input_ids,
+ input_mask=input_mask,
+ p_mask=p_mask,
+ segment_ids=segment_ids,
+ paragraph_len=paragraph_len,
+ cls_index=cls_index,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=span_is_impossible)
+
+ # Run callback
+ output_fn(feature)
+
+ unique_id += 1
+ if span_is_impossible:
+ cnt_neg += 1
+ else:
+ cnt_pos += 1
+
+ logging.info("Total number of instances: %d = pos %d + neg %d",
+ cnt_pos + cnt_neg, cnt_pos, cnt_neg)
+
+
+def _check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the "max context" doc span for the token."""
+
+ # Because of the sliding window approach taken to scoring documents, a single
+ # token can appear in multiple documents. E.g.
+ # Doc: the man went to the store and bought a gallon of milk
+ # Span A: the man went to the
+ # Span B: to the store and bought
+ # Span C: and bought a gallon of
+ # ...
+ #
+ # Now the word "bought" will have two scores from spans B and C. We only
+ # want to consider the score with "maximum context", which we define as
+ # the *minimum* of its left and right context (the *sum* of left and
+ # right context will always be the same, of course).
+ #
+ # In the example the maximum context for "bought" would be span C since
+ # it has 1 left context and 3 right context, while span B has 4 left context
+ # and 0 right context.
+ best_score = None
+ best_span_index = None
+ for (span_index, doc_span) in enumerate(doc_spans):
+ end = doc_span.start + doc_span.length - 1
+ if position < doc_span.start:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span.start
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+class FeatureWriter(object):
+ """Writes InputFeature to TF example file."""
+
+ def __init__(self, filename, is_training):
+ self.filename = filename
+ self.is_training = is_training
+ self.num_features = 0
+ self._writer = tf.io.TFRecordWriter(filename)
+
+ def process_feature(self, feature):
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
+ self.num_features += 1
+
+ def create_int_feature(values):
+ feature = tf.train.Feature(
+ int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ features = collections.OrderedDict()
+ features["unique_ids"] = create_int_feature([feature.unique_id])
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_float_feature(feature.input_mask)
+ features["p_mask"] = create_float_feature(feature.p_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+
+ features["cls_index"] = create_int_feature([feature.cls_index])
+
+ if self.is_training:
+ features["start_positions"] = create_int_feature([feature.start_position])
+ features["end_positions"] = create_int_feature([feature.end_position])
+ impossible = 0
+ if feature.is_impossible:
+ impossible = 1
+ features["is_impossible"] = create_float_feature([impossible])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ self._writer.write(tf_example.SerializeToString())
+
+ def close(self):
+ self._writer.close()
+
+
+def create_eval_data(spm_basename,
+ sp_model,
+ eval_examples,
+ max_seq_length,
+ max_query_length,
+ doc_stride,
+ uncased,
+ output_dir=None):
+ """Creates evaluation tfrecords."""
+ eval_features = []
+ eval_writer = None
+ if output_dir:
+ eval_rec_file = os.path.join(
+ output_dir,
+ "{}.slen-{}.qlen-{}.eval.tf_record".format(spm_basename, max_seq_length,
+ max_query_length))
+ eval_feature_file = os.path.join(
+ output_dir,
+ "{}.slen-{}.qlen-{}.eval.features.pkl".format(spm_basename,
+ max_seq_length,
+ max_query_length))
+
+ eval_writer = FeatureWriter(filename=eval_rec_file, is_training=False)
+
+ def append_feature(feature):
+ eval_features.append(feature)
+ if eval_writer:
+ eval_writer.process_feature(feature)
+
+ convert_examples_to_features(
+ examples=eval_examples,
+ sp_model=sp_model,
+ max_seq_length=max_seq_length,
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=False,
+ output_fn=append_feature,
+ uncased=uncased)
+
+ if eval_writer:
+ eval_writer.close()
+ with tf.io.gfile.GFile(eval_feature_file, "wb") as fout:
+ pickle.dump(eval_features, fout)
+
+ return eval_features
diff --git a/modeling/official/legacy/xlnet/training_utils.py b/modeling/official/legacy/xlnet/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf19f0b2e152cd41b97aa4c85530a416397ff186
--- /dev/null
+++ b/modeling/official/legacy/xlnet/training_utils.py
@@ -0,0 +1,305 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""XLNet training utils."""
+
+import os
+import re
+from typing import Any, Callable, Dict, Optional, Text
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.legacy.bert import model_training_utils
+from official.legacy.xlnet import data_utils
+
+# pytype: disable=attribute-error
+# pylint: disable=g-bare-generic,unused-import
+
+_MIN_SUMMARY_STEPS = 10
+
+
+def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
+ """Saves model to with provided checkpoint prefix."""
+
+ checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
+ saved_path = checkpoint.save(checkpoint_path)
+ logging.info("Saving model as TF checkpoint: %s", saved_path)
+ return
+
+
+def _float_metric_value(metric):
+ """Gets the value of a float-value keras metric."""
+ return metric.result().numpy().astype(float)
+
+
+def train(
+ strategy: tf.distribute.Strategy,
+ model_fn: Callable,
+ input_meta_data: Dict,
+ train_input_fn: Callable,
+ total_training_steps: int,
+ steps_per_loop: int,
+ optimizer: tf_keras.optimizers.Optimizer,
+ learning_rate_fn: tf_keras.optimizers.schedules.LearningRateSchedule,
+ eval_fn: Optional[Callable[[tf_keras.Model, int, tf.summary.SummaryWriter],
+ Any]] = None,
+ metric_fn: Optional[Callable[[], tf_keras.metrics.Metric]] = None,
+ init_checkpoint: Optional[Text] = None,
+ init_from_transformerxl: Optional[bool] = False,
+ model_dir: Optional[Text] = None,
+ save_steps: Optional[int] = None,
+ run_eagerly: Optional[bool] = False):
+ """Runs customized training.
+
+ Args:
+ strategy: Distribution strategy on which to run low level training loop.
+ model_fn: The function returns a keras.Model.
+ input_meta_data: A dictionary of params: `mem_len`, `lr_layer_decay_rate`,
+ `n_layer`, `batch_size_per_core` and `d_model`.
+ train_input_fn: Function returns a tf.data.Dataset used for training.
+ total_training_steps: Number of steps to train in total.
+ steps_per_loop: Number of steps per graph-mode loop. In order to reduce
+ communication in eager context, training logs are printed every
+ steps_per_loop.
+ optimizer: The optimizer for model.
+ learning_rate_fn: the learning rate schedule.
+ eval_fn: A callback of evaluation function, that takes a keras.Model,
+ current step and evaluation summary writer.
+ metric_fn: A metrics function returns a Keras Metric object to record
+ evaluation result using evaluation dataset or with training dataset
+ after every epoch.
+ init_checkpoint: Optional checkpoint to load to `sub_model` returned by
+ `model_fn`.
+ init_from_transformerxl: Whether to load to `transformerxl_model` of
+ `model_fn`.
+ model_dir: The directory of model (checkpoints, summaries).
+ save_steps: The frequency to save checkpoints. Every save_steps, we save a
+ model checkpoint. Model checkpoint will be saved and evaluation will be
+ conducted if evaluation dataset is provided.
+ run_eagerly: Whether to run training eagerly.
+
+ Returns:
+ Last training step logits if training happens, otherwise returns None.
+ Raises:
+ TypeError: if model directory is not specified.
+ """
+ required_arguments = [
+ train_input_fn, total_training_steps, steps_per_loop, optimizer,
+ learning_rate_fn, save_steps
+ ]
+ if [arg for arg in required_arguments if arg is None]:
+ raise ValueError("`train_input_fn`, `total_training_steps`, "
+ "`steps_per_loop`, `optimizer`, `save_steps` and "
+ "`learning_rate_fn` are required parameters.")
+ if not model_dir:
+ raise TypeError("Model directory must be specified.")
+ train_iterator = data_utils.get_input_iterator(train_input_fn, strategy)
+ if not tf.io.gfile.exists(model_dir):
+ tf.io.gfile.mkdir(model_dir)
+ # Create summary writers
+ summary_dir = os.path.join(model_dir, "summaries")
+ if not tf.io.gfile.exists(summary_dir):
+ tf.io.gfile.mkdir(summary_dir)
+ train_summary_writer = None
+ eval_summary_writer = None
+ if eval_fn:
+ eval_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, "eval"))
+ if steps_per_loop >= _MIN_SUMMARY_STEPS:
+ # Only writes summary when the stats are collected sufficiently over
+ # enough steps.
+ train_summary_writer = tf.summary.create_file_writer(
+ os.path.join(summary_dir, "train"))
+
+ with strategy.scope():
+ model = model_fn()
+
+ if init_checkpoint:
+ logging.info("restore from %s", init_checkpoint)
+ if init_from_transformerxl:
+ checkpoint = tf.train.Checkpoint(
+ transformer_xl=model.transformerxl_model)
+ else:
+ checkpoint = tf.train.Checkpoint(model=model)
+ checkpoint.restore(init_checkpoint)
+
+ model.optimizer = optimizer
+
+ if not hasattr(model, "optimizer"):
+ raise ValueError("User should set optimizer attribute to model.")
+
+ train_loss_metric = tf_keras.metrics.Mean("training_loss", dtype=tf.float32)
+ train_metric = None
+ if metric_fn:
+ train_metric = metric_fn()
+
+ def _replicated_step(inputs, mem=None):
+ """Replicated training step."""
+
+ inputs["mems"] = mem
+ with tf.GradientTape() as tape:
+ mem, logits = model(inputs, training=True)
+ loss = model.losses
+ train_loss_metric.update_state(loss)
+ if train_metric:
+ train_metric.update_state(inputs["label_ids"], logits)
+ scaled_loss = loss[0] * 1.0 / float(strategy.num_replicas_in_sync)
+
+ # Collects training variables.
+ tvars = model.trainable_variables
+ grads = tape.gradient(scaled_loss, tvars)
+ clipped, _ = tf.clip_by_global_norm(grads, clip_norm=1.0)
+
+ if input_meta_data["lr_layer_decay_rate"] != 1.0:
+ n_layer = 0
+ for i in range(len(clipped)):
+ m = re.search(r"model/transformer/layer_(\d+?)/", tvars[i].name)
+ if not m:
+ continue
+ n_layer = max(n_layer, int(m.group(1)) + 1)
+
+ for i in range(len(clipped)):
+ for l in range(n_layer):
+ if "model/transformer/layer_{}/".format(l) in tvars[i].name:
+ abs_rate = input_meta_data["lr_layer_decay_rate"]**(
+ n_layer - 1 - l)
+ clipped[i] *= abs_rate
+ logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
+ abs_rate, l, tvars[i].name))
+ break
+
+ optimizer.apply_gradients(zip(clipped, tvars))
+ if input_meta_data["mem_len"] > 0:
+ return mem
+
+ def train_steps(iterator, steps):
+ """Performs distributed training steps in a loop.
+
+ Args:
+ iterator: the distributed iterator of training datasets.
+ steps: an tf.int32 integer tensor to specify number of steps to run
+ inside host training loop.
+
+ Raises:
+ ValueError: Any of the arguments or tensor shapes are invalid.
+
+ Returns:
+ logits: logits computed.
+ """
+ if not isinstance(steps, tf.Tensor):
+ raise ValueError("steps should be an Tensor. Python object may cause "
+ "retracing.")
+
+ def cache_fn():
+ """Initializes memory tensor used in XLNet pretraining."""
+ mems = []
+ if input_meta_data["mem_len"] > 0:
+ for _ in range(input_meta_data["n_layer"]):
+ zeros = tf.zeros([
+ input_meta_data["batch_size_per_core"],
+ input_meta_data["mem_len"],
+ input_meta_data["d_model"]
+ ],
+ dtype=tf.float32)
+ mems.append(zeros)
+ return mems
+
+ if input_meta_data["mem_len"] > 0:
+ mem = strategy.run(cache_fn)
+ for _ in tf.range(steps):
+ mem = strategy.run(
+ _replicated_step, args=(
+ next(iterator),
+ mem,
+ ))
+ else:
+ for _ in tf.range(steps):
+ strategy.run(_replicated_step, args=(next(iterator),))
+
+ if not run_eagerly:
+ train_steps = tf.function(train_steps)
+
+ logging.info("Start training...")
+ checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
+ latest_checkpoint_file = tf.train.latest_checkpoint(model_dir)
+ if latest_checkpoint_file:
+ logging.info("Checkpoint file %s found and restoring from checkpoint",
+ latest_checkpoint_file)
+ checkpoint.restore(latest_checkpoint_file)
+ logging.info("Loading from checkpoint file completed")
+
+ current_step = optimizer.iterations.numpy()
+ checkpoint_name = "xlnet_step_{step}.ckpt"
+
+ while current_step < total_training_steps:
+ train_loss_metric.reset_states()
+ if train_metric:
+ train_metric.reset_states()
+
+ steps = model_training_utils.steps_to_run(current_step, save_steps,
+ steps_per_loop)
+ train_steps(train_iterator, tf.convert_to_tensor(steps, dtype=tf.int32))
+ current_step += steps
+ train_loss = _float_metric_value(train_loss_metric)
+ log_stream = "Train step: %d/%d / lr = %.9f / loss = %.7f" % (
+ current_step, total_training_steps, learning_rate_fn(current_step),
+ train_loss)
+ if train_metric:
+ log_stream += " / %s = %f" % (train_metric.name,
+ _float_metric_value(train_metric))
+ logging.info(log_stream)
+ if train_summary_writer:
+ with train_summary_writer.as_default():
+ tf.summary.scalar(
+ "learning_rate",
+ learning_rate_fn(current_step),
+ step=current_step)
+ tf.summary.scalar(
+ train_loss_metric.name, train_loss, step=current_step)
+ if train_metric:
+ tf.summary.scalar(
+ train_metric.name,
+ _float_metric_value(train_metric),
+ step=current_step)
+ train_summary_writer.flush()
+ if model_dir and current_step % save_steps == 0:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+
+ if eval_fn and current_step % save_steps == 0:
+
+ logging.info("Running evaluation after step: %s.", current_step)
+
+ eval_fn(model, current_step, eval_summary_writer)
+ if model_dir:
+ _save_checkpoint(checkpoint, model_dir,
+ checkpoint_name.format(step=current_step))
+ if eval_fn:
+ logging.info("Running final evaluation after training is complete.")
+ eval_metric = eval_fn(model, current_step, eval_summary_writer)
+
+ training_summary = {
+ "total_training_steps": total_training_steps,
+ "train_loss": _float_metric_value(train_loss_metric),
+ }
+ if train_metric:
+ training_summary["last_train_metrics"] = _float_metric_value(train_metric)
+ if eval_fn:
+ # eval_metric is supposed to be a float.
+ training_summary["eval_metrics"] = eval_metric
+
+ model_training_utils.write_txt_summary(training_summary, summary_dir)
+
+ return model
diff --git a/modeling/official/legacy/xlnet/xlnet_config.py b/modeling/official/legacy/xlnet/xlnet_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..35b7e979af93fdb3cc9b65248cd3a6f9eecf74bf
--- /dev/null
+++ b/modeling/official/legacy/xlnet/xlnet_config.py
@@ -0,0 +1,179 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility functions used in XLNet model."""
+
+import json
+import os
+
+import tensorflow as tf, tf_keras
+
+
+def create_run_config(is_training, is_finetune, flags):
+ """Helper function for creating RunConfig."""
+ kwargs = dict(
+ is_training=is_training,
+ use_tpu=flags.use_tpu,
+ dropout=flags.dropout,
+ dropout_att=flags.dropout_att,
+ init_method=flags.init_method,
+ init_range=flags.init_range,
+ init_std=flags.init_std,
+ clamp_len=flags.clamp_len)
+
+ if not is_finetune:
+ kwargs.update(
+ dict(
+ mem_len=flags.mem_len,
+ reuse_len=flags.reuse_len,
+ bi_data=flags.bi_data,
+ clamp_len=flags.clamp_len,
+ same_length=flags.same_length))
+
+ return RunConfig(**kwargs)
+
+
+# TODO(hongkuny): refactor XLNetConfig and RunConfig.
+class XLNetConfig(object):
+ """Configs for XLNet model.
+
+ XLNetConfig contains hyperparameters that are specific to a model checkpoint;
+ i.e., these hyperparameters should be the same between
+ pretraining and finetuning.
+
+ The following hyperparameters are defined:
+ n_layer: int, the number of layers.
+ d_model: int, the hidden size.
+ n_head: int, the number of attention heads.
+ d_head: int, the dimension size of each attention head.
+ d_inner: int, the hidden size in feed-forward layers.
+ ff_activation: str, "relu" or "gelu".
+ untie_r: bool, whether to untie the biases in attention.
+ n_token: int, the vocab size.
+ """
+
+ def __init__(self, FLAGS=None, json_path=None, args_dict=None):
+ """Constructing an XLNetConfig.
+
+ One of FLAGS or json_path should be provided.
+
+ Args:
+ FLAGS: An FLAGS instance.
+ json_path: A path to a json config file.
+ args_dict: A dict for args.
+ """
+
+ assert FLAGS is not None or json_path is not None or args_dict is not None
+
+ self.keys = [
+ 'n_layer', 'd_model', 'n_head', 'd_head', 'd_inner', 'ff_activation',
+ 'untie_r', 'n_token'
+ ]
+
+ if FLAGS is not None:
+ self.init_from_flags(FLAGS)
+
+ if json_path is not None:
+ self.init_from_json(json_path)
+
+ if args_dict is not None:
+ self.init_from_dict(args_dict)
+
+ def init_from_dict(self, args_dict):
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
+ for key in self.keys:
+ setattr(self, key, args_dict[key])
+
+ def init_from_flags(self, flags):
+ for key in self.keys:
+ setattr(self, key, getattr(flags, key))
+
+ def init_from_json(self, json_path):
+ with tf.io.gfile.GFile(json_path) as f:
+ json_data = json.load(f)
+ self.init_from_dict(json_data)
+
+ def to_json(self, json_path):
+ """Save XLNetConfig to a json file."""
+ json_data = {}
+ for key in self.keys:
+ json_data[key] = getattr(self, key)
+
+ json_dir = os.path.dirname(json_path)
+ if not tf.io.gfile.exists(json_dir):
+ tf.io.gfile.makedirs(json_dir)
+ with tf.io.gfile.GFile(json_path, 'w') as f:
+ json.dump(json_data, f, indent=4, sort_keys=True)
+
+
+class RunConfig(object):
+ """Class of RunConfig.
+
+ RunConfig contains hyperparameters that could be different
+ between pretraining and finetuning.
+ These hyperparameters can also be changed from run to run.
+ We store them separately from XLNetConfig for flexibility.
+ """
+
+ def __init__(self,
+ is_training,
+ use_tpu,
+ dropout,
+ dropout_att,
+ init_method='normal',
+ init_range=0.1,
+ init_std=0.02,
+ mem_len=None,
+ reuse_len=None,
+ bi_data=False,
+ clamp_len=-1,
+ same_length=False,
+ use_cls_mask=True):
+ """Initializes RunConfig.
+
+ Args:
+ is_training: bool, whether in training mode.
+ use_tpu: bool, whether TPUs are used.
+ dropout: float, dropout rate.
+ dropout_att: float, dropout rate on attention probabilities.
+ init_method: str, the initialization scheme, either "normal" or "uniform".
+ init_range: float, initialize the parameters with a uniform distribution
+ in [-init_range, init_range]. Only effective when init="uniform".
+ init_std: float, initialize the parameters with a normal distribution with
+ mean 0 and stddev init_std. Only effective when init="normal".
+ mem_len: int, the number of tokens to cache.
+ reuse_len: int, the number of tokens in the currect batch to be cached and
+ reused in the future.
+ bi_data: bool, whether to use bidirectional input pipeline. Usually set to
+ True during pretraining and False during finetuning.
+ clamp_len: int, clamp all relative distances larger than clamp_len. -1
+ means no clamping.
+ same_length: bool, whether to use the same attention length for each
+ token.
+ use_cls_mask: bool, whether to introduce cls mask.
+ """
+
+ self.init_method = init_method
+ self.init_range = init_range
+ self.init_std = init_std
+ self.is_training = is_training
+ self.dropout = dropout
+ self.dropout_att = dropout_att
+ self.use_tpu = use_tpu
+ self.mem_len = mem_len
+ self.reuse_len = reuse_len
+ self.bi_data = bi_data
+ self.clamp_len = clamp_len
+ self.same_length = same_length
+ self.use_cls_mask = use_cls_mask
diff --git a/modeling/official/legacy/xlnet/xlnet_modeling.py b/modeling/official/legacy/xlnet/xlnet_modeling.py
new file mode 100644
index 0000000000000000000000000000000000000000..29c3f4f046358587c671ef6bec65cc9bcc9570bd
--- /dev/null
+++ b/modeling/official/legacy/xlnet/xlnet_modeling.py
@@ -0,0 +1,1322 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Keras layers of XLNet model in TF 2.0."""
+
+import copy
+import warnings
+
+import tensorflow as tf, tf_keras
+from official.legacy.xlnet import data_utils
+from official.nlp.modeling import networks
+
+
+def gelu(x):
+ return tf_keras.activations.gelu(x, approximate=True)
+
+
+def _get_initializer(flags):
+ """Get variable initializer."""
+ if flags.init_method == "uniform":
+ initializer = tf_keras.initializers.RandomUniform(
+ minval=-flags.init_range, maxval=flags.init_range)
+ elif flags.init_method == "normal":
+ initializer = tf_keras.initializers.RandomNormal(stddev=flags.init_std)
+ else:
+ raise ValueError("Initializer {} not supported".format(flags.init_method))
+ return initializer
+
+
+def rel_shift(x, klen=-1):
+ """Performs relative shift to form the relative attention score."""
+ x_size = tf.shape(x)
+
+ x = tf.reshape(x, [x_size[1], x_size[0], x_size[2], x_size[3]])
+ x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
+ x = tf.reshape(x, [x_size[0], x_size[1] - 1, x_size[2], x_size[3]])
+ x = tf.slice(x, [0, 0, 0, 0], [-1, klen, -1, -1])
+
+ return x
+
+
+def _create_mask(qlen, mlen, dtype=tf.float32, same_length=False):
+ """Creates attention mask when single-side context allowed only."""
+ attn_mask = tf.ones([qlen, qlen], dtype=dtype)
+ mask_u = tf.linalg.band_part(attn_mask, 0, -1)
+ mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
+ attn_mask_pad = tf.zeros([qlen, mlen], dtype=dtype)
+ ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
+ if same_length:
+ mask_l = tf.linalg.band_part(attn_mask, -1, 0)
+ ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
+
+ return ret
+
+
+def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
+ """cache hidden states into memory."""
+
+ if mem_len is None or mem_len == 0:
+ return None
+ else:
+ if reuse_len is not None and reuse_len > 0:
+ curr_out = curr_out[:reuse_len]
+
+ if prev_mem is None:
+ new_mem = curr_out[-mem_len:]
+ else:
+ new_mem = tf.concat([prev_mem, curr_out], 0)[-mem_len:]
+
+ return tf_keras.backend.stop_gradient(new_mem)
+
+
+def is_special_none_tensor(tensor):
+ """Checks if a tensor is a special None Tensor."""
+ return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
+
+
+@tf_keras.utils.register_keras_serializable(package="Text")
+class RelativePositionEncoding(tf_keras.layers.Layer):
+ """Creates a relative positional encoding.
+
+ This layer creates a relative positional encoding as described in
+ "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ (https://arxiv.org/abs/1901.02860).
+
+ Rather than an absolute position embedding as in Transformer, this
+ formulation represents position as the relative distance between tokens using
+ sinusoidal positional embeddings.
+
+ Note: This layer is currently experimental.
+
+ Attributes:
+ hidden_size: The dimensionality of the input embeddings.
+ """
+
+ def __init__(self, hidden_size, **kwargs):
+ super(RelativePositionEncoding, self).__init__(**kwargs)
+ self._hidden_size = hidden_size
+ self._inv_freq = 1.0 / (10000.0**(
+ tf.range(0, self._hidden_size, 2.0) / self._hidden_size))
+
+ def call(self, pos_seq, batch_size=None):
+ """Implements call() for the layer.
+
+ Args:
+ pos_seq: A 1-D `Tensor`
+ batch_size: The optionally provided batch size that tiles the relative
+ positional encoding.
+
+ Returns:
+ The relative positional encoding of shape:
+ [len(pos_seq), batch_size, hidden_size] if batch_size is provided, else
+ [len(pos_seq), 1, hidden_size].
+ """
+ sinusoid_input = tf.einsum("i,d->id", pos_seq, self._inv_freq)
+ pos_emb = tf.concat([tf.sin(sinusoid_input), tf.cos(sinusoid_input)], -1)
+ pos_emb = pos_emb[:, None, :]
+
+ if batch_size is not None:
+ pos_emb = tf.tile(pos_emb, [1, batch_size, 1])
+ return pos_emb
+
+
+class RelativeAttention(tf_keras.layers.Layer):
+ """Core calculations for relative attention."""
+
+ def __init__(self, dropout_att, scale):
+ super(RelativeAttention, self).__init__()
+ self.scale = scale
+ self.dropout_att = dropout_att
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+
+ self.attention_probs_dropout = tf_keras.layers.Dropout(
+ rate=self.dropout_att)
+
+ super(RelativeAttention, self).build(unused_input_shapes)
+
+ def call(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias, attn_mask):
+ """Implements call() for the layer."""
+
+ # content based attention score
+ ac = tf.einsum("ibnd,jbnd->ijbn", q_head + r_w_bias, k_head_h)
+
+ # position based attention score
+ bd = tf.einsum("ibnd,jbnd->ijbn", q_head + r_r_bias, k_head_r)
+ bd = rel_shift(bd, klen=tf.shape(ac)[1])
+
+ # segment-based attention score
+ if seg_mat is None:
+ ef = 0
+ else:
+ ef = tf.einsum("ibnd,snd->isbn", q_head + r_s_bias, seg_embed)
+ tgt_shape = tf.shape(bd)
+ ef = tf.where(
+ tf.broadcast_to(tf.expand_dims(seg_mat, 3), tgt_shape),
+ tf.broadcast_to(ef[:, 1:, :, :], tgt_shape),
+ tf.broadcast_to(ef[:, :1, :, :], tgt_shape))
+
+ # merges attention scores and performs masking
+ attn_score = (ac + bd + ef) * self.scale
+ if attn_mask is not None:
+ attn_score = attn_score - 1e30 * attn_mask
+
+ # attention probability
+ attn_prob = tf.nn.softmax(attn_score, 1)
+ attn_prob = self.attention_probs_dropout(attn_prob)
+
+ # attention output
+ attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
+
+ return attn_vec
+
+
+class PositionwiseFF(tf_keras.layers.Layer):
+ """Positionwise feed-forward layer."""
+
+ def __init__(self, d_model, d_inner, dropout, kernel_initializer,
+ activation_type, **kwargs):
+ super(PositionwiseFF, self).__init__(**kwargs)
+ self.d_model = d_model
+ self.d_inner = d_inner
+ self.dropout = dropout
+ self.activation_type = activation_type
+ self.kernel_initializer = kernel_initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ if self.activation_type == "relu":
+ activation = tf.nn.relu
+ elif self.activation_type == "gelu":
+ activation = gelu
+ else:
+ raise (ValueError("Unsupported activation type {}".format(
+ self.activation_type)))
+ self.inner_projection_layer = (
+ tf_keras.layers.Dense(
+ units=self.d_inner,
+ activation=activation,
+ kernel_initializer=self.kernel_initializer,
+ name="layer_1"))
+ self.output_projection_layer = (
+ tf_keras.layers.Dense(
+ units=self.d_model,
+ kernel_initializer=self.kernel_initializer,
+ name="layer_2"))
+ self.output_dropout = tf_keras.layers.Dropout(
+ rate=self.dropout, name="drop_2")
+ self.output_layer_norm = (
+ tf_keras.layers.LayerNormalization(
+ name="LayerNorm", axis=-1, epsilon=1e-12))
+ super(PositionwiseFF, self).build(unused_input_shapes)
+
+ def call(self, inp):
+ """Implements call() for the layer."""
+
+ output = self.inner_projection_layer(inp)
+ output = self.output_projection_layer(output)
+ output = self.output_dropout(output)
+ output = self.output_layer_norm(output + inp)
+ return output
+
+
+class EmbeddingLookup(tf_keras.layers.Layer):
+ """Looks up words embeddings for id tensor."""
+
+ def __init__(self, n_token, d_embed, initializer, **kwargs):
+ super(EmbeddingLookup, self).__init__(**kwargs)
+ self.n_token = n_token
+ self.d_embed = d_embed
+ self.initializer = initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.lookup_table = self.add_weight(
+ "lookup_table",
+ shape=[self.n_token, self.d_embed],
+ initializer=self.initializer,
+ dtype=self.dtype)
+
+ super(EmbeddingLookup, self).build(unused_input_shapes)
+
+ def call(self, inputs):
+ return tf.nn.embedding_lookup(self.lookup_table, inputs)
+
+
+class RelativeMultiheadAttention(tf_keras.layers.Layer):
+ """Multi-head attention with relative embedding."""
+
+ def __init__(self, d_model, n_head, d_head, dropout, dropout_att,
+ kernel_initializer, **kwargs):
+ super(RelativeMultiheadAttention, self).__init__(**kwargs)
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_head = d_head
+ self.dropout = dropout
+ self.dropout_att = dropout_att
+ self.initializer = kernel_initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.scale = 1.0 / (self.d_head**0.5)
+
+ self.output_layer_norm = tf_keras.layers.LayerNormalization(
+ name="LayerNorm", axis=-1, epsilon=1e-12)
+
+ self.kh_projection_layer = self.add_weight(
+ "k/kernel",
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+ self.vh_projection_layer = self.add_weight(
+ "v/kernel",
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+ self.kr_projection_layer = self.add_weight(
+ "r/kernel",
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+ self.qh_projection_layer = self.add_weight(
+ "q/kernel",
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+
+ self.relative_attention_layer = RelativeAttention(
+ dropout_att=self.dropout_att, scale=self.scale)
+
+ self.proj_o = self.add_weight(
+ "o/kernel",
+ shape=[self.d_model, self.n_head, self.d_head],
+ initializer=self.initializer)
+
+ self.attention_dropout = tf_keras.layers.Dropout(rate=self.dropout)
+
+ super(RelativeMultiheadAttention, self).build(unused_input_shapes)
+
+ def call(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
+ attn_mask_h, attn_mask_g, mems, target_mapping):
+ """Implements call() for the layer."""
+
+ if mems is not None and mems.shape.ndims > 1:
+ cat = tf.concat([mems, h], 0)
+ else:
+ cat = h
+
+ # content heads
+ q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.qh_projection_layer)
+ k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.kh_projection_layer)
+ v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.vh_projection_layer)
+
+ # positional heads
+ k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.kr_projection_layer)
+
+ # core attention ops
+ attn_vec_h = self.relative_attention_layer(q_head_h, k_head_h, v_head_h,
+ k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias,
+ attn_mask_h)
+
+ # post processing
+ output_h = tf.einsum("ibnd,hnd->ibh", attn_vec_h, self.proj_o)
+ output_h = self.attention_dropout(output_h)
+ output_h = self.output_layer_norm(output_h + h)
+
+ output_g = None
+ if g is not None: # enable two-stream attention
+ # g-stream
+ q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.qh_projection_layer)
+ if target_mapping is not None:
+ q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
+ attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h,
+ k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias,
+ attn_mask_g)
+ attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
+
+ else:
+ attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h,
+ k_head_r, seg_embed, seg_mat,
+ r_w_bias, r_r_bias, r_s_bias,
+ attn_mask_g)
+
+ # post processing
+ output_g = tf.einsum("ibnd,hnd->ibh", attn_vec_g, self.proj_o)
+ output_g = self.attention_dropout(output_g)
+ output_g = self.output_layer_norm(output_g + g)
+
+ return (output_h, output_g)
+
+
+class TransformerXLModel(tf_keras.layers.Layer):
+ """Defines a Transformer-XL computation graph with additional support for XLNet."""
+
+ def __init__(self,
+ n_token,
+ n_layer,
+ d_model,
+ n_head,
+ d_head,
+ d_inner,
+ dropout,
+ dropout_att,
+ attn_type,
+ bi_data,
+ is_training,
+ initializer,
+ mem_len=None,
+ same_length=False,
+ clamp_len=-1,
+ untie_r=False,
+ use_tpu=True,
+ reuse_len=None,
+ ff_activation="relu",
+ use_cls_mask=False,
+ **kwargs):
+ """Initializes TransformerXLModel.
+
+ Args:
+ n_token: int, the number of tokens in vocabulary.
+ n_layer: int, the number of layers.
+ d_model: int, the hidden size.
+ n_head: int, the number of attention heads.
+ d_head: int, the dimension size of each attention head.
+ d_inner: int, the hidden size in feed-forward layers.
+ dropout: float, dropout rate.
+ dropout_att: float, dropout rate on attention probabilities.
+ attn_type: str, "uni" or "bi".
+ bi_data: bool, whether to use bidirectional input pipeline. Usually set to
+ True during pretraining and False during finetuning.
+ is_training: bool, whether in training mode.
+ initializer: A tf initializer.
+ mem_len: int, the number of tokens to cache.
+ same_length: bool, whether to use the same attention length for each
+ token.
+ clamp_len: int, clamp all relative distances larger than clamp_len. -1
+ means no clamping.
+ untie_r: bool, whether to untie the biases in attention.
+ use_tpu: bool, whether TPUs are used.
+ reuse_len: int, the number of tokens in the currect batch to be cached and
+ reused in the future.
+ ff_activation: str, "relu" or "gelu".
+ use_cls_mask: bool, whether to introduce cls mask.
+ **kwargs: Other parameters.
+ """
+
+ super(TransformerXLModel, self).__init__(**kwargs)
+ warnings.warn(
+ "`TransformerXLModel` is deprecated, please use `XLNetBase` instead",
+ DeprecationWarning, stacklevel=2)
+
+ self.n_token = n_token
+ self.initializer = initializer
+ self.attn_type = attn_type
+ self.n_layer = n_layer
+ self.d_model = d_model
+ self.n_head = n_head
+ self.d_head = d_head
+ self.d_inner = d_inner
+ self.ff_activation = ff_activation
+ self.untie_r = untie_r
+ self.use_tpu = use_tpu
+ self.dropout = dropout
+ self.dropout_att = dropout_att
+
+ self.mem_len = mem_len
+ self.reuse_len = reuse_len
+ self.bi_data = bi_data
+ self.clamp_len = clamp_len
+ self.same_length = same_length
+ self.use_cls_mask = use_cls_mask
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.tf_float = tf.float32
+
+ self.embedding_lookup = EmbeddingLookup(
+ n_token=self.n_token,
+ d_embed=self.d_model,
+ initializer=self.initializer,
+ dtype=self.tf_float,
+ name="word_embedding")
+
+ self.h_dropout = tf_keras.layers.Dropout(rate=self.dropout)
+ self.g_dropout = tf_keras.layers.Dropout(rate=self.dropout)
+
+ if self.untie_r:
+ self.r_w_bias = (
+ self.add_weight(
+ "r_w_bias",
+ shape=[self.n_layer, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_r_bias = (
+ self.add_weight(
+ "r_r_bias",
+ shape=[self.n_layer, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_s_bias = (
+ self.add_weight(
+ "r_s_bias",
+ shape=[self.n_layer, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ else:
+ self.r_w_bias = (
+ self.add_weight(
+ "r_w_bias",
+ shape=[self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_r_bias = (
+ self.add_weight(
+ "r_r_bias",
+ shape=[self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+ self.r_s_bias = (
+ self.add_weight(
+ "r_s_bias", [self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer))
+
+ self.seg_embed = self.add_weight(
+ "seg_embed", [self.n_layer, 2, self.n_head, self.d_head],
+ dtype=self.tf_float,
+ initializer=self.initializer)
+
+ self.mask_emb = self.add_weight(
+ "mask_emb/mask_emb", shape=[1, 1, self.d_model], dtype=self.tf_float)
+
+ self.emb_dropout = tf_keras.layers.Dropout(rate=self.dropout)
+ self.fwd_position_embedding = RelativePositionEncoding(self.d_model)
+ self.bwd_position_embedding = RelativePositionEncoding(self.d_model)
+
+ self.rel_multihead_layers = []
+ self.h_positionwise_ffn_layers = []
+ for i in range(self.n_layer):
+ self.rel_multihead_layers.append(
+ RelativeMultiheadAttention(
+ d_model=self.d_model,
+ dropout=self.dropout,
+ n_head=self.n_head,
+ d_head=self.d_head,
+ dropout_att=self.dropout_att,
+ kernel_initializer=self.initializer,
+ name="layer_%d/rel_attn" % (i)))
+ self.h_positionwise_ffn_layers.append(
+ PositionwiseFF(
+ d_model=self.d_model,
+ d_inner=self.d_inner,
+ dropout=self.dropout,
+ kernel_initializer=self.initializer,
+ activation_type=self.ff_activation,
+ name="layer_%d/ff" % (i)))
+
+ self.output_dropout = tf_keras.layers.Dropout(rate=self.dropout)
+
+ super(TransformerXLModel, self).build(unused_input_shapes)
+
+ def __call__(self,
+ inp_k,
+ seg_id=None,
+ input_mask=None,
+ mems=None,
+ perm_mask=None,
+ target_mapping=None,
+ inp_q=None,
+ **kwargs):
+ # Uses dict to feed inputs into call() in order to keep mems as a python
+ # list.
+ inputs = {
+ "inp_k": inp_k,
+ "seg_id": seg_id,
+ "input_mask": input_mask,
+ "mems": mems,
+ "perm_mask": perm_mask,
+ "target_mapping": target_mapping,
+ "inp_q": inp_q
+ }
+ return super(TransformerXLModel, self).__call__(inputs, **kwargs)
+
+ def call(self, inputs):
+ """Implements call() for the layer."""
+ inp_k = inputs["inp_k"]
+ seg_id = inputs["seg_id"]
+ input_mask = inputs["input_mask"]
+ mems = inputs["mems"]
+ perm_mask = inputs["perm_mask"]
+ target_mapping = inputs["target_mapping"]
+ inp_q = inputs["inp_q"]
+
+ new_mems = []
+
+ bsz = tf.shape(inp_k)[1]
+
+ qlen = inp_k.shape.as_list()[0]
+
+ mlen = mems[0].shape.as_list()[0] if mems is not None else 0
+ klen = mlen + qlen
+
+ ##### Attention mask
+ # causal attention mask
+ if self.attn_type == "uni":
+ attn_mask = _create_mask(qlen, mlen, self.tf_float, self.same_length)
+ # pylint: enable=protected-access
+ attn_mask = attn_mask[:, :, None, None]
+ elif self.attn_type == "bi":
+ attn_mask = None
+ else:
+ raise ValueError("Unsupported attention type: {}".format(self.attn_type))
+
+ # data mask: input mask & perm mask
+ if input_mask is not None and perm_mask is not None:
+ data_mask = input_mask[None] + perm_mask
+
+ elif input_mask is not None and perm_mask is None:
+ data_mask = input_mask[None]
+ elif input_mask is None and perm_mask is not None:
+ data_mask = perm_mask
+ else:
+ data_mask = None
+
+ if data_mask is not None:
+ # all mems can be attended to
+ mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
+ dtype=self.tf_float)
+ data_mask = tf.concat([mems_mask, data_mask], 1)
+ if attn_mask is None:
+ attn_mask = data_mask[:, :, :, None]
+ else:
+ attn_mask += data_mask[:, :, :, None]
+
+ if attn_mask is not None:
+ attn_mask = tf.cast(attn_mask > 0, dtype=self.tf_float)
+
+ if attn_mask is not None:
+ non_tgt_mask = -tf.eye(qlen, dtype=self.tf_float)
+ non_tgt_mask = tf.concat(
+ [tf.zeros([qlen, mlen], dtype=self.tf_float), non_tgt_mask], axis=-1)
+ non_tgt_mask = tf.cast(
+ (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=self.tf_float)
+ else:
+ non_tgt_mask = None
+
+ word_emb_k = self.embedding_lookup(inp_k)
+
+ if inp_q is not None:
+ if target_mapping is not None:
+ word_emb_q = tf.tile(self.mask_emb,
+ [tf.shape(target_mapping)[0], bsz, 1])
+ else:
+ inp_q_ext = inp_q[:, :, None]
+ word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
+
+ output_h = self.h_dropout(word_emb_k)
+ output_g = None
+ if inp_q is not None:
+ output_g = self.g_dropout(word_emb_q)
+
+ ##### Segment embedding
+ if seg_id is not None:
+
+ # Convert `seg_id` to one-hot `seg_mat`
+
+ mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
+
+ cat_id = tf.concat([mem_pad, seg_id], 0)
+
+ if self.use_cls_mask:
+ # `1` indicates not in the same segment [qlen x klen x bsz]
+ # seg_id: [qlen x bsz] & cat_id: [klen x bsz]
+ cls_mat = tf.logical_or(
+ tf.equal(seg_id, tf.constant([data_utils.SEG_ID_CLS]))[:, None],
+ tf.equal(cat_id, tf.constant([data_utils.SEG_ID_CLS]))[None, :])
+ seg_mat = tf.equal(seg_id[:, None], cat_id[None, :])
+ seg_mat = tf.logical_or(cls_mat, seg_mat)
+ else:
+ seg_mat = tf.logical_not(tf.equal(seg_id[:, None], cat_id[None, :]))
+ else:
+ seg_mat = None
+
+ dtype = self.tf_float
+ freq_seq = tf.range(0, self.d_model, 2.0)
+ if dtype is not None and dtype != tf.float32:
+ freq_seq = tf.cast(freq_seq, dtype=self.dtype)
+
+ if self.attn_type == "bi":
+ beg, end = klen, -qlen
+ elif self.attn_type == "uni":
+ beg, end = klen, -1
+ else:
+ raise ValueError("Unknown `attn_type` {}.".format(self.attn_type))
+
+ if self.bi_data:
+ fwd_pos_seq = tf.range(beg, end, -1.0)
+ bwd_pos_seq = tf.range(-beg, -end, 1.0)
+
+ if dtype is not None and dtype != tf.float32:
+ fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
+ bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
+
+ if self.clamp_len > 0:
+ fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len,
+ self.clamp_len)
+ bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len,
+ self.clamp_len)
+
+ if bsz is not None:
+ fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz // 2)
+ bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz // 2)
+ else:
+ fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, None)
+ bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, None)
+
+ pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
+ else:
+ fwd_pos_seq = tf.range(beg, end, -1.0)
+ if dtype is not None and dtype != tf.float32:
+ fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
+ if self.clamp_len > 0:
+ fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len,
+ self.lamp_len)
+
+ pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz)
+
+ pos_emb = self.emb_dropout(pos_emb)
+
+ if mems is None:
+ mems = [None] * self.n_layer
+ for i in range(self.n_layer):
+ # cache new mems
+ new_mems.append(
+ _cache_mem(output_h, mems[i], self.mem_len, self.reuse_len))
+ # pylint: enable=protected-access
+
+ # segment bias
+ if seg_id is None:
+ r_s_bias_i = None
+ seg_embed_i = None
+ else:
+ r_s_bias_i = self.r_s_bias if not self.untie_r else self.r_s_bias[i]
+ seg_embed_i = self.seg_embed[i]
+
+ ffn_layer = self.h_positionwise_ffn_layers[i]
+ attention_layer = self.rel_multihead_layers[i]
+ output_h, output_g = attention_layer(
+ h=output_h,
+ g=output_g,
+ r=pos_emb,
+ r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
+ r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
+ seg_mat=seg_mat,
+ r_s_bias=r_s_bias_i,
+ seg_embed=seg_embed_i,
+ attn_mask_h=non_tgt_mask,
+ attn_mask_g=attn_mask,
+ mems=mems[i],
+ target_mapping=target_mapping)
+ output_h = ffn_layer(output_h)
+ if output_g is not None:
+ output_g = ffn_layer(output_g)
+
+ if inp_q is not None:
+ output = output_g
+ else:
+ output = output_h
+
+ return output, new_mems, None
+
+
+class PretrainingXLNetModel(tf_keras.Model):
+ """XLNet keras model combined with pretraining LM loss layer.
+
+ See the original paper: https://arxiv.org/pdf/1906.08237.pdf
+
+ """
+
+ def __init__(self, use_proj, xlnet_config, run_config, use_legacy_mask=True,
+ **kwargs):
+ super(PretrainingXLNetModel, self).__init__(**kwargs)
+ self.run_config = run_config
+ self.initializer = _get_initializer(run_config)
+ self.xlnet_config = copy.deepcopy(xlnet_config)
+ self._use_legacy_mask = use_legacy_mask
+
+ self.xlnet_model = networks.XLNetBase(
+ vocab_size=self.xlnet_config.n_token,
+ initializer=self.initializer,
+ attention_type="bi",
+ num_layers=self.xlnet_config.n_layer,
+ hidden_size=self.xlnet_config.d_model,
+ num_attention_heads=self.xlnet_config.n_head,
+ head_size=self.xlnet_config.d_head,
+ inner_size=self.xlnet_config.d_inner,
+ two_stream=True,
+ tie_attention_biases=not self.xlnet_config.untie_r,
+ inner_activation=self.xlnet_config.ff_activation,
+ dropout_rate=self.run_config.dropout,
+ attention_dropout_rate=self.run_config.dropout_att,
+ memory_length=self.run_config.mem_len,
+ reuse_length=self.run_config.reuse_len,
+ bi_data=self.run_config.bi_data,
+ clamp_length=self.run_config.clamp_len,
+ use_cls_mask=self.run_config.use_cls_mask,
+ name="xlnet_model")
+
+ self.lmloss_layer = LMLossLayer(
+ vocab_size=self.xlnet_config.n_token,
+ hidden_size=self.xlnet_config.d_model,
+ initializer=self.initializer,
+ tie_weight=True,
+ bi_data=self.run_config.bi_data,
+ use_one_hot=self.run_config.use_tpu,
+ use_proj=use_proj,
+ name="lm_loss")
+
+ def call(self, features):
+ """Implements call() for the layer."""
+
+ input_ids = features["input_ids"]
+ masked_tokens = features["input_q"]
+ seg_ids = features["seg_id"]
+ if self._use_legacy_mask:
+ # Legacy input mask assumes `real` values are 0 and `padding`
+ # values are 1.
+ perm_mask = 1 - features["perm_mask"]
+ else:
+ perm_mask = features["perm_mask"]
+ target_mapping = features["target_mapping"]
+
+ # target for LM loss
+ target = features["target"]
+
+ # target mask for LM loss
+ tgt_mask = features["target_mask"]
+
+ mems = features.get("mems", None)
+
+ model_output, self.new_mems = self.xlnet_model(
+ input_ids=input_ids,
+ segment_ids=seg_ids,
+ input_mask=None,
+ state=mems,
+ permutation_mask=perm_mask,
+ target_mapping=target_mapping,
+ masked_tokens=masked_tokens)
+ lm_loss, _ = self.lmloss_layer(
+ hidden=model_output,
+ target=target,
+ lookup_table=self.xlnet_model.get_embedding_lookup_table(),
+ target_mask=tgt_mask)
+ self.add_loss(lm_loss)
+ return self.new_mems, model_output
+
+
+class ClassificationXLNetModel(tf_keras.Model):
+ """XLNet keras model combined with classification loss layer.
+
+ See the original paper: https://arxiv.org/pdf/1906.08237.pdf
+
+ """
+
+ def __init__(self, xlnet_config, run_config, n_class, summary_type,
+ use_legacy_mask=True, **kwargs):
+ super(ClassificationXLNetModel, self).__init__(**kwargs)
+ warnings.warn(
+ "`ClassificationXLNetModel` is deprecated, please use `XLNetClassifier`"
+ "instead.", DeprecationWarning, stacklevel=2)
+ self.run_config = run_config
+ self.initializer = _get_initializer(run_config)
+ self.xlnet_config = copy.deepcopy(xlnet_config)
+ self._use_legacy_mask = use_legacy_mask
+
+ self.xlnet_model = networks.XLNetBase(
+ vocab_size=self.xlnet_config.n_token,
+ initializer=self.initializer,
+ attention_type="bi",
+ num_layers=self.xlnet_config.n_layer,
+ hidden_size=self.xlnet_config.d_model,
+ num_attention_heads=self.xlnet_config.n_head,
+ head_size=self.xlnet_config.d_head,
+ inner_size=self.xlnet_config.d_inner,
+ two_stream=False,
+ tie_attention_biases=not self.xlnet_config.untie_r,
+ inner_activation=self.xlnet_config.ff_activation,
+ dropout_rate=self.run_config.dropout,
+ attention_dropout_rate=self.run_config.dropout_att,
+ memory_length=self.run_config.mem_len,
+ reuse_length=self.run_config.reuse_len,
+ bi_data=self.run_config.bi_data,
+ clamp_length=self.run_config.clamp_len,
+ use_cls_mask=False,
+ name="xlnet_model")
+
+ self.summarization_layer = Summarization(
+ hidden_size=self.xlnet_config.d_model,
+ num_attention_heads=self.xlnet_config.n_head,
+ head_size=self.xlnet_config.d_head,
+ dropout_rate=self.run_config.dropout,
+ attention_dropout_rate=self.run_config.dropout_att,
+ initializer=self.initializer,
+ use_proj=True,
+ summary_type=summary_type,
+ name="sequence_summary")
+
+ self.cl_loss_layer = ClassificationLossLayer(
+ n_class=n_class, initializer=self.initializer, name="classification")
+
+ def call(self, features):
+ """Implements call() for the layer."""
+ batch_size_per_core = tf.shape(features["input_ids"])[0]
+
+ input_ids = features["input_ids"]
+ segment_ids = features["segment_ids"]
+ if self._use_legacy_mask:
+ # Legacy input mask assumes `real` values are 0 and `padding`
+ # values are 1.
+ input_mask = 1 - features["input_mask"]
+ else:
+ input_mask = features["input_mask"]
+
+ label = tf.reshape(features["label_ids"], [batch_size_per_core])
+
+ mems = features.get("mems", None)
+
+ attention_output, new_mems = (
+ self.xlnet_model(input_ids, segment_ids, input_mask, mems))
+
+ summary = self.summarization_layer(attention_output)
+ per_example_loss, logits = self.cl_loss_layer(hidden=summary, labels=label)
+ self.add_loss(tf_keras.backend.mean(per_example_loss))
+ return new_mems, logits
+
+
+class LMLossLayer(tf_keras.layers.Layer):
+ """Layer computing cross entropy loss for language modeling."""
+
+ def __init__(self,
+ vocab_size,
+ hidden_size,
+ initializer,
+ tie_weight=False,
+ bi_data=True,
+ use_one_hot=False,
+ use_proj=False,
+ **kwargs):
+ """Constructs LMLoss layer.
+
+ Args:
+ vocab_size: Number of tokens in vocabulary.
+ hidden_size: The dimension of model hidden state.
+ initializer: Initializer used for parameters.
+ tie_weight: Whether to share weights between embedding lookup layer and
+ next-token prediction layer.
+ bi_data: Whether to use bidirectional input pipeline. Usually set to True
+ during pretraining and False during finetuning.
+ use_one_hot: bool, whether to use one hot encodings. This should be used
+ when TPUs are used.
+ use_proj: bool, whether to add a projection layer before LM prediction.
+ **kwargs: Other parameters.
+ """
+ super(LMLossLayer, self).__init__(**kwargs)
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.initializer = initializer
+
+ self.tie_weight = tie_weight
+ self.bi_data = bi_data
+ self.use_one_hot = use_one_hot
+ self.use_proj = use_proj
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ if self.use_proj:
+ self.proj_layer = tf_keras.layers.Dense(
+ units=self.hidden_size,
+ kernel_initializer=self.initializer,
+ activation=gelu,
+ name="lm_projection/dense")
+ self.proj_layer_norm = tf_keras.layers.LayerNormalization(
+ axis=-1, epsilon=1e-12, name="lm_projection/LayerNorm")
+ if not self.tie_weight:
+ self.softmax_w = self.add_weight(
+ "weight",
+ shape=[self.vocab_size, self.hidden_size],
+ initializer=self.initializer)
+
+ self.softmax_b = self.add_weight(
+ "bias", shape=[self.vocab_size], initializer=tf.zeros_initializer())
+
+ super(LMLossLayer, self).build(unused_input_shapes)
+
+ def call(self, hidden, target, lookup_table, target_mask):
+ """Implements call() for the layer."""
+ if self.use_proj:
+ hidden = self.proj_layer_norm(self.proj_layer(hidden))
+ if self.tie_weight:
+ logits = tf.einsum("ibd,nd->ibn", hidden, lookup_table) + self.softmax_b
+ else:
+ logits = tf.einsum("ibd,nd->ibn", hidden, self.softmax_w) + self.softmax_b
+
+ if self.use_one_hot:
+ one_hot_target = tf.one_hot(target, self.vocab_size, dtype=logits.dtype)
+ loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
+ else:
+ loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
+ labels=target, logits=logits)
+
+ total_loss = tf.reduce_sum(loss * target_mask) / tf.reduce_sum(target_mask)
+
+ return total_loss, logits
+
+
+class Summarization(tf_keras.layers.Layer):
+ """The layer to pool the output from XLNet model into a vector."""
+
+ def __init__(self,
+ hidden_size,
+ num_attention_heads,
+ head_size,
+ dropout_rate,
+ attention_dropout_rate,
+ initializer,
+ use_proj=True,
+ summary_type="last",
+ **kwargs):
+ """Constructs Summarization layer.
+
+ Args:
+ hidden_size: int, the dimension of model hidden state.
+ num_attention_heads: int, the number of attention heads.
+ head_size: int, the dimension size of each attention head.
+ dropout_rate: float, dropout rate.
+ attention_dropout_rate: float, dropout rate on attention probabilities.
+ initializer: Initializer used for parameters.
+ use_proj: bool, whether to use projection layer for summarization.
+ summary_type: Method used to summarize a sequence into a compact vector.
+ **kwargs: Other parameters.
+ """
+ super(Summarization, self).__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.num_attention_heads = num_attention_heads
+ self.head_size = head_size
+ self.initializer = initializer
+
+ self.dropout_rate = dropout_rate
+ self.attention_dropout_rate = attention_dropout_rate
+ self.use_proj = use_proj
+ self.summary_type = summary_type
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ if self.use_proj:
+ self.proj_layer = tf_keras.layers.Dense(
+ units=self.hidden_size,
+ kernel_initializer=self.initializer,
+ activation=tf.nn.tanh,
+ name="summary")
+ self.dropout_layer = tf_keras.layers.Dropout(rate=self.dropout_rate)
+
+ super(Summarization, self).build(unused_input_shapes)
+
+ def call(self, inputs):
+ """Implements call() for the layer."""
+ if self.summary_type == "last":
+ summary = inputs[:, -1, :]
+ elif self.summary_type == "first":
+ summary = inputs[:, 0, :]
+ else:
+ raise ValueError("Invalid summary type provided: %s" % self.summary_type)
+ if self.use_proj:
+ summary = self.proj_layer(summary)
+ summary = self.dropout_layer(summary)
+ return summary
+
+
+class ClassificationLossLayer(tf_keras.layers.Layer):
+ """Layer computing cross entropy loss for classification task."""
+
+ def __init__(self, n_class, initializer, **kwargs):
+ """Constructs Summarization layer.
+
+ Args:
+ n_class: Number of tokens in vocabulary.
+ initializer: Initializer used for parameters.
+ **kwargs: Other parameters.
+ """
+ super(ClassificationLossLayer, self).__init__(**kwargs)
+
+ self.n_class = n_class
+ self.initializer = initializer
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.proj_layer = tf_keras.layers.Dense(
+ units=self.n_class, kernel_initializer=self.initializer, name="logit")
+
+ super(ClassificationLossLayer, self).build(unused_input_shapes)
+
+ def call(self, hidden, labels):
+ """Implements call() for the layer."""
+
+ logits = self.proj_layer(hidden)
+ one_hot_target = tf.one_hot(labels, self.n_class, dtype=hidden.dtype) # pytype: disable=attribute-error
+ loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
+
+ return loss, logits
+
+
+class QAXLNetModel(tf_keras.Model):
+ """XLNet keras model combined with question answering loss layer.
+
+ See the original paper: https://arxiv.org/pdf/1906.08237.pdf
+
+ """
+
+ def __init__(self, xlnet_config, run_config, start_n_top, end_n_top,
+ use_legacy_mask=True, **kwargs):
+ super(QAXLNetModel, self).__init__(**kwargs)
+ warnings.warn(
+ "`QAXLNetModel` is deprecated, please use `XLNetSpanLabeler` instead.",
+ DeprecationWarning, stacklevel=2)
+ self.run_config = run_config
+ self.initializer = _get_initializer(run_config)
+ self.xlnet_config = copy.deepcopy(xlnet_config)
+ self._use_legacy_mask = use_legacy_mask
+
+ self.xlnet_model = networks.XLNetBase(
+ vocab_size=self.xlnet_config.n_token,
+ initializer=self.initializer,
+ attention_type="bi",
+ num_layers=self.xlnet_config.n_layer,
+ hidden_size=self.xlnet_config.d_model,
+ num_attention_heads=self.xlnet_config.n_head,
+ head_size=self.xlnet_config.d_head,
+ inner_size=self.xlnet_config.d_inner,
+ tie_attention_biases=not self.xlnet_config.untie_r,
+ inner_activation=self.xlnet_config.ff_activation,
+ dropout_rate=self.run_config.dropout,
+ attention_dropout_rate=self.run_config.dropout_att,
+ two_stream=False,
+ memory_length=self.run_config.mem_len,
+ reuse_length=self.run_config.reuse_len,
+ bi_data=self.run_config.bi_data,
+ clamp_length=self.run_config.clamp_len,
+ use_cls_mask=False,
+ name="xlnet_model")
+
+ self.qa_loss_layer = QALossLayer(
+ hidden_size=self.xlnet_config.d_model,
+ start_n_top=start_n_top,
+ end_n_top=end_n_top,
+ initializer=self.initializer,
+ dropout_rate=self.run_config.dropout,
+ name="qa_loss_layer")
+
+ def call(self, features, training=False):
+ """Implements call() for the layer."""
+
+ input_ids = features["input_ids"]
+ segment_ids = features["segment_ids"]
+ if self._use_legacy_mask:
+ # Legacy input mask assumes `real` values are 0 and `padding`
+ # values are 1.
+ input_mask = 1 - features["input_mask"]
+ else:
+ input_mask = features["input_mask"]
+
+ cls_index = tf.reshape(features["cls_index"], [-1])
+ p_mask = features["p_mask"]
+
+ attention_output, new_mems = (
+ self.xlnet_model(input_ids, segment_ids, input_mask))
+
+ if training:
+ loss, logits = self.qa_loss_layer(
+ hidden=attention_output,
+ p_mask=p_mask,
+ cls_index=cls_index,
+ start_positions=features["start_positions"],
+ end_positions=features["end_positions"],
+ is_impossible=features["is_impossible"])
+ self.add_loss(loss)
+ return new_mems, logits
+ else:
+ results = self.qa_loss_layer(
+ hidden=attention_output, p_mask=p_mask, cls_index=cls_index)
+ return results
+
+
+class QALossLayer(tf_keras.layers.Layer):
+ """Layer computing position and regression loss for question answering task."""
+
+ def __init__(self, hidden_size, start_n_top, end_n_top, initializer,
+ dropout_rate, **kwargs):
+ """Constructs Summarization layer.
+
+ Args:
+ hidden_size: Int, the hidden size.
+ start_n_top: Beam size for span start.
+ end_n_top: Beam size for span end.
+ initializer: Initializer used for parameters.
+ dropout_rate: float, dropout rate.
+ **kwargs: Other parameters.
+ """
+ super(QALossLayer, self).__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.start_n_top = start_n_top
+ self.end_n_top = end_n_top
+ self.initializer = initializer
+ self.dropout_rate = dropout_rate
+
+ def build(self, unused_input_shapes):
+ """Implements build() for the layer."""
+ self.start_logits_proj_layer = tf_keras.layers.Dense(
+ units=1, kernel_initializer=self.initializer, name="start_logits/dense")
+ self.end_logits_proj_layer0 = tf_keras.layers.Dense(
+ units=self.hidden_size,
+ kernel_initializer=self.initializer,
+ activation=tf.nn.tanh,
+ name="end_logits/dense_0")
+ self.end_logits_proj_layer1 = tf_keras.layers.Dense(
+ units=1, kernel_initializer=self.initializer, name="end_logits/dense_1")
+ self.end_logits_layer_norm = tf_keras.layers.LayerNormalization(
+ axis=-1, epsilon=1e-12, name="end_logits/LayerNorm")
+ self.answer_class_proj_layer0 = tf_keras.layers.Dense(
+ units=self.hidden_size,
+ kernel_initializer=self.initializer,
+ activation=tf.nn.tanh,
+ name="answer_class/dense_0")
+ self.answer_class_proj_layer1 = tf_keras.layers.Dense(
+ units=1,
+ kernel_initializer=self.initializer,
+ use_bias=False,
+ name="answer_class/dense_1")
+ self.ans_feature_dropout = tf_keras.layers.Dropout(rate=self.dropout_rate)
+ super(QALossLayer, self).build(unused_input_shapes)
+
+ def __call__(self, hidden, p_mask, cls_index, **kwargs):
+ return super(QALossLayer, self).__call__(
+ (hidden, p_mask, cls_index, kwargs))
+
+ def call(self, inputs, training=False):
+ """Implements call() for the layer."""
+ hidden, p_mask, cls_index, kwargs = inputs
+ return_dict = {}
+ seq_len = tf.shape(hidden)[1]
+
+ hidden = tf.transpose(hidden, [1, 0, 2])
+ start_logits = self.start_logits_proj_layer(hidden)
+ start_logits = tf.transpose(tf.squeeze(start_logits, -1), [1, 0])
+ start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
+ start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)
+ if training:
+ start_positions = kwargs["start_positions"]
+ end_positions = kwargs["end_positions"]
+ is_impossible = kwargs["is_impossible"]
+ start_positions = tf.reshape(start_positions, [-1])
+ start_index = tf.one_hot(
+ start_positions, depth=seq_len, axis=-1, dtype=tf.float32)
+ start_features = tf.einsum("lbh,bl->bh", hidden, start_index)
+ start_features = tf.tile(start_features[None], [seq_len, 1, 1])
+ end_logits = self.end_logits_proj_layer0(
+ tf.concat([hidden, start_features], axis=-1))
+
+ end_logits = self.end_logits_layer_norm(end_logits)
+
+ end_logits = self.end_logits_proj_layer1(end_logits)
+ end_logits = tf.transpose(tf.squeeze(end_logits, -1), [1, 0])
+ end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
+ end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
+ else:
+ # during inference, compute the end logits based on beam search
+
+ start_top_log_probs, start_top_index = tf.nn.top_k(
+ start_log_probs, k=self.start_n_top)
+ start_index = tf.one_hot(
+ start_top_index, depth=seq_len, axis=-1, dtype=tf.float32)
+ start_features = tf.einsum("lbh,bkl->bkh", hidden, start_index)
+ end_input = tf.tile(hidden[:, :, None], [1, 1, self.start_n_top, 1])
+ start_features = tf.tile(start_features[None], [seq_len, 1, 1, 1])
+ end_input = tf.concat([end_input, start_features], axis=-1)
+ end_logits = self.end_logits_proj_layer0(end_input)
+ end_logits = tf.reshape(end_logits, [seq_len, -1, self.hidden_size])
+ end_logits = self.end_logits_layer_norm(end_logits)
+
+ end_logits = tf.reshape(end_logits,
+ [seq_len, -1, self.start_n_top, self.hidden_size])
+
+ end_logits = self.end_logits_proj_layer1(end_logits)
+ end_logits = tf.reshape(end_logits, [seq_len, -1, self.start_n_top])
+ end_logits = tf.transpose(end_logits, [1, 2, 0])
+ end_logits_masked = end_logits * (
+ 1 - p_mask[:, None]) - 1e30 * p_mask[:, None]
+ end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
+ end_top_log_probs, end_top_index = tf.nn.top_k(
+ end_log_probs, k=self.end_n_top)
+ end_top_log_probs = tf.reshape(end_top_log_probs,
+ [-1, self.start_n_top * self.end_n_top])
+ end_top_index = tf.reshape(end_top_index,
+ [-1, self.start_n_top * self.end_n_top])
+
+ if training:
+ return_dict["start_log_probs"] = start_log_probs
+ return_dict["end_log_probs"] = end_log_probs
+ else:
+ return_dict["start_top_log_probs"] = start_top_log_probs
+ return_dict["start_top_index"] = start_top_index
+ return_dict["end_top_log_probs"] = end_top_log_probs
+ return_dict["end_top_index"] = end_top_index
+ # an additional layer to predict answerability
+
+ # get the representation of CLS
+ cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
+ cls_feature = tf.einsum("lbh,bl->bh", hidden, cls_index)
+
+ # get the representation of START
+ start_p = tf.nn.softmax(start_logits_masked, axis=-1, name="softmax_start")
+ start_feature = tf.einsum("lbh,bl->bh", hidden, start_p)
+
+ ans_feature = tf.concat([start_feature, cls_feature], -1)
+ ans_feature = self.answer_class_proj_layer0(ans_feature)
+ ans_feature = self.ans_feature_dropout(ans_feature)
+ cls_logits = self.answer_class_proj_layer1(ans_feature)
+ cls_logits = tf.squeeze(cls_logits, -1)
+ return_dict["cls_logits"] = cls_logits
+
+ if not training:
+ return return_dict
+
+ def compute_loss(log_probs, positions):
+ one_hot_positions = tf.one_hot(positions, depth=seq_len, dtype=tf.float32)
+
+ loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
+ loss = tf.reduce_mean(loss)
+ return loss
+
+ start_loss = compute_loss(start_log_probs, start_positions)
+ end_loss = compute_loss(end_log_probs, end_positions)
+
+ total_loss = (start_loss + end_loss) * 0.5
+
+ is_impossible = tf.reshape(is_impossible, [-1])
+ regression_loss = tf.nn.sigmoid_cross_entropy_with_logits(
+ labels=is_impossible, logits=cls_logits)
+ regression_loss = tf.reduce_mean(regression_loss)
+
+ total_loss += regression_loss * 0.5
+ return total_loss, cls_logits
diff --git a/modeling/official/modeling/__init__.py b/modeling/official/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/modeling/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/modeling/activations/__init__.py b/modeling/official/modeling/activations/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fe76b2901daff05be6438455b6f13097e9d3111
--- /dev/null
+++ b/modeling/official/modeling/activations/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Activations package definition."""
+from official.modeling.activations.gelu import gelu
+from official.modeling.activations.mish import mish
+from official.modeling.activations.relu import relu6
+from official.modeling.activations.sigmoid import hard_sigmoid
+from official.modeling.activations.swish import hard_swish
+from official.modeling.activations.swish import identity
+from official.modeling.activations.swish import simple_swish
diff --git a/modeling/official/modeling/activations/gelu.py b/modeling/official/modeling/activations/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e39e40787f52ae62be9bf3372449b6847a1842
--- /dev/null
+++ b/modeling/official/modeling/activations/gelu.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Gaussian error linear unit."""
+
+import tensorflow as tf, tf_keras
+
+
+@tf_keras.utils.register_keras_serializable(package='Text')
+def gelu(x):
+ """Gaussian Error Linear Unit.
+
+ This is a smoother version of the RELU.
+ Original paper: https://arxiv.org/abs/1606.08415
+ Args:
+ x: float Tensor to perform activation.
+
+ Returns:
+ `x` with the GELU activation applied.
+ """
+ return tf_keras.activations.gelu(x, approximate=True)
diff --git a/modeling/official/modeling/activations/gelu_test.py b/modeling/official/modeling/activations/gelu_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..aca4d8b5561261e0e1e60025483de224e326b613
--- /dev/null
+++ b/modeling/official/modeling/activations/gelu_test.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the Gaussian error linear unit."""
+
+import tensorflow as tf, tf_keras
+
+from official.modeling import activations
+
+
+class GeluTest(tf.test.TestCase):
+
+ def test_gelu(self):
+ expected_data = [[0.14967535, 0., -0.10032465],
+ [-0.15880796, -0.04540223, 2.9963627]]
+ gelu_data = activations.gelu([[.25, 0, -.25], [-1, -2, 3]])
+ self.assertAllClose(expected_data, gelu_data)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/activations/mish.py b/modeling/official/modeling/activations/mish.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8c72f8a14138b8739093a01d1aa116e720d9fd
--- /dev/null
+++ b/modeling/official/modeling/activations/mish.py
@@ -0,0 +1,36 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Self Regularized Non-Monotonic Activation Function."""
+
+import tensorflow as tf, tf_keras
+
+
+@tf_keras.utils.register_keras_serializable(package='Text')
+def mish(x) -> tf.Tensor:
+ """Mish activation function.
+
+ Mish: A Self Regularized Non-Monotonic Activation Function
+ https://arxiv.org/pdf/1908.08681.pdf
+
+ Mish(x) = x * tanh(ln(1+e^x))
+
+ Args:
+ x: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ x = tf.convert_to_tensor(x)
+ return x * tf.tanh(tf.nn.softplus(x))
diff --git a/modeling/official/modeling/activations/mish_test.py b/modeling/official/modeling/activations/mish_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a51f270a2c74a5701a474a2ec84d4d003d5941f
--- /dev/null
+++ b/modeling/official/modeling/activations/mish_test.py
@@ -0,0 +1,30 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the customized Mish activation."""
+
+import tensorflow as tf, tf_keras
+
+from official.modeling import activations
+
+
+class MishTest(tf.test.TestCase):
+
+ def test_mish(self):
+ x = tf.constant([1.0, 0.0])
+ self.assertAllClose([0.86509839, 0.0], activations.mish(x))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/activations/relu.py b/modeling/official/modeling/activations/relu.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6de54b3e22c2d5b76abe1e04bc24a2bd77158c1
--- /dev/null
+++ b/modeling/official/modeling/activations/relu.py
@@ -0,0 +1,31 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Customized Relu activation."""
+
+import tensorflow as tf, tf_keras
+
+
+@tf_keras.utils.register_keras_serializable(package='Text')
+def relu6(features):
+ """Computes the Relu6 activation function.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ return tf.nn.relu6(features)
diff --git a/modeling/official/modeling/activations/relu_test.py b/modeling/official/modeling/activations/relu_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b97a1cdcb9ce7eb6a7ed9de20233c76ae82574f2
--- /dev/null
+++ b/modeling/official/modeling/activations/relu_test.py
@@ -0,0 +1,32 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the customized Relu activation."""
+
+import tensorflow as tf, tf_keras
+
+from official.modeling import activations
+
+
+class CustomizedReluTest(tf.test.TestCase):
+
+ def test_relu6(self):
+ features = [[.25, 0, -.25], [-1, -2, 3]]
+ customized_relu6_data = activations.relu6(features)
+ relu6_data = tf.nn.relu6(features)
+ self.assertAllClose(customized_relu6_data, relu6_data)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/activations/sigmoid.py b/modeling/official/modeling/activations/sigmoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee2464ab2752e1cb1c94b9ab265746b94c2e4210
--- /dev/null
+++ b/modeling/official/modeling/activations/sigmoid.py
@@ -0,0 +1,31 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Customized Sigmoid activation."""
+
+import tensorflow as tf, tf_keras
+
+
+@tf_keras.utils.register_keras_serializable(package='Text')
+def hard_sigmoid(features):
+ """Computes the hard sigmoid activation function.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ return tf.nn.relu6(features + tf.cast(3., features.dtype)) * 0.16667
diff --git a/modeling/official/modeling/activations/sigmoid_test.py b/modeling/official/modeling/activations/sigmoid_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..9813e17e9893fc271f0e07862dde9b525fc7eda4
--- /dev/null
+++ b/modeling/official/modeling/activations/sigmoid_test.py
@@ -0,0 +1,37 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the customized Sigmoid activation."""
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.modeling import activations
+
+
+class CustomizedSigmoidTest(tf.test.TestCase):
+
+ def _hard_sigmoid_nn(self, x):
+ x = np.float32(x)
+ return tf.nn.relu6(x + 3.) * 0.16667
+
+ def test_hard_sigmoid(self):
+ features = [[.25, 0, -.25], [-1, -2, 3]]
+ customized_hard_sigmoid_data = activations.hard_sigmoid(features)
+ sigmoid_data = self._hard_sigmoid_nn(features)
+ self.assertAllClose(customized_hard_sigmoid_data, sigmoid_data)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/activations/swish.py b/modeling/official/modeling/activations/swish.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0f311bcb68e022cd824b2b112048727a5c5752
--- /dev/null
+++ b/modeling/official/modeling/activations/swish.py
@@ -0,0 +1,72 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Customized Swish activation."""
+
+import tensorflow as tf, tf_keras
+
+
+@tf_keras.utils.register_keras_serializable(package='Text')
+def simple_swish(features):
+ """Computes the Swish activation function.
+
+ The tf.nn.swish operation uses a custom gradient to reduce memory usage.
+ Since saving custom gradients in SavedModel is currently not supported, and
+ one would not be able to use an exported TF-Hub module for fine-tuning, we
+ provide this wrapper that can allow to select whether to use the native
+ TensorFlow swish operation, or whether to use a customized operation that
+ has uses default TensorFlow gradient computation.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ return features * tf.nn.sigmoid(features)
+
+
+@tf_keras.utils.register_keras_serializable(package='Text')
+def hard_swish(features):
+ """Computes a hard version of the swish function.
+
+ This operation can be used to reduce computational cost and improve
+ quantization for edge devices.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ fdtype = features.dtype
+ return features * tf.nn.relu6(features + tf.cast(3., fdtype)) * (1. / 6.)
+
+
+@tf_keras.utils.register_keras_serializable(package='Text')
+def identity(features):
+ """Computes the identity function.
+
+ Useful for helping in quantization.
+
+ Args:
+ features: A `Tensor` representing preactivation values.
+
+ Returns:
+ The activation value.
+ """
+ features = tf.convert_to_tensor(features)
+ return tf.identity(features)
diff --git a/modeling/official/modeling/activations/swish_test.py b/modeling/official/modeling/activations/swish_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1fbfc0e4179fe752e1134a273a7ea5bee873c91
--- /dev/null
+++ b/modeling/official/modeling/activations/swish_test.py
@@ -0,0 +1,42 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the customized Swish activation."""
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.modeling import activations
+
+
+class CustomizedSwishTest(tf.test.TestCase):
+
+ def _hard_swish_np(self, x):
+ x = np.float32(x)
+ return x * np.clip(x + 3, 0, 6) / 6
+
+ def test_simple_swish(self):
+ features = [[.25, 0, -.25], [-1, -2, 3]]
+ customized_swish_data = activations.simple_swish(features)
+ swish_data = tf.nn.swish(features)
+ self.assertAllClose(customized_swish_data, swish_data)
+
+ def test_hard_swish(self):
+ features = [[.25, 0, -.25], [-1, -2, 3]]
+ customized_swish_data = activations.hard_swish(features)
+ swish_data = self._hard_swish_np(features)
+ self.assertAllClose(customized_swish_data, swish_data)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py b/modeling/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
new file mode 100644
index 0000000000000000000000000000000000000000..26355bf3d1613c269693d654ebaec89800af2f28
--- /dev/null
+++ b/modeling/official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
@@ -0,0 +1,186 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Stacking model horizontally."""
+
+from absl import logging
+import numpy as np
+import tensorflow as tf, tf_keras
+
+
+def expand_vector(v: np.ndarray) -> np.ndarray:
+ """Expands a vector with batch dimensions.
+
+ Equivalent to expand_1_axis(v, epsilon=0.0, axis=-1)
+
+ Args:
+ v: A vector with shape [..., a].
+
+ Returns:
+ A vector with shape [..., 2 * a].
+ """
+ return np.repeat(v, 2, axis=-1)
+
+
+def expand_1_axis(w: np.ndarray,
+ epsilon: float,
+ axis: int) -> np.ndarray:
+ """Expands either the first dimension or the last dimension of w.
+
+ If `axis = 0`, the following constraint will be satisfied:
+ matmul(x, w) ==
+ matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
+
+ If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
+ expand_vector(matmul(x, w)) ==
+ 2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
+
+ Args:
+ w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
+ epsilon: Symmetric Noise added to expanded tensor.
+ axis: Must be either 0 or -1.
+
+ Returns:
+ Expanded numpy array.
+ """
+ assert axis in (0, -1), (
+ "Only support expanding the first or the last dimension. "
+ "Got: {}".format(axis))
+
+ rank = len(w.shape)
+
+ d_w = np.random.normal(np.zeros_like(w), np.fabs(w) * epsilon, w.shape)
+ d_w = np.repeat(d_w, 2, axis=axis)
+
+ sign_flip = np.array([1, -1])
+ for _ in range(rank - 1):
+ sign_flip = np.expand_dims(sign_flip, axis=-1 if axis == 0 else 0)
+ sign_flip = np.tile(sign_flip,
+ [w.shape[0]] + [1] * (rank - 2) + [w.shape[-1]])
+
+ d_w *= sign_flip
+ w_expand = (np.repeat(w, 2, axis=axis) + d_w) / 2
+ return w_expand
+
+
+def expand_2_axes(w: np.ndarray,
+ epsilon: float) -> np.ndarray:
+ """Expands the first dimension and the last dimension of w.
+
+ The following constraint will be satisfied:
+ expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
+
+ Args:
+ w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
+ epsilon: Symmetric Noise added to expanded tensor.
+
+ Returns:
+ Expanded numpy array.
+ """
+ rank = len(w.shape)
+
+ d_w = np.random.normal(np.zeros_like(w), np.fabs(w) * epsilon, w.shape)
+ d_w = np.repeat(np.repeat(d_w, 2, axis=0), 2, axis=-1)
+
+ sign_flip = np.array([1, -1])
+ for _ in range(rank - 1):
+ sign_flip = np.expand_dims(sign_flip, axis=-1)
+ sign_flip = np.tile(sign_flip,
+ [w.shape[0]] + [1] * (rank - 2) + [w.shape[-1] * 2])
+ d_w *= sign_flip
+
+ w_expand = (np.repeat(np.repeat(w, 2, axis=0), 2, axis=-1) + d_w) / 2
+ return w_expand
+
+
+def var_to_var(var_from: tf.Variable,
+ var_to: tf.Variable,
+ epsilon: float):
+ """Expands a variable to another variable.
+
+ Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
+ can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
+
+ If the shape of `var_to` is (a, ..., 2 * z):
+ For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
+ Not that there will be noise added to the left hand side, if epsilon != 0.
+ If the shape of `var_to` is (2 * a, ..., z):
+ For any x, tf.matmul(expand_vector(x), var_to) == tf.matmul(x, var_from)
+ If the shape of `var_to` is (2 * a, ..., 2 * z):
+ For any x, tf.matmul(expand_vector(x), var_to) ==
+ expand_vector(tf.matmul(expand_vector(x), var_from))
+
+ Args:
+ var_from: input variable to expand.
+ var_to: output variable.
+ epsilon: the noise ratio that will be added, when splitting `var_from`.
+ """
+ shape_from = var_from.shape
+ shape_to = var_to.shape
+
+ if shape_from == shape_to:
+ var_to.assign(var_from)
+
+ elif len(shape_from) == 1 and len(shape_to) == 1:
+ var_to.assign(expand_vector(var_from.numpy()))
+
+ elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] == shape_to[-1]:
+ var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=0))
+
+ elif shape_from[0] == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
+ var_to.assign(expand_1_axis(var_from.numpy(), epsilon=epsilon, axis=-1))
+
+ elif shape_from[0] * 2 == shape_to[0] and shape_from[-1] * 2 == shape_to[-1]:
+ var_to.assign(expand_2_axes(var_from.numpy(), epsilon=epsilon))
+
+ else:
+ raise ValueError("Shape not supported, {}, {}".format(shape_from, shape_to))
+
+
+def model_to_model_2x_wide(model_from: tf.Module,
+ model_to: tf.Module,
+ epsilon: float = 0.1):
+ """Expands a model to a wider version.
+
+ Also makes sure that the output of the model is not changed after expanding.
+ For example:
+ ```
+ model_narrow = tf_keras.Sequential()
+ model_narrow.add(tf_keras.Input(shape=(3,)))
+ model_narrow.add(tf_keras.layers.Dense(4))
+ model_narrow.add(tf_keras.layers.Dense(1))
+
+ model_wide = tf_keras.Sequential()
+ model_wide.add(tf_keras.Input(shape=(6,)))
+ model_wide.add(tf_keras.layers.Dense(8))
+ model_wide.add(tf_keras.layers.Dense(1))
+
+ model_to_model_2x_wide(model_narrow, model_wide)
+ assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
+ ```
+
+ We assume that `model_from` and `model_to` has the same architecture and only
+ widths of them differ.
+
+ Args:
+ model_from: input model to expand.
+ model_to: output model whose variables will be assigned expanded values
+ according to `model_from`.
+ epsilon: the noise ratio that will be added, when splitting `var_from`.
+ """
+ for w_from, w_to in zip(model_from.trainable_variables,
+ model_to.trainable_variables):
+ logging.info("expanding %s %s to %s %s",
+ w_from.name, w_from.shape, w_to.name, w_to.shape)
+ var_to_var(w_from, w_to, epsilon=epsilon)
diff --git a/modeling/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py b/modeling/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..59ad118cc6bef753452f3bc7616696b06ec27999
--- /dev/null
+++ b/modeling/official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
@@ -0,0 +1,101 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for tf2_utils_2x_wide."""
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.modeling.fast_training.experimental import tf2_utils_2x_wide
+
+
+class Tf2Utils2XWideTest(tf.test.TestCase):
+
+ def test_expand_vector(self):
+ x = np.array([1, 2])
+ self.assertAllClose(tf2_utils_2x_wide.expand_vector(x),
+ np.array([1, 1, 2, 2]))
+
+ def test_expand_matrix(self):
+ x = np.array([[1, 2], [3, 4]])
+ x = tf2_utils_2x_wide.expand_2_axes(x, epsilon=0.1)
+ self.assertAllClose(x[0, :] + x[1, :], np.array([1, 1, 2, 2]))
+ self.assertAllClose(x[2, :] + x[3, :], np.array([3, 3, 4, 4]))
+
+ def test_expand_matrix_axis_0(self):
+ x = np.array([[1, 2], [3, 4]])
+ x = tf2_utils_2x_wide.expand_1_axis(x, axis=0, epsilon=0.1)
+ self.assertAllClose(x[0, :] + x[1, :], np.array([1, 2]))
+ self.assertAllClose(x[2, :] + x[3, :], np.array([3, 4]))
+
+ def test_expand_matrix_axis_1(self):
+ x = np.array([[1, 2], [3, 4]])
+ x = tf2_utils_2x_wide.expand_1_axis(x, axis=-1, epsilon=0.1)
+ self.assertAllClose(x[:, 0] + x[:, 1], np.array([1, 3]))
+ self.assertAllClose(x[:, 2] + x[:, 3], np.array([2, 4]))
+
+ def test_expand_3d_tensor(self):
+ x0 = np.array([10, 11])
+ x1 = np.array([10, 10, 11, 11])
+ w0 = np.random.rand(2, 2)
+ w1 = tf2_utils_2x_wide.expand_2_axes(w0, epsilon=0.1)
+ o0 = np.matmul(x0, w0)
+ o1 = np.matmul(x1, w1)
+ self.assertAllClose(np.repeat(o0, 2, axis=-1), o1)
+
+ def test_expand_3d_tensor_axis_0(self):
+ x0 = np.array([10, 11])
+ x1 = np.array([10, 10, 11, 11])
+ w0 = np.random.rand(2, 2)
+ w1 = tf2_utils_2x_wide.expand_1_axis(w0, axis=0, epsilon=0.1)
+ o0 = np.matmul(x0, w0)
+ o1 = np.matmul(x1, w1)
+ self.assertAllClose(o0, o1)
+
+ def test_expand_3d_tensor_axis_2(self):
+ x = np.array([10, 11])
+ w0 = np.random.rand(2, 2)
+ w1 = tf2_utils_2x_wide.expand_1_axis(w0, axis=-1, epsilon=0.1)
+ o0 = np.matmul(x, w0)
+ o1 = np.matmul(x, w1)
+ self.assertAllClose(o0, np.sum(o1.reshape(2, 2), axis=-1))
+
+ def test_end_to_end(self):
+ """Covers expand_vector, expand_2_axes, and expand_1_axis."""
+ model_narrow = tf_keras.Sequential()
+ model_narrow.add(tf_keras.Input(shape=(3,)))
+ model_narrow.add(tf_keras.layers.Dense(4))
+ model_narrow.add(tf_keras.layers.Dense(4))
+ model_narrow.add(tf_keras.layers.Dense(1))
+
+ model_wide = tf_keras.Sequential()
+ model_wide.add(tf_keras.Input(shape=(6,)))
+ model_wide.add(tf_keras.layers.Dense(8))
+ model_wide.add(tf_keras.layers.Dense(8))
+ model_wide.add(tf_keras.layers.Dense(1))
+
+ x0 = np.array([[1, 2, 3]])
+ x1 = np.array([[1, 1, 2, 2, 3, 3]])
+
+ # Call model once to build variables first.
+ _, _ = model_narrow(x0), model_wide(x1)
+ tf2_utils_2x_wide.model_to_model_2x_wide(
+ model_narrow, model_wide, epsilon=0.2)
+
+ self.assertAllClose(model_narrow(x0), model_wide(x1),
+ rtol=1e-05, atol=1e-05)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/modeling/fast_training/progressive/policies.py b/modeling/official/modeling/fast_training/progressive/policies.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5b94368d8dd699af1d17d3a34972884ddec68a3
--- /dev/null
+++ b/modeling/official/modeling/fast_training/progressive/policies.py
@@ -0,0 +1,178 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Base ProgressivePolicy definition for progressive training.
+
+To write a progressive model, subclass ProgressivePolicy and implement its
+abstract methods to handle each training stage.
+"""
+
+import abc
+import dataclasses
+from typing import Any, Mapping
+from absl import logging
+import six
+import tensorflow as tf, tf_keras
+
+from official.common import streamz_counters
+from official.modeling.fast_training.progressive import utils
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class ProgressiveConfig(base_config.Config):
+ pass
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ProgressivePolicy:
+ """The APIs for handling progressive training stages.
+
+ Attributes:
+ cur_model: The model for the current progressive training stage.
+ cur_train_dataset: The train dataset function for the current stage.
+ cur_eval_dataset: The eval dataset function for the current stage.
+ cur_optimizer: The optimizer for the current stage.
+ cur_checkpoint_items: Items to be saved in and restored from checkpoints,
+ for the progressive trainer.
+ is_last_stage: Whether it is currently in the last stage.
+
+ Interfaces:
+ is_stage_advancing: Returns if progressive training is advancing to the
+ next stage.
+ update_pt_stage: Update progressive training stage.
+ """
+
+ def __init__(self):
+ """Initialize stage policy."""
+ self._cur_train_dataset = None
+ self._cur_eval_dataset = None
+ self._volatiles = utils.VolatileTrackable(optimizer=None, model=None)
+
+ stage_id = 0
+ self._stage_id = tf.Variable(
+ stage_id,
+ trainable=False,
+ dtype=tf.int64,
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
+ shape=[])
+ self._volatiles.reassign_trackable(
+ optimizer=self.get_optimizer(stage_id),
+ model=self.get_model(stage_id, old_model=None)) # pytype: disable=wrong-arg-types # typed-keras
+
+ streamz_counters.progressive_policy_creation_counter.get_cell(
+ ).increase_by(1)
+
+ def compute_stage_id(self, global_step: int) -> int:
+ for stage_id in range(self.num_stages()):
+ global_step -= self.num_steps(stage_id)
+ if global_step < 0:
+ return stage_id
+ logging.error('Global step %d found no matching progressive stages. '
+ 'Default to the last stage.', global_step)
+ return self.num_stages() - 1
+
+ @abc.abstractmethod
+ def num_stages(self) -> int:
+ """Return the total number of progressive stages."""
+ pass
+
+ @abc.abstractmethod
+ def num_steps(self, stage_id: int) -> int:
+ """Return the total number of steps in this stage."""
+ pass
+
+ @abc.abstractmethod
+ def get_model(self,
+ stage_id: int,
+ old_model: tf_keras.Model = None) -> tf_keras.Model: # pytype: disable=annotation-type-mismatch # typed-keras
+ """Return model for this stage. For initialization, `old_model` = None."""
+ pass
+
+ @abc.abstractmethod
+ def get_optimizer(self, stage_id: int) -> tf_keras.optimizers.Optimizer:
+ """Return optimizer for this stage."""
+ pass
+
+ @abc.abstractmethod
+ def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
+ """Return training Dataset for this stage."""
+ pass
+
+ @abc.abstractmethod
+ def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
+ """Return evaluation Dataset for this stage."""
+ pass
+
+ @property
+ def cur_model(self) -> tf_keras.Model:
+ return self._volatiles.model
+
+ @property
+ def cur_train_dataset(self) -> tf.data.Dataset:
+ if self._cur_train_dataset is None:
+ self._cur_train_dataset = self.get_train_dataset(self._stage_id.numpy())
+ return self._cur_train_dataset
+
+ @property
+ def cur_eval_dataset(self) -> tf.data.Dataset:
+ if self._cur_eval_dataset is None:
+ self._cur_eval_dataset = self.get_eval_dataset(self._stage_id.numpy())
+ return self._cur_eval_dataset
+
+ @property
+ def cur_optimizer(self) -> tf_keras.optimizers.Optimizer:
+ return self._volatiles.optimizer
+
+ @property
+ def is_last_stage(self) -> bool:
+ stage_id = self._stage_id.numpy()
+ return stage_id >= self.num_stages() - 1
+
+ @property
+ def cur_checkpoint_items(self) -> Mapping[str, Any]:
+ return dict(stage_id=self._stage_id, volatiles=self._volatiles)
+
+ def is_stage_advancing(self, global_step: int) -> bool:
+ old_stage_id = self._stage_id.numpy()
+ new_stage_id = self.compute_stage_id(global_step)
+ return old_stage_id != new_stage_id
+
+ def update_pt_stage(self, global_step: int, pass_old_model=True) -> None:
+ """Update progressive training internal status.
+
+ Call this after a training loop ends.
+
+ Args:
+ global_step: an integer scalar of the current global step.
+ pass_old_model: whether to pass the old_model to get_model() function.
+ This is set to False if the old_model is irrelevant (e.g, just a default
+ model from stage 0).
+ """
+ old_stage_id = self._stage_id.numpy()
+ new_stage_id = self.compute_stage_id(global_step)
+ logging.info('Switching stage from %d to %d', old_stage_id, new_stage_id)
+
+ # Update stage id.
+ self._stage_id.assign(new_stage_id)
+ # Update dataset function.
+ self._cur_train_dataset = None
+ self._cur_eval_dataset = None
+
+ # Update optimizer and model.
+ new_optimizer = self.get_optimizer(new_stage_id)
+ self._volatiles.reassign_trackable(optimizer=new_optimizer)
+ new_model = self.get_model(
+ new_stage_id, old_model=self.cur_model if pass_old_model else None)
+ self._volatiles.reassign_trackable(model=new_model)
diff --git a/modeling/official/modeling/fast_training/progressive/train.py b/modeling/official/modeling/fast_training/progressive/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..76732a8b0f8fae2a780ac8e2a75d15a485cf4ad0
--- /dev/null
+++ b/modeling/official/modeling/fast_training/progressive/train.py
@@ -0,0 +1,69 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TFM binary for the progressive trainer."""
+
+from absl import app
+from absl import flags
+import gin
+
+from official.common import distribute_utils
+# pylint: disable=unused-import
+from official.common import registry_imports
+# pylint: enable=unused-import
+from official.common import flags as tfm_flags
+from official.core import task_factory
+from official.core import train_utils
+from official.modeling import performance
+from official.modeling.fast_training.progressive import train_lib
+
+FLAGS = flags.FLAGS
+
+
+def main(_):
+ gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
+ params = train_utils.parse_configuration(FLAGS)
+ model_dir = FLAGS.model_dir
+ if 'train' in FLAGS.mode:
+ # Pure eval modes do not output yaml files. Otherwise continuous eval job
+ # may race against the train job for writing the same file.
+ train_utils.serialize_config(params, model_dir)
+
+ # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
+ # can have significant impact on model speeds by utilizing float16 in case of
+ # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
+ # dtype is float16
+ if params.runtime.mixed_precision_dtype:
+ performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
+ distribution_strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=params.runtime.distribution_strategy,
+ all_reduce_alg=params.runtime.all_reduce_alg,
+ num_gpus=params.runtime.num_gpus,
+ tpu_address=params.runtime.tpu,
+ **params.runtime.model_parallelism())
+ with distribution_strategy.scope():
+ task = task_factory.get_task(params.task, logging_dir=model_dir)
+
+ train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode=FLAGS.mode,
+ params=params,
+ model_dir=model_dir)
+
+ train_utils.save_gin_config(FLAGS.mode, model_dir)
+
+if __name__ == '__main__':
+ tfm_flags.define_flags()
+ app.run(main)
diff --git a/modeling/official/modeling/fast_training/progressive/train_lib.py b/modeling/official/modeling/fast_training/progressive/train_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c42549ff03f4a38f4293a12ecfa7a52091aee4
--- /dev/null
+++ b/modeling/official/modeling/fast_training/progressive/train_lib.py
@@ -0,0 +1,126 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TFM progressive training driver library.
+
+Compared to the common training driver, the only difference is that we use
+prog_trainer_lib.ProgressiveTrainer instead of the base trainer.
+"""
+
+# pytype: disable=attribute-error
+import os
+from typing import Any, Mapping, Tuple
+
+# Import libraries
+from absl import logging
+import orbit
+import tensorflow as tf, tf_keras
+from official.core import base_task
+from official.core import config_definitions
+from official.core import train_lib as base_train_lib
+from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
+
+
+def run_experiment(distribution_strategy: tf.distribute.Strategy,
+ task: base_task.Task,
+ mode: str,
+ params: config_definitions.ExperimentConfig,
+ model_dir: str,
+ run_post_eval: bool = False,
+ save_summary: bool = True) \
+-> Tuple[tf_keras.Model, Mapping[str, Any]]:
+ """Runs train/eval configured by the experiment params.
+
+ Args:
+ distribution_strategy: A distribution distribution_strategy.
+ task: A Task instance.
+ mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
+ or 'continuous_eval'.
+ params: ExperimentConfig instance.
+ model_dir: A 'str', a path to store model checkpoints and summaries.
+ run_post_eval: Whether to run post eval once after training, metrics logs
+ are returned.
+ save_summary: Whether to save train and validation summary.
+
+ Returns:
+ A 2-tuple of (model, eval_logs).
+ model: `tf_keras.Model` instance.
+ eval_logs: returns eval metrics logs when run_post_eval is set to True,
+ otherwise, returns {}.
+ """
+
+ with distribution_strategy.scope():
+ logging.info('Running progressive trainer.')
+ trainer = prog_trainer_lib.ProgressiveTrainer(
+ params, task, ckpt_dir=model_dir,
+ train='train' in mode,
+ evaluate=('eval' in mode) or run_post_eval,
+ checkpoint_exporter=base_train_lib.maybe_create_best_ckpt_exporter(
+ params, model_dir))
+
+ if trainer.checkpoint:
+ checkpoint_manager = tf.train.CheckpointManager(
+ trainer.checkpoint,
+ directory=model_dir,
+ max_to_keep=params.trainer.max_to_keep,
+ step_counter=trainer.global_step,
+ checkpoint_interval=params.trainer.checkpoint_interval,
+ init_fn=trainer.initialize)
+ else:
+ checkpoint_manager = None
+
+ controller = orbit.Controller(
+ strategy=distribution_strategy,
+ trainer=trainer if 'train' in mode else None,
+ evaluator=trainer,
+ global_step=trainer.global_step,
+ steps_per_loop=params.trainer.steps_per_loop,
+ checkpoint_manager=checkpoint_manager,
+ summary_dir=os.path.join(model_dir, 'train') if (save_summary) else None,
+ eval_summary_dir=os.path.join(model_dir, 'validation') if
+ (save_summary) else None,
+ summary_interval=params.trainer.summary_interval if
+ (save_summary) else None)
+
+ logging.info('Starts to execute mode: %s', mode)
+ with distribution_strategy.scope():
+ if mode == 'train':
+ controller.train(steps=params.trainer.train_steps)
+ elif mode == 'train_and_eval':
+ controller.train_and_evaluate(
+ train_steps=params.trainer.train_steps,
+ eval_steps=params.trainer.validation_steps,
+ eval_interval=params.trainer.validation_interval)
+ elif mode == 'eval':
+ controller.evaluate(steps=params.trainer.validation_steps)
+ elif mode == 'continuous_eval':
+
+ def timeout_fn():
+ if trainer.global_step.numpy() >= params.trainer.train_steps:
+ return True
+ return False
+
+ controller.evaluate_continuously(
+ steps=params.trainer.validation_steps,
+ timeout=params.trainer.continuous_eval_timeout,
+ timeout_fn=timeout_fn)
+ else:
+ raise NotImplementedError('The mode is not implemented: %s' % mode)
+
+ if run_post_eval:
+ with distribution_strategy.scope():
+ return trainer.model, trainer.evaluate(
+ tf.convert_to_tensor(params.trainer.validation_steps))
+ else:
+ return trainer.model, {}
diff --git a/modeling/official/modeling/fast_training/progressive/train_lib_test.py b/modeling/official/modeling/fast_training/progressive/train_lib_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cc64c6adbb5c968343b58d1d7eeb976e423c38a
--- /dev/null
+++ b/modeling/official/modeling/fast_training/progressive/train_lib_test.py
@@ -0,0 +1,183 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the progressive train_lib."""
+import os
+
+from absl import flags
+from absl.testing import parameterized
+import dataclasses
+import orbit
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.common import flags as tfm_flags
+# pylint: disable=unused-import
+from official.common import registry_imports
+# pylint: enable=unused-import
+from official.core import config_definitions as cfg
+from official.core import task_factory
+from official.modeling import optimization
+from official.modeling.hyperparams import params_dict
+from official.modeling.fast_training.progressive import policies
+from official.modeling.fast_training.progressive import train_lib
+from official.modeling.fast_training.progressive import trainer as prog_trainer_lib
+from official.utils.testing import mock_task
+
+FLAGS = flags.FLAGS
+
+tfm_flags.define_flags()
+
+
+@dataclasses.dataclass
+class ProgTaskConfig(cfg.TaskConfig):
+ pass
+
+
+@task_factory.register_task_cls(ProgTaskConfig)
+class ProgMockTask(policies.ProgressivePolicy, mock_task.MockTask):
+ """Progressive task for testing."""
+
+ def __init__(self, params: cfg.TaskConfig, logging_dir: str = None):
+ mock_task.MockTask.__init__(
+ self, params=params, logging_dir=logging_dir)
+ policies.ProgressivePolicy.__init__(self)
+
+ def num_stages(self):
+ return 2
+
+ def num_steps(self, stage_id):
+ return 2 if stage_id == 0 else 4
+
+ def get_model(self, stage_id, old_model=None):
+ del stage_id, old_model
+ return self.build_model()
+
+ def get_optimizer(self, stage_id):
+ """Build optimizer for each stage."""
+ params = optimization.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'adamw',
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 0.01,
+ 'end_learning_rate': 0.0,
+ 'power': 1.0,
+ 'decay_steps': 10,
+ },
+ },
+ 'warmup': {
+ 'polynomial': {
+ 'power': 1,
+ 'warmup_steps': 2,
+ },
+ 'type': 'polynomial',
+ }
+ })
+ opt_factory = optimization.OptimizerFactory(params)
+ optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
+
+ return optimizer
+
+ def get_train_dataset(self, stage_id):
+ del stage_id
+ strategy = tf.distribute.get_strategy()
+ return orbit.utils.make_distributed_dataset(
+ strategy, self.build_inputs, None)
+
+ def get_eval_dataset(self, stage_id):
+ del stage_id
+ strategy = tf.distribute.get_strategy()
+ return orbit.utils.make_distributed_dataset(
+ strategy, self.build_inputs, None)
+
+
+class TrainTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(TrainTest, self).setUp()
+ self._test_config = {
+ 'trainer': {
+ 'checkpoint_interval': 10,
+ 'steps_per_loop': 10,
+ 'summary_interval': 10,
+ 'train_steps': 10,
+ 'validation_steps': 5,
+ 'validation_interval': 10,
+ 'continuous_eval_timeout': 1,
+ 'optimizer_config': {
+ 'optimizer': {
+ 'type': 'sgd',
+ },
+ 'learning_rate': {
+ 'type': 'constant'
+ }
+ }
+ },
+ }
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ flag_mode=['train', 'eval', 'train_and_eval'],
+ run_post_eval=[True, False]))
+ def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
+ model_dir = self.get_temp_dir()
+ experiment_config = cfg.ExperimentConfig(
+ trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
+ task=ProgTaskConfig())
+ experiment_config = params_dict.override_params_dict(
+ experiment_config, self._test_config, is_strict=False)
+
+ with distribution_strategy.scope():
+ task = task_factory.get_task(experiment_config.task,
+ logging_dir=model_dir)
+
+ _, logs = train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode=flag_mode,
+ params=experiment_config,
+ model_dir=model_dir,
+ run_post_eval=run_post_eval)
+
+ if run_post_eval:
+ self.assertNotEmpty(logs)
+ else:
+ self.assertEmpty(logs)
+
+ if flag_mode == 'eval':
+ return
+ self.assertNotEmpty(
+ tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
+ # Tests continuous evaluation.
+ _, logs = train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode='continuous_eval',
+ params=experiment_config,
+ model_dir=model_dir,
+ run_post_eval=run_post_eval)
+ print(logs)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/fast_training/progressive/trainer.py b/modeling/official/modeling/fast_training/progressive/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d562e100bc14e12e5d1fbbfd449399e67963453
--- /dev/null
+++ b/modeling/official/modeling/fast_training/progressive/trainer.py
@@ -0,0 +1,294 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Progressive Trainer implementation.
+
+The trainer implements the Orbit `StandardTrainable` and
+`StandardEvaluable` interfaces. Trainers inside this project should be
+interchangable and independent on model architectures and tasks.
+"""
+
+import dataclasses
+import os
+from typing import Any, Optional
+
+# Import libraries
+from absl import logging
+import gin
+import orbit
+import tensorflow as tf, tf_keras
+from official.core import base_task
+from official.core import base_trainer as trainer_lib
+from official.core import config_definitions
+from official.modeling.fast_training.progressive import policies
+from official.modeling.fast_training.progressive import utils
+
+ExperimentConfig = config_definitions.ExperimentConfig
+
+
+@dataclasses.dataclass
+class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
+ """Configuration for progressive trainer.
+
+ Attributes:
+ progressive: A task-specific config. Users can subclass ProgressiveConfig
+ and define any task-specific settings in their subclass.
+ export_checkpoint: A bool. Whether to export checkpoints in non-progressive
+ manner (without the volatiles wrapper) such that your down-stream tasks
+ can load checkpoints from a progressive trainer as if it is a regular
+ checkpoint.
+ export_checkpoint_interval: A bool. The number of steps between exporting
+ checkpoints. If None (by default), will use the same value as
+ TrainerConfig.checkpoint_interval.
+ export_max_to_keep: The maximum number of exported checkpoints to keep.
+ If None (by default), will use the same value as
+ TrainerConfig.max_to_keep.
+ export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
+ during the final progressive training stage. In other words, whether to
+ not export small, partial models. In many cases, it is not meaningful to
+ finetune a small, partial model in down-stream tasks.
+ """
+ progressive: Optional[policies.ProgressiveConfig] = None
+ export_checkpoint: bool = True
+ export_checkpoint_interval: Optional[int] = None
+ export_max_to_keep: Optional[int] = None
+ export_only_final_stage_ckpt: bool = True
+
+
+@gin.configurable
+class ProgressiveTrainer(trainer_lib.Trainer):
+ """Implements the progressive trainer shared for TensorFlow models."""
+
+ def __init__(
+ self,
+ config: ExperimentConfig,
+ prog_task: base_task.Task, # also implemented ProgressivePolicy.
+ ckpt_dir: str = '',
+ train: bool = True,
+ evaluate: bool = True,
+ checkpoint_exporter: Any = None):
+ """Initialize common trainer for TensorFlow models.
+
+ Args:
+ config: An `ExperimentConfig` instance specifying experiment config.
+ prog_task: An instance both implemented policies.ProgressivePolicy and
+ base_task.Task.
+ ckpt_dir: Checkpoint directory.
+ train: bool, whether or not this trainer will be used for training.
+ default to True.
+ evaluate: bool, whether or not this trainer will be used for evaluation.
+ default to True.
+ checkpoint_exporter: an object that has the `maybe_export_checkpoint`
+ interface.
+ """
+ # Gets the current distribution strategy. If not inside any strategy scope,
+ # it gets a single-replica no-op strategy.
+ self._strategy = tf.distribute.get_strategy()
+ self._config = config
+ self._runtime_options = trainer_lib.get_runtime_options(config)
+ self._task = prog_task
+
+ # Directory for non-progressive checkpoint
+ self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
+ tf.io.gfile.makedirs(self._export_ckpt_dir)
+ self._export_ckpt_manager = None
+
+ # Receive other checkpoint export, e.g, best checkpoint exporter.
+ # TODO(lehou): unify the checkpoint exporting logic, although the default
+ # setting does not use checkpoint_exporter.
+ self._checkpoint_exporter = checkpoint_exporter
+
+ self._global_step = orbit.utils.create_global_step()
+
+ self._checkpoint = utils.CheckpointWithHooks(
+ before_load_hook=self._update_pt_stage_from_ckpt,
+ global_step=self.global_step,
+ **self._task.cur_checkpoint_items)
+
+ self._train_loss = tf_keras.metrics.Mean('training_loss', dtype=tf.float32)
+ self._validation_loss = tf_keras.metrics.Mean(
+ 'validation_loss', dtype=tf.float32)
+ self._train_metrics = self.task.build_metrics(
+ training=True) + self.model.metrics
+ self._validation_metrics = self.task.build_metrics(
+ training=False) + self.model.metrics
+
+ if train:
+ orbit.StandardTrainer.__init__(
+ self,
+ None, # Manage train_dataset by ourselves, not by StandardTrainer.
+ options=orbit.StandardTrainerOptions(
+ use_tf_while_loop=config.trainer.train_tf_while_loop,
+ use_tf_function=config.trainer.train_tf_function))
+
+ if evaluate:
+ orbit.StandardEvaluator.__init__(
+ self,
+ None, # Manage train_dataset by ourselves, not by StandardEvaluator.
+ options=orbit.StandardEvaluatorOptions(
+ use_tf_function=config.trainer.eval_tf_function))
+
+ @property
+ def model(self):
+ return self._task.cur_model
+
+ @property
+ def optimizer(self):
+ return self._task.cur_optimizer
+
+ # override
+ @property
+ def train_dataset(self):
+ """Overriding StandardTrainer.train_dataset."""
+ return self._task.cur_train_dataset
+
+ # override
+ @train_dataset.setter
+ def train_dataset(self, _):
+ raise SyntaxError('Please do not set train_dataset. Progressive training '
+ 'relies on progressive policy to manager train dataset.')
+
+ # override
+ @property
+ def eval_dataset(self):
+ """Overriding StandardEvaluator.eval_dataset."""
+ return self._task.cur_eval_dataset
+
+ # override
+ @eval_dataset.setter
+ def eval_dataset(self, _):
+ raise SyntaxError('Please do not set eval_dataset. Progressive training '
+ 'relies on progressive policy to manager eval dataset.')
+
+ def train_loop_end(self):
+ """See base class."""
+ logs = {}
+ for metric in self.train_metrics + [self.train_loss]:
+ logs[metric.name] = metric.result()
+ metric.reset_states()
+ if callable(self.optimizer.learning_rate):
+ logs['learning_rate'] = self.optimizer.learning_rate(
+ self.optimizer.iterations)
+ else:
+ logs['learning_rate'] = self.optimizer.learning_rate
+
+ self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir)
+ if self._task.is_stage_advancing(self.global_step.numpy()):
+ old_train_dataset = self.train_dataset
+
+ # Update progressive properties
+ self._task.update_pt_stage(self.global_step.numpy())
+
+ # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
+ # rebuild the train and eval functions with the updated model.
+ self._train_loop_fn = None
+ self._eval_loop_fn = None
+
+ if self.train_dataset != old_train_dataset:
+ # Setting `self._train_iter` to None will rebuild the dataset iterator.
+ self._train_iter = None
+
+ # Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
+ # for exporting.
+ self._export_ckpt_manager = None
+
+ return logs
+
+ def _update_pt_stage_from_ckpt(self, ckpt_file):
+ """Update stage properties based on the global_step variable in a ckpt file.
+
+ Before loading variables from a checkpoint file, we need to go to the
+ correct stage and build corresponding model and optimizer, to make sure that
+ we retore variables of the right model and optimizer.
+
+ Args:
+ ckpt_file: Checkpoint file that will be restored/read from.
+ """
+ if not ckpt_file:
+ return
+ ckpt = tf.train.Checkpoint(global_step=self.global_step)
+ ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched()
+
+ if self._task.is_stage_advancing(self.global_step.numpy()):
+ old_train_dataset = self.train_dataset
+
+ # Update progressive properties
+ self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False)
+
+ # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
+ # rebuild the train and eval functions with the updated model.
+ self._train_loop_fn = None
+ self._eval_loop_fn = None
+
+ if self.train_dataset != old_train_dataset:
+ # Setting `self._train_iter` to None will rebuild the dataset iterator.
+ self._train_iter = None
+
+ # Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
+ # for exporting.
+ self._export_ckpt_manager = None
+
+ def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
+ """Export checkpoints in non-progressive format.
+
+ This basically removes the wrapping of self._task.cur_checkpoint_items
+ -- just save the model, optimizer, etc., directly.
+ The purpose is to let your down-stream tasks to use these checkpoints.
+
+ Args:
+ export_ckpt_dir: A str. folder of exported checkpoints.
+ """
+ if not self.config.trainer.export_checkpoint:
+ logging.info('Not exporting checkpoints.')
+ return
+ if not self._task.is_last_stage and (
+ self.config.trainer.export_only_final_stage_ckpt):
+ logging.info('Not exporting checkpoints until the last stage.')
+ return
+
+ if self._export_ckpt_manager is None:
+ # Create a checkpoint object just now, to make sure we use
+ # progressive_policy.cur_model and progressive_policy.cur_optimizer of the
+ # current stage.
+ if hasattr(self.model, 'checkpoint_items'):
+ checkpoint_items = self.model.checkpoint_items
+ else:
+ checkpoint_items = {}
+ checkpoint = tf.train.Checkpoint(
+ global_step=self.global_step,
+ model=self.model,
+ optimizer=self.optimizer,
+ **checkpoint_items)
+
+ max_to_keep = self.config.trainer.export_max_to_keep or (
+ self.config.trainer.max_to_keep)
+ checkpoint_interval = self.config.trainer.export_checkpoint_interval or (
+ self.config.trainer.checkpoint_interval)
+ self._export_ckpt_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=export_ckpt_dir,
+ checkpoint_name='ckpt',
+ step_counter=self.global_step,
+ max_to_keep=max_to_keep,
+ checkpoint_interval=checkpoint_interval,
+ )
+
+ # Make sure we export the last checkpoint.
+ last_checkpoint = (
+ self.global_step.numpy() == self._config.trainer.train_steps)
+ checkpoint_path = self._export_ckpt_manager.save(
+ checkpoint_number=self.global_step.numpy(),
+ check_interval=not last_checkpoint)
+ if checkpoint_path:
+ logging.info('Checkpoints exported: %s.', checkpoint_path)
diff --git a/modeling/official/modeling/fast_training/progressive/trainer_test.py b/modeling/official/modeling/fast_training/progressive/trainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5377105d021ae321ea1029a9b9ea96508aa1fc9
--- /dev/null
+++ b/modeling/official/modeling/fast_training/progressive/trainer_test.py
@@ -0,0 +1,242 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the progressive trainer."""
+# pylint: disable=g-direct-tensorflow-import
+import os
+
+from absl.testing import parameterized
+import orbit
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.core import config_definitions as cfg
+from official.modeling import optimization
+from official.modeling.fast_training.progressive import policies
+from official.modeling.fast_training.progressive import trainer as trainer_lib
+from official.nlp.configs import bert
+from official.utils.testing import mock_task
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],)
+
+
+def get_exp_config():
+ return cfg.ExperimentConfig(
+ task=cfg.TaskConfig(
+ model=bert.PretrainerConfig()),
+ trainer=trainer_lib.ProgressiveTrainerConfig(
+ export_checkpoint=True,
+ export_checkpoint_interval=1,
+ export_only_final_stage_ckpt=False))
+
+
+class TestPolicy(policies.ProgressivePolicy, mock_task.MockTask):
+ """Just for testing purposes."""
+
+ def __init__(self, strategy, task_config, change_train_dataset=True):
+ self._strategy = strategy
+ self._change_train_dataset = change_train_dataset
+ self._my_train_dataset = None
+ mock_task.MockTask.__init__(self, params=task_config, logging_dir=None)
+ policies.ProgressivePolicy.__init__(self)
+
+ def num_stages(self) -> int:
+ return 2
+
+ def num_steps(self, stage_id: int) -> int:
+ return 2 if stage_id == 0 else 4
+
+ def get_model(self,
+ stage_id: int,
+ old_model: tf_keras.Model) -> tf_keras.Model:
+ del stage_id, old_model
+ return self.build_model()
+
+ def get_optimizer(self, stage_id: int) -> tf_keras.optimizers.Optimizer:
+ optimizer_type = 'sgd' if stage_id == 0 else 'adamw'
+ optimizer_config = cfg.OptimizationConfig({
+ 'optimizer': {'type': optimizer_type},
+ 'learning_rate': {'type': 'constant'}})
+ opt_factory = optimization.OptimizerFactory(optimizer_config)
+ return opt_factory.build_optimizer(opt_factory.build_learning_rate())
+
+ def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
+ if not self._change_train_dataset and self._my_train_dataset:
+ return self._my_train_dataset
+ if self._strategy:
+ self._my_train_dataset = orbit.utils.make_distributed_dataset(
+ self._strategy,
+ self._build_inputs,
+ stage_id)
+ else:
+ self._my_train_dataset = self._build_inputs(stage_id)
+ return self._my_train_dataset
+
+ def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
+ if self._strategy:
+ return orbit.utils.make_distributed_dataset(
+ self._strategy,
+ self._build_inputs,
+ stage_id)
+ return self._build_inputs(stage_id)
+
+ def _build_inputs(self, stage_id):
+ def dummy_data(_):
+ batch_size = 2 if stage_id == 0 else 1
+ x = tf.zeros(shape=(batch_size, 2), dtype=tf.float32)
+ label = tf.zeros(shape=(batch_size, 1), dtype=tf.float32)
+ return x, label
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ return dataset.map(
+ dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+
+
+class TrainerTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(TrainerTest, self).setUp()
+ self._config = get_exp_config()
+
+ def create_test_trainer(self, distribution, model_dir, change_train_dataset):
+ trainer = trainer_lib.ProgressiveTrainer(
+ self._config,
+ prog_task=TestPolicy(
+ distribution, self._config.task, change_train_dataset),
+ ckpt_dir=model_dir)
+ return trainer
+
+ @combinations.generate(all_strategy_combinations())
+ def test_checkpointing(self, distribution):
+ model_dir = self.get_temp_dir()
+ ckpt_file = os.path.join(model_dir, 'ckpt')
+ with distribution.scope():
+ trainer = self.create_test_trainer(distribution, model_dir, True)
+ self.assertFalse(trainer._task.is_last_stage)
+ trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
+ self.assertTrue(trainer._task.is_last_stage)
+ trainer.checkpoint.save(ckpt_file)
+
+ trainer = self.create_test_trainer(distribution, model_dir, True)
+ self.assertFalse(trainer._task.is_last_stage)
+ trainer.checkpoint.restore(ckpt_file + '-1')
+ self.assertTrue(trainer._task.is_last_stage)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_train_dataset(self, distribution):
+ model_dir = self.get_temp_dir()
+ with distribution.scope():
+ trainer = self.create_test_trainer(distribution, model_dir, True)
+ # Using dataset of stage == 0
+ train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
+ train_data = train_iter.next()[0]
+ if distribution.num_replicas_in_sync > 1:
+ train_data = train_data.values[0]
+ self.assertEqual(train_data.shape[0], 2)
+
+ trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
+ # Using dataset of stage == 1
+ train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
+ train_data = train_iter.next()[0]
+ if distribution.num_replicas_in_sync > 1:
+ train_data = train_data.values[0]
+ self.assertEqual(train_data.shape[0], 1)
+
+ with self.assertRaises(SyntaxError):
+ trainer.train_dataset = None
+
+ @combinations.generate(all_strategy_combinations())
+ def test_train_dataset_no_switch(self, distribution):
+ model_dir = self.get_temp_dir()
+ with distribution.scope():
+ trainer = self.create_test_trainer(distribution, model_dir, False)
+ trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
+ # _train_iter is not reset since the dataset is not changed.
+ self.assertIsNotNone(trainer._train_iter)
+ with distribution.scope():
+ trainer = self.create_test_trainer(distribution, model_dir, True)
+ trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
+ # _train_iter is reset since the dataset changed.
+ self.assertIsNone(trainer._train_iter)
+
+
+class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(TrainerWithMaskedLMTaskTest, self).setUp()
+ self._config = get_exp_config()
+
+ def create_test_trainer(self, distribution):
+ trainer = trainer_lib.ProgressiveTrainer(
+ self._config,
+ prog_task=TestPolicy(distribution, self._config.task),
+ ckpt_dir=self.get_temp_dir())
+ return trainer
+
+ @combinations.generate(all_strategy_combinations())
+ def test_trainer_train(self, distribution):
+ with distribution.scope():
+ trainer = self.create_test_trainer(distribution)
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('training_loss', logs)
+ self.assertIn('learning_rate', logs)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_trainer_validate(self, distribution):
+ with distribution.scope():
+ trainer = self.create_test_trainer(distribution)
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('validation_loss', logs)
+ self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
+
+ @combinations.generate(
+ combinations.combine(
+ mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
+ loss_scale=[None, 'dynamic', 128, 256],
+ ))
+ def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
+ config = cfg.ExperimentConfig(
+ task=cfg.TaskConfig(
+ model=bert.PretrainerConfig()),
+ runtime=cfg.RuntimeConfig(
+ mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
+ trainer=trainer_lib.ProgressiveTrainerConfig(
+ export_checkpoint=True,
+ export_checkpoint_interval=1,
+ export_only_final_stage_ckpt=False))
+ task = TestPolicy(None, config.task)
+ trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
+ if mixed_precision_dtype != 'float16':
+ self.assertIsInstance(
+ trainer.optimizer,
+ (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
+ elif mixed_precision_dtype == 'float16' and loss_scale is None:
+ self.assertIsInstance(
+ trainer.optimizer,
+ (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
+
+ metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertIn('training_loss', metrics)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/fast_training/progressive/utils.py b/modeling/official/modeling/fast_training/progressive/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca360348a71479c0925e91ac036883e07a4050d0
--- /dev/null
+++ b/modeling/official/modeling/fast_training/progressive/utils.py
@@ -0,0 +1,56 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Util classes and functions."""
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+# pylint: disable=g-direct-tensorflow-import
+from tensorflow.python.trackable import autotrackable
+
+
+class VolatileTrackable(autotrackable.AutoTrackable):
+ """A util class to keep Trackables that might change instances."""
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ def reassign_trackable(self, **kwargs):
+ for k, v in kwargs.items():
+ delattr(self, k) # untrack this object
+ setattr(self, k, v) # track the new object
+
+
+class CheckpointWithHooks(tf.train.Checkpoint):
+ """Same as tf.train.Checkpoint but supports hooks.
+
+ In progressive training, use this class instead of tf.train.Checkpoint.
+
+ Since the network architecture changes during progressive training, we need to
+ prepare something (like switch to the correct architecture) before loading the
+ checkpoint. This class supports a hook that will be executed before checkpoint
+ loading.
+ """
+
+ def __init__(self, before_load_hook, **kwargs):
+ self._before_load_hook = before_load_hook
+ super(CheckpointWithHooks, self).__init__(**kwargs)
+
+ # override
+ def read(self, save_path, options=None):
+ self._before_load_hook(save_path)
+ logging.info('Ran before_load_hook.')
+ super(CheckpointWithHooks, self).read(save_path=save_path, options=options)
diff --git a/modeling/official/modeling/grad_utils.py b/modeling/official/modeling/grad_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ccce40d79ad625518e47e64b76c000e212a4fb9
--- /dev/null
+++ b/modeling/official/modeling/grad_utils.py
@@ -0,0 +1,151 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Some gradient util functions to help users writing custom training loop."""
+
+from absl import logging
+
+import tensorflow as tf, tf_keras
+
+
+def _filter_grads(grads_and_vars):
+ """Filter out iterable with grad equal to None."""
+ grads_and_vars = tuple(grads_and_vars)
+ if not grads_and_vars:
+ return grads_and_vars
+ filtered = []
+ vars_with_empty_grads = []
+ for grad, var in grads_and_vars:
+ if grad is None:
+ vars_with_empty_grads.append(var)
+ else:
+ filtered.append((grad, var))
+ filtered = tuple(filtered)
+ if not filtered:
+ raise ValueError("No gradients provided for any variable: %s." %
+ ([v.name for _, v in grads_and_vars],))
+ if vars_with_empty_grads:
+ logging.warning(
+ ("Gradients do not exist for variables %s when minimizing the loss."),
+ ([v.name for v in vars_with_empty_grads]))
+ return filtered
+
+
+def _filter_and_allreduce_gradients(grads_and_vars,
+ allreduce_precision="float32",
+ bytes_per_pack=0):
+ """Filter None grads and then allreduce gradients in specified precision.
+
+ This utils function is used when users intent to explicitly allreduce
+ gradients and customize gradients operations before and after allreduce.
+ The allreduced gradients are then passed to optimizer.apply_gradients(
+ experimental_aggregate_gradients=False).
+
+ Args:
+ grads_and_vars: gradients and variables pairs.
+ allreduce_precision: Whether to allreduce gradients in float32 or float16.
+ bytes_per_pack: A non-negative integer. Breaks collective operations into
+ packs of certain size. If it's zero, all gradients are in one pack.
+
+ Returns:
+ pairs of allreduced non-None gradients and variables.
+ """
+ filtered_grads_and_vars = _filter_grads(grads_and_vars)
+ (grads, variables) = zip(*filtered_grads_and_vars)
+ if allreduce_precision == "float16":
+ grads = [tf.cast(grad, "float16") for grad in grads]
+ hints = tf.distribute.experimental.CommunicationOptions(
+ bytes_per_pack=bytes_per_pack)
+ allreduced_grads = tf.distribute.get_strategy( # pylint: disable=protected-access
+ ).extended._replica_ctx_all_reduce(tf.distribute.ReduceOp.SUM, grads, hints)
+ if allreduce_precision == "float16":
+ allreduced_grads = [tf.cast(grad, "float32") for grad in allreduced_grads]
+ return allreduced_grads, variables
+
+
+def _run_callbacks(callbacks, grads_and_vars):
+ for callback in callbacks:
+ grads_and_vars = callback(grads_and_vars)
+ return grads_and_vars
+
+
+def minimize_using_explicit_allreduce(tape,
+ optimizer,
+ loss,
+ trainable_variables,
+ pre_allreduce_callbacks=None,
+ post_allreduce_callbacks=None,
+ allreduce_bytes_per_pack=0):
+ """Minimizes loss for one step by updating `trainable_variables`.
+
+ Minimizes loss for one step by updating `trainable_variables`.
+ This explicitly performs gradient allreduce, instead of relying on implicit
+ allreduce in optimizer.apply_gradients(). If training using FP16 mixed
+ precision, explicit allreduce will aggregate gradients in FP16 format.
+ For TPU and GPU training using FP32, explicit allreduce will aggregate
+ gradients in FP32 format.
+
+ Args:
+ tape: An instance of `tf.GradientTape`.
+ optimizer: An instance of `tf_keras.optimizers.Optimizer`.
+ loss: the loss tensor.
+ trainable_variables: A list of model Variables.
+ pre_allreduce_callbacks: A list of callback functions that takes gradients
+ and model variables pairs as input, manipulate them, and returns a new
+ gradients and model variables pairs. The callback functions will be
+ invoked in the list order and before gradients are allreduced. With
+ mixed precision training, the pre_allreduce_allbacks will be applied on
+ scaled_gradients. Default is no callbacks.
+ post_allreduce_callbacks: A list of callback functions that takes
+ gradients and model variables pairs as input, manipulate them, and
+ returns a new gradients and model variables paris. The callback
+ functions will be invoked in the list order and right before gradients
+ are applied to variables for updates. Default is no callbacks.
+ allreduce_bytes_per_pack: A non-negative integer. Breaks collective
+ operations into packs of certain size. If it's zero, all gradients are
+ in one pack.
+ """
+ if isinstance(optimizer,
+ tf_keras.mixed_precision.LossScaleOptimizer):
+ # FP16 GPU code path
+ with tape:
+ scaled_loss = optimizer.get_scaled_loss(loss)
+ scaled_grads = tape.gradient(scaled_loss, trainable_variables)
+ grads_and_vars = zip(scaled_grads, trainable_variables)
+ if pre_allreduce_callbacks:
+ grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
+ (allreduced_scaled_grads,
+ filtered_training_vars) = _filter_and_allreduce_gradients(
+ grads_and_vars,
+ allreduce_precision="float16",
+ bytes_per_pack=allreduce_bytes_per_pack)
+ allreduced_unscaled_grads = optimizer.get_unscaled_gradients(
+ allreduced_scaled_grads)
+ grads_and_vars = zip(allreduced_unscaled_grads, filtered_training_vars)
+ else:
+ # TPU or FP32 GPU code path
+ grads = tape.gradient(loss, trainable_variables)
+ grads_and_vars = zip(grads, trainable_variables)
+ if pre_allreduce_callbacks:
+ grads_and_vars = _run_callbacks(pre_allreduce_callbacks, grads_and_vars)
+ (allreduced_grads,
+ filtered_training_vars) = _filter_and_allreduce_gradients(
+ grads_and_vars,
+ allreduce_precision="float32",
+ bytes_per_pack=allreduce_bytes_per_pack)
+ grads_and_vars = zip(allreduced_grads, filtered_training_vars)
+ if post_allreduce_callbacks:
+ grads_and_vars = _run_callbacks(post_allreduce_callbacks, grads_and_vars)
+ optimizer.apply_gradients(
+ grads_and_vars, experimental_aggregate_gradients=False)
diff --git a/modeling/official/modeling/grad_utils_test.py b/modeling/official/modeling/grad_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f9fc08a8313dee257b0572ebf794ad94de1773b
--- /dev/null
+++ b/modeling/official/modeling/grad_utils_test.py
@@ -0,0 +1,77 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for grad_utils."""
+
+import tensorflow as tf, tf_keras
+from official.modeling import grad_utils
+from official.modeling import performance
+
+
+class GradUtilsTest(tf.test.TestCase):
+
+ def test_minimize(self):
+
+ optimizer = tf_keras.optimizers.SGD(0.1)
+ with tf.GradientTape() as tape:
+ model = tf_keras.layers.Dense(2)
+ outputs = model(tf.zeros((2, 2), tf.float32))
+ loss = tf.reduce_mean(outputs)
+
+ grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
+ model.trainable_variables)
+
+ def test_minimize_fp16(self):
+
+ optimizer = performance.configure_optimizer(
+ tf_keras.optimizers.SGD(0.1), use_float16=True)
+ performance.set_mixed_precision_policy(tf.float16)
+ with tf.GradientTape() as tape:
+ model = tf_keras.layers.Dense(2)
+ outputs = model(tf.zeros((2, 2), tf.float16))
+ loss = tf.reduce_mean(outputs)
+
+ grad_utils.minimize_using_explicit_allreduce(tape, optimizer, loss,
+ model.trainable_variables)
+
+ # Test other fp16 settings.
+ def _clip_by_global_norm(grads_and_vars):
+ grads, tvars = list(zip(*grads_and_vars))
+ (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
+ return zip(grads, tvars)
+ with tf.GradientTape() as tape:
+ model = tf_keras.layers.Dense(2)
+ outputs = model(tf.zeros((2, 2), tf.float16))
+ loss = tf.reduce_mean(outputs)
+ optimizer = performance.configure_optimizer(
+ tf_keras.optimizers.SGD(0.1), use_float16=True, loss_scale=128)
+ grad_utils.minimize_using_explicit_allreduce(
+ tape,
+ optimizer,
+ loss,
+ model.trainable_variables,
+ pre_allreduce_callbacks=[_clip_by_global_norm],
+ post_allreduce_callbacks=[_clip_by_global_norm])
+
+ def test_set_mixed_precision_policy(self):
+ performance.set_mixed_precision_policy(tf.float16)
+ performance.set_mixed_precision_policy(tf.bfloat16)
+ performance.set_mixed_precision_policy(tf.float32)
+
+ with self.assertRaises(ValueError):
+ performance.set_mixed_precision_policy(tf.int32)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/hyperparams/__init__.py b/modeling/official/modeling/hyperparams/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a0f116bb7a3537bc99a6d1b148bfc40361445b
--- /dev/null
+++ b/modeling/official/modeling/hyperparams/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Hyperparams package definition."""
+# pylint: disable=g-multiple-import
+from official.modeling.hyperparams.base_config import *
+from official.modeling.hyperparams.oneof import *
+from official.modeling.hyperparams.params_dict import *
+
diff --git a/modeling/official/modeling/hyperparams/base_config.py b/modeling/official/modeling/hyperparams/base_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e1a57995f004e532d5c6f79c73394d019700ee3
--- /dev/null
+++ b/modeling/official/modeling/hyperparams/base_config.py
@@ -0,0 +1,350 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Base configurations to standardize experiments."""
+
+import copy
+import dataclasses
+import functools
+import inspect
+import typing
+from typing import Any, List, Mapping, Optional, Type, Union
+
+from absl import logging
+import tensorflow as tf, tf_keras
+import yaml
+
+from official.modeling.hyperparams import params_dict
+
+
+_BOUND = set()
+
+
+def bind(config_cls):
+ """Bind a class to config cls."""
+ if not inspect.isclass(config_cls):
+ raise ValueError('The bind decorator is supposed to apply on the class '
+ f'attribute. Received {config_cls}, not a class.')
+
+ def decorator(builder):
+ if config_cls in _BOUND:
+ raise ValueError('Inside a program, we should not bind the config with a'
+ ' class twice.')
+ if inspect.isclass(builder):
+ config_cls._BUILDER = builder # pylint: disable=protected-access
+ elif inspect.isfunction(builder):
+
+ def _wrapper(self, *args, **kwargs): # pylint: disable=unused-argument
+ return builder(*args, **kwargs)
+
+ config_cls._BUILDER = _wrapper # pylint: disable=protected-access
+ else:
+ raise ValueError(f'The `BUILDER` type is not supported: {builder}')
+ _BOUND.add(config_cls)
+ return builder
+
+ return decorator
+
+
+def _is_optional(field):
+ return typing.get_origin(field) is Union and type(None) in typing.get_args(
+ field)
+
+
+@dataclasses.dataclass
+class Config(params_dict.ParamsDict):
+ """The base configuration class that supports YAML/JSON based overrides.
+
+ Because of YAML/JSON serialization limitations, some semantics of dataclass
+ are not supported:
+ * It recursively enforces a allowlist of basic types and container types, so
+ it avoids surprises with copy and reuse caused by unanticipated types.
+ * Warning: it converts Dict to `Config` even within sequences,
+ e.g. for config = Config({'key': [([{'a': 42}],)]),
+ type(config.key[0][0][0]) is Config rather than dict.
+ If you define/annotate some field as Dict, the field will convert to a
+ `Config` instance and lose the dictionary type.
+ """
+ # The class or method to bind with the params class.
+ _BUILDER = None
+ # It's safe to add bytes and other immutable types here.
+ IMMUTABLE_TYPES = (str, int, float, bool, type(None))
+ # It's safe to add set, frozenset and other collections here.
+ SEQUENCE_TYPES = (list, tuple)
+
+ default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
+ restrictions: dataclasses.InitVar[Optional[List[str]]] = None
+
+ def __post_init__(self, default_params, restrictions):
+ super().__init__(
+ default_params=default_params,
+ restrictions=restrictions)
+
+ @property
+ def BUILDER(self):
+ return self._BUILDER
+
+ @classmethod
+ def _get_annotations(cls):
+ """Returns valid annotations.
+
+ Note: this is similar to dataclasses.__annotations__ except it also includes
+ annotations from its parent classes.
+ """
+ all_annotations = typing.get_type_hints(cls)
+ # Removes Config class annotation from the value, e.g., default_params,
+ # restrictions, etc.
+ for k in Config.__annotations__:
+ del all_annotations[k]
+ return all_annotations
+
+ @classmethod
+ def _isvalidsequence(cls, v):
+ """Check if the input values are valid sequences.
+
+ Args:
+ v: Input sequence.
+
+ Returns:
+ True if the sequence is valid. Valid sequence includes the sequence
+ type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
+ is dict or ParamsDict.
+ """
+ if not isinstance(v, cls.SEQUENCE_TYPES):
+ return False
+ return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
+ all(isinstance(e, dict) for e in v) or
+ all(isinstance(e, params_dict.ParamsDict) for e in v))
+
+ @classmethod
+ def _import_config(cls, v, subconfig_type):
+ """Returns v with dicts converted to Configs, recursively."""
+ if not issubclass(subconfig_type, params_dict.ParamsDict):
+ raise TypeError(
+ 'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
+ subconfig_type))
+ if isinstance(v, cls.IMMUTABLE_TYPES):
+ return v
+ elif isinstance(v, cls.SEQUENCE_TYPES):
+ # Only support one layer of sequence.
+ if not cls._isvalidsequence(v):
+ raise TypeError(
+ 'Invalid sequence: only supports single level {!r} of {!r} or '
+ 'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
+ cls.IMMUTABLE_TYPES, v))
+ import_fn = functools.partial(
+ cls._import_config, subconfig_type=subconfig_type)
+ return type(v)(map(import_fn, v))
+ elif isinstance(v, params_dict.ParamsDict):
+ # Deepcopy here is a temporary solution for preserving type in nested
+ # Config object.
+ return copy.deepcopy(v)
+ elif isinstance(v, dict):
+ return subconfig_type(v)
+ else:
+ raise TypeError('Unknown type: {!r}'.format(type(v)))
+
+ @classmethod
+ def _export_config(cls, v):
+ """Returns v with Configs converted to dicts, recursively."""
+ if isinstance(v, cls.IMMUTABLE_TYPES):
+ return v
+ elif isinstance(v, cls.SEQUENCE_TYPES):
+ return type(v)(map(cls._export_config, v))
+ elif isinstance(v, params_dict.ParamsDict):
+ return v.as_dict()
+ elif isinstance(v, dict):
+ raise TypeError('dict value not supported in converting.')
+ else:
+ raise TypeError('Unknown type: {!r}'.format(type(v)))
+
+ @classmethod
+ def _get_subconfig_type(
+ cls, k, subconfig_type=None
+ ) -> Type[params_dict.ParamsDict]:
+ """Get element type by the field name.
+
+ Args:
+ k: the key/name of the field.
+ subconfig_type: default subconfig_type. If None, it is set to
+ Config.
+
+ Returns:
+ Config as default. If a type annotation is found for `k`,
+ 1) returns the type of the annotation if it is subtype of ParamsDict;
+ 2) returns the element type if the annotation of `k` is List[SubType]
+ or Tuple[SubType].
+ """
+ if not subconfig_type:
+ subconfig_type = Config
+
+ annotations = cls._get_annotations()
+ if k in annotations:
+ # Directly Config subtype.
+ type_annotation = annotations[k]
+ i = 0
+ # Loop for striping the Optional annotation.
+ traverse_in = True
+ while traverse_in:
+ i += 1
+ if (isinstance(type_annotation, type) and
+ issubclass(type_annotation, Config)):
+ subconfig_type = type_annotation
+ break
+ else:
+ # Check if the field is a sequence of subtypes.
+ field_type = typing.get_origin(type_annotation)
+ if (isinstance(field_type, type) and
+ issubclass(field_type, cls.SEQUENCE_TYPES)):
+ element_type = typing.get_args(type_annotation)[0]
+ subconfig_type = (
+ element_type if issubclass(element_type, params_dict.ParamsDict)
+ else subconfig_type)
+ break
+ elif _is_optional(type_annotation):
+ # Strip the `Optional` annotation and process the subtype.
+ type_annotation = typing.get_args(type_annotation)[0]
+ continue
+ traverse_in = False
+ return subconfig_type
+
+ def _set(self, k, v):
+ """Overrides same method in ParamsDict.
+
+ Also called by ParamsDict methods.
+
+ Args:
+ k: key to set.
+ v: value.
+
+ Raises:
+ RuntimeError
+ """
+ subconfig_type = self._get_subconfig_type(k)
+
+ def is_null(k):
+ if k not in self.__dict__ or not self.__dict__[k]:
+ return True
+ return False
+
+ if isinstance(v, dict):
+ if is_null(k):
+ # If the key not exist or the value is None, a new Config-family object
+ # sould be created for the key.
+ self.__dict__[k] = subconfig_type(v)
+ else:
+ self.__dict__[k].override(v)
+ elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
+ [not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
+ if len(self.__dict__[k]) == len(v):
+ for i in range(len(v)):
+ self.__dict__[k][i].override(v[i])
+ elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
+ logging.warning(
+ "The list/tuple don't match the value dictionaries provided. Thus, "
+ 'the list/tuple is determined by the type annotation and '
+ 'values provided. This is error-prone.')
+ self.__dict__[k] = self._import_config(v, subconfig_type)
+ else:
+ self.__dict__[k] = self._import_config(v, subconfig_type)
+ else:
+ self.__dict__[k] = self._import_config(v, subconfig_type)
+
+ def __setattr__(self, k, v):
+ if k == 'BUILDER' or k == '_BUILDER':
+ raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
+ 'reserved class attribute. We should only assign '
+ '`_BUILDER` at the class level.')
+
+ if k not in self.RESERVED_ATTR:
+ if getattr(self, '_locked', False):
+ raise ValueError('The Config has been locked. ' 'No change is allowed.')
+ self._set(k, v)
+
+ def _override(self, override_dict, is_strict=True):
+ """Overrides same method in ParamsDict.
+
+ Also called by ParamsDict methods.
+
+ Args:
+ override_dict: dictionary to write to .
+ is_strict: If True, not allows to add new keys.
+
+ Raises:
+ KeyError: overriding reserved keys or keys not exist (is_strict=True).
+ """
+ for k, v in sorted(override_dict.items()):
+ if k in self.RESERVED_ATTR:
+ raise KeyError('The key {!r} is internally reserved. '
+ 'Can not be overridden.'.format(k))
+ if k not in self.__dict__:
+ if is_strict:
+ raise KeyError('The key {!r} does not exist in {!r}. '
+ 'To extend the existing keys, use '
+ '`override` with `is_strict` = False.'.format(
+ k, type(self)))
+ else:
+ self._set(k, v)
+ else:
+ if isinstance(v, dict) and self.__dict__[k]:
+ self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
+ elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
+ self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
+ else:
+ self._set(k, v)
+
+ def as_dict(self):
+ """Returns a dict representation of params_dict.ParamsDict.
+
+ For the nested params_dict.ParamsDict, a nested dict will be returned.
+ """
+ return {
+ k: self._export_config(v)
+ for k, v in self.__dict__.items()
+ if k not in self.RESERVED_ATTR
+ }
+
+ def replace(self, **kwargs):
+ """Overrides/returns a unlocked copy with the current config unchanged."""
+ # pylint: disable=protected-access
+ params = copy.deepcopy(self)
+ params._locked = False
+ params._override(kwargs, is_strict=True)
+ # pylint: enable=protected-access
+ return params
+
+ @classmethod
+ def from_yaml(cls, file_path: str):
+ # Note: This only works if the Config has all default values.
+ with tf.io.gfile.GFile(file_path, 'r') as f:
+ loaded = yaml.load(f, Loader=yaml.FullLoader)
+ config = cls()
+ config.override(loaded)
+ return config
+
+ @classmethod
+ def from_json(cls, file_path: str):
+ """Wrapper for `from_yaml`."""
+ return cls.from_yaml(file_path)
+
+ @classmethod
+ def from_args(cls, *args, **kwargs):
+ """Builds a config from the given list of arguments."""
+ # Note we intend to keep `__annotations__` instead of `_get_annotations`.
+ # Assuming a parent class of (a, b) with the sub-class of (c, d), the
+ # sub-class will take (c, d) for args, rather than starting from (a, b).
+ attributes = list(cls.__annotations__.keys())
+ default_params = {a: p for a, p in zip(attributes, args)}
+ default_params.update(kwargs)
+ return cls(default_params=default_params)
diff --git a/modeling/official/modeling/hyperparams/base_config_test.py b/modeling/official/modeling/hyperparams/base_config_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..efd976f9e3588ec8ea608b9448b2f5a5db317f64
--- /dev/null
+++ b/modeling/official/modeling/hyperparams/base_config_test.py
@@ -0,0 +1,427 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pprint
+import dataclasses
+from typing import List, Optional, Tuple
+
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class DumpConfig1(base_config.Config):
+ a: int = 1
+ b: str = 'text'
+
+
+@dataclasses.dataclass
+class DumpConfig2(base_config.Config):
+ c: int = 2
+ d: str = 'text'
+ e: DumpConfig1 = dataclasses.field(default_factory=DumpConfig1)
+ optional_e: Optional[DumpConfig1] = None
+
+
+@dataclasses.dataclass
+class DumpConfig3(DumpConfig2):
+ f: int = 2
+ g: str = 'text'
+ h: List[DumpConfig1] = dataclasses.field(
+ default_factory=lambda: [DumpConfig1(), DumpConfig1()])
+ g: Tuple[DumpConfig1, ...] = (DumpConfig1(),)
+
+
+@dataclasses.dataclass
+class DumpConfig4(DumpConfig2):
+ x: int = 3
+
+
+@dataclasses.dataclass
+class DummyConfig5(base_config.Config):
+ y: Tuple[DumpConfig2, ...] = (DumpConfig2(), DumpConfig4())
+ z: Tuple[str] = ('a',)
+
+
+@dataclasses.dataclass
+class DumpConfig6(base_config.Config):
+ test_config1: Optional[DumpConfig1] = None
+
+
+class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
+
+ def assertHasSameTypes(self, c, d, msg=''):
+ """Checks if a Config has the same structure as a given dict.
+
+ Args:
+ c: the Config object to be check.
+ d: the reference dict object.
+ msg: The error message to show when type mismatched.
+ """
+ # Make sure d is not a Config. Assume d is either
+ # dictionary or primitive type and c is the Config or primitive types.
+ self.assertNotIsInstance(d, base_config.Config)
+ if isinstance(d, base_config.Config.IMMUTABLE_TYPES):
+ self.assertEqual(pprint.pformat(c), pprint.pformat(d), msg=msg)
+ elif isinstance(d, base_config.Config.SEQUENCE_TYPES):
+ self.assertEqual(type(c), type(d), msg=msg)
+ for i, v in enumerate(d):
+ self.assertHasSameTypes(c[i], v, msg='{}[{!r}]'.format(msg, i))
+ elif isinstance(d, dict):
+ self.assertIsInstance(c, base_config.Config, msg=msg)
+ for k, v in sorted(d.items()):
+ self.assertHasSameTypes(getattr(c, k), v, msg='{}[{!r}]'.format(msg, k))
+ else:
+ raise TypeError('Unknown type: %r' % type(d))
+
+ def assertImportExport(self, v):
+ config = base_config.Config({'key': v})
+ back = config.as_dict()['key']
+ self.assertEqual(pprint.pformat(back), pprint.pformat(v))
+ self.assertHasSameTypes(config.key, v, msg='=%s v' % pprint.pformat(v))
+
+ def test_invalid_keys(self):
+ params = base_config.Config()
+ with self.assertRaises(AttributeError):
+ _ = params.a
+
+ def test_cls(self):
+ params = base_config.Config()
+ with self.assertRaisesRegex(
+ AttributeError,
+ '`BUILDER` is a property and `_BUILDER` is the reserved'):
+ params.BUILDER = DumpConfig2
+ with self.assertRaisesRegex(
+ AttributeError,
+ '`BUILDER` is a property and `_BUILDER` is the reserved'):
+ params._BUILDER = DumpConfig2
+
+ base_config.bind(DumpConfig1)(DumpConfig2)
+ params = DumpConfig1()
+ self.assertEqual(params.BUILDER, DumpConfig2)
+ with self.assertRaisesRegex(ValueError,
+ 'Inside a program, we should not bind'):
+ base_config.bind(DumpConfig1)(DumpConfig2)
+
+ def _test():
+ return 'test'
+
+ base_config.bind(DumpConfig2)(_test)
+ params = DumpConfig2()
+ self.assertEqual(params.BUILDER(), 'test')
+
+ def test_nested_config_types(self):
+ config = DumpConfig3()
+ self.assertIsInstance(config.e, DumpConfig1)
+ self.assertIsInstance(config.h[0], DumpConfig1)
+ self.assertIsInstance(config.h[1], DumpConfig1)
+ self.assertIsInstance(config.g[0], DumpConfig1)
+
+ config.override({'e': {'a': 2, 'b': 'new text'}})
+ self.assertIsInstance(config.e, DumpConfig1)
+ self.assertEqual(config.e.a, 2)
+ self.assertEqual(config.e.b, 'new text')
+
+ config.override({'h': [{'a': 3, 'b': 'new text 2'}]})
+ self.assertIsInstance(config.h[0], DumpConfig1)
+ self.assertLen(config.h, 1)
+ self.assertEqual(config.h[0].a, 3)
+ self.assertEqual(config.h[0].b, 'new text 2')
+
+ config.override({'g': [{'a': 4, 'b': 'new text 3'}]})
+ self.assertIsInstance(config.g[0], DumpConfig1)
+ self.assertLen(config.g, 1)
+ self.assertEqual(config.g[0].a, 4)
+ self.assertEqual(config.g[0].b, 'new text 3')
+
+ def test_replace(self):
+ config = DumpConfig2()
+ new_config = config.replace(e={'a': 2})
+ self.assertEqual(new_config.e.a, 2)
+ self.assertIsInstance(new_config.e, DumpConfig1)
+
+ config = DumpConfig2(e=DumpConfig2())
+ new_config = config.replace(e={'c': 4})
+ self.assertEqual(new_config.e.c, 4)
+ self.assertIsInstance(new_config.e, DumpConfig2)
+
+ config = DumpConfig3()
+ new_config = config.replace(g=[{'a': 4, 'b': 'new text 3'}])
+ self.assertIsInstance(new_config.g[0], DumpConfig1)
+ self.assertEqual(new_config.g[0].a, 4)
+
+ @parameterized.parameters(
+ ('_locked', "The key '_locked' is internally reserved."),
+ ('_restrictions', "The key '_restrictions' is internally reserved."),
+ ('aa', "The key 'aa' does not exist."),
+ )
+ def test_key_error(self, key, msg):
+ params = base_config.Config()
+ with self.assertRaisesRegex(KeyError, msg):
+ params.override({key: True})
+
+ @parameterized.parameters(
+ ('str data',),
+ (123,),
+ (1.23,),
+ (None,),
+ (['str', 1, 2.3, None],),
+ (('str', 1, 2.3, None),),
+ )
+ def test_import_export_immutable_types(self, v):
+ self.assertImportExport(v)
+ out = base_config.Config({'key': v})
+ self.assertEqual(pprint.pformat(v), pprint.pformat(out.key))
+
+ def test_override_is_strict_true(self):
+ params = base_config.Config({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 'cc',
+ 'c2': 20
+ }
+ })
+ params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
+ self.assertEqual(params.a, 2)
+ self.assertEqual(params.c.c1, 'ccc')
+ with self.assertRaises(KeyError):
+ params.override({'d': 'ddd'}, is_strict=True)
+ with self.assertRaises(KeyError):
+ params.override({'c': {'c3': 30}}, is_strict=True)
+
+ config = base_config.Config({'key': [{'a': 42}]})
+ with self.assertRaisesRegex(KeyError, "The key 'b' does not exist"):
+ config.override({'key': [{'b': 43}]})
+
+ @parameterized.parameters(
+ (lambda x: x, 'Unknown type'),
+ (object(), 'Unknown type'),
+ (set(), 'Unknown type'),
+ (frozenset(), 'Unknown type'),
+ )
+ def test_import_unsupport_types(self, v, msg):
+ with self.assertRaisesRegex(TypeError, msg):
+ _ = base_config.Config({'key': v})
+
+ @parameterized.parameters(
+ ({
+ 'a': [{
+ 'b': 2,
+ }, {
+ 'c': 3,
+ }]
+ },),
+ ({
+ 'c': [{
+ 'f': 1.1,
+ }, {
+ 'h': [1, 2],
+ }]
+ },),
+ (({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 10,
+ 'c2': 20,
+ }
+ },),),
+ )
+ def test_import_export_nested_structure(self, d):
+ self.assertImportExport(d)
+
+ @parameterized.parameters(
+ ([{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }],),
+ (({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },),),
+ )
+ def test_import_export_nested_sequences(self, v):
+ self.assertImportExport(v)
+
+ @parameterized.parameters(
+ ([([{}],)],),
+ ([['str', 1, 2.3, None]],),
+ ((('str', 1, 2.3, None),),),
+ ([
+ ('str', 1, 2.3, None),
+ ],),
+ ([
+ ('str', 1, 2.3, None),
+ ],),
+ ([[{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }]],),
+ ([[[{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }]]],),
+ ((({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },),),),
+ (((({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },),),),),
+ ([({
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ },)],),
+ (([{
+ 'a': 42,
+ 'b': 'hello',
+ 'c': 1.2
+ }],),),
+ )
+ def test_import_export_unsupport_sequence(self, v):
+ with self.assertRaisesRegex(TypeError,
+ 'Invalid sequence: only supports single level'):
+ _ = base_config.Config({'key': v})
+
+ def test_construct_subtype(self):
+ pass
+
+ def test_import_config(self):
+ params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
+ self.assertLen(params.a, 2)
+ self.assertEqual(params.a[0].b, 2)
+ self.assertEqual(type(params.a[0]), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[0].b), '2')
+ self.assertEqual(type(params.a[1]), base_config.Config)
+ self.assertEqual(type(params.a[1].c), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[1].c.d), '3')
+
+ def test_override(self):
+ params = base_config.Config({'a': [{'b': 2}, {'c': {'d': 3}}]})
+ params.override({'a': [{'b': 4}, {'c': {'d': 5}}]}, is_strict=False)
+ self.assertEqual(type(params.a), list)
+ self.assertEqual(type(params.a[0]), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[0].b), '4')
+ self.assertEqual(type(params.a[1]), base_config.Config)
+ self.assertEqual(type(params.a[1].c), base_config.Config)
+ self.assertEqual(pprint.pformat(params.a[1].c.d), '5')
+
+ @parameterized.parameters(
+ ([{}],),
+ (({},),),
+ )
+ def test_config_vs_params_dict(self, v):
+ d = {'key': v}
+ self.assertEqual(type(base_config.Config(d).key[0]), base_config.Config)
+ self.assertEqual(type(base_config.params_dict.ParamsDict(d).key[0]), dict)
+
+ def test_ppformat(self):
+ self.assertEqual(
+ pprint.pformat([
+ 's', 1, 1.0, True, None, {}, [], (), {
+ (2,): (3, [4], {
+ 6: 7,
+ }),
+ 8: 9,
+ }
+ ]),
+ "['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
+
+ def test_with_superclass_override(self):
+ config = DumpConfig2()
+ config.override({'optional_e': {'a': 2}})
+ self.assertEqual(
+ config.optional_e.as_dict(),
+ {
+ 'a': 2,
+ 'b': 'text',
+ },
+ )
+
+ # Previously, the following will fail. See b/274696969 for context.
+ config = DumpConfig3()
+ config.override({'optional_e': {'a': 2}})
+ self.assertEqual(
+ config.optional_e.as_dict(),
+ {
+ 'a': 2,
+ 'b': 'text',
+ },
+ )
+
+ def test_get_annotations_without_base_config_leak(self):
+ with self.assertRaisesRegex(
+ KeyError, "The key 'restrictions' does not exist"
+ ):
+ DumpConfig3().override({'restrictions': None})
+
+ def test_with_restrictions(self):
+ restrictions = ['e.a[a-zA-Z][\w\.]*)(?P\[?[0-9]*\]?) # variable name: "var" or "x" followed by optional index: "[0]" or "[23]"
+ \s*=\s*
+ ((?P\'(.*?)\' # single quote
+ |
+ \"(.*?)\" # double quote
+ |
+ [^,\[]* # single value
+ |
+ \[[^\]]*\])) # list of values
+ ($|,\s*)""", re.VERBOSE)
+
+_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)')
+
+# Yaml LOADER with an implicit resolver to parse float decimal and exponential
+# format. The regular experission parse the following cases:
+# 1- Decimal number with an optional exponential term.
+# 2- Integer number with an exponential term.
+# 3- Decimal number with an optional exponential term.
+# 4- Decimal number.
+
+_LOADER = yaml.FullLoader
+_LOADER.add_implicit_resolver(
+ 'tag:yaml.org,2002:float',
+ re.compile(r'''
+ ^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |
+ [-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |
+ \\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |
+ [-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X),
+ list('-+0123456789.'))
+
+
+class ParamsDict(object):
+ """A hyperparameter container class."""
+
+ RESERVED_ATTR = ['_locked', '_restrictions']
+
+ def __init__(self, default_params=None, restrictions=None):
+ """Instantiate a ParamsDict.
+
+ Instantiate a ParamsDict given a set of default parameters and a list of
+ restrictions. Upon initialization, it validates itself by checking all the
+ defined restrictions, and raise error if it finds inconsistency.
+
+ Args:
+ default_params: a Python dict or another ParamsDict object including the
+ default parameters to initialize.
+ restrictions: a list of strings, which define a list of restrictions to
+ ensure the consistency of different parameters internally. Each
+ restriction string is defined as a binary relation with a set of
+ operators, including {'==', '!=', '<', '<=', '>', '>='}.
+ """
+ self._locked = False
+ self._restrictions = []
+ if restrictions:
+ self._restrictions = restrictions
+ if default_params is None:
+ default_params = {}
+ self.override(default_params, is_strict=False)
+
+ def _set(self, k, v):
+ if isinstance(v, dict):
+ self.__dict__[k] = ParamsDict(v)
+ else:
+ self.__dict__[k] = copy.deepcopy(v)
+
+ def __setattr__(self, k, v):
+ """Sets the value of the existing key.
+
+ Note that this does not allow directly defining a new key. Use the
+ `override` method with `is_strict=False` instead.
+
+ Args:
+ k: the key string.
+ v: the value to be used to set the key `k`.
+
+ Raises:
+ KeyError: if k is not defined in the ParamsDict.
+ """
+ if k not in ParamsDict.RESERVED_ATTR:
+ if k not in self.__dict__.keys():
+ raise KeyError('The key `%{}` does not exist. '
+ 'To extend the existing keys, use '
+ '`override` with `is_strict` = True.'.format(k))
+ if self._locked:
+ raise ValueError('The ParamsDict has been locked. '
+ 'No change is allowed.')
+ self._set(k, v)
+
+ def __getattr__(self, k):
+ """Gets the value of the existing key.
+
+ Args:
+ k: the key string.
+
+ Returns:
+ the value of the key.
+
+ Raises:
+ AttributeError: if k is not defined in the ParamsDict.
+ """
+ if k not in self.__dict__.keys():
+ raise AttributeError('The key `{}` does not exist. '.format(k))
+ return self.__dict__[k]
+
+ def __contains__(self, key):
+ """Implements the membership test operator."""
+ return key in self.__dict__
+
+ def get(self, key, value=None):
+ """Accesses through built-in dictionary get method."""
+ return self.__dict__.get(key, value)
+
+ def __delattr__(self, k):
+ """Deletes the key and removes its values.
+
+ Args:
+ k: the key string.
+
+ Raises:
+ AttributeError: if k is reserverd or not defined in the ParamsDict.
+ ValueError: if the ParamsDict instance has been locked.
+ """
+ if k in ParamsDict.RESERVED_ATTR:
+ raise AttributeError(
+ 'The key `{}` is reserved. No change is allowes. '.format(k))
+ if k not in self.__dict__.keys():
+ raise AttributeError('The key `{}` does not exist. '.format(k))
+ if self._locked:
+ raise ValueError('The ParamsDict has been locked. No change is allowed.')
+ del self.__dict__[k]
+
+ def override(self, override_params, is_strict=True):
+ """Override the ParamsDict with a set of given params.
+
+ Args:
+ override_params: a dict or a ParamsDict specifying the parameters to be
+ overridden.
+ is_strict: a boolean specifying whether override is strict or not. If
+ True, keys in `override_params` must be present in the ParamsDict. If
+ False, keys in `override_params` can be different from what is currently
+ defined in the ParamsDict. In this case, the ParamsDict will be extended
+ to include the new keys.
+ """
+ if self._locked:
+ raise ValueError('The ParamsDict has been locked. No change is allowed.')
+ if isinstance(override_params, ParamsDict):
+ override_params = override_params.as_dict()
+ self._override(override_params, is_strict) # pylint: disable=protected-access
+
+ def _override(self, override_dict, is_strict=True):
+ """The implementation of `override`."""
+ for k, v in six.iteritems(override_dict):
+ if k in ParamsDict.RESERVED_ATTR:
+ raise KeyError('The key `%{}` is internally reserved. '
+ 'Can not be overridden.')
+ if k not in self.__dict__.keys():
+ if is_strict:
+ raise KeyError('The key `{}` does not exist. '
+ 'To extend the existing keys, use '
+ '`override` with `is_strict` = False.'.format(k))
+ else:
+ self._set(k, v)
+ else:
+ if isinstance(v, dict):
+ self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
+ elif isinstance(v, ParamsDict):
+ self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
+ else:
+ self.__dict__[k] = copy.deepcopy(v)
+
+ def lock(self):
+ """Makes the ParamsDict immutable."""
+ self._locked = True
+
+ def as_dict(self):
+ """Returns a dict representation of ParamsDict.
+
+ For the nested ParamsDict, a nested dict will be returned.
+ """
+ params_dict = {}
+ for k, v in six.iteritems(self.__dict__):
+ if k not in ParamsDict.RESERVED_ATTR:
+ if isinstance(v, ParamsDict):
+ params_dict[k] = v.as_dict()
+ else:
+ params_dict[k] = copy.deepcopy(v)
+ return params_dict
+
+ def validate(self):
+ """Validate the parameters consistency based on the restrictions.
+
+ This method validates the internal consistency using the pre-defined list of
+ restrictions. A restriction is defined as a string which specifies a binary
+ operation. The supported binary operations are {'==', '!=', '<', '<=', '>',
+ '>='}. Note that the meaning of these operators are consistent with the
+ underlying Python immplementation. Users should make sure the define
+ restrictions on their type make sense.
+
+ For example, for a ParamsDict like the following
+ ```
+ a:
+ a1: 1
+ a2: 2
+ b:
+ bb:
+ bb1: 10
+ bb2: 20
+ ccc:
+ a1: 1
+ a3: 3
+ ```
+ one can define two restrictions like this
+ ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2']
+
+ What it enforces are:
+ - a.a1 = 1 == b.ccc.a1 = 1
+ - a.a2 = 2 <= b.bb.bb2 = 20
+
+ Raises:
+ KeyError: if any of the following happens
+ (1) any of parameters in any of restrictions is not defined in
+ ParamsDict,
+ (2) any inconsistency violating the restriction is found.
+ ValueError: if the restriction defined in the string is not supported.
+ """
+
+ def _get_kv(dotted_string, params_dict):
+ """Get keys and values indicated by dotted_string."""
+ if _CONST_VALUE_RE.match(dotted_string) is not None:
+ const_str = dotted_string
+ if const_str == 'None':
+ constant = None
+ else:
+ constant = float(const_str)
+ return None, constant
+ else:
+ tokenized_params = dotted_string.split('.')
+ v = params_dict
+ for t in tokenized_params:
+ v = v[t]
+ return tokenized_params[-1], v
+
+ def _get_kvs(tokens, params_dict):
+ if len(tokens) != 2:
+ raise ValueError('Only support binary relation in restriction.')
+ stripped_tokens = [t.strip() for t in tokens]
+ left_k, left_v = _get_kv(stripped_tokens[0], params_dict)
+ right_k, right_v = _get_kv(stripped_tokens[1], params_dict)
+ return left_k, left_v, right_k, right_v
+
+ params_dict = self.as_dict()
+ for restriction in self._restrictions:
+ if '==' in restriction:
+ tokens = restriction.split('==')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v != right_v:
+ raise KeyError(
+ 'Found inconsistency between key `{}` and key `{}`.'.format(
+ tokens[0], tokens[1]))
+ elif '!=' in restriction:
+ tokens = restriction.split('!=')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v == right_v:
+ raise KeyError(
+ 'Found inconsistency between key `{}` and key `{}`.'.format(
+ tokens[0], tokens[1]))
+ elif '<' in restriction:
+ tokens = restriction.split('<')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v >= right_v:
+ raise KeyError(
+ 'Found inconsistency between key `{}` and key `{}`.'.format(
+ tokens[0], tokens[1]))
+ elif '<=' in restriction:
+ tokens = restriction.split('<=')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v > right_v:
+ raise KeyError(
+ 'Found inconsistency between key `{}` and key `{}`.'.format(
+ tokens[0], tokens[1]))
+ elif '>' in restriction:
+ tokens = restriction.split('>')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v <= right_v:
+ raise KeyError(
+ 'Found inconsistency between key `{}` and key `{}`.'.format(
+ tokens[0], tokens[1]))
+ elif '>=' in restriction:
+ tokens = restriction.split('>=')
+ _, left_v, _, right_v = _get_kvs(tokens, params_dict)
+ if left_v < right_v:
+ raise KeyError(
+ 'Found inconsistency between key `{}` and key `{}`.'.format(
+ tokens[0], tokens[1]))
+ else:
+ raise ValueError('Unsupported relation in restriction.')
+
+
+def read_yaml_to_params_dict(file_path: str):
+ """Reads a YAML file to a ParamsDict."""
+ with tf.io.gfile.GFile(file_path, 'r') as f:
+ params_dict = yaml.load(f, Loader=_LOADER)
+ return ParamsDict(params_dict)
+
+
+def save_params_dict_to_yaml(params, file_path):
+ """Saves the input ParamsDict to a YAML file."""
+ with tf.io.gfile.GFile(file_path, 'w') as f:
+
+ def _my_list_rep(dumper, data):
+ # u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
+ return dumper.represent_sequence(
+ u'tag:yaml.org,2002:seq', data, flow_style=True)
+
+ yaml.add_representer(list, _my_list_rep)
+ yaml.dump(params.as_dict(), f, default_flow_style=False)
+
+
+def nested_csv_str_to_json_str(csv_str):
+ """Converts a nested (using '.') comma-separated k=v string to a JSON string.
+
+ Converts a comma-separated string of key/value pairs that supports
+ nesting of keys to a JSON string. Nesting is implemented using
+ '.' between levels for a given key.
+
+ Spacing between commas and = is supported (e.g. there is no difference between
+ "a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before
+ keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported).
+
+ Note that this will only support values supported by CSV, meaning
+ values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not
+ supported. Strings are supported as well, e.g. "a='hello'".
+
+ An example conversion would be:
+
+ "a=1, b=2, c.a=2, c.b=3, d.a.a=5"
+
+ to
+
+ "{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}"
+
+ Args:
+ csv_str: the comma separated string.
+
+ Returns:
+ the converted JSON string.
+
+ Raises:
+ ValueError: If csv_str is not in a comma separated string or
+ if the string is formatted incorrectly.
+ """
+ if not csv_str:
+ return ''
+
+ array_param_map = collections.defaultdict(str)
+ max_index_map = collections.defaultdict(str)
+ formatted_entries = []
+ nested_map = collections.defaultdict(list)
+ pos = 0
+ while pos < len(csv_str):
+ m = _PARAM_RE.match(csv_str, pos)
+ if not m:
+ raise ValueError('Malformed hyperparameter value while parsing '
+ 'CSV string: %s' % csv_str[pos:])
+ pos = m.end()
+ # Parse the values.
+ m_dict = m.groupdict()
+ name = m_dict['name']
+ v = m_dict['val']
+ bracketed_index = m_dict['bracketed_index']
+ # If we reach the name of the array.
+ if bracketed_index and '.' not in name:
+ # Extract the array's index by removing '[' and ']'
+ index = int(bracketed_index[1:-1])
+ if '.' in v:
+ numeric_val = float(v)
+ else:
+ numeric_val = int(v)
+ # Add the value to the array.
+ if name not in array_param_map:
+ max_index_map[name] = index
+ array_param_map[name] = [None] * (index + 1)
+ array_param_map[name][index] = numeric_val
+ elif index < max_index_map[name]:
+ array_param_map[name][index] = numeric_val
+ else:
+ array_param_map[name] += [None] * (index - max_index_map[name])
+ array_param_map[name][index] = numeric_val
+ max_index_map[name] = index
+ continue
+
+ # If a GCS path (e.g. gs://...) is provided, wrap this in quotes
+ # as yaml.load would otherwise throw an exception
+ if re.match(r'(?=[^\"\'])(?=[gs://])', v):
+ v = '\'{}\''.format(v)
+
+ name_nested = name.split('.')
+ if len(name_nested) > 1:
+ grouping = name_nested[0]
+ if bracketed_index:
+ value = '.'.join(name_nested[1:]) + bracketed_index + '=' + v
+ else:
+ value = '.'.join(name_nested[1:]) + '=' + v
+ nested_map[grouping].append(value)
+ else:
+ formatted_entries.append('%s : %s' % (name, v))
+
+ for grouping, value in nested_map.items():
+ value = ','.join(value)
+ value = nested_csv_str_to_json_str(value)
+ formatted_entries.append('%s : %s' % (grouping, value))
+
+ # Add array parameters and check that the array is fully initialized.
+ for name in array_param_map:
+ if any(v is None for v in array_param_map[name]):
+ raise ValueError('Did not pass all values of array: %s' % name)
+ formatted_entries.append('%s : %s' % (name, array_param_map[name]))
+
+ return '{' + ', '.join(formatted_entries) + '}'
+
+
+def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
+ """Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file.
+
+ The logic of the function is outlined below:
+ 1. Test that the input is a dict. If not, proceed to 2.
+ 2. Tests that the input is a string. If not, raise unknown ValueError
+ 2.1. Test if the string is in a CSV format. If so, parse.
+ If not, proceed to 2.2.
+ 2.2. Try loading the string as a YAML/JSON. If successful, parse to
+ dict and use it to override. If not, proceed to 2.3.
+ 2.3. Try using the string as a file path and load the YAML file.
+
+ Args:
+ params: a ParamsDict object to be overridden.
+ dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to
+ a YAML file specifying the parameters to be overridden.
+ is_strict: a boolean specifying whether override is strict or not.
+
+ Returns:
+ params: the overridden ParamsDict object.
+
+ Raises:
+ ValueError: if failed to override the parameters.
+ """
+ if not dict_or_string_or_yaml_file:
+ return params
+ if isinstance(dict_or_string_or_yaml_file, dict):
+ params.override(dict_or_string_or_yaml_file, is_strict)
+ elif isinstance(dict_or_string_or_yaml_file, six.string_types):
+ try:
+ dict_or_string_or_yaml_file = (
+ nested_csv_str_to_json_str(dict_or_string_or_yaml_file))
+ except ValueError:
+ pass
+ params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=_LOADER)
+ if isinstance(params_dict, dict):
+ params.override(params_dict, is_strict)
+ else:
+ with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f:
+ params.override(yaml.load(f, Loader=_LOADER), is_strict)
+ else:
+ raise ValueError('Unknown input type to parse.')
+ return params
diff --git a/modeling/official/modeling/hyperparams/params_dict_test.py b/modeling/official/modeling/hyperparams/params_dict_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..1967d292722c57638f578cd610c426e39fa4b173
--- /dev/null
+++ b/modeling/official/modeling/hyperparams/params_dict_test.py
@@ -0,0 +1,446 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for params_dict.py."""
+
+import os
+
+import tensorflow as tf, tf_keras
+import yaml
+
+from official.modeling.hyperparams import params_dict
+
+
+class ParamsDictTest(tf.test.TestCase):
+
+ def test_init_from_an_empty_dict(self):
+ params = params_dict.ParamsDict()
+ with self.assertRaises(AttributeError):
+ _ = params.a
+
+ with self.assertRaises(KeyError):
+ params.a = 'aa'
+
+ def test_init_from_a_dict(self):
+ params = params_dict.ParamsDict({'a': 'aa', 'b': 2})
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+
+ def test_init_from_a_param_dict(self):
+ params_init = params_dict.ParamsDict({'a': 'aa', 'b': 2})
+ params = params_dict.ParamsDict(params_init)
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+
+ def test_lock(self):
+ params = params_dict.ParamsDict({'a': 1, 'b': 2, 'c': 3})
+ params.lock()
+ with self.assertRaises(ValueError):
+ params.a = 10
+ with self.assertRaises(ValueError):
+ params.override({'b': 20})
+ with self.assertRaises(ValueError):
+ del params.c
+
+ def test_setattr(self):
+ params = params_dict.ParamsDict()
+ params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
+ params.c = 'ccc'
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ self.assertEqual(params.c, 'ccc')
+
+ def test_getattr(self):
+ params = params_dict.ParamsDict()
+ params.override({'a': 'aa', 'b': 2, 'c': None}, is_strict=False)
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ self.assertEqual(params.c, None)
+
+ def test_delattr(self):
+ params = params_dict.ParamsDict()
+ params.override({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': None,
+ 'd': {
+ 'd1': 1,
+ 'd2': 10
+ }
+ },
+ is_strict=False)
+ del params.c
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ with self.assertRaises(AttributeError):
+ _ = params.c
+ del params.d
+ with self.assertRaises(AttributeError):
+ _ = params.d.d1
+
+ def test_contains(self):
+ params = params_dict.ParamsDict()
+ params.override({'a': 'aa'}, is_strict=False)
+ self.assertIn('a', params)
+ self.assertNotIn('b', params)
+
+ def test_get(self):
+ params = params_dict.ParamsDict()
+ params.override({'a': 'aa'}, is_strict=False)
+ self.assertEqual(params.get('a'), 'aa')
+ self.assertEqual(params.get('b', 2), 2)
+ self.assertEqual(params.get('b'), None)
+
+ def test_override_is_strict_true(self):
+ params = params_dict.ParamsDict({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 'cc',
+ 'c2': 20
+ }
+ })
+ params.override({'a': 2, 'c': {'c1': 'ccc'}}, is_strict=True)
+ self.assertEqual(params.a, 2)
+ self.assertEqual(params.c.c1, 'ccc')
+ with self.assertRaises(KeyError):
+ params.override({'d': 'ddd'}, is_strict=True)
+ with self.assertRaises(KeyError):
+ params.override({'c': {'c3': 30}}, is_strict=True)
+
+ def test_override_is_strict_false(self):
+ params = params_dict.ParamsDict({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 10,
+ 'c2': 20
+ }
+ })
+ params.override({'a': 2, 'c': {'c3': 3000}}, is_strict=False)
+ self.assertEqual(params.a, 2)
+ self.assertEqual(params.c.c3, 3000)
+ params.override({'d': 'ddd'}, is_strict=False)
+ self.assertEqual(params.d, 'ddd')
+ params.override({'c': {'c4': 4444}}, is_strict=False)
+ self.assertEqual(params.c.c4, 4444)
+
+ def test_as_dict(self):
+ params = params_dict.ParamsDict({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 10,
+ 'c2': 20
+ }
+ })
+ params_d = params.as_dict()
+ self.assertEqual(params_d['a'], 'aa')
+ self.assertEqual(params_d['b'], 2)
+ self.assertEqual(params_d['c']['c1'], 10)
+ self.assertEqual(params_d['c']['c2'], 20)
+
+ def test_validate(self):
+ # Raise error due to the unknown parameter.
+ with self.assertRaises(KeyError):
+ params = params_dict.ParamsDict({'a': 1, 'b': {'a': 11}}, ['a == c'])
+ params.validate()
+
+ # OK to check equality of two nested dicts.
+ params = params_dict.ParamsDict({
+ 'a': 1,
+ 'b': {
+ 'a': 10
+ },
+ 'c': {
+ 'a': 10
+ }
+ }, ['b == c'])
+
+ # Raise error due to inconsistency
+ with self.assertRaises(KeyError):
+ params = params_dict.ParamsDict({'a': 1, 'c': {'a': 10}}, ['a == c.a'])
+ params.validate()
+
+ # Valid rule.
+ params = params_dict.ParamsDict({'a': 1, 'c': {'a': 1}}, ['a == c.a'])
+
+ # Overriding violates the existing rule, raise error upon validate.
+ params.override({'a': 11})
+ with self.assertRaises(KeyError):
+ params.validate()
+
+ # Valid restrictions with constant.
+ params = params_dict.ParamsDict({
+ 'a': None,
+ 'c': {
+ 'a': 1
+ }
+ }, ['a == None', 'c.a == 1'])
+ params.validate()
+ with self.assertRaises(KeyError):
+ params = params_dict.ParamsDict({
+ 'a': 4,
+ 'c': {
+ 'a': 1
+ }
+ }, ['a == None', 'c.a == 1'])
+ params.validate()
+
+
+class ParamsDictIOTest(tf.test.TestCase):
+
+ def write_temp_file(self, filename, text):
+ temp_file = os.path.join(self.get_temp_dir(), filename)
+ with tf.io.gfile.GFile(temp_file, 'w') as writer:
+ writer.write(text)
+ return temp_file
+
+ def test_save_params_dict_to_yaml(self):
+ params = params_dict.ParamsDict({
+ 'a': 'aa',
+ 'b': 2,
+ 'c': {
+ 'c1': 10,
+ 'c2': 20
+ }
+ })
+ output_yaml_file = os.path.join(self.get_temp_dir(), 'params.yaml')
+ params_dict.save_params_dict_to_yaml(params, output_yaml_file)
+
+ with tf.io.gfile.GFile(output_yaml_file, 'r') as f:
+ params_d = yaml.load(f, Loader=yaml.Loader)
+ self.assertEqual(params.a, params_d['a'])
+ self.assertEqual(params.b, params_d['b'])
+ self.assertEqual(params.c.c1, params_d['c']['c1'])
+ self.assertEqual(params.c.c2, params_d['c']['c2'])
+
+ def test_read_yaml_to_params_dict(self):
+ input_yaml_file = self.write_temp_file(
+ 'params.yaml', r"""
+ a: 'aa'
+ b: 2
+ c:
+ c1: 10
+ c2: 20
+ """)
+ params = params_dict.read_yaml_to_params_dict(input_yaml_file)
+
+ self.assertEqual(params.a, 'aa')
+ self.assertEqual(params.b, 2)
+ self.assertEqual(params.c.c1, 10)
+ self.assertEqual(params.c.c2, 20)
+
+ def test_override_params_dict_using_dict(self):
+ params = params_dict.ParamsDict({
+ 'a': 1,
+ 'b': 2.5,
+ 'c': [3, 4],
+ 'd': 'hello',
+ 'e': False
+ })
+ override_dict = {'b': 5.2, 'c': [30, 40]}
+ params = params_dict.override_params_dict(
+ params, override_dict, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(5.2, params.b)
+ self.assertEqual([30, 40], params.c)
+ self.assertEqual('hello', params.d)
+ self.assertEqual(False, params.e)
+
+ def test_override_params_dict_using_yaml_string(self):
+ params = params_dict.ParamsDict({
+ 'a': 1,
+ 'b': 2.5,
+ 'c': [3, 4],
+ 'd': 'hello',
+ 'e': False
+ })
+ override_yaml_string = "'b': 5.2\n'c': [30, 40]"
+ params = params_dict.override_params_dict(
+ params, override_yaml_string, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(5.2, params.b)
+ self.assertEqual([30, 40], params.c)
+ self.assertEqual('hello', params.d)
+ self.assertEqual(False, params.e)
+
+ def test_override_params_dict_using_json_string(self):
+ params = params_dict.ParamsDict({
+ 'a': 1,
+ 'b': {
+ 'b1': 2,
+ 'b2': [2, 3],
+ },
+ 'd': {
+ 'd1': {
+ 'd2': 'hello'
+ }
+ },
+ 'e': False
+ })
+ override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
+ params = params_dict.override_params_dict(
+ params, override_json_string, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(2, params.b.b1)
+ self.assertEqual([3, 4], params.b.b2)
+ self.assertEqual('hi', params.d.d1.d2)
+ self.assertEqual(False, params.e)
+
+ def test_override_params_dict_using_csv_string(self):
+ params = params_dict.ParamsDict({
+ 'a': 1,
+ 'b': {
+ 'b1': 2,
+ 'b2': [2, 3],
+ },
+ 'd': {
+ 'd1': {
+ 'd2': 'hello'
+ }
+ },
+ 'e': False
+ })
+ override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
+ params = params_dict.override_params_dict(
+ params, override_csv_string, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(2, params.b.b1)
+ self.assertEqual([3, 4], params.b.b2)
+ self.assertEqual('hi, world', params.d.d1.d2)
+ self.assertEqual('gs://test', params.e)
+ # Test different float formats
+ override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
+ params = params_dict.override_params_dict(
+ params, override_csv_string, is_strict=True)
+ self.assertEqual(-1e-3, params.b.b2)
+ self.assertEqual(0.001, params.d.d1.d2)
+ self.assertEqual(1e3, params.e)
+ self.assertEqual(-1.5e-3, params.a)
+
+ def test_override_params_dict_using_yaml_file(self):
+ params = params_dict.ParamsDict({
+ 'a': 1,
+ 'b': 2.5,
+ 'c': [3, 4],
+ 'd': 'hello',
+ 'e': False
+ })
+ override_yaml_file = self.write_temp_file(
+ 'params.yaml', r"""
+ b: 5.2
+ c: [30, 40]
+ """)
+ params = params_dict.override_params_dict(
+ params, override_yaml_file, is_strict=True)
+ self.assertEqual(1, params.a)
+ self.assertEqual(5.2, params.b)
+ self.assertEqual([30, 40], params.c)
+ self.assertEqual('hello', params.d)
+ self.assertEqual(False, params.e)
+
+
+class IOTest(tf.test.TestCase):
+
+ def test_basic_csv_str_to_json_str(self):
+ csv_str = 'a=1,b=2,c=3'
+ json_str = '{a : 1, b : 2, c : 3}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_basic_csv_str_load(self):
+ csv_str = 'a=1,b=2,c=3'
+ expected_output = {'a': 1, 'b': 2, 'c': 3}
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str, Loader=yaml.Loader)
+ self.assertDictEqual(converted_dict, expected_output)
+
+ def test_basic_nested_csv_str_to_json_str(self):
+ csv_str = 'a=1,b.b1=2'
+ json_str = '{a : 1, b : {b1 : 2}}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_basic_nested_csv_str_load(self):
+ csv_str = 'a=1,b.b1=2,c.c1=3'
+ expected_output = {'a': 1, 'b': {'b1': 2}, 'c': {'c1': 3}}
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str, Loader=yaml.Loader)
+ self.assertDictEqual(converted_dict, expected_output)
+
+ def test_complex_nested_csv_str_to_json_str(self):
+ csv_str = 'a.aa.aaa.aaaaa.a=1'
+ json_str = '{a : {aa : {aaa : {aaaaa : {a : 1}}}}}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_complex_nested_csv_str_load(self):
+ csv_str = 'a.aa.aaa.aaaaa.a=1,a.a=2'
+ expected_output = {'a': {'aa': {'aaa': {'aaaaa': {'a': 1}}}, 'a': 2}}
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str, Loader=yaml.Loader)
+ self.assertDictEqual(converted_dict, expected_output)
+
+ def test_int_array_param_nested_csv_str_to_json_str(self):
+ csv_str = 'a.b[2]=3,a.b[0]=1,a.b[1]=2'
+ json_str = '{a : {b : [1, 2, 3]}}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_float_array_param_nested_csv_str_to_json_str(self):
+ csv_str = 'a.b[1]=3.45,a.b[2]=1.32,a.b[0]=2.232'
+ json_str = '{a : {b : [2.232, 3.45, 1.32]}}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, json_str)
+
+ def test_incomplete_array_param_nested_csv_str_to_json_str(self):
+ csv_str = 'a.b[0]=1,a.b[2]=2'
+ self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
+ csv_str)
+
+ def test_csv_str_load_supported_datatypes(self):
+ csv_str = 'a=1,b=2.,c=[1,2,3],d=\'hello, there\',e=\"Hi.\"'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ converted_dict = yaml.load(converted_csv_str, Loader=yaml.Loader)
+ self.assertEqual(converted_dict['a'], 1)
+ self.assertEqual(converted_dict['b'], 2.)
+ self.assertEqual(converted_dict['c'], [1, 2, 3])
+ self.assertEqual(converted_dict['d'], 'hello, there')
+ self.assertEqual(converted_dict['e'], 'Hi.')
+
+ def test_csv_str_load_unsupported_datatypes(self):
+ csv_str = 'a=[[1,2,3],[4,5,6]]'
+ self.assertRaises(ValueError, params_dict.nested_csv_str_to_json_str,
+ csv_str)
+
+ def test_csv_str_to_json_str_spacing(self):
+ csv_str1 = 'a=1,b=2,c=3'
+ csv_str2 = 'a = 1, b = 2, c = 3'
+ json_str = '{a : 1, b : 2, c : 3}'
+ converted_csv_str1 = params_dict.nested_csv_str_to_json_str(csv_str1)
+ converted_csv_str2 = params_dict.nested_csv_str_to_json_str(csv_str2)
+ self.assertEqual(converted_csv_str1, converted_csv_str2)
+ self.assertEqual(converted_csv_str1, json_str)
+ self.assertEqual(converted_csv_str2, json_str)
+
+ def test_gcs_added_quotes(self):
+ csv_str = 'a=gs://abc, b=gs://def'
+ expected_output = '{a : \'gs://abc\', b : \'gs://def\'}'
+ converted_csv_str = params_dict.nested_csv_str_to_json_str(csv_str)
+ self.assertEqual(converted_csv_str, expected_output)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/multitask/__init__.py b/modeling/official/modeling/multitask/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/modeling/multitask/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/modeling/multitask/base_model.py b/modeling/official/modeling/multitask/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..262eec3deb69566f72b2810827c602ed361f4991
--- /dev/null
+++ b/modeling/official/modeling/multitask/base_model.py
@@ -0,0 +1,54 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Abstraction of multi-task model."""
+from typing import Text, Dict
+
+import tensorflow as tf, tf_keras
+
+
+class MultiTaskBaseModel(tf.Module):
+ """Base class that holds multi-task model computation."""
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._sub_tasks = self._instantiate_sub_tasks()
+
+ def _instantiate_sub_tasks(self) -> Dict[Text, tf_keras.Model]:
+ """Abstract function that sets up the computation for each sub-task.
+
+ Returns:
+ A map from task name (as string) to a tf_keras.Model object that
+ represents the sub-task in the multi-task pool.
+ """
+ raise NotImplementedError(
+ "_instantiate_sub_task_models() is not implemented.")
+
+ @property
+ def sub_tasks(self):
+ """Fetch a map of task name (string) to task model (tf_keras.Model)."""
+ return self._sub_tasks
+
+ def initialize(self):
+ """Optional function that loads a pre-train checkpoint."""
+ return
+
+ def build(self):
+ """Builds the networks for tasks to make sure variables are created."""
+ # Try to build all sub tasks.
+ for task_model in self._sub_tasks.values():
+ # Assumes all the tf.Module models are built because we don't have any
+ # way to check them.
+ if isinstance(task_model, tf_keras.Model) and not task_model.built:
+ _ = task_model(task_model.inputs)
diff --git a/modeling/official/modeling/multitask/base_trainer.py b/modeling/official/modeling/multitask/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ac812f2b1f60a10e9212a875f3f331c8e21ded0
--- /dev/null
+++ b/modeling/official/modeling/multitask/base_trainer.py
@@ -0,0 +1,170 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Multitask base trainer implementation.
+
+The trainer derives from the Orbit `StandardTrainer` class.
+"""
+from typing import Union
+
+import gin
+import orbit
+import tensorflow as tf, tf_keras
+
+from official.modeling import optimization
+from official.modeling.multitask import base_model
+from official.modeling.multitask import multitask
+
+
+@gin.configurable
+class MultiTaskBaseTrainer(orbit.StandardTrainer):
+ """Multitask base trainer."""
+
+ def __init__(self,
+ multi_task: multitask.MultiTask,
+ multi_task_model: Union[tf_keras.Model,
+ base_model.MultiTaskBaseModel],
+ optimizer: tf.optimizers.Optimizer,
+ trainer_options=None,
+ train_datasets=None):
+ self._strategy = tf.distribute.get_strategy()
+ self._multi_task = multi_task
+ self._multi_task_model = multi_task_model
+ self._optimizer = optimizer
+
+ self._training_losses = None
+ self._training_metrics = None
+ self._global_step = orbit.utils.create_global_step()
+
+ # Creates a shadow copy of the weights to store weights moving average.
+ if isinstance(self._optimizer, optimization.ExponentialMovingAverage
+ ) and not self._optimizer.has_shadow_copy:
+ self._optimizer.shadow_copy(multi_task_model)
+
+ if hasattr(self.multi_task_model, "checkpoint_items"):
+ checkpoint_items = self.multi_task_model.checkpoint_items
+ else:
+ checkpoint_items = {}
+
+ self._checkpoint = tf.train.Checkpoint(
+ model=self.multi_task_model,
+ optimizer=self.optimizer,
+ global_step=self.global_step,
+ **checkpoint_items)
+
+ if train_datasets is None:
+ train_datasets = {}
+ for name, task in self.multi_task.tasks.items():
+ train_datasets[name] = orbit.utils.make_distributed_dataset(
+ self.strategy, task.build_inputs, task.task_config.train_data)
+
+ super().__init__(
+ train_dataset=train_datasets,
+ options=trainer_options or orbit.StandardTrainerOptions())
+
+ def train_loop_begin(self):
+ """Clean up states that hold losses and metrics."""
+ for _, train_loss_metric in self.training_losses.items():
+ train_loss_metric.reset_states()
+
+ for _, metrics in self.training_metrics.items():
+ for metric in metrics:
+ metric.reset_states()
+
+ def train_loop_end(self):
+ """Record loss and metric values per task."""
+ result = {}
+ for task_name, loss in self.training_losses.items():
+ result[task_name] = {loss.name: loss.result()}
+ for task_name, task_metrics in self.training_metrics.items():
+ result[task_name].update(
+ {metric.name: metric.result() for metric in task_metrics})
+ # Note that, the learning rate schedule is managed by the keras optimizer
+ # internally, which respects the number of backward pass as `iterations`.
+ # The learning rate schedule does not follow the trainer logical global
+ # step of multiple tasks.
+ if callable(self.optimizer.learning_rate):
+ result["learning_rate"] = self.optimizer.learning_rate(
+ self.optimizer.iterations)
+ else:
+ result["learning_rate"] = self.optimizer.learning_rate
+ return result
+
+ @property
+ def checkpoint(self):
+ """Accesses the training checkpoint."""
+ return self._checkpoint
+
+ @property
+ def training_losses(self):
+ """Access training loss metric objects for all tasks."""
+ if self._training_losses is None:
+ # Builds the per-task metrics and losses.
+ # This the total summed training loss of tasks in the joint training.
+ self._training_losses = dict(
+ total_loss=tf_keras.metrics.Mean("training_loss", dtype=tf.float32))
+ for name in self.multi_task.tasks:
+ self._training_losses[name] = tf_keras.metrics.Mean(
+ "training_loss", dtype=tf.float32)
+ return self._training_losses
+
+ @property
+ def training_metrics(self):
+ """Access training metric metric objects for all tasks."""
+ if self._training_metrics is None:
+ # Builds the per-task metrics and losses.
+ self._training_metrics = {}
+ for name, task in self.multi_task.tasks.items():
+ self._training_metrics[name] = task.build_metrics(training=True)
+ return self._training_metrics
+
+ @property
+ def strategy(self):
+ return self._strategy
+
+ @property
+ def multi_task(self):
+ return self._multi_task
+
+ @property
+ def multi_task_model(self):
+ return self._multi_task_model
+
+ @property
+ def optimizer(self):
+ return self._optimizer
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ def train_step(self, iterator_map):
+ """The default train step calling the multi-task train step.
+
+ Args:
+ iterator_map: a dictionary of task names and per-task dataset iterators.
+ """
+
+ def step_fn(inputs):
+ losses = self.multi_task.joint_train_step(
+ inputs,
+ multi_task_model=self.multi_task_model,
+ optimizer=self.optimizer,
+ task_metrics=self.training_metrics)
+ for key, loss in losses.items():
+ self.training_losses[key].update_state(loss)
+
+ self.strategy.run(
+ step_fn, args=(tf.nest.map_structure(next, iterator_map),))
+ self.global_step.assign_add(1)
diff --git a/modeling/official/modeling/multitask/base_trainer_test.py b/modeling/official/modeling/multitask/base_trainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..039bdb012135e8bcf0f04711e2bc1f1ccc0a9b8f
--- /dev/null
+++ b/modeling/official/modeling/multitask/base_trainer_test.py
@@ -0,0 +1,90 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for multitask.base_trainer."""
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.modeling.multitask import base_trainer
+from official.modeling.multitask import configs
+from official.modeling.multitask import multitask
+from official.modeling.multitask import test_utils
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ mode="eager",
+ )
+
+
+class BaseTrainerTest(tf.test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_strategy_combinations())
+ def test_multitask_joint_trainer(self, distribution):
+ with distribution.scope():
+ tasks = [
+ test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
+ test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
+ ]
+ task_weights = {"foo": 1.0, "bar": 1.0}
+ test_multitask = multitask.MultiTask(
+ tasks=tasks, task_weights=task_weights)
+ test_optimizer = tf_keras.optimizers.SGD(0.1)
+ model = test_utils.MockMultiTaskModel()
+ test_trainer = base_trainer.MultiTaskBaseTrainer(
+ multi_task=test_multitask,
+ multi_task_model=model,
+ optimizer=test_optimizer)
+ results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertContainsSubset(["training_loss", "bar_acc"],
+ results["bar"].keys())
+ self.assertContainsSubset(["training_loss", "foo_acc"],
+ results["foo"].keys())
+
+ def test_trainer_with_configs(self):
+ config = configs.MultiTaskConfig(
+ task_routines=(configs.TaskRoutine(
+ task_name="foo",
+ task_config=test_utils.FooConfig(),
+ task_weight=0.5),
+ configs.TaskRoutine(
+ task_name="bar",
+ task_config=test_utils.BarConfig(),
+ task_weight=0.5)))
+ test_multitask = multitask.MultiTask.from_config(config)
+ test_optimizer = tf_keras.optimizers.SGD(0.1)
+ model = test_utils.MockMultiTaskModel()
+ test_trainer = base_trainer.MultiTaskBaseTrainer(
+ multi_task=test_multitask,
+ multi_task_model=model,
+ optimizer=test_optimizer)
+ results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertContainsSubset(["training_loss", "bar_acc"],
+ results["bar"].keys())
+ self.assertContainsSubset(["training_loss", "foo_acc"],
+ results["foo"].keys())
+ self.assertEqual(test_multitask.task_weight("foo"), 0.5)
+ self.assertEqual(test_trainer.global_step.numpy(), 5)
+ self.assertIn("learning_rate", results)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/modeling/multitask/configs.py b/modeling/official/modeling/multitask/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcda4a568e174ff9da25cff6124aed9fb51ab2e0
--- /dev/null
+++ b/modeling/official/modeling/multitask/configs.py
@@ -0,0 +1,98 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configuration definitions for multi-task training."""
+import dataclasses
+from typing import Optional, Tuple
+
+from official.core import config_definitions as cfg
+from official.modeling import hyperparams
+from official.modeling.privacy import configs as dp_configs
+
+
+@dataclasses.dataclass
+class TaskRoutine(hyperparams.Config):
+ # TODO(hongkuny): deprecate the task_name once we migrated client code.
+ task_name: str = ""
+ task_config: cfg.TaskConfig = None
+ eval_steps: Optional[int] = None
+ task_weight: Optional[float] = 1.0
+
+
+@dataclasses.dataclass
+class MultiTaskConfig(hyperparams.Config):
+ init_checkpoint: str = ""
+ model: hyperparams.Config = None
+ task_routines: Tuple[TaskRoutine, ...] = ()
+ # Configs for differential privacy
+ # These configs are only effective if you use create_optimizer in
+ # tensorflow_models/official/core/base_task.py
+ # DEPRECATED b/264611883
+ differential_privacy_config: Optional[
+ dp_configs.DifferentialPrivacyConfig] = None
+
+
+@dataclasses.dataclass
+class ProportionalSampleConfig(hyperparams.Config):
+ alpha: float = 1.0
+
+
+@dataclasses.dataclass
+class AnnealingSampleConfig(hyperparams.Config):
+ steps_per_epoch: int = 5
+ total_steps: int = 20
+
+
+@dataclasses.dataclass
+class TaskSamplingConfig(hyperparams.OneOfConfig):
+ type: str = ""
+ uniform: hyperparams.Config = dataclasses.field(
+ default_factory=hyperparams.Config
+ )
+ proportional: ProportionalSampleConfig = dataclasses.field(
+ default_factory=ProportionalSampleConfig
+ )
+ annealing: AnnealingSampleConfig = dataclasses.field(
+ default_factory=AnnealingSampleConfig
+ )
+
+
+@dataclasses.dataclass
+class MultiTaskTrainerConfig(cfg.TrainerConfig):
+ trainer_type: str = "interleaving"
+ task_sampler: TaskSamplingConfig = dataclasses.field(
+ default_factory=lambda: TaskSamplingConfig(type="proportional")
+ )
+
+
+@dataclasses.dataclass
+class MultiTaskExperimentConfig(hyperparams.Config):
+ """An experiment config for multi-task training and multi-task evaluation."""
+ task: MultiTaskConfig = dataclasses.field(default_factory=MultiTaskConfig)
+ trainer: MultiTaskTrainerConfig = dataclasses.field(
+ default_factory=MultiTaskTrainerConfig
+ )
+ runtime: cfg.RuntimeConfig = dataclasses.field(
+ default_factory=cfg.RuntimeConfig
+ )
+
+
+@dataclasses.dataclass
+class MultiEvalExperimentConfig(cfg.ExperimentConfig):
+ """An experiment config for single-task training and multi-task evaluation.
+
+ Attributes:
+ eval_tasks: individual evaluation tasks.
+ """
+ eval_tasks: Tuple[TaskRoutine, ...] = ()
diff --git a/modeling/official/modeling/multitask/evaluator.py b/modeling/official/modeling/multitask/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e0ddb546820648598ea3366de256ba16db4c6cc
--- /dev/null
+++ b/modeling/official/modeling/multitask/evaluator.py
@@ -0,0 +1,180 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Multitask Evaluator implementation.
+
+The evaluator implements the Orbit `AbstractEvaluator` interface.
+"""
+from typing import Dict, List, Optional, Union
+import gin
+import orbit
+import tensorflow as tf, tf_keras
+
+from official.core import base_task
+from official.core import train_utils
+from official.modeling.multitask import base_model
+
+
+@gin.configurable
+class MultiTaskEvaluator(orbit.AbstractEvaluator):
+ """Implements the common trainer shared for TensorFlow models."""
+
+ def __init__(
+ self,
+ eval_tasks: List[base_task.Task],
+ model: Union[tf_keras.Model, base_model.MultiTaskBaseModel],
+ global_step: Optional[tf.Variable] = None,
+ eval_steps: Optional[Dict[str, int]] = None,
+ checkpoint_exporter: Optional[train_utils.BestCheckpointExporter] = None):
+ """Initialize common trainer for TensorFlow models.
+
+ Args:
+ eval_tasks: A list of tasks to evaluate.
+ model: tf_keras.Model instance.
+ global_step: the global step variable.
+ eval_steps: a dictionary of steps to run eval keyed by task names.
+ checkpoint_exporter: an object that has the `maybe_export_checkpoint`
+ interface.
+ """
+ # Gets the current distribution strategy. If not inside any strategy scope,
+ # it gets a single-replica no-op strategy.
+ self._strategy = tf.distribute.get_strategy()
+ self._tasks = eval_tasks
+ self._model = model
+ self._global_step = global_step or orbit.utils.create_global_step()
+ self._checkpoint_exporter = checkpoint_exporter
+ if hasattr(self.model, "checkpoint_items"):
+ checkpoint_items = self.model.checkpoint_items
+ else:
+ checkpoint_items = {}
+
+ self._checkpoint = tf.train.Checkpoint(
+ model=self.model,
+ global_step=self.global_step,
+ **checkpoint_items)
+
+ self._validation_losses = None
+ self._validation_metrics = None
+
+ # Builds per-task datasets.
+ self.eval_datasets = {}
+ self.eval_steps = eval_steps or {}
+ for task in self.tasks:
+ self.eval_datasets[task.name] = orbit.utils.make_distributed_dataset(
+ self.strategy, task.build_inputs, task.task_config.validation_data)
+
+ # Builds per-task validation loops.
+ def get_function(task_name, task):
+
+ task_metrics = self.validation_metrics[task_name]
+ task_loss = self.validation_losses[task_name]
+ if isinstance(self.model, base_model.MultiTaskBaseModel):
+ model = self.model.sub_tasks[task_name]
+ else:
+ model = self.model
+
+ def step_fn(inputs):
+ logs = task.validation_step(inputs, model=model, metrics=task_metrics)
+ task_loss.update_state(logs[task.loss])
+ return logs
+
+ @tf.function
+ def eval_step_fn(iterator):
+ distributed_outputs = self.strategy.run(step_fn, args=(next(iterator),))
+ return tf.nest.map_structure(self.strategy.experimental_local_results,
+ distributed_outputs)
+
+ return orbit.utils.create_loop_fn(eval_step_fn)
+
+ self.task_fns = {
+ task.name: get_function(task.name, task) for task in self.tasks
+ }
+
+ @property
+ def strategy(self):
+ return self._strategy
+
+ @property
+ def tasks(self):
+ return self._tasks
+
+ @property
+ def model(self):
+ return self._model
+
+ @property
+ def global_step(self):
+ return self._global_step
+
+ @property
+ def validation_losses(self):
+ """Accesses the validation loss metric object."""
+ if self._validation_losses is None:
+ # Builds the per-task metrics and losses.
+ self._validation_losses = {}
+ for task in self.tasks:
+ self._validation_losses[task.name] = tf_keras.metrics.Mean(
+ "validation_loss", dtype=tf.float32)
+ return self._validation_losses
+
+ @property
+ def validation_metrics(self):
+ """Accesses all validation metric metric objects."""
+ if self._validation_metrics is None:
+ # Builds the per-task metrics and losses.
+ self._validation_metrics = {}
+ for task in self.tasks:
+ self._validation_metrics[task.name] = task.build_metrics(training=False)
+ return self._validation_metrics
+
+ @property
+ def checkpoint(self):
+ """Accesses the training checkpoint."""
+ return self._checkpoint
+
+ def evaluate(self, num_steps: tf.Tensor):
+ """Performs evaluation for each `EvalTask`."""
+ for metric in self.validation_losses.values():
+ metric.reset_states()
+ for metrics in self.validation_metrics.values():
+ for metric in metrics:
+ metric.reset_states()
+ results = {}
+ eval_iters = tf.nest.map_structure(iter, self.eval_datasets)
+
+ for task in self.tasks:
+ outputs = None
+ name = task.name
+ eval_iter = eval_iters[name]
+ task_eval_steps = self.eval_steps.get(name, None) or num_steps
+ outputs = self.task_fns[name](
+ eval_iter,
+ task_eval_steps,
+ state=outputs,
+ reduce_fn=task.aggregate_logs)
+ task_metrics = self.validation_metrics[name]
+ task_loss = self.validation_losses[name]
+ logs = {}
+ for metric in task_metrics + [task_loss]:
+ logs[metric.name] = metric.result()
+ if outputs:
+ metrics = task.reduce_aggregated_logs(
+ outputs, global_step=self.global_step)
+ logs.update(metrics)
+ results[name] = logs
+
+ if self._checkpoint_exporter:
+ self._checkpoint_exporter.maybe_export_checkpoint(
+ self.checkpoint, results, self.global_step.numpy())
+ return results
diff --git a/modeling/official/modeling/multitask/evaluator_test.py b/modeling/official/modeling/multitask/evaluator_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e37b12116e1db146a273eb41487ca51e883b014
--- /dev/null
+++ b/modeling/official/modeling/multitask/evaluator_test.py
@@ -0,0 +1,133 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for multitask.evaluator."""
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.core import base_task
+from official.core import config_definitions as cfg
+from official.modeling.multitask import evaluator
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ mode="eager",
+ )
+
+
+class MockModel(tf_keras.Model):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.dense = tf_keras.layers.Dense(1)
+
+ def call(self, inputs):
+ print(inputs, type(inputs))
+ if "y" in inputs:
+ self.add_loss(tf.zeros((1,), dtype=tf.float32))
+ else:
+ self.add_loss(tf.ones((1,), dtype=tf.float32))
+ return self.dense(inputs["x"])
+
+
+class MockTask(base_task.Task):
+ """Mock task object for testing."""
+
+ def build_metrics(self, training: bool = True):
+ del training
+ return [tf_keras.metrics.Accuracy(name="acc")]
+
+ def build_inputs(self, params):
+
+ def generate_data(_):
+ x = tf.zeros(shape=(2,), dtype=tf.float32)
+ label = tf.zeros([1], dtype=tf.int32)
+ if self.name == "bar":
+ return dict(x=x, y=x), label
+ else:
+ return dict(x=x), label
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
+
+ def validation_step(self, inputs, model: tf_keras.Model, metrics=None):
+ logs = super().validation_step(inputs, model, metrics)
+ logs["counter"] = tf.ones((1,), dtype=tf.float32)
+ return logs
+
+ def aggregate_logs(self, state, step_outputs):
+ if state is None:
+ state = {}
+ for key, value in step_outputs.items():
+ if key not in state:
+ state[key] = []
+ state[key].append(
+ np.concatenate([np.expand_dims(v.numpy(), axis=0) for v in value]))
+ return state
+
+ def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
+ for k, v in aggregated_logs.items():
+ aggregated_logs[k] = np.sum(np.stack(v, axis=0))
+ return aggregated_logs
+
+
+class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_strategy_combinations())
+ def test_multitask_evaluator(self, distribution):
+ with distribution.scope():
+ tasks = [
+ MockTask(params=cfg.TaskConfig(), name="bar"),
+ MockTask(params=cfg.TaskConfig(), name="foo")
+ ]
+ model = MockModel()
+ test_evaluator = evaluator.MultiTaskEvaluator(
+ eval_tasks=tasks, model=model)
+ results = test_evaluator.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
+ self.assertContainsSubset(["validation_loss", "acc"], results["bar"].keys())
+ self.assertContainsSubset(["validation_loss", "acc"], results["foo"].keys())
+ self.assertEqual(results["bar"]["validation_loss"], 0.0)
+ self.assertEqual(results["foo"]["validation_loss"], 1.0)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_multitask_evaluator_numpy_metrics(self, distribution):
+ with distribution.scope():
+ tasks = [
+ MockTask(params=cfg.TaskConfig(), name="bar"),
+ MockTask(params=cfg.TaskConfig(), name="foo")
+ ]
+ model = MockModel()
+ test_evaluator = evaluator.MultiTaskEvaluator(
+ eval_tasks=tasks, model=model)
+ results = test_evaluator.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertEqual(results["bar"]["counter"],
+ 5. * distribution.num_replicas_in_sync)
+ self.assertEqual(results["foo"]["counter"],
+ 5. * distribution.num_replicas_in_sync)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/modeling/multitask/interleaving_trainer.py b/modeling/official/modeling/multitask/interleaving_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6347e98e17895255224b71dac09b625c168f227c
--- /dev/null
+++ b/modeling/official/modeling/multitask/interleaving_trainer.py
@@ -0,0 +1,111 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Multitask trainer that interleaves each task's train step."""
+from typing import Union
+import gin
+import orbit
+import tensorflow as tf, tf_keras
+from official.modeling.multitask import base_model
+from official.modeling.multitask import base_trainer
+from official.modeling.multitask import multitask
+from official.modeling.multitask import task_sampler as sampler
+
+
+@gin.configurable
+class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
+ """MultiTask trainer that interleaves task update."""
+
+ def __init__(self,
+ multi_task: multitask.MultiTask,
+ multi_task_model: Union[tf_keras.Model,
+ base_model.MultiTaskBaseModel],
+ optimizer: Union[tf.optimizers.Optimizer,
+ tf_keras.optimizers.experimental.Optimizer,
+ tf_keras.optimizers.legacy.Optimizer],
+ task_sampler: sampler.TaskSampler,
+ trainer_options=None):
+ super().__init__(
+ multi_task=multi_task,
+ multi_task_model=multi_task_model,
+ optimizer=optimizer,
+ trainer_options=trainer_options)
+ self._task_sampler = task_sampler
+
+ # Build per task train step.
+ def _get_task_step(task_name, task):
+
+ def step_fn(inputs):
+ if isinstance(self.multi_task_model, base_model.MultiTaskBaseModel):
+ task_model = self.multi_task_model.sub_tasks[task_name]
+ else:
+ task_model = self.multi_task_model
+ task_logs = task.train_step(
+ inputs,
+ model=task_model,
+ optimizer=self.optimizer,
+ metrics=self.training_metrics[task_name])
+ self.training_losses[task_name].update_state(task_logs[task.loss])
+
+ return step_fn
+
+ self._task_train_step_map = {
+ name: _get_task_step(name, task)
+ for name, task in self.multi_task.tasks.items()
+ }
+
+ # TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
+ # on TensorBoard.
+ self._task_step_counters = {
+ name: orbit.utils.create_global_step() for name in self.multi_task.tasks
+ }
+
+ # If the new Keras optimizer is used, we require all model variables are
+ # created before the training and let the optimizer to create the slot
+ # variable all together.
+ if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer):
+ multi_task_model.build()
+ optimizer.build(multi_task_model.trainable_variables)
+
+ def task_step_counter(self, name):
+ return self._task_step_counters[name]
+
+ def train_step(self, iterator_map):
+ # Sample one task to train according to a multinomial distribution
+ rn = tf.random.stateless_uniform(shape=[], seed=(0, self.global_step))
+ cumulative_sample_distribution = self._task_sampler.task_cumulative_distribution(
+ self.global_step)
+ # Prepend a [0.0] for indexing convenience.
+ cumulative_sample_distribution = tf.concat(
+ [tf.constant([0.0], dtype=tf.float32), cumulative_sample_distribution],
+ axis=0)
+
+ for idx, (name, _) in enumerate(self.multi_task.tasks.items()):
+ begin = cumulative_sample_distribution[idx]
+ end = cumulative_sample_distribution[idx + 1]
+ if rn >= begin and rn < end:
+ self._strategy.run(
+ self._task_train_step_map[name], args=(next(iterator_map[name]),))
+ self.global_step.assign_add(1)
+ self.task_step_counter(name).assign_add(1)
+
+ def train_loop_end(self):
+ """Record loss and metric values per task."""
+ result = super().train_loop_end()
+ # Interleaving training does not have a good semantic for `total_loss`. In
+ # fact, it is always zero. To avoid confusion, we filter the `total_loss`
+ # from the result logs.
+ if 'total_loss' in result:
+ result.pop('total_loss')
+ return result
diff --git a/modeling/official/modeling/multitask/interleaving_trainer_test.py b/modeling/official/modeling/multitask/interleaving_trainer_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e47c05376c95ccb19f521b0872ee2058473edb2
--- /dev/null
+++ b/modeling/official/modeling/multitask/interleaving_trainer_test.py
@@ -0,0 +1,102 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for multitask.interleaving_trainer."""
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.modeling.multitask import configs
+from official.modeling.multitask import interleaving_trainer
+from official.modeling.multitask import multitask
+from official.modeling.multitask import task_sampler
+from official.modeling.multitask import test_utils
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ distribution=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ mode="eager",
+ )
+
+
+class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_strategy_combinations())
+ def test_multitask_interleaving_trainer(self, distribution):
+ with distribution.scope():
+ tasks = [
+ test_utils.MockFooTask(params=test_utils.FooConfig(), name="foo"),
+ test_utils.MockBarTask(params=test_utils.BarConfig(), name="bar")
+ ]
+ test_multitask = multitask.MultiTask(tasks=tasks)
+ test_optimizer = tf_keras.optimizers.SGD(0.1)
+ model = test_utils.MockMultiTaskModel()
+ sampler = task_sampler.UniformTaskSampler(
+ task_weights=test_multitask.task_weights)
+ test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
+ multi_task=test_multitask,
+ multi_task_model=model,
+ optimizer=test_optimizer,
+ task_sampler=sampler)
+ results = test_trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
+ self.assertContainsSubset(["training_loss", "bar_acc"],
+ results["bar"].keys())
+ self.assertContainsSubset(["training_loss", "foo_acc"],
+ results["foo"].keys())
+ self.assertNotIn("total_loss", results)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_trainer_with_configs(self, distribution):
+ config = configs.MultiTaskConfig(
+ task_routines=(configs.TaskRoutine(
+ task_name="foo",
+ task_config=test_utils.FooConfig(),
+ task_weight=3.0),
+ configs.TaskRoutine(
+ task_name="bar",
+ task_config=test_utils.BarConfig(),
+ task_weight=1.0)))
+ with distribution.scope():
+ test_multitask = multitask.MultiTask.from_config(config)
+ test_optimizer = tf_keras.optimizers.SGD(0.1)
+ model = test_utils.MockMultiTaskModel()
+ num_step = 1000
+ sampler = task_sampler.AnnealingTaskSampler(
+ task_weights=test_multitask.task_weights,
+ steps_per_epoch=num_step/5,
+ total_steps=num_step)
+ test_trainer = interleaving_trainer.MultiTaskInterleavingTrainer(
+ multi_task=test_multitask,
+ multi_task_model=model,
+ optimizer=test_optimizer,
+ task_sampler=sampler)
+ results = test_trainer.train(tf.convert_to_tensor(num_step, dtype=tf.int32))
+ self.assertContainsSubset(["training_loss", "bar_acc"],
+ results["bar"].keys())
+ self.assertContainsSubset(["training_loss", "foo_acc"],
+ results["foo"].keys())
+ self.assertEqual(test_trainer.global_step.numpy(), num_step)
+ bar_sampled_step = test_trainer.task_step_counter("bar").numpy()
+ foo_sampled_step = test_trainer.task_step_counter("foo").numpy()
+ self.assertEqual(bar_sampled_step + foo_sampled_step, num_step)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/modeling/multitask/multitask.py b/modeling/official/modeling/multitask/multitask.py
new file mode 100644
index 0000000000000000000000000000000000000000..ded8aa65113f5e8b59790bd90f7e32b21bc216fe
--- /dev/null
+++ b/modeling/official/modeling/multitask/multitask.py
@@ -0,0 +1,149 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Experimental MultiTask base class for multi-task training/evaluation."""
+import abc
+from typing import Dict, List, Optional, Text, Union
+
+import tensorflow as tf, tf_keras
+from official.core import base_task
+from official.core import config_definitions
+from official.core import task_factory
+from official.modeling import optimization
+from official.modeling.multitask import base_model
+from official.modeling.multitask import configs
+from official.modeling.privacy import configs as dp_configs
+
+OptimizationConfig = optimization.OptimizationConfig
+RuntimeConfig = config_definitions.RuntimeConfig
+DifferentialPrivacyConfig = dp_configs.DifferentialPrivacyConfig
+
+
+class MultiTask(tf.Module, metaclass=abc.ABCMeta):
+ """A multi-task class to manage multiple tasks."""
+
+ def __init__(self,
+ tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
+ task_weights: Optional[Dict[str, Union[float, int]]] = None,
+ task_eval_steps: Optional[Dict[str, int]] = None,
+ name: Optional[str] = None):
+ """MultiTask initialization.
+
+ Args:
+ tasks: a list or a flat dict of Task.
+ task_weights: a dict of (task, task weight), task weight can be applied
+ directly during loss summation in a joint backward step, or it can be
+ used to sample task among interleaved backward step.
+ task_eval_steps: a dict of (task, eval steps).
+ name: the instance name of a MultiTask object.
+ """
+ super().__init__(name=name)
+ if isinstance(tasks, list):
+ self._tasks = {}
+ for task in tasks:
+ if task.name in self._tasks:
+ raise ValueError("Duplicated tasks found, task.name is %s" %
+ task.name)
+ self._tasks[task.name] = task
+ elif isinstance(tasks, dict):
+ self._tasks = tasks
+ else:
+ raise ValueError("The tasks argument has an invalid type: %s" %
+ type(tasks))
+ self.task_eval_steps = task_eval_steps or {}
+ self._task_weights = task_weights or {}
+ self._task_weights = dict([
+ (name, self._task_weights.get(name, 1.0)) for name in self.tasks
+ ])
+
+ @classmethod
+ def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):
+ tasks = {}
+ task_eval_steps = {}
+ task_weights = {}
+ for task_routine in config.task_routines:
+ task_name = task_routine.task_name or task_routine.task_config.name
+ tasks[task_name] = task_factory.get_task(
+ task_routine.task_config, logging_dir=logging_dir, name=task_name)
+ task_eval_steps[task_name] = task_routine.eval_steps
+ task_weights[task_name] = task_routine.task_weight
+ return cls(
+ tasks, task_eval_steps=task_eval_steps, task_weights=task_weights)
+
+ @property
+ def tasks(self):
+ return self._tasks
+
+ def task_weight(self, task_name):
+ return self._task_weights[task_name]
+
+ @property
+ def task_weights(self):
+ return self._task_weights
+
+ @classmethod
+ def create_optimizer(cls,
+ optimizer_config: OptimizationConfig,
+ runtime_config: Optional[RuntimeConfig] = None,
+ dp_config: Optional[DifferentialPrivacyConfig] = None):
+ return base_task.Task.create_optimizer(
+ optimizer_config=optimizer_config, runtime_config=runtime_config,
+ dp_config=dp_config)
+
+ def joint_train_step(self, task_inputs,
+ multi_task_model: base_model.MultiTaskBaseModel,
+ optimizer: tf_keras.optimizers.Optimizer, task_metrics,
+ **kwargs):
+ """The joint train step.
+
+ Args:
+ task_inputs: a dictionary of task names and per-task features.
+ multi_task_model: a MultiTaskBaseModel instance.
+ optimizer: a tf.optimizers.Optimizer.
+ task_metrics: a dictionary of task names and per-task metrics.
+ **kwargs: other arguments to pass through.
+
+ Returns:
+ A dictionary of losses, inculding per-task losses and their weighted sum.
+ """
+ losses = {}
+ with tf.GradientTape() as tape:
+ total_loss = 0.0
+ for name, model in multi_task_model.sub_tasks.items():
+ inputs = task_inputs[name]
+ if isinstance(inputs, tuple) and len(inputs) == 2:
+ features, labels = inputs
+ elif isinstance(inputs, dict):
+ features, labels = inputs, inputs
+ else:
+ raise ValueError("The iterator output is neither a tuple nor a "
+ "dictionary. It is not implemented to support "
+ "such outputs.")
+ outputs = model(features, training=True)
+ task_loss = self.tasks[name].build_losses(labels, outputs)
+ task_weight = self.task_weight(name)
+ total_loss += task_weight * task_loss
+ losses[name] = task_loss
+ self.tasks[name].process_metrics(task_metrics[name], labels, outputs,
+ **kwargs)
+
+ # Scales loss as the default gradients allreduce performs sum inside
+ # the optimizer.
+ scaled_loss = total_loss / tf.distribute.get_strategy(
+ ).num_replicas_in_sync
+ tvars = multi_task_model.trainable_variables
+ grads = tape.gradient(scaled_loss, tvars)
+ optimizer.apply_gradients(list(zip(grads, tvars)))
+ losses["total_loss"] = total_loss
+ return losses
diff --git a/modeling/official/modeling/multitask/task_sampler.py b/modeling/official/modeling/multitask/task_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..620ebce3ddda0ea60ab730e70328c7704a087a41
--- /dev/null
+++ b/modeling/official/modeling/multitask/task_sampler.py
@@ -0,0 +1,128 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utils to sample tasks for interleaved optimization."""
+import abc
+from typing import Union, Dict, Text
+import tensorflow as tf, tf_keras
+
+from official.modeling.multitask import configs
+
+
+class TaskSampler(tf.Module, metaclass=abc.ABCMeta):
+ """An abstract class defining task sampling API for interleaving trainer."""
+
+ def __init__(self, task_weights: Dict[Text, Union[float, int]]):
+ self._task_weights = task_weights
+
+ @property
+ def task_weights(self):
+ return self._task_weights
+
+ @abc.abstractmethod
+ def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
+ """Compute cumulative distribution to sample tasks.
+
+ It calculates the cumulative distribution of the multinomial task
+ distribution with respect to which to be sampled against.
+
+ Args:
+ global_step: A tensor indicating current progess of training.
+
+ Returns:
+ A float tensor with shape (#(task), 1) that represents the cumulative
+ sampling distribution.
+ """
+ pass
+
+
+class UniformTaskSampler(TaskSampler):
+ """Sample all tasks uniformly."""
+
+ def __init__(self, task_weights: Dict[Text, Union[float, int]]):
+ super(UniformTaskSampler, self).__init__(task_weights=task_weights)
+ self._uniform_cumulative = tf.math.cumsum(
+ tf.constant(
+ [1.0 / len(self._task_weights)] * len(self._task_weights),
+ dtype=tf.float32))
+
+ def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
+ del global_step
+ return self._uniform_cumulative
+
+
+class ProportionalTaskSampler(TaskSampler):
+ """Sample tasks proportional to task weights."""
+
+ def __init__(self,
+ task_weights: Dict[Text, Union[float, int]],
+ alpha: float = 1.0):
+ super(ProportionalTaskSampler, self).__init__(task_weights=task_weights)
+ self._alpha = tf.cast(alpha, dtype=tf.float32)
+ task_weight_dict_ordered_list = tf.constant(
+ [weight for _, weight in self._task_weights.items()], dtype=tf.float32)
+ task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha)
+ task_distribution = task_sizes / tf.reduce_sum(task_sizes)
+ self._porportional_cumulative = tf.math.cumsum(task_distribution)
+
+ def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
+ del global_step
+ return self._porportional_cumulative
+
+
+class AnnealingTaskSampler(TaskSampler):
+ """Sample tasks according to task weights as well as training progress.
+
+ See http://proceedings.mlr.press/v97/stickland19a/stickland19a.pdf
+ """
+
+ def __init__(self,
+ task_weights: Dict[Text, Union[float, int]],
+ steps_per_epoch: int,
+ total_steps: int):
+ super(AnnealingTaskSampler, self).__init__(task_weights=task_weights)
+ self._steps_per_epoch = tf.cast(steps_per_epoch, dtype=tf.float32)
+ self._total_epochs = tf.cast(
+ total_steps / self._steps_per_epoch, dtype=tf.float32)
+
+ def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
+ cur_epoch = tf.math.floor(
+ tf.cast(global_step, dtype=tf.float32) / self._steps_per_epoch)
+ alpha = 1.0 - 0.8 * (cur_epoch - 1) / (self._total_epochs - 1 + 1e-10)
+ task_weight_dict_ordered_list = [
+ weight for _, weight in self._task_weights.items()
+ ]
+ task_sizes = tf.math.pow(
+ tf.constant(task_weight_dict_ordered_list, dtype=tf.float32),
+ tf.cast(alpha, dtype=tf.float32))
+ dynamic_task_distribution = task_sizes / tf.reduce_sum(task_sizes)
+ return tf.math.cumsum(dynamic_task_distribution)
+
+
+def get_task_sampler(config: configs.TaskSamplingConfig,
+ task_weights: Dict[Text, float]) -> TaskSampler:
+ """Utils to create task sampler with configuration and task weights."""
+ oneof_config = config.get()
+ if config.type == 'uniform':
+ return UniformTaskSampler(task_weights=task_weights)
+ elif config.type == 'proportional':
+ return ProportionalTaskSampler(
+ task_weights=task_weights, alpha=oneof_config.alpha)
+ elif config.type == 'annealing':
+ return AnnealingTaskSampler(
+ task_weights=task_weights,
+ steps_per_epoch=oneof_config.steps_per_epoch,
+ total_steps=oneof_config.total_steps)
+ else:
+ raise RuntimeError('Task sampler type not supported')
diff --git a/modeling/official/modeling/multitask/task_sampler_test.py b/modeling/official/modeling/multitask/task_sampler_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d0746914b8edce9404789349f6cc493b40967b
--- /dev/null
+++ b/modeling/official/modeling/multitask/task_sampler_test.py
@@ -0,0 +1,75 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for multitask.task_sampler."""
+import tensorflow as tf, tf_keras
+
+from official.modeling.multitask import configs
+from official.modeling.multitask import task_sampler as sampler
+
+
+class TaskSamplerTest(tf.test.TestCase):
+
+ def setUp(self):
+ super(TaskSamplerTest, self).setUp()
+ self._task_weights = {'A': 1.0, 'B': 2.0, 'C': 3.0}
+
+ def test_uniform_sample_distribution(self):
+ uniform_sampler = sampler.get_task_sampler(
+ configs.TaskSamplingConfig(type='uniform'), self._task_weights)
+ for step in range(5):
+ cumulative_distribution = uniform_sampler.task_cumulative_distribution(
+ tf.constant(step, dtype=tf.int64))
+ self.assertAllClose([0.333333, 0.666666, 1.0],
+ cumulative_distribution.numpy())
+
+ def test_proportional_sample_distribution(self):
+ prop_sampler = sampler.get_task_sampler(
+ configs.TaskSamplingConfig(
+ type='proportional',
+ proportional=configs.ProportionalSampleConfig(alpha=2.0)),
+ self._task_weights)
+ # CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
+ for step in range(5):
+ cumulative_distribution = prop_sampler.task_cumulative_distribution(
+ tf.constant(step, dtype=tf.int64))
+ self.assertAllClose([0.07142857, 0.35714286, 1.0],
+ cumulative_distribution.numpy())
+
+ def test_annealing_sample_distribution(self):
+ num_epoch = 3
+ step_per_epoch = 6
+ annel_sampler = sampler.get_task_sampler(
+ configs.TaskSamplingConfig(
+ type='annealing',
+ annealing=configs.AnnealingSampleConfig(
+ steps_per_epoch=step_per_epoch,
+ total_steps=step_per_epoch * num_epoch)), self._task_weights)
+
+ global_step = tf.Variable(
+ 0, dtype=tf.int64, name='global_step', trainable=False)
+ expected_cumulative_epochs = [[0.12056106, 0.4387236, 1.0],
+ [0.16666667, 0.5, 1.0],
+ [0.22477472, 0.5654695, 1.0]]
+ for epoch in range(num_epoch):
+ for _ in range(step_per_epoch):
+ cumulative_distribution = annel_sampler.task_cumulative_distribution(
+ tf.constant(global_step, dtype=tf.int64))
+ global_step.assign_add(1)
+ self.assertAllClose(expected_cumulative_epochs[epoch],
+ cumulative_distribution.numpy())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/multitask/test_utils.py b/modeling/official/modeling/multitask/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..747e88f26f8bc045b21934db8be0938ea7169cf2
--- /dev/null
+++ b/modeling/official/modeling/multitask/test_utils.py
@@ -0,0 +1,129 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Testing utils for mock models and tasks."""
+from typing import Dict, Text
+import tensorflow as tf, tf_keras
+from official.core import base_task
+from official.core import config_definitions as cfg
+from official.core import task_factory
+from official.modeling.multitask import base_model
+
+
+class MockFooModel(tf_keras.Model):
+ """A mock model can consume 'foo' and 'bar' inputs."""
+
+ def __init__(self, shared_layer, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._share_layer = shared_layer
+ self._foo_specific_layer = tf_keras.layers.Dense(1)
+ self.inputs = {"foo": tf_keras.Input(shape=(2,), dtype=tf.float32),
+ "bar": tf_keras.Input(shape=(2,), dtype=tf.float32)}
+
+ def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
+ self.add_loss(tf.zeros((1,), dtype=tf.float32))
+ if "foo" in inputs:
+ input_tensor = inputs["foo"]
+ else:
+ input_tensor = inputs["bar"]
+ return self._foo_specific_layer(self._share_layer(input_tensor))
+
+
+class MockBarModel(tf_keras.Model):
+ """A mock model can only consume 'bar' inputs."""
+
+ def __init__(self, shared_layer, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._share_layer = shared_layer
+ self._bar_specific_layer = tf_keras.layers.Dense(1)
+ self.inputs = {"bar": tf_keras.Input(shape=(2,), dtype=tf.float32)}
+
+ def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
+ self.add_loss(tf.zeros((2,), dtype=tf.float32))
+ return self._bar_specific_layer(self._share_layer(inputs["bar"]))
+
+
+class MockMultiTaskModel(base_model.MultiTaskBaseModel):
+
+ def __init__(self, *args, **kwargs):
+ self._shared_dense = tf_keras.layers.Dense(1)
+ super().__init__(*args, **kwargs)
+
+ def _instantiate_sub_tasks(self) -> Dict[Text, tf_keras.Model]:
+ return {
+ "foo": MockFooModel(self._shared_dense),
+ "bar": MockBarModel(self._shared_dense)
+ }
+
+
+def mock_data(feature_name):
+ """Mock dataset function."""
+
+ def _generate_data(_):
+ x = tf.zeros(shape=(2,), dtype=tf.float32)
+ label = tf.zeros([1], dtype=tf.int32)
+ return {feature_name: x}, label
+
+ dataset = tf.data.Dataset.range(1)
+ dataset = dataset.repeat()
+ dataset = dataset.map(
+ _generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ return dataset.prefetch(buffer_size=1).batch(2, drop_remainder=True)
+
+
+class FooConfig(cfg.TaskConfig):
+ pass
+
+
+class BarConfig(cfg.TaskConfig):
+ pass
+
+
+@task_factory.register_task_cls(FooConfig)
+class MockFooTask(base_task.Task):
+ """Mock foo task object for testing."""
+
+ def build_metrics(self, training: bool = True):
+ del training
+ return [tf_keras.metrics.Accuracy(name="foo_acc")]
+
+ def build_inputs(self, params):
+ return mock_data("foo")
+
+ def build_model(self) -> tf_keras.Model:
+ return MockFooModel(shared_layer=tf_keras.layers.Dense(1))
+
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
+ loss = tf_keras.losses.mean_squared_error(labels, model_outputs)
+ if aux_losses:
+ loss += tf.add_n(aux_losses)
+ return tf.reduce_mean(loss)
+
+
+@task_factory.register_task_cls(BarConfig)
+class MockBarTask(base_task.Task):
+ """Mock bar task object for testing."""
+
+ def build_metrics(self, training: bool = True):
+ del training
+ return [tf_keras.metrics.Accuracy(name="bar_acc")]
+
+ def build_inputs(self, params):
+ return mock_data("bar")
+
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
+ loss = tf_keras.losses.mean_squared_error(labels, model_outputs)
+ if aux_losses:
+ loss += tf.add_n(aux_losses)
+ return tf.reduce_mean(loss)
diff --git a/modeling/official/modeling/multitask/train_lib.py b/modeling/official/modeling/multitask/train_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd190bd7293a0ba1c032b300ac4001d574c796e4
--- /dev/null
+++ b/modeling/official/modeling/multitask/train_lib.py
@@ -0,0 +1,300 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Multitask training driver library."""
+# pytype: disable=attribute-error
+import os
+from typing import Any, List, Mapping, Optional, Tuple, Union
+from absl import logging
+import orbit
+import tensorflow as tf, tf_keras
+from official.core import base_task
+from official.core import base_trainer as core_lib
+from official.core import train_utils
+from official.modeling.multitask import base_model
+from official.modeling.multitask import base_trainer
+from official.modeling.multitask import configs
+from official.modeling.multitask import evaluator as evaluator_lib
+from official.modeling.multitask import interleaving_trainer
+from official.modeling.multitask import multitask
+from official.modeling.multitask import task_sampler
+
+TRAINERS = {
+ 'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer,
+ 'joint': base_trainer.MultiTaskBaseTrainer
+}
+
+
+def run_experiment(
+ *,
+ distribution_strategy: tf.distribute.Strategy,
+ task: multitask.MultiTask,
+ model: base_model.MultiTaskBaseModel,
+ mode: str,
+ params: configs.MultiTaskExperimentConfig,
+ model_dir: str,
+ run_post_eval: bool = False,
+ trainer: base_trainer.MultiTaskBaseTrainer = None,
+ eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
+ best_ckpt_exporter_creator: Optional[Any] = train_utils
+ .maybe_create_best_ckpt_exporter
+) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel,
+ Mapping[Any, Any]]]:
+ """Runs train/eval configured by the experiment params.
+
+ Args:
+ distribution_strategy: A distribution distribution_strategy.
+ task: A MultiTaskTask instance.
+ model: A MultiTaskBaseModel instance.
+ mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
+ or 'continuous_eval'.
+ params: ExperimentConfig instance.
+ model_dir: A 'str', a path to store model checkpoints and summaries.
+ run_post_eval: Whether to run post eval once after training, metrics logs
+ are returned.
+ trainer: (optional) A multi-task trainer to use. If none is provided, a
+ default one will be created based on `params`.
+ eval_summary_manager: Instance of the eval summary manager. If set, the
+ `eval_summary_dir` will be ignored. Otherwise the eval summary manager
+ will be created internally for TensorBoard summaries by default from the
+ `eval_summary_dir`.
+ best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
+
+ Returns:
+ model: `base_model.MultiTaskBaseModel` instance.
+ """
+
+ is_training = 'train' in mode
+ is_eval = 'eval' in mode
+ with distribution_strategy.scope():
+ optimizer = train_utils.create_optimizer(task, params)
+ kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
+ if params.trainer.trainer_type == 'interleaving':
+ sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
+ task.task_weights)
+ kwargs.update(dict(task_sampler=sampler))
+ if trainer is None:
+ trainer = TRAINERS[params.trainer.trainer_type](
+ **kwargs) if is_training else None
+ if is_eval:
+ eval_steps = task.task_eval_steps
+ evaluator = evaluator_lib.MultiTaskEvaluator(
+ eval_tasks=task.tasks.values(),
+ model=model,
+ eval_steps=eval_steps,
+ global_step=trainer.global_step if is_training else None,
+ checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
+ else:
+ evaluator = None
+
+ if trainer:
+ checkpoint = trainer.checkpoint
+ global_step = trainer.global_step
+ else:
+ checkpoint = evaluator.checkpoint
+ global_step = evaluator.global_step
+
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=model_dir,
+ max_to_keep=params.trainer.max_to_keep,
+ step_counter=global_step,
+ checkpoint_interval=params.trainer.checkpoint_interval,
+ init_fn=model.initialize)
+
+ controller = orbit.Controller(
+ strategy=distribution_strategy,
+ trainer=trainer,
+ evaluator=evaluator,
+ global_step=global_step,
+ steps_per_loop=params.trainer.steps_per_loop,
+ checkpoint_manager=checkpoint_manager,
+ summary_dir=os.path.join(model_dir, 'train'),
+ eval_summary_dir=os.path.join(model_dir, 'validation'),
+ eval_summary_manager=eval_summary_manager,
+ summary_interval=params.trainer.summary_interval)
+
+ logging.info('Starts to execute mode: %s', mode)
+ with distribution_strategy.scope():
+ if mode == 'train':
+ controller.train(steps=params.trainer.train_steps)
+ elif mode == 'train_and_eval':
+ controller.train_and_evaluate(
+ train_steps=params.trainer.train_steps,
+ eval_steps=params.trainer.validation_steps,
+ eval_interval=params.trainer.validation_interval)
+ elif mode == 'eval':
+ controller.evaluate(steps=params.trainer.validation_steps)
+ elif mode == 'continuous_eval':
+
+ def timeout_fn():
+ if evaluator.global_step.numpy() >= params.trainer.train_steps:
+ return True
+ return False
+
+ controller.evaluate_continuously(
+ steps=params.trainer.validation_steps,
+ timeout=params.trainer.continuous_eval_timeout,
+ timeout_fn=timeout_fn)
+ else:
+ raise NotImplementedError('The mode is not implemented: %s' % mode)
+
+ if run_post_eval:
+ return model, evaluator.evaluate(
+ tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
+ else:
+ return model
+
+
+def run_experiment_with_multitask_eval(
+ *,
+ distribution_strategy: tf.distribute.Strategy,
+ train_task: base_task.Task,
+ eval_tasks: List[base_task.Task],
+ mode: str,
+ params: configs.MultiEvalExperimentConfig,
+ model_dir: str,
+ run_post_eval: bool = False,
+ save_summary: bool = True,
+ trainer: Optional[core_lib.Trainer] = None,
+ eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
+ best_ckpt_exporter_creator: Optional[Any] = train_utils
+ .maybe_create_best_ckpt_exporter,
+) -> Tuple[Any, Any]:
+ """Runs train/eval configured by the experiment params.
+
+ Args:
+ distribution_strategy: A distribution distribution_strategy.
+ train_task: A base_task.Task instance.
+ eval_tasks: A list of evaluation tasks.
+ mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
+ or 'continuous_eval'.
+ params: MultiEvalExperimentConfig instance.
+ model_dir: A 'str', a path to store model checkpoints and summaries.
+ run_post_eval: Whether to run post eval once after training, metrics logs
+ are returned.
+ save_summary: Whether to save train and validation summary.
+ trainer: the core_lib.Trainer instance. It should be created within the
+ strategy.scope(). If not provided, an instance will be created by default
+ if `mode` contains 'train'.
+ eval_summary_manager: Instance of the eval summary manager. If set, the
+ `eval_summary_dir` will be ignored. Otherwise the eval summary manager
+ will be created internally for TensorBoard summaries by default from the
+ `eval_summary_dir`.
+ best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
+
+ Returns:
+ model: `tf_keras.Model` instance.
+ """
+
+ is_training = 'train' in mode
+ is_eval = 'eval' in mode
+ with distribution_strategy.scope():
+ if is_training:
+ trainer = trainer or core_lib.Trainer(
+ config=params,
+ task=train_task,
+ model=train_task.build_model(),
+ optimizer=train_utils.create_optimizer(train_task, params),
+ train=True,
+ evaluate=False)
+ else:
+ trainer = None
+
+ # Build the model or fetch the pre-cached one (which could be either
+ # multi-task model or single task model).
+ model = None
+ if trainer is None:
+ if isinstance(train_task, multitask.MultiTask):
+ model = train_task.build_multitask_model()
+ else:
+ model = train_task.build_model()
+ else:
+ if isinstance(trainer, base_trainer.MultiTaskBaseTrainer):
+ model = trainer.multi_task_model
+ else:
+ model = trainer.model
+
+ if is_eval:
+ eval_steps = dict([(task_routine.task_config.name,
+ task_routine.eval_steps)
+ for task_routine in params.eval_tasks])
+ evaluator = evaluator_lib.MultiTaskEvaluator(
+ eval_tasks=eval_tasks,
+ model=model,
+ global_step=trainer.global_step if is_training else None,
+ eval_steps=eval_steps,
+ checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
+ else:
+ evaluator = None
+
+ if trainer:
+ checkpoint = trainer.checkpoint
+ global_step = trainer.global_step
+ else:
+ checkpoint = evaluator.checkpoint
+ global_step = evaluator.global_step
+
+ checkpoint_manager = tf.train.CheckpointManager(
+ checkpoint,
+ directory=model_dir,
+ max_to_keep=params.trainer.max_to_keep,
+ step_counter=global_step,
+ checkpoint_interval=params.trainer.checkpoint_interval,
+ init_fn=trainer.initialize if trainer else None)
+
+ controller = orbit.Controller(
+ strategy=distribution_strategy,
+ trainer=trainer,
+ evaluator=evaluator,
+ global_step=global_step,
+ steps_per_loop=params.trainer.steps_per_loop,
+ checkpoint_manager=checkpoint_manager,
+ summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
+ eval_summary_dir=os.path.join(model_dir, 'validation') if
+ (save_summary) else None,
+ eval_summary_manager=eval_summary_manager,
+ summary_interval=params.trainer.summary_interval if
+ (save_summary) else None)
+
+ logging.info('Starts to execute mode: %s', mode)
+ with distribution_strategy.scope():
+ if mode == 'train':
+ controller.train(steps=params.trainer.train_steps)
+ elif mode == 'train_and_eval':
+ controller.train_and_evaluate(
+ train_steps=params.trainer.train_steps,
+ eval_steps=params.trainer.validation_steps,
+ eval_interval=params.trainer.validation_interval)
+ elif mode == 'eval':
+ controller.evaluate(steps=params.trainer.validation_steps)
+ elif mode == 'continuous_eval':
+
+ def timeout_fn():
+ if evaluator.global_step.numpy() >= params.trainer.train_steps:
+ return True
+ return False
+
+ controller.evaluate_continuously(
+ steps=params.trainer.validation_steps,
+ timeout=params.trainer.continuous_eval_timeout,
+ timeout_fn=timeout_fn)
+ else:
+ raise NotImplementedError('The mode is not implemented: %s' % mode)
+
+ if run_post_eval:
+ return model, evaluator.evaluate(
+ tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
+ else:
+ return model, {} # pytype: disable=bad-return-type # typed-keras
diff --git a/modeling/official/modeling/multitask/train_lib_test.py b/modeling/official/modeling/multitask/train_lib_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6471baec6681876b209cc3564c462ba5af0272d
--- /dev/null
+++ b/modeling/official/modeling/multitask/train_lib_test.py
@@ -0,0 +1,123 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for multitask.train_lib."""
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.core import task_factory
+from official.modeling.hyperparams import params_dict
+from official.modeling.multitask import configs
+from official.modeling.multitask import multitask
+from official.modeling.multitask import test_utils
+from official.modeling.multitask import train_lib
+
+
+class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._test_config = {
+ 'trainer': {
+ 'checkpoint_interval': 10,
+ 'steps_per_loop': 10,
+ 'summary_interval': 10,
+ 'train_steps': 10,
+ 'validation_steps': 5,
+ 'validation_interval': 10,
+ 'continuous_eval_timeout': 1,
+ 'optimizer_config': {
+ 'optimizer': {
+ 'type': 'sgd',
+ },
+ 'learning_rate': {
+ 'type': 'constant'
+ }
+ }
+ },
+ }
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ mode='eager',
+ optimizer=['sgd_experimental', 'sgd'],
+ flag_mode=['train', 'eval', 'train_and_eval']))
+ def test_end_to_end(self, distribution_strategy, optimizer, flag_mode):
+ model_dir = self.get_temp_dir()
+ experiment_config = configs.MultiTaskExperimentConfig(
+ task=configs.MultiTaskConfig(
+ task_routines=(
+ configs.TaskRoutine(
+ task_name='foo', task_config=test_utils.FooConfig()),
+ configs.TaskRoutine(
+ task_name='bar', task_config=test_utils.BarConfig()))))
+ experiment_config = params_dict.override_params_dict(
+ experiment_config, self._test_config, is_strict=False)
+ experiment_config.trainer.optimizer_config.optimizer.type = optimizer
+ with distribution_strategy.scope():
+ test_multitask = multitask.MultiTask.from_config(experiment_config.task)
+ model = test_utils.MockMultiTaskModel()
+ train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=test_multitask,
+ model=model,
+ mode=flag_mode,
+ params=experiment_config,
+ model_dir=model_dir)
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.default_strategy,
+ strategy_combinations.cloud_tpu_strategy,
+ strategy_combinations.one_device_strategy_gpu,
+ ],
+ mode='eager',
+ flag_mode=['train', 'eval', 'train_and_eval']))
+ def test_end_to_end_multi_eval(self, distribution_strategy, flag_mode):
+ model_dir = self.get_temp_dir()
+ experiment_config = configs.MultiEvalExperimentConfig(
+ task=test_utils.FooConfig(),
+ eval_tasks=(configs.TaskRoutine(
+ task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
+ configs.TaskRoutine(
+ task_name='bar',
+ task_config=test_utils.BarConfig(),
+ eval_steps=3)))
+ experiment_config = params_dict.override_params_dict(
+ experiment_config, self._test_config, is_strict=False)
+ with distribution_strategy.scope():
+ train_task = task_factory.get_task(experiment_config.task)
+ eval_tasks = [
+ task_factory.get_task(config.task_config, name=config.task_name)
+ for config in experiment_config.eval_tasks
+ ]
+ train_lib.run_experiment_with_multitask_eval(
+ distribution_strategy=distribution_strategy,
+ train_task=train_task,
+ eval_tasks=eval_tasks,
+ mode=flag_mode,
+ params=experiment_config,
+ model_dir=model_dir)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/optimization/__init__.py b/modeling/official/modeling/optimization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bae0d0ba195e1c6f47ee7bc3afe7b4164a5aa74a
--- /dev/null
+++ b/modeling/official/modeling/optimization/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Optimization package definition."""
+
+# pylint: disable=wildcard-import
+from official.modeling.optimization.configs.learning_rate_config import *
+from official.modeling.optimization.configs.optimization_config import *
+from official.modeling.optimization.configs.optimizer_config import *
+from official.modeling.optimization.ema_optimizer import ExponentialMovingAverage
+from official.modeling.optimization.lr_schedule import *
+from official.modeling.optimization.optimizer_factory import OptimizerFactory
+from official.modeling.optimization.optimizer_factory import register_optimizer_cls
diff --git a/modeling/official/modeling/optimization/adafactor_optimizer.py b/modeling/official/modeling/optimization/adafactor_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b485bf01d61368fd85c635a633c59f2bef72e8ad
--- /dev/null
+++ b/modeling/official/modeling/optimization/adafactor_optimizer.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Adafactor optimizer.
+
+A new optimizer that will be open sourced soon.
+"""
+# pylint: disable=invalid-name, represents an unimplemented class definition.
+Adafactor = "Unimplemented"
diff --git a/modeling/official/modeling/optimization/configs/__init__.py b/modeling/official/modeling/optimization/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/modeling/optimization/configs/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/modeling/optimization/configs/learning_rate_config.py b/modeling/official/modeling/optimization/configs/learning_rate_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f37ed3713bbbc2c2fb32e2cf7669f3851926682a
--- /dev/null
+++ b/modeling/official/modeling/optimization/configs/learning_rate_config.py
@@ -0,0 +1,288 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Dataclasses for learning rate schedule config."""
+from typing import List, Optional
+
+import dataclasses
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class ConstantLrConfig(base_config.Config):
+ """Configuration for constant learning rate.
+
+ This class is a containers for the constant learning rate decay configs.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to Constant.
+ learning_rate: A float. The learning rate. Defaults to 0.1.
+ """
+ name: str = 'Constant'
+ learning_rate: float = 0.1
+
+
+@dataclasses.dataclass
+class StepwiseLrConfig(base_config.Config):
+ """Configuration for stepwise learning rate decay.
+
+ This class is a container for the piecewise constant learning rate scheduling
+ configs. It will configure an instance of PiecewiseConstantDecay keras
+ learning rate schedule.
+
+ An example (from keras docs): use a learning rate that's 1.0 for the first
+ 100001 steps, 0.5 for the next 10000 steps, and 0.1 for any additional steps.
+ ```python
+ boundaries: [100000, 110000]
+ values: [1.0, 0.5, 0.1]
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
+ boundaries: A list of ints of strictly increasing entries. Defaults to None.
+ values: A list of floats that specifies the values for the intervals defined
+ by `boundaries`. It should have one more element than `boundaries`.
+ The learning rate is computed as follows: [0, boundaries[0]] ->
+ values[0] [boundaries[0], boundaries[1]] -> values[1]
+ [boundaries[n-1], boundaries[n]] -> values[n] [boundaries[n],
+ end] -> values[n+1] Defaults to None.
+ offset: An int. The offset applied to steps. Defaults to 0.
+ """
+ name: str = 'PiecewiseConstantDecay'
+ boundaries: Optional[List[int]] = None
+ values: Optional[List[float]] = None
+ offset: int = 0
+
+
+@dataclasses.dataclass
+class ExponentialLrConfig(base_config.Config):
+ """Configuration for exponential learning rate decay.
+
+ This class is a containers for the exponential learning rate decay configs.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to ExponentialDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to None.
+ decay_steps: A positive integer that is used for decay computation. Defaults
+ to None.
+ decay_rate: A float. Defaults to None.
+ staircase: A boolean, if true, learning rate is decreased at discreate
+ intervals. Defaults to False.
+ offset: An int. The offset applied to steps. Defaults to 0.
+ """
+ name: str = 'ExponentialDecay'
+ initial_learning_rate: Optional[float] = None
+ decay_steps: Optional[int] = None
+ decay_rate: Optional[float] = None
+ staircase: Optional[bool] = None
+ offset: int = 0
+
+
+@dataclasses.dataclass
+class PolynomialLrConfig(base_config.Config):
+ """Configuration for polynomial learning rate decay.
+
+ This class is a containers for the polynomial learning rate decay configs.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to PolynomialDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to None.
+ decay_steps: A positive integer that is used for decay computation. Defaults
+ to None.
+ end_learning_rate: A float. The minimal end learning rate.
+ power: A float. The power of the polynomial. Defaults to linear, 1.0.
+ cycle: A boolean, whether or not it should cycle beyond decay_steps.
+ Defaults to False.
+ offset: An int. The offset applied to steps. Defaults to 0.
+ """
+ name: str = 'PolynomialDecay'
+ initial_learning_rate: Optional[float] = None
+ decay_steps: Optional[int] = None
+ end_learning_rate: float = 0.0001
+ power: float = 1.0
+ cycle: bool = False
+ offset: int = 0
+
+
+@dataclasses.dataclass
+class CosineLrConfig(base_config.Config):
+ """Configuration for Cosine learning rate decay.
+
+ This class is a containers for the cosine learning rate decay configs,
+ tf_keras.experimental.CosineDecay.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to CosineDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to None.
+ decay_steps: A positive integer that is used for decay computation. Defaults
+ to None.
+ alpha: A float. Minimum learning rate value as a fraction of
+ initial_learning_rate.
+ offset: An int. The offset applied to steps. Defaults to 0.
+ """
+ name: str = 'CosineDecay'
+ initial_learning_rate: Optional[float] = None
+ decay_steps: Optional[int] = None
+ alpha: float = 0.0
+ offset: int = 0
+
+
+@dataclasses.dataclass
+class DirectPowerLrConfig(base_config.Config):
+ """Configuration for DirectPower learning rate decay.
+
+ This class configures a schedule following follows lr * (step)^power.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to DirectPowerDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to None.
+ power: A float. Defaults to -0.5, for sqrt decay.
+ """
+ name: str = 'DirectPowerDecay'
+ initial_learning_rate: Optional[float] = None
+ power: float = -0.5
+
+
+@dataclasses.dataclass
+class PowerAndLinearDecayLrConfig(base_config.Config):
+ """Configuration for DirectPower learning rate decay.
+
+ The schedule has the following behavoir.
+ Let offset_step = step - offset.
+ 1) offset_step < 0, the actual learning rate equals initial_learning_rate.
+ 2) offset_step <= total_decay_steps * (1 - linear_decay_fraction), the
+ actual learning rate equals lr * offset_step^power.
+ 3) total_decay_steps * (1 - linear_decay_fraction) <= offset_step <
+ total_decay_steps, the actual learning rate equals lr * offset_step^power *
+ (total_decay_steps - offset_step) / (total_decay_steps *
+ linear_decay_fraction).
+ 4) offset_step >= total_decay_steps, the actual learning rate equals zero.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to
+ PowerAndLinearDecay.
+ initial_learning_rate: A float. The initial learning rate. Defaults to None.
+ total_decay_steps: An int. The total number of steps for power + linear
+ decay. Defaults to None.
+ power: A float. The order of the polynomial. Defaults to -0.5, for sqrt
+ decay.
+ linear_decay_fraction: A float. In the last `linear_decay_fraction` steps,
+ the learning rate will be multiplied by a linear decay. Defaults to 0.1.
+ offset: An int. The offset applied to steps. Defaults to 0.
+ """
+ name: str = 'PowerAndLinearDecay'
+ initial_learning_rate: Optional[float] = None
+ total_decay_steps: Optional[int] = None
+ power: float = -0.5
+ linear_decay_fraction: float = 0.1
+ offset: int = 0
+
+
+@dataclasses.dataclass
+class PowerDecayWithOffsetLrConfig(base_config.Config):
+ """Configuration for power learning rate decay with step offset.
+
+ Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
+ Otherwise, learning rate equals to lr * (step - offset)^power.
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to
+ PowerDecayWithOffset.
+ initial_learning_rate: A float. The initial learning rate. Defaults to None.
+ power: A float. Defaults to -0.5, for sqrt decay.
+ offset: An integer. Power decay happens after `offset` steps.
+ pre_offset_learning_rate: A float. The constant learning rate before
+ `offset` steps.
+ """
+ name: str = 'PowerDecayWithOffset'
+ initial_learning_rate: Optional[float] = None
+ power: float = -0.5
+ offset: int = 0
+ pre_offset_learning_rate: float = 1.0e6
+
+
+@dataclasses.dataclass
+class StepCosineLrConfig(base_config.Config):
+ """Configuration for stepwise learning rate decay.
+
+ This class is a container for the piecewise cosine learning rate scheduling
+ configs. It will configure an instance of StepCosineDecayWithOffset keras
+ learning rate schedule.
+
+ ```python
+ boundaries: [100000, 110000]
+ values: [1.0, 0.5]
+ lr_decayed_fn = (
+ lr_schedule.StepCosineDecayWithOffset(
+ boundaries,
+ values))
+ ```
+ from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
+ from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
+
+ Attributes:
+ name: The name of the learning rate schedule. Defaults to PiecewiseConstant.
+ boundaries: A list of ints of strictly increasing entries. Defaults to None.
+ values: A list of floats that specifies the values for the intervals defined
+ by `boundaries`. It should have one more element than `boundaries`.
+ The learning rate is computed as follows:
+ [0, boundaries[0]] -> cosine from values[0] to values[1]
+ [boundaries[0], boundaries[1]] -> values[1] to values[2]
+ ...
+ [boundaries[n-1], boundaries[n]] -> values[n] to values[n+1]
+ [boundaries[n], end] -> values[n+1] to 0.
+ offset: An int. The offset applied to steps. Defaults to 0.
+ """
+ name: str = 'StepCosineDecayWithOffset'
+ boundaries: Optional[List[int]] = None
+ values: Optional[List[float]] = None
+ offset: int = 0
+
+
+@dataclasses.dataclass
+class LinearWarmupConfig(base_config.Config):
+ """Configuration for linear warmup schedule config.
+
+ This class is a container for the linear warmup schedule configs.
+ Warmup_learning_rate is the initial learning rate, the final learning rate of
+ the warmup period is the learning_rate of the optimizer in use. The learning
+ rate at each step linearly increased according to the following formula:
+ warmup_learning_rate = warmup_learning_rate +
+ step / warmup_steps * (final_learning_rate - warmup_learning_rate).
+ Using warmup overrides the learning rate schedule by the number of warmup
+ steps.
+
+ Attributes:
+ name: The name of warmup schedule. Defaults to linear.
+ warmup_learning_rate: Initial learning rate for the warmup. Defaults to 0.
+ warmup_steps: Warmup steps. Defaults to None.
+ """
+ name: str = 'linear'
+ warmup_learning_rate: float = 0
+ warmup_steps: Optional[int] = None
+
+
+@dataclasses.dataclass
+class PolynomialWarmupConfig(base_config.Config):
+ """Configuration for linear warmup schedule config.
+
+ This class is a container for the polynomial warmup schedule configs.
+
+ Attributes:
+ name: The name of warmup schedule. Defaults to Polynomial.
+ power: Polynomial power. Defaults to 1.
+ warmup_steps: Warmup steps. Defaults to None.
+ """
+ name: str = 'polynomial'
+ power: float = 1
+ warmup_steps: Optional[int] = None
diff --git a/modeling/official/modeling/optimization/configs/optimization_config.py b/modeling/official/modeling/optimization/configs/optimization_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e52dec85cf1762a8d9e0c068f128188927bd932c
--- /dev/null
+++ b/modeling/official/modeling/optimization/configs/optimization_config.py
@@ -0,0 +1,171 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Dataclasses for optimization configs.
+
+This file define the dataclass for optimization configs (OptimizationConfig).
+It also has two helper functions get_optimizer_config, and get_lr_config from
+an OptimizationConfig class.
+"""
+from typing import Optional
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+from official.modeling.hyperparams import oneof
+from official.modeling.optimization.configs import learning_rate_config as lr_cfg
+from official.modeling.optimization.configs import optimizer_config as opt_cfg
+
+
+@dataclasses.dataclass
+class OptimizerConfig(oneof.OneOfConfig):
+ """Configuration for optimizer.
+
+ Attributes:
+ type: 'str', type of optimizer to be used, on the of fields below.
+ sgd: sgd optimizer config.
+ adam: adam optimizer config.
+ adamw: adam with weight decay.
+ lamb: lamb optimizer.
+ rmsprop: rmsprop optimizer.
+ lars: lars optimizer.
+ adagrad: adagrad optimizer.
+ slide: slide optimizer.
+ adafactor: adafactor optimizer.
+ adafactor_keras: adafactor optimizer.
+ """
+ type: Optional[str] = None
+ sgd: opt_cfg.SGDConfig = dataclasses.field(default_factory=opt_cfg.SGDConfig)
+ sgd_experimental: opt_cfg.SGDExperimentalConfig = dataclasses.field(
+ default_factory=opt_cfg.SGDExperimentalConfig
+ )
+ adam: opt_cfg.AdamConfig = dataclasses.field(
+ default_factory=opt_cfg.AdamConfig
+ )
+ adam_experimental: opt_cfg.AdamExperimentalConfig = dataclasses.field(
+ default_factory=opt_cfg.AdamExperimentalConfig
+ )
+ adamw: opt_cfg.AdamWeightDecayConfig = dataclasses.field(
+ default_factory=opt_cfg.AdamWeightDecayConfig
+ )
+ adamw_experimental: opt_cfg.AdamWeightDecayExperimentalConfig = (
+ dataclasses.field(
+ default_factory=opt_cfg.AdamWeightDecayExperimentalConfig
+ )
+ )
+ lamb: opt_cfg.LAMBConfig = dataclasses.field(
+ default_factory=opt_cfg.LAMBConfig
+ )
+ rmsprop: opt_cfg.RMSPropConfig = dataclasses.field(
+ default_factory=opt_cfg.RMSPropConfig
+ )
+ lars: opt_cfg.LARSConfig = dataclasses.field(
+ default_factory=opt_cfg.LARSConfig
+ )
+ adagrad: opt_cfg.AdagradConfig = dataclasses.field(
+ default_factory=opt_cfg.AdagradConfig
+ )
+ slide: opt_cfg.SLIDEConfig = dataclasses.field(
+ default_factory=opt_cfg.SLIDEConfig
+ )
+ adafactor: opt_cfg.AdafactorConfig = dataclasses.field(
+ default_factory=opt_cfg.AdafactorConfig
+ )
+ adafactor_keras: opt_cfg.AdafactorKerasConfig = dataclasses.field(
+ default_factory=opt_cfg.AdafactorKerasConfig
+ )
+
+
+@dataclasses.dataclass
+class LrConfig(oneof.OneOfConfig):
+ """Configuration for lr schedule.
+
+ Attributes:
+ type: 'str', type of lr schedule to be used, one of the fields below.
+ constant: constant learning rate config.
+ stepwise: stepwise learning rate config.
+ exponential: exponential learning rate config.
+ polynomial: polynomial learning rate config.
+ cosine: cosine learning rate config.
+ power: step^power learning rate config.
+ power_linear: learning rate config of step^power followed by
+ step^power*linear.
+ power_with_offset: power decay with a step offset.
+ step_cosine_with_offset: Step cosine with a step offset.
+ """
+ type: Optional[str] = None
+ constant: lr_cfg.ConstantLrConfig = dataclasses.field(
+ default_factory=lr_cfg.ConstantLrConfig
+ )
+ stepwise: lr_cfg.StepwiseLrConfig = dataclasses.field(
+ default_factory=lr_cfg.StepwiseLrConfig
+ )
+ exponential: lr_cfg.ExponentialLrConfig = dataclasses.field(
+ default_factory=lr_cfg.ExponentialLrConfig
+ )
+ polynomial: lr_cfg.PolynomialLrConfig = dataclasses.field(
+ default_factory=lr_cfg.PolynomialLrConfig
+ )
+ cosine: lr_cfg.CosineLrConfig = dataclasses.field(
+ default_factory=lr_cfg.CosineLrConfig
+ )
+ power: lr_cfg.DirectPowerLrConfig = dataclasses.field(
+ default_factory=lr_cfg.DirectPowerLrConfig
+ )
+ power_linear: lr_cfg.PowerAndLinearDecayLrConfig = dataclasses.field(
+ default_factory=lr_cfg.PowerAndLinearDecayLrConfig
+ )
+ power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = dataclasses.field(
+ default_factory=lr_cfg.PowerDecayWithOffsetLrConfig
+ )
+ step_cosine_with_offset: lr_cfg.StepCosineLrConfig = dataclasses.field(
+ default_factory=lr_cfg.StepCosineLrConfig
+ )
+
+
+@dataclasses.dataclass
+class WarmupConfig(oneof.OneOfConfig):
+ """Configuration for lr schedule.
+
+ Attributes:
+ type: 'str', type of warmup schedule to be used, one of the fields below.
+ linear: linear warmup config.
+ polynomial: polynomial warmup config.
+ """
+ type: Optional[str] = None
+ linear: lr_cfg.LinearWarmupConfig = dataclasses.field(
+ default_factory=lr_cfg.LinearWarmupConfig
+ )
+ polynomial: lr_cfg.PolynomialWarmupConfig = dataclasses.field(
+ default_factory=lr_cfg.PolynomialWarmupConfig
+ )
+
+
+@dataclasses.dataclass
+class OptimizationConfig(base_config.Config):
+ """Configuration for optimizer and learning rate schedule.
+
+ Attributes:
+ optimizer: optimizer oneof config.
+ ema: optional exponential moving average optimizer config, if specified, ema
+ optimizer will be used.
+ learning_rate: learning rate oneof config.
+ warmup: warmup oneof config.
+ """
+ optimizer: OptimizerConfig = dataclasses.field(
+ default_factory=OptimizerConfig
+ )
+ ema: Optional[opt_cfg.EMAConfig] = None
+ learning_rate: LrConfig = dataclasses.field(default_factory=LrConfig)
+ warmup: WarmupConfig = dataclasses.field(default_factory=WarmupConfig)
diff --git a/modeling/official/modeling/optimization/configs/optimization_config_test.py b/modeling/official/modeling/optimization/configs/optimization_config_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..902dbcd4dad334e6d1d94a7785c637c60c13eac5
--- /dev/null
+++ b/modeling/official/modeling/optimization/configs/optimization_config_test.py
@@ -0,0 +1,59 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for optimization_config.py."""
+
+import tensorflow as tf, tf_keras
+
+from official.modeling.optimization.configs import learning_rate_config as lr_cfg
+from official.modeling.optimization.configs import optimization_config
+from official.modeling.optimization.configs import optimizer_config as opt_cfg
+
+
+class OptimizerConfigTest(tf.test.TestCase):
+
+ def test_no_optimizer(self):
+ optimizer = optimization_config.OptimizationConfig({}).optimizer.get()
+ self.assertIsNone(optimizer)
+
+ def test_no_lr_schedule(self):
+ lr = optimization_config.OptimizationConfig({}).learning_rate.get()
+ self.assertIsNone(lr)
+
+ def test_no_warmup_schedule(self):
+ warmup = optimization_config.OptimizationConfig({}).warmup.get()
+ self.assertIsNone(warmup)
+
+ def test_config(self):
+ opt_config = optimization_config.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {} # default config
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {}
+ },
+ 'warmup': {
+ 'type': 'linear'
+ }
+ })
+ self.assertEqual(opt_config.optimizer.get(), opt_cfg.SGDConfig())
+ self.assertEqual(opt_config.learning_rate.get(),
+ lr_cfg.PolynomialLrConfig())
+ self.assertEqual(opt_config.warmup.get(), lr_cfg.LinearWarmupConfig())
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/optimization/configs/optimizer_config.py b/modeling/official/modeling/optimization/configs/optimizer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb52d6c2b40b401a43c8fe713f0ebdba96e1b8d4
--- /dev/null
+++ b/modeling/official/modeling/optimization/configs/optimizer_config.py
@@ -0,0 +1,374 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Dataclasses for optimizer configs."""
+from typing import List, Optional
+
+import dataclasses
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class BaseOptimizerConfig(base_config.Config):
+ """Base optimizer config.
+
+ Attributes:
+ clipnorm: float >= 0 or None. If not None, Gradients will be clipped when
+ their L2 norm exceeds this value.
+ clipvalue: float >= 0 or None. If not None, Gradients will be clipped when
+ their absolute value exceeds this value.
+ global_clipnorm: float >= 0 or None. If not None, gradient of all weights is
+ clipped so that their global norm is no higher than this value
+ """
+ clipnorm: Optional[float] = None
+ clipvalue: Optional[float] = None
+ global_clipnorm: Optional[float] = None
+
+
+@dataclasses.dataclass
+class SGDConfig(BaseOptimizerConfig):
+ """Configuration for SGD optimizer.
+
+ The attributes for this class matches the arguments of tf_keras.optimizer.SGD.
+
+ Attributes:
+ name: name of the optimizer.
+ decay: decay rate for SGD optimizer.
+ nesterov: nesterov for SGD optimizer.
+ momentum: momentum for SGD optimizer.
+ """
+ name: str = "SGD"
+ decay: float = 0.0
+ nesterov: bool = False
+ momentum: float = 0.0
+
+
+# TODO(b/216129465): Merge this config with SGDConfig after the experimental
+# optimizer graduates.
+@dataclasses.dataclass
+class SGDExperimentalConfig(BaseOptimizerConfig):
+ """Configuration for SGD optimizer.
+
+ The attributes for this class matches the arguments of
+ `tf_keras.optimizer.experimental.SGD`.
+
+ Attributes:
+ name: name of the optimizer.
+ nesterov: nesterov for SGD optimizer.
+ momentum: momentum for SGD optimizer.
+ jit_compile: if True, jit compile will be used.
+ """
+ name: str = "SGD"
+ nesterov: bool = False
+ momentum: float = 0.0
+ jit_compile: bool = False
+
+
+@dataclasses.dataclass
+class RMSPropConfig(BaseOptimizerConfig):
+ """Configuration for RMSProp optimizer.
+
+ The attributes for this class matches the arguments of
+ tf_keras.optimizers.RMSprop.
+
+ Attributes:
+ name: name of the optimizer.
+ rho: discounting factor for RMSprop optimizer.
+ momentum: momentum for RMSprop optimizer.
+ epsilon: epsilon value for RMSprop optimizer, help with numerical stability.
+ centered: Whether to normalize gradients or not.
+ """
+ name: str = "RMSprop"
+ rho: float = 0.9
+ momentum: float = 0.0
+ epsilon: float = 1e-7
+ centered: bool = False
+
+
+@dataclasses.dataclass
+class AdagradConfig(BaseOptimizerConfig):
+ """Configuration for Adagrad optimizer.
+
+ The attributes of this class match the arguments of
+ tf_keras.optimizer.Adagrad.
+
+ Attributes:
+ name: name of the optimizer.
+ initial_accumulator_value: A floating point value. Starting value for the
+ accumulators, must be non-negative.
+ epsilon: A small floating point value to avoid zero denominator.
+ """
+ name: str = "Adagrad"
+ initial_accumulator_value: float = 0.1
+ epsilon: float = 1e-07
+
+
+@dataclasses.dataclass
+class AdamConfig(BaseOptimizerConfig):
+ """Configuration for Adam optimizer.
+
+ The attributes for this class matches the arguments of
+ tf_keras.optimizer.Adam.
+
+ Attributes:
+ name: name of the optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in Adam optimizer.
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
+ the paper "On the Convergence of Adam and beyond".
+ """
+ name: str = "Adam"
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-07
+ amsgrad: bool = False
+
+
+@dataclasses.dataclass
+class AdamExperimentalConfig(BaseOptimizerConfig):
+ """Configuration for experimental Adam optimizer.
+
+ The attributes for this class matches the arguments of
+ `tf_keras.optimizer.experimental.Adam`.
+
+ Attributes:
+ name: name of the optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in Adam optimizer.
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
+ the paper "On the Convergence of Adam and beyond".
+ jit_compile: if True, jit compile will be used.
+ """
+ name: str = "Adam"
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-07
+ amsgrad: bool = False
+ jit_compile: bool = False
+
+
+@dataclasses.dataclass
+class AdamWeightDecayConfig(BaseOptimizerConfig):
+ """Configuration for Adam optimizer with weight decay.
+
+ Attributes:
+ name: name of the optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in the optimizer.
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
+ the paper "On the Convergence of Adam and beyond".
+ weight_decay_rate: float. Weight decay rate. Default to 0.
+ include_in_weight_decay: list[str], or None. List of weight names to include
+ in weight decay.
+ exclude_from_weight_decay: list[str], or None. List of weight names to not
+ include in weight decay.
+ gradient_clip_norm: A positive float. Clips the gradients to this maximum
+ L2-norm. Default to 1.0.
+ """
+ name: str = "AdamWeightDecay"
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-07
+ amsgrad: bool = False
+ weight_decay_rate: float = 0.0
+ include_in_weight_decay: Optional[List[str]] = None
+ exclude_from_weight_decay: Optional[List[str]] = None
+ gradient_clip_norm: float = 1.0
+
+
+@dataclasses.dataclass
+class AdamWeightDecayExperimentalConfig(BaseOptimizerConfig):
+ """Configuration for Adam optimizer with weight decay.
+
+ Attributes:
+ name: name of the optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in the optimizer.
+ amsgrad: boolean. Whether to apply AMSGrad variant of this algorithm from
+ the paper "On the Convergence of Adam and beyond".
+ weight_decay: float. Weight decay rate. Default to 0.
+ global_clipnorm: A positive float. Clips the gradients to this maximum
+ L2-norm. Default to 1.0.
+ jit_compile: if True, jit compile will be used.
+ """
+ name: str = "AdamWeightDecayExperimental"
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-07
+ amsgrad: bool = False
+ weight_decay: float = 0.0
+ global_clipnorm: float = 1.0
+ jit_compile: bool = False
+
+
+@dataclasses.dataclass
+class LAMBConfig(BaseOptimizerConfig):
+ """Configuration for LAMB optimizer.
+
+ The attributes for this class matches the arguments of LAMB optimizer.
+
+ Attributes:
+ name: name of the optimizer.
+ beta_1: decay rate for 1st order moments.
+ beta_2: decay rate for 2st order moments.
+ epsilon: epsilon value used for numerical stability in LAMB optimizer.
+ weight_decay_rate: float. Weight decay rate. Default to 0.
+ exclude_from_weight_decay: List of regex patterns of variables excluded from
+ weight decay. Variables whose name contain a substring matching the
+ pattern will be excluded.
+ exclude_from_layer_adaptation: List of regex patterns of variables excluded
+ from layer adaptation. Variables whose name contain a substring matching
+ the pattern will be excluded.
+ """
+ name: str = "LAMB"
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-6
+ weight_decay_rate: float = 0.0
+ exclude_from_weight_decay: Optional[List[str]] = None
+ exclude_from_layer_adaptation: Optional[List[str]] = None
+
+
+@dataclasses.dataclass
+class EMAConfig(BaseOptimizerConfig):
+ """Exponential moving average optimizer config.
+
+ Attributes:
+ name: 'str', name of the optimizer.
+ trainable_weights_only: 'bool', if True, only model trainable weights will
+ be updated. Otherwise, all model weights will be updated. This mainly
+ affects batch normalization parameters.
+ average_decay: 'float', average decay value.
+ start_step: 'int', start step to apply moving average.
+ dynamic_decay: 'bool', whether to apply dynamic decay or not.
+ """
+ name: str = "ExponentialMovingAverage"
+ trainable_weights_only: bool = True
+ average_decay: float = 0.99
+ start_step: int = 0
+ dynamic_decay: bool = True
+
+
+@dataclasses.dataclass
+class LARSConfig(BaseOptimizerConfig):
+ """Layer-wise adaptive rate scaling config.
+
+ Attributes:
+ name: 'str', name of the optimizer.
+ momentum: `float` hyperparameter >= 0 that accelerates gradient descent in
+ the relevant direction and dampens oscillations. Defaults to 0.9.
+ eeta: `float` LARS coefficient as used in the paper. Default set to LARS
+ coefficient from the paper. (eeta / weight_decay) determines the highest
+ scaling factor in LARS..
+ weight_decay_rate: `float` for weight decay.
+ nesterov: 'boolean' for whether to use nesterov momentum.
+ classic_momentum: `boolean` for whether to use classic (or popular)
+ momentum. The learning rate is applied during momentum update in classic
+ momentum, but after momentum for popular momentum.
+ exclude_from_weight_decay: A list of `string` for variable screening, if any
+ of the string appears in a variable's name, the variable will be excluded
+ for computing weight decay. For example, one could specify the list like
+ ['batch_normalization', 'bias'] to exclude BN and bias from weight decay.
+ exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but for
+ layer adaptation. If it is None, it will be defaulted the same as
+ exclude_from_weight_decay.
+ """
+ name: str = "LARS"
+ momentum: float = 0.9
+ eeta: float = 0.001
+ weight_decay_rate: float = 0.0
+ nesterov: bool = False
+ classic_momentum: bool = True
+ exclude_from_weight_decay: Optional[List[str]] = None
+ exclude_from_layer_adaptation: Optional[List[str]] = None
+
+
+@dataclasses.dataclass
+class SLIDEConfig(BaseOptimizerConfig):
+ """Configuration for SLIDE optimizer.
+
+ Details coming soon.
+ """
+ name: str = "SLIDE"
+ beta_1: float = 0.9
+ beta_2: float = 0.999
+ epsilon: float = 1e-6
+ weight_decay_rate: float = 0.0
+ weight_decay_type: str = "inner"
+ exclude_from_weight_decay: Optional[List[str]] = None
+ exclude_from_layer_adaptation: Optional[List[str]] = None
+ include_in_sparse_layer_adaptation: Optional[List[str]] = None
+ sparse_layer_learning_rate: float = 0.1
+ do_gradient_rescaling: bool = True
+ norm_type: str = "layer"
+ ratio_clip_norm: float = 1e5
+
+
+@dataclasses.dataclass
+class AdafactorConfig(BaseOptimizerConfig):
+ """Configuration for Adafactor optimizer.
+
+ The attributes for this class matches the arguments of the Adafactor
+ implementation.
+ """
+ name: str = "Adafactor"
+ factored: bool = True
+ multiply_by_parameter_scale: bool = True
+ beta1: Optional[float] = None
+ decay_rate: float = 0.8
+ step_offset: int = 0
+ clipping_threshold: float = 1.0
+ min_dim_size_to_factor: int = 128
+ epsilon1: float = 1e-30
+ epsilon2: float = 1e-3
+ weight_decay: Optional[float] = None
+ include_in_weight_decay: Optional[str] = None
+
+
+@dataclasses.dataclass
+class AdafactorKerasConfig(BaseOptimizerConfig):
+ """Configuration for AdafactorKeras optimizer.
+
+ The attributes for this class matches the arguments of the Adafactor
+ implementation provided by keras.
+
+ Attributes:
+ learning_rate: Initial value for the learning rate: either a floating
+ point value, or a
+ `tf_keras.optimizers.schedules.LearningRateSchedule` instance.
+ Defaults to 0.001.
+ beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
+ epsilon_1: float, defaults to 1e-30. A small offset to keep denominator
+ away from 0.
+ epsilon_2: float, defaults to 1e-3. A small offset to avoid learning
+ rate becoming too small by time.
+ clip_threshold: float, defaults to 1.0. Clipping threshold. This is a
+ part of Adafactor algorithm, independent from `clipnorm`, `clipvalue`
+ and `global_clipnorm`.
+ relative_step: bool, defaults to True. If `learning_rate` is a constant
+ and `relative_step=True`, learning rate will be adjusted based on
+ current iterations. This is a default learning rate decay in
+ Adafactor.
+ """
+ name: str = "Adafactor"
+ learning_rate: float = 0.001
+ beta_2_decay: float = -0.8
+ epsilon_1: float = 1e-30
+ epsilon_2: float = 1e-3
+ clip_threshold: float = 1.0
+ relative_step: bool = True
diff --git a/modeling/official/modeling/optimization/ema_optimizer.py b/modeling/official/modeling/optimization/ema_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f88842aea80848cc1ecb612cdf1f7e1dd3f149a
--- /dev/null
+++ b/modeling/official/modeling/optimization/ema_optimizer.py
@@ -0,0 +1,296 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Exponential moving average optimizer."""
+
+from typing import List, Optional
+
+import tensorflow as tf, tf_keras
+
+# pylint: disable=protected-access
+
+
+def maybe_merge_call(fn, strategy, *args, **kwargs):
+ """Maybe invoke `fn` via `merge_call` which may or may not be fulfilled.
+
+ The caller of this utility function requests to invoke `fn` via `merge_call`
+ at `tf.distribute.Strategy`'s best efforts. It is `tf.distribute`'s internal
+ whether the request is honored, depending on the `Strategy`. See
+ `tf.distribute.ReplicaContext.merge_call()` for more information.
+
+ This is adapted from tensorflow/python/distribute/merge_call_interim.py.
+
+ Args:
+ fn: the function to be invoked.
+ strategy: the `tf.distribute.Strategy` to call `fn` with.
+ *args: the positional arguments to be passed in to `fn`.
+ **kwargs: the keyword arguments to be passed in to `fn`.
+
+ Returns:
+ The return value of the `fn` call.
+ """
+ if strategy.extended._use_merge_call():
+ return tf.distribute.get_replica_context().merge_call(
+ fn, args=args, kwargs=kwargs
+ )
+ else:
+ return fn(strategy, *args, **kwargs)
+
+
+class ExponentialMovingAverage(tf_keras.optimizers.legacy.Optimizer):
+ """Optimizer that computes an exponential moving average of the variables.
+
+ Empirically it has been found that using the moving average of the trained
+ parameters of a deep network is better than using its trained parameters
+ directly. This optimizer allows you to compute this moving average and swap
+ the variables at save time so that any code outside of the training loop
+ will use by default the average values instead of the original ones.
+
+ Example of usage for training:
+ ```python
+ opt = tf_keras.optimizers.SGD(learning_rate)
+ opt = ExponentialMovingAverage(opt)
+
+ opt.shadow_copy(model)
+ ```
+
+ At test time, swap the shadow variables to evaluate on the averaged weights:
+ ```python
+ opt.swap_weights()
+ # Test eval the model here
+ opt.swap_weights()
+ ```
+ """
+
+ def __init__(self,
+ optimizer: tf_keras.optimizers.Optimizer,
+ trainable_weights_only: bool = True,
+ average_decay: float = 0.99,
+ start_step: int = 0,
+ dynamic_decay: bool = True,
+ name: str = 'ExponentialMovingAverage',
+ **kwargs):
+ """Construct a new ExponentialMovingAverage optimizer.
+
+ Args:
+ optimizer: `tf_keras.optimizers.Optimizer` that will be
+ used to compute and apply gradients.
+ trainable_weights_only: 'bool', if True, only model trainable weights will
+ be updated. Otherwise, all model weights will be updated. This mainly
+ affects batch normalization parameters.
+ average_decay: float. Decay to use to maintain the moving averages
+ of trained variables.
+ start_step: int. What step to start the moving average.
+ dynamic_decay: bool. Whether to change the decay based on the number
+ of optimizer updates. Decay will start at 0.1 and gradually increase
+ up to `average_decay` after each optimizer update. This behavior is
+ similar to `tf.train.ExponentialMovingAverage` in TF 1.x.
+ name: Optional name for the operations created when applying
+ gradients. Defaults to "moving_average".
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`,
+ `clipvalue`, `lr`, `decay`}.
+ """
+ super().__init__(name, **kwargs)
+ self._average_decay = average_decay
+ self._trainable_weights_only = trainable_weights_only
+ self._start_step = tf.constant(start_step, tf.float32)
+ self._dynamic_decay = dynamic_decay
+ self._optimizer = optimizer
+ self._track_trackable(self._optimizer, 'ema_base_optimizer')
+ self._average_weights = None
+ self._model_weights = None
+
+ def shadow_copy(self, model: tf_keras.Model):
+ """Creates shadow variables for the given model weights."""
+
+ if self._trainable_weights_only:
+ self._model_weights = model.trainable_variables
+ else:
+ self._model_weights = model.variables
+ for var in self._model_weights:
+ self.add_slot(var, 'average', initializer='zeros')
+
+ self._average_weights = [
+ self.get_slot(var, 'average') for var in self._model_weights
+ ]
+
+ @property
+ def has_shadow_copy(self):
+ """Whether this optimizer has created shadow variables."""
+ return self._model_weights is not None and self._average_weights is not None
+
+ def _create_slots(self, var_list):
+ self._optimizer._create_slots(var_list=var_list) # pylint: disable=protected-access
+
+ def apply_gradients(self, grads_and_vars, name: Optional[str] = None):
+ result = self._optimizer.apply_gradients(grads_and_vars, name)
+ maybe_merge_call(self.update_average, tf.distribute.get_strategy())
+ return result
+
+ @tf.function
+ def update_average(self, strategy):
+ # Compute current decay value.
+ step = tf.cast(self.iterations, tf.float32)
+ if step < self._start_step:
+ decay = tf.constant(0., tf.float32)
+ elif self._dynamic_decay:
+ decay = step - self._start_step
+ decay = tf.minimum(self._average_decay, (1. + decay) / (10. + decay))
+ else:
+ decay = self._average_decay
+
+ def _apply_moving(average, normal):
+ diff = average - normal
+ average.assign_sub(tf.cast(1.0 - decay, average.dtype) * diff)
+ return average
+
+ # Update moving average with the latest value.
+ for average, normal in zip(self._average_weights, self._model_weights):
+ strategy.extended.update(
+ average, _apply_moving, args=(normal,), group=False
+ )
+
+ def swap_weights(self):
+ """Swap the average and moving weights.
+
+ This is a convenience method to allow one to evaluate the averaged weights
+ at test time. Loads the weights stored in `self._average` into the model,
+ keeping a copy of the original model weights. Swapping twice will return
+ the original weights.
+ """
+ if tf.distribute.in_cross_replica_context():
+ strategy = tf.distribute.get_strategy()
+ strategy.run(self._swap_weights, args=())
+ else:
+ raise ValueError(
+ 'Swapping weights must occur under a tf.distribute.Strategy.'
+ )
+
+ @tf.function
+ def _swap_weights(self):
+ def fn_0(a, b):
+ a.assign_add(b)
+ return a
+ def fn_1(b, a):
+ b.assign(a - b)
+ return b
+ def fn_2(a, b):
+ a.assign_sub(b)
+ return a
+
+ def _swap(strategy, a_and_b):
+ """Swap `a` and `b` and mirror to all devices."""
+ for a, b in a_and_b:
+ strategy.extended.update(a, fn_0, args=(b,)) # a = a + b
+ strategy.extended.update(b, fn_1, args=(a,)) # b = a - b
+ strategy.extended.update(a, fn_2, args=(b,)) # a = a - b
+
+ # Use merge_call if requested by strategy and always for TPUStrategy as
+ # the use of merge_call is not recommended and deprecated for other
+ # strategies such as mirrored strategy (MS) and multi-worker mirrored
+ # strategy (MWMS) if nccl/collective_ops are used, which can operate in
+ # pure replica context.
+ strategy = tf.distribute.get_strategy()
+ if isinstance(strategy, tf.distribute.TPUStrategy):
+ maybe_merge_call(
+ _swap,
+ strategy,
+ zip(self._average_weights, self._model_weights),
+ )
+ else:
+ _swap(
+ strategy,
+ zip(self._average_weights, self._model_weights),
+ )
+
+ def assign_average_vars(self, var_list: List[tf.Variable]):
+ """Assign variables in var_list with their respective averages.
+
+ Args:
+ var_list: List of model variables to be assigned to their average.
+ Returns:
+ assign_op: The op corresponding to the assignment operation of
+ variables to their average.
+ """
+ assign_op = tf.group([
+ var.assign(self.get_slot(var, 'average')) for var in var_list
+ if var.trainable
+ ])
+ return assign_op
+
+ def _create_hypers(self):
+ self._optimizer._create_hypers() # pylint: disable=protected-access
+
+ def _prepare(self, var_list):
+ return self._optimizer._prepare(var_list=var_list) # pylint: disable=protected-access
+
+ @property
+ def iterations(self):
+ return self._optimizer.iterations
+
+ @iterations.setter
+ def iterations(self, variable):
+ self._optimizer.iterations = variable
+
+ @property
+ def weights(self):
+ # return self._weights + self._optimizer.weights
+ return self._optimizer.weights
+
+ def variables(self):
+ return self._weights + [self.iterations]
+
+ @property
+ def lr(self):
+ return self._optimizer._get_hyper('learning_rate')
+
+ @lr.setter
+ def lr(self, lr):
+ self._optimizer._set_hyper('learning_rate', lr)
+
+ @property
+ def learning_rate(self):
+ return self._optimizer._get_hyper('learning_rate')
+
+ @learning_rate.setter
+ def learning_rate(self, learning_rate): # pylint: disable=redefined-outer-name
+ self._optimizer._set_hyper('learning_rate', learning_rate)
+
+ def _resource_apply_dense(self, grad, var):
+ return self._optimizer._resource_apply_dense(grad, var)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ return self._optimizer._resource_apply_sparse(grad, var, indices)
+
+ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices):
+ return self._optimizer._resource_apply_sparse_duplicate_indices(
+ grad, var, indices)
+
+ def get_config(self):
+ config = {
+ 'optimizer': tf_keras.optimizers.serialize(self._optimizer),
+ 'average_decay': self._average_decay,
+ 'start_step': self._start_step,
+ 'dynamic_decay': self._dynamic_decay,
+ }
+ base_config = super(ExponentialMovingAverage, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ optimizer = tf_keras.optimizers.deserialize(
+ config.pop('optimizer'),
+ custom_objects=custom_objects,
+ )
+ return cls(optimizer, **config)
diff --git a/modeling/official/modeling/optimization/lamb.py b/modeling/official/modeling/optimization/lamb.py
new file mode 100644
index 0000000000000000000000000000000000000000..9524026478dbcbe2be010ffd546f351bb6eb05f8
--- /dev/null
+++ b/modeling/official/modeling/optimization/lamb.py
@@ -0,0 +1,252 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Layer-wise Adaptive Moments (LAMB) optimizer.
+
+See paper [Large Batch Optimization for Deep Learning: Training BERT in
+76 minutes](https://arxiv.org/abs/1904.00962).
+"""
+import re
+from typing import Optional, Union, Callable, List
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32]
+
+
+@tf_keras.utils.register_keras_serializable(package="Addons")
+class LAMB(tf_keras.optimizers.legacy.Optimizer):
+ """Optimizer that implements the Layer-wise Adaptive Moments (LAMB).
+
+ See paper [Large Batch Optimization for Deep Learning: Training BERT
+ in 76 minutes](https://arxiv.org/abs/1904.00962).
+ """
+
+ def __init__(
+ self,
+ learning_rate: Union[FloatTensorLike, Callable] = 0.001,
+ beta_1: FloatTensorLike = 0.9,
+ beta_2: FloatTensorLike = 0.999,
+ epsilon: FloatTensorLike = 1e-6,
+ weight_decay_rate: FloatTensorLike = 0.0,
+ exclude_from_weight_decay: Optional[List[str]] = None,
+ exclude_from_layer_adaptation: Optional[List[str]] = None,
+ name: str = "LAMB",
+ **kwargs,
+ ):
+ """Construct a new LAMB optimizer.
+
+ Args:
+ learning_rate: A `Tensor` or a floating point value. or a schedule that
+ is a `tf_keras.optimizers.schedules.LearningRateSchedule` The learning
+ rate.
+ beta_1: A `float` value or a constant `float` tensor. The exponential
+ decay rate for the 1st moment estimates.
+ beta_2: A `float` value or a constant `float` tensor. The exponential
+ decay rate for the 2nd moment estimates.
+ epsilon: A small constant for numerical stability.
+ weight_decay_rate: weight decay rate.
+ exclude_from_weight_decay: List of regex patterns of variables excluded
+ from weight decay. Variables whose name contain a substring matching
+ the pattern will be excluded.
+ exclude_from_layer_adaptation: List of regex patterns of variables
+ excluded from layer adaptation. Variables whose name contain a
+ substring matching the pattern will be excluded.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "LAMB".
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`,
+ `lr`, `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is
+ clip gradients by value, `decay` is included for backward
+ compatibility to allow time inverse decay of learning rate. `lr` is
+ included for backward compatibility, recommended to use
+ `learning_rate` instead.
+ """
+ super().__init__(name, **kwargs)
+
+ # Just adding the square of the weights to the loss function is *not*
+ # the correct way of using L2 regularization/weight decay with Adam,
+ # since that will interact with the m and v parameters in strange ways.
+ #
+ # Instead we want to decay the weights in a manner that doesn't interact
+ # with the m/v parameters.
+ self._set_hyper("weight_decay_rate", weight_decay_rate)
+ self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
+
+ # This is learning rate decay for using keras learning rate schedule.
+ self._set_hyper("decay", self._initial_decay)
+ self._set_hyper("beta_1", beta_1)
+ self._set_hyper("beta_2", beta_2)
+ self.epsilon = epsilon or tf.backend_config.epsilon()
+ self.exclude_from_weight_decay = exclude_from_weight_decay
+ # exclude_from_layer_adaptation is set to exclude_from_weight_decay if
+ # the arg is None.
+ if exclude_from_layer_adaptation:
+ self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
+ else:
+ self.exclude_from_layer_adaptation = exclude_from_weight_decay
+
+ def _create_slots(self, var_list):
+ # Create slots for the first and second moments.
+ # Separate for-loops to respect the ordering of slot variables from v1.
+ for var in var_list:
+ self.add_slot(var, "m")
+ for var in var_list:
+ self.add_slot(var, "v")
+
+ def _prepare_local(self, var_device, var_dtype, apply_state):
+ super()._prepare_local(var_device, var_dtype, apply_state)
+
+ local_step = tf.cast(self.iterations + 1, var_dtype)
+ beta_1_t = tf.identity(self._get_hyper("beta_1", var_dtype))
+ beta_2_t = tf.identity(self._get_hyper("beta_2", var_dtype))
+ weight_decay_rate = tf.identity(
+ self._get_hyper("weight_decay_rate", var_dtype)
+ )
+ beta_1_power = tf.pow(beta_1_t, local_step)
+ beta_2_power = tf.pow(beta_2_t, local_step)
+ apply_state[(var_device, var_dtype)].update(
+ dict(
+ weight_decay_rate=weight_decay_rate,
+ epsilon=tf.convert_to_tensor(self.epsilon, var_dtype),
+ beta_1_t=beta_1_t,
+ beta_1_power=beta_1_power,
+ one_minus_beta_1_t=1 - beta_1_t,
+ beta_2_t=beta_2_t,
+ beta_2_power=beta_2_power,
+ one_minus_beta_2_t=1 - beta_2_t,
+ )
+ )
+
+ def _resource_apply_dense(self, grad, var, apply_state=None):
+ var_device, var_dtype = var.device, var.dtype.base_dtype
+ coefficients = (apply_state or {}).get(
+ (var_device, var_dtype)
+ ) or self._fallback_apply_state(var_device, var_dtype)
+
+ # m_t = beta1 * m + (1 - beta1) * g_t
+ m = self.get_slot(var, "m")
+ m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
+ m_t = m * coefficients["beta_1_t"] + m_scaled_g_values
+ m_t = m.assign(m_t, use_locking=self._use_locking)
+ # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+ v = self.get_slot(var, "v")
+ v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
+ v_t = v * coefficients["beta_2_t"] + v_scaled_g_values
+ v_t = v.assign(v_t, use_locking=self._use_locking)
+
+ m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
+ v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])
+
+ v_sqrt = tf.sqrt(v_t_hat)
+ update = m_t_hat / (v_sqrt + coefficients["epsilon"])
+
+ var_name = self._get_variable_name(var.name)
+ if self._do_use_weight_decay(var_name):
+ update += coefficients["weight_decay_rate"] * var
+
+ ratio = 1.0
+ if self._do_layer_adaptation(var_name):
+ w_norm = tf.norm(var, ord=2)
+ g_norm = tf.norm(update, ord=2)
+ ratio = tf.where(
+ tf.greater(w_norm, 0),
+ tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
+ 1.0,
+ )
+
+ var_update = var - ratio * coefficients["lr_t"] * update
+ return var.assign(var_update, use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
+ var_device, var_dtype = var.device, var.dtype.base_dtype
+ coefficients = (apply_state or {}).get(
+ (var_device, var_dtype)
+ ) or self._fallback_apply_state(var_device, var_dtype)
+
+ # m_t = beta1 * m + (1 - beta1) * g_t
+ m = self.get_slot(var, "m")
+ m_scaled_g_values = grad * coefficients["one_minus_beta_1_t"]
+ m_t = m.assign(m * coefficients["beta_1_t"], use_locking=self._use_locking)
+ with tf.control_dependencies([m_t]):
+ m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)
+
+ # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
+ v = self.get_slot(var, "v")
+ v_scaled_g_values = (grad * grad) * coefficients["one_minus_beta_2_t"]
+ v_t = v.assign(v * coefficients["beta_2_t"], use_locking=self._use_locking)
+ with tf.control_dependencies([v_t]):
+ v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)
+
+ m_t_hat = m_t / (1.0 - coefficients["beta_1_power"])
+ v_t_hat = v_t / (1.0 - coefficients["beta_2_power"])
+
+ v_sqrt = tf.sqrt(v_t_hat)
+ update = m_t_hat / (v_sqrt + coefficients["epsilon"])
+
+ var_name = self._get_variable_name(var.name)
+ if self._do_use_weight_decay(var_name):
+ update += coefficients["weight_decay_rate"] * var
+
+ ratio = 1.0
+ if self._do_layer_adaptation(var_name):
+ w_norm = tf.norm(var, ord=2)
+ g_norm = tf.norm(update, ord=2)
+ ratio = tf.where(
+ tf.greater(w_norm, 0),
+ tf.where(tf.greater(g_norm, 0), (w_norm / g_norm), 1.0),
+ 1.0,
+ )
+
+ var_update = var.assign_sub(
+ ratio * coefficients["lr_t"] * update, use_locking=self._use_locking
+ )
+ return tf.group(*[var_update, m_t, v_t])
+
+ def get_config(self):
+ config = super().get_config()
+ config.update({
+ "learning_rate": self._serialize_hyperparameter("learning_rate"),
+ "weight_decay_rate": self._serialize_hyperparameter(
+ "weight_decay_rate"
+ ),
+ "decay": self._serialize_hyperparameter("decay"),
+ "beta_1": self._serialize_hyperparameter("beta_1"),
+ "beta_2": self._serialize_hyperparameter("beta_2"),
+ "epsilon": self.epsilon,
+ })
+ return config
+
+ def _do_use_weight_decay(self, param_name):
+ """Whether to use L2 weight decay for `param_name`."""
+ if self.exclude_from_weight_decay:
+ for r in self.exclude_from_weight_decay:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
+
+ def _do_layer_adaptation(self, param_name):
+ """Whether to do layer-wise learning rate adaptation for `param_name`."""
+ if self.exclude_from_layer_adaptation:
+ for r in self.exclude_from_layer_adaptation:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
+
+ def _get_variable_name(self, param_name):
+ """Get the variable name from the tensor name."""
+ m = re.match("^(.*):\\d+$", param_name)
+ if m is not None:
+ param_name = m.group(1)
+ return param_name
diff --git a/modeling/official/modeling/optimization/lamb_test.py b/modeling/official/modeling/optimization/lamb_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..56fa513649c36a876c7951e0ec54e621d2ceb0a9
--- /dev/null
+++ b/modeling/official/modeling/optimization/lamb_test.py
@@ -0,0 +1,177 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for LAMB Optimizer."""
+import numpy as np
+from numpy import linalg
+
+import tensorflow as tf, tf_keras
+
+from official.modeling.optimization import lamb
+
+
+def lamb_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ lr=0.001,
+ lamb_wd=0.0,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-6):
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ m_t_hat = m_t / (1 - beta1**(t + 1))
+ v_t_hat = v_t / (1 - beta2**(t + 1))
+ update = m_t_hat / (np.sqrt(v_t_hat) + epsilon)
+
+ update += lamb_wd * param
+
+ w_norm = linalg.norm(param, ord=2)
+ g_norm = linalg.norm(update, ord=2)
+ ratio = np.where(w_norm > 0, np.where(g_norm > 0, (w_norm / g_norm), 1.0),
+ 1.0)
+
+ param_t = param - ratio * lr * update
+ return param_t, m_t, v_t
+
+
+def get_beta_accumulators(opt, dtype):
+ local_step = tf.cast(opt.iterations + 1, dtype)
+ beta_1_t = tf.cast(opt._get_hyper("beta_1"), dtype)
+ beta_1_power = tf.math.pow(beta_1_t, local_step)
+ beta_2_t = tf.cast(opt._get_hyper("beta_2"), dtype)
+ beta_2_power = tf.math.pow(beta_2_t, local_step)
+ return (beta_1_power, beta_2_power)
+
+
+class LAMBTest(tf.test.TestCase):
+
+ def test_sparse(self):
+ dtype = tf.float32
+ # Initialize tf for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.0, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.0, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = tf.Variable(var0_np)
+ var1 = tf.Variable(var1_np)
+ grads0_np_indices = np.array([0, 2], dtype=np.int32)
+ grads0 = tf.IndexedSlices(
+ tf.constant(grads0_np[grads0_np_indices]),
+ tf.constant(grads0_np_indices),
+ tf.constant([3]),
+ )
+ grads1_np_indices = np.array([0, 2], dtype=np.int32)
+ grads1 = tf.IndexedSlices(
+ tf.constant(grads1_np[grads1_np_indices]),
+ tf.constant(grads1_np_indices),
+ tf.constant([3]),
+ )
+ opt = lamb.LAMB()
+
+ # Fetch params to validate initial values
+ np.testing.assert_allclose(np.asanyarray([1.0, 1.0, 2.0]), var0.numpy())
+ np.testing.assert_allclose(np.asanyarray([3.0, 3.0, 4.0]), var1.numpy())
+
+ # Run 3 steps of LAMB
+ for t in range(3):
+ beta_1_power, beta_2_power = get_beta_accumulators(opt, dtype)
+ self.assertAllClose(0.9 ** (t + 1), beta_1_power)
+ self.assertAllClose(0.999 ** (t + 1), beta_2_power)
+
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ var0_np, m0, v0 = lamb_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = lamb_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.numpy())
+ self.assertAllClose(var1_np, var1.numpy())
+
+ def test_basic_with_learning_rate_decay(self):
+ dtype = tf.float32
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = tf.Variable(var0_np, name="var0")
+ var1 = tf.Variable(var1_np, name="var1")
+ grads0 = tf.constant(grads0_np)
+ grads1 = tf.constant(grads1_np)
+
+ learning_rate = 0.001
+ beta_1 = 0.9
+ beta_2 = 0.999
+ epsilon = 1e-7
+ decay = 0.5
+ lamb_wd = 0.01
+
+ opt = lamb.LAMB(
+ learning_rate=learning_rate,
+ beta_1=beta_1,
+ beta_2=beta_2,
+ epsilon=epsilon,
+ weight_decay_rate=lamb_wd,
+ decay=decay,
+ )
+
+ # Run 3 steps of LAMB
+ for t in range(3):
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ lr_np = learning_rate / (1 + decay * t)
+
+ var0_np, m0, v0 = lamb_update_numpy(
+ var0_np, grads0_np, t, m0, v0, lr=lr_np, lamb_wd=lamb_wd)
+ var1_np, m1, v1 = lamb_update_numpy(
+ var1_np, grads1_np, t, m1, v1, lr=lr_np, lamb_wd=lamb_wd)
+
+ # Validate updated params
+ self.assertAllClose(var0_np, var0.numpy())
+ self.assertAllClose(var1_np, var1.numpy())
+
+ def test_exclude_weight_decay(self):
+ opt = lamb.LAMB(
+ 0.01, weight_decay_rate=0.01, exclude_from_weight_decay=["var1"]
+ )
+ assert opt._do_use_weight_decay("var0")
+ assert not opt._do_use_weight_decay("var1")
+ assert not opt._do_use_weight_decay("var1_weight")
+
+ def test_exclude_layer_adaptation(self):
+ opt = lamb.LAMB(0.01, exclude_from_layer_adaptation=["var1"])
+ assert opt._do_layer_adaptation("var0")
+ assert not opt._do_layer_adaptation("var1")
+ assert not opt._do_layer_adaptation("var1_weight")
+
+ def test_serialization(self):
+ optimizer = lamb.LAMB(1e-4)
+ config = tf_keras.optimizers.serialize(optimizer, use_legacy_format=True)
+ new_optimizer = tf_keras.optimizers.deserialize(
+ config, use_legacy_format=True
+ )
+ assert new_optimizer.get_config() == optimizer.get_config()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/modeling/optimization/lars.py b/modeling/official/modeling/optimization/lars.py
new file mode 100644
index 0000000000000000000000000000000000000000..de083b6254324953422f4d3e82b3efc20ca46e49
--- /dev/null
+++ b/modeling/official/modeling/optimization/lars.py
@@ -0,0 +1,186 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Layer-wise adaptive rate scaling optimizer."""
+import re
+from typing import Text, List, Optional
+
+import tensorflow as tf, tf_keras
+
+
+# pylint: disable=protected-access
+
+
+class LARS(tf_keras.optimizers.legacy.Optimizer):
+ """Layer-wise Adaptive Rate Scaling for large batch training.
+
+ Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
+ I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
+ """
+
+ def __init__(self,
+ learning_rate: float = 0.01,
+ momentum: float = 0.9,
+ weight_decay_rate: float = 0.0,
+ eeta: float = 0.001,
+ nesterov: bool = False,
+ classic_momentum: bool = True,
+ exclude_from_weight_decay: Optional[List[Text]] = None,
+ exclude_from_layer_adaptation: Optional[List[Text]] = None,
+ name: Text = "LARS",
+ **kwargs):
+ """Constructs a LARSOptimizer.
+
+ Args:
+ learning_rate: `float` for learning rate. Defaults to 0.01.
+ momentum: `float` hyperparameter >= 0 that accelerates gradient descent
+ in the relevant direction and dampens oscillations. Defaults to 0.9.
+ weight_decay_rate: `float` for weight decay.
+ eeta: `float` LARS coefficient as used in the paper. Default set to LARS
+ coefficient from the paper. (eeta / weight_decay) determines the
+ highest scaling factor in LARS..
+ nesterov: 'boolean' for whether to use nesterov momentum.
+ classic_momentum: `boolean` for whether to use classic (or popular)
+ momentum. The learning rate is applied during momentum update in
+ classic momentum, but after momentum for popular momentum.
+ exclude_from_weight_decay: A list of `string` for variable screening, if
+ any of the string appears in a variable's name, the variable will be
+ excluded for computing weight decay. For example, one could specify
+ the list like ['batch_normalization', 'bias'] to exclude BN and bias
+ from weight decay.
+ exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
+ for layer adaptation. If it is None, it will be defaulted the same as
+ exclude_from_weight_decay.
+ name: `Text` as optional name for the operations created when applying
+ gradients. Defaults to "LARS".
+ **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
+ `decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
+ gradients by value, `decay` is included for backward compatibility to
+ allow time inverse decay of learning rate. `lr` is included for
+ backward compatibility, recommended to use `learning_rate` instead.
+ """
+ super(LARS, self).__init__(name, **kwargs)
+
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("decay", self._initial_decay)
+ self.momentum = momentum
+ self.weight_decay_rate = weight_decay_rate
+ self.eeta = eeta
+ self.nesterov = nesterov
+ self.classic_momentum = classic_momentum
+ self.exclude_from_weight_decay = exclude_from_weight_decay
+ # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
+ # arg is None.
+ if exclude_from_layer_adaptation:
+ self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
+ else:
+ self.exclude_from_layer_adaptation = exclude_from_weight_decay
+
+ def _create_slots(self, var_list):
+ for v in var_list:
+ self.add_slot(v, "momentum")
+
+ def _resource_apply_dense(self, grad, param, apply_state=None):
+ if grad is None or param is None:
+ return tf.no_op()
+
+ var_device, var_dtype = param.device, param.dtype.base_dtype
+ coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
+ self._fallback_apply_state(var_device, var_dtype))
+ learning_rate = coefficients["lr_t"]
+
+ param_name = param.name
+
+ v = self.get_slot(param, "momentum")
+
+ if self._use_weight_decay(param_name):
+ grad += self.weight_decay_rate * param
+
+ if self.classic_momentum:
+ trust_ratio = 1.0
+ if self._do_layer_adaptation(param_name):
+ w_norm = tf.norm(param, ord=2)
+ g_norm = tf.norm(grad, ord=2)
+ trust_ratio = tf.where(
+ tf.greater(w_norm, 0),
+ tf.where(tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0),
+ 1.0)
+ scaled_lr = learning_rate * trust_ratio
+
+ next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
+ if self.nesterov:
+ update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
+ else:
+ update = next_v
+ next_param = param - update
+ else:
+ next_v = tf.multiply(self.momentum, v) + grad
+ if self.nesterov:
+ update = tf.multiply(self.momentum, next_v) + grad
+ else:
+ update = next_v
+
+ trust_ratio = 1.0
+ if self._do_layer_adaptation(param_name):
+ w_norm = tf.norm(param, ord=2)
+ v_norm = tf.norm(update, ord=2)
+ trust_ratio = tf.where(
+ tf.greater(w_norm, 0),
+ tf.where(tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm), 1.0),
+ 1.0)
+ scaled_lr = trust_ratio * learning_rate
+ next_param = param - scaled_lr * update
+
+ return tf.group(*[
+ param.assign(next_param, use_locking=False),
+ v.assign(next_v, use_locking=False)
+ ])
+
+ def _resource_apply_sparse(self, grad, handle, indices, apply_state):
+ raise NotImplementedError("Applying sparse gradients is not implemented.")
+
+ def _use_weight_decay(self, param_name):
+ """Whether to use L2 weight decay for `param_name`."""
+ if not self.weight_decay_rate:
+ return False
+ if self.exclude_from_weight_decay:
+ for r in self.exclude_from_weight_decay:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
+
+ def _do_layer_adaptation(self, param_name):
+ """Whether to do layer-wise learning rate adaptation for `param_name`."""
+ if self.exclude_from_layer_adaptation:
+ for r in self.exclude_from_layer_adaptation:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
+
+ def get_config(self):
+ config = super(LARS, self).get_config()
+ config.update({
+ "learning_rate": self._serialize_hyperparameter("learning_rate"),
+ "decay": self._serialize_hyperparameter("decay"),
+ "momentum": self.momentum,
+ "classic_momentum": self.classic_momentum,
+ "weight_decay_rate": self.weight_decay_rate,
+ "eeta": self.eeta,
+ "nesterov": self.nesterov,
+ })
+ return config
+
+ @classmethod
+ def from_config(cls, config, custom_objects=None):
+ return cls(**config)
diff --git a/modeling/official/modeling/optimization/legacy_adamw.py b/modeling/official/modeling/optimization/legacy_adamw.py
new file mode 100644
index 0000000000000000000000000000000000000000..16bd640cbef833ef5249dea64043cd4c56b96663
--- /dev/null
+++ b/modeling/official/modeling/optimization/legacy_adamw.py
@@ -0,0 +1,139 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Adam optimizer with weight decay that exactly matches the original BERT."""
+
+import re
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+
+class AdamWeightDecay(tf_keras.optimizers.legacy.Adam):
+ """Adam enables L2 weight decay and clip_by_global_norm on gradients.
+
+ [Warning!]: Keras optimizer supports gradient clipping and has an AdamW
+ implementation. Please consider evaluating the choice in Keras package.
+
+ Just adding the square of the weights to the loss function is *not* the
+ correct way of using L2 regularization/weight decay with Adam, since that will
+ interact with the m and v parameters in strange ways.
+
+ Instead we want to decay the weights in a manner that doesn't interact with
+ the m/v parameters. This is equivalent to adding the square of the weights to
+ the loss with plain (non-momentum) SGD.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-7,
+ amsgrad=False,
+ weight_decay_rate=0.0,
+ include_in_weight_decay=None,
+ exclude_from_weight_decay=None,
+ gradient_clip_norm=1.0,
+ name='AdamWeightDecay',
+ **kwargs):
+ super(AdamWeightDecay, self).__init__(learning_rate, beta_1, beta_2,
+ epsilon, amsgrad, name, **kwargs)
+ self.weight_decay_rate = weight_decay_rate
+ self.gradient_clip_norm = gradient_clip_norm
+ self._include_in_weight_decay = include_in_weight_decay
+ self._exclude_from_weight_decay = exclude_from_weight_decay
+ logging.info('AdamWeightDecay gradient_clip_norm=%f', gradient_clip_norm)
+
+ def _prepare_local(self, var_device, var_dtype, apply_state):
+ super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype, # pytype: disable=attribute-error # typed-keras
+ apply_state)
+ apply_state[(var_device, var_dtype)]['weight_decay_rate'] = tf.constant(
+ self.weight_decay_rate, name='adam_weight_decay_rate')
+
+ def _decay_weights_op(self, var, learning_rate, apply_state):
+ do_decay = self._do_use_weight_decay(var.name)
+ if do_decay:
+ return var.assign_sub(
+ learning_rate * var *
+ apply_state[(var.device, var.dtype.base_dtype)]['weight_decay_rate'],
+ use_locking=self._use_locking)
+ return tf.no_op()
+
+ def apply_gradients(self,
+ grads_and_vars,
+ name=None,
+ experimental_aggregate_gradients=True):
+ grads, tvars = list(zip(*grads_and_vars))
+ if experimental_aggregate_gradients and self.gradient_clip_norm > 0.0:
+ # when experimental_aggregate_gradients = False, apply_gradients() no
+ # longer implicitly allreduce gradients, users manually allreduce gradient
+ # and passed the allreduced grads_and_vars. For now, the
+ # clip_by_global_norm will be moved to before the explicit allreduce to
+ # keep the math the same as TF 1 and pre TF 2.2 implementation.
+ (grads, _) = tf.clip_by_global_norm(
+ grads, clip_norm=self.gradient_clip_norm)
+ return super(AdamWeightDecay, self).apply_gradients(
+ zip(grads, tvars),
+ name=name,
+ experimental_aggregate_gradients=experimental_aggregate_gradients)
+
+ def _get_lr(self, var_device, var_dtype, apply_state):
+ """Retrieves the learning rate with the given state."""
+ if apply_state is None:
+ return self._decayed_lr_t[var_dtype], {}
+
+ apply_state = apply_state or {}
+ coefficients = apply_state.get((var_device, var_dtype))
+ if coefficients is None:
+ coefficients = self._fallback_apply_state(var_device, var_dtype)
+ apply_state[(var_device, var_dtype)] = coefficients
+
+ return coefficients['lr_t'], dict(apply_state=apply_state)
+
+ def _resource_apply_dense(self, grad, var, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super(AdamWeightDecay,
+ self)._resource_apply_dense(grad, var, **kwargs) # pytype: disable=attribute-error # typed-keras
+
+ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
+ lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
+ decay = self._decay_weights_op(var, lr_t, apply_state)
+ with tf.control_dependencies([decay]):
+ return super(AdamWeightDecay,
+ self)._resource_apply_sparse(grad, var, indices, **kwargs) # pytype: disable=attribute-error # typed-keras
+
+ def get_config(self):
+ config = super(AdamWeightDecay, self).get_config()
+ config.update({
+ 'weight_decay_rate': self.weight_decay_rate,
+ })
+ return config
+
+ def _do_use_weight_decay(self, param_name):
+ """Whether to use L2 weight decay for `param_name`."""
+ if self.weight_decay_rate == 0:
+ return False
+
+ if self._include_in_weight_decay:
+ for r in self._include_in_weight_decay:
+ if re.search(r, param_name) is not None:
+ return True
+
+ if self._exclude_from_weight_decay:
+ for r in self._exclude_from_weight_decay:
+ if re.search(r, param_name) is not None:
+ return False
+ return True
diff --git a/modeling/official/modeling/optimization/lr_schedule.py b/modeling/official/modeling/optimization/lr_schedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..03f0c0da0979ff809aec2ffae04fdd1db3acd551
--- /dev/null
+++ b/modeling/official/modeling/optimization/lr_schedule.py
@@ -0,0 +1,490 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Learning rate schedule classes."""
+
+import math
+from typing import Mapping, Any, Union, Optional
+
+import tensorflow as tf, tf_keras
+
+
+def _make_offset_wrapper(new_class_name: str, base_lr_class):
+ """Generates a offset wrapper of learning rate schedule.
+
+ It will returns a subclass of the `base_lr_class`, the subclass takes an
+ `offset` argument in the constructor. When the new class instance is called,
+ the behavior is:
+ new_class_object(step) = base_lr_class_object(step - offset)
+
+ Example:
+ CosineDecayWithOffset = _make_offset_wrapper(
+ 'CosineDecayWithOffset',
+ tf_keras.optimizers.schedules.CosineDecay)
+ # Use the lr:
+ lr = CosineDecayWithOffset(offset=100, initial_learning_rate=0.1,
+ decay_steps=1000)
+ lr(101) # equals to keras.optimizers.schedules.CosineDecay(...)(101-100)
+
+ Args:
+ new_class_name: the name of the new class.
+ base_lr_class: the base learning rate schedule class. Should be subclass of
+ tf_keras.optimizers.schedules.LearningRateSchedule
+
+ Returns:
+ A new class (subclass of the base_lr_class) that can take an offset.
+ """
+ assert issubclass(base_lr_class,
+ tf_keras.optimizers.schedules.LearningRateSchedule), (
+ "base_lr_class should be subclass of keras "
+ f"LearningRateSchedule, got {base_lr_class}")
+
+ # pylint: disable=protected-access,pointless-statement
+ def offset_learning_rate_init(self, offset=0, **kwargs):
+ """Construct learning rate schedule object.
+
+ When this object is called, its behavior is
+ self.__call__(step) == base_lr_class.__call__(step - offset)
+ Args:
+ self: this object.
+ offset: The offset when computing the learning rate schedule.
+ **kwargs: Pass through to base learning rate class constructor.
+ """
+ base_lr_class.__init__(self, **kwargs)
+ self._offset = offset
+
+ def offset_learning_rate_call(self, step):
+ step = tf.cast(step - self._offset, tf.float32)
+ return base_lr_class.__call__(self, step)
+
+ # pylint: enable=protected-access,pointless-statement
+
+ return type(
+ new_class_name, (base_lr_class,), {
+ "base_lr_class": base_lr_class,
+ "__init__": offset_learning_rate_init,
+ "__call__": offset_learning_rate_call
+ })
+
+
+PiecewiseConstantDecayWithOffset = _make_offset_wrapper(
+ "PiecewiseConstantDecayWithOffset",
+ tf_keras.optimizers.schedules.PiecewiseConstantDecay)
+PolynomialDecayWithOffset = _make_offset_wrapper(
+ "PolynomialDecayWithOffset", tf_keras.optimizers.schedules.PolynomialDecay)
+ExponentialDecayWithOffset = _make_offset_wrapper(
+ "ExponentialDecayWithOffset",
+ tf_keras.optimizers.schedules.ExponentialDecay)
+CosineDecayWithOffset = _make_offset_wrapper(
+ "CosineDecayWithOffset",
+ tf_keras.optimizers.schedules.CosineDecay,
+)
+
+
+class LinearWarmup(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Linear warmup schedule."""
+
+ def __init__(self,
+ after_warmup_lr_sched: Union[
+ tf_keras.optimizers.schedules.LearningRateSchedule, float],
+ warmup_steps: int,
+ warmup_learning_rate: float,
+ name: Optional[str] = None):
+ """Add linear warmup schedule to a learning rate schedule.
+
+ warmup_lr is the initial learning rate, the final learning rate of the
+ init_warmup period is the initial learning rate of lr_schedule in use.
+ The learning rate at each step linearly increased according to the following
+ formula:
+ learning_rate = warmup_lr + step / warmup_steps
+ * (final_warmup_lr - warmup_lr).
+ Using warmup overrides the learning rate schedule by the number of warmup
+ steps.
+
+ Args:
+ after_warmup_lr_sched: tf_keras.optimizers.schedules .LearningRateSchedule
+ or a constant.
+ warmup_steps: Number of the warmup steps.
+ warmup_learning_rate: Initial learning rate for the warmup.
+ name: Optional, name of warmup schedule.
+ """
+ super().__init__()
+ self._name = name
+ self._after_warmup_lr_sched = after_warmup_lr_sched
+ self._warmup_steps = warmup_steps
+ self._init_warmup_lr = warmup_learning_rate
+ if isinstance(after_warmup_lr_sched,
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ self._final_warmup_lr = after_warmup_lr_sched(warmup_steps)
+ else:
+ self._final_warmup_lr = tf.cast(after_warmup_lr_sched, dtype=tf.float32)
+
+ def __call__(self, step: int):
+
+ global_step = tf.cast(step, dtype=tf.float32)
+
+ linear_warmup_lr = (
+ self._init_warmup_lr + global_step / self._warmup_steps *
+ (self._final_warmup_lr - self._init_warmup_lr))
+
+ if isinstance(self._after_warmup_lr_sched,
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ after_warmup_lr = self._after_warmup_lr_sched(step)
+ else:
+ after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)
+
+ lr = tf.cond(global_step < self._warmup_steps,
+ lambda: linear_warmup_lr,
+ lambda: after_warmup_lr)
+ return lr
+
+ def get_config(self) -> Mapping[str, Any]:
+ if isinstance(self._after_warmup_lr_sched,
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ config = {
+ "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error
+ else:
+ config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error
+
+ config.update({
+ "warmup_steps": self._warmup_steps,
+ "warmup_learning_rate": self._init_warmup_lr,
+ "name": self._name
+ })
+ return config
+
+
+class PolynomialWarmUp(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Applies polynomial warmup schedule on a given learning rate decay schedule."""
+
+ def __init__(self,
+ after_warmup_lr_sched: Union[
+ tf_keras.optimizers.schedules.LearningRateSchedule, float],
+ warmup_steps: int,
+ power: float = 1.0,
+ name: str = "PolynomialWarmup"):
+ super().__init__()
+ if isinstance(after_warmup_lr_sched,
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ self._initial_learning_rate = after_warmup_lr_sched(warmup_steps)
+ else:
+ self._initial_learning_rate = tf.cast(
+ after_warmup_lr_sched, dtype=tf.float32)
+
+ self._warmup_steps = warmup_steps
+ self._power = power
+ self._after_warmup_lr_sched = after_warmup_lr_sched
+ self._name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self._name or "PolynomialWarmUp") as name:
+ # Implements polynomial warmup. i.e., if global_step < warmup_steps, the
+ # learning rate will be `global_step/num_warmup_steps * init_lr`.
+ global_step_float = tf.cast(step, tf.float32)
+ warmup_steps_float = tf.cast(self._warmup_steps, tf.float32)
+
+ if self._warmup_steps <= 0:
+ warmup_percent_done = 1.0
+ else:
+ # A zero `step` may cause Inf. So make `step` positive.
+ step_non_zero = tf.math.maximum(global_step_float, 1.0)
+ warmup_percent_done = step_non_zero / warmup_steps_float
+
+ warmup_learning_rate = (
+ self._initial_learning_rate *
+ tf.math.pow(warmup_percent_done, self._power))
+
+ if isinstance(self._after_warmup_lr_sched,
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ after_warmup_lr = self._after_warmup_lr_sched(step)
+ else:
+ after_warmup_lr = tf.cast(self._after_warmup_lr_sched, dtype=tf.float32)
+
+ return tf.cond(
+ global_step_float < warmup_steps_float,
+ lambda: warmup_learning_rate,
+ lambda: after_warmup_lr,
+ name=name)
+
+ def get_config(self) -> Mapping[str, Any]:
+ if isinstance(self._after_warmup_lr_sched,
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ config = {
+ "after_warmup_lr_sched": self._after_warmup_lr_sched.get_config()} # pytype: disable=attribute-error
+ else:
+ config = {"after_warmup_lr_sched": self._after_warmup_lr_sched} # pytype: disable=attribute-error
+
+ config.update({
+ "warmup_steps": self._warmup_steps,
+ "power": self._power,
+ "name": self._name
+ })
+ return config
+
+
+class DirectPowerDecay(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Learning rate schedule follows lr * (step)^power."""
+
+ def __init__(self,
+ initial_learning_rate: float,
+ power: float = 1.0,
+ name: str = "DirectPowerDecay"):
+ """Initialize configuration of the learning rate schedule.
+
+ Args:
+ initial_learning_rate: The initial learning rate.
+ power: The order of the polynomial.
+ name: Optional, name of learning rate schedule.
+ """
+ super().__init__()
+ self._initial_learning_rate = initial_learning_rate
+ self._power = power
+ self._name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self._name or "DirectPowerDecay"):
+ step = tf.cast(step, tf.float32)
+ learning_rate = self._initial_learning_rate
+ # A zero `step` may cause Inf. So make `step` positive.
+ step_non_zero = tf.math.maximum(step, 1.0)
+ learning_rate *= tf.math.pow(step_non_zero, self._power)
+ return learning_rate
+
+ def get_config(self):
+ """Get the configuration of the learning rate schedule."""
+ return {
+ "initial_learning_rate": self._initial_learning_rate,
+ "power": self._power,
+ "name": self._name,
+ }
+
+
+class PowerAndLinearDecay(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Learning rate schedule with multiplied by linear decay at the end.
+
+ The schedule has the following behavoir.
+ Let offset_step = step - offset.
+ 1) offset_step < 0, the actual learning rate equals initial_learning_rate.
+ 2) offset_step <= total_decay_steps * (1 - linear_decay_fraction), the
+ actual learning rate equals lr * offset_step^power.
+ 3) total_decay_steps * (1 - linear_decay_fraction) <= offset_step <
+ total_decay_steps, the actual learning rate equals lr * offset_step^power *
+ (total_decay_steps - offset_step) / (total_decay_steps *
+ linear_decay_fraction).
+ 4) offset_step >= total_decay_steps, the actual learning rate equals zero.
+ """
+
+ def __init__(self,
+ initial_learning_rate: float,
+ total_decay_steps: int,
+ power: float = 1.0,
+ linear_decay_fraction: float = 0.1,
+ offset: int = 0,
+ name: str = "PowerAndLinearDecay"):
+ """Initialize configuration of the learning rate schedule.
+
+ Args:
+ initial_learning_rate: The initial learning rate.
+ total_decay_steps: The total number of steps for power + linear decay.
+ power: The order of the polynomial.
+ linear_decay_fraction: In the last `linear_decay_fraction` steps, the
+ learning rate will be multiplied by a linear decay.
+ offset: The offset applied to steps.
+ name: Optional, name of learning rate schedule.
+ """
+ super().__init__()
+ self._initial_learning_rate = initial_learning_rate
+ self._total_decay_steps = total_decay_steps
+ self._power = power
+ self._linear_decay_fraction = linear_decay_fraction
+ self._offset = offset
+ self._name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self._name or "PowerAndLinearDecay"):
+ step = tf.cast(step - self._offset, tf.float32)
+ learning_rate = self._initial_learning_rate
+ # A zero `step` may cause Inf. So make `step` positive.
+ step_non_zero = tf.math.maximum(step, 1.0)
+ learning_rate *= tf.math.pow(step_non_zero, self._power)
+ if self._total_decay_steps * self._linear_decay_fraction > 0:
+ learning_rate *= tf.minimum(
+ 1.0, (self._total_decay_steps - step) /
+ (self._total_decay_steps * self._linear_decay_fraction))
+ learning_rate = tf.maximum(0.0, learning_rate)
+ return learning_rate
+
+ def get_config(self):
+ """Get the configuration of the learning rate schedule."""
+ return {
+ "initial_learning_rate": self._initial_learning_rate,
+ "total_decay_steps": self._total_decay_steps,
+ "power": self._power,
+ "linear_decay_fraction": self._linear_decay_fraction,
+ "offset": self._offset,
+ "name": self._name,
+ }
+
+
+class PowerDecayWithOffset(tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Power learning rate decay with offset.
+
+ Learning rate equals to `pre_offset_learning_rate` if `step` < `offset`.
+ Otherwise, learning rate equals to lr * (step - offset)^power.
+ """
+
+ def __init__(self,
+ initial_learning_rate: float,
+ power: float = 1.0,
+ offset: int = 0,
+ pre_offset_learning_rate: float = 1.0e6,
+ name: str = "PowerDecayWithOffset"):
+ """Initialize configuration of the learning rate schedule.
+
+ Args:
+ initial_learning_rate: The initial learning rate.
+ power: The order of the polynomial.
+ offset: The offset when computing the power decay.
+ pre_offset_learning_rate: The maximum learning rate we'll use.
+ name: Optional, name of learning rate schedule.
+ """
+ super().__init__()
+ self._initial_learning_rate = initial_learning_rate
+ self._power = power
+ self._offset = offset
+ self._pre_offset_lr = pre_offset_learning_rate
+ self._name = name
+
+ def __call__(self, step):
+ with tf.name_scope(self._name or "PowerDecayWithOffset"):
+ step = tf.cast(step, tf.float32)
+ lr_after_offset = tf.math.pow(
+ tf.math.maximum(step - self._offset, 1.0), self._power) * (
+ self._initial_learning_rate)
+
+ sign = tf.cast(step > self._offset, tf.float32)
+ lr_combined = (1.0 - sign) * self._pre_offset_lr + sign * lr_after_offset
+ # Power may give infinitely large LR. So cap it with pre_offset_lr.
+ return tf.math.minimum(lr_combined, self._pre_offset_lr)
+
+ def get_config(self):
+ """Get the configuration of the learning rate schedule."""
+ return {
+ "initial_learning_rate": self._initial_learning_rate,
+ "power": self._power,
+ "offset": self._offset,
+ "pre_offset_learning_rate": self._pre_offset_lr,
+ "name": self._name,
+ }
+
+
+class StepCosineDecayWithOffset(
+ tf_keras.optimizers.schedules.LearningRateSchedule):
+ """Stepwise cosine learning rate decay with offset.
+
+ Learning rate is equivalent to one or more cosine decay(s) starting and
+ ending at each interval.
+
+ ExampleL
+
+ ```python
+ boundaries: [100000, 110000]
+ values: [1.0, 0.5]
+ lr_decayed_fn = (
+ lr_schedule.StepCosineDecayWithOffset(
+ boundaries,
+ values))
+ ```
+
+ from 0 to 100000 step, it will cosine decay from 1.0 to 0.5
+ from 100000 to 110000 step, it cosine decay from 0.5 to 0.0
+ """
+
+ def __init__(self,
+ boundaries,
+ values,
+ offset: int = 0,
+ name: str = "StepCosineDecayWithOffset"):
+ """Initialize configuration of the learning rate schedule.
+
+ Args:
+ boundaries: A list of `Tensor`s or `int`s with strictly
+ increasing entries, and with all elements having the same type as the
+ optimizer step.
+ values: A list of `Tensor`s or `float`s that specifies the
+ values for the intervals defined by `boundaries`. It should have one
+ more element than `boundaries`, and all elements should have the same
+ type.
+ offset: The offset when computing the power decay.
+ name: Optional, name of learning rate schedule.
+ """
+ super().__init__()
+ self.values = values
+ self.boundaries = boundaries
+ self.offset = offset
+ self.name = name
+
+ if len(self.values) < 1:
+ raise ValueError(f"Expect non empty {self.values}")
+ if len(self.boundaries) != len(self.values):
+ raise ValueError(
+ "Boundaries length is equal to learning rate levels length"
+ f"{len(self.boundaries)} != {len(self.values)}")
+
+ self.total_steps = (
+ [boundaries[i + 1] - boundaries[i] for i in range(len(boundaries) - 1)
+ ] + [0])
+
+ def __call__(self, global_step):
+ with tf.name_scope(self.name or "StepCosineDecayWithOffset"):
+ global_step = tf.cast(global_step - self.offset, tf.float32)
+ lr_levels = self.values
+ lr_steps = self.boundaries
+ level_total_steps = self.total_steps
+ num_levels = len(lr_levels)
+
+ init_lr = lr_levels[0]
+ next_init_lr = lr_levels[1] if num_levels > 1 else 0.
+
+ init_total_steps = level_total_steps[0]
+
+ cosine_learning_rate = ((init_lr - next_init_lr) * (tf.cos(
+ tf.constant(math.pi) * (global_step) /
+ (init_total_steps)) + 1.0) / 2.0 + next_init_lr)
+ learning_rate = cosine_learning_rate
+
+ for i in range(1, num_levels):
+ next_init_lr = lr_levels[i]
+ next_start_step = lr_steps[i]
+ next_total_steps = level_total_steps[i]
+ next_next_init_lr = lr_levels[i + 1] if num_levels > i + 1 else 0.
+
+ next_cosine_learning_rate = ((next_init_lr - next_next_init_lr) *
+ (tf.cos(
+ tf.constant(math.pi) *
+ (global_step - next_start_step) /
+ (next_total_steps)) + 1.0) / 2.0 +
+ next_next_init_lr)
+ learning_rate = tf.where(global_step >= next_start_step,
+ next_cosine_learning_rate, learning_rate)
+
+ return learning_rate
+
+ def get_config(self):
+ return {
+ "boundaries": self.boundaries,
+ "values": self.values,
+ "offset": self.offset,
+ "name": self.name
+ }
diff --git a/modeling/official/modeling/optimization/lr_schedule_test.py b/modeling/official/modeling/optimization/lr_schedule_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b8f3b50da7bd7f6499e550fe165093542b3faec
--- /dev/null
+++ b/modeling/official/modeling/optimization/lr_schedule_test.py
@@ -0,0 +1,109 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for lr_schedule."""
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from official.modeling.optimization import lr_schedule
+
+
+class PowerAndLinearDecayTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name='power_only',
+ init_lr=1.0,
+ power=-1.0,
+ linear_decay_fraction=0.0,
+ total_decay_steps=100,
+ offset=0,
+ expected=[[0, 1.0], [1, 1.0], [40, 1. / 40.], [60, 1. / 60],
+ [100, 1. / 100]]),
+ dict(
+ testcase_name='linear_only',
+ init_lr=1.0,
+ power=0.0,
+ linear_decay_fraction=1.0,
+ total_decay_steps=100,
+ offset=0,
+ expected=[[0, 1.0], [1, 0.99], [40, 0.6], [60, 0.4], [100, 0.0]]),
+ dict(
+ testcase_name='general',
+ init_lr=1.0,
+ power=-1.0,
+ linear_decay_fraction=0.5,
+ total_decay_steps=100,
+ offset=0,
+ expected=[[0, 1.0], [1, 1.0], [40, 1. / 40.],
+ [60, 1. / 60. * 0.8], [100, 0.0]]),
+ dict(
+ testcase_name='offset',
+ init_lr=1.0,
+ power=-1.0,
+ linear_decay_fraction=0.5,
+ total_decay_steps=100,
+ offset=90,
+ expected=[[0, 1.0], [90, 1.0], [91, 1.0], [130, 1. / 40.],
+ [150, 1. / 60. * 0.8], [190, 0.0], [200, 0.0]]),
+ )
+ def test_power_linear_lr_schedule(self, init_lr, power, linear_decay_fraction,
+ total_decay_steps, offset, expected):
+ lr = lr_schedule.PowerAndLinearDecay(
+ initial_learning_rate=init_lr,
+ power=power,
+ linear_decay_fraction=linear_decay_fraction,
+ total_decay_steps=total_decay_steps,
+ offset=offset)
+ for step, value in expected:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+
+class OffsetLearningRateTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(
+ dict(class_name=lr_schedule.PiecewiseConstantDecayWithOffset),
+ dict(class_name=lr_schedule.PolynomialDecayWithOffset),
+ dict(class_name=lr_schedule.ExponentialDecayWithOffset),
+ dict(class_name=lr_schedule.CosineDecayWithOffset),
+ )
+ def test_generated_docstring(self, class_name):
+ self.assertNotEmpty(class_name.__init__.__doc__)
+
+ @parameterized.parameters(
+ dict(
+ class_name=lr_schedule.PiecewiseConstantDecayWithOffset,
+ kwarg=dict(boundaries=[50, 80], values=[1.0, 0.5, 0.1])),
+ dict(
+ class_name=lr_schedule.PolynomialDecayWithOffset,
+ kwarg=dict(initial_learning_rate=1.0, decay_steps=100)),
+ dict(
+ class_name=lr_schedule.ExponentialDecayWithOffset,
+ kwarg=dict(
+ initial_learning_rate=1.0, decay_steps=100, decay_rate=0.5)),
+ dict(
+ class_name=lr_schedule.CosineDecayWithOffset,
+ kwarg=dict(initial_learning_rate=1.0, decay_steps=100)),
+ )
+ def test_offset(self, class_name, kwarg):
+ offset = 10
+ offset_lr = class_name(offset=offset, **kwarg)
+ base_lr = class_name.base_lr_class(**kwarg)
+ self.assertIsInstance(offset_lr, class_name)
+ for step in range(10, 101, 10):
+ self.assertEqual(offset_lr(step), base_lr(step - offset))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/optimization/optimizer_factory.py b/modeling/official/modeling/optimization/optimizer_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c37635537203fc223477cc5e128d55cab18103f
--- /dev/null
+++ b/modeling/official/modeling/optimization/optimizer_factory.py
@@ -0,0 +1,267 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Optimizer factory class."""
+from typing import Callable, List, Optional, Tuple, Union
+
+import gin
+import tensorflow as tf, tf_keras
+
+from official.modeling.optimization import slide_optimizer
+from official.modeling.optimization import adafactor_optimizer
+from official.modeling.optimization import ema_optimizer
+from official.modeling.optimization import lamb
+from official.modeling.optimization import lars
+from official.modeling.optimization import legacy_adamw
+from official.modeling.optimization import lr_schedule
+from official.modeling.optimization.configs import optimization_config as opt_cfg
+
+# Optimizer CLS to be used in both legacy and new path.
+SHARED_OPTIMIZERS = {
+ 'sgd_experimental': tf_keras.optimizers.experimental.SGD,
+ 'adam_experimental': tf_keras.optimizers.experimental.Adam,
+ 'adamw': legacy_adamw.AdamWeightDecay,
+ 'adamw_experimental': tf_keras.optimizers.experimental.AdamW,
+ 'lamb': lamb.LAMB,
+ 'lars': lars.LARS,
+ 'slide': slide_optimizer.SLIDE,
+ 'adafactor': adafactor_optimizer.Adafactor,
+ 'adafactor_keras': tf_keras.optimizers.Adafactor,
+}
+
+LEGACY_OPTIMIZERS_CLS = {
+ 'sgd': tf_keras.optimizers.legacy.SGD,
+ 'adam': tf_keras.optimizers.legacy.Adam,
+ 'rmsprop': tf_keras.optimizers.legacy.RMSprop,
+ 'adagrad': tf_keras.optimizers.legacy.Adagrad,
+}
+LEGACY_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)
+
+NEW_OPTIMIZERS_CLS = {
+ 'sgd': tf_keras.optimizers.experimental.SGD,
+ 'adam': tf_keras.optimizers.experimental.Adam,
+ 'rmsprop': tf_keras.optimizers.experimental.RMSprop,
+ 'adagrad': tf_keras.optimizers.experimental.Adagrad,
+}
+NEW_OPTIMIZERS_CLS.update(SHARED_OPTIMIZERS)
+
+LR_CLS = {
+ 'stepwise': lr_schedule.PiecewiseConstantDecayWithOffset,
+ 'polynomial': lr_schedule.PolynomialDecayWithOffset,
+ 'exponential': lr_schedule.ExponentialDecayWithOffset,
+ 'cosine': lr_schedule.CosineDecayWithOffset,
+ 'power': lr_schedule.DirectPowerDecay,
+ 'power_linear': lr_schedule.PowerAndLinearDecay,
+ 'power_with_offset': lr_schedule.PowerDecayWithOffset,
+ 'step_cosine_with_offset': lr_schedule.StepCosineDecayWithOffset,
+}
+
+WARMUP_CLS = {
+ 'linear': lr_schedule.LinearWarmup,
+ 'polynomial': lr_schedule.PolynomialWarmUp
+}
+
+
+def register_optimizer_cls(key: str,
+ optimizer_config_cls: Union[
+ tf_keras.optimizers.Optimizer,
+ tf_keras.optimizers.legacy.Optimizer,
+ tf_keras.optimizers.experimental.Optimizer
+ ],
+ use_legacy_optimizer: bool = True):
+ """Register customize optimizer cls.
+
+ The user will still need to subclass data classes in
+ configs.optimization_config to be used with OptimizerFactory.
+
+ Args:
+ key: A string to that the optimizer_config_cls is registered with.
+ optimizer_config_cls: A class which inherits tf_keras.optimizers.Optimizer.
+ use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
+ """
+ if use_legacy_optimizer:
+ if key in LEGACY_OPTIMIZERS_CLS:
+ raise ValueError('%s already registered in LEGACY_OPTIMIZERS_CLS.' % key)
+ LEGACY_OPTIMIZERS_CLS[key] = optimizer_config_cls
+ else:
+ if key in NEW_OPTIMIZERS_CLS:
+ raise ValueError('%s already registered in NEW_OPTIMIZERS_CLS.' % key)
+ NEW_OPTIMIZERS_CLS[key] = optimizer_config_cls
+
+
+class OptimizerFactory:
+ """Optimizer factory class.
+
+ This class builds learning rate and optimizer based on an optimization config.
+ To use this class, you need to do the following:
+ (1) Define optimization config, this includes optimizer, and learning rate
+ schedule.
+ (2) Initialize the class using the optimization config.
+ (3) Build learning rate.
+ (4) Build optimizer.
+
+ This is a typical example for using this class:
+
+ ```
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {'momentum': 0.9}
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]}
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {'warmup_steps': 500, 'warmup_learning_rate': 0.01}
+ }
+ }
+ opt_config = OptimizationConfig(params)
+ opt_factory = OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+ optimizer = opt_factory.build_optimizer(lr)
+ ```
+ """
+
+ def __init__(self, config: opt_cfg.OptimizationConfig):
+ """Initializing OptimizerFactory.
+
+ Args:
+ config: OptimizationConfig instance contain optimization config.
+ """
+ self._config = config
+ self._optimizer_config = config.optimizer.get()
+ self._optimizer_type = config.optimizer.type
+
+ self._use_ema = config.ema is not None
+ self._ema_config = config.ema
+
+ if self._optimizer_config is None:
+ raise ValueError('Optimizer type must be specified')
+
+ self._lr_config = config.learning_rate.get()
+ self._lr_type = config.learning_rate.type
+
+ if self._lr_type is None:
+ raise ValueError('Learning rate type must be specified')
+
+ self._warmup_config = config.warmup.get()
+ self._warmup_type = config.warmup.type
+
+ def build_learning_rate(self):
+ """Build learning rate.
+
+ Builds learning rate from config. Learning rate schedule is built according
+ to the learning rate config. If learning rate type is consant,
+ lr_config.learning_rate is returned.
+
+ Returns:
+ tf_keras.optimizers.schedules.LearningRateSchedule instance. If
+ learning rate type is consant, lr_config.learning_rate is returned.
+ """
+ if self._lr_type == 'constant':
+ lr = self._lr_config.learning_rate
+ else:
+ lr = LR_CLS[self._lr_type](**self._lr_config.as_dict())
+
+ if self._warmup_config:
+ lr = WARMUP_CLS[self._warmup_type](lr, **self._warmup_config.as_dict())
+
+ return lr
+
+ @gin.configurable
+ def build_optimizer(
+ self,
+ lr: Union[tf_keras.optimizers.schedules.LearningRateSchedule, float],
+ gradient_aggregator: Optional[Callable[
+ [List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
+ tf.Tensor]]]] = None,
+ gradient_transformers: Optional[List[Callable[
+ [List[Tuple[tf.Tensor, tf.Tensor]]], List[Tuple[tf.Tensor,
+ tf.Tensor]]]]] = None,
+ postprocessor: Optional[Callable[[tf_keras.optimizers.Optimizer],
+ tf_keras.optimizers.Optimizer]] = None,
+ use_legacy_optimizer: bool = True):
+ """Build optimizer.
+
+ Builds optimizer from config. It takes learning rate as input, and builds
+ the optimizer according to the optimizer config. Typically, the learning
+ rate built using self.build_lr() is passed as an argument to this method.
+
+ Args:
+ lr: A floating point value, or a
+ tf_keras.optimizers.schedules.LearningRateSchedule instance.
+ gradient_aggregator: Optional function to overwrite gradient aggregation.
+ gradient_transformers: Optional list of functions to use to transform
+ gradients before applying updates to Variables. The functions are
+ applied after gradient_aggregator. The functions should accept and
+ return a list of (gradient, variable) tuples. clipvalue, clipnorm,
+ global_clipnorm should not be set when gradient_transformers is passed.
+ postprocessor: An optional function for postprocessing the optimizer. It
+ takes an optimizer and returns an optimizer.
+ use_legacy_optimizer: A boolean that indicates if using legacy optimizers.
+
+ Returns:
+ `tf_keras.optimizers.legacy.Optimizer` or
+ `tf_keras.optimizers.experimental.Optimizer` instance.
+ """
+
+ optimizer_dict = self._optimizer_config.as_dict()
+ ## Delete clipnorm, clipvalue, global_clipnorm if None
+ if optimizer_dict['clipnorm'] is None:
+ del optimizer_dict['clipnorm']
+ if optimizer_dict['clipvalue'] is None:
+ del optimizer_dict['clipvalue']
+ if optimizer_dict['global_clipnorm'] is None:
+ del optimizer_dict['global_clipnorm']
+
+ optimizer_dict['learning_rate'] = lr
+ if gradient_aggregator is not None:
+ optimizer_dict['gradient_aggregator'] = gradient_aggregator
+ if gradient_transformers is not None:
+ optimizer_dict['gradient_transformers'] = gradient_transformers
+
+ if use_legacy_optimizer:
+ optimizer = LEGACY_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
+ else:
+ if 'decay' in optimizer_dict:
+ raise ValueError(
+ '`decay` is deprecated in new Keras optimizer, please reflect the '
+ 'decay logic in `lr` or set `use_legacy_optimizer=True` to use the '
+ 'legacy optimizer.')
+ optimizer = NEW_OPTIMIZERS_CLS[self._optimizer_type](**optimizer_dict)
+
+ if self._use_ema:
+ if not use_legacy_optimizer:
+ raise ValueError(
+ 'EMA can only work with the legacy optimizer, please set '
+ '`use_legacy_optimizer=True`.')
+ optimizer = ema_optimizer.ExponentialMovingAverage(
+ optimizer, **self._ema_config.as_dict())
+ if postprocessor:
+ optimizer = postprocessor(optimizer)
+ if isinstance(optimizer, tf_keras.optimizers.Optimizer):
+ return optimizer
+ # The following check makes sure the function won't break in older TF
+ # version because of missing the experimental/legacy package.
+ if hasattr(tf_keras.optimizers, 'experimental'):
+ if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer):
+ return optimizer
+ if hasattr(tf_keras.optimizers, 'legacy'):
+ if isinstance(optimizer, tf_keras.optimizers.legacy.Optimizer):
+ return optimizer
+ raise TypeError('OptimizerFactory.build_optimizer returning a '
+ 'non-optimizer object: {}'.format(optimizer))
diff --git a/modeling/official/modeling/optimization/optimizer_factory_test.py b/modeling/official/modeling/optimization/optimizer_factory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..95422fafc42d5d2fb2deb91cf3149f0cef826d8c
--- /dev/null
+++ b/modeling/official/modeling/optimization/optimizer_factory_test.py
@@ -0,0 +1,530 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for optimizer_factory.py."""
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.modeling.optimization import optimizer_factory
+from official.modeling.optimization.configs import optimization_config
+
+
+class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'),
+ ('lars'), ('adagrad'))
+ def test_optimizers(self, optimizer_type):
+ params = {
+ 'optimizer': {
+ 'type': optimizer_type
+ },
+ 'learning_rate': {
+ 'type': 'constant',
+ 'constant': {
+ 'learning_rate': 0.1
+ }
+ }
+ }
+ optimizer_cls = optimizer_factory.LEGACY_OPTIMIZERS_CLS[optimizer_type]
+ expected_optimizer_config = optimizer_cls().get_config()
+ expected_optimizer_config['learning_rate'] = 0.1
+
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+ optimizer = opt_factory.build_optimizer(lr, postprocessor=lambda x: x)
+
+ self.assertIsInstance(optimizer, optimizer_cls)
+ self.assertEqual(expected_optimizer_config, optimizer.get_config())
+
+ @parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'),
+ ('lars'), ('adagrad'))
+ def test_new_optimizers(self, optimizer_type):
+ params = {
+ 'optimizer': {
+ 'type': optimizer_type
+ },
+ 'learning_rate': {
+ 'type': 'constant',
+ 'constant': {
+ 'learning_rate': 0.1
+ }
+ }
+ }
+ optimizer_cls = optimizer_factory.NEW_OPTIMIZERS_CLS[optimizer_type]
+ expected_optimizer_config = optimizer_cls().get_config()
+ expected_optimizer_config['learning_rate'] = 0.1
+
+ opt_config = optimization_config.OptimizationConfig(params)
+ if optimizer_type == 'sgd':
+ # Delete unsupported arg `decay` from SGDConfig.
+ delattr(opt_config.optimizer.sgd, 'decay')
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+ optimizer = opt_factory.build_optimizer(
+ lr, postprocessor=lambda x: x, use_legacy_optimizer=False)
+
+ self.assertIsInstance(optimizer, optimizer_cls)
+ self.assertEqual(expected_optimizer_config, optimizer.get_config())
+
+ def test_gradient_aggregator(self):
+ params = {
+ 'optimizer': {
+ 'type': 'adam',
+ },
+ 'learning_rate': {
+ 'type': 'constant',
+ 'constant': {
+ 'learning_rate': 1.0
+ }
+ }
+ }
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ # Dummy function to zero out gradients.
+ zero_grads = lambda gv: [(tf.zeros_like(g), v) for g, v in gv]
+
+ optimizer = opt_factory.build_optimizer(lr, gradient_aggregator=zero_grads)
+ if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer):
+ self.skipTest('New Keras optimizer does not support '
+ '`gradient_aggregator` arg.')
+
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+
+ grads0 = tf.constant([1.0, 1.0])
+ grads1 = tf.constant([1.0, 1.0])
+
+ grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
+ optimizer.apply_gradients(grads_and_vars)
+
+ self.assertAllClose(np.array([1.0, 2.0]), var0.numpy())
+ self.assertAllClose(np.array([3.0, 4.0]), var1.numpy())
+
+ @parameterized.parameters((None, None), (1.0, None), (None, 1.0))
+ def test_gradient_clipping(self, clipnorm, clipvalue):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'clipnorm': clipnorm,
+ 'clipvalue': clipvalue
+ }
+ },
+ 'learning_rate': {
+ 'type': 'constant',
+ 'constant': {
+ 'learning_rate': 1.0
+ }
+ }
+ }
+
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+ optimizer = opt_factory.build_optimizer(lr)
+
+ var0 = tf.Variable([1.0, 2.0])
+ var1 = tf.Variable([3.0, 4.0])
+
+ grads0 = tf.constant([0.1, 0.1])
+ grads1 = tf.constant([2.0, 3.0])
+
+ grads_and_vars = list(zip([grads0, grads1], [var0, var1]))
+ optimizer.apply_gradients(grads_and_vars)
+
+ self.assertAllClose(np.array([0.9, 1.9]), var0.numpy())
+ if clipvalue is not None:
+ self.assertAllClose(np.array([2.0, 3.0]), var1.numpy())
+ elif clipnorm is not None:
+ self.assertAllClose(np.array([2.4452999, 3.1679497]), var1.numpy())
+ else:
+ self.assertAllClose(np.array([1.0, 1.0]), var1.numpy())
+
+ def test_missing_types(self):
+ params = {'optimizer': {'type': 'sgd', 'sgd': {'momentum': 0.9}}}
+ with self.assertRaises(ValueError):
+ optimizer_factory.OptimizerFactory(
+ optimization_config.OptimizationConfig(params))
+ params = {
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {
+ 'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]
+ }
+ }
+ }
+ with self.assertRaises(ValueError):
+ optimizer_factory.OptimizerFactory(
+ optimization_config.OptimizationConfig(params))
+
+ def test_wrong_return_type(self):
+ optimizer_type = 'sgd'
+ params = {
+ 'optimizer': {
+ 'type': optimizer_type
+ },
+ 'learning_rate': {
+ 'type': 'constant',
+ 'constant': {
+ 'learning_rate': 0.1
+ }
+ }
+ }
+
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ with self.assertRaises(TypeError):
+ _ = opt_factory.build_optimizer(0.1, postprocessor=lambda x: None)
+
+
+# TODO(b/187559334) refactor lr_schedule tests into `lr_schedule_test.py`.
+
+ def test_stepwise_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {
+ 'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 0.1], [5000, 0.1], [10000, 0.1],
+ [10001, 0.01], [20000, 0.01], [20001, 0.001]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_stepwise_lr_with_warmup_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {
+ 'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]
+ }
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {
+ 'warmup_steps': 500,
+ 'warmup_learning_rate': 0.01
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 0.01], [250, 0.055], [500, 0.1], [5500, 0.1],
+ [10000, 0.1], [10001, 0.01], [20000, 0.01],
+ [20001, 0.001]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_exponential_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'exponential',
+ 'exponential': {
+ 'initial_learning_rate': 0.1,
+ 'decay_steps': 1000,
+ 'decay_rate': 0.96,
+ 'staircase': True
+ }
+ }
+ }
+ expected_lr_step_values = [
+ [0, 0.1],
+ [999, 0.1],
+ [1000, 0.096],
+ [1999, 0.096],
+ [2000, 0.09216],
+ ]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_polynomial_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 0.1,
+ 'decay_steps': 1000,
+ 'end_learning_rate': 0.001
+ }
+ }
+ }
+
+ expected_lr_step_values = [[0, 0.1], [500, 0.0505], [1000, 0.001]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_cosine_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'cosine',
+ 'cosine': {
+ 'initial_learning_rate': 0.1,
+ 'decay_steps': 1000
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 0.1], [250, 0.08535534], [500, 0.04999999],
+ [750, 0.01464466], [1000, 0]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_constant_lr_with_warmup_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'constant',
+ 'constant': {
+ 'learning_rate': 0.1
+ }
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {
+ 'warmup_steps': 500,
+ 'warmup_learning_rate': 0.01
+ }
+ }
+ }
+
+ expected_lr_step_values = [[0, 0.01], [250, 0.055], [500, 0.1], [5000, 0.1],
+ [10000, 0.1], [20000, 0.1]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_stepwise_lr_with_polynomial_warmup_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'stepwise',
+ 'stepwise': {
+ 'boundaries': [10000, 20000],
+ 'values': [0.1, 0.01, 0.001]
+ }
+ },
+ 'warmup': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'warmup_steps': 500,
+ 'power': 2.
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 0.0], [250, 0.025], [500, 0.1], [5500, 0.1],
+ [10000, 0.1], [10001, 0.01], [20000, 0.01],
+ [20001, 0.001]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value, places=6)
+
+ def test_power_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'power',
+ 'power': {
+ 'initial_learning_rate': 1.0,
+ 'power': -1.0
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 1.0], [1, 1.0], [250, 1. / 250.]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_power_linear_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'power_linear',
+ 'power_linear': {
+ 'initial_learning_rate': 1.0,
+ 'power': -1.0,
+ 'linear_decay_fraction': 0.5,
+ 'total_decay_steps': 100,
+ 'offset': 0,
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 1.0], [1, 1.0], [40, 1. / 40.],
+ [60, 1. / 60. * 0.8]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_power_with_offset_lr_schedule(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'power_with_offset',
+ 'power_with_offset': {
+ 'initial_learning_rate': 1.0,
+ 'power': -1.0,
+ 'offset': 10,
+ 'pre_offset_learning_rate': 3.0,
+ }
+ }
+ }
+ expected_lr_step_values = [[1, 3.0], [10, 3.0], [20, 1. / 10.]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+ def test_step_cosine_lr_schedule_with_warmup(self):
+ params = {
+ 'optimizer': {
+ 'type': 'sgd',
+ 'sgd': {
+ 'momentum': 0.9
+ }
+ },
+ 'learning_rate': {
+ 'type': 'step_cosine_with_offset',
+ 'step_cosine_with_offset': {
+ 'values': (0.0001, 0.00005),
+ 'boundaries': (0, 500000),
+ 'offset': 10000,
+ }
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {
+ 'warmup_steps': 10000,
+ 'warmup_learning_rate': 0.0
+ }
+ }
+ }
+ expected_lr_step_values = [[0, 0.0], [5000, 1e-4 / 2.0], [10000, 1e-4],
+ [20000, 9.994863e-05], [499999, 5e-05]]
+ opt_config = optimization_config.OptimizationConfig(params)
+ opt_factory = optimizer_factory.OptimizerFactory(opt_config)
+ lr = opt_factory.build_learning_rate()
+
+ for step, value in expected_lr_step_values:
+ self.assertAlmostEqual(lr(step).numpy(), value)
+
+
+class OptimizerFactoryRegistryTest(tf.test.TestCase):
+
+ def test_registry(self):
+
+ class MyClass():
+ pass
+
+ optimizer_factory.register_optimizer_cls('test', MyClass)
+ self.assertIn('test', optimizer_factory.LEGACY_OPTIMIZERS_CLS)
+ with self.assertRaisesRegex(ValueError, 'test already registered.*'):
+ optimizer_factory.register_optimizer_cls('test', MyClass)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/optimization/slide_optimizer.py b/modeling/official/modeling/optimization/slide_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a15cbe11296192943e104e7611d7ac704f7939c0
--- /dev/null
+++ b/modeling/official/modeling/optimization/slide_optimizer.py
@@ -0,0 +1,20 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""SLIDE optimizer.
+
+A new optimizer that will be open sourced soon.
+"""
+
+SLIDE = "Unimplemented"
diff --git a/modeling/official/modeling/performance.py b/modeling/official/modeling/performance.py
new file mode 100644
index 0000000000000000000000000000000000000000..821001a7069546ea3f003188e8ba04c8286916e5
--- /dev/null
+++ b/modeling/official/modeling/performance.py
@@ -0,0 +1,53 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions and classes related to training performance."""
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+
+def configure_optimizer(optimizer,
+ use_float16=False,
+ loss_scale=None,
+ use_graph_rewrite=None):
+ """Configures optimizer object with performance options."""
+ if use_graph_rewrite is not None:
+ logging.warning('`use_graph_rewrite` is deprecated inside '
+ '`configure_optimizer`. Please remove the usage.')
+ del use_graph_rewrite
+ if use_float16:
+ if loss_scale in (None, 'dynamic'):
+ optimizer = tf_keras.mixed_precision.LossScaleOptimizer(optimizer)
+ else:
+ # loss_scale is a number. We interpret that as a fixed loss scale.
+ optimizer = tf_keras.mixed_precision.LossScaleOptimizer(
+ optimizer, dynamic=False, initial_scale=loss_scale)
+ return optimizer
+
+
+def set_mixed_precision_policy(dtype, loss_scale=None):
+ """Sets the global `tf_keras.mixed_precision.Policy`."""
+ # TODO(b/191894773): Remove loss_scale argument
+ assert loss_scale is None, (
+ 'The loss_scale argument must be None. The argument exists for '
+ 'historical reasons and will be removed soon.')
+ if dtype == tf.float16:
+ tf_keras.mixed_precision.set_global_policy('mixed_float16')
+ elif dtype == tf.bfloat16:
+ tf_keras.mixed_precision.set_global_policy('mixed_bfloat16')
+ elif dtype == tf.float32:
+ tf_keras.mixed_precision.set_global_policy('float32')
+ else:
+ raise ValueError('Unexpected dtype: %s' % dtype)
diff --git a/modeling/official/modeling/privacy/__init__.py b/modeling/official/modeling/privacy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/modeling/privacy/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/modeling/privacy/configs.py b/modeling/official/modeling/privacy/configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..91e666d19882ab91bcad2b5a8ea65af96670c499
--- /dev/null
+++ b/modeling/official/modeling/privacy/configs.py
@@ -0,0 +1,26 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Configs for differential privacy."""
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+
+
+@dataclasses.dataclass
+class DifferentialPrivacyConfig(base_config.Config):
+ # Applied to the gradients
+ # Setting to a large number so nothing is clipped.
+ clipping_norm: float = 100000000.0 # 10^9
+ noise_multiplier: float = 0.0
diff --git a/modeling/official/modeling/privacy/configs_test.py b/modeling/official/modeling/privacy/configs_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dad44dc2a3ee76557b6094505ea3cd63af80acab
--- /dev/null
+++ b/modeling/official/modeling/privacy/configs_test.py
@@ -0,0 +1,41 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for configs."""
+
+import tensorflow as tf, tf_keras
+from official.modeling.privacy import configs
+
+
+class ConfigsTest(tf.test.TestCase):
+
+ def test_clipping_norm_default(self):
+ clipping_norm = configs.DifferentialPrivacyConfig().clipping_norm
+ self.assertEqual(100000000.0, clipping_norm)
+
+ def test_noise_multiplier_default(self):
+ noise_multiplier = configs.DifferentialPrivacyConfig().noise_multiplier
+ self.assertEqual(0.0, noise_multiplier)
+
+ def test_config(self):
+ dp_config = configs.DifferentialPrivacyConfig(
+ clipping_norm=1.0,
+ noise_multiplier=1.0,
+ )
+ self.assertEqual(1.0, dp_config.clipping_norm)
+ self.assertEqual(1.0, dp_config.noise_multiplier)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/privacy/ops.py b/modeling/official/modeling/privacy/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..85656bf105757ad48f50e5f75c5054c4795b0371
--- /dev/null
+++ b/modeling/official/modeling/privacy/ops.py
@@ -0,0 +1,63 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Ops for differential privacy (gradient) transforms."""
+
+from typing import List, Tuple
+import warnings
+
+import tensorflow as tf, tf_keras
+
+
+def clip_l2_norm(grads_vars: List[Tuple[tf.Tensor, tf.Tensor]],
+ l2_norm_clip: float) -> List[Tuple[tf.Tensor, tf.Tensor]]:
+ """DEPRECATED Clip gradients by global norm.
+
+ Args:
+ grads_vars: List of tuple of gradient and its corresponding variables
+ l2_norm_clip: Float for differential privacy norm
+
+ Returns:
+ List of clipped gradients and its corresponding variables
+ """
+ warnings.warn("`clip_l2_norm` deprecated.",
+ DeprecationWarning)
+
+ gradients = []
+ variables = []
+ for (g, v) in grads_vars:
+ gradients.append(g)
+ variables.append(v)
+ clipped_gradients = tf.clip_by_global_norm(gradients, l2_norm_clip)[0]
+ return list(zip(clipped_gradients, variables))
+
+
+def add_noise(grads_vars: List[Tuple[tf.Tensor, tf.Tensor]],
+ noise_stddev: float) -> List[Tuple[tf.Tensor, tf.Tensor]]:
+ """DEPRECATED Add noise to gradients.
+
+ Args:
+ grads_vars: List of tuple of gradient and its corresponding variables
+ noise_stddev: Noise multiplier
+
+ Returns:
+ List of noised gradients and its corresponding variables
+ """
+ warnings.warn("`add_noise` deprecated.", DeprecationWarning)
+
+ ret = []
+ for (g, v) in grads_vars:
+ noise = tf.random.normal(tf.shape(g), stddev=noise_stddev)
+ ret.append((g + noise, v))
+ return ret
diff --git a/modeling/official/modeling/privacy/ops_test.py b/modeling/official/modeling/privacy/ops_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f52386366d97d5a6305d05c01029cc5d6a6228df
--- /dev/null
+++ b/modeling/official/modeling/privacy/ops_test.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for ops."""
+
+from unittest import mock
+
+import tensorflow as tf, tf_keras
+
+from official.modeling.privacy import ops
+
+
+class OpsTest(tf.test.TestCase):
+
+ def test_clip_l2_norm(self):
+ x = tf.constant([4.0, 3.0])
+ y = tf.constant([[12.0]])
+ tensors = [(x, x), (y, y)]
+ clipped = ops.clip_l2_norm(tensors, 1.0)
+ for a, b in zip(clipped, tensors):
+ self.assertAllClose(a[0], b[0] / 13.0) # sqrt(4^2 + 3^2 + 12 ^3) = 13
+ self.assertAllClose(a[1], b[1])
+
+ @mock.patch.object(tf.random,
+ 'normal',
+ autospec=True)
+ def test_add_noise(self, mock_random):
+ x = tf.constant([0.0, 0.0])
+ y = tf.constant([[0.0]])
+ tensors = [(x, x), (y, y)]
+ mock_random.side_effect = [tf.constant([1.0, 1.0]), tf.constant([[1.0]])]
+ added = ops.add_noise(tensors, 10.0)
+ for a, b in zip(added, tensors):
+ self.assertAllClose(a[0], b[0] + 1.0)
+ self.assertAllClose(a[1], b[1])
+ _, kwargs = mock_random.call_args
+ self.assertEqual(kwargs['stddev'], 10.0)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/modeling/tf_utils.py b/modeling/official/modeling/tf_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c0bd9c8964bb4f1dc4e3a39db7821adb1cffb47
--- /dev/null
+++ b/modeling/official/modeling/tf_utils.py
@@ -0,0 +1,372 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common TF utilities."""
+
+import functools
+import inspect
+import six
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.util import deprecation
+from official.modeling import activations
+
+
+@deprecation.deprecated(
+ None,
+ "tf_keras.layers.Layer supports multiple positional args and kwargs as "
+ "input tensors. pack/unpack inputs to override __call__ is no longer "
+ "needed.")
+def pack_inputs(inputs):
+ """Pack a list of `inputs` tensors to a tuple.
+
+ Args:
+ inputs: a list of tensors.
+
+ Returns:
+ a tuple of tensors. if any input is None, replace it with a special constant
+ tensor.
+ """
+ inputs = tf.nest.flatten(inputs)
+ outputs = []
+ for x in inputs:
+ if x is None:
+ outputs.append(tf.constant(0, shape=[], dtype=tf.int32))
+ else:
+ outputs.append(x)
+ return tuple(outputs)
+
+
+@deprecation.deprecated(
+ None,
+ "tf_keras.layers.Layer supports multiple positional args and kwargs as "
+ "input tensors. pack/unpack inputs to override __call__ is no longer "
+ "needed.")
+def unpack_inputs(inputs):
+ """unpack a tuple of `inputs` tensors to a tuple.
+
+ Args:
+ inputs: a list of tensors.
+
+ Returns:
+ a tuple of tensors. if any input is a special constant tensor, replace it
+ with None.
+ """
+ inputs = tf.nest.flatten(inputs)
+ outputs = []
+ for x in inputs:
+ if is_special_none_tensor(x):
+ outputs.append(None)
+ else:
+ outputs.append(x)
+ x = tuple(outputs)
+
+ # To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
+ # from triggering.
+ if len(x) == 1:
+ return x[0]
+ return tuple(outputs)
+
+
+def is_special_none_tensor(tensor):
+ """Checks if a tensor is a special None Tensor."""
+ return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
+
+
+def get_activation(identifier, use_keras_layer=False, **kwargs):
+ """Maps an identifier to a Python function, e.g., "relu" => `tf.nn.relu`.
+
+ It checks string first and if it is one of customized activation not in TF,
+ the corresponding activation will be returned. For non-customized activation
+ names and callable identifiers, always fallback to tf_keras.activations.get.
+
+ Prefers using keras layers when use_keras_layer=True. Now it only supports
+ 'relu', 'linear', 'identity', 'swish', 'mish', 'leaky_relu', and 'gelu'.
+
+ Args:
+ identifier: String name of the activation function or callable.
+ use_keras_layer: If True, use keras layer if identifier is allow-listed.
+ **kwargs: Keyword arguments to use to instantiate an activation function.
+ Available only for 'leaky_relu' and 'gelu' when using keras layers.
+ For example: get_activation('leaky_relu', use_keras_layer=True, alpha=0.1)
+
+ Returns:
+ A Python function corresponding to the activation function or a keras
+ activation layer when use_keras_layer=True.
+ """
+ if isinstance(identifier, six.string_types):
+ identifier = str(identifier).lower()
+ if use_keras_layer:
+ keras_layer_allowlist = {
+ "relu": "relu",
+ "linear": "linear",
+ "identity": "linear",
+ "swish": "swish",
+ "sigmoid": "sigmoid",
+ "relu6": tf.nn.relu6,
+ "leaky_relu": functools.partial(tf.nn.leaky_relu, **kwargs),
+ "hard_swish": activations.hard_swish,
+ "hard_sigmoid": activations.hard_sigmoid,
+ "mish": activations.mish,
+ "gelu": functools.partial(tf.nn.gelu, **kwargs),
+ }
+ if identifier in keras_layer_allowlist:
+ return tf_keras.layers.Activation(keras_layer_allowlist[identifier])
+ name_to_fn = {
+ "gelu": activations.gelu,
+ "simple_swish": activations.simple_swish,
+ "hard_swish": activations.hard_swish,
+ "relu6": activations.relu6,
+ "hard_sigmoid": activations.hard_sigmoid,
+ "identity": activations.identity,
+ "mish": activations.mish,
+ }
+ if identifier in name_to_fn:
+ return tf_keras.activations.get(name_to_fn[identifier])
+ return tf_keras.activations.get(identifier)
+
+
+def get_shape_list(tensor, expected_rank=None, name=None):
+ """Returns a list of the shape of tensor, preferring static dimensions.
+
+ Args:
+ tensor: A tf.Tensor object to find the shape of.
+ expected_rank: (optional) int. The expected rank of `tensor`. If this is
+ specified and the `tensor` has a different rank, and exception will be
+ thrown.
+ name: Optional name of the tensor for the error message.
+
+ Returns:
+ A list of dimensions of the shape of tensor. All static dimensions will
+ be returned as python integers, and dynamic dimensions will be returned
+ as tf.Tensor scalars.
+ """
+ if expected_rank is not None:
+ assert_rank(tensor, expected_rank, name)
+
+ shape = tensor.shape.as_list()
+
+ non_static_indexes = []
+ for (index, dim) in enumerate(shape):
+ if dim is None:
+ non_static_indexes.append(index)
+
+ if not non_static_indexes:
+ return shape
+
+ dyn_shape = tf.shape(tensor)
+ for index in non_static_indexes:
+ shape[index] = dyn_shape[index]
+ return shape
+
+
+def assert_rank(tensor, expected_rank, name=None):
+ """Raises an exception if the tensor rank is not of the expected rank.
+
+ Args:
+ tensor: A tf.Tensor to check the rank of.
+ expected_rank: Python integer or list of integers, expected rank.
+ name: Optional name of the tensor for the error message.
+
+ Raises:
+ ValueError: If the expected shape doesn't match the actual shape.
+ """
+ expected_rank_dict = {}
+ if isinstance(expected_rank, six.integer_types):
+ expected_rank_dict[expected_rank] = True
+ else:
+ for x in expected_rank:
+ expected_rank_dict[x] = True
+
+ actual_rank = tensor.shape.ndims
+ if actual_rank not in expected_rank_dict:
+ raise ValueError(
+ "For the tensor `%s`, the actual tensor rank `%d` (shape = %s) is not "
+ "equal to the expected tensor rank `%s`" %
+ (name, actual_rank, str(tensor.shape), str(expected_rank)))
+
+
+def safe_mean(losses):
+ """Computes a safe mean of the losses.
+
+ Args:
+ losses: `Tensor` whose elements contain individual loss measurements.
+
+ Returns:
+ A scalar representing the mean of `losses`. If `num_present` is zero,
+ then zero is returned.
+ """
+ total = tf.reduce_sum(losses)
+ num_elements = tf.cast(tf.size(losses), dtype=losses.dtype)
+ return tf.math.divide_no_nan(total, num_elements)
+
+
+def get_replica_id():
+ """Gets replica id depending on the environment."""
+ context = tf.distribute.get_replica_context()
+ if context is not None:
+ return context.replica_id_in_sync_group
+ else:
+ raise RuntimeError("Unknown replica context. The `get_replica_id` method "
+ "relies on TF 2.x tf.distribute API.")
+
+
+def cross_replica_concat(value, axis, name="cross_replica_concat"):
+ """Concatenates the given `value` across (GPU/TPU) cores, along `axis`.
+
+ In general, each core ("replica") will pass a
+ replica-specific value as `value` (corresponding to some element of a
+ data-parallel computation taking place across replicas).
+
+ The resulting concatenated `Tensor` will have the same shape as `value` for
+ all dimensions except `axis`, where it will be larger by a factor of the
+ number of replicas. It will also have the same `dtype` as `value`.
+
+ The position of a given replica's `value` within the resulting concatenation
+ is determined by that replica's replica ID. For
+ example:
+
+ With `value` for replica 0 given as
+
+ 0 0 0
+ 0 0 0
+
+ and `value` for replica 1 given as
+
+ 1 1 1
+ 1 1 1
+
+ the resulting concatenation along axis 0 will be
+
+ 0 0 0
+ 0 0 0
+ 1 1 1
+ 1 1 1
+
+ and this result will be identical across all replicas.
+
+ Note that this API only works in TF2 with `tf.distribute`.
+
+ Args:
+ value: The `Tensor` to concatenate across replicas. Each replica will have a
+ different value for this `Tensor`, and these replica-specific values will
+ be concatenated.
+ axis: The axis along which to perform the concatenation as a Python integer
+ (not a `Tensor`). E.g., `axis=0` to concatenate along the batch dimension.
+ name: A name for the operation (used to create a name scope).
+
+ Returns:
+ The result of concatenating `value` along `axis` across replicas.
+
+ Raises:
+ RuntimeError: when the batch (0-th) dimension is None.
+ """
+ with tf.name_scope(name):
+ context = tf.distribute.get_replica_context()
+ # Typically this could be hit only if the tensor is derived from a
+ # dataset with finite epochs and drop_remainder=False, where the last
+ # batch could of different batch size and then the dim-0 is of dynamic
+ # shape.
+ if value.shape.as_list()[0] is None:
+ raise RuntimeError(f"{value} has unknown batch.")
+ return context.all_gather(value, axis=axis)
+
+
+def clone_initializer(initializer):
+ # Keras initializer is going to be stateless, which mean reusing the same
+ # initializer will produce same init value when the shapes are the same.
+ if isinstance(initializer, tf_keras.initializers.Initializer):
+ return initializer.__class__.from_config(initializer.get_config())
+ # When the input is string/dict or other serialized configs, caller will
+ # create a new keras Initializer instance based on that, and we don't need to
+ # do anything
+ return initializer
+
+
+def serialize_keras_object(obj):
+ if hasattr(tf_keras.utils, "legacy"):
+ return tf_keras.utils.legacy.serialize_keras_object(obj)
+ else:
+ return tf_keras.utils.serialize_keras_object(obj)
+
+
+def deserialize_keras_object(
+ config, module_objects=None, custom_objects=None, printable_module_name=None
+):
+ if hasattr(tf_keras.utils, "legacy"):
+ return tf_keras.utils.legacy.deserialize_keras_object(
+ config, custom_objects, module_objects, printable_module_name
+ )
+ else:
+ return tf_keras.utils.deserialize_keras_object(
+ config, custom_objects, module_objects, printable_module_name
+ )
+
+
+def serialize_layer(layer, use_legacy_format=False):
+ if (
+ "use_legacy_format"
+ in inspect.getfullargspec(tf_keras.layers.serialize).args
+ ):
+ return tf_keras.layers.serialize(layer, use_legacy_format=use_legacy_format)
+ else:
+ return tf_keras.layers.serialize(layer)
+
+
+def serialize_initializer(initializer, use_legacy_format=False):
+ if (
+ "use_legacy_format"
+ in inspect.getfullargspec(tf_keras.initializers.serialize).args
+ ):
+ return tf_keras.initializers.serialize(
+ initializer, use_legacy_format=use_legacy_format
+ )
+ else:
+ return tf_keras.initializers.serialize(initializer)
+
+
+def serialize_regularizer(regularizer, use_legacy_format=False):
+ if (
+ "use_legacy_format"
+ in inspect.getfullargspec(tf_keras.regularizers.serialize).args
+ ):
+ return tf_keras.regularizers.serialize(
+ regularizer, use_legacy_format=use_legacy_format
+ )
+ else:
+ return tf_keras.regularizers.serialize(regularizer)
+
+
+def serialize_constraint(constraint, use_legacy_format=False):
+ if (
+ "use_legacy_format"
+ in inspect.getfullargspec(tf_keras.constraints.serialize).args
+ ):
+ return tf_keras.constraints.serialize(
+ constraint, use_legacy_format=use_legacy_format
+ )
+ else:
+ return tf_keras.constraints.serialize(constraint)
+
+
+def serialize_activation(activation, use_legacy_format=False):
+ if (
+ "use_legacy_format"
+ in inspect.getfullargspec(tf_keras.activations.serialize).args
+ ):
+ return tf_keras.activations.serialize(
+ activation, use_legacy_format=use_legacy_format
+ )
+ else:
+ return tf_keras.activations.serialize(activation)
diff --git a/modeling/official/modeling/tf_utils_test.py b/modeling/official/modeling/tf_utils_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a3e4fbbc8a7579c68904be933775d615d06e22
--- /dev/null
+++ b/modeling/official/modeling/tf_utils_test.py
@@ -0,0 +1,108 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for tf_utils."""
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.modeling import tf_utils
+
+
+def all_strategy_combinations():
+ return combinations.combine(
+ strategy=[
+ strategy_combinations.cloud_tpu_strategy,
+ # TODO(b/285797201):disable multi-gpu tests due to hanging.
+ # strategy_combinations.mirrored_strategy_with_two_gpus,
+ ],
+ mode='eager',
+ )
+
+
+class TFUtilsTest(tf.test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_strategy_combinations())
+ def test_cross_replica_concat(self, strategy):
+ num_cores = strategy.num_replicas_in_sync
+
+ shape = (2, 3, 4)
+
+ def concat(axis):
+
+ @tf.function
+ def function():
+ replica_value = tf.fill(shape, tf_utils.get_replica_id())
+ return tf_utils.cross_replica_concat(replica_value, axis=axis)
+
+ return function
+
+ def expected(axis):
+ values = [np.full(shape, i) for i in range(num_cores)]
+ return np.concatenate(values, axis=axis)
+
+ per_replica_results = strategy.run(concat(axis=0))
+ replica_0_result = per_replica_results.values[0].numpy()
+ for value in per_replica_results.values[1:]:
+ self.assertAllClose(value.numpy(), replica_0_result)
+ self.assertAllClose(replica_0_result, expected(axis=0))
+
+ replica_0_result = strategy.run(concat(axis=1)).values[0].numpy()
+ self.assertAllClose(replica_0_result, expected(axis=1))
+
+ replica_0_result = strategy.run(concat(axis=2)).values[0].numpy()
+ self.assertAllClose(replica_0_result, expected(axis=2))
+
+ @combinations.generate(all_strategy_combinations())
+ def test_cross_replica_concat_gradient(self, strategy):
+ num_cores = strategy.num_replicas_in_sync
+
+ shape = (10, 5)
+
+ @tf.function
+ def function():
+ replica_value = tf.random.normal(shape)
+ with tf.GradientTape() as tape:
+ tape.watch(replica_value)
+ concat_value = tf_utils.cross_replica_concat(replica_value, axis=0)
+ output = tf.reduce_sum(concat_value)
+ return tape.gradient(output, replica_value)
+
+ per_replica_gradients = strategy.run(function)
+ for gradient in per_replica_gradients.values:
+ self.assertAllClose(gradient, num_cores * tf.ones(shape))
+
+ @parameterized.parameters(('relu', True), ('relu', False),
+ ('leaky_relu', False), ('leaky_relu', True),
+ ('mish', True), ('mish', False), ('gelu', True))
+ def test_get_activations(self, name, use_keras_layer):
+ fn = tf_utils.get_activation(name, use_keras_layer)
+ self.assertIsNotNone(fn)
+
+ @combinations.generate(all_strategy_combinations())
+ def test_get_leaky_relu_layer(self, strategy):
+ @tf.function
+ def forward(x):
+ fn = tf_utils.get_activation(
+ 'leaky_relu', use_keras_layer=True, alpha=0.1)
+ return strategy.run(fn, args=(x,)).values[0]
+
+ got = forward(tf.constant([-1]))
+ self.assertAllClose(got, tf.constant([-0.1]))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/nightly_requirements.txt b/modeling/official/nightly_requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..54e074d61d6261011fb527db68ce868fc996580e
--- /dev/null
+++ b/modeling/official/nightly_requirements.txt
@@ -0,0 +1,31 @@
+six
+google-api-python-client>=1.6.7
+kaggle>=1.3.9
+numpy>=1.20
+oauth2client
+pandas>=0.22.0
+psutil>=5.4.3
+py-cpuinfo>=3.3.0
+scipy>=0.19.1
+tensorflow-hub>=0.6.0
+tensorflow-model-optimization>=0.4.1
+tensorflow-datasets
+tf-keras-nightly
+gin-config
+tf_slim>=1.1.0
+Cython
+matplotlib
+pyyaml
+# CV related dependencies
+opencv-python-headless
+Pillow
+pycocotools
+# NLP related dependencies
+seqeval
+sentencepiece
+sacrebleu
+# Projects/vit dependencies
+immutabledict
+# Fix CI
+wrapt>=1.15
+
diff --git a/modeling/official/nlp/MODEL_GARDEN.md b/modeling/official/nlp/MODEL_GARDEN.md
new file mode 100644
index 0000000000000000000000000000000000000000..e357aafad4aaf0120b6b127c796a453731f34835
--- /dev/null
+++ b/modeling/official/nlp/MODEL_GARDEN.md
@@ -0,0 +1,79 @@
+# TF-NLP Model Garden
+## Introduction
+
+The TF-NLP library provides a collection of scripts for training and
+evaluating transformer-based models, on various tasks such as sentence
+classification, question answering, and translation. Additionally, we provide
+checkpoints of pretrained models which can be finetuned on downstream tasks.
+
+⚠️ Disclaimer: Checkpoints are based on training with publicly available datasets.
+Some datasets contain limitations, including non-commercial use limitations. Please review the terms and conditions made available by third parties before using
+the datasets provided. Checkpoints are licensed under
+[Apache 2.0](https://github.com/tensorflow/models/blob/master/LICENSE).
+
+⚠️ Disclaimer: Datasets hyperlinked from this page are not owned or distributed
+by Google. Such datasets are made available by third parties. Please review the
+terms and conditions made available by the third parties before using the data.
+
+### How to Train Models
+
+Model Garden can be easily installed with
+`pip install tf-models-nightly`. After installation, check out
+[this instruction](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md)
+on how to train models with this codebase.
+
+
+By default, the experiment runs on GPUs. To run on TPUs, one should overwrite
+`runtime.distribution_strategy` and set the tpu address. See [RuntimeConfig](https://github.com/tensorflow/models/blob/master/official/core/config_definitions.py) for details.
+
+In general, the experiments can run with the folloing command by setting the
+corresponding `${TASK}`, `${TASK_CONFIG}`, `${MODEL_CONFIG}`.
+```
+EXPERIMENT=???
+TASK_CONFIG=???
+MODEL_CONFIG=???
+EXRTRA_PARAMS=???
+MODEL_DIR=??? # a-folder-to-hold-checkpoints-and-logs
+python3 train.py \
+ --experiment=${EXPERIMENT} \
+ --mode=train_and_eval \
+ --model_dir=${MODEL_DIR} \
+ --config_file=${TASK_CONFIG} \
+ --config_file=${MODEL_CONFIG} \
+ --params_override=${EXRTRA_PARAMS}
+```
+
+* `EXPERIMENT` can be found under `configs/`
+* `TASK_CONFIG` can be found under `configs/experiments/`
+* `MODEL_CONFIG` can be found under `configs/models/`
+
+#### Order of params override:
+1. `train.py` looks up the registered `ExperimentConfig` with `${EXPERIMENT}`
+2. Overrides params in `TaskConfig` in `${TASK_CONFIG}`
+3. Overrides params `model` in `TaskConfig` with `${MODEL_CONFIG}`
+4. Overrides any params in `ExperimentConfig` with `${EXTRA_PARAMS}`
+
+Note that
+1. `${TASK_CONFIG}`, `${MODEL_CONFIG}`, `${EXTRA_PARAMS}` can be optional when EXPERIMENT default is enough.
+2. `${TASK_CONFIG}`, `${MODEL_CONFIG}`, `${EXTRA_PARAMS}` are only guaranteed to be compatible to it's `${EXPERIMENT}` that defines it.
+
+## Experiments
+
+| NAME | EXPERIMENT | TASK_CONFIG | MODEL_CONFIG | EXRTRA_PARAMS |
+| ----------------- | ------------------------ | ------- | -------- | ----------- |
+| BERT-base GLUE/MNLI-matched finetune | [bert/sentence_prediction](https://github.com/tensorflow/models/blob/master/official/nlp/configs/finetuning_experiments.py) | [glue_mnli_matched.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiments/glue_mnli_matched.yaml) | [bert_en_uncased_base.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/models/bert_en_uncased_base.yaml) | data and bert-base hub inittask.train_data.input_path=/path-to-your-training-data,task.validation_data.input_path=/path-to-your-val-data,task.hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4 |
+| BERT-base GLUE/MNLI-matched finetune | [bert/sentence_prediction](https://github.com/tensorflow/models/blob/master/official/nlp/configs/finetuning_experiments.py) | [glue_mnli_matched.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiments/glue_mnli_matched.yaml) | [bert_en_uncased_base.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/models/bert_en_uncased_base.yaml) | data and bert-base ckpt inittask.train_data.input_path=/path-to-your-training-data,task.validation_data.input_path=/path-to-your-val-data,task.init_checkpoint=gs://tf_model_garden/nlp/bert/uncased_L-12_H-768_A-12/bert_model.ckpt |
+| BERT-base SQuAD v1.1 finetune | [bert/squad](https://github.com/tensorflow/models/blob/master/official/nlp/configs/finetuning_experiments.py) | [squad_v1.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiments/squad_v1.yaml) | [bert_en_uncased_base.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/models/bert_en_uncased_base.yaml) | data and bert-base hub inittask.train_data.input_path=/path-to-your-training-data,task.validation_data.input_path=/path-to-your-val-data,task.hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4 |
+|ALBERT-base SQuAD v1.1 finetune | [bert/squad](https://github.com/tensorflow/models/blob/master/official/nlp/configs/finetuning_experiments.py) | [squad_v1.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/experiments/squad_v1.yaml) | [albert_base.yaml](https://github.com/tensorflow/models/blob/master/official/nlp/configs/models/albert_base.yaml)| data and albert-base hub inittask.train_data.input_path=/path-to-your-training-data,task.validation_data.input_path=/path-to-your-val-data,task.hub_module_url=https://tfhub.dev/tensorflow/albert_en_base/3 |
+| Transformer-large WMT14/en-de scratch |[wmt_transformer/large](https://github.com/tensorflow/models/blob/master/official/nlp/configs/wmt_transformer_experiments.py)| | | ende-32k sentencepiecetask.sentencepiece_model_path='gs://tf_model_garden/nlp/transformer_wmt/ende_bpe_32k.model' |
+
+
+## Useful links
+
+[How to Train Models](https://github.com/tensorflow/models/blob/master/official/nlp/docs/train.md)
+
+[List of Pretrained Models for finetuning](https://github.com/tensorflow/models/blob/master/official/nlp/docs/pretrained_models.md)
+
+[How to Publish Models](https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md)
+
+[TensorFlow blog on Model Garden](https://blog.tensorflow.org/2020/03/introducing-model-garden-for-tensorflow-2.html).
diff --git a/modeling/official/nlp/README.md b/modeling/official/nlp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2fc949f288ab8e44967eba869d6e06c935784ebb
--- /dev/null
+++ b/modeling/official/nlp/README.md
@@ -0,0 +1,70 @@
+# TF-NLP Model Garden
+
+⚠️ Disclaimer: Datasets hyperlinked from this page are not owned or distributed
+by Google. Such datasets are made available by third parties. Please review the
+terms and conditions made available by the third parties before using the data.
+
+This codebase provides a Natural Language Processing modeling toolkit written in
+[TF2](https://www.tensorflow.org/guide/effective_tf2). It allows researchers and
+developers to reproduce state-of-the-art model results and train custom models
+to experiment new research ideas.
+
+## Features
+
+* Reusable and modularized modeling building blocks
+* State-of-the-art reproducible
+* Easy to customize and extend
+* End-to-end training
+* Distributed trainable on both GPUs and TPUs
+
+## Major components
+
+### Libraries
+
+We provide modeling library to allow users to train custom models for new
+research ideas. Detailed instructions can be found in READMEs in each folder.
+
+* [modeling/](modeling): modeling library that provides building blocks
+ (e.g.,Layers, Networks, and Models) that can be assembled into
+ transformer-based architectures.
+* [data/](data): binaries and utils for input preprocessing, tokenization,
+ etc.
+
+### State-of-the-Art models and examples
+
+We provide SoTA model implementations, pre-trained models, training and
+evaluation examples, and command lines. Detail instructions can be found in the
+READMEs for specific papers. Below are some papers implemented in the repository
+and more NLP projects can be found in the
+[`projects`](https://github.com/tensorflow/models/tree/master/official/projects)
+folder:
+
+1. [BERT](MODEL_GARDEN.md#available-model-configs): [BERT: Pre-training of Deep
+ Bidirectional Transformers for Language
+ Understanding](https://arxiv.org/abs/1810.04805) by Devlin et al., 2018
+2. [ALBERT](MODEL_GARDEN.md#available-model-configs):
+ [A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942)
+ by Lan et al., 2019
+3. [XLNet](MODEL_GARDEN.md):
+ [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237)
+ by Yang et al., 2019
+4. [Transformer for translation](MODEL_GARDEN.md#available-model-configs):
+ [Attention Is All You Need](https://arxiv.org/abs/1706.03762) by Vaswani et
+ al., 2017
+
+### Common Training Driver
+
+We provide a single common driver [train.py](train.py) to train above SoTA
+models on popular tasks. Please see [docs/train.md](docs/train.md) for more
+details.
+
+### Pre-trained models with checkpoints and TF-Hub
+
+We provide a large collection of baselines and checkpoints for NLP pre-trained
+models. Please see [docs/pretrained_models.md](docs/pretrained_models.md) for
+more details.
+
+## More Documentations
+
+Please read through the model training tutorials and references in the
+[docs/ folder](docs/README.md).
diff --git a/modeling/official/nlp/__init__.py b/modeling/official/nlp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/nlp/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/nlp/configs/__init__.py b/modeling/official/nlp/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/nlp/configs/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/nlp/configs/bert.py b/modeling/official/nlp/configs/bert.py
new file mode 100644
index 0000000000000000000000000000000000000000..325e7f7b97b33cd88dab7754623f61d5a540a47d
--- /dev/null
+++ b/modeling/official/nlp/configs/bert.py
@@ -0,0 +1,47 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Multi-head BERT encoder network with classification heads.
+
+Includes configurations and instantiation methods.
+"""
+from typing import List, Optional, Text
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+from official.nlp.configs import encoders
+
+
+@dataclasses.dataclass
+class ClsHeadConfig(base_config.Config):
+ inner_dim: int = 0
+ num_classes: int = 2
+ activation: Optional[Text] = "tanh"
+ dropout_rate: float = 0.0
+ cls_token_idx: int = 0
+ name: Optional[Text] = None
+
+
+@dataclasses.dataclass
+class PretrainerConfig(base_config.Config):
+ """Pretrainer configuration."""
+ encoder: encoders.EncoderConfig = dataclasses.field(
+ default_factory=encoders.EncoderConfig
+ )
+ cls_heads: List[ClsHeadConfig] = dataclasses.field(default_factory=list)
+ mlm_activation: str = "gelu"
+ mlm_initializer_range: float = 0.02
+ # Currently only used for mobile bert.
+ mlm_output_weights_use_proj: bool = False
diff --git a/modeling/official/nlp/configs/electra.py b/modeling/official/nlp/configs/electra.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ffa0d3cbb2277ba1ddcd2f6bdae8b13447dec4
--- /dev/null
+++ b/modeling/official/nlp/configs/electra.py
@@ -0,0 +1,40 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""ELECTRA model configurations and instantiation methods."""
+from typing import List
+
+import dataclasses
+
+from official.modeling.hyperparams import base_config
+from official.nlp.configs import bert
+from official.nlp.configs import encoders
+
+
+@dataclasses.dataclass
+class ElectraPretrainerConfig(base_config.Config):
+ """ELECTRA pretrainer configuration."""
+ num_masked_tokens: int = 76
+ sequence_length: int = 512
+ num_classes: int = 2
+ discriminator_loss_weight: float = 50.0
+ tie_embeddings: bool = True
+ disallow_correct: bool = False
+ generator_encoder: encoders.EncoderConfig = dataclasses.field(
+ default_factory=encoders.EncoderConfig
+ )
+ discriminator_encoder: encoders.EncoderConfig = dataclasses.field(
+ default_factory=encoders.EncoderConfig
+ )
+ cls_heads: List[bert.ClsHeadConfig] = dataclasses.field(default_factory=list)
diff --git a/modeling/official/nlp/configs/encoders.py b/modeling/official/nlp/configs/encoders.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf4e665f424fd4d560c2baa46d81d972ce623b4
--- /dev/null
+++ b/modeling/official/nlp/configs/encoders.py
@@ -0,0 +1,773 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Transformer Encoders.
+
+Includes configurations and factory methods.
+"""
+import dataclasses
+from typing import Optional, Sequence, Union
+
+import gin
+import tensorflow as tf, tf_keras
+
+from official.modeling import hyperparams
+from official.modeling import tf_utils
+from official.nlp.modeling import layers
+from official.nlp.modeling import networks
+from official.projects.bigbird import encoder as bigbird_encoder
+
+
+@dataclasses.dataclass
+class BertEncoderConfig(hyperparams.Config):
+ """BERT encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ hidden_activation: str = "gelu"
+ intermediate_size: int = 3072
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ embedding_size: Optional[int] = None
+ output_range: Optional[int] = None
+ return_all_encoder_outputs: bool = False
+ return_attention_scores: bool = False
+ # Pre/Post-LN Transformer
+ norm_first: bool = False
+
+
+@dataclasses.dataclass
+class FunnelEncoderConfig(hyperparams.Config):
+ """Funnel encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 16
+ inner_dim: int = 3072
+ hidden_activation: str = "gelu"
+ approx_gelu: bool = True
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ pool_type: str = "max"
+ pool_stride: Union[int, Sequence[Union[int, float]]] = 2
+ unpool_length: int = 0
+ initializer_range: float = 0.02
+ output_range: Optional[int] = None
+ embedding_width: Optional[int] = None
+ embedding_layer: Optional[tf_keras.layers.Layer] = None
+ norm_first: bool = False
+ share_rezero: bool = False
+ append_dense_inputs: bool = False
+ transformer_cls: str = "TransformerEncoderBlock"
+
+
+@dataclasses.dataclass
+class MobileBertEncoderConfig(hyperparams.Config):
+ """MobileBERT encoder configuration.
+
+ Attributes:
+ word_vocab_size: number of words in the vocabulary.
+ word_embed_size: word embedding size.
+ type_vocab_size: number of word types.
+ max_sequence_length: maximum length of input sequence.
+ num_blocks: number of transformer block in the encoder model.
+ hidden_size: the hidden size for the transformer block.
+ num_attention_heads: number of attention heads in the transformer block.
+ intermediate_size: the size of the "intermediate" (a.k.a., feed forward)
+ layer.
+ hidden_activation: the non-linear activation function to apply to the
+ output of the intermediate/feed-forward layer.
+ hidden_dropout_prob: dropout probability for the hidden layers.
+ attention_probs_dropout_prob: dropout probability of the attention
+ probabilities.
+ intra_bottleneck_size: the size of bottleneck.
+ initializer_range: The stddev of the truncated_normal_initializer for
+ initializing all weight matrices.
+ use_bottleneck_attention: Use attention inputs from the bottleneck
+ transformation. If true, the following `key_query_shared_bottleneck`
+ will be ignored.
+ key_query_shared_bottleneck: whether to share linear transformation for keys
+ and queries.
+ num_feedforward_networks: number of stacked feed-forward networks.
+ normalization_type: the type of normalization_type, only 'no_norm' and
+ 'layer_norm' are supported. 'no_norm' represents the element-wise linear
+ transformation for the student model, as suggested by the original
+ MobileBERT paper. 'layer_norm' is used for the teacher model.
+ classifier_activation: if using the tanh activation for the final
+ representation of the [CLS] token in fine-tuning.
+ """
+ word_vocab_size: int = 30522
+ word_embed_size: int = 128
+ type_vocab_size: int = 2
+ max_sequence_length: int = 512
+ num_blocks: int = 24
+ hidden_size: int = 512
+ num_attention_heads: int = 4
+ intermediate_size: int = 4096
+ hidden_activation: str = "gelu"
+ hidden_dropout_prob: float = 0.1
+ attention_probs_dropout_prob: float = 0.1
+ intra_bottleneck_size: int = 1024
+ initializer_range: float = 0.02
+ use_bottleneck_attention: bool = False
+ key_query_shared_bottleneck: bool = False
+ num_feedforward_networks: int = 1
+ normalization_type: str = "layer_norm"
+ classifier_activation: bool = True
+ input_mask_dtype: str = "int32"
+
+
+@dataclasses.dataclass
+class AlbertEncoderConfig(hyperparams.Config):
+ """ALBERT encoder configuration."""
+ vocab_size: int = 30000
+ embedding_width: int = 128
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ hidden_activation: str = "gelu"
+ intermediate_size: int = 3072
+ dropout_rate: float = 0.0
+ attention_dropout_rate: float = 0.0
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+
+
+@dataclasses.dataclass
+class BigBirdEncoderConfig(hyperparams.Config):
+ """BigBird encoder configuration."""
+ vocab_size: int = 50358
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ hidden_activation: str = "gelu"
+ intermediate_size: int = 3072
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ # Pre/Post-LN Transformer
+ norm_first: bool = False
+ max_position_embeddings: int = 4096
+ num_rand_blocks: int = 3
+ block_size: int = 64
+ type_vocab_size: int = 16
+ initializer_range: float = 0.02
+ embedding_width: Optional[int] = None
+ use_gradient_checkpointing: bool = False
+
+
+@dataclasses.dataclass
+class KernelEncoderConfig(hyperparams.Config):
+ """Linear encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ hidden_activation: str = "gelu"
+ intermediate_size: int = 3072
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ # Pre/Post-LN Transformer
+ norm_first: bool = False
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ embedding_size: Optional[int] = None
+ feature_transform: str = "exp"
+ num_random_features: int = 256
+ redraw: bool = False
+ is_short_seq: bool = False
+ begin_kernel: int = 0
+ scale: Optional[float] = None
+
+
+@dataclasses.dataclass
+class ReuseEncoderConfig(hyperparams.Config):
+ """Reuse encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ hidden_activation: str = "gelu"
+ intermediate_size: int = 3072
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ embedding_size: Optional[int] = None
+ output_range: Optional[int] = None
+ return_all_encoder_outputs: bool = False
+ # Pre/Post-LN Transformer
+ norm_first: bool = False
+ # Reuse transformer
+ reuse_attention: int = -1
+ use_relative_pe: bool = False
+ pe_max_seq_length: int = 512
+ max_reuse_layer_idx: int = 6
+
+
+@dataclasses.dataclass
+class XLNetEncoderConfig(hyperparams.Config):
+ """XLNet encoder configuration."""
+ vocab_size: int = 32000
+ num_layers: int = 24
+ hidden_size: int = 1024
+ num_attention_heads: int = 16
+ head_size: int = 64
+ inner_size: int = 4096
+ inner_activation: str = "gelu"
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ attention_type: str = "bi"
+ bi_data: bool = False
+ tie_attention_biases: bool = False
+ memory_length: int = 0
+ same_length: bool = False
+ clamp_length: int = -1
+ reuse_length: int = 0
+ use_cls_mask: bool = False
+ embedding_width: int = 1024
+ initializer_range: float = 0.02
+ two_stream: bool = False
+
+
+@dataclasses.dataclass
+class QueryBertConfig(hyperparams.Config):
+ """Query BERT encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ hidden_activation: str = "gelu"
+ intermediate_size: int = 3072
+ dropout_rate: float = 0.1
+ attention_dropout_rate: float = 0.1
+ max_position_embeddings: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ embedding_size: Optional[int] = None
+ output_range: Optional[int] = None
+ return_all_encoder_outputs: bool = False
+ return_attention_scores: bool = False
+ # Pre/Post-LN Transformer
+ norm_first: bool = False
+
+
+@dataclasses.dataclass
+class FNetEncoderConfig(hyperparams.Config):
+ """FNet encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 12
+ num_attention_heads: int = 12
+ inner_activation: str = "gelu"
+ inner_dim: int = 3072
+ output_dropout: float = 0.1
+ attention_dropout: float = 0.1
+ max_sequence_length: int = 512
+ type_vocab_size: int = 2
+ initializer_range: float = 0.02
+ embedding_width: Optional[int] = None
+ output_range: Optional[int] = None
+ norm_first: bool = False
+ use_fft: bool = False
+ attention_layers: Sequence[int] = ()
+
+
+@dataclasses.dataclass
+class SparseMixerEncoderConfig(hyperparams.Config):
+ """SparseMixer encoder configuration."""
+ vocab_size: int = 30522
+ hidden_size: int = 768
+ num_layers: int = 14
+ moe_layers: Sequence[int] = (5, 6, 7, 8)
+ attention_layers: Sequence[int] = (10, 11, 12, 13)
+ num_experts: int = 16
+ train_capacity_factor: float = 1.
+ eval_capacity_factor: float = 1.
+ examples_per_group: float = 1.
+ use_fft: bool = False
+ num_attention_heads: int = 8
+ max_sequence_length: int = 512
+ type_vocab_size: int = 2
+ inner_dim: int = 3072
+ inner_activation: str = "gelu"
+ output_dropout: float = 0.1
+ attention_dropout: float = 0.1
+ initializer_range: float = 0.02
+ output_range: Optional[int] = None
+ embedding_width: Optional[int] = None
+ norm_first: bool = False
+
+
+@dataclasses.dataclass
+class EncoderConfig(hyperparams.OneOfConfig):
+ """Encoder configuration."""
+ type: Optional[str] = "bert"
+ albert: AlbertEncoderConfig = dataclasses.field(
+ default_factory=AlbertEncoderConfig
+ )
+ bert: BertEncoderConfig = dataclasses.field(default_factory=BertEncoderConfig)
+ bert_v2: BertEncoderConfig = dataclasses.field(
+ default_factory=BertEncoderConfig
+ )
+ bigbird: BigBirdEncoderConfig = dataclasses.field(
+ default_factory=BigBirdEncoderConfig
+ )
+ funnel: FunnelEncoderConfig = dataclasses.field(
+ default_factory=FunnelEncoderConfig
+ )
+ kernel: KernelEncoderConfig = dataclasses.field(
+ default_factory=KernelEncoderConfig
+ )
+ mobilebert: MobileBertEncoderConfig = dataclasses.field(
+ default_factory=MobileBertEncoderConfig
+ )
+ reuse: ReuseEncoderConfig = dataclasses.field(
+ default_factory=ReuseEncoderConfig
+ )
+ xlnet: XLNetEncoderConfig = dataclasses.field(
+ default_factory=XLNetEncoderConfig
+ )
+ query_bert: QueryBertConfig = dataclasses.field(
+ default_factory=QueryBertConfig
+ )
+ fnet: FNetEncoderConfig = dataclasses.field(default_factory=FNetEncoderConfig)
+ sparse_mixer: SparseMixerEncoderConfig = dataclasses.field(
+ default_factory=SparseMixerEncoderConfig
+ )
+ # If `any` is used, the encoder building relies on any.BUILDER.
+ any: hyperparams.Config = dataclasses.field(
+ default_factory=hyperparams.Config
+ )
+
+
+@gin.configurable
+def build_encoder(config: EncoderConfig,
+ embedding_layer: Optional[tf_keras.layers.Layer] = None,
+ encoder_cls=None,
+ bypass_config: bool = False):
+ """Instantiate a Transformer encoder network from EncoderConfig.
+
+ Args:
+ config: the one-of encoder config, which provides encoder parameters of a
+ chosen encoder.
+ embedding_layer: an external embedding layer passed to the encoder.
+ encoder_cls: an external encoder cls not included in the supported encoders,
+ usually used by gin.configurable.
+ bypass_config: whether to ignore config instance to create the object with
+ `encoder_cls`.
+
+ Returns:
+ An encoder instance.
+ """
+ if bypass_config:
+ return encoder_cls()
+ encoder_type = config.type
+ encoder_cfg = config.get()
+ if encoder_cls and encoder_cls.__name__ == "EncoderScaffold":
+ embedding_cfg = dict(
+ vocab_size=encoder_cfg.vocab_size,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ max_seq_length=encoder_cfg.max_position_embeddings,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ dropout_rate=encoder_cfg.dropout_rate,
+ )
+ hidden_cfg = dict(
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ intermediate_size=encoder_cfg.intermediate_size,
+ intermediate_activation=tf_utils.get_activation(
+ encoder_cfg.hidden_activation),
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ )
+ kwargs = dict(
+ embedding_cfg=embedding_cfg,
+ hidden_cfg=hidden_cfg,
+ num_hidden_instances=encoder_cfg.num_layers,
+ pooled_output_dim=encoder_cfg.hidden_size,
+ pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ return_all_layer_outputs=encoder_cfg.return_all_encoder_outputs,
+ dict_outputs=True)
+ return encoder_cls(**kwargs)
+
+ if encoder_type == "any":
+ encoder = encoder_cfg.BUILDER(encoder_cfg)
+ if not isinstance(encoder,
+ (tf.Module, tf_keras.Model, tf_keras.layers.Layer)):
+ raise ValueError("The BUILDER returns an unexpected instance. The "
+ "`build_encoder` should returns a tf.Module, "
+ "tf_keras.Model or tf_keras.layers.Layer. However, "
+ f"we get {encoder.__class__}")
+ return encoder
+
+ if encoder_type == "mobilebert":
+ return networks.MobileBERTEncoder(
+ word_vocab_size=encoder_cfg.word_vocab_size,
+ word_embed_size=encoder_cfg.word_embed_size,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ max_sequence_length=encoder_cfg.max_sequence_length,
+ num_blocks=encoder_cfg.num_blocks,
+ hidden_size=encoder_cfg.hidden_size,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ intermediate_size=encoder_cfg.intermediate_size,
+ intermediate_act_fn=encoder_cfg.hidden_activation,
+ hidden_dropout_prob=encoder_cfg.hidden_dropout_prob,
+ attention_probs_dropout_prob=encoder_cfg.attention_probs_dropout_prob,
+ intra_bottleneck_size=encoder_cfg.intra_bottleneck_size,
+ initializer_range=encoder_cfg.initializer_range,
+ use_bottleneck_attention=encoder_cfg.use_bottleneck_attention,
+ key_query_shared_bottleneck=encoder_cfg.key_query_shared_bottleneck,
+ num_feedforward_networks=encoder_cfg.num_feedforward_networks,
+ normalization_type=encoder_cfg.normalization_type,
+ classifier_activation=encoder_cfg.classifier_activation,
+ input_mask_dtype=encoder_cfg.input_mask_dtype)
+
+ if encoder_type == "albert":
+ return networks.AlbertEncoder(
+ vocab_size=encoder_cfg.vocab_size,
+ embedding_width=encoder_cfg.embedding_width,
+ hidden_size=encoder_cfg.hidden_size,
+ num_layers=encoder_cfg.num_layers,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ max_sequence_length=encoder_cfg.max_position_embeddings,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ intermediate_size=encoder_cfg.intermediate_size,
+ activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ dict_outputs=True)
+
+ if encoder_type == "bigbird":
+ # TODO(frederickliu): Support use_gradient_checkpointing and update
+ # experiments to use the EncoderScaffold only.
+ if encoder_cfg.use_gradient_checkpointing:
+ return bigbird_encoder.BigBirdEncoder(
+ vocab_size=encoder_cfg.vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ num_layers=encoder_cfg.num_layers,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ intermediate_size=encoder_cfg.intermediate_size,
+ activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ num_rand_blocks=encoder_cfg.num_rand_blocks,
+ block_size=encoder_cfg.block_size,
+ max_position_embeddings=encoder_cfg.max_position_embeddings,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ embedding_width=encoder_cfg.embedding_width,
+ use_gradient_checkpointing=encoder_cfg.use_gradient_checkpointing)
+ embedding_cfg = dict(
+ vocab_size=encoder_cfg.vocab_size,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ max_seq_length=encoder_cfg.max_position_embeddings,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ dropout_rate=encoder_cfg.dropout_rate)
+ attention_cfg = dict(
+ num_heads=encoder_cfg.num_attention_heads,
+ key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ max_rand_mask_length=encoder_cfg.max_position_embeddings,
+ num_rand_blocks=encoder_cfg.num_rand_blocks,
+ from_block_size=encoder_cfg.block_size,
+ to_block_size=encoder_cfg.block_size,
+ )
+ hidden_cfg = dict(
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ intermediate_size=encoder_cfg.intermediate_size,
+ intermediate_activation=tf_utils.get_activation(
+ encoder_cfg.hidden_activation),
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ norm_first=encoder_cfg.norm_first,
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ attention_cls=layers.BigBirdAttention,
+ attention_cfg=attention_cfg)
+ kwargs = dict(
+ embedding_cfg=embedding_cfg,
+ hidden_cls=layers.TransformerScaffold,
+ hidden_cfg=hidden_cfg,
+ num_hidden_instances=encoder_cfg.num_layers,
+ mask_cls=layers.BigBirdMasks,
+ mask_cfg=dict(block_size=encoder_cfg.block_size),
+ pooled_output_dim=encoder_cfg.hidden_size,
+ pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ return_all_layer_outputs=False,
+ dict_outputs=True,
+ layer_idx_as_attention_seed=True)
+ return networks.EncoderScaffold(**kwargs)
+
+ if encoder_type == "funnel":
+
+ if encoder_cfg.hidden_activation == "gelu":
+ activation = tf_utils.get_activation(
+ encoder_cfg.hidden_activation,
+ approximate=encoder_cfg.approx_gelu)
+ else:
+ activation = tf_utils.get_activation(encoder_cfg.hidden_activation)
+
+ return networks.FunnelTransformerEncoder(
+ vocab_size=encoder_cfg.vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ num_layers=encoder_cfg.num_layers,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ max_sequence_length=encoder_cfg.max_position_embeddings,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ inner_dim=encoder_cfg.inner_dim,
+ inner_activation=activation,
+ output_dropout=encoder_cfg.dropout_rate,
+ attention_dropout=encoder_cfg.attention_dropout_rate,
+ pool_type=encoder_cfg.pool_type,
+ pool_stride=encoder_cfg.pool_stride,
+ unpool_length=encoder_cfg.unpool_length,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ output_range=encoder_cfg.output_range,
+ embedding_width=encoder_cfg.embedding_width,
+ embedding_layer=embedding_layer,
+ norm_first=encoder_cfg.norm_first,
+ share_rezero=encoder_cfg.share_rezero,
+ append_dense_inputs=encoder_cfg.append_dense_inputs,
+ transformer_cls=encoder_cfg.transformer_cls,
+ )
+
+ if encoder_type == "kernel":
+ embedding_cfg = dict(
+ vocab_size=encoder_cfg.vocab_size,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ max_seq_length=encoder_cfg.max_position_embeddings,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ dropout_rate=encoder_cfg.dropout_rate)
+ attention_cfg = dict(
+ num_heads=encoder_cfg.num_attention_heads,
+ key_dim=int(encoder_cfg.hidden_size // encoder_cfg.num_attention_heads),
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ feature_transform=encoder_cfg.feature_transform,
+ num_random_features=encoder_cfg.num_random_features,
+ redraw=encoder_cfg.redraw,
+ is_short_seq=encoder_cfg.is_short_seq,
+ begin_kernel=encoder_cfg.begin_kernel,
+ scale=encoder_cfg.scale,
+ )
+ hidden_cfg = dict(
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ intermediate_size=encoder_cfg.intermediate_size,
+ intermediate_activation=tf_utils.get_activation(
+ encoder_cfg.hidden_activation),
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ norm_first=encoder_cfg.norm_first,
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ attention_cls=layers.KernelAttention,
+ attention_cfg=attention_cfg)
+ kwargs = dict(
+ embedding_cfg=embedding_cfg,
+ hidden_cls=layers.TransformerScaffold,
+ hidden_cfg=hidden_cfg,
+ num_hidden_instances=encoder_cfg.num_layers,
+ mask_cls=layers.KernelMask,
+ pooled_output_dim=encoder_cfg.hidden_size,
+ pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ return_all_layer_outputs=False,
+ dict_outputs=True,
+ layer_idx_as_attention_seed=True)
+ return networks.EncoderScaffold(**kwargs)
+
+ if encoder_type == "xlnet":
+ return networks.XLNetBase(
+ vocab_size=encoder_cfg.vocab_size,
+ num_layers=encoder_cfg.num_layers,
+ hidden_size=encoder_cfg.hidden_size,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ head_size=encoder_cfg.head_size,
+ inner_size=encoder_cfg.inner_size,
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ attention_type=encoder_cfg.attention_type,
+ bi_data=encoder_cfg.bi_data,
+ two_stream=encoder_cfg.two_stream,
+ tie_attention_biases=encoder_cfg.tie_attention_biases,
+ memory_length=encoder_cfg.memory_length,
+ clamp_length=encoder_cfg.clamp_length,
+ reuse_length=encoder_cfg.reuse_length,
+ inner_activation=encoder_cfg.inner_activation,
+ use_cls_mask=encoder_cfg.use_cls_mask,
+ embedding_width=encoder_cfg.embedding_width,
+ initializer=tf_keras.initializers.RandomNormal(
+ stddev=encoder_cfg.initializer_range))
+
+ if encoder_type == "reuse":
+ embedding_cfg = dict(
+ vocab_size=encoder_cfg.vocab_size,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ max_seq_length=encoder_cfg.max_position_embeddings,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ dropout_rate=encoder_cfg.dropout_rate)
+ hidden_cfg = dict(
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ inner_dim=encoder_cfg.intermediate_size,
+ inner_activation=tf_utils.get_activation(
+ encoder_cfg.hidden_activation),
+ output_dropout=encoder_cfg.dropout_rate,
+ attention_dropout=encoder_cfg.attention_dropout_rate,
+ norm_first=encoder_cfg.norm_first,
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ reuse_attention=encoder_cfg.reuse_attention,
+ use_relative_pe=encoder_cfg.use_relative_pe,
+ pe_max_seq_length=encoder_cfg.pe_max_seq_length,
+ max_reuse_layer_idx=encoder_cfg.max_reuse_layer_idx)
+ kwargs = dict(
+ embedding_cfg=embedding_cfg,
+ hidden_cls=layers.ReuseTransformer,
+ hidden_cfg=hidden_cfg,
+ num_hidden_instances=encoder_cfg.num_layers,
+ pooled_output_dim=encoder_cfg.hidden_size,
+ pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ return_all_layer_outputs=False,
+ dict_outputs=True,
+ feed_layer_idx=True,
+ recursive=True)
+ return networks.EncoderScaffold(**kwargs)
+
+ if encoder_type == "query_bert":
+ embedding_layer = layers.FactorizedEmbedding(
+ vocab_size=encoder_cfg.vocab_size,
+ embedding_width=encoder_cfg.embedding_size,
+ output_dim=encoder_cfg.hidden_size,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ name="word_embeddings")
+ return networks.BertEncoderV2(
+ vocab_size=encoder_cfg.vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ num_layers=encoder_cfg.num_layers,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ intermediate_size=encoder_cfg.intermediate_size,
+ activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ max_sequence_length=encoder_cfg.max_position_embeddings,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ output_range=encoder_cfg.output_range,
+ embedding_layer=embedding_layer,
+ return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
+ return_attention_scores=encoder_cfg.return_attention_scores,
+ dict_outputs=True,
+ norm_first=encoder_cfg.norm_first)
+
+ if encoder_type == "fnet":
+ return networks.FNet(
+ vocab_size=encoder_cfg.vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ num_layers=encoder_cfg.num_layers,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ inner_dim=encoder_cfg.inner_dim,
+ inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
+ output_dropout=encoder_cfg.output_dropout,
+ attention_dropout=encoder_cfg.attention_dropout,
+ max_sequence_length=encoder_cfg.max_sequence_length,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ output_range=encoder_cfg.output_range,
+ embedding_width=encoder_cfg.embedding_width,
+ embedding_layer=embedding_layer,
+ norm_first=encoder_cfg.norm_first,
+ use_fft=encoder_cfg.use_fft,
+ attention_layers=encoder_cfg.attention_layers)
+
+ if encoder_type == "sparse_mixer":
+ return networks.SparseMixer(
+ vocab_size=encoder_cfg.vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ num_layers=encoder_cfg.num_layers,
+ moe_layers=encoder_cfg.moe_layers,
+ attention_layers=encoder_cfg.attention_layers,
+ num_experts=encoder_cfg.num_experts,
+ train_capacity_factor=encoder_cfg.train_capacity_factor,
+ eval_capacity_factor=encoder_cfg.eval_capacity_factor,
+ examples_per_group=encoder_cfg.examples_per_group,
+ use_fft=encoder_cfg.use_fft,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ max_sequence_length=encoder_cfg.max_sequence_length,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ inner_dim=encoder_cfg.inner_dim,
+ inner_activation=tf_utils.get_activation(encoder_cfg.inner_activation),
+ output_dropout=encoder_cfg.output_dropout,
+ attention_dropout=encoder_cfg.attention_dropout,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ output_range=encoder_cfg.output_range,
+ embedding_width=encoder_cfg.embedding_width,
+ norm_first=encoder_cfg.norm_first,
+ embedding_layer=embedding_layer)
+
+ bert_encoder_cls = networks.BertEncoder
+ if encoder_type == "bert_v2":
+ bert_encoder_cls = networks.BertEncoderV2
+
+ # Uses the default BERTEncoder configuration schema to create the encoder.
+ # If it does not match, please add a switch branch by the encoder type.
+ return bert_encoder_cls(
+ vocab_size=encoder_cfg.vocab_size,
+ hidden_size=encoder_cfg.hidden_size,
+ num_layers=encoder_cfg.num_layers,
+ num_attention_heads=encoder_cfg.num_attention_heads,
+ intermediate_size=encoder_cfg.intermediate_size,
+ activation=tf_utils.get_activation(encoder_cfg.hidden_activation),
+ dropout_rate=encoder_cfg.dropout_rate,
+ attention_dropout_rate=encoder_cfg.attention_dropout_rate,
+ max_sequence_length=encoder_cfg.max_position_embeddings,
+ type_vocab_size=encoder_cfg.type_vocab_size,
+ initializer=tf_keras.initializers.TruncatedNormal(
+ stddev=encoder_cfg.initializer_range),
+ output_range=encoder_cfg.output_range,
+ embedding_width=encoder_cfg.embedding_size,
+ embedding_layer=embedding_layer,
+ return_all_encoder_outputs=encoder_cfg.return_all_encoder_outputs,
+ return_attention_scores=encoder_cfg.return_attention_scores,
+ dict_outputs=True,
+ norm_first=encoder_cfg.norm_first)
diff --git a/modeling/official/nlp/configs/encoders_test.py b/modeling/official/nlp/configs/encoders_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6cc1c36c38f11798ce1c0419123e5c940d8fc90
--- /dev/null
+++ b/modeling/official/nlp/configs/encoders_test.py
@@ -0,0 +1,52 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.configs.encoders."""
+import os
+
+import tensorflow as tf, tf_keras
+
+from official.modeling import hyperparams
+from official.nlp.configs import encoders
+from official.nlp.modeling import networks
+from official.projects.teams import teams
+
+
+class EncodersTest(tf.test.TestCase):
+
+ def test_encoder_from_yaml(self):
+ config = encoders.EncoderConfig(
+ type="bert", bert=encoders.BertEncoderConfig(num_layers=1))
+ encoder = encoders.build_encoder(config)
+ ckpt = tf.train.Checkpoint(encoder=encoder)
+ ckpt_path = ckpt.save(self.get_temp_dir() + "/ckpt")
+ params_save_path = os.path.join(self.get_temp_dir(), "params.yaml")
+ hyperparams.save_params_dict_to_yaml(config, params_save_path)
+
+ retored_cfg = encoders.EncoderConfig.from_yaml(params_save_path)
+ retored_encoder = encoders.build_encoder(retored_cfg)
+ status = tf.train.Checkpoint(encoder=retored_encoder).restore(ckpt_path)
+ status.assert_consumed()
+
+ def test_build_teams(self):
+ config = encoders.EncoderConfig(
+ type="any", any=teams.TeamsEncoderConfig(num_layers=1))
+ encoder = encoders.build_encoder(config)
+ self.assertIsInstance(encoder, networks.EncoderScaffold)
+ self.assertIsInstance(encoder.embedding_network,
+ networks.PackedSequenceEmbedding)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/nlp/configs/experiment_configs.py b/modeling/official/nlp/configs/experiment_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..849b1acbeada46146c297d0c2b6aae9436aa7bcf
--- /dev/null
+++ b/modeling/official/nlp/configs/experiment_configs.py
@@ -0,0 +1,19 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Experiments definition."""
+# pylint: disable=unused-import
+from official.nlp.configs import finetuning_experiments
+from official.nlp.configs import pretraining_experiments
+from official.nlp.configs import wmt_transformer_experiments
diff --git a/modeling/official/nlp/configs/experiments/glue_mnli_matched.yaml b/modeling/official/nlp/configs/experiments/glue_mnli_matched.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..29dfcb68b9c314d309239c321dde4ec4f439da1d
--- /dev/null
+++ b/modeling/official/nlp/configs/experiments/glue_mnli_matched.yaml
@@ -0,0 +1,49 @@
+task:
+ hub_module_url: ''
+ model:
+ num_classes: 3
+ init_checkpoint: ''
+ metric_type: 'accuracy'
+ train_data:
+ drop_remainder: true
+ global_batch_size: 32
+ input_path: ''
+ is_training: true
+ seq_length: 128
+ label_type: 'int'
+ validation_data:
+ drop_remainder: false
+ global_batch_size: 32
+ input_path: ''
+ is_training: false
+ seq_length: 128
+ label_type: 'int'
+trainer:
+ checkpoint_interval: 3000
+ optimizer_config:
+ learning_rate:
+ polynomial:
+ # 100% of train_steps.
+ decay_steps: 36813
+ end_learning_rate: 0.0
+ initial_learning_rate: 3.0e-05
+ power: 1.0
+ type: polynomial
+ optimizer:
+ type: adamw
+ warmup:
+ polynomial:
+ power: 1
+ # ~10% of train_steps.
+ warmup_steps: 3681
+ type: polynomial
+ steps_per_loop: 1000
+ summary_interval: 1000
+ # Training data size 392,702 examples, 3 epochs.
+ train_steps: 36813
+ validation_interval: 6135
+ # Eval data size = 9815 examples.
+ validation_steps: 307
+ best_checkpoint_export_subdir: 'best_ckpt'
+ best_checkpoint_eval_metric: 'cls_accuracy'
+ best_checkpoint_metric_comp: 'higher'
diff --git a/modeling/official/nlp/configs/experiments/glue_mnli_text.yaml b/modeling/official/nlp/configs/experiments/glue_mnli_text.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..53676092da3df7decb7e972d84ac036b0827804e
--- /dev/null
+++ b/modeling/official/nlp/configs/experiments/glue_mnli_text.yaml
@@ -0,0 +1,50 @@
+task:
+ hub_module_url: ''
+ model:
+ num_classes: 3
+ init_checkpoint: ''
+ train_data:
+ drop_remainder: true
+ global_batch_size: 32
+ is_training: true
+ seq_length: 128
+ shuffle_buffer_size: 100
+ tfds_name: 'glue/mnli'
+ tfds_split: 'train'
+ text_fields: ['premise', 'hypothesis']
+ vocab_file: ''
+ lower_case: true
+ validation_data:
+ drop_remainder: false
+ global_batch_size: 32
+ is_training: false
+ seq_length: 128
+ tfds_name: 'glue/mnli'
+ tfds_split: 'validation_matched'
+ text_fields: ['premise', 'hypothesis']
+ vocab_file: ''
+ lower_case: true
+trainer:
+ checkpoint_interval: 3000
+ max_to_keep: 5
+ optimizer_config:
+ learning_rate:
+ polynomial:
+ cycle: false
+ decay_steps: 36813
+ end_learning_rate: 0.0
+ initial_learning_rate: 3.0e-05
+ power: 1.0
+ type: polynomial
+ optimizer:
+ type: adamw
+ warmup:
+ polynomial:
+ power: 1
+ warmup_steps: 3681
+ type: polynomial
+ steps_per_loop: 1000
+ summary_interval: 1000
+ train_steps: 36813
+ validation_interval: 6135
+ validation_steps: 307
diff --git a/modeling/official/nlp/configs/experiments/squad_v1.yaml b/modeling/official/nlp/configs/experiments/squad_v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a69710a58f7dfa4e044bceb73c5870701ca39189
--- /dev/null
+++ b/modeling/official/nlp/configs/experiments/squad_v1.yaml
@@ -0,0 +1,50 @@
+task:
+ hub_module_url: ''
+ max_answer_length: 30
+ n_best_size: 20
+ null_score_diff_threshold: 0.0
+ init_checkpoint: ''
+ train_data:
+ drop_remainder: true
+ global_batch_size: 48
+ input_path: ''
+ is_training: true
+ seq_length: 384
+ validation_data:
+ do_lower_case: true
+ doc_stride: 128
+ drop_remainder: false
+ global_batch_size: 48
+ input_path: ''
+ is_training: false
+ query_length: 64
+ seq_length: 384
+ tokenization: WordPiece
+ version_2_with_negative: false
+ vocab_file: ''
+trainer:
+ checkpoint_interval: 1000
+ max_to_keep: 5
+ optimizer_config:
+ learning_rate:
+ polynomial:
+ decay_steps: 3699
+ end_learning_rate: 0.0
+ initial_learning_rate: 8.0e-05
+ power: 1.0
+ type: polynomial
+ optimizer:
+ type: adamw
+ warmup:
+ polynomial:
+ power: 1
+ warmup_steps: 370
+ type: polynomial
+ steps_per_loop: 1000
+ summary_interval: 1000
+ train_steps: 3699
+ validation_interval: 1000
+ validation_steps: 226
+ best_checkpoint_export_subdir: 'best_ckpt'
+ best_checkpoint_eval_metric: 'final_f1'
+ best_checkpoint_metric_comp: 'higher'
diff --git a/modeling/official/nlp/configs/experiments/wiki_books_pretrain.yaml b/modeling/official/nlp/configs/experiments/wiki_books_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..be126e0d0bfd680ccab88ccde0d9721def45f548
--- /dev/null
+++ b/modeling/official/nlp/configs/experiments/wiki_books_pretrain.yaml
@@ -0,0 +1,48 @@
+task:
+ init_checkpoint: ''
+ model:
+ cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768, name: next_sentence, num_classes: 2}]
+ train_data:
+ drop_remainder: true
+ global_batch_size: 512
+ input_path: '[Your processed wiki data path]*,[Your processed books data path]*'
+ is_training: true
+ max_predictions_per_seq: 76
+ seq_length: 512
+ use_next_sentence_label: true
+ use_position_id: false
+ use_v2_feature_names: true
+ validation_data:
+ drop_remainder: false
+ global_batch_size: 512
+ input_path: '[Your processed wiki data path]-00000-of-00500,[Your processed books data path]-00000-of-00500'
+ is_training: false
+ max_predictions_per_seq: 76
+ seq_length: 512
+ use_next_sentence_label: true
+ use_position_id: false
+ use_v2_feature_names: true
+trainer:
+ checkpoint_interval: 20000
+ max_to_keep: 5
+ optimizer_config:
+ learning_rate:
+ polynomial:
+ cycle: false
+ decay_steps: 1000000
+ end_learning_rate: 0.0
+ initial_learning_rate: 0.0001
+ power: 1.0
+ type: polynomial
+ optimizer:
+ type: adamw
+ warmup:
+ polynomial:
+ power: 1
+ warmup_steps: 10000
+ type: polynomial
+ steps_per_loop: 1000
+ summary_interval: 1000
+ train_steps: 1000000
+ validation_interval: 1000
+ validation_steps: 64
diff --git a/modeling/official/nlp/configs/experiments/wiki_tfds_pretrain.yaml b/modeling/official/nlp/configs/experiments/wiki_tfds_pretrain.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..388bce3210aadddb7c7f81176fc34d2ab860ec1e
--- /dev/null
+++ b/modeling/official/nlp/configs/experiments/wiki_tfds_pretrain.yaml
@@ -0,0 +1,50 @@
+task:
+ init_checkpoint: ''
+ model:
+ cls_heads: [{activation: tanh, cls_token_idx: 0, dropout_rate: 0.1, inner_dim: 768, name: next_sentence, num_classes: 2}]
+ train_data:
+ drop_remainder: true
+ global_batch_size: 512
+ is_training: true
+ max_predictions_per_seq: 76
+ seq_length: 512
+ use_next_sentence_label: false
+ use_whole_word_masking: true
+ tfds_name: wikipedia/20201201.en
+ tfds_split: train
+ vocab_file_path: 'Please provide the vocab file path.'
+ validation_data:
+ drop_remainder: true
+ global_batch_size: 32
+ is_training: false
+ max_predictions_per_seq: 76
+ seq_length: 512
+ use_next_sentence_label: false
+ use_whole_word_masking: true
+ tfds_name: wikipedia/20201201.en
+ tfds_split: train
+ vocab_file_path: 'Please provide the vocab file path.'
+trainer:
+ checkpoint_interval: 20000
+ max_to_keep: 5
+ optimizer_config:
+ learning_rate:
+ polynomial:
+ cycle: false
+ decay_steps: 1000000
+ end_learning_rate: 0.0
+ initial_learning_rate: 0.0001
+ power: 1.0
+ type: polynomial
+ optimizer:
+ type: adamw
+ warmup:
+ polynomial:
+ power: 1
+ warmup_steps: 10000
+ type: polynomial
+ steps_per_loop: 1000
+ summary_interval: 1000
+ train_steps: 1000000
+ validation_interval: 1000
+ validation_steps: 64
diff --git a/modeling/official/nlp/configs/finetuning_experiments.py b/modeling/official/nlp/configs/finetuning_experiments.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0e3e0c33b2581321ef41b9755bc18ca34dec0c5
--- /dev/null
+++ b/modeling/official/nlp/configs/finetuning_experiments.py
@@ -0,0 +1,179 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Finetuning experiment configurations."""
+# pylint: disable=g-doc-return-or-yield,line-too-long
+from official.core import config_definitions as cfg
+from official.core import exp_factory
+from official.modeling import optimization
+from official.nlp.data import question_answering_dataloader
+from official.nlp.data import sentence_prediction_dataloader
+from official.nlp.data import tagging_dataloader
+from official.nlp.tasks import question_answering
+from official.nlp.tasks import sentence_prediction
+from official.nlp.tasks import tagging
+
+
+@exp_factory.register_config_factory('bert/sentence_prediction')
+def bert_sentence_prediction() -> cfg.ExperimentConfig:
+ r"""BERT GLUE."""
+ config = cfg.ExperimentConfig(
+ task=sentence_prediction.SentencePredictionConfig(
+ train_data=sentence_prediction_dataloader
+ .SentencePredictionDataConfig(),
+ validation_data=sentence_prediction_dataloader
+ .SentencePredictionDataConfig(
+ is_training=False, drop_remainder=False)),
+ trainer=cfg.TrainerConfig(
+ optimizer_config=optimization.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'adamw',
+ 'adamw': {
+ 'weight_decay_rate':
+ 0.01,
+ 'exclude_from_weight_decay':
+ ['LayerNorm', 'layer_norm', 'bias'],
+ }
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 3e-5,
+ 'end_learning_rate': 0.0,
+ }
+ },
+ 'warmup': {
+ 'type': 'polynomial'
+ }
+ })),
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None'
+ ])
+ return config
+
+
+@exp_factory.register_config_factory('bert/sentence_prediction_text')
+def bert_sentence_prediction_text() -> cfg.ExperimentConfig:
+ r"""BERT sentence prediction with raw text data.
+
+ Example: use tf.text and tfds as input with glue_mnli_text.yaml
+ """
+ config = cfg.ExperimentConfig(
+ task=sentence_prediction.SentencePredictionConfig(
+ train_data=sentence_prediction_dataloader
+ .SentencePredictionTextDataConfig(),
+ validation_data=sentence_prediction_dataloader
+ .SentencePredictionTextDataConfig(
+ is_training=False, drop_remainder=False)),
+ trainer=cfg.TrainerConfig(
+ optimizer_config=optimization.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'adamw',
+ 'adamw': {
+ 'weight_decay_rate':
+ 0.01,
+ 'exclude_from_weight_decay':
+ ['LayerNorm', 'layer_norm', 'bias'],
+ }
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 3e-5,
+ 'end_learning_rate': 0.0,
+ }
+ },
+ 'warmup': {
+ 'type': 'polynomial'
+ }
+ })),
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None'
+ ])
+ return config
+
+
+@exp_factory.register_config_factory('bert/squad')
+def bert_squad() -> cfg.ExperimentConfig:
+ """BERT Squad V1/V2."""
+ config = cfg.ExperimentConfig(
+ task=question_answering.QuestionAnsweringConfig(
+ train_data=question_answering_dataloader.QADataConfig(),
+ validation_data=question_answering_dataloader.QADataConfig()),
+ trainer=cfg.TrainerConfig(
+ optimizer_config=optimization.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'adamw',
+ 'adamw': {
+ 'weight_decay_rate':
+ 0.01,
+ 'exclude_from_weight_decay':
+ ['LayerNorm', 'layer_norm', 'bias'],
+ }
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 8e-5,
+ 'end_learning_rate': 0.0,
+ }
+ },
+ 'warmup': {
+ 'type': 'polynomial'
+ }
+ })),
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None'
+ ])
+ return config
+
+
+@exp_factory.register_config_factory('bert/tagging')
+def bert_tagging() -> cfg.ExperimentConfig:
+ """BERT tagging task."""
+ config = cfg.ExperimentConfig(
+ task=tagging.TaggingConfig(
+ train_data=tagging_dataloader.TaggingDataConfig(),
+ validation_data=tagging_dataloader.TaggingDataConfig(
+ is_training=False, drop_remainder=False)),
+ trainer=cfg.TrainerConfig(
+ optimizer_config=optimization.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'adamw',
+ 'adamw': {
+ 'weight_decay_rate':
+ 0.01,
+ 'exclude_from_weight_decay':
+ ['LayerNorm', 'layer_norm', 'bias'],
+ }
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 8e-5,
+ 'end_learning_rate': 0.0,
+ }
+ },
+ 'warmup': {
+ 'type': 'polynomial'
+ }
+ })),
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None',
+ ])
+ return config
diff --git a/modeling/official/nlp/configs/models/albert_base.yaml b/modeling/official/nlp/configs/models/albert_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7d1317b2f4113de671967f0dc94b848653c544c9
--- /dev/null
+++ b/modeling/official/nlp/configs/models/albert_base.yaml
@@ -0,0 +1,16 @@
+task:
+ model:
+ encoder:
+ type: albert
+ albert:
+ attention_dropout_rate: 0.0
+ dropout_rate: 0.0
+ hidden_activation: gelu
+ hidden_size: 768
+ initializer_range: 0.02
+ intermediate_size: 3072
+ max_position_embeddings: 512
+ num_attention_heads: 12
+ num_layers: 12
+ type_vocab_size: 2
+ vocab_size: 30000
diff --git a/modeling/official/nlp/configs/models/bert_en_uncased_base.yaml b/modeling/official/nlp/configs/models/bert_en_uncased_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e49bc5430ed0135aa6d981421aad623f4f1fac9
--- /dev/null
+++ b/modeling/official/nlp/configs/models/bert_en_uncased_base.yaml
@@ -0,0 +1,16 @@
+task:
+ model:
+ encoder:
+ type: bert
+ bert:
+ attention_dropout_rate: 0.1
+ dropout_rate: 0.1
+ hidden_activation: gelu
+ hidden_size: 768
+ initializer_range: 0.02
+ intermediate_size: 3072
+ max_position_embeddings: 512
+ num_attention_heads: 12
+ num_layers: 12
+ type_vocab_size: 2
+ vocab_size: 30522
diff --git a/modeling/official/nlp/configs/pretraining_experiments.py b/modeling/official/nlp/configs/pretraining_experiments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4108e35cd45b0e79711c5128f4965efcef288853
--- /dev/null
+++ b/modeling/official/nlp/configs/pretraining_experiments.py
@@ -0,0 +1,135 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Pretraining experiment configurations."""
+# pylint: disable=g-doc-return-or-yield,line-too-long
+from official.core import config_definitions as cfg
+from official.core import exp_factory
+from official.modeling import optimization
+from official.nlp.data import pretrain_dataloader
+from official.nlp.data import pretrain_dynamic_dataloader
+from official.nlp.data import pretrain_text_dataloader
+from official.nlp.tasks import electra_task
+from official.nlp.tasks import masked_lm
+
+
+_TRAINER = cfg.TrainerConfig(
+ train_steps=1000000,
+ optimizer_config=optimization.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'adamw',
+ 'adamw': {
+ 'weight_decay_rate':
+ 0.01,
+ 'exclude_from_weight_decay': [
+ 'LayerNorm', 'layer_norm', 'bias'
+ ],
+ }
+ },
+ 'learning_rate': {
+ 'type': 'polynomial',
+ 'polynomial': {
+ 'initial_learning_rate': 1e-4,
+ 'end_learning_rate': 0.0,
+ }
+ },
+ 'warmup': {
+ 'type': 'polynomial'
+ }
+ }))
+
+
+@exp_factory.register_config_factory('bert/pretraining')
+def bert_pretraining() -> cfg.ExperimentConfig:
+ """BERT pretraining experiment."""
+ config = cfg.ExperimentConfig(
+ runtime=cfg.RuntimeConfig(enable_xla=True),
+ task=masked_lm.MaskedLMConfig(
+ train_data=pretrain_dataloader.BertPretrainDataConfig(),
+ validation_data=pretrain_dataloader.BertPretrainDataConfig(
+ is_training=False)),
+ trainer=_TRAINER,
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None'
+ ])
+ return config
+
+
+@exp_factory.register_config_factory('bert/pretraining_dynamic')
+def bert_dynamic() -> cfg.ExperimentConfig:
+ """BERT base with dynamic input sequences.
+
+ TPU needs to run with tf.data service with round-robin behavior.
+ """
+ config = cfg.ExperimentConfig(
+ runtime=cfg.RuntimeConfig(enable_xla=True),
+ task=masked_lm.MaskedLMConfig(
+ train_data=pretrain_dynamic_dataloader.BertPretrainDataConfig(),
+ validation_data=pretrain_dataloader.BertPretrainDataConfig(
+ is_training=False)),
+ trainer=_TRAINER,
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None'
+ ])
+ return config
+
+
+@exp_factory.register_config_factory('bert/text_wiki_pretraining')
+def bert_text_wiki_pretraining() -> cfg.ExperimentConfig:
+ r"""BERT with wiki text tfds.
+
+ Note that: only wikipedia english corpus is used. It cannot exactly reproduce
+ BERT training setup because the next sentence sampling is hard to match the
+ implementation with tf ops.
+ """
+ config = cfg.ExperimentConfig(
+ task=masked_lm.MaskedLMConfig(
+ train_data=pretrain_text_dataloader.BertPretrainTextDataConfig(
+ tfds_name='wikipedia/20201201.en',
+ tfds_split='train',
+ vocab_file_path='TODO for users',
+ ),
+ validation_data=pretrain_text_dataloader.BertPretrainTextDataConfig(
+ tfds_name='wikipedia/20201201.en',
+ tfds_split='train',
+ vocab_file_path='TODO for users',
+ is_training=False)),
+ trainer=_TRAINER,
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None'
+ ])
+ return config
+
+
+@exp_factory.register_config_factory('electra/pretraining')
+def electra_pretrain() -> cfg.ExperimentConfig:
+ """ELECTRA pretraining experiment."""
+ config = cfg.ExperimentConfig(
+ runtime=cfg.RuntimeConfig(enable_xla=True),
+ task=electra_task.ElectraPretrainConfig(
+ train_data=pretrain_dataloader.BertPretrainDataConfig(),
+ validation_data=pretrain_dataloader.BertPretrainDataConfig(
+ is_training=False
+ ),
+ ),
+ trainer=_TRAINER,
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.validation_data.is_training != None',
+ ],
+ )
+ return config
diff --git a/modeling/official/nlp/configs/wmt_transformer_experiments.py b/modeling/official/nlp/configs/wmt_transformer_experiments.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aaa725167c5ec7787fe3deca170e8b79f034177
--- /dev/null
+++ b/modeling/official/nlp/configs/wmt_transformer_experiments.py
@@ -0,0 +1,110 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# pylint: disable=g-doc-return-or-yield,line-too-long
+"""WMT translation configurations."""
+
+from official.core import config_definitions as cfg
+from official.core import exp_factory
+from official.modeling import optimization
+from official.nlp.data import wmt_dataloader
+from official.nlp.tasks import translation
+
+
+@exp_factory.register_config_factory('wmt_transformer/large')
+def wmt_transformer_large() -> cfg.ExperimentConfig:
+ """WMT Transformer Large.
+
+ Please refer to
+ tensorflow_models/official/nlp/data/train_sentencepiece.py
+ to generate sentencepiece_model
+ and pass
+ --params_override=task.sentencepiece_model_path='YOUR_PATH'
+ to the train script.
+ """
+ learning_rate = 2.0
+ hidden_size = 1024
+ learning_rate *= (hidden_size**-0.5)
+ warmup_steps = 16000
+ train_steps = 300000
+ token_batch_size = 24576
+ encdecoder = translation.EncDecoder(
+ num_attention_heads=16, intermediate_size=hidden_size * 4)
+ config = cfg.ExperimentConfig(
+ runtime=cfg.RuntimeConfig(enable_xla=True),
+ task=translation.TranslationConfig(
+ model=translation.ModelConfig(
+ encoder=encdecoder,
+ decoder=encdecoder,
+ embedding_width=hidden_size,
+ padded_decode=True,
+ decode_max_length=100),
+ train_data=wmt_dataloader.WMTDataConfig(
+ tfds_name='wmt14_translate/de-en',
+ tfds_split='train',
+ src_lang='en',
+ tgt_lang='de',
+ is_training=True,
+ global_batch_size=token_batch_size,
+ static_batch=True,
+ max_seq_length=64
+ ),
+ validation_data=wmt_dataloader.WMTDataConfig(
+ tfds_name='wmt14_translate/de-en',
+ tfds_split='test',
+ src_lang='en',
+ tgt_lang='de',
+ is_training=False,
+ global_batch_size=32,
+ static_batch=True,
+ max_seq_length=100,
+ ),
+ sentencepiece_model_path=None,
+ ),
+ trainer=cfg.TrainerConfig(
+ train_steps=train_steps,
+ validation_steps=-1,
+ steps_per_loop=1000,
+ summary_interval=1000,
+ checkpoint_interval=5000,
+ validation_interval=5000,
+ max_to_keep=1,
+ optimizer_config=optimization.OptimizationConfig({
+ 'optimizer': {
+ 'type': 'adam',
+ 'adam': {
+ 'beta_2': 0.997,
+ 'epsilon': 1e-9,
+ },
+ },
+ 'learning_rate': {
+ 'type': 'power',
+ 'power': {
+ 'initial_learning_rate': learning_rate,
+ 'power': -0.5,
+ }
+ },
+ 'warmup': {
+ 'type': 'linear',
+ 'linear': {
+ 'warmup_steps': warmup_steps,
+ 'warmup_learning_rate': 0.0
+ }
+ }
+ })),
+ restrictions=[
+ 'task.train_data.is_training != None',
+ 'task.sentencepiece_model_path != None',
+ ])
+ return config
diff --git a/modeling/official/nlp/continuous_finetune_lib.py b/modeling/official/nlp/continuous_finetune_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ad706e2e8e5e69c188763fbccc05a6ac23e0435
--- /dev/null
+++ b/modeling/official/nlp/continuous_finetune_lib.py
@@ -0,0 +1,217 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""TFM continuous finetuning+eval training driver library."""
+import gc
+import os
+import time
+from typing import Any, Mapping, Optional
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.common import distribute_utils
+from official.core import config_definitions
+from official.core import task_factory
+from official.core import train_lib
+from official.core import train_utils
+from official.modeling import performance
+from official.modeling.multitask import configs
+from official.modeling.multitask import train_lib as multitask_train_lib
+
+
+def _flatten_dict(xs):
+ """Flatten a nested dictionary.
+
+ The nested keys are flattened to a tuple.
+
+ Example::
+
+ xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
+ flat_xs = flatten_dict(xs)
+ print(flat_xs)
+ # {
+ # ('foo',): 1,
+ # ('bar', 'a'): 2,
+ # }
+
+ Note that empty dictionaries are ignored and
+ will not be restored by `unflatten_dict`.
+
+ Args:
+ xs: a nested dictionary
+
+ Returns:
+ The flattened dictionary.
+ """
+ assert isinstance(xs, dict), 'input is not a dict'
+
+ def _flatten(xs, prefix):
+ if not isinstance(xs, dict):
+ return {prefix: xs}
+ result = {}
+ for key, value in xs.items():
+ path = prefix + (key,)
+ result.update(_flatten(value, path))
+ return result
+
+ return _flatten(xs, ())
+
+
+def run_continuous_finetune(
+ mode: str,
+ params: config_definitions.ExperimentConfig,
+ model_dir: str,
+ run_post_eval: bool = False,
+ pretrain_steps: Optional[int] = None,
+) -> Mapping[str, Any]:
+ """Run modes with continuous training.
+
+ Currently only supports continuous_train_and_eval.
+
+ Args:
+ mode: A 'str', specifying the mode. continuous_train_and_eval - monitors a
+ checkpoint directory. Once a new checkpoint is discovered, loads the
+ checkpoint, finetune the model by training it (probably on another dataset
+ or with another task), then evaluate the finetuned model.
+ params: ExperimentConfig instance.
+ model_dir: A 'str', a path to store model checkpoints and summaries.
+ run_post_eval: Whether to run post eval once after training, metrics logs
+ are returned.
+ pretrain_steps: Optional, the number of total training steps for the
+ pretraining job.
+
+ Returns:
+ eval logs: returns eval metrics logs when run_post_eval is set to True,
+ othewise, returns {}.
+ """
+
+ assert mode == 'continuous_train_and_eval', (
+ 'Only continuous_train_and_eval is supported by continuous_finetune. '
+ 'Got mode: {}'.format(mode))
+
+ # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
+ # can have significant impact on model speeds by utilizing float16 in case of
+ # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
+ # dtype is float16
+ if params.runtime.mixed_precision_dtype:
+ performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
+ distribution_strategy = distribute_utils.get_distribution_strategy(
+ distribution_strategy=params.runtime.distribution_strategy,
+ all_reduce_alg=params.runtime.all_reduce_alg,
+ num_gpus=params.runtime.num_gpus,
+ tpu_address=params.runtime.tpu)
+
+ retry_times = 0
+ while not tf.io.gfile.isdir(params.task.init_checkpoint):
+ # Wait for the init_checkpoint directory to be created.
+ if retry_times >= 60:
+ raise ValueError(
+ 'ExperimentConfig.task.init_checkpoint must be a directory for '
+ 'continuous_train_and_eval mode.')
+ retry_times += 1
+ time.sleep(60)
+
+ summary_writer = tf.summary.create_file_writer(
+ os.path.join(model_dir, 'eval'))
+
+ global_step = 0
+
+ def timeout_fn():
+ if pretrain_steps and global_step < pretrain_steps:
+ # Keeps waiting for another timeout period.
+ logging.info(
+ 'Continue waiting for new checkpoint as current pretrain '
+ 'global_step=%d and target is %d.', global_step, pretrain_steps)
+ return False
+ # Quits the loop.
+ return True
+
+ for pretrain_ckpt in tf.train.checkpoints_iterator(
+ checkpoint_dir=params.task.init_checkpoint,
+ min_interval_secs=10,
+ timeout=params.trainer.continuous_eval_timeout,
+ timeout_fn=timeout_fn):
+
+ # If there are checkpoints, they might be the finetune checkpoint of a
+ # different pretrained checkpoint. So we just remove all checkpoints.
+ train_utils.remove_ckpts(model_dir)
+
+ with distribution_strategy.scope():
+ global_step = train_utils.read_global_step_from_checkpoint(pretrain_ckpt)
+ # Replaces params.task.init_checkpoint to make sure that we load
+ # exactly this pretrain checkpoint.
+ if params.trainer.best_checkpoint_export_subdir:
+ best_ckpt_subdir = '{}_{}'.format(
+ params.trainer.best_checkpoint_export_subdir, global_step)
+ params_replaced = params.replace(
+ task={'init_checkpoint': pretrain_ckpt},
+ trainer={'best_checkpoint_export_subdir': best_ckpt_subdir})
+ else:
+ params_replaced = params.replace(task={'init_checkpoint': pretrain_ckpt})
+ params_replaced.lock()
+ logging.info('Running finetuning with params: %s', params_replaced)
+
+ with distribution_strategy.scope():
+ if isinstance(params, configs.MultiEvalExperimentConfig):
+ task = task_factory.get_task(params_replaced.task)
+ eval_tasks = [
+ task_factory.get_task(config.task_config, name=config.task_name)
+ for config in params.eval_tasks
+ ]
+ (_,
+ eval_metrics) = multitask_train_lib.run_experiment_with_multitask_eval(
+ distribution_strategy=distribution_strategy,
+ train_task=task,
+ eval_tasks=eval_tasks,
+ mode='train_and_eval',
+ params=params_replaced,
+ model_dir=model_dir,
+ run_post_eval=True,
+ save_summary=False)
+ else:
+ task = task_factory.get_task(
+ params_replaced.task, logging_dir=model_dir)
+ _, eval_metrics = train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode='train_and_eval',
+ params=params_replaced,
+ model_dir=model_dir,
+ run_post_eval=True,
+ save_summary=False)
+ logging.info('Evaluation finished. Pretrain global_step: %d', global_step)
+ train_utils.write_json_summary(model_dir, global_step, eval_metrics)
+
+ if not os.path.basename(model_dir): # if model_dir.endswith('/')
+ summary_grp = os.path.dirname(model_dir) + '_' + task.name
+ else:
+ summary_grp = os.path.basename(model_dir) + '_' + task.name
+ summaries = {}
+ for name, value in _flatten_dict(eval_metrics).items():
+ summaries[summary_grp + '/' + '-'.join(name)] = value
+ train_utils.write_summary(summary_writer, global_step, summaries)
+
+ train_utils.remove_ckpts(model_dir)
+ # In TF2, the resource life cycle is bound with the python object life
+ # cycle. Force trigger python garbage collection here so those resources
+ # can be deallocated in time, so it doesn't cause OOM when allocating new
+ # objects.
+ # TODO(b/169178664): Fix cycle reference in Keras model and revisit to see
+ # if we need gc here.
+ gc.collect()
+
+ if run_post_eval:
+ return eval_metrics
+ return {}
diff --git a/modeling/official/nlp/continuous_finetune_lib_test.py b/modeling/official/nlp/continuous_finetune_lib_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc4108c7b3fae1ed803d9f3165773d135077c8c1
--- /dev/null
+++ b/modeling/official/nlp/continuous_finetune_lib_test.py
@@ -0,0 +1,98 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from absl import flags
+from absl.testing import flagsaver
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+# pylint: disable=unused-import
+from official.common import registry_imports
+# pylint: enable=unused-import
+from official.common import flags as tfm_flags
+from official.core import task_factory
+from official.core import train_lib
+from official.core import train_utils
+from official.nlp import continuous_finetune_lib
+
+FLAGS = flags.FLAGS
+
+tfm_flags.define_flags()
+
+
+class ContinuousFinetuneTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self._model_dir = os.path.join(self.get_temp_dir(), 'model_dir')
+
+ def testContinuousFinetune(self):
+ pretrain_steps = 1
+ src_model_dir = self.get_temp_dir()
+ flags_dict = dict(
+ experiment='mock',
+ mode='continuous_train_and_eval',
+ model_dir=self._model_dir,
+ params_override={
+ 'task': {
+ 'init_checkpoint': src_model_dir,
+ },
+ 'trainer': {
+ 'continuous_eval_timeout': 1,
+ 'steps_per_loop': 1,
+ 'train_steps': 1,
+ 'validation_steps': 1,
+ 'best_checkpoint_export_subdir': 'best_ckpt',
+ 'best_checkpoint_eval_metric': 'acc',
+ 'optimizer_config': {
+ 'optimizer': {
+ 'type': 'sgd'
+ },
+ 'learning_rate': {
+ 'type': 'constant'
+ }
+ }
+ }
+ })
+
+ with flagsaver.flagsaver(**flags_dict):
+ # Train and save some checkpoints.
+ params = train_utils.parse_configuration(flags.FLAGS)
+ distribution_strategy = tf.distribute.get_strategy()
+ with distribution_strategy.scope():
+ task = task_factory.get_task(params.task, logging_dir=src_model_dir)
+ _ = train_lib.run_experiment(
+ distribution_strategy=distribution_strategy,
+ task=task,
+ mode='train',
+ params=params,
+ model_dir=src_model_dir)
+
+ params = train_utils.parse_configuration(FLAGS)
+ eval_metrics = continuous_finetune_lib.run_continuous_finetune(
+ FLAGS.mode,
+ params,
+ FLAGS.model_dir,
+ run_post_eval=True,
+ pretrain_steps=pretrain_steps)
+ self.assertIn('best_acc', eval_metrics)
+
+ self.assertFalse(
+ tf.io.gfile.exists(os.path.join(FLAGS.model_dir, 'checkpoint')))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/nlp/data/README.md b/modeling/official/nlp/data/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2a706d7be7f99974a097ca128e890589c8cdb826
--- /dev/null
+++ b/modeling/official/nlp/data/README.md
@@ -0,0 +1,4 @@
+This directory contains binaries and utils required for input preprocessing,
+tokenization, etc that can be used with model building blocks available in
+NLP modeling library [nlp/modelling](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)
+to train custom models and validate new research ideas.
diff --git a/modeling/official/nlp/data/__init__.py b/modeling/official/nlp/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9852eb329122ba340f1d0cfaa3b4b90b04c78930
--- /dev/null
+++ b/modeling/official/nlp/data/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/modeling/official/nlp/data/classifier_data_lib.py b/modeling/official/nlp/data/classifier_data_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb9e6822570bffe95aa3820b445c9a288758ba96
--- /dev/null
+++ b/modeling/official/nlp/data/classifier_data_lib.py
@@ -0,0 +1,1612 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT library to process data for classification task."""
+
+import collections
+import csv
+import importlib
+import json
+import os
+
+from absl import logging
+import tensorflow as tf, tf_keras
+import tensorflow_datasets as tfds
+
+from official.nlp.tools import tokenization
+
+
+class InputExample(object):
+ """A single training/test example for simple seq regression/classification."""
+
+ def __init__(self,
+ guid,
+ text_a,
+ text_b=None,
+ label=None,
+ weight=None,
+ example_id=None):
+ """Constructs a InputExample.
+
+ Args:
+ guid: Unique id for the example.
+ text_a: string. The untokenized text of the first sequence. For single
+ sequence tasks, only this sequence must be specified.
+ text_b: (Optional) string. The untokenized text of the second sequence.
+ Only must be specified for sequence pair tasks.
+ label: (Optional) string for classification, float for regression. The
+ label of the example. This should be specified for train and dev
+ examples, but not for test examples.
+ weight: (Optional) float. The weight of the example to be used during
+ training.
+ example_id: (Optional) int. The int identification number of example in
+ the corpus.
+ """
+ self.guid = guid
+ self.text_a = text_a
+ self.text_b = text_b
+ self.label = label
+ self.weight = weight
+ self.example_id = example_id
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ input_ids,
+ input_mask,
+ segment_ids,
+ label_id,
+ is_real_example=True,
+ weight=None,
+ example_id=None):
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.label_id = label_id
+ self.is_real_example = is_real_example
+ self.weight = weight
+ self.example_id = example_id
+
+
+class DataProcessor(object):
+ """Base class for converters for seq regression/classification datasets."""
+
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
+ self.process_text_fn = process_text_fn
+ self.is_regression = False
+ self.label_type = None
+
+ def get_train_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the train set."""
+ raise NotImplementedError()
+
+ def get_dev_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for the dev set."""
+ raise NotImplementedError()
+
+ def get_test_examples(self, data_dir):
+ """Gets a collection of `InputExample`s for prediction."""
+ raise NotImplementedError()
+
+ def get_labels(self):
+ """Gets the list of labels for this data set."""
+ raise NotImplementedError()
+
+ @staticmethod
+ def get_processor_name():
+ """Gets the string identifier of the processor."""
+ raise NotImplementedError()
+
+ @classmethod
+ def _read_tsv(cls, input_file, quotechar=None):
+ """Reads a tab separated value file."""
+ with tf.io.gfile.GFile(input_file, "r") as f:
+ reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
+ lines = []
+ for line in reader:
+ lines.append(line)
+ return lines
+
+ @classmethod
+ def _read_jsonl(cls, input_file):
+ """Reads a json line file."""
+ with tf.io.gfile.GFile(input_file, "r") as f:
+ lines = []
+ for json_str in f:
+ lines.append(json.loads(json_str))
+ return lines
+
+ def featurize_example(self, *kargs, **kwargs):
+ """Converts a single `InputExample` into a single `InputFeatures`."""
+ return convert_single_example(*kargs, **kwargs)
+
+
+class DefaultGLUEDataProcessor(DataProcessor):
+ """Processor for the SuperGLUE dataset."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples_tfds("train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples_tfds("validation")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples_tfds("test")
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ raise NotImplementedError()
+
+
+class AxProcessor(DataProcessor):
+ """Processor for the AX dataset (GLUE diagnostics dataset)."""
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ train_mnli_dataset = tfds.load(
+ "glue/mnli", split="train", try_gcs=True).as_numpy_iterator()
+ return self._create_examples_tfds(train_mnli_dataset, "train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ val_mnli_dataset = tfds.load(
+ "glue/mnli", split="validation_matched",
+ try_gcs=True).as_numpy_iterator()
+ return self._create_examples_tfds(val_mnli_dataset, "validation")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ test_ax_dataset = tfds.load(
+ "glue/ax", split="test", try_gcs=True).as_numpy_iterator()
+ return self._create_examples_tfds(test_ax_dataset, "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "AX"
+
+ def _create_examples_tfds(self, dataset, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "contradiction"
+ text_a = self.process_text_fn(example["hypothesis"])
+ text_b = self.process_text_fn(example["premise"])
+ if set_type != "test":
+ label = self.get_labels()[example["label"]]
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+
+class ColaProcessor(DefaultGLUEDataProcessor):
+ """Processor for the CoLA data set (GLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "COLA"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/cola", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "0"
+ text_a = self.process_text_fn(example["sentence"])
+ if set_type != "test":
+ label = str(example["label"])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
+ return examples
+
+
+class ImdbProcessor(DataProcessor):
+ """Processor for the IMDb dataset."""
+
+ def get_labels(self):
+ return ["neg", "pos"]
+
+ def get_train_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "train"))
+
+ def get_dev_examples(self, data_dir):
+ return self._create_examples(os.path.join(data_dir, "test"))
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "IMDB"
+
+ def _create_examples(self, data_dir):
+ """Creates examples."""
+ examples = []
+ for label in ["neg", "pos"]:
+ cur_dir = os.path.join(data_dir, label)
+ for filename in tf.io.gfile.listdir(cur_dir):
+ if not filename.endswith("txt"):
+ continue
+
+ if len(examples) % 1000 == 0:
+ logging.info("Loading dev example %d", len(examples))
+
+ path = os.path.join(cur_dir, filename)
+ with tf.io.gfile.GFile(path, "r") as f:
+ text = f.read().strip().replace(" ", " ")
+ examples.append(
+ InputExample(
+ guid="unused_id", text_a=text, text_b=None, label=label))
+ return examples
+
+
+class MnliProcessor(DataProcessor):
+ """Processor for the MultiNLI data set (GLUE version)."""
+
+ def __init__(self,
+ mnli_type="matched",
+ process_text_fn=tokenization.convert_to_unicode):
+ super(MnliProcessor, self).__init__(process_text_fn)
+ self.dataset = tfds.load("glue/mnli", try_gcs=True)
+ if mnli_type not in ("matched", "mismatched"):
+ raise ValueError("Invalid `mnli_type`: %s" % mnli_type)
+ self.mnli_type = mnli_type
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples_tfds("train")
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ if self.mnli_type == "matched":
+ return self._create_examples_tfds("validation_matched")
+ else:
+ return self._create_examples_tfds("validation_mismatched")
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ if self.mnli_type == "matched":
+ return self._create_examples_tfds("test_matched")
+ else:
+ return self._create_examples_tfds("test_mismatched")
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "MNLI"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/mnli", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "contradiction"
+ text_a = self.process_text_fn(example["hypothesis"])
+ text_b = self.process_text_fn(example["premise"])
+ if set_type != "test":
+ label = self.get_labels()[example["label"]]
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+
+class MrpcProcessor(DefaultGLUEDataProcessor):
+ """Processor for the MRPC data set (GLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "MRPC"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/mrpc", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "0"
+ text_a = self.process_text_fn(example["sentence1"])
+ text_b = self.process_text_fn(example["sentence2"])
+ if set_type != "test":
+ label = str(example["label"])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+
+class PawsxProcessor(DataProcessor):
+ """Processor for the PAWS-X data set."""
+ supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
+
+ def __init__(self,
+ language="en",
+ process_text_fn=tokenization.convert_to_unicode):
+ super(PawsxProcessor, self).__init__(process_text_fn)
+ if language == "all":
+ self.languages = PawsxProcessor.supported_languages
+ elif language not in PawsxProcessor.supported_languages:
+ raise ValueError("language %s is not supported for PAWS-X task." %
+ language)
+ else:
+ self.languages = [language]
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = []
+ for language in self.languages:
+ if language == "en":
+ train_tsv = "train.tsv"
+ else:
+ train_tsv = "translated_train.tsv"
+ # Skips the header.
+ lines.extend(
+ self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
+
+ examples = []
+ for i, line in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[1])
+ text_b = self.process_text_fn(line[2])
+ label = self.process_text_fn(line[3])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ lines = []
+ for lang in PawsxProcessor.supported_languages:
+ lines.extend(
+ self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
+
+ examples = []
+ for i, line in enumerate(lines):
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[1])
+ text_b = self.process_text_fn(line[2])
+ label = self.process_text_fn(line[3])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ examples_by_lang = {k: [] for k in self.supported_languages}
+ for lang in self.supported_languages:
+ lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
+ for i, line in enumerate(lines):
+ guid = "test-%d" % i
+ text_a = self.process_text_fn(line[1])
+ text_b = self.process_text_fn(line[2])
+ label = self.process_text_fn(line[3])
+ examples_by_lang[lang].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XTREME-PAWS-X"
+
+
+class QnliProcessor(DefaultGLUEDataProcessor):
+ """Processor for the QNLI data set (GLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["entailment", "not_entailment"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "QNLI"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/qnli", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "entailment"
+ text_a = self.process_text_fn(example["question"])
+ text_b = self.process_text_fn(example["sentence"])
+ if set_type != "test":
+ label = self.get_labels()[example["label"]]
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+
+class QqpProcessor(DefaultGLUEDataProcessor):
+ """Processor for the QQP data set (GLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "QQP"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/qqp", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "0"
+ text_a = self.process_text_fn(example["question1"])
+ text_b = self.process_text_fn(example["question2"])
+ if set_type != "test":
+ label = str(example["label"])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+
+class RteProcessor(DefaultGLUEDataProcessor):
+ """Processor for the RTE data set (GLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ # All datasets are converted to 2-class split, where for 3-class datasets we
+ # collapse neutral and contradiction into not_entailment.
+ return ["entailment", "not_entailment"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "RTE"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "entailment"
+ text_a = self.process_text_fn(example["sentence1"])
+ text_b = self.process_text_fn(example["sentence2"])
+ if set_type != "test":
+ label = self.get_labels()[example["label"]]
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+
+class SstProcessor(DefaultGLUEDataProcessor):
+ """Processor for the SST-2 data set (GLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "SST-2"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/sst2", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "0"
+ text_a = self.process_text_fn(example["sentence"])
+ if set_type != "test":
+ label = str(example["label"])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=None, label=label, weight=None))
+ return examples
+
+
+class StsBProcessor(DefaultGLUEDataProcessor):
+ """Processor for the STS-B data set (GLUE version)."""
+
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
+ super(StsBProcessor, self).__init__(process_text_fn=process_text_fn)
+ self.is_regression = True
+ self.label_type = float
+ self._labels = None
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/stsb", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = 0.0
+ text_a = self.process_text_fn(example["sentence1"])
+ text_b = self.process_text_fn(example["sentence2"])
+ if set_type != "test":
+ label = self.label_type(example["label"])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+ def get_labels(self):
+ """See base class."""
+ return self._labels
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "STS-B"
+
+
+class TfdsProcessor(DataProcessor):
+ """Processor for generic text classification and regression TFDS data set.
+
+ The TFDS parameters are expected to be provided in the tfds_params string, in
+ a comma-separated list of parameter assignments.
+ Examples:
+ tfds_params="dataset=scicite,text_key=string"
+ tfds_params="dataset=imdb_reviews,test_split=,dev_split=test"
+ tfds_params="dataset=glue/cola,text_key=sentence"
+ tfds_params="dataset=glue/sst2,text_key=sentence"
+ tfds_params="dataset=glue/qnli,text_key=question,text_b_key=sentence"
+ tfds_params="dataset=glue/mrpc,text_key=sentence1,text_b_key=sentence2"
+ tfds_params="dataset=glue/stsb,text_key=sentence1,text_b_key=sentence2,"
+ "is_regression=true,label_type=float"
+ tfds_params="dataset=snli,text_key=premise,text_b_key=hypothesis,"
+ "skip_label=-1"
+ Possible parameters (please refer to the documentation of Tensorflow Datasets
+ (TFDS) for the meaning of individual parameters):
+ dataset: Required dataset name (potentially with subset and version number).
+ data_dir: Optional TFDS source root directory.
+ module_import: Optional Dataset module to import.
+ train_split: Name of the train split (defaults to `train`).
+ dev_split: Name of the dev split (defaults to `validation`).
+ test_split: Name of the test split (defaults to `test`).
+ text_key: Key of the text_a feature (defaults to `text`).
+ text_b_key: Key of the second text feature if available.
+ label_key: Key of the label feature (defaults to `label`).
+ test_text_key: Key of the text feature to use in test set.
+ test_text_b_key: Key of the second text feature to use in test set.
+ test_label: String to be used as the label for all test examples.
+ label_type: Type of the label key (defaults to `int`).
+ weight_key: Key of the float sample weight (is not used if not provided).
+ is_regression: Whether the task is a regression problem (defaults to False).
+ skip_label: Skip examples with given label (defaults to None).
+ """
+
+ def __init__(self,
+ tfds_params,
+ process_text_fn=tokenization.convert_to_unicode):
+ super(TfdsProcessor, self).__init__(process_text_fn)
+ self._process_tfds_params_str(tfds_params)
+ if self.module_import:
+ importlib.import_module(self.module_import)
+
+ self.dataset, info = tfds.load(
+ self.dataset_name, data_dir=self.data_dir, with_info=True)
+ if self.is_regression:
+ self._labels = None
+ else:
+ self._labels = list(range(info.features[self.label_key].num_classes))
+
+ def _process_tfds_params_str(self, params_str):
+ """Extracts TFDS parameters from a comma-separated assignements string."""
+ dtype_map = {"int": int, "float": float}
+ cast_str_to_bool = lambda s: s.lower() not in ["false", "0"]
+
+ tuples = [x.split("=") for x in params_str.split(",")]
+ d = {k.strip(): v.strip() for k, v in tuples}
+ self.dataset_name = d["dataset"] # Required.
+ self.data_dir = d.get("data_dir", None)
+ self.module_import = d.get("module_import", None)
+ self.train_split = d.get("train_split", "train")
+ self.dev_split = d.get("dev_split", "validation")
+ self.test_split = d.get("test_split", "test")
+ self.text_key = d.get("text_key", "text")
+ self.text_b_key = d.get("text_b_key", None)
+ self.label_key = d.get("label_key", "label")
+ self.test_text_key = d.get("test_text_key", self.text_key)
+ self.test_text_b_key = d.get("test_text_b_key", self.text_b_key)
+ self.test_label = d.get("test_label", "test_example")
+ self.label_type = dtype_map[d.get("label_type", "int")]
+ self.is_regression = cast_str_to_bool(d.get("is_regression", "False"))
+ self.weight_key = d.get("weight_key", None)
+ self.skip_label = d.get("skip_label", None)
+ if self.skip_label is not None:
+ self.skip_label = self.label_type(self.skip_label)
+
+ def get_train_examples(self, data_dir):
+ assert data_dir is None
+ return self._create_examples(self.train_split, "train")
+
+ def get_dev_examples(self, data_dir):
+ assert data_dir is None
+ return self._create_examples(self.dev_split, "dev")
+
+ def get_test_examples(self, data_dir):
+ assert data_dir is None
+ return self._create_examples(self.test_split, "test")
+
+ def get_labels(self):
+ return self._labels
+
+ def get_processor_name(self):
+ return "TFDS_" + self.dataset_name
+
+ def _create_examples(self, split_name, set_type):
+ """Creates examples for the training/dev/test sets."""
+ if split_name not in self.dataset:
+ raise ValueError("Split {} not available.".format(split_name))
+ dataset = self.dataset[split_name].as_numpy_iterator()
+ examples = []
+ text_b, weight = None, None
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ if set_type == "test":
+ text_a = self.process_text_fn(example[self.test_text_key])
+ if self.test_text_b_key:
+ text_b = self.process_text_fn(example[self.test_text_b_key])
+ label = self.test_label
+ else:
+ text_a = self.process_text_fn(example[self.text_key])
+ if self.text_b_key:
+ text_b = self.process_text_fn(example[self.text_b_key])
+ label = self.label_type(example[self.label_key])
+ if self.skip_label is not None and label == self.skip_label:
+ continue
+ if self.weight_key:
+ weight = float(example[self.weight_key])
+ examples.append(
+ InputExample(
+ guid=guid,
+ text_a=text_a,
+ text_b=text_b,
+ label=label,
+ weight=weight))
+ return examples
+
+
+class WnliProcessor(DefaultGLUEDataProcessor):
+ """Processor for the WNLI data set (GLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "WNLI"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "glue/wnli", split=set_type, try_gcs=True).as_numpy_iterator()
+ dataset = list(dataset)
+ dataset.sort(key=lambda x: x["idx"])
+ examples = []
+ for i, example in enumerate(dataset):
+ guid = "%s-%s" % (set_type, i)
+ label = "0"
+ text_a = self.process_text_fn(example["sentence1"])
+ text_b = self.process_text_fn(example["sentence2"])
+ if set_type != "test":
+ label = str(example["label"])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label,
+ weight=None))
+ return examples
+
+
+class XnliProcessor(DataProcessor):
+ """Processor for the XNLI data set."""
+ supported_languages = [
+ "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
+ "ur", "vi", "zh"
+ ]
+
+ def __init__(self,
+ language="en",
+ process_text_fn=tokenization.convert_to_unicode):
+ super(XnliProcessor, self).__init__(process_text_fn)
+ if language == "all":
+ self.languages = XnliProcessor.supported_languages
+ elif language not in XnliProcessor.supported_languages:
+ raise ValueError("language %s is not supported for XNLI task." % language)
+ else:
+ self.languages = [language]
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = []
+ for language in self.languages:
+ # Skips the header.
+ lines.extend(
+ self._read_tsv(
+ os.path.join(data_dir, "multinli",
+ "multinli.train.%s.tsv" % language))[1:])
+
+ examples = []
+ for i, line in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ if label == self.process_text_fn("contradictory"):
+ label = self.process_text_fn("contradiction")
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.dev.tsv"))
+ examples = []
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[6])
+ text_b = self.process_text_fn(line[7])
+ label = self.process_text_fn(line[1])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "xnli.test.tsv"))
+ examples_by_lang = {k: [] for k in XnliProcessor.supported_languages}
+ for i, line in enumerate(lines):
+ if i == 0:
+ continue
+ guid = "test-%d" % i
+ language = self.process_text_fn(line[0])
+ text_a = self.process_text_fn(line[6])
+ text_b = self.process_text_fn(line[7])
+ label = self.process_text_fn(line[1])
+ examples_by_lang[language].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XNLI"
+
+
+class XtremePawsxProcessor(DataProcessor):
+ """Processor for the XTREME PAWS-X data set."""
+ supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
+
+ def __init__(self,
+ process_text_fn=tokenization.convert_to_unicode,
+ translated_data_dir=None,
+ only_use_en_dev=True):
+ """See base class.
+
+ Args:
+ process_text_fn: See base class.
+ translated_data_dir: If specified, will also include translated data in
+ the training and testing data.
+ only_use_en_dev: If True, only use english dev data. Otherwise, use dev
+ data from all languages.
+ """
+ super(XtremePawsxProcessor, self).__init__(process_text_fn)
+ self.translated_data_dir = translated_data_dir
+ self.only_use_en_dev = only_use_en_dev
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ examples = []
+ if self.translated_data_dir is None:
+ lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
+ for i, line in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ else:
+ for lang in self.supported_languages:
+ lines = self._read_tsv(
+ os.path.join(self.translated_data_dir, "translate-train",
+ f"en-{lang}-translated.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"train-{lang}-{i}"
+ text_a = self.process_text_fn(line[2])
+ text_b = self.process_text_fn(line[3])
+ label = self.process_text_fn(line[4])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ examples = []
+ if self.only_use_en_dev:
+ lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
+ for i, line in enumerate(lines):
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ else:
+ for lang in self.supported_languages:
+ lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"dev-{lang}-{i}"
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ examples_by_lang = {}
+ for lang in self.supported_languages:
+ examples_by_lang[lang] = []
+ lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"test-{lang}-{i}"
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = "0"
+ examples_by_lang[lang].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ if self.translated_data_dir is not None:
+ for lang in self.supported_languages:
+ if lang == "en":
+ continue
+ examples_by_lang[f"{lang}-en"] = []
+ lines = self._read_tsv(
+ os.path.join(self.translated_data_dir, "translate-test",
+ f"test-{lang}-en-translated.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"test-{lang}-en-{i}"
+ text_a = self.process_text_fn(line[2])
+ text_b = self.process_text_fn(line[3])
+ label = "0"
+ examples_by_lang[f"{lang}-en"].append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["0", "1"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XTREME-PAWS-X"
+
+
+class XtremeXnliProcessor(DataProcessor):
+ """Processor for the XTREME XNLI data set."""
+ supported_languages = [
+ "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
+ "ur", "vi", "zh"
+ ]
+
+ def __init__(self,
+ process_text_fn=tokenization.convert_to_unicode,
+ translated_data_dir=None,
+ only_use_en_dev=True):
+ """See base class.
+
+ Args:
+ process_text_fn: See base class.
+ translated_data_dir: If specified, will also include translated data in
+ the training data.
+ only_use_en_dev: If True, only use english dev data. Otherwise, use dev
+ data from all languages.
+ """
+ super(XtremeXnliProcessor, self).__init__(process_text_fn)
+ self.translated_data_dir = translated_data_dir
+ self.only_use_en_dev = only_use_en_dev
+
+ def get_train_examples(self, data_dir):
+ """See base class."""
+ lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
+
+ examples = []
+ if self.translated_data_dir is None:
+ for i, line in enumerate(lines):
+ guid = "train-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ if label == self.process_text_fn("contradictory"):
+ label = self.process_text_fn("contradiction")
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ else:
+ for lang in self.supported_languages:
+ lines = self._read_tsv(
+ os.path.join(self.translated_data_dir, "translate-train",
+ f"en-{lang}-translated.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"train-{lang}-{i}"
+ text_a = self.process_text_fn(line[2])
+ text_b = self.process_text_fn(line[3])
+ label = self.process_text_fn(line[4])
+ if label == self.process_text_fn("contradictory"):
+ label = self.process_text_fn("contradiction")
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_dev_examples(self, data_dir):
+ """See base class."""
+ examples = []
+ if self.only_use_en_dev:
+ lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
+ for i, line in enumerate(lines):
+ guid = "dev-%d" % i
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ else:
+ for lang in self.supported_languages:
+ lines = self._read_tsv(os.path.join(data_dir, f"dev-{lang}.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"dev-{lang}-{i}"
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = self.process_text_fn(line[2])
+ if label == self.process_text_fn("contradictory"):
+ label = self.process_text_fn("contradiction")
+ examples.append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ examples_by_lang = {}
+ for lang in self.supported_languages:
+ examples_by_lang[lang] = []
+ lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"test-{lang}-{i}"
+ text_a = self.process_text_fn(line[0])
+ text_b = self.process_text_fn(line[1])
+ label = "contradiction"
+ examples_by_lang[lang].append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ if self.translated_data_dir is not None:
+ for lang in self.supported_languages:
+ if lang == "en":
+ continue
+ examples_by_lang[f"{lang}-en"] = []
+ lines = self._read_tsv(
+ os.path.join(self.translated_data_dir, "translate-test",
+ f"test-{lang}-en-translated.tsv"))
+ for i, line in enumerate(lines):
+ guid = f"test-{lang}-en-{i}"
+ text_a = self.process_text_fn(line[2])
+ text_b = self.process_text_fn(line[3])
+ label = "contradiction"
+ examples_by_lang[f"{lang}-en"].append(
+ InputExample(
+ guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples_by_lang
+
+ def get_labels(self):
+ """See base class."""
+ return ["contradiction", "entailment", "neutral"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "XTREME-XNLI"
+
+
+def convert_single_example(ex_index, example, label_list, max_seq_length,
+ tokenizer):
+ """Converts a single `InputExample` into a single `InputFeatures`."""
+ label_map = {}
+ if label_list:
+ for (i, label) in enumerate(label_list):
+ label_map[label] = i
+
+ tokens_a = tokenizer.tokenize(example.text_a)
+ tokens_b = None
+ if example.text_b:
+ tokens_b = tokenizer.tokenize(example.text_b)
+
+ if tokens_b:
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
+ # length is less than the specified length.
+ # Account for [CLS], [SEP], [SEP] with "- 3"
+ _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
+ else:
+ # Account for [CLS] and [SEP] with "- 2"
+ if len(tokens_a) > max_seq_length - 2:
+ tokens_a = tokens_a[0:(max_seq_length - 2)]
+
+ seg_id_a = 0
+ seg_id_b = 1
+ seg_id_cls = 0
+ seg_id_pad = 0
+
+ # The convention in BERT is:
+ # (a) For sequence pairs:
+ # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
+ # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
+ # (b) For single sequences:
+ # tokens: [CLS] the dog is hairy . [SEP]
+ # type_ids: 0 0 0 0 0 0 0
+ #
+ # Where "type_ids" are used to indicate whether this is the first
+ # sequence or the second sequence. The embedding vectors for `type=0` and
+ # `type=1` were learned during pre-training and are added to the wordpiece
+ # embedding vector (and position vector). This is not *strictly* necessary
+ # since the [SEP] token unambiguously separates the sequences, but it makes
+ # it easier for the model to learn the concept of sequences.
+ #
+ # For classification tasks, the first vector (corresponding to [CLS]) is
+ # used as the "sentence vector". Note that this only makes sense because
+ # the entire model is fine-tuned.
+ tokens = []
+ segment_ids = []
+ tokens.append("[CLS]")
+ segment_ids.append(seg_id_cls)
+ for token in tokens_a:
+ tokens.append(token)
+ segment_ids.append(seg_id_a)
+ tokens.append("[SEP]")
+ segment_ids.append(seg_id_a)
+
+ if tokens_b:
+ for token in tokens_b:
+ tokens.append(token)
+ segment_ids.append(seg_id_b)
+ tokens.append("[SEP]")
+ segment_ids.append(seg_id_b)
+
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(seg_id_pad)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ label_id = label_map[example.label] if label_map else example.label
+ if ex_index < 5:
+ logging.info("*** Example ***")
+ logging.info("guid: %s", (example.guid))
+ logging.info("tokens: %s",
+ " ".join([tokenization.printable_text(x) for x in tokens]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("label: %s (id = %s)", example.label, str(label_id))
+ logging.info("weight: %s", example.weight)
+ logging.info("example_id: %s", example.example_id)
+
+ feature = InputFeatures(
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ label_id=label_id,
+ is_real_example=True,
+ weight=example.weight,
+ example_id=example.example_id)
+
+ return feature
+
+
+class AXgProcessor(DataProcessor):
+ """Processor for the AXg dataset (SuperGLUE diagnostics dataset)."""
+
+ def get_test_examples(self, data_dir):
+ """See base class."""
+ return self._create_examples(
+ self._read_jsonl(os.path.join(data_dir, "AX-g.jsonl")), "test")
+
+ def get_labels(self):
+ """See base class."""
+ return ["entailment", "not_entailment"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "AXg"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training/dev/test sets."""
+ examples = []
+ for line in lines:
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(line["idx"])))
+ text_a = self.process_text_fn(line["premise"])
+ text_b = self.process_text_fn(line["hypothesis"])
+ label = self.process_text_fn(line["label"])
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class BoolQProcessor(DefaultGLUEDataProcessor):
+ """Processor for the BoolQ dataset (SuperGLUE diagnostics dataset)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["True", "False"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "BoolQ"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "super_glue/boolq", split=set_type, try_gcs=True).as_numpy_iterator()
+ examples = []
+ for example in dataset:
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
+ text_a = self.process_text_fn(example["question"])
+ text_b = self.process_text_fn(example["passage"])
+ label = "False"
+ if set_type != "test":
+ label = self.get_labels()[example["label"]]
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class CBProcessor(DefaultGLUEDataProcessor):
+ """Processor for the CB dataset (SuperGLUE diagnostics dataset)."""
+
+ def get_labels(self):
+ """See base class."""
+ return ["entailment", "neutral", "contradiction"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "CB"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ dataset = tfds.load(
+ "super_glue/cb", split=set_type, try_gcs=True).as_numpy_iterator()
+ examples = []
+ for example in dataset:
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
+ text_a = self.process_text_fn(example["premise"])
+ text_b = self.process_text_fn(example["hypothesis"])
+ label = "entailment"
+ if set_type != "test":
+ label = self.get_labels()[example["label"]]
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class SuperGLUERTEProcessor(DefaultGLUEDataProcessor):
+ """Processor for the RTE dataset (SuperGLUE version)."""
+
+ def get_labels(self):
+ """See base class."""
+ # All datasets are converted to 2-class split, where for 3-class datasets we
+ # collapse neutral and contradiction into not_entailment.
+ return ["entailment", "not_entailment"]
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "RTESuperGLUE"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ examples = []
+ dataset = tfds.load(
+ "super_glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
+ for example in dataset:
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
+ text_a = self.process_text_fn(example["premise"])
+ text_b = self.process_text_fn(example["hypothesis"])
+ label = "entailment"
+ if set_type != "test":
+ label = self.get_labels()[example["label"]]
+ examples.append(
+ InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
+ return examples
+
+
+class WiCInputExample(InputExample):
+ """Processor for the WiC dataset (SuperGLUE version)."""
+
+ def __init__(self,
+ guid,
+ text_a,
+ text_b=None,
+ label=None,
+ word=None,
+ weight=None,
+ example_id=None):
+ """A single training/test example for simple seq regression/classification."""
+ super(WiCInputExample, self).__init__(guid, text_a, text_b, label, weight,
+ example_id)
+ self.word = word
+
+
+class WiCProcessor(DefaultGLUEDataProcessor):
+ """Processor for the RTE dataset (SuperGLUE version)."""
+
+ def get_labels(self):
+ """Not used."""
+ return []
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "RTESuperGLUE"
+
+ def _create_examples_tfds(self, set_type):
+ """Creates examples for the training/dev/test sets."""
+ examples = []
+ dataset = tfds.load(
+ "super_glue/wic", split=set_type, try_gcs=True).as_numpy_iterator()
+ for example in dataset:
+ guid = "%s-%s" % (set_type, self.process_text_fn(str(example["idx"])))
+ text_a = self.process_text_fn(example["sentence1"])
+ text_b = self.process_text_fn(example["sentence2"])
+ word = self.process_text_fn(example["word"])
+ label = 0
+ if set_type != "test":
+ label = example["label"]
+ examples.append(
+ WiCInputExample(
+ guid=guid, text_a=text_a, text_b=text_b, word=word, label=label))
+ return examples
+
+ def featurize_example(self, ex_index, example, label_list, max_seq_length,
+ tokenizer):
+ """Here we concate sentence1, sentence2, word together with [SEP] tokens."""
+ del label_list
+ tokens_a = tokenizer.tokenize(example.text_a)
+ tokens_b = tokenizer.tokenize(example.text_b)
+ tokens_word = tokenizer.tokenize(example.word)
+
+ # Modifies `tokens_a` and `tokens_b` in place so that the total
+ # length is less than the specified length.
+ # Account for [CLS], [SEP], [SEP], [SEP] with "- 4"
+ # Here we only pop out the first two sentence tokens.
+ _truncate_seq_pair(tokens_a, tokens_b,
+ max_seq_length - 4 - len(tokens_word))
+
+ seg_id_a = 0
+ seg_id_b = 1
+ seg_id_c = 2
+ seg_id_cls = 0
+ seg_id_pad = 0
+
+ tokens = []
+ segment_ids = []
+ tokens.append("[CLS]")
+ segment_ids.append(seg_id_cls)
+ for token in tokens_a:
+ tokens.append(token)
+ segment_ids.append(seg_id_a)
+ tokens.append("[SEP]")
+ segment_ids.append(seg_id_a)
+
+ for token in tokens_b:
+ tokens.append(token)
+ segment_ids.append(seg_id_b)
+
+ tokens.append("[SEP]")
+ segment_ids.append(seg_id_b)
+
+ for token in tokens_word:
+ tokens.append(token)
+ segment_ids.append(seg_id_c)
+
+ tokens.append("[SEP]")
+ segment_ids.append(seg_id_c)
+
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(seg_id_pad)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ label_id = example.label
+ if ex_index < 5:
+ logging.info("*** Example ***")
+ logging.info("guid: %s", (example.guid))
+ logging.info("tokens: %s",
+ " ".join([tokenization.printable_text(x) for x in tokens]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("label: %s (id = %s)", example.label, str(label_id))
+ logging.info("weight: %s", example.weight)
+ logging.info("example_id: %s", example.example_id)
+
+ feature = InputFeatures(
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ label_id=label_id,
+ is_real_example=True,
+ weight=example.weight,
+ example_id=example.example_id)
+
+ return feature
+
+
+def file_based_convert_examples_to_features(examples,
+ label_list,
+ max_seq_length,
+ tokenizer,
+ output_file,
+ label_type=None,
+ featurize_fn=None):
+ """Convert a set of `InputExample`s to a TFRecord file."""
+
+ tf.io.gfile.makedirs(os.path.dirname(output_file))
+ writer = tf.io.TFRecordWriter(output_file)
+
+ for ex_index, example in enumerate(examples):
+ if ex_index % 10000 == 0:
+ logging.info("Writing example %d of %d", ex_index, len(examples))
+
+ if featurize_fn:
+ feature = featurize_fn(ex_index, example, label_list, max_seq_length,
+ tokenizer)
+ else:
+ feature = convert_single_example(ex_index, example, label_list,
+ max_seq_length, tokenizer)
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ features = collections.OrderedDict()
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_int_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+ if label_type is not None and label_type == float:
+ features["label_ids"] = create_float_feature([feature.label_id])
+ elif feature.label_id is not None:
+ features["label_ids"] = create_int_feature([feature.label_id])
+ features["is_real_example"] = create_int_feature(
+ [int(feature.is_real_example)])
+ if feature.weight is not None:
+ features["weight"] = create_float_feature([feature.weight])
+ if feature.example_id is not None:
+ features["example_id"] = create_int_feature([feature.example_id])
+ else:
+ features["example_id"] = create_int_feature([ex_index])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+def _truncate_seq_pair(tokens_a, tokens_b, max_length):
+ """Truncates a sequence pair in place to the maximum length."""
+
+ # This is a simple heuristic which will always truncate the longer sequence
+ # one token at a time. This makes more sense than truncating an equal percent
+ # of tokens from each, since if one sequence is very short then each token
+ # that's truncated likely contains more information than a longer sequence.
+ while True:
+ total_length = len(tokens_a) + len(tokens_b)
+ if total_length <= max_length:
+ break
+ if len(tokens_a) > len(tokens_b):
+ tokens_a.pop()
+ else:
+ tokens_b.pop()
+
+
+def generate_tf_record_from_data_file(processor,
+ data_dir,
+ tokenizer,
+ train_data_output_path=None,
+ eval_data_output_path=None,
+ test_data_output_path=None,
+ max_seq_length=128):
+ """Generates and saves training data into a tf record file.
+
+ Args:
+ processor: Input processor object to be used for generating data. Subclass
+ of `DataProcessor`.
+ data_dir: Directory that contains train/eval/test data to process.
+ tokenizer: The tokenizer to be applied on the data.
+ train_data_output_path: Output to which processed tf record for training
+ will be saved.
+ eval_data_output_path: Output to which processed tf record for evaluation
+ will be saved.
+ test_data_output_path: Output to which processed tf record for testing
+ will be saved. Must be a pattern template with {} if processor has
+ language specific test data.
+ max_seq_length: Maximum sequence length of the to be generated
+ training/eval data.
+
+ Returns:
+ A dictionary containing input meta data.
+ """
+ assert train_data_output_path or eval_data_output_path
+
+ label_list = processor.get_labels()
+ label_type = getattr(processor, "label_type", None)
+ is_regression = getattr(processor, "is_regression", False)
+ has_sample_weights = getattr(processor, "weight_key", False)
+
+ num_training_data = 0
+ if train_data_output_path:
+ train_input_data_examples = processor.get_train_examples(data_dir)
+ file_based_convert_examples_to_features(train_input_data_examples,
+ label_list, max_seq_length,
+ tokenizer, train_data_output_path,
+ label_type,
+ processor.featurize_example)
+ num_training_data = len(train_input_data_examples)
+
+ if eval_data_output_path:
+ eval_input_data_examples = processor.get_dev_examples(data_dir)
+ file_based_convert_examples_to_features(eval_input_data_examples,
+ label_list, max_seq_length,
+ tokenizer, eval_data_output_path,
+ label_type,
+ processor.featurize_example)
+
+ meta_data = {
+ "processor_type": processor.get_processor_name(),
+ "train_data_size": num_training_data,
+ "max_seq_length": max_seq_length,
+ }
+
+ if test_data_output_path:
+ test_input_data_examples = processor.get_test_examples(data_dir)
+ if isinstance(test_input_data_examples, dict):
+ for language, examples in test_input_data_examples.items():
+ file_based_convert_examples_to_features(
+ examples, label_list, max_seq_length, tokenizer,
+ test_data_output_path.format(language), label_type,
+ processor.featurize_example)
+ meta_data["test_{}_data_size".format(language)] = len(examples)
+ else:
+ file_based_convert_examples_to_features(test_input_data_examples,
+ label_list, max_seq_length,
+ tokenizer, test_data_output_path,
+ label_type,
+ processor.featurize_example)
+ meta_data["test_data_size"] = len(test_input_data_examples)
+
+ if is_regression:
+ meta_data["task_type"] = "bert_regression"
+ meta_data["label_type"] = {int: "int", float: "float"}[label_type]
+ else:
+ meta_data["task_type"] = "bert_classification"
+ meta_data["num_labels"] = len(processor.get_labels())
+ if has_sample_weights:
+ meta_data["has_sample_weights"] = True
+
+ if eval_data_output_path:
+ meta_data["eval_data_size"] = len(eval_input_data_examples)
+
+ return meta_data
diff --git a/modeling/official/nlp/data/classifier_data_lib_test.py b/modeling/official/nlp/data/classifier_data_lib_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef54a6742b125da30a1c9023312ed183e91d7fa2
--- /dev/null
+++ b/modeling/official/nlp/data/classifier_data_lib_test.py
@@ -0,0 +1,95 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for third_party.tensorflow_models.official.nlp.data.classifier_data_lib."""
+
+import os
+import tempfile
+
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+import tensorflow_datasets as tfds
+
+from official.nlp.data import classifier_data_lib
+from official.nlp.tools import tokenization
+
+
+def decode_record(record, name_to_features):
+ """Decodes a record to a TensorFlow example."""
+ return tf.io.parse_single_example(record, name_to_features)
+
+
+class BertClassifierLibTest(tf.test.TestCase, parameterized.TestCase):
+
+ def setUp(self):
+ super(BertClassifierLibTest, self).setUp()
+ self.model_dir = self.get_temp_dir()
+ self.processors = {
+ "CB": classifier_data_lib.CBProcessor,
+ "SUPERGLUE-RTE": classifier_data_lib.SuperGLUERTEProcessor,
+ "BOOLQ": classifier_data_lib.BoolQProcessor,
+ "WIC": classifier_data_lib.WiCProcessor,
+ }
+
+ vocab_tokens = [
+ "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
+ "##ing", ","
+ ]
+ with tempfile.NamedTemporaryFile(delete=False) as vocab_writer:
+ vocab_writer.write("".join([x + "\n" for x in vocab_tokens
+ ]).encode("utf-8"))
+ vocab_file = vocab_writer.name
+ self.tokenizer = tokenization.FullTokenizer(vocab_file)
+
+ @parameterized.parameters(
+ {"task_type": "CB"},
+ {"task_type": "BOOLQ"},
+ {"task_type": "SUPERGLUE-RTE"},
+ {"task_type": "WIC"},
+ )
+ def test_generate_dataset_from_tfds_processor(self, task_type):
+ with tfds.testing.mock_data(num_examples=5):
+ output_path = os.path.join(self.model_dir, task_type)
+
+ processor = self.processors[task_type]()
+
+ classifier_data_lib.generate_tf_record_from_data_file(
+ processor,
+ None,
+ self.tokenizer,
+ train_data_output_path=output_path,
+ eval_data_output_path=output_path,
+ test_data_output_path=output_path)
+ files = tf.io.gfile.glob(output_path)
+ self.assertNotEmpty(files)
+
+ train_dataset = tf.data.TFRecordDataset(output_path)
+ seq_length = 128
+ label_type = tf.int64
+ name_to_features = {
+ "input_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "input_mask": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "segment_ids": tf.io.FixedLenFeature([seq_length], tf.int64),
+ "label_ids": tf.io.FixedLenFeature([], label_type),
+ }
+ train_dataset = train_dataset.map(
+ lambda record: decode_record(record, name_to_features))
+
+ # If data is retrieved without error, then all requirements
+ # including data type/shapes are met.
+ _ = next(iter(train_dataset))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/nlp/data/create_finetuning_data.py b/modeling/official/nlp/data/create_finetuning_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..631a7e480445d4ecf24c934d2a45875e60d59c88
--- /dev/null
+++ b/modeling/official/nlp/data/create_finetuning_data.py
@@ -0,0 +1,441 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT finetuning task dataset generator."""
+
+import functools
+import json
+import os
+
+# Import libraries
+from absl import app
+from absl import flags
+import tensorflow as tf, tf_keras
+from official.nlp.data import classifier_data_lib
+from official.nlp.data import sentence_retrieval_lib
+# word-piece tokenizer based squad_lib
+from official.nlp.data import squad_lib as squad_lib_wp
+# sentence-piece tokenizer based squad_lib
+from official.nlp.data import squad_lib_sp
+from official.nlp.data import tagging_data_lib
+from official.nlp.tools import tokenization
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_enum(
+ "fine_tuning_task_type", "classification",
+ ["classification", "regression", "squad", "retrieval", "tagging"],
+ "The name of the BERT fine tuning task for which data "
+ "will be generated.")
+
+# BERT classification specific flags.
+flags.DEFINE_string(
+ "input_data_dir", None,
+ "The input data dir. Should contain the .tsv files (or other data files) "
+ "for the task.")
+
+flags.DEFINE_enum(
+ "classification_task_name", "MNLI", [
+ "AX", "COLA", "IMDB", "MNLI", "MRPC", "PAWS-X", "QNLI", "QQP", "RTE",
+ "SST-2", "STS-B", "WNLI", "XNLI", "XTREME-XNLI", "XTREME-PAWS-X",
+ "AX-g", "SUPERGLUE-RTE", "CB", "BoolQ", "WIC"
+ ], "The name of the task to train BERT classifier. The "
+ "difference between XTREME-XNLI and XNLI is: 1. the format "
+ "of input tsv files; 2. the dev set for XTREME is english "
+ "only and for XNLI is all languages combined. Same for "
+ "PAWS-X.")
+
+# MNLI task-specific flag.
+flags.DEFINE_enum("mnli_type", "matched", ["matched", "mismatched"],
+ "The type of MNLI dataset.")
+
+# XNLI task-specific flag.
+flags.DEFINE_string(
+ "xnli_language", "en",
+ "Language of training data for XNLI task. If the value is 'all', the data "
+ "of all languages will be used for training.")
+
+# PAWS-X task-specific flag.
+flags.DEFINE_string(
+ "pawsx_language", "en",
+ "Language of training data for PAWS-X task. If the value is 'all', the data "
+ "of all languages will be used for training.")
+
+# XTREME classification specific flags. Only used in XtremePawsx and XtremeXnli.
+flags.DEFINE_string(
+ "translated_input_data_dir", None,
+ "The translated input data dir. Should contain the .tsv files (or other "
+ "data files) for the task.")
+
+# Retrieval task-specific flags.
+flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
+ "The name of sentence retrieval task for scoring")
+
+# Tagging task-specific flags.
+flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
+ "The name of BERT tagging (token classification) task.")
+
+flags.DEFINE_bool("tagging_only_use_en_train", True,
+ "Whether only use english training data in tagging.")
+
+# BERT Squad task-specific flags.
+flags.DEFINE_string(
+ "squad_data_file", None,
+ "The input data file in for generating training data for BERT squad task.")
+
+flags.DEFINE_string(
+ "translated_squad_data_folder", None,
+ "The translated data folder for generating training data for BERT squad "
+ "task.")
+
+flags.DEFINE_integer(
+ "doc_stride", 128,
+ "When splitting up a long document into chunks, how much stride to "
+ "take between chunks.")
+
+flags.DEFINE_integer(
+ "max_query_length", 64,
+ "The maximum number of tokens for the question. Questions longer than "
+ "this will be truncated to this length.")
+
+flags.DEFINE_bool(
+ "version_2_with_negative", False,
+ "If true, the SQuAD examples contain some that do not have an answer.")
+
+flags.DEFINE_bool(
+ "xlnet_format", False,
+ "If true, then data will be preprocessed in a paragraph, query, class order"
+ " instead of the BERT-style class, paragraph, query order.")
+
+# XTREME specific flags.
+flags.DEFINE_bool("only_use_en_dev", True, "Whether only use english dev data.")
+
+# Shared flags across BERT fine-tuning tasks.
+flags.DEFINE_string("vocab_file", None,
+ "The vocabulary file that the BERT model was trained on.")
+
+flags.DEFINE_string(
+ "train_data_output_path", None,
+ "The path in which generated training input data will be written as tf"
+ " records.")
+
+flags.DEFINE_string(
+ "eval_data_output_path", None,
+ "The path in which generated evaluation input data will be written as tf"
+ " records.")
+
+flags.DEFINE_string(
+ "test_data_output_path", None,
+ "The path in which generated test input data will be written as tf"
+ " records. If None, do not generate test data. Must be a pattern template"
+ " as test_{}.tfrecords if processor has language specific test data.")
+
+flags.DEFINE_string("meta_data_file_path", None,
+ "The path in which input meta data will be written.")
+
+flags.DEFINE_bool(
+ "do_lower_case", True,
+ "Whether to lower case the input text. Should be True for uncased "
+ "models and False for cased models.")
+
+flags.DEFINE_integer(
+ "max_seq_length", 128,
+ "The maximum total input sequence length after WordPiece tokenization. "
+ "Sequences longer than this will be truncated, and sequences shorter "
+ "than this will be padded.")
+
+flags.DEFINE_string("sp_model_file", "",
+ "The path to the model used by sentence piece tokenizer.")
+
+flags.DEFINE_enum(
+ "tokenization", "WordPiece", ["WordPiece", "SentencePiece"],
+ "Specifies the tokenizer implementation, i.e., whether to use WordPiece "
+ "or SentencePiece tokenizer. Canonical BERT uses WordPiece tokenizer, "
+ "while ALBERT uses SentencePiece tokenizer.")
+
+flags.DEFINE_string(
+ "tfds_params", "", "Comma-separated list of TFDS parameter assigments for "
+ "generic classfication data import (for more details "
+ "see the TfdsProcessor class documentation).")
+
+
+def generate_classifier_dataset():
+ """Generates classifier dataset and returns input meta data."""
+ if FLAGS.classification_task_name in [
+ "COLA",
+ "WNLI",
+ "SST-2",
+ "MRPC",
+ "QQP",
+ "STS-B",
+ "MNLI",
+ "QNLI",
+ "RTE",
+ "AX",
+ "SUPERGLUE-RTE",
+ "CB",
+ "BoolQ",
+ "WIC",
+ ]:
+ assert not FLAGS.input_data_dir or FLAGS.tfds_params
+ else:
+ assert (FLAGS.input_data_dir and FLAGS.classification_task_name or
+ FLAGS.tfds_params)
+
+ if FLAGS.tokenization == "WordPiece":
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ processor_text_fn = tokenization.convert_to_unicode
+ else:
+ assert FLAGS.tokenization == "SentencePiece"
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ processor_text_fn = functools.partial(
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
+
+ if FLAGS.tfds_params:
+ processor = classifier_data_lib.TfdsProcessor(
+ tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
+ return classifier_data_lib.generate_tf_record_from_data_file(
+ processor,
+ None,
+ tokenizer,
+ train_data_output_path=FLAGS.train_data_output_path,
+ eval_data_output_path=FLAGS.eval_data_output_path,
+ test_data_output_path=FLAGS.test_data_output_path,
+ max_seq_length=FLAGS.max_seq_length)
+ else:
+ processors = {
+ "ax":
+ classifier_data_lib.AxProcessor,
+ "cola":
+ classifier_data_lib.ColaProcessor,
+ "imdb":
+ classifier_data_lib.ImdbProcessor,
+ "mnli":
+ functools.partial(
+ classifier_data_lib.MnliProcessor, mnli_type=FLAGS.mnli_type),
+ "mrpc":
+ classifier_data_lib.MrpcProcessor,
+ "qnli":
+ classifier_data_lib.QnliProcessor,
+ "qqp":
+ classifier_data_lib.QqpProcessor,
+ "rte":
+ classifier_data_lib.RteProcessor,
+ "sst-2":
+ classifier_data_lib.SstProcessor,
+ "sts-b":
+ classifier_data_lib.StsBProcessor,
+ "xnli":
+ functools.partial(
+ classifier_data_lib.XnliProcessor,
+ language=FLAGS.xnli_language),
+ "paws-x":
+ functools.partial(
+ classifier_data_lib.PawsxProcessor,
+ language=FLAGS.pawsx_language),
+ "wnli":
+ classifier_data_lib.WnliProcessor,
+ "xtreme-xnli":
+ functools.partial(
+ classifier_data_lib.XtremeXnliProcessor,
+ translated_data_dir=FLAGS.translated_input_data_dir,
+ only_use_en_dev=FLAGS.only_use_en_dev),
+ "xtreme-paws-x":
+ functools.partial(
+ classifier_data_lib.XtremePawsxProcessor,
+ translated_data_dir=FLAGS.translated_input_data_dir,
+ only_use_en_dev=FLAGS.only_use_en_dev),
+ "ax-g":
+ classifier_data_lib.AXgProcessor,
+ "superglue-rte":
+ classifier_data_lib.SuperGLUERTEProcessor,
+ "cb":
+ classifier_data_lib.CBProcessor,
+ "boolq":
+ classifier_data_lib.BoolQProcessor,
+ "wic":
+ classifier_data_lib.WnliProcessor,
+ }
+ task_name = FLAGS.classification_task_name.lower()
+ if task_name not in processors:
+ raise ValueError("Task not found: %s" % (task_name))
+
+ processor = processors[task_name](process_text_fn=processor_text_fn)
+ return classifier_data_lib.generate_tf_record_from_data_file(
+ processor,
+ FLAGS.input_data_dir,
+ tokenizer,
+ train_data_output_path=FLAGS.train_data_output_path,
+ eval_data_output_path=FLAGS.eval_data_output_path,
+ test_data_output_path=FLAGS.test_data_output_path,
+ max_seq_length=FLAGS.max_seq_length)
+
+
+def generate_regression_dataset():
+ """Generates regression dataset and returns input meta data."""
+ if FLAGS.tokenization == "WordPiece":
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ processor_text_fn = tokenization.convert_to_unicode
+ else:
+ assert FLAGS.tokenization == "SentencePiece"
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ processor_text_fn = functools.partial(
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
+
+ if FLAGS.tfds_params:
+ processor = classifier_data_lib.TfdsProcessor(
+ tfds_params=FLAGS.tfds_params, process_text_fn=processor_text_fn)
+ return classifier_data_lib.generate_tf_record_from_data_file(
+ processor,
+ None,
+ tokenizer,
+ train_data_output_path=FLAGS.train_data_output_path,
+ eval_data_output_path=FLAGS.eval_data_output_path,
+ test_data_output_path=FLAGS.test_data_output_path,
+ max_seq_length=FLAGS.max_seq_length)
+ else:
+ raise ValueError("No data processor found for the given regression task.")
+
+
+def generate_squad_dataset():
+ """Generates squad training dataset and returns input meta data."""
+ assert FLAGS.squad_data_file
+ if FLAGS.tokenization == "WordPiece":
+ return squad_lib_wp.generate_tf_record_from_json_file(
+ input_file_path=FLAGS.squad_data_file,
+ vocab_file_path=FLAGS.vocab_file,
+ output_path=FLAGS.train_data_output_path,
+ translated_input_folder=FLAGS.translated_squad_data_folder,
+ max_seq_length=FLAGS.max_seq_length,
+ do_lower_case=FLAGS.do_lower_case,
+ max_query_length=FLAGS.max_query_length,
+ doc_stride=FLAGS.doc_stride,
+ version_2_with_negative=FLAGS.version_2_with_negative,
+ xlnet_format=FLAGS.xlnet_format)
+ else:
+ assert FLAGS.tokenization == "SentencePiece"
+ return squad_lib_sp.generate_tf_record_from_json_file(
+ input_file_path=FLAGS.squad_data_file,
+ sp_model_file=FLAGS.sp_model_file,
+ output_path=FLAGS.train_data_output_path,
+ translated_input_folder=FLAGS.translated_squad_data_folder,
+ max_seq_length=FLAGS.max_seq_length,
+ do_lower_case=FLAGS.do_lower_case,
+ max_query_length=FLAGS.max_query_length,
+ doc_stride=FLAGS.doc_stride,
+ xlnet_format=FLAGS.xlnet_format,
+ version_2_with_negative=FLAGS.version_2_with_negative)
+
+
+def generate_retrieval_dataset():
+ """Generate retrieval test and dev dataset and returns input meta data."""
+ assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
+ if FLAGS.tokenization == "WordPiece":
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ processor_text_fn = tokenization.convert_to_unicode
+ else:
+ assert FLAGS.tokenization == "SentencePiece"
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ processor_text_fn = functools.partial(
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
+
+ processors = {
+ "bucc": sentence_retrieval_lib.BuccProcessor,
+ "tatoeba": sentence_retrieval_lib.TatoebaProcessor,
+ }
+
+ task_name = FLAGS.retrieval_task_name.lower()
+ if task_name not in processors:
+ raise ValueError("Task not found: %s" % task_name)
+
+ processor = processors[task_name](process_text_fn=processor_text_fn)
+
+ return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
+ processor, FLAGS.input_data_dir, tokenizer, FLAGS.eval_data_output_path,
+ FLAGS.test_data_output_path, FLAGS.max_seq_length)
+
+
+def generate_tagging_dataset():
+ """Generates tagging dataset."""
+ processors = {
+ "panx":
+ functools.partial(
+ tagging_data_lib.PanxProcessor,
+ only_use_en_train=FLAGS.tagging_only_use_en_train,
+ only_use_en_dev=FLAGS.only_use_en_dev),
+ "udpos":
+ functools.partial(
+ tagging_data_lib.UdposProcessor,
+ only_use_en_train=FLAGS.tagging_only_use_en_train,
+ only_use_en_dev=FLAGS.only_use_en_dev),
+ }
+ task_name = FLAGS.tagging_task_name.lower()
+ if task_name not in processors:
+ raise ValueError("Task not found: %s" % task_name)
+
+ if FLAGS.tokenization == "WordPiece":
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+ processor_text_fn = tokenization.convert_to_unicode
+ elif FLAGS.tokenization == "SentencePiece":
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ processor_text_fn = functools.partial(
+ tokenization.preprocess_text, lower=FLAGS.do_lower_case)
+ else:
+ raise ValueError("Unsupported tokenization: %s" % FLAGS.tokenization)
+
+ processor = processors[task_name]()
+ return tagging_data_lib.generate_tf_record_from_data_file(
+ processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
+ FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
+ FLAGS.test_data_output_path, processor_text_fn)
+
+
+def main(_):
+ if FLAGS.tokenization == "WordPiece":
+ if not FLAGS.vocab_file:
+ raise ValueError(
+ "FLAG vocab_file for word-piece tokenizer is not specified.")
+ else:
+ assert FLAGS.tokenization == "SentencePiece"
+ if not FLAGS.sp_model_file:
+ raise ValueError(
+ "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
+
+ if FLAGS.fine_tuning_task_type != "retrieval":
+ flags.mark_flag_as_required("train_data_output_path")
+
+ if FLAGS.fine_tuning_task_type == "classification":
+ input_meta_data = generate_classifier_dataset()
+ elif FLAGS.fine_tuning_task_type == "regression":
+ input_meta_data = generate_regression_dataset()
+ elif FLAGS.fine_tuning_task_type == "retrieval":
+ input_meta_data = generate_retrieval_dataset()
+ elif FLAGS.fine_tuning_task_type == "squad":
+ input_meta_data = generate_squad_dataset()
+ else:
+ assert FLAGS.fine_tuning_task_type == "tagging"
+ input_meta_data = generate_tagging_dataset()
+
+ tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
+ with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
+ writer.write(json.dumps(input_meta_data, indent=4) + "\n")
+
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("meta_data_file_path")
+ app.run(main)
diff --git a/modeling/official/nlp/data/create_pretraining_data.py b/modeling/official/nlp/data/create_pretraining_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..55a4b8489a870a4fc88b4c1adb4fc54a086a2131
--- /dev/null
+++ b/modeling/official/nlp/data/create_pretraining_data.py
@@ -0,0 +1,669 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Create masked LM/next sentence masked_lm TF examples for BERT."""
+
+import collections
+import itertools
+import random
+
+# Import libraries
+from absl import app
+from absl import flags
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.nlp.tools import tokenization
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string("input_file", None,
+ "Input raw text file (or comma-separated list of files).")
+
+flags.DEFINE_string(
+ "output_file", None,
+ "Output TF example file (or comma-separated list of files).")
+
+flags.DEFINE_string("vocab_file", None,
+ "The vocabulary file that the BERT model was trained on.")
+
+flags.DEFINE_bool(
+ "do_lower_case", True,
+ "Whether to lower case the input text. Should be True for uncased "
+ "models and False for cased models.")
+
+flags.DEFINE_bool(
+ "do_whole_word_mask", False,
+ "Whether to use whole word masking rather than per-WordPiece masking.")
+
+flags.DEFINE_integer(
+ "max_ngram_size", None,
+ "Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
+ "weighting scheme to favor shorter n-grams. "
+ "Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
+
+flags.DEFINE_bool(
+ "gzip_compress", False,
+ "Whether to use `GZIP` compress option to get compressed TFRecord files.")
+
+flags.DEFINE_bool(
+ "use_v2_feature_names", False,
+ "Whether to use the feature names consistent with the models.")
+
+flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
+
+flags.DEFINE_integer("max_predictions_per_seq", 20,
+ "Maximum number of masked LM predictions per sequence.")
+
+flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
+
+flags.DEFINE_integer(
+ "dupe_factor", 10,
+ "Number of times to duplicate the input data (with different masks).")
+
+flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
+
+flags.DEFINE_float(
+ "short_seq_prob", 0.1,
+ "Probability of creating sequences which are shorter than the "
+ "maximum length.")
+
+
+class TrainingInstance(object):
+ """A single training instance (sentence pair)."""
+
+ def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
+ is_random_next):
+ self.tokens = tokens
+ self.segment_ids = segment_ids
+ self.is_random_next = is_random_next
+ self.masked_lm_positions = masked_lm_positions
+ self.masked_lm_labels = masked_lm_labels
+
+ def __str__(self):
+ s = ""
+ s += "tokens: %s\n" % (" ".join(
+ [tokenization.printable_text(x) for x in self.tokens]))
+ s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
+ s += "is_random_next: %s\n" % self.is_random_next
+ s += "masked_lm_positions: %s\n" % (" ".join(
+ [str(x) for x in self.masked_lm_positions]))
+ s += "masked_lm_labels: %s\n" % (" ".join(
+ [tokenization.printable_text(x) for x in self.masked_lm_labels]))
+ s += "\n"
+ return s
+
+ def __repr__(self):
+ return self.__str__()
+
+
+def write_instance_to_example_files(instances, tokenizer, max_seq_length,
+ max_predictions_per_seq, output_files,
+ gzip_compress, use_v2_feature_names):
+ """Creates TF example files from `TrainingInstance`s."""
+ writers = []
+ for output_file in output_files:
+ writers.append(
+ tf.io.TFRecordWriter(
+ output_file, options="GZIP" if gzip_compress else ""))
+
+ writer_index = 0
+
+ total_written = 0
+ for (inst_index, instance) in enumerate(instances):
+ input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
+ input_mask = [1] * len(input_ids)
+ segment_ids = list(instance.segment_ids)
+ assert len(input_ids) <= max_seq_length
+
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(0)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+
+ masked_lm_positions = list(instance.masked_lm_positions)
+ masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
+ masked_lm_weights = [1.0] * len(masked_lm_ids)
+
+ while len(masked_lm_positions) < max_predictions_per_seq:
+ masked_lm_positions.append(0)
+ masked_lm_ids.append(0)
+ masked_lm_weights.append(0.0)
+
+ next_sentence_label = 1 if instance.is_random_next else 0
+
+ features = collections.OrderedDict()
+ if use_v2_feature_names:
+ features["input_word_ids"] = create_int_feature(input_ids)
+ features["input_type_ids"] = create_int_feature(segment_ids)
+ else:
+ features["input_ids"] = create_int_feature(input_ids)
+ features["segment_ids"] = create_int_feature(segment_ids)
+
+ features["input_mask"] = create_int_feature(input_mask)
+ features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
+ features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
+ features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
+ features["next_sentence_labels"] = create_int_feature([next_sentence_label])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+
+ writers[writer_index].write(tf_example.SerializeToString())
+ writer_index = (writer_index + 1) % len(writers)
+
+ total_written += 1
+
+ if inst_index < 20:
+ logging.info("*** Example ***")
+ logging.info("tokens: %s", " ".join(
+ [tokenization.printable_text(x) for x in instance.tokens]))
+
+ for feature_name in features.keys():
+ feature = features[feature_name]
+ values = []
+ if feature.int64_list.value:
+ values = feature.int64_list.value
+ elif feature.float_list.value:
+ values = feature.float_list.value
+ logging.info("%s: %s", feature_name, " ".join([str(x) for x in values]))
+
+ for writer in writers:
+ writer.close()
+
+ logging.info("Wrote %d total instances", total_written)
+
+
+def create_int_feature(values):
+ feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+
+def create_float_feature(values):
+ feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return feature
+
+
+def create_training_instances(input_files,
+ tokenizer,
+ max_seq_length,
+ dupe_factor,
+ short_seq_prob,
+ masked_lm_prob,
+ max_predictions_per_seq,
+ rng,
+ do_whole_word_mask=False,
+ max_ngram_size=None):
+ """Create `TrainingInstance`s from raw text."""
+ all_documents = [[]]
+
+ # Input file format:
+ # (1) One sentence per line. These should ideally be actual sentences, not
+ # entire paragraphs or arbitrary spans of text. (Because we use the
+ # sentence boundaries for the "next sentence prediction" task).
+ # (2) Blank lines between documents. Document boundaries are needed so
+ # that the "next sentence prediction" task doesn't span between documents.
+ for input_file in input_files:
+ with tf.io.gfile.GFile(input_file, "rb") as reader:
+ while True:
+ line = tokenization.convert_to_unicode(reader.readline())
+ if not line:
+ break
+ line = line.strip()
+
+ # Empty lines are used as document delimiters
+ if not line:
+ all_documents.append([])
+ tokens = tokenizer.tokenize(line)
+ if tokens:
+ all_documents[-1].append(tokens)
+
+ # Remove empty documents
+ all_documents = [x for x in all_documents if x]
+ rng.shuffle(all_documents)
+
+ vocab_words = list(tokenizer.vocab.keys())
+ instances = []
+ for _ in range(dupe_factor):
+ for document_index in range(len(all_documents)):
+ instances.extend(
+ create_instances_from_document(
+ all_documents, document_index, max_seq_length, short_seq_prob,
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask, max_ngram_size))
+
+ rng.shuffle(instances)
+ return instances
+
+
+def create_instances_from_document(
+ all_documents, document_index, max_seq_length, short_seq_prob,
+ masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask=False,
+ max_ngram_size=None):
+ """Creates `TrainingInstance`s for a single document."""
+ document = all_documents[document_index]
+
+ # Account for [CLS], [SEP], [SEP]
+ max_num_tokens = max_seq_length - 3
+
+ # We *usually* want to fill up the entire sequence since we are padding
+ # to `max_seq_length` anyways, so short sequences are generally wasted
+ # computation. However, we *sometimes*
+ # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
+ # sequences to minimize the mismatch between pre-training and fine-tuning.
+ # The `target_seq_length` is just a rough target however, whereas
+ # `max_seq_length` is a hard limit.
+ target_seq_length = max_num_tokens
+ if rng.random() < short_seq_prob:
+ target_seq_length = rng.randint(2, max_num_tokens)
+
+ # We DON'T just concatenate all of the tokens from a document into a long
+ # sequence and choose an arbitrary split point because this would make the
+ # next sentence prediction task too easy. Instead, we split the input into
+ # segments "A" and "B" based on the actual "sentences" provided by the user
+ # input.
+ instances = []
+ current_chunk = []
+ current_length = 0
+ i = 0
+ while i < len(document):
+ segment = document[i]
+ current_chunk.append(segment)
+ current_length += len(segment)
+ if i == len(document) - 1 or current_length >= target_seq_length:
+ if current_chunk:
+ # `a_end` is how many segments from `current_chunk` go into the `A`
+ # (first) sentence.
+ a_end = 1
+ if len(current_chunk) >= 2:
+ a_end = rng.randint(1, len(current_chunk) - 1)
+
+ tokens_a = []
+ for j in range(a_end):
+ tokens_a.extend(current_chunk[j])
+
+ tokens_b = []
+ # Random next
+ is_random_next = False
+ if len(current_chunk) == 1 or rng.random() < 0.5:
+ is_random_next = True
+ target_b_length = target_seq_length - len(tokens_a)
+
+ # This should rarely go for more than one iteration for large
+ # corpora. However, just to be careful, we try to make sure that
+ # the random document is not the same as the document
+ # we're processing.
+ for _ in range(10):
+ random_document_index = rng.randint(0, len(all_documents) - 1)
+ if random_document_index != document_index:
+ break
+
+ random_document = all_documents[random_document_index]
+ random_start = rng.randint(0, len(random_document) - 1)
+ for j in range(random_start, len(random_document)):
+ tokens_b.extend(random_document[j])
+ if len(tokens_b) >= target_b_length:
+ break
+ # We didn't actually use these segments so we "put them back" so
+ # they don't go to waste.
+ num_unused_segments = len(current_chunk) - a_end
+ i -= num_unused_segments
+ # Actual next
+ else:
+ is_random_next = False
+ for j in range(a_end, len(current_chunk)):
+ tokens_b.extend(current_chunk[j])
+ truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
+
+ assert len(tokens_a) >= 1
+ assert len(tokens_b) >= 1
+
+ tokens = []
+ segment_ids = []
+ tokens.append("[CLS]")
+ segment_ids.append(0)
+ for token in tokens_a:
+ tokens.append(token)
+ segment_ids.append(0)
+
+ tokens.append("[SEP]")
+ segment_ids.append(0)
+
+ for token in tokens_b:
+ tokens.append(token)
+ segment_ids.append(1)
+ tokens.append("[SEP]")
+ segment_ids.append(1)
+
+ (tokens, masked_lm_positions,
+ masked_lm_labels) = create_masked_lm_predictions(
+ tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask, max_ngram_size)
+ instance = TrainingInstance(
+ tokens=tokens,
+ segment_ids=segment_ids,
+ is_random_next=is_random_next,
+ masked_lm_positions=masked_lm_positions,
+ masked_lm_labels=masked_lm_labels)
+ instances.append(instance)
+ current_chunk = []
+ current_length = 0
+ i += 1
+
+ return instances
+
+
+MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
+ ["index", "label"])
+
+# A _Gram is a [half-open) interval of token indices which form a word.
+# E.g.,
+# words: ["The", "doghouse"]
+# tokens: ["The", "dog", "##house"]
+# grams: [(0,1), (1,3)]
+_Gram = collections.namedtuple("_Gram", ["begin", "end"])
+
+
+def _window(iterable, size):
+ """Helper to create a sliding window iterator with a given size.
+
+ E.g.,
+ input = [1, 2, 3, 4]
+ _window(input, 1) => [1], [2], [3], [4]
+ _window(input, 2) => [1, 2], [2, 3], [3, 4]
+ _window(input, 3) => [1, 2, 3], [2, 3, 4]
+ _window(input, 4) => [1, 2, 3, 4]
+ _window(input, 5) => None
+
+ Args:
+ iterable: elements to iterate over.
+ size: size of the window.
+
+ Yields:
+ Elements of `iterable` batched into a sliding window of length `size`.
+ """
+ i = iter(iterable)
+ window = []
+ try:
+ for e in range(0, size):
+ window.append(next(i))
+ yield window
+ except StopIteration:
+ # handle the case where iterable's length is less than the window size.
+ return
+ for e in i:
+ window = window[1:] + [e]
+ yield window
+
+
+def _contiguous(sorted_grams):
+ """Test whether a sequence of grams is contiguous.
+
+ Args:
+ sorted_grams: _Grams which are sorted in increasing order.
+ Returns:
+ True if `sorted_grams` are touching each other.
+
+ E.g.,
+ _contiguous([(1, 4), (4, 5), (5, 10)]) == True
+ _contiguous([(1, 2), (4, 5)]) == False
+ """
+ for a, b in _window(sorted_grams, 2):
+ if a.end != b.begin:
+ return False
+ return True
+
+
+def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
+ """Create a list of masking {1, ..., n}-grams from a list of one-grams.
+
+ This is an extention of 'whole word masking' to mask multiple, contiguous
+ words such as (e.g., "the red boat").
+
+ Each input gram represents the token indices of a single word,
+ words: ["the", "red", "boat"]
+ tokens: ["the", "red", "boa", "##t"]
+ grams: [(0,1), (1,2), (2,4)]
+
+ For a `max_ngram_size` of three, possible outputs masks include:
+ 1-grams: (0,1), (1,2), (2,4)
+ 2-grams: (0,2), (1,4)
+ 3-grams; (0,4)
+
+ Output masks will not overlap and contain less than `max_masked_tokens` total
+ tokens. E.g., for the example above with `max_masked_tokens` as three,
+ valid outputs are,
+ [(0,1), (1,2)] # "the", "red" covering two tokens
+ [(1,2), (2,4)] # "red", "boa", "##t" covering three tokens
+
+ The length of the selected n-gram follows a zipf weighting to
+ favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
+
+ Args:
+ grams: List of one-grams.
+ max_ngram_size: Maximum number of contiguous one-grams combined to create
+ an n-gram.
+ max_masked_tokens: Maximum total number of tokens to be masked.
+ rng: `random.Random` generator.
+
+ Returns:
+ A list of n-grams to be used as masks.
+ """
+ if not grams:
+ return None
+
+ grams = sorted(grams)
+ num_tokens = grams[-1].end
+
+ # Ensure our grams are valid (i.e., they don't overlap).
+ for a, b in _window(grams, 2):
+ if a.end > b.begin:
+ raise ValueError("overlapping grams: {}".format(grams))
+
+ # Build map from n-gram length to list of n-grams.
+ ngrams = {i: [] for i in range(1, max_ngram_size+1)}
+ for gram_size in range(1, max_ngram_size+1):
+ for g in _window(grams, gram_size):
+ if _contiguous(g):
+ # Add an n-gram which spans these one-grams.
+ ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))
+
+ # Shuffle each list of n-grams.
+ for v in ngrams.values():
+ rng.shuffle(v)
+
+ # Create the weighting for n-gram length selection.
+ # Stored cummulatively for `random.choices` below.
+ cummulative_weights = list(
+ itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))
+
+ output_ngrams = []
+ # Keep a bitmask of which tokens have been masked.
+ masked_tokens = [False] * num_tokens
+ # Loop until we have enough masked tokens or there are no more candidate
+ # n-grams of any length.
+ # Each code path should ensure one or more elements from `ngrams` are removed
+ # to guarentee this loop terminates.
+ while (sum(masked_tokens) < max_masked_tokens and
+ sum(len(s) for s in ngrams.values())):
+ # Pick an n-gram size based on our weights.
+ sz = random.choices(range(1, max_ngram_size+1),
+ cum_weights=cummulative_weights)[0]
+
+ # Ensure this size doesn't result in too many masked tokens.
+ # E.g., a two-gram contains _at least_ two tokens.
+ if sum(masked_tokens) + sz > max_masked_tokens:
+ # All n-grams of this length are too long and can be removed from
+ # consideration.
+ ngrams[sz].clear()
+ continue
+
+ # All of the n-grams of this size have been used.
+ if not ngrams[sz]:
+ continue
+
+ # Choose a random n-gram of the given size.
+ gram = ngrams[sz].pop()
+ num_gram_tokens = gram.end-gram.begin
+
+ # Check if this would add too many tokens.
+ if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
+ continue
+
+ # Check if any of the tokens in this gram have already been masked.
+ if sum(masked_tokens[gram.begin:gram.end]):
+ continue
+
+ # Found a usable n-gram! Mark its tokens as masked and add it to return.
+ masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
+ output_ngrams.append(gram)
+ return output_ngrams
+
+
+def _wordpieces_to_grams(tokens):
+ """Reconstitue grams (words) from `tokens`.
+
+ E.g.,
+ tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
+ grams: [ [1,2), [2, 4), [4,5) , [5, 6)]
+
+ Args:
+ tokens: list of wordpieces
+ Returns:
+ List of _Grams representing spans of whole words
+ (without "[CLS]" and "[SEP]").
+ """
+ grams = []
+ gram_start_pos = None
+ for i, token in enumerate(tokens):
+ if gram_start_pos is not None and token.startswith("##"):
+ continue
+ if gram_start_pos is not None:
+ grams.append(_Gram(gram_start_pos, i))
+ if token not in ["[CLS]", "[SEP]"]:
+ gram_start_pos = i
+ else:
+ gram_start_pos = None
+ if gram_start_pos is not None:
+ grams.append(_Gram(gram_start_pos, len(tokens)))
+ return grams
+
+
+def create_masked_lm_predictions(tokens, masked_lm_prob,
+ max_predictions_per_seq, vocab_words, rng,
+ do_whole_word_mask,
+ max_ngram_size=None):
+ """Creates the predictions for the masked LM objective."""
+ if do_whole_word_mask:
+ grams = _wordpieces_to_grams(tokens)
+ else:
+ # Here we consider each token to be a word to allow for sub-word masking.
+ if max_ngram_size:
+ raise ValueError("cannot use ngram masking without whole word masking")
+ grams = [_Gram(i, i+1) for i in range(0, len(tokens))
+ if tokens[i] not in ["[CLS]", "[SEP]"]]
+
+ num_to_predict = min(max_predictions_per_seq,
+ max(1, int(round(len(tokens) * masked_lm_prob))))
+ # Generate masks. If `max_ngram_size` in [0, None] it means we're doing
+ # whole word masking or token level masking. Both of these can be treated
+ # as the `max_ngram_size=1` case.
+ masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
+ num_to_predict, rng)
+ masked_lms = []
+ output_tokens = list(tokens)
+ for gram in masked_grams:
+ # 80% of the time, replace all n-gram tokens with [MASK]
+ if rng.random() < 0.8:
+ replacement_action = lambda idx: "[MASK]"
+ else:
+ # 10% of the time, keep all the original n-gram tokens.
+ if rng.random() < 0.5:
+ replacement_action = lambda idx: tokens[idx]
+ # 10% of the time, replace each n-gram token with a random word.
+ else:
+ replacement_action = lambda idx: rng.choice(vocab_words)
+
+ for idx in range(gram.begin, gram.end):
+ output_tokens[idx] = replacement_action(idx)
+ masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))
+
+ assert len(masked_lms) <= num_to_predict
+ masked_lms = sorted(masked_lms, key=lambda x: x.index)
+
+ masked_lm_positions = []
+ masked_lm_labels = []
+ for p in masked_lms:
+ masked_lm_positions.append(p.index)
+ masked_lm_labels.append(p.label)
+
+ return (output_tokens, masked_lm_positions, masked_lm_labels)
+
+
+def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
+ """Truncates a pair of sequences to a maximum sequence length."""
+ while True:
+ total_length = len(tokens_a) + len(tokens_b)
+ if total_length <= max_num_tokens:
+ break
+
+ trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
+ assert len(trunc_tokens) >= 1
+
+ # We want to sometimes truncate from the front and sometimes from the
+ # back to add more randomness and avoid biases.
+ if rng.random() < 0.5:
+ del trunc_tokens[0]
+ else:
+ trunc_tokens.pop()
+
+
+def main(_):
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
+
+ input_files = []
+ for input_pattern in FLAGS.input_file.split(","):
+ input_files.extend(tf.io.gfile.glob(input_pattern))
+
+ logging.info("*** Reading from input files ***")
+ for input_file in input_files:
+ logging.info(" %s", input_file)
+
+ rng = random.Random(FLAGS.random_seed)
+ instances = create_training_instances(
+ input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
+ FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
+ rng, FLAGS.do_whole_word_mask, FLAGS.max_ngram_size)
+
+ output_files = FLAGS.output_file.split(",")
+ logging.info("*** Writing to output files ***")
+ for output_file in output_files:
+ logging.info(" %s", output_file)
+
+ write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
+ FLAGS.max_predictions_per_seq, output_files,
+ FLAGS.gzip_compress,
+ FLAGS.use_v2_feature_names)
+
+
+if __name__ == "__main__":
+ flags.mark_flag_as_required("input_file")
+ flags.mark_flag_as_required("output_file")
+ flags.mark_flag_as_required("vocab_file")
+ app.run(main)
diff --git a/modeling/official/nlp/data/create_pretraining_data_test.py b/modeling/official/nlp/data/create_pretraining_data_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ae8b86ca22a46dc51e1e45e3ab5b08d91960672
--- /dev/null
+++ b/modeling/official/nlp/data/create_pretraining_data_test.py
@@ -0,0 +1,128 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.data.create_pretraining_data."""
+import random
+
+import tensorflow as tf, tf_keras
+
+from official.nlp.data import create_pretraining_data as cpd
+
+_VOCAB_WORDS = ["vocab_1", "vocab_2"]
+
+
+class CreatePretrainingDataTest(tf.test.TestCase):
+
+ def assertTokens(self, input_tokens, output_tokens, masked_positions,
+ masked_labels):
+ # Ensure the masked positions are unique.
+ self.assertCountEqual(masked_positions, set(masked_positions))
+
+ # Ensure we can reconstruct the input from the output.
+ reconstructed_tokens = output_tokens
+ for pos, label in zip(masked_positions, masked_labels):
+ reconstructed_tokens[pos] = label
+ self.assertEqual(input_tokens, reconstructed_tokens)
+
+ # Ensure each label is valid.
+ for pos, label in zip(masked_positions, masked_labels):
+ output_token = output_tokens[pos]
+ if (output_token == "[MASK]" or output_token in _VOCAB_WORDS or
+ output_token == input_tokens[pos]):
+ continue
+ self.fail("invalid mask value: {}".format(output_token))
+
+ def test_wordpieces_to_grams(self):
+ tests = [
+ (["That", "cone"], [(0, 1), (1, 2)]),
+ (["That", "cone", "##s"], [(0, 1), (1, 3)]),
+ (["Swit", "##zer", "##land"], [(0, 3)]),
+ (["[CLS]", "Up", "##dog"], [(1, 3)]),
+ (["[CLS]", "Up", "##dog", "[SEP]", "Down"], [(1, 3), (4, 5)]),
+ ]
+ for inp, expected in tests:
+ output = cpd._wordpieces_to_grams(inp)
+ self.assertEqual(expected, output)
+
+ def test_window(self):
+ input_list = [1, 2, 3, 4]
+ window_outputs = [
+ (1, [[1], [2], [3], [4]]),
+ (2, [[1, 2], [2, 3], [3, 4]]),
+ (3, [[1, 2, 3], [2, 3, 4]]),
+ (4, [[1, 2, 3, 4]]),
+ (5, []),
+ ]
+ for window, expected in window_outputs:
+ output = cpd._window(input_list, window)
+ self.assertEqual(expected, list(output))
+
+ def test_create_masked_lm_predictions(self):
+ tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
+ rng = random.Random(123)
+ for _ in range(0, 5):
+ output_tokens, masked_positions, masked_labels = (
+ cpd.create_masked_lm_predictions(
+ tokens=tokens,
+ masked_lm_prob=1.0,
+ max_predictions_per_seq=3,
+ vocab_words=_VOCAB_WORDS,
+ rng=rng,
+ do_whole_word_mask=False,
+ max_ngram_size=None))
+ self.assertEqual(len(masked_positions), 3)
+ self.assertEqual(len(masked_labels), 3)
+ self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
+
+ def test_create_masked_lm_predictions_whole_word(self):
+ tokens = ["[CLS]", "a", "##a", "b", "##b", "c", "##c", "[SEP]"]
+ rng = random.Random(345)
+ for _ in range(0, 5):
+ output_tokens, masked_positions, masked_labels = (
+ cpd.create_masked_lm_predictions(
+ tokens=tokens,
+ masked_lm_prob=1.0,
+ max_predictions_per_seq=3,
+ vocab_words=_VOCAB_WORDS,
+ rng=rng,
+ do_whole_word_mask=True,
+ max_ngram_size=None))
+ # since we can't get exactly three tokens without breaking a word we
+ # only take two.
+ self.assertEqual(len(masked_positions), 2)
+ self.assertEqual(len(masked_labels), 2)
+ self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
+ # ensure that we took an entire word.
+ self.assertIn(masked_labels, [["a", "##a"], ["b", "##b"], ["c", "##c"]])
+
+ def test_create_masked_lm_predictions_ngram(self):
+ tokens = ["[CLS]"] + ["tok{}".format(i) for i in range(0, 512)] + ["[SEP]"]
+ rng = random.Random(345)
+ for _ in range(0, 5):
+ output_tokens, masked_positions, masked_labels = (
+ cpd.create_masked_lm_predictions(
+ tokens=tokens,
+ masked_lm_prob=1.0,
+ max_predictions_per_seq=76,
+ vocab_words=_VOCAB_WORDS,
+ rng=rng,
+ do_whole_word_mask=True,
+ max_ngram_size=3))
+ self.assertEqual(len(masked_positions), 76)
+ self.assertEqual(len(masked_labels), 76)
+ self.assertTokens(tokens, output_tokens, masked_positions, masked_labels)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/nlp/data/create_xlnet_pretraining_data.py b/modeling/official/nlp/data/create_xlnet_pretraining_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5e142b9bd01e8bc055f83f3c8efe369ce314fe8
--- /dev/null
+++ b/modeling/official/nlp/data/create_xlnet_pretraining_data.py
@@ -0,0 +1,721 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Create LM TF examples for XLNet."""
+
+import dataclasses
+import json
+import math
+import os
+
+import random
+from typing import Iterable, Mapping, List, Optional, Tuple
+import unicodedata
+
+# Import libraries
+
+from absl import app
+from absl import flags
+from absl import logging
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.nlp.tools import tokenization
+
+special_symbols = {
+ "": 0,
+ "": 1,
+ "": 2,
+ "": 3,
+ "": 4,
+ "": 5,
+ "": 6,
+ "": 7,
+ "": 8,
+}
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_integer("seq_length", 512,
+ help="Sequence length.")
+flags.DEFINE_integer("reuse_length", 256,
+ help="Number of token that can be reused as memory. "
+ "Could be half of `seq_len`.")
+flags.DEFINE_string("input_file", None,
+ "Input raw text file (or comma-separated list of files).")
+flags.DEFINE_string(
+ "save_dir", None,
+ "Directory for saving processed data.")
+flags.DEFINE_string("sp_model_file", "",
+ "The path to the model used by sentence piece tokenizer.")
+flags.DEFINE_bool("use_eod_token", True,
+ "Whether or not to include EOD tokens.")
+flags.DEFINE_bool("bi_data", True, "Whether or not to use bi-directional data.")
+flags.DEFINE_bool(
+ "do_lower_case", True,
+ "Whether to lower case the input text. Should be True for uncased "
+ "models and False for cased models.")
+flags.DEFINE_integer("per_host_batch_size", 32, "Batch size per host.")
+flags.DEFINE_integer("num_cores_per_host", 16,
+ "The number of (TPU) cores per host.")
+flags.DEFINE_string("prefix", "", "Filename prefix.")
+flags.DEFINE_string("suffix", "", "Filename suffix.")
+
+flags.DEFINE_integer("task_id", None,
+ "The id of the current task.")
+flags.DEFINE_integer("num_tasks", None,
+ "The total number of tasks.")
+flags.DEFINE_integer("num_passes", 1, "The number of times to run the script.")
+
+
+@dataclasses.dataclass
+class TrainingInstance:
+ """Representation of a single XLNet Pretraining instance."""
+ data: Iterable[int]
+ segment_ids: Iterable[int]
+ boundary_indices: Iterable[int]
+ label: int
+
+ def to_feature(self) -> Mapping[str, tf.train.Feature]:
+ feat = lambda x: tf.train.Feature(int64_list=tf.train.Int64List(value=x))
+ return dict(
+ input_word_ids=feat(self.data),
+ input_type_ids=feat(self.segment_ids),
+ boundary_indices=feat(self.boundary_indices),
+ label=feat([self.label]))
+
+ def to_example(self) -> tf.train.Example:
+ return tf.train.Example(
+ features=tf.train.Features(feature=self.to_feature()))
+
+ def __str__(self):
+ def seq_to_str(seq):
+ return " ".join([str(x) for x in seq])
+
+ s = ""
+ s += "tokens: %s\n" % seq_to_str(self.data)
+ s += "segment_ids: %s\n" % seq_to_str(self.segment_ids)
+ s += "boundary_indices: %s\n" % seq_to_str(self.boundary_indices)
+ s += "label: %s\n" % self.label
+ s += "\n"
+ return s
+
+ def __repr__(self):
+ return self.__str__()
+
+
+def _preprocess_line(line: str, do_lower_case: bool = False) -> str:
+ """Preprocesses an individual raw text line.
+
+ This function will:
+ - Remove extraneous spaces.
+ - Replace `` with ", and '' with ".
+ - Replaces accents.
+ - Applies lower casing.
+
+ Args:
+ line: The input line to preprocess.
+ do_lower_case: Whether or not to lower case the text.
+
+ Returns:
+ The preprocessed line.
+
+ """
+ line = " ".join(line.split())
+ line = line.replace("``", "\"").replace("''", "\"")
+
+ # Replace accents.
+ line = unicodedata.normalize("NFKD", line)
+ line = "".join([c for c in line if not unicodedata.combining(c)])
+
+ if do_lower_case:
+ line = line.lower()
+ return line
+
+
+def preprocess_and_tokenize_input_files(
+ input_files: Iterable[str],
+ tokenizer: tokenization.FullSentencePieceTokenizer,
+ use_eod: bool = True,
+ do_lower_case: bool = False,
+ log_example_freq: int = 100000) -> List[Tuple[np.array, np.array]]:
+ """Preprocesses and encodes raw text from input files.
+
+ This function preprocesses raw text and encodes them into tokens using a
+ `SentencePieceModel` tokenization method. This also provides the sentence
+ indicator for each token.
+
+ Args:
+ input_files: The list of input file names.
+ tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
+ use_eod: Whether or not to use an EOD indicator. If `False`, then EOD is
+ not included.
+ do_lower_case: Whether or not to apply lower casing during raw text
+ preprocessing.
+ log_example_freq: The optional field for how many lines to process before
+ emitting an info log.
+
+ Returns:
+ The preprocessed list. Each entry in the list is a tuple consisting of
+ the token IDs and the sentence IDs.
+
+ """
+ all_data = []
+ eod_symbol = special_symbols[""]
+
+ total_number_of_lines = 0
+
+ # Input file format:
+ # (1) One sentence per line. These should ideally be actual sentences, not
+ # entire paragraphs or arbitrary spans of text. (Because we use the
+ # sentence boundaries for the "next sentence prediction" task).
+ # (2) Blank lines between documents. Document boundaries are needed so
+ # that the "next sentence prediction" task doesn't span between documents.
+ for input_file in input_files:
+ line_count = 0
+ logging.info("Preprocessing %s", input_file)
+
+ all_tokens = []
+ all_sentence_ids = []
+
+ sentence_id = True
+
+ with tf.io.gfile.GFile(input_file, "rb") as reader:
+ while True:
+ line = tokenization.convert_to_unicode(reader.readline())
+ if not line:
+ break
+
+ line_count += 1
+ if line_count % log_example_freq == 0:
+ logging.info("Loading line %d", line_count)
+
+ line = line.strip()
+
+ if not line:
+ if use_eod:
+ token_ids = [eod_symbol]
+ sentence_id = not sentence_id
+ else:
+ continue
+ else:
+ preprocessed_line = _preprocess_line(
+ line=line, do_lower_case=do_lower_case)
+ token_ids = tokenization.encode_ids(
+ sp_model=tokenizer.sp_model, text=preprocessed_line)
+
+ all_tokens.extend(token_ids)
+ all_sentence_ids.extend([sentence_id] * len(token_ids))
+ sentence_id = not sentence_id
+ logging.info("Finished processing %s. Number of lines: %d",
+ input_file, line_count)
+ if line_count == 0:
+ continue
+ total_number_of_lines += line_count
+ all_tokens = np.array(all_tokens, dtype=np.int64)
+ all_sentence_ids = np.array(all_sentence_ids, dtype=bool)
+ all_data.append((all_tokens, all_sentence_ids))
+
+ logging.info("Completed text preprocessing. Total number of lines: %d",
+ total_number_of_lines)
+ return all_data
+
+
+def _reshape_to_batch_dimensions(
+ tokens: np.array,
+ sentence_ids: np.array,
+ per_host_batch_size: int) -> Tuple[np.array, np.array]:
+ """Truncates and reshapes input data with a batch major dimension.
+
+ Args:
+ tokens: The input token ids. This should have the same shape as
+ `sentence_ids`.
+ sentence_ids: The input sentence ids. This should have the same shape as
+ `token_ids`.
+ per_host_batch_size: The target per-host batch size.
+
+ Returns:
+ The tuple of reshaped tokens and sentence_ids.
+ """
+ num_steps = len(tokens) // per_host_batch_size
+ truncated_data_length = num_steps * per_host_batch_size
+
+ logging.info("per_host_batch_size: %d", per_host_batch_size)
+ logging.info("num_steps: %d", num_steps)
+ def truncate_and_reshape(a):
+ return a[:truncated_data_length].reshape((per_host_batch_size, num_steps))
+
+ return (truncate_and_reshape(tokens), truncate_and_reshape(sentence_ids))
+
+
+def _create_a_and_b_segments(
+ tokens: np.array,
+ sentence_ids: np.array,
+ begin_index: int,
+ total_length: int,
+ no_cut_probability: float = 0.5):
+ """Splits segments A and B from a single instance of tokens and sentence ids.
+
+ Args:
+ tokens: The 1D input token ids. This represents an individual entry within a
+ batch.
+ sentence_ids: The 1D input sentence ids. This represents an indivdual entry
+ within a batch. This should be the same length as `tokens`.
+ begin_index: The reference beginning index to split data.
+ total_length: The target combined length of segments A and B.
+ no_cut_probability: The probability of not cutting a segment despite
+ a cut possibly existing.
+
+ Returns:
+ A tuple consisting of A data, B data, and label.
+
+ """
+ data_length = tokens.shape[0]
+ if begin_index + total_length >= data_length:
+ logging.info("[_create_segments]: begin_index %d + total_length %d >= "
+ "data_length %d", begin_index, total_length, data_length)
+ return None
+
+ end_index = begin_index + 1
+ cut_indices = []
+
+ # Identify all indices where sentence IDs change from one to the next.
+ while end_index < data_length:
+ if sentence_ids[end_index] != sentence_ids[end_index - 1]:
+ if end_index - begin_index >= total_length:
+ break
+ cut_indices.append(end_index)
+ end_index += 1
+
+ a_begin = begin_index
+
+ if not cut_indices or random.random() < no_cut_probability:
+ # Segments A and B are contained within the same sentence.
+ label = 0
+ if not cut_indices:
+ a_end = end_index
+ else:
+ a_end = random.choice(cut_indices)
+ b_length = max(1, total_length - (a_end - a_begin))
+ b_begin = random.randint(0, data_length - 1 - b_length)
+ b_end = b_begin + b_length
+
+ while b_begin > 0 and sentence_ids[b_begin - 1] == sentence_ids[b_begin]:
+ b_begin -= 1
+ while (b_end < data_length - 1 and
+ sentence_ids[b_end - 1] == sentence_ids[b_end]):
+ b_end += 1
+ else:
+ # Segments A and B are different sentences.
+ label = 1
+ a_end = random.choice(cut_indices)
+ b_begin = a_end
+ b_end = end_index
+
+ while a_end - a_begin + b_end - b_begin > total_length:
+ if a_end - a_begin > b_end - b_begin:
+ # Delete only the right side for the LM objective.
+ a_end -= 1
+ else:
+ b_end -= 1
+ if a_end >= data_length or b_end >= data_length:
+ logging.info("[_create_segments]: a_end %d or b_end %d >= data_length %d",
+ a_end, b_end, data_length)
+ return None
+
+ a_data = tokens[a_begin: a_end]
+ b_data = tokens[b_begin: b_end]
+ return a_data, b_data, label
+
+
+def _is_functional_piece(piece: str) -> bool:
+ return piece != "" and piece.startswith("<") and piece.endswith(">")
+
+
+def _is_start_piece(piece: str) -> bool:
+ special_pieces = set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~'))
+ if (piece.startswith("▁") or piece in special_pieces):
+ return True
+ else:
+ return False
+
+
+def _get_boundary_indices(
+ data: np.array,
+ tokenizer: tokenization.FullSentencePieceTokenizer) -> np.array:
+ """Gets the boundary indices of whole words."""
+ seq_length = len(data)
+ boundary_indices = []
+ for index, piece in enumerate(tokenizer.convert_ids_to_tokens(data.tolist())):
+ if _is_start_piece(piece) and not _is_functional_piece(piece):
+ boundary_indices.append(index)
+ boundary_indices.append(seq_length)
+ return boundary_indices
+
+
+def _convert_tokens_to_instances(
+ tokens: np.array,
+ sentence_ids: np.array,
+ per_host_batch_size: int,
+ seq_length: int,
+ reuse_length: int,
+ bi_data: bool,
+ tokenizer: tokenization.FullSentencePieceTokenizer,
+ num_cores_per_host: int = 0,
+ logging_frequency: int = 500) -> List[TrainingInstance]:
+ """Converts tokens and sentence IDs into individual training instances.
+
+ The format of data in the XLNet pretraining task is very similar to the
+ BERT pretraining task. Two segments A and B are randomly sampled, and the
+ contatenation of A and B into a single sequence is used to perform
+ language modeling.
+
+ To create an XLNet Pretraining instance from a single long sequence, S:
+ - Create a segment of length `reuse_length`. This first segment represents
+ past tokens. During modeling, this segment is used to cache obtained
+ content representations for the segment recurrence mechanism.
+ - Similar to BERT, create a segment of length `seq_length` - `reuse_length`
+ composed of A and B segments.
+ For XLNet, the order is "A", "SEP", "B", "SEP", "CLS".
+
+ Args:
+ tokens: All tokens concatenated into a single list.
+ sentence_ids: All sentence IDs concatenated into a single list.
+ per_host_batch_size: The target batch size per host.
+ seq_length: The max sequence length.
+ reuse_length: The number of tokens to use from the previous segment.
+ bi_data: Whether or not to use bidirectional data.
+ tokenizer: The SentencePiece tokenizer that has the attribute `sp_model`.
+ num_cores_per_host: The number of cores per host. This is required if
+ `bi_data` = `True`.
+ logging_frequency: The frequency at which to log status updates.
+
+ Returns:
+ A list of `TrainingInstance` objects.
+ """
+ instances = []
+
+ per_core_batch_size = (per_host_batch_size // num_cores_per_host
+ if bi_data else None)
+
+ if bi_data:
+ logging.info("Bi-directional data enabled.")
+ assert per_host_batch_size % (2 * num_cores_per_host) == 0
+ forward_tokens, forward_sentence_ids = _reshape_to_batch_dimensions(
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ per_host_batch_size=per_host_batch_size // 2)
+ forward_data_shape = (num_cores_per_host, 1, per_core_batch_size // 2, -1)
+
+ forward_tokens = forward_tokens.reshape(forward_data_shape)
+ forward_sentence_ids = forward_sentence_ids.reshape(forward_data_shape)
+
+ backwards_tokens = forward_tokens[:, :, :, ::-1]
+ backwards_sentence_ids = forward_sentence_ids[:, :, :, ::-1]
+
+ tokens = np.concatenate([forward_tokens, backwards_tokens], 1).reshape(
+ per_host_batch_size, -1)
+ sentence_ids = np.concatenate(
+ [forward_sentence_ids, backwards_sentence_ids]).reshape(
+ per_host_batch_size, -1)
+ else:
+ logging.info("Bi-directional data disabled.")
+ tokens, sentence_ids = _reshape_to_batch_dimensions(
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ per_host_batch_size=per_host_batch_size)
+
+ logging.info("Tokens shape: %s", tokens.shape)
+
+ data_length = tokens.shape[1]
+ sep = np.array([special_symbols[""]], dtype=np.int64)
+ cls = np.array([special_symbols[""]], dtype=np.int64)
+ # 2 sep, 1 cls
+ num_special_tokens = 3
+
+ data_index = 0
+ batch_number = 0
+ step_size = reuse_length if reuse_length else seq_length
+ num_batches = math.ceil(data_length / step_size)
+
+ while data_index + seq_length <= data_length:
+ if batch_number % logging_frequency == 0:
+ logging.info("Processing batch %d of %d", batch_number, num_batches)
+
+ for batch_index in range(per_host_batch_size):
+ previous_segment_tokens = tokens[
+ batch_index, data_index: data_index + reuse_length]
+
+ results = _create_a_and_b_segments(
+ tokens=tokens[batch_index],
+ sentence_ids=sentence_ids[batch_index],
+ begin_index=data_index + reuse_length,
+ total_length=seq_length - reuse_length - num_special_tokens)
+
+ if results is None:
+ logging.info("Stopping at data index: %d", data_index)
+ break
+ a_data, b_data, label = results
+
+ data = np.concatenate(
+ [previous_segment_tokens, a_data, sep, b_data, sep, cls])
+ a_length = a_data.shape[0]
+ b_length = b_data.shape[0]
+ segment_ids = ([0] * (reuse_length + a_length) + [0]
+ + [1] * b_length + [1] + [2])
+ boundary_indices = _get_boundary_indices(tokenizer=tokenizer,
+ data=data)
+ assert len(data) == seq_length
+ assert len(segment_ids) == seq_length
+ assert len(boundary_indices) > 0 # pylint: disable=g-explicit-length-test
+
+ instances.append(TrainingInstance(
+ data=data,
+ segment_ids=segment_ids,
+ boundary_indices=boundary_indices,
+ label=label))
+ batch_number += 1
+ data_index += step_size
+ return instances
+
+
+def write_instances_to_tfrecord(
+ instances: Iterable[TrainingInstance],
+ save_path: str):
+ """Writes instances to TFRecord."""
+ record_writer = tf.io.TFRecordWriter(save_path)
+ logging.info("Start writing to %s.", save_path)
+
+ for i, instance in enumerate(instances):
+ if i < 5:
+ logging.info("Instance %d: %s", i, str(instance))
+ record_writer.write(instance.to_example().SerializeToString())
+
+ record_writer.close()
+ logging.info("Done writing %s.", save_path)
+
+
+def shuffle_and_combine_preprocessed_data(
+ all_data: List[Tuple[np.array, np.array]]) -> Tuple[np.array, np.array]:
+ """Shuffles and combines preprocessed token/sentence IDs from documents."""
+ document_permutation = np.random.permutation(len(all_data))
+
+ previous_sentence_id = None
+
+ all_tokens, all_sentence_ids = [], []
+ for document_index in document_permutation:
+ tokens, sentence_ids = all_data[document_index]
+ # pylint: disable=g-explicit-length-test
+ if len(tokens) == 0:
+ continue
+ if (previous_sentence_id is not None and
+ sentence_ids[0] == previous_sentence_id):
+ sentence_ids = np.logical_not(sentence_ids)
+
+ all_tokens.append(tokens)
+ all_sentence_ids.append(sentence_ids)
+
+ previous_sentence_id = sentence_ids[-1]
+
+ return np.concatenate(all_tokens), np.concatenate(all_sentence_ids)
+
+
+def get_tfrecord_name(
+ per_host_batch_size: int,
+ num_cores_per_host: int,
+ seq_length: int,
+ bi_data: bool,
+ reuse_length: int,
+ do_lower_case: bool,
+ use_eod_token: bool,
+ prefix: str = "",
+ suffix: str = "",
+ pass_id: int = 0,
+ num_passes: int = 1,
+ task_id: int = None,
+ num_tasks: int = None) -> str:
+ """Formats the resulting TFRecord name based on provided inputs."""
+ components = []
+ if prefix:
+ components.append(prefix)
+ components.append("seqlen-{}".format(seq_length))
+ if reuse_length == 0:
+ components.append("memless")
+ else:
+ components.append("reuse-{}".format(reuse_length))
+ components.append("bs-{}".format(per_host_batch_size))
+ components.append("cores-{}".format(num_cores_per_host))
+
+ if do_lower_case:
+ components.append("uncased")
+ else:
+ components.append("cased")
+ if use_eod_token:
+ components.append("eod")
+ if bi_data:
+ components.append("bi")
+ else:
+ components.append("uni")
+
+ if suffix:
+ components.append(suffix)
+
+ s = "_".join(components) + ".tfrecord"
+ if num_passes == 1 and task_id is None:
+ return s
+
+ if task_id is None:
+ num_tasks = 1
+ task_id = 0
+
+ current_shard = task_id * num_passes + pass_id
+ total_shards = num_tasks * num_passes
+ return s + "-{}-of-{}".format(current_shard, total_shards)
+
+
+def create_tfrecords(
+ tokenizer: tokenization.FullSentencePieceTokenizer,
+ input_file_or_files: str,
+ use_eod_token: bool,
+ do_lower_case: bool,
+ per_host_batch_size: int,
+ seq_length: int,
+ reuse_length: int,
+ bi_data: bool,
+ num_cores_per_host: int,
+ save_dir: str,
+ prefix: str = "",
+ suffix: str = "",
+ num_tasks: Optional[int] = None,
+ task_id: Optional[int] = None,
+ num_passes: int = 1):
+ """Runs the end-to-end preprocessing pipeline."""
+
+ logging.info("Input configuration:")
+ logging.info("input file(s): %s", input_file_or_files)
+ logging.info("use_eod_token: %s", use_eod_token)
+ logging.info("do_lower_case: %s", do_lower_case)
+ logging.info("per_host_batch_size: %d", per_host_batch_size)
+ logging.info("seq_length: %d", seq_length)
+ logging.info("reuse_length: %d", reuse_length)
+ logging.info("bi_data: %s", bi_data)
+ logging.info("num_cores_per_host: %d", num_cores_per_host)
+ logging.info("save_dir: %s", save_dir)
+ if task_id is not None and num_tasks is not None:
+ logging.info("task_id: %d", task_id)
+ logging.info("num_tasks: %d", num_tasks)
+
+ input_files = []
+ for input_pattern in input_file_or_files.split(","):
+ input_files.extend(tf.io.gfile.glob(input_pattern))
+
+ logging.info("*** Reading from input files ***")
+ for input_file in input_files:
+ logging.info(" %s", input_file)
+
+ logging.info("Shuffling the files with a fixed random seed.")
+ np.random.shuffle(input_files)
+ if num_tasks is not None:
+ assert task_id is not None
+ logging.info("Total number of input files: %d", len(input_files))
+ logging.info("Splitting into %d shards of %d files each.",
+ num_tasks, len(input_files) // num_tasks)
+ input_files = input_files[task_id::num_tasks]
+
+ all_data = preprocess_and_tokenize_input_files(
+ input_files=input_files,
+ tokenizer=tokenizer,
+ use_eod=use_eod_token,
+ do_lower_case=do_lower_case)
+ for pass_id in range(num_passes):
+ logging.info("Beginning pass %d of %d", pass_id, num_passes)
+ tokens, sentence_ids = shuffle_and_combine_preprocessed_data(all_data)
+
+ assert len(tokens) == len(sentence_ids)
+
+ filename = get_tfrecord_name(
+ per_host_batch_size=per_host_batch_size,
+ num_cores_per_host=num_cores_per_host,
+ seq_length=seq_length,
+ bi_data=bi_data,
+ use_eod_token=use_eod_token,
+ reuse_length=reuse_length,
+ do_lower_case=do_lower_case,
+ prefix=prefix,
+ suffix=suffix,
+ pass_id=pass_id,
+ num_passes=num_passes,
+ num_tasks=num_tasks,
+ task_id=task_id)
+ save_path = os.path.join(save_dir, filename)
+ if os.path.exists(save_path):
+ # If the path already exists, then we were probably preempted but
+ # previously wrote this file.
+ logging.info("%s already exists, skipping this batch.", save_path)
+ else:
+ instances = _convert_tokens_to_instances(
+ tokenizer=tokenizer,
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ per_host_batch_size=per_host_batch_size,
+ seq_length=seq_length,
+ reuse_length=reuse_length,
+ bi_data=bi_data,
+ num_cores_per_host=num_cores_per_host)
+ write_instances_to_tfrecord(instances=instances, save_path=save_path)
+
+ if task_id is None or task_id == 0:
+ corpus_info = {
+ "vocab_size": 32000,
+ "per_host_batch_size": per_host_batch_size,
+ "num_cores_per_host": num_cores_per_host,
+ "seq_length": seq_length,
+ "reuse_length": reuse_length,
+ "do_lower_case": do_lower_case,
+ "bi_data": bi_data,
+ "use_eod_token": use_eod_token,
+ }
+ corpus_fname = os.path.basename(filename) + ".json"
+ corpus_destination = os.path.join(save_dir, corpus_fname)
+ logging.info("Saving corpus info to %s", corpus_destination)
+
+ with tf.io.gfile.GFile(corpus_destination, "w") as fp:
+ json.dump(corpus_info, fp)
+
+
+def main(_):
+ tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
+ create_tfrecords(
+ tokenizer=tokenizer,
+ input_file_or_files=FLAGS.input_file,
+ use_eod_token=FLAGS.use_eod_token,
+ do_lower_case=FLAGS.do_lower_case,
+ per_host_batch_size=FLAGS.per_host_batch_size,
+ seq_length=FLAGS.seq_length,
+ reuse_length=FLAGS.reuse_length,
+ bi_data=FLAGS.bi_data,
+ num_cores_per_host=FLAGS.num_cores_per_host,
+ save_dir=FLAGS.save_dir,
+ prefix=FLAGS.prefix,
+ suffix=FLAGS.suffix,
+ num_tasks=FLAGS.num_tasks,
+ task_id=FLAGS.task_id,
+ num_passes=FLAGS.num_passes)
+
+
+if __name__ == "__main__":
+ np.random.seed(0)
+ logging.set_verbosity(logging.INFO)
+ app.run(main)
diff --git a/modeling/official/nlp/data/create_xlnet_pretraining_data_test.py b/modeling/official/nlp/data/create_xlnet_pretraining_data_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5a9176e91fefef5c09ddf38c9372826415eb9ba
--- /dev/null
+++ b/modeling/official/nlp/data/create_xlnet_pretraining_data_test.py
@@ -0,0 +1,355 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.data.create_xlnet_pretraining_data."""
+import os
+import tempfile
+from typing import List
+
+from absl import logging
+from absl.testing import parameterized
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.nlp.data import create_xlnet_pretraining_data as cpd
+
+_VOCAB_WORDS = ["vocab_1", "vocab_2"]
+
+
+# pylint: disable=invalid-name
+def _create_files(
+ temp_dir: str, file_contents: List[List[str]]) -> List[str]:
+ """Writes arbitrary documents into files."""
+ root_dir = tempfile.mkdtemp(dir=temp_dir)
+ files = []
+
+ for i, file_content in enumerate(file_contents):
+ destination = os.path.join(root_dir, "%d.txt" % i)
+ with open(destination, "wb") as f:
+ for line in file_content:
+ f.write(line.encode("utf-8"))
+ files.append(destination)
+ return files
+
+
+def _get_mock_tokenizer():
+ """Creates a mock tokenizer."""
+
+ class MockSpieceModel:
+ """Mock Spiece model for testing."""
+
+ def __init__(self):
+ self._special_piece_to_id = {
+ "": 0,
+ }
+ for piece in set(list('!"#$%&\"()*+,-./:;?@[\\]^_`{|}~')):
+ self._special_piece_to_id[piece] = 1
+
+ def EncodeAsPieces(self, inputs: str) -> List[str]:
+ return inputs
+
+ def SampleEncodeAsPieces(self,
+ inputs: str,
+ nbest_size: int,
+ theta: float) -> List[str]:
+ del nbest_size, theta
+ return inputs
+
+ def PieceToId(self, piece: str) -> int:
+ return ord(piece[0])
+
+ def IdToPiece(self, id_: int) -> str:
+ return chr(id_) * 3
+
+ class Tokenizer:
+ """Mock Tokenizer for testing."""
+
+ def __init__(self):
+ self.sp_model = MockSpieceModel()
+
+ def convert_ids_to_tokens(self, ids: List[int]) -> List[str]:
+ return [self.sp_model.IdToPiece(id_) for id_ in ids]
+
+ return Tokenizer()
+
+
+class PreprocessDataTest(tf.test.TestCase):
+
+ def test_remove_extraneous_space(self):
+ line = " abc "
+ output = cpd._preprocess_line(line)
+ self.assertEqual(output, "abc")
+
+ def test_symbol_replacements(self):
+ self.assertEqual(cpd._preprocess_line("``abc``"), "\"abc\"")
+ self.assertEqual(cpd._preprocess_line("''abc''"), "\"abc\"")
+
+ def test_accent_replacements(self):
+ self.assertEqual(cpd._preprocess_line("åbc"), "abc")
+
+ def test_lower_case(self):
+ self.assertEqual(cpd._preprocess_line("ABC", do_lower_case=True), "abc")
+
+ def test_end_to_end(self):
+ self.assertEqual(
+ cpd._preprocess_line("HelLo ``wórLd``", do_lower_case=True),
+ "hello \"world\"")
+
+
+class PreprocessAndTokenizeFilesTest(tf.test.TestCase):
+
+ def test_basic_end_to_end(self):
+ documents = [
+ [
+ "This is sentence 1.\n",
+ "This is sentence 2.\n",
+ "Sentence 3 is what this is.\n",
+ ],
+ [
+ "This is the second document.\n",
+ "This is the second line of the second document.\n"
+ ],
+ ]
+ input_files = _create_files(temp_dir=self.get_temp_dir(),
+ file_contents=documents)
+ all_data = cpd.preprocess_and_tokenize_input_files(
+ input_files=input_files,
+ tokenizer=_get_mock_tokenizer(),
+ log_example_freq=1)
+
+ self.assertEqual(len(all_data), len(documents))
+ for token_ids, sentence_ids in all_data:
+ self.assertEqual(len(token_ids), len(sentence_ids))
+
+ def test_basic_correctness(self):
+ documents = [["a\n", "b\n", "c\n"]]
+ input_files = _create_files(temp_dir=self.get_temp_dir(),
+ file_contents=documents)
+ all_data = cpd.preprocess_and_tokenize_input_files(
+ input_files=input_files,
+ tokenizer=_get_mock_tokenizer(),
+ log_example_freq=1)
+
+ token_ids, sentence_ids = all_data[0]
+
+ self.assertAllClose(token_ids, [97, 98, 99])
+ self.assertAllClose(sentence_ids, [True, False, True])
+
+ def test_correctness_with_spaces_and_accents(self):
+ documents = [[
+ " å \n",
+ "b \n",
+ " c \n",
+ ]]
+ input_files = _create_files(temp_dir=self.get_temp_dir(),
+ file_contents=documents)
+ all_data = cpd.preprocess_and_tokenize_input_files(
+ input_files=input_files,
+ tokenizer=_get_mock_tokenizer(),
+ log_example_freq=1)
+
+ token_ids, sentence_ids = all_data[0]
+
+ self.assertAllClose(token_ids, [97, 98, 99])
+ self.assertAllClose(sentence_ids, [True, False, True])
+
+
+class BatchReshapeTests(tf.test.TestCase):
+
+ def test_basic_functionality(self):
+ per_host_batch_size = 3
+ mock_shape = (20,)
+
+ # Should truncate and reshape.
+ expected_result_shape = (3, 6)
+
+ tokens = np.zeros(mock_shape)
+ sentence_ids = np.zeros(mock_shape)
+
+ reshaped_data = cpd._reshape_to_batch_dimensions(
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ per_host_batch_size=per_host_batch_size)
+ for values in reshaped_data:
+ self.assertEqual(len(values.flatten()) % per_host_batch_size, 0)
+ self.assertAllClose(values.shape, expected_result_shape)
+
+
+class CreateSegmentsTest(tf.test.TestCase):
+
+ def test_basic_functionality(self):
+ data_length = 10
+ tokens = np.arange(data_length)
+ sentence_ids = np.concatenate([np.zeros(data_length // 2),
+ np.ones(data_length // 2)])
+ begin_index = 0
+ total_length = 8
+ a_data, b_data, label = cpd._create_a_and_b_segments(
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ begin_index=begin_index,
+ total_length=total_length,
+ no_cut_probability=0.)
+ self.assertAllClose(a_data, [0, 1, 2, 3])
+ self.assertAllClose(b_data, [5, 6, 7, 8])
+ self.assertEqual(label, 1)
+
+ def test_no_cut(self):
+ data_length = 10
+ tokens = np.arange(data_length)
+ sentence_ids = np.zeros(data_length)
+
+ begin_index = 0
+ total_length = 8
+ a_data, b_data, label = cpd._create_a_and_b_segments(
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ begin_index=begin_index,
+ total_length=total_length,
+ no_cut_probability=0.)
+ self.assertGreater(len(a_data), 0)
+ self.assertGreater(len(b_data), 0)
+ self.assertEqual(label, 0)
+
+ def test_no_cut_with_probability(self):
+ data_length = 10
+ tokens = np.arange(data_length)
+ sentence_ids = np.concatenate([np.zeros(data_length // 2),
+ np.ones(data_length // 2)])
+ begin_index = 0
+ total_length = 8
+ a_data, b_data, label = cpd._create_a_and_b_segments(
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ begin_index=begin_index,
+ total_length=total_length,
+ no_cut_probability=1.)
+ self.assertGreater(len(a_data), 0)
+ self.assertGreater(len(b_data), 0)
+ self.assertEqual(label, 0)
+
+
+class CreateInstancesTest(tf.test.TestCase):
+ """Tests conversions of Token/Sentence IDs to training instances."""
+
+ def test_basic(self):
+ data_length = 12
+ tokens = np.arange(data_length)
+ sentence_ids = np.zeros(data_length)
+ seq_length = 8
+ instances = cpd._convert_tokens_to_instances(
+ tokens=tokens,
+ sentence_ids=sentence_ids,
+ per_host_batch_size=2,
+ seq_length=seq_length,
+ reuse_length=4,
+ tokenizer=_get_mock_tokenizer(),
+ bi_data=False,
+ num_cores_per_host=1,
+ logging_frequency=1)
+ for instance in instances:
+ self.assertEqual(len(instance.data), seq_length)
+ self.assertEqual(len(instance.segment_ids), seq_length)
+ self.assertIsInstance(instance.label, int)
+ self.assertIsInstance(instance.boundary_indices, list)
+
+
+class TFRecordPathTests(tf.test.TestCase):
+
+ def test_basic(self):
+ base_kwargs = dict(
+ per_host_batch_size=1,
+ num_cores_per_host=1,
+ seq_length=2,
+ reuse_length=1)
+
+ config1 = dict(
+ prefix="test",
+ suffix="",
+ bi_data=True,
+ use_eod_token=False,
+ do_lower_case=True)
+ config1.update(base_kwargs)
+ expectation1 = "test_seqlen-2_reuse-1_bs-1_cores-1_uncased_bi.tfrecord"
+ self.assertEqual(cpd.get_tfrecord_name(**config1), expectation1)
+
+ config2 = dict(
+ prefix="",
+ suffix="test",
+ bi_data=False,
+ use_eod_token=False,
+ do_lower_case=False)
+ config2.update(base_kwargs)
+ expectation2 = "seqlen-2_reuse-1_bs-1_cores-1_cased_uni_test.tfrecord"
+ self.assertEqual(cpd.get_tfrecord_name(**config2), expectation2)
+
+ config3 = dict(
+ prefix="",
+ suffix="",
+ use_eod_token=True,
+ bi_data=False,
+ do_lower_case=True)
+ config3.update(base_kwargs)
+ expectation3 = "seqlen-2_reuse-1_bs-1_cores-1_uncased_eod_uni.tfrecord"
+ self.assertEqual(cpd.get_tfrecord_name(**config3), expectation3)
+
+
+class TestCreateTFRecords(parameterized.TestCase, tf.test.TestCase):
+
+ @parameterized.named_parameters(
+ ("bi_data_only", True, False, False),
+ ("eod_token_only", False, True, True),
+ ("lower_case_only", False, False, True),
+ ("all_enabled", True, True, True),
+ )
+ def test_end_to_end(self,
+ bi_data: bool,
+ use_eod_token: bool,
+ do_lower_case: bool):
+ tokenizer = _get_mock_tokenizer()
+
+ num_documents = 5
+ sentences_per_document = 10
+ document_length = 50
+
+ documents = [
+ ["a " * document_length for _ in range(sentences_per_document)]
+ for _ in range(num_documents)]
+
+ save_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
+ files = _create_files(temp_dir=self.get_temp_dir(), file_contents=documents)
+
+ cpd.create_tfrecords(
+ tokenizer=tokenizer,
+ input_file_or_files=",".join(files),
+ use_eod_token=use_eod_token,
+ do_lower_case=do_lower_case,
+ per_host_batch_size=8,
+ seq_length=8,
+ reuse_length=4,
+ bi_data=bi_data,
+ num_cores_per_host=2,
+ save_dir=save_dir)
+
+ self.assertTrue(any(filter(lambda x: x.endswith(".json"),
+ os.listdir(save_dir))))
+ self.assertTrue(any(filter(lambda x: x.endswith(".tfrecord"),
+ os.listdir(save_dir))))
+
+
+if __name__ == "__main__":
+ np.random.seed(0)
+ logging.set_verbosity(logging.INFO)
+ tf.test.main()
diff --git a/modeling/official/nlp/data/data_loader.py b/modeling/official/nlp/data/data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e59b4650c6ff0a4dd7989a1a056f680735cdb110
--- /dev/null
+++ b/modeling/official/nlp/data/data_loader.py
@@ -0,0 +1,48 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""An abstraction that NLP models define input pipelines."""
+
+import abc
+from typing import Optional
+
+import tensorflow as tf, tf_keras
+
+
+class DataLoader(metaclass=abc.ABCMeta):
+ """An abstract class defining the APIs for tf.data input pipeline."""
+
+ @abc.abstractmethod
+ def load(
+ self,
+ input_context: Optional[tf.distribute.InputContext] = None
+ ) -> tf.data.Dataset:
+ """Implements DataLoader load method.
+
+ Builds the entire input pipeline inside the load method. Users can define
+ states inside the DataLoader class and returns a tf.data dataset
+ object.
+
+ Args:
+ input_context: This is a context class that is passed to the user's input
+ function and contains information about the compute replicas and input
+ pipelines. This object is used for multi-host inputs and passed by the
+ distribution strategy.
+
+ Returns:
+ A per-host tf.data dataset. Note that, we usually create the distributed
+ dataset through the load method, so we should not directly return a
+ distributed dataset here.
+ """
+ pass
diff --git a/modeling/official/nlp/data/data_loader_factory.py b/modeling/official/nlp/data/data_loader_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ba79eec01fe2903138c34afa651a24d0e942bbe
--- /dev/null
+++ b/modeling/official/nlp/data/data_loader_factory.py
@@ -0,0 +1,58 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A global factory to access NLP registered data loaders."""
+
+from official.core import registry
+
+_REGISTERED_DATA_LOADER_CLS = {}
+
+
+def register_data_loader_cls(data_config_cls):
+ """Decorates a factory of DataLoader for lookup by a subclass of DataConfig.
+
+ This decorator supports registration of data loaders as follows:
+
+ ```
+ @dataclasses.dataclass
+ class MyDataConfig(DataConfig):
+ # Add fields here.
+ pass
+
+ @register_data_loader_cls(MyDataConfig)
+ class MyDataLoader:
+ # Inherits def __init__(self, data_config).
+ pass
+
+ my_data_config = MyDataConfig()
+
+ # Returns MyDataLoader(my_data_config).
+ my_loader = get_data_loader(my_data_config)
+ ```
+
+ Args:
+ data_config_cls: a subclass of DataConfig (*not* an instance
+ of DataConfig).
+
+ Returns:
+ A callable for use as class decorator that registers the decorated class
+ for creation from an instance of data_config_cls.
+ """
+ return registry.register(_REGISTERED_DATA_LOADER_CLS, data_config_cls)
+
+
+def get_data_loader(data_config):
+ """Creates a data_loader from data_config."""
+ return registry.lookup(_REGISTERED_DATA_LOADER_CLS, data_config.__class__)(
+ data_config)
diff --git a/modeling/official/nlp/data/data_loader_factory_test.py b/modeling/official/nlp/data/data_loader_factory_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..41b2c37d1082897842fb803f514609bc38b2d54f
--- /dev/null
+++ b/modeling/official/nlp/data/data_loader_factory_test.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.data.data_loader_factory."""
+
+import dataclasses
+import tensorflow as tf, tf_keras
+
+from official.core import config_definitions as cfg
+from official.nlp.data import data_loader_factory
+
+
+@dataclasses.dataclass
+class MyDataConfig(cfg.DataConfig):
+ is_training: bool = True
+
+
+@data_loader_factory.register_data_loader_cls(MyDataConfig)
+class MyDataLoader:
+
+ def __init__(self, params):
+ self.params = params
+
+
+class DataLoaderFactoryTest(tf.test.TestCase):
+
+ def test_register_and_load(self):
+ train_config = MyDataConfig()
+ train_loader = data_loader_factory.get_data_loader(train_config)
+ self.assertTrue(train_loader.params.is_training)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/nlp/data/dual_encoder_dataloader.py b/modeling/official/nlp/data/dual_encoder_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f4e8adb2d4ef7abaecbb17805071b168d8a6fbf
--- /dev/null
+++ b/modeling/official/nlp/data/dual_encoder_dataloader.py
@@ -0,0 +1,147 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Loads dataset for the dual encoder (retrieval) task."""
+import dataclasses
+import functools
+import itertools
+from typing import Iterable, Mapping, Optional, Tuple
+
+import tensorflow as tf, tf_keras
+import tensorflow_hub as hub
+
+from official.common import dataset_fn
+from official.core import config_definitions as cfg
+from official.core import input_reader
+from official.nlp.data import data_loader
+from official.nlp.data import data_loader_factory
+from official.nlp.modeling import layers
+
+
+@dataclasses.dataclass
+class DualEncoderDataConfig(cfg.DataConfig):
+ """Data config for dual encoder task (tasks/dual_encoder)."""
+ # Either set `input_path`...
+ input_path: str = ''
+ # ...or `tfds_name` and `tfds_split` to specify input.
+ tfds_name: str = ''
+ tfds_split: str = ''
+ global_batch_size: int = 32
+ # Either build preprocessing with Python code by specifying these values...
+ vocab_file: str = ''
+ lower_case: bool = True
+ # ...or load preprocessing from a SavedModel at this location.
+ preprocessing_hub_module_url: str = ''
+
+ left_text_fields: Tuple[str] = ('left_input',)
+ right_text_fields: Tuple[str] = ('right_input',)
+ is_training: bool = True
+ seq_length: int = 128
+ file_type: str = 'tfrecord'
+
+
+@data_loader_factory.register_data_loader_cls(DualEncoderDataConfig)
+class DualEncoderDataLoader(data_loader.DataLoader):
+ """A class to load dataset for dual encoder task (tasks/dual_encoder)."""
+
+ def __init__(self, params):
+ if bool(params.tfds_name) == bool(params.input_path):
+ raise ValueError('Must specify either `tfds_name` and `tfds_split` '
+ 'or `input_path`.')
+ if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
+ raise ValueError('Must specify exactly one of vocab_file (with matching '
+ 'lower_case flag) or preprocessing_hub_module_url.')
+ self._params = params
+ self._seq_length = params.seq_length
+ self._left_text_fields = params.left_text_fields
+ self._right_text_fields = params.right_text_fields
+
+ if params.preprocessing_hub_module_url:
+ preprocessing_hub_module = hub.load(params.preprocessing_hub_module_url)
+ self._tokenizer = preprocessing_hub_module.tokenize
+ self._pack_inputs = functools.partial(
+ preprocessing_hub_module.bert_pack_inputs,
+ seq_length=params.seq_length)
+ else:
+ self._tokenizer = layers.BertTokenizer(
+ vocab_file=params.vocab_file, lower_case=params.lower_case)
+ self._pack_inputs = layers.BertPackInputs(
+ seq_length=params.seq_length,
+ special_tokens_dict=self._tokenizer.get_special_tokens_dict())
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = {
+ x: tf.io.FixedLenFeature([], tf.string)
+ for x in itertools.chain(
+ *[self._left_text_fields, self._right_text_fields])
+ }
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in example:
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _bert_tokenize(
+ self, record: Mapping[str, tf.Tensor],
+ text_fields: Iterable[str]) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
+ """Tokenize the input in text_fields using BERT tokenizer.
+
+ Args:
+ record: A tfexample record contains the features.
+ text_fields: A list of fields to be tokenzied.
+
+ Returns:
+ The tokenized features in a tuple of (input_word_ids, input_mask,
+ input_type_ids).
+ """
+ segments_text = [record[x] for x in text_fields]
+ segments_tokens = [self._tokenizer(s) for s in segments_text]
+ segments = [tf.cast(x.merge_dims(1, 2), tf.int32) for x in segments_tokens]
+ return self._pack_inputs(segments)
+
+ def _bert_preprocess(
+ self, record: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
+ """Perform the bert word piece tokenization for left and right inputs."""
+
+ def _switch_prefix(string, old, new):
+ if string.startswith(old): return new + string[len(old):]
+ raise ValueError('Expected {} to start with {}'.format(string, old))
+
+ def _switch_key_prefix(d, old, new):
+ return {_switch_prefix(key, old, new): value for key, value in d.items()} # pytype: disable=attribute-error # trace-all-classes
+
+ model_inputs = _switch_key_prefix(
+ self._bert_tokenize(record, self._left_text_fields),
+ 'input_', 'left_')
+ model_inputs.update(_switch_key_prefix(
+ self._bert_tokenize(record, self._right_text_fields),
+ 'input_', 'right_'))
+ return model_inputs
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ params=self._params,
+ # Skip `decoder_fn` for tfds input.
+ decoder_fn=self._decode if self._params.input_path else None,
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
+ postprocess_fn=self._bert_preprocess)
+ return reader.read(input_context)
diff --git a/modeling/official/nlp/data/dual_encoder_dataloader_test.py b/modeling/official/nlp/data/dual_encoder_dataloader_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..f74c01c81a3b039c8027863a55b4dd54ad2d3345
--- /dev/null
+++ b/modeling/official/nlp/data/dual_encoder_dataloader_test.py
@@ -0,0 +1,131 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.data.dual_encoder_dataloader."""
+import os
+
+from absl.testing import parameterized
+import tensorflow as tf, tf_keras
+
+from official.nlp.data import dual_encoder_dataloader
+
+
+_LEFT_FEATURE_NAME = 'left_input'
+_RIGHT_FEATURE_NAME = 'right_input'
+
+
+def _create_fake_dataset(output_path):
+ """Creates a fake dataset contains examples for training a dual encoder model.
+
+ The created dataset contains examples with two byteslist features keyed by
+ _LEFT_FEATURE_NAME and _RIGHT_FEATURE_NAME.
+
+ Args:
+ output_path: The output path of the fake dataset.
+ """
+ def create_str_feature(values):
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
+
+ with tf.io.TFRecordWriter(output_path) as writer:
+ for _ in range(100):
+ features = {}
+ features[_LEFT_FEATURE_NAME] = create_str_feature([b'hello world.'])
+ features[_RIGHT_FEATURE_NAME] = create_str_feature([b'world hello.'])
+
+ tf_example = tf.train.Example(
+ features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+
+
+def _make_vocab_file(vocab, output_path):
+ with tf.io.gfile.GFile(output_path, 'w') as f:
+ f.write('\n'.join(vocab + ['']))
+
+
+class DualEncoderDataTest(tf.test.TestCase, parameterized.TestCase):
+
+ def test_load_dataset(self):
+ seq_length = 16
+ batch_size = 10
+ train_data_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
+
+ _create_fake_dataset(train_data_path)
+ _make_vocab_file(
+ ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'], vocab_path)
+
+ data_config = dual_encoder_dataloader.DualEncoderDataConfig(
+ input_path=train_data_path,
+ seq_length=seq_length,
+ vocab_file=vocab_path,
+ lower_case=True,
+ left_text_fields=(_LEFT_FEATURE_NAME,),
+ right_text_fields=(_RIGHT_FEATURE_NAME,),
+ global_batch_size=batch_size)
+ dataset = dual_encoder_dataloader.DualEncoderDataLoader(
+ data_config).load()
+ features = next(iter(dataset))
+ self.assertCountEqual(
+ ['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
+ 'right_mask', 'right_type_ids'],
+ features.keys())
+ self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
+
+ @parameterized.parameters(False, True)
+ def test_load_tfds(self, use_preprocessing_hub):
+ seq_length = 16
+ batch_size = 10
+ if use_preprocessing_hub:
+ vocab_path = ''
+ preprocessing_hub = (
+ 'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3')
+ else:
+ vocab_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
+ _make_vocab_file(
+ ['[PAD]', '[UNK]', '[CLS]', '[SEP]', 'he', '#llo', 'world'],
+ vocab_path)
+ preprocessing_hub = ''
+
+ data_config = dual_encoder_dataloader.DualEncoderDataConfig(
+ tfds_name='para_crawl/enmt',
+ tfds_split='train',
+ seq_length=seq_length,
+ vocab_file=vocab_path,
+ lower_case=True,
+ left_text_fields=('en',),
+ right_text_fields=('mt',),
+ preprocessing_hub_module_url=preprocessing_hub,
+ global_batch_size=batch_size)
+ dataset = dual_encoder_dataloader.DualEncoderDataLoader(
+ data_config).load()
+ features = next(iter(dataset))
+ self.assertCountEqual(
+ ['left_word_ids', 'left_mask', 'left_type_ids', 'right_word_ids',
+ 'right_mask', 'right_type_ids'],
+ features.keys())
+ self.assertEqual(features['left_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['left_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['left_type_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['right_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['right_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['right_type_ids'].shape, (batch_size, seq_length))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/nlp/data/pretrain_dataloader.py b/modeling/official/nlp/data/pretrain_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..51e268b1b339ba725ac37d41c7543670d770be9e
--- /dev/null
+++ b/modeling/official/nlp/data/pretrain_dataloader.py
@@ -0,0 +1,589 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Loads dataset for the BERT pretraining task."""
+import dataclasses
+from typing import Mapping, Optional
+
+from absl import logging
+
+import numpy as np
+import tensorflow as tf, tf_keras
+from official.common import dataset_fn
+from official.core import config_definitions as cfg
+from official.core import input_reader
+from official.nlp.data import data_loader
+from official.nlp.data import data_loader_factory
+
+
+@dataclasses.dataclass
+class BertPretrainDataConfig(cfg.DataConfig):
+ """Data config for BERT pretraining task (tasks/masked_lm)."""
+ input_path: str = ''
+ global_batch_size: int = 512
+ is_training: bool = True
+ seq_length: int = 512
+ max_predictions_per_seq: int = 76
+ use_next_sentence_label: bool = True
+ use_position_id: bool = False
+ # Historically, BERT implementations take `input_ids` and `segment_ids` as
+ # feature names. Inside the TF Model Garden implementation, the Keras model
+ # inputs are set as `input_word_ids` and `input_type_ids`. When
+ # v2_feature_names is True, the data loader assumes the tf.Examples use
+ # `input_word_ids` and `input_type_ids` as keys.
+ use_v2_feature_names: bool = False
+ file_type: str = 'tfrecord'
+
+
+@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
+class BertPretrainDataLoader(data_loader.DataLoader):
+ """A class to load dataset for bert pretraining task."""
+
+ def __init__(self, params):
+ """Inits `BertPretrainDataLoader` class.
+
+ Args:
+ params: A `BertPretrainDataConfig` object.
+ """
+ self._params = params
+ self._seq_length = params.seq_length
+ self._max_predictions_per_seq = params.max_predictions_per_seq
+ self._use_next_sentence_label = params.use_next_sentence_label
+ self._use_position_id = params.use_position_id
+
+ def _name_to_features(self):
+ name_to_features = {
+ 'input_mask':
+ tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'masked_lm_positions':
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
+ 'masked_lm_ids':
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.int64),
+ 'masked_lm_weights':
+ tf.io.FixedLenFeature([self._max_predictions_per_seq], tf.float32),
+ }
+ if self._params.use_v2_feature_names:
+ name_to_features.update({
+ 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ })
+ else:
+ name_to_features.update({
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ })
+ if self._use_next_sentence_label:
+ name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
+ tf.int64)
+ if self._use_position_id:
+ name_to_features['position_ids'] = tf.io.FixedLenFeature(
+ [self._seq_length], tf.int64)
+ return name_to_features
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = self._name_to_features()
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _parse(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ x = {
+ 'input_mask': record['input_mask'],
+ 'masked_lm_positions': record['masked_lm_positions'],
+ 'masked_lm_ids': record['masked_lm_ids'],
+ 'masked_lm_weights': record['masked_lm_weights'],
+ }
+ if self._params.use_v2_feature_names:
+ x['input_word_ids'] = record['input_word_ids']
+ x['input_type_ids'] = record['input_type_ids']
+ else:
+ x['input_word_ids'] = record['input_ids']
+ x['input_type_ids'] = record['segment_ids']
+ if self._use_next_sentence_label:
+ x['next_sentence_labels'] = record['next_sentence_labels']
+ if self._use_position_id:
+ x['position_ids'] = record['position_ids']
+
+ return x
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ params=self._params,
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
+ decoder_fn=self._decode,
+ parser_fn=self._parse)
+ return reader.read(input_context)
+
+
+@dataclasses.dataclass
+class XLNetPretrainDataConfig(cfg.DataConfig):
+ """Data config for XLNet pretraining task.
+
+ Attributes:
+ input_path: See base class.
+ global_batch_size: See base calss.
+ is_training: See base class.
+ seq_length: The length of each sequence.
+ max_predictions_per_seq: The number of predictions per sequence.
+ reuse_length: The number of tokens in a previous segment to reuse. This
+ should be the same value used during pretrain data creation.
+ sample_strategy: The strategy used to sample factorization permutations.
+ Possible values: 'single_token', 'whole_word', 'token_span', 'word_span'.
+ min_num_tokens: The minimum number of tokens to sample in a span. This is
+ used when `sample_strategy` is 'token_span'.
+ max_num_tokens: The maximum number of tokens to sample in a span. This is
+ used when `sample_strategy` is 'token_span'.
+ min_num_words: The minimum number of words to sample in a span. This is used
+ when `sample_strategy` is 'word_span'.
+ max_num_words: The maximum number of words to sample in a span. This is used
+ when `sample_strategy` is 'word_span'.
+ permutation_size: The length of the longest permutation. This can be set to
+ `reuse_length`. This should NOT be greater than `reuse_length`, otherwise
+ this may introduce data leaks.
+ leak_ratio: The percentage of masked tokens that are leaked.
+ segment_sep_id: The ID of the SEP token used when preprocessing the dataset.
+ segment_cls_id: The ID of the CLS token used when preprocessing the dataset.
+ """
+ input_path: str = ''
+ global_batch_size: int = 512
+ is_training: bool = True
+ seq_length: int = 512
+ max_predictions_per_seq: int = 76
+ reuse_length: int = 256
+ sample_strategy: str = 'word_span'
+ min_num_tokens: int = 1
+ max_num_tokens: int = 5
+ min_num_words: int = 1
+ max_num_words: int = 5
+ permutation_size: int = 256
+ leak_ratio: float = 0.1
+ segment_sep_id: int = 4
+ segment_cls_id: int = 3
+
+
+@data_loader_factory.register_data_loader_cls(XLNetPretrainDataConfig)
+class XLNetPretrainDataLoader(data_loader.DataLoader):
+ """A class to load dataset for xlnet pretraining task."""
+
+ def __init__(self, params: XLNetPretrainDataConfig):
+ """Inits `XLNetPretrainDataLoader` class.
+
+ Args:
+ params: A `XLNetPretrainDataConfig` object.
+ """
+ self._params = params
+ self._seq_length = params.seq_length
+ self._max_predictions_per_seq = params.max_predictions_per_seq
+ self._reuse_length = params.reuse_length
+ self._num_replicas_in_sync = None
+ self._permutation_size = params.permutation_size
+ self._sep_id = params.segment_sep_id
+ self._cls_id = params.segment_cls_id
+ self._sample_strategy = params.sample_strategy
+ self._leak_ratio = params.leak_ratio
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = {
+ 'input_word_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'input_type_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'boundary_indices': tf.io.VarLenFeature(tf.int64),
+ }
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in list(example.keys()):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _parse(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ x = {}
+
+ inputs = record['input_word_ids']
+ x['input_type_ids'] = record['input_type_ids']
+
+ if self._sample_strategy in ['whole_word', 'word_span']:
+ boundary = tf.sparse.to_dense(record['boundary_indices'])
+ else:
+ boundary = None
+
+ input_mask = self._online_sample_mask(inputs=inputs, boundary=boundary)
+
+ if self._reuse_length > 0:
+ if self._permutation_size > self._reuse_length:
+ logging.warning(
+ '`permutation_size` is greater than `reuse_length` (%d > %d).'
+ 'This may introduce data leakage.', self._permutation_size,
+ self._reuse_length)
+
+ # Enable the memory mechanism.
+ # Permute the reuse and non-reuse segments separately.
+ non_reuse_len = self._seq_length - self._reuse_length
+ if not (self._reuse_length % self._permutation_size == 0 and
+ non_reuse_len % self._permutation_size == 0):
+ raise ValueError('`reuse_length` and `seq_length` should both be '
+ 'a multiple of `permutation_size`.')
+
+ # Creates permutation mask and target mask for the first reuse_len tokens.
+ # The tokens in this part are reused from the last sequence.
+ perm_mask_0, target_mask_0, tokens_0, masked_0 = self._get_factorization(
+ inputs=inputs[:self._reuse_length],
+ input_mask=input_mask[:self._reuse_length])
+
+ # Creates permutation mask and target mask for the rest of tokens in
+ # current example, which are concatentation of two new segments.
+ perm_mask_1, target_mask_1, tokens_1, masked_1 = self._get_factorization(
+ inputs[self._reuse_length:], input_mask[self._reuse_length:])
+
+ perm_mask_0 = tf.concat([
+ perm_mask_0,
+ tf.zeros([self._reuse_length, non_reuse_len], dtype=tf.int32)
+ ],
+ axis=1)
+ perm_mask_1 = tf.concat([
+ tf.ones([non_reuse_len, self._reuse_length], dtype=tf.int32),
+ perm_mask_1
+ ],
+ axis=1)
+ perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
+ target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
+ tokens = tf.concat([tokens_0, tokens_1], axis=0)
+ masked_tokens = tf.concat([masked_0, masked_1], axis=0)
+ else:
+ # Disable the memory mechanism.
+ if self._seq_length % self._permutation_size != 0:
+ raise ValueError('`seq_length` should be a multiple of '
+ '`permutation_size`.')
+ # Permute the entire sequence together
+ perm_mask, target_mask, tokens, masked_tokens = self._get_factorization(
+ inputs=inputs, input_mask=input_mask)
+ x['permutation_mask'] = tf.reshape(perm_mask,
+ [self._seq_length, self._seq_length])
+ x['input_word_ids'] = tokens
+ x['masked_tokens'] = masked_tokens
+
+ target = tokens
+ if self._max_predictions_per_seq is not None:
+ indices = tf.range(self._seq_length, dtype=tf.int32)
+ bool_target_mask = tf.cast(target_mask, tf.bool)
+ indices = tf.boolean_mask(indices, bool_target_mask)
+
+ # account for extra padding due to CLS/SEP.
+ actual_num_predict = tf.shape(indices)[0]
+ pad_len = self._max_predictions_per_seq - actual_num_predict
+
+ target_mapping = tf.one_hot(indices, self._seq_length, dtype=tf.int32)
+ paddings = tf.zeros([pad_len, self._seq_length],
+ dtype=target_mapping.dtype)
+ target_mapping = tf.concat([target_mapping, paddings], axis=0)
+ x['target_mapping'] = tf.reshape(
+ target_mapping, [self._max_predictions_per_seq, self._seq_length])
+
+ target = tf.boolean_mask(target, bool_target_mask)
+ paddings = tf.zeros([pad_len], dtype=target.dtype)
+ target = tf.concat([target, paddings], axis=0)
+ x['target'] = tf.reshape(target, [self._max_predictions_per_seq])
+
+ target_mask = tf.concat([
+ tf.ones([actual_num_predict], dtype=tf.int32),
+ tf.zeros([pad_len], dtype=tf.int32)
+ ],
+ axis=0)
+ x['target_mask'] = tf.reshape(target_mask,
+ [self._max_predictions_per_seq])
+ else:
+ x['target'] = tf.reshape(target, [self._seq_length])
+ x['target_mask'] = tf.reshape(target_mask, [self._seq_length])
+ return x
+
+ def _index_pair_to_mask(self, begin_indices: tf.Tensor,
+ end_indices: tf.Tensor,
+ inputs: tf.Tensor) -> tf.Tensor:
+ """Converts beginning and end indices into an actual mask."""
+ non_func_mask = tf.logical_and(
+ tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
+ all_indices = tf.where(
+ non_func_mask, tf.range(self._seq_length, dtype=tf.int32),
+ tf.constant(-1, shape=[self._seq_length], dtype=tf.int32))
+ candidate_matrix = tf.cast(
+ tf.logical_and(all_indices[None, :] >= begin_indices[:, None],
+ all_indices[None, :] < end_indices[:, None]), tf.float32)
+ cumsum_matrix = tf.reshape(
+ tf.cumsum(tf.reshape(candidate_matrix, [-1])), [-1, self._seq_length])
+ masked_matrix = tf.cast(cumsum_matrix <= self._max_predictions_per_seq,
+ tf.float32)
+ target_mask = tf.reduce_sum(candidate_matrix * masked_matrix, axis=0)
+ return tf.cast(target_mask, tf.bool)
+
+ def _single_token_mask(self, inputs: tf.Tensor) -> tf.Tensor:
+ """Samples individual tokens as prediction targets."""
+ all_indices = tf.range(self._seq_length, dtype=tf.int32)
+ non_func_mask = tf.logical_and(
+ tf.not_equal(inputs, self._sep_id), tf.not_equal(inputs, self._cls_id))
+ non_func_indices = tf.boolean_mask(all_indices, non_func_mask)
+
+ masked_pos = tf.random.shuffle(non_func_indices)
+ masked_pos = tf.sort(masked_pos[:self._max_predictions_per_seq])
+
+ sparse_indices = tf.stack([tf.zeros_like(masked_pos), masked_pos], axis=-1)
+ sparse_indices = tf.cast(sparse_indices, tf.int64)
+
+ sparse_indices = tf.sparse.SparseTensor(
+ sparse_indices,
+ values=tf.ones_like(masked_pos),
+ dense_shape=(1, self._seq_length))
+
+ target_mask = tf.sparse.to_dense(sp_input=sparse_indices, default_value=0)
+
+ return tf.squeeze(tf.cast(target_mask, tf.bool))
+
+ def _whole_word_mask(self, inputs: tf.Tensor,
+ boundary: tf.Tensor) -> tf.Tensor:
+ """Samples whole words as prediction targets."""
+ pair_indices = tf.concat([boundary[:-1, None], boundary[1:, None]], axis=1)
+ cand_pair_indices = tf.random.shuffle(
+ pair_indices)[:self._max_predictions_per_seq]
+ begin_indices = cand_pair_indices[:, 0]
+ end_indices = cand_pair_indices[:, 1]
+
+ return self._index_pair_to_mask(
+ begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
+
+ def _token_span_mask(self, inputs: tf.Tensor) -> tf.Tensor:
+ """Samples token spans as prediction targets."""
+ min_num_tokens = self._params.min_num_tokens
+ max_num_tokens = self._params.max_num_tokens
+
+ mask_alpha = self._seq_length / self._max_predictions_per_seq
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
+
+ # Sample span lengths from a zipf distribution
+ span_len_seq = np.arange(min_num_tokens, max_num_tokens + 1)
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
+
+ probs /= np.sum(probs)
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
+ span_lens = tf.random.categorical(
+ logits=logits[None],
+ num_samples=self._max_predictions_per_seq,
+ dtype=tf.int32,
+ )[0] + min_num_tokens
+
+ # Sample the ratio [0.0, 1.0) of left context lengths
+ span_lens_float = tf.cast(span_lens, tf.float32)
+ left_ratio = tf.random.uniform(
+ shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
+ left_ctx_len = round_to_int(left_ctx_len)
+
+ # Compute the offset from left start to the right end
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
+
+ # Get the actual begin and end indices
+ begin_indices = (
+ tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
+ end_indices = begin_indices + span_lens
+
+ # Remove out of range indices
+ valid_idx_mask = end_indices < self._seq_length
+ begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
+
+ # Shuffle valid indices
+ num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
+ begin_indices = tf.gather(begin_indices, order)
+ end_indices = tf.gather(end_indices, order)
+
+ return self._index_pair_to_mask(
+ begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
+
+ def _word_span_mask(self, inputs: tf.Tensor, boundary: tf.Tensor):
+ """Sample whole word spans as prediction targets."""
+ min_num_words = self._params.min_num_words
+ max_num_words = self._params.max_num_words
+
+ # Note: 1.2 is the token-to-word ratio
+ mask_alpha = self._seq_length / self._max_predictions_per_seq / 1.2
+ round_to_int = lambda x: tf.cast(tf.round(x), tf.int32)
+
+ # Sample span lengths from a zipf distribution
+ span_len_seq = np.arange(min_num_words, max_num_words + 1)
+ probs = np.array([1.0 / (i + 1) for i in span_len_seq])
+ probs /= np.sum(probs)
+ logits = tf.constant(np.log(probs), dtype=tf.float32)
+
+ # Sample `num_predict` words here: note that this is over sampling
+ span_lens = tf.random.categorical(
+ logits=logits[None],
+ num_samples=self._max_predictions_per_seq,
+ dtype=tf.int32,
+ )[0] + min_num_words
+
+ # Sample the ratio [0.0, 1.0) of left context lengths
+ span_lens_float = tf.cast(span_lens, tf.float32)
+ left_ratio = tf.random.uniform(
+ shape=[self._max_predictions_per_seq], minval=0.0, maxval=1.0)
+ left_ctx_len = left_ratio * span_lens_float * (mask_alpha - 1)
+
+ left_ctx_len = round_to_int(left_ctx_len)
+ right_offset = round_to_int(span_lens_float * mask_alpha) - left_ctx_len
+
+ begin_indices = (
+ tf.cumsum(left_ctx_len) + tf.cumsum(right_offset, exclusive=True))
+ end_indices = begin_indices + span_lens
+
+ # Remove out of range indices
+ max_boundary_index = tf.cast(tf.shape(boundary)[0] - 1, tf.int32)
+ valid_idx_mask = end_indices < max_boundary_index
+ begin_indices = tf.boolean_mask(begin_indices, valid_idx_mask)
+ end_indices = tf.boolean_mask(end_indices, valid_idx_mask)
+
+ begin_indices = tf.gather(boundary, begin_indices)
+ end_indices = tf.gather(boundary, end_indices)
+
+ # Shuffle valid indices
+ num_valid = tf.cast(tf.shape(begin_indices)[0], tf.int32)
+ order = tf.random.shuffle(tf.range(num_valid, dtype=tf.int32))
+ begin_indices = tf.gather(begin_indices, order)
+ end_indices = tf.gather(end_indices, order)
+
+ return self._index_pair_to_mask(
+ begin_indices=begin_indices, end_indices=end_indices, inputs=inputs)
+
+ def _online_sample_mask(self, inputs: tf.Tensor,
+ boundary: tf.Tensor) -> tf.Tensor:
+ """Samples target positions for predictions.
+
+ Descriptions of each strategy:
+ - 'single_token': Samples individual tokens as prediction targets.
+ - 'token_span': Samples spans of tokens as prediction targets.
+ - 'whole_word': Samples individual words as prediction targets.
+ - 'word_span': Samples spans of words as prediction targets.
+
+ Args:
+ inputs: The input tokens.
+ boundary: The `int` Tensor of indices indicating whole word boundaries.
+ This is used in 'whole_word' and 'word_span'
+
+ Returns:
+ The sampled `bool` input mask.
+
+ Raises:
+ `ValueError`: if `max_predictions_per_seq` is not set or if boundary is
+ not provided for 'whole_word' and 'word_span' sample strategies.
+ """
+ if self._max_predictions_per_seq is None:
+ raise ValueError('`max_predictions_per_seq` must be set.')
+
+ if boundary is None and 'word' in self._sample_strategy:
+ raise ValueError('`boundary` must be provided for {} strategy'.format(
+ self._sample_strategy))
+
+ if self._sample_strategy == 'single_token':
+ return self._single_token_mask(inputs)
+ elif self._sample_strategy == 'token_span':
+ return self._token_span_mask(inputs)
+ elif self._sample_strategy == 'whole_word':
+ return self._whole_word_mask(inputs, boundary)
+ elif self._sample_strategy == 'word_span':
+ return self._word_span_mask(inputs, boundary)
+ else:
+ raise NotImplementedError('Invalid sample strategy.')
+
+ def _get_factorization(self, inputs: tf.Tensor, input_mask: tf.Tensor):
+ """Samples a permutation of the factorization order.
+
+ Args:
+ inputs: the input tokens.
+ input_mask: the `bool` Tensor of the same shape as `inputs`. If `True`,
+ then this means select for partial prediction.
+
+ Returns:
+ perm_mask: An `int32` Tensor of shape [seq_length, seq_length] consisting
+ of 0s and 1s. If perm_mask[i][j] == 0, then this means that the i-th
+ token (in original order) cannot attend to the jth attention token.
+ target_mask: An `int32` Tensor of shape [seq_len] consisting of 0s and 1s.
+ If target_mask[i] == 1, then the i-th token needs to be predicted and
+ the mask will be used as input. This token will be included in the loss.
+ If target_mask[i] == 0, then the token (or [SEP], [CLS]) will be used as
+ input. This token will not be included in the loss.
+ tokens: int32 Tensor of shape [seq_length].
+ masked_tokens: int32 Tensor of shape [seq_length].
+ """
+ factorization_length = tf.shape(inputs)[0]
+ # Generate permutation indices
+ index = tf.range(factorization_length, dtype=tf.int32)
+ index = tf.transpose(tf.reshape(index, [-1, self._permutation_size]))
+ index = tf.random.shuffle(index)
+ index = tf.reshape(tf.transpose(index), [-1])
+
+ input_mask = tf.cast(input_mask, tf.bool)
+
+ # non-functional tokens
+ non_func_tokens = tf.logical_not(
+ tf.logical_or(
+ tf.equal(inputs, self._sep_id), tf.equal(inputs, self._cls_id)))
+ masked_tokens = tf.logical_and(input_mask, non_func_tokens)
+ non_masked_or_func_tokens = tf.logical_not(masked_tokens)
+
+ smallest_index = -2 * tf.ones([factorization_length], dtype=tf.int32)
+
+ # Similar to BERT, randomly leak some masked tokens
+ if self._leak_ratio > 0:
+ leak_tokens = tf.logical_and(
+ masked_tokens,
+ tf.random.uniform([factorization_length], maxval=1.0) <
+ self._leak_ratio)
+ can_attend_self = tf.logical_or(non_masked_or_func_tokens, leak_tokens)
+ else:
+ can_attend_self = non_masked_or_func_tokens
+ to_index = tf.where(can_attend_self, smallest_index, index)
+ from_index = tf.where(can_attend_self, to_index + 1, to_index)
+
+ # For masked tokens, can attend if i > j
+ # For context tokens, always can attend each other
+ can_attend = from_index[:, None] > to_index[None, :]
+
+ perm_mask = tf.cast(can_attend, tf.int32)
+
+ # Only masked tokens are included in the loss
+ target_mask = tf.cast(masked_tokens, tf.int32)
+
+ return perm_mask, target_mask, inputs, masked_tokens
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ if input_context:
+ self._num_replicas_in_sync = input_context.num_replicas_in_sync
+ reader = input_reader.InputReader(
+ params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
+ return reader.read(input_context)
diff --git a/modeling/official/nlp/data/pretrain_dataloader_test.py b/modeling/official/nlp/data/pretrain_dataloader_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebefbfb2fbf8651568c353a9b199ad6dfb70c132
--- /dev/null
+++ b/modeling/official/nlp/data/pretrain_dataloader_test.py
@@ -0,0 +1,242 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.data.pretrain_dataloader."""
+import itertools
+import os
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.nlp.data import pretrain_dataloader
+
+
+def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+
+def _create_fake_bert_dataset(
+ output_path,
+ seq_length,
+ max_predictions_per_seq,
+ use_position_id,
+ use_next_sentence_label,
+ use_v2_feature_names=False):
+ """Creates a fake dataset."""
+ writer = tf.io.TFRecordWriter(output_path)
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ for _ in range(100):
+ features = {}
+ input_ids = np.random.randint(100, size=(seq_length))
+ features["input_mask"] = create_int_feature(np.ones_like(input_ids))
+ if use_v2_feature_names:
+ features["input_word_ids"] = create_int_feature(input_ids)
+ features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
+ else:
+ features["input_ids"] = create_int_feature(input_ids)
+ features["segment_ids"] = create_int_feature(np.ones_like(input_ids))
+
+ features["masked_lm_positions"] = create_int_feature(
+ np.random.randint(100, size=(max_predictions_per_seq)))
+ features["masked_lm_ids"] = create_int_feature(
+ np.random.randint(100, size=(max_predictions_per_seq)))
+ features["masked_lm_weights"] = create_float_feature(
+ [1.0] * max_predictions_per_seq)
+
+ if use_next_sentence_label:
+ features["next_sentence_labels"] = create_int_feature([1])
+
+ if use_position_id:
+ features["position_ids"] = create_int_feature(range(0, seq_length))
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+def _create_fake_xlnet_dataset(
+ output_path, seq_length, max_predictions_per_seq):
+ """Creates a fake dataset."""
+ writer = tf.io.TFRecordWriter(output_path)
+ for _ in range(100):
+ features = {}
+ input_ids = np.random.randint(100, size=(seq_length))
+ num_boundary_indices = np.random.randint(1, seq_length)
+
+ if max_predictions_per_seq is not None:
+ input_mask = np.zeros_like(input_ids)
+ input_mask[:max_predictions_per_seq] = 1
+ np.random.shuffle(input_mask)
+ else:
+ input_mask = np.ones_like(input_ids)
+
+ features["input_mask"] = create_int_feature(input_mask)
+ features["input_word_ids"] = create_int_feature(input_ids)
+ features["input_type_ids"] = create_int_feature(np.ones_like(input_ids))
+ features["boundary_indices"] = create_int_feature(
+ sorted(np.random.randint(seq_length, size=(num_boundary_indices))))
+ features["target"] = create_int_feature(input_ids + 1)
+ features["label"] = create_int_feature([1])
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+class BertPretrainDataTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(itertools.product(
+ (False, True),
+ (False, True),
+ ))
+ def test_load_data(self, use_next_sentence_label, use_position_id):
+ train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
+ seq_length = 128
+ max_predictions_per_seq = 20
+ _create_fake_bert_dataset(
+ train_data_path,
+ seq_length,
+ max_predictions_per_seq,
+ use_next_sentence_label=use_next_sentence_label,
+ use_position_id=use_position_id)
+ data_config = pretrain_dataloader.BertPretrainDataConfig(
+ input_path=train_data_path,
+ max_predictions_per_seq=max_predictions_per_seq,
+ seq_length=seq_length,
+ global_batch_size=10,
+ is_training=True,
+ use_next_sentence_label=use_next_sentence_label,
+ use_position_id=use_position_id)
+
+ dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
+ features = next(iter(dataset))
+ self.assertLen(features,
+ 6 + int(use_next_sentence_label) + int(use_position_id))
+ self.assertIn("input_word_ids", features)
+ self.assertIn("input_mask", features)
+ self.assertIn("input_type_ids", features)
+ self.assertIn("masked_lm_positions", features)
+ self.assertIn("masked_lm_ids", features)
+ self.assertIn("masked_lm_weights", features)
+
+ self.assertEqual("next_sentence_labels" in features,
+ use_next_sentence_label)
+ self.assertEqual("position_ids" in features, use_position_id)
+
+ def test_v2_feature_names(self):
+ train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
+ seq_length = 128
+ max_predictions_per_seq = 20
+ _create_fake_bert_dataset(
+ train_data_path,
+ seq_length,
+ max_predictions_per_seq,
+ use_next_sentence_label=True,
+ use_position_id=False,
+ use_v2_feature_names=True)
+ data_config = pretrain_dataloader.BertPretrainDataConfig(
+ input_path=train_data_path,
+ max_predictions_per_seq=max_predictions_per_seq,
+ seq_length=seq_length,
+ global_batch_size=10,
+ is_training=True,
+ use_next_sentence_label=True,
+ use_position_id=False,
+ use_v2_feature_names=True)
+
+ dataset = pretrain_dataloader.BertPretrainDataLoader(data_config).load()
+ features = next(iter(dataset))
+ self.assertIn("input_word_ids", features)
+ self.assertIn("input_mask", features)
+ self.assertIn("input_type_ids", features)
+ self.assertIn("masked_lm_positions", features)
+ self.assertIn("masked_lm_ids", features)
+ self.assertIn("masked_lm_weights", features)
+
+
+class XLNetPretrainDataTest(parameterized.TestCase, tf.test.TestCase):
+
+ @parameterized.parameters(itertools.product(
+ ("single_token", "whole_word", "token_span"),
+ (0, 64),
+ (20, None),
+ ))
+ def test_load_data(
+ self, sample_strategy, reuse_length, max_predictions_per_seq):
+ train_data_path = os.path.join(self.get_temp_dir(), "train.tf_record")
+ seq_length = 128
+ batch_size = 5
+
+ _create_fake_xlnet_dataset(
+ train_data_path, seq_length, max_predictions_per_seq)
+
+ data_config = pretrain_dataloader.XLNetPretrainDataConfig(
+ input_path=train_data_path,
+ max_predictions_per_seq=max_predictions_per_seq,
+ seq_length=seq_length,
+ global_batch_size=batch_size,
+ is_training=True,
+ reuse_length=reuse_length,
+ sample_strategy=sample_strategy,
+ min_num_tokens=1,
+ max_num_tokens=2,
+ permutation_size=seq_length // 2,
+ leak_ratio=0.1)
+
+ if max_predictions_per_seq is None:
+ with self.assertRaises(ValueError):
+ dataset = pretrain_dataloader.XLNetPretrainDataLoader(
+ data_config).load()
+ features = next(iter(dataset))
+ else:
+ dataset = pretrain_dataloader.XLNetPretrainDataLoader(data_config).load()
+ features = next(iter(dataset))
+
+ self.assertIn("input_word_ids", features)
+ self.assertIn("input_type_ids", features)
+ self.assertIn("permutation_mask", features)
+ self.assertIn("masked_tokens", features)
+ self.assertIn("target", features)
+ self.assertIn("target_mask", features)
+
+ self.assertAllClose(features["input_word_ids"].shape,
+ (batch_size, seq_length))
+ self.assertAllClose(features["input_type_ids"].shape,
+ (batch_size, seq_length))
+ self.assertAllClose(features["permutation_mask"].shape,
+ (batch_size, seq_length, seq_length))
+ self.assertAllClose(features["masked_tokens"].shape,
+ (batch_size, seq_length,))
+ if max_predictions_per_seq is not None:
+ self.assertIn("target_mapping", features)
+ self.assertAllClose(features["target_mapping"].shape,
+ (batch_size, max_predictions_per_seq, seq_length))
+ self.assertAllClose(features["target_mask"].shape,
+ (batch_size, max_predictions_per_seq))
+ self.assertAllClose(features["target"].shape,
+ (batch_size, max_predictions_per_seq))
+ else:
+ self.assertAllClose(features["target_mask"].shape,
+ (batch_size, seq_length))
+ self.assertAllClose(features["target"].shape,
+ (batch_size, seq_length))
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/modeling/official/nlp/data/pretrain_dynamic_dataloader.py b/modeling/official/nlp/data/pretrain_dynamic_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..2168c34f36f33436e507e4233a09a32895789493
--- /dev/null
+++ b/modeling/official/nlp/data/pretrain_dynamic_dataloader.py
@@ -0,0 +1,223 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Dataset loader for the pre-training with dynamic sequence length."""
+from typing import Optional, Tuple
+
+import dataclasses
+import tensorflow as tf, tf_keras
+
+from official.core import config_definitions as cfg
+from official.core import input_reader
+from official.nlp.data import data_loader_factory
+from official.nlp.data import pretrain_dataloader
+
+
+@dataclasses.dataclass
+class BertPretrainDataConfig(cfg.DataConfig):
+ """Data config for BERT pretraining task (tasks/masked_lm)."""
+ input_path: str = ''
+ global_batch_size: int = 512
+ is_training: bool = True
+ seq_bucket_lengths: Tuple[int, ...] = (128, 256, 384, 512,)
+ # TODO(rxsang): `seq_bucket_window_scale` is only useful when round robin
+ # tf.data service is disabled. Deprecate this flag once we always enable round
+ # robin tf.data service.
+ seq_bucket_window_scale: int = 8
+ use_next_sentence_label: bool = True
+ use_position_id: bool = False
+ deterministic: bool = False
+ enable_tf_data_service: bool = False
+ enable_round_robin_tf_data_service: bool = False
+ tf_data_service_job_name: str = 'bert_pretrain'
+ use_v2_feature_names: bool = False
+
+
+@data_loader_factory.register_data_loader_cls(BertPretrainDataConfig)
+class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
+ """Dataset loader for bert-style pretraining with dynamic sequenece length.
+
+ Bucketizes the input id features by the seq_bucket_lengths and features are
+ padded to the bucket boundaries. The mask features are usually short than
+ input id features and can also be dynamic. We require the mask feature lengths
+ within a bucket must be the same. For example, with [128, 256] buckets,
+ the mask features for bucket 128 should always have the length as X and
+ features for bucket 256 should always have the length as Y.
+
+ The dataloader does not filter out empty masks. Make sure to handle this
+ in the model.
+ """
+
+ def __init__(self, params):
+ self._params = params
+ if len(params.seq_bucket_lengths) < 1:
+ raise ValueError('The seq_bucket_lengths cannot be empty.')
+ self._seq_bucket_lengths = params.seq_bucket_lengths
+ self._seq_bucket_window_scale = params.seq_bucket_window_scale
+ self._global_batch_size = params.global_batch_size
+ self._use_next_sentence_label = params.use_next_sentence_label
+ self._use_position_id = params.use_position_id
+ self._drop_remainder = params.drop_remainder
+ self._enable_tf_data_service = params.enable_tf_data_service
+ self._enable_round_robin_tf_data_service = (
+ params.enable_round_robin_tf_data_service)
+ self._mask_keys = [
+ 'masked_lm_positions', 'masked_lm_ids', 'masked_lm_weights'
+ ]
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = {
+ 'input_mask': tf.io.VarLenFeature(tf.int64),
+ 'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
+ 'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
+ 'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
+ }
+ if self._params.use_v2_feature_names:
+ input_ids_key = 'input_word_ids'
+ segment_key = 'input_type_ids'
+ name_to_features.update({
+ input_ids_key: tf.io.VarLenFeature(tf.int64),
+ segment_key: tf.io.VarLenFeature(tf.int64),
+ })
+ else:
+ input_ids_key = 'input_ids'
+ segment_key = 'segment_ids'
+ name_to_features.update({
+ input_ids_key: tf.io.VarLenFeature(tf.int64),
+ segment_key: tf.io.VarLenFeature(tf.int64),
+ })
+ if self._use_next_sentence_label:
+ name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
+ tf.int64)
+ dynamic_keys = [input_ids_key, 'input_mask', segment_key]
+ if self._use_position_id:
+ name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
+ dynamic_keys.append('position_ids')
+
+ example = tf.io.parse_single_example(record, name_to_features)
+ for key in dynamic_keys + self._mask_keys:
+ example[key] = tf.sparse.to_dense(example[key])
+
+ # Truncate padded data after the first non pad in the
+ # sequence length dimension.
+ # Pad before the first non pad from the back should not be removed.
+ mask = tf.math.greater(
+ tf.math.cumsum(example[input_ids_key], reverse=True), 0)
+ for key in dynamic_keys:
+ example[key] = tf.boolean_mask(example[key], mask)
+
+ # masked_lm_ids should be 0 padded.
+ # Change mask features to -1 padding so that we can differentiate
+ # padding from data or from bucketizing.
+ mask = tf.math.not_equal(example['masked_lm_ids'], 0)
+ example['masked_lm_ids'] = tf.where(
+ mask, example['masked_lm_ids'],
+ -tf.ones(tf.shape(example['masked_lm_ids']), dtype=example[key].dtype))
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ # tf.data service uses dataset graph fingerprint to distinguish input
+ # pipeline jobs, thus we sort the keys here to make sure they are generated
+ # in a deterministic order each time the dataset function is traced.
+ for name in sorted(list(example.keys())):
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _bucketize_and_batch(
+ self,
+ dataset,
+ input_context: Optional[tf.distribute.InputContext] = None):
+ """Bucketize by sequence length and batch the datasets."""
+ per_replica_batch_size = input_context.get_per_replica_batch_size(
+ self._global_batch_size) if input_context else self._global_batch_size
+
+ def element_length_func(example, seq_len_dim):
+ return tf.shape(example['input_word_ids'])[seq_len_dim]
+
+ bucket_boundaries = [length + 1 for length in self._seq_bucket_lengths]
+ bucket_batch_sizes = [per_replica_batch_size] * (len(bucket_boundaries) + 1)
+
+ # Bucketize and batch the dataset with per replica batch size first.
+ dataset = dataset.apply(
+ tf.data.experimental.bucket_by_sequence_length(
+ lambda example: tf.cast(element_length_func(example, 0), tf.int32),
+ bucket_boundaries,
+ bucket_batch_sizes,
+ pad_to_bucket_boundary=True,
+ drop_remainder=self._drop_remainder))
+ if input_context:
+ window_size = input_context.num_replicas_in_sync
+ if self._enable_tf_data_service and (
+ not self._enable_round_robin_tf_data_service):
+ # If tf.data service is enabled but round-robin behavior is not enabled,
+ # different TPU workers may fetch data from one tf.data service worker
+ # in different speed. We set the window size to be
+ # `seq_bucket_window_scale` larger to leave buffer if some workers are
+ # fetching data faster than others, so all the data within the same
+ # global batch can still have more chances to be in the same bucket.
+ window_size *= self._seq_bucket_window_scale
+
+ # Group `num_replicas_in_sync` batches from same bucket together, so all
+ # replicas can get the same sequence length for one global step.
+ dataset = dataset.apply(
+ tf.data.experimental.group_by_window(
+ key_func=lambda example: tf.cast( # pylint: disable=g-long-lambda
+ element_length_func(example, 1), tf.int64),
+ reduce_func=lambda _, x: tf.data.Dataset.from_tensors(x),
+ window_size=window_size))
+ dataset = dataset.flat_map(lambda x: x)
+
+ def _remove_pads_from_bucketize(features):
+ # All mask features must have the same effective length.
+ # The real masked ids padding token is -1 and 0 comes from
+ # bucket_by_sequence_length.
+ mask = tf.math.not_equal(features['masked_lm_ids'], 0)
+
+ mask_per_example = tf.math.reduce_sum(tf.cast(mask, tf.int32), axis=1)
+ normalized = tf.cast(
+ mask_per_example / tf.math.reduce_max(mask_per_example), tf.int32)
+ assert_op = tf.debugging.assert_equal(
+ tf.math.reduce_sum(normalized), per_replica_batch_size,
+ 'Number of non padded mask tokens is not the same for each example '
+ 'in the same sequence length.')
+ with tf.control_dependencies([assert_op]):
+ for key in self._mask_keys:
+ features[key] = tf.reshape(
+ tf.boolean_mask(
+ features[key], mask), [per_replica_batch_size, -1])
+ # Revert masked_lm_ids to be 0-padded.
+ mask = tf.math.not_equal(features['masked_lm_ids'], -1)
+ features['masked_lm_ids'] = tf.where(
+ mask, features['masked_lm_ids'],
+ tf.zeros(
+ tf.shape(features['masked_lm_ids']),
+ dtype=features['masked_lm_ids'].dtype))
+ return features
+
+ dataset = dataset.map(_remove_pads_from_bucketize)
+ return dataset
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ params=self._params,
+ decoder_fn=self._decode,
+ parser_fn=self._parse,
+ transform_and_batch_fn=self._bucketize_and_batch)
+ return reader.read(input_context)
diff --git a/modeling/official/nlp/data/pretrain_dynamic_dataloader_test.py b/modeling/official/nlp/data/pretrain_dynamic_dataloader_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e559ea7e36fbbb0b7570bcf3c4282d197d8b10b
--- /dev/null
+++ b/modeling/official/nlp/data/pretrain_dynamic_dataloader_test.py
@@ -0,0 +1,245 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for nlp.data.pretrain_dynamic_dataloader."""
+import os
+
+from absl import logging
+from absl.testing import parameterized
+import numpy as np
+import orbit
+import tensorflow as tf, tf_keras
+
+from tensorflow.python.distribute import combinations
+from tensorflow.python.distribute import strategy_combinations
+from official.nlp.configs import bert
+from official.nlp.configs import encoders
+from official.nlp.data import pretrain_dataloader
+from official.nlp.data import pretrain_dynamic_dataloader
+from official.nlp.tasks import masked_lm
+
+
+def _create_fake_dataset(output_path, seq_length, num_masked_tokens,
+ max_seq_length, num_examples):
+ """Creates a fake dataset."""
+ writer = tf.io.TFRecordWriter(output_path)
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ rng = np.random.default_rng(37)
+ for _ in range(num_examples):
+ features = {}
+ padding = np.zeros(shape=(max_seq_length - seq_length), dtype=np.int32)
+ input_ids = rng.integers(low=1, high=100, size=(seq_length))
+ features['input_ids'] = create_int_feature(
+ np.concatenate((input_ids, padding)))
+ features['input_mask'] = create_int_feature(
+ np.concatenate((np.ones_like(input_ids), padding)))
+ features['segment_ids'] = create_int_feature(
+ np.concatenate((np.ones_like(input_ids), padding)))
+ features['position_ids'] = create_int_feature(
+ np.concatenate((np.ones_like(input_ids), padding)))
+ features['masked_lm_positions'] = create_int_feature(
+ rng.integers(60, size=(num_masked_tokens), dtype=np.int64))
+ features['masked_lm_ids'] = create_int_feature(
+ rng.integers(100, size=(num_masked_tokens), dtype=np.int64))
+ features['masked_lm_weights'] = create_float_feature(
+ np.ones((num_masked_tokens,), dtype=np.float32))
+ features['next_sentence_labels'] = create_int_feature(np.array([0]))
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+class PretrainDynamicDataLoaderTest(tf.test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(
+ combinations.combine(
+ distribution_strategy=[
+ strategy_combinations.cloud_tpu_strategy,
+ ],
+ mode='eager'))
+ def test_distribution_strategy(self, distribution_strategy):
+ max_seq_length = 128
+ batch_size = 8
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ _create_fake_dataset(
+ input_path,
+ seq_length=60,
+ num_masked_tokens=20,
+ max_seq_length=max_seq_length,
+ num_examples=batch_size)
+ data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
+ is_training=False,
+ input_path=input_path,
+ seq_bucket_lengths=[64, 128],
+ global_batch_size=batch_size)
+ dataloader = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
+ data_config)
+ distributed_ds = orbit.utils.make_distributed_dataset(
+ distribution_strategy, dataloader.load)
+ train_iter = iter(distributed_ds)
+ with distribution_strategy.scope():
+ config = masked_lm.MaskedLMConfig(
+ init_checkpoint=self.get_temp_dir(),
+ model=bert.PretrainerConfig(
+ encoders.EncoderConfig(
+ bert=encoders.BertEncoderConfig(
+ vocab_size=30522, num_layers=1)),
+ cls_heads=[
+ bert.ClsHeadConfig(
+ inner_dim=10, num_classes=2, name='next_sentence')
+ ]),
+ train_data=data_config)
+ task = masked_lm.MaskedLMTask(config)
+ model = task.build_model()
+ metrics = task.build_metrics()
+
+ @tf.function
+ def step_fn(features):
+ return task.validation_step(features, model, metrics=metrics)
+
+ distributed_outputs = distribution_strategy.run(
+ step_fn, args=(next(train_iter),))
+ local_results = tf.nest.map_structure(
+ distribution_strategy.experimental_local_results, distributed_outputs)
+ logging.info('Dynamic padding: local_results= %s', str(local_results))
+ dynamic_metrics = {}
+ for metric in metrics:
+ dynamic_metrics[metric.name] = metric.result()
+
+ data_config = pretrain_dataloader.BertPretrainDataConfig(
+ is_training=False,
+ input_path=input_path,
+ seq_length=max_seq_length,
+ max_predictions_per_seq=20,
+ global_batch_size=batch_size)
+ dataloader = pretrain_dataloader.BertPretrainDataLoader(data_config)
+ distributed_ds = orbit.utils.make_distributed_dataset(
+ distribution_strategy, dataloader.load)
+ train_iter = iter(distributed_ds)
+ with distribution_strategy.scope():
+ metrics = task.build_metrics()
+
+ @tf.function
+ def step_fn_b(features):
+ return task.validation_step(features, model, metrics=metrics)
+
+ distributed_outputs = distribution_strategy.run(
+ step_fn_b, args=(next(train_iter),))
+ local_results = tf.nest.map_structure(
+ distribution_strategy.experimental_local_results, distributed_outputs)
+ logging.info('Static padding: local_results= %s', str(local_results))
+ static_metrics = {}
+ for metric in metrics:
+ static_metrics[metric.name] = metric.result()
+ for key in static_metrics:
+ # We need to investigate the differences on losses.
+ if key != 'next_sentence_loss':
+ self.assertEqual(dynamic_metrics[key], static_metrics[key])
+
+ def test_load_dataset(self):
+ tf.random.set_seed(0)
+ max_seq_length = 128
+ batch_size = 2
+ input_path_1 = os.path.join(self.get_temp_dir(), 'train_1.tf_record')
+ _create_fake_dataset(
+ input_path_1,
+ seq_length=60,
+ num_masked_tokens=20,
+ max_seq_length=max_seq_length,
+ num_examples=batch_size)
+ input_path_2 = os.path.join(self.get_temp_dir(), 'train_2.tf_record')
+ _create_fake_dataset(
+ input_path_2,
+ seq_length=100,
+ num_masked_tokens=70,
+ max_seq_length=max_seq_length,
+ num_examples=batch_size)
+ input_paths = ','.join([input_path_1, input_path_2])
+ data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
+ is_training=False,
+ input_path=input_paths,
+ seq_bucket_lengths=[64, 128],
+ use_position_id=True,
+ global_batch_size=batch_size,
+ deterministic=True)
+ dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
+ data_config).load()
+ dataset_it = iter(dataset)
+ features = next(dataset_it)
+ self.assertCountEqual([
+ 'input_word_ids',
+ 'input_mask',
+ 'input_type_ids',
+ 'next_sentence_labels',
+ 'masked_lm_positions',
+ 'masked_lm_ids',
+ 'masked_lm_weights',
+ 'position_ids',
+ ], features.keys())
+ # Sequence length dimension should be bucketized and pad to 64.
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, 64))
+ self.assertEqual(features['input_mask'].shape, (batch_size, 64))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, 64))
+ self.assertEqual(features['position_ids'].shape, (batch_size, 64))
+ self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 20))
+ features = next(dataset_it)
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, 128))
+ self.assertEqual(features['input_mask'].shape, (batch_size, 128))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, 128))
+ self.assertEqual(features['position_ids'].shape, (batch_size, 128))
+ self.assertEqual(features['masked_lm_positions'].shape, (batch_size, 70))
+
+ def test_load_dataset_not_same_masks(self):
+ max_seq_length = 128
+ batch_size = 2
+ input_path_1 = os.path.join(self.get_temp_dir(), 'train_3.tf_record')
+ _create_fake_dataset(
+ input_path_1,
+ seq_length=60,
+ num_masked_tokens=20,
+ max_seq_length=max_seq_length,
+ num_examples=batch_size)
+ input_path_2 = os.path.join(self.get_temp_dir(), 'train_4.tf_record')
+ _create_fake_dataset(
+ input_path_2,
+ seq_length=60,
+ num_masked_tokens=15,
+ max_seq_length=max_seq_length,
+ num_examples=batch_size)
+ input_paths = ','.join([input_path_1, input_path_2])
+ data_config = pretrain_dynamic_dataloader.BertPretrainDataConfig(
+ is_training=False,
+ input_path=input_paths,
+ seq_bucket_lengths=[64, 128],
+ use_position_id=True,
+ global_batch_size=batch_size * 2)
+ dataset = pretrain_dynamic_dataloader.PretrainingDynamicDataLoader(
+ data_config).load()
+ dataset_it = iter(dataset)
+ with self.assertRaisesRegex(
+ tf.errors.InvalidArgumentError, '.*Number of non padded mask tokens.*'):
+ next(dataset_it)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/nlp/data/pretrain_text_dataloader.py b/modeling/official/nlp/data/pretrain_text_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c826b2ef153a1559993a1bf0c7e847925ae36e2
--- /dev/null
+++ b/modeling/official/nlp/data/pretrain_text_dataloader.py
@@ -0,0 +1,226 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Loads text dataset for the BERT pretraining task."""
+import dataclasses
+from typing import List, Mapping, Optional, Text
+
+import tensorflow as tf, tf_keras
+import tensorflow_text as tf_text
+
+from official.common import dataset_fn
+from official.core import config_definitions as cfg
+from official.core import input_reader
+from official.nlp.data import data_loader
+from official.nlp.data import data_loader_factory
+from official.nlp.modeling.ops import segment_extractor
+
+
+@dataclasses.dataclass
+class BertPretrainTextDataConfig(cfg.DataConfig):
+ """Data config for BERT pretraining task (tasks/masked_lm) from text."""
+ input_path: str = ""
+ doc_batch_size: int = 8
+ global_batch_size: int = 512
+ is_training: bool = True
+ seq_length: int = 512
+ max_predictions_per_seq: int = 76
+ use_next_sentence_label: bool = True
+ # The name of the text feature fields. The text features will be
+ # concatenated in order.
+ # Note: More than 1 field name is not compatible with NSP.
+ text_field_names: Optional[List[str]] = dataclasses.field(
+ default_factory=lambda: ["text"])
+ vocab_file_path: str = ""
+ masking_rate: float = 0.15
+ use_whole_word_masking: bool = False
+ file_type: str = "tfrecord"
+
+
+_CLS_TOKEN = b"[CLS]"
+_SEP_TOKEN = b"[SEP]"
+_MASK_TOKEN = b"[MASK]"
+_NUM_OOV_BUCKETS = 1
+# Accounts for [CLS] and 2 x [SEP] tokens
+_NUM_SPECIAL_TOKENS = 3
+
+
+@data_loader_factory.register_data_loader_cls(BertPretrainTextDataConfig)
+class BertPretrainTextDataLoader(data_loader.DataLoader):
+ """A class to load text dataset for BERT pretraining task."""
+
+ def __init__(self, params):
+ """Inits `BertPretrainTextDataLoader` class.
+
+ Args:
+ params: A `BertPretrainTextDataConfig` object.
+ """
+ if len(params.text_field_names) > 1 and params.use_next_sentence_label:
+ raise ValueError("Currently there is no support for more than text field "
+ "while generating next sentence labels.")
+
+ self._params = params
+ self._seq_length = params.seq_length
+ self._max_predictions_per_seq = params.max_predictions_per_seq
+ self._use_next_sentence_label = params.use_next_sentence_label
+ self._masking_rate = params.masking_rate
+ self._use_whole_word_masking = params.use_whole_word_masking
+
+ lookup_table_init = tf.lookup.TextFileInitializer(
+ params.vocab_file_path,
+ key_dtype=tf.string,
+ key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
+ value_dtype=tf.int64,
+ value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
+ self._vocab_lookup_table = tf.lookup.StaticVocabularyTable(
+ lookup_table_init,
+ num_oov_buckets=_NUM_OOV_BUCKETS,
+ lookup_key_dtype=tf.string)
+
+ self._cls_token = self._vocab_lookup_table.lookup(tf.constant(_CLS_TOKEN))
+ self._sep_token = self._vocab_lookup_table.lookup(tf.constant(_SEP_TOKEN))
+ self._mask_token = self._vocab_lookup_table.lookup(tf.constant(_MASK_TOKEN))
+
+ # -_NUM_OOV_BUCKETS to offset unused OOV bucket.
+ self._vocab_size = self._vocab_lookup_table.size() - _NUM_OOV_BUCKETS
+
+ def _decode(self, record: tf.Tensor) -> Mapping[Text, tf.Tensor]:
+ """Decodes a serialized tf.Example."""
+ name_to_features = {}
+ for text_field_name in self._params.text_field_names:
+ name_to_features[text_field_name] = tf.io.FixedLenFeature([], tf.string)
+ return tf.io.parse_single_example(record, name_to_features)
+
+ def _tokenize(self, segments):
+ """Tokenize the input segments."""
+ # Tokenize segments
+ tokenizer = tf_text.BertTokenizer(
+ self._vocab_lookup_table, token_out_type=tf.int64)
+
+ if self._use_whole_word_masking:
+ # tokenize the segments which should have the shape:
+ # [num_sentence, (num_words), (num_wordpieces)]
+ segments = [tokenizer.tokenize(s) for s in segments]
+ else:
+ # tokenize the segments and merge out the token dimension so that each
+ # segment has the shape: [num_sentence, (num_wordpieces)]
+ segments = [tokenizer.tokenize(s).merge_dims(-2, -1) for s in segments]
+
+ # Truncate inputs
+ trimmer = tf_text.WaterfallTrimmer(
+ self._seq_length - _NUM_SPECIAL_TOKENS, axis=-1)
+ truncated_segments = trimmer.trim(segments)
+
+ # Combine segments, get segment ids and add special tokens
+ return tf_text.combine_segments(
+ truncated_segments,
+ start_of_sequence_id=self._cls_token,
+ end_of_segment_id=self._sep_token)
+
+ def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ if self._use_next_sentence_label:
+ input_text = record[self._params.text_field_names[0]]
+ # Split sentences
+ sentence_breaker = tf_text.RegexSplitter()
+ sentences = sentence_breaker.split(input_text)
+
+ # Extract next-sentence-prediction labels and segments
+ next_or_random_segment, is_next = (
+ segment_extractor.get_next_sentence_labels(sentences))
+ # merge dims to change shape from [num_docs, (num_segments)] to
+ # [total_num_segments]
+ is_next = is_next.merge_dims(-2, -1)
+
+ # construct segments with shape [(num_sentence)]
+ segments = [
+ sentences.merge_dims(-2, -1),
+ next_or_random_segment.merge_dims(-2, -1)
+ ]
+ else:
+ segments = [record[name] for name in self._params.text_field_names]
+
+ segments_combined, segment_ids = self._tokenize(segments)
+
+ # Dynamic masking
+ item_selector = tf_text.RandomItemSelector(
+ self._max_predictions_per_seq,
+ selection_rate=self._masking_rate,
+ unselectable_ids=[self._cls_token, self._sep_token],
+ shuffle_fn=(tf.identity if self._params.deterministic else None))
+ values_chooser = tf_text.MaskValuesChooser(
+ vocab_size=self._vocab_size, mask_token=self._mask_token)
+ masked_input_ids, masked_lm_positions, masked_lm_ids = (
+ tf_text.mask_language_model(
+ segments_combined,
+ item_selector=item_selector,
+ mask_values_chooser=values_chooser,
+ ))
+
+ # Pad out to fixed shape and get input mask.
+ seq_lengths = {
+ "input_word_ids": self._seq_length,
+ "input_type_ids": self._seq_length,
+ "masked_lm_positions": self._max_predictions_per_seq,
+ "masked_lm_ids": self._max_predictions_per_seq,
+ }
+ model_inputs = {
+ "input_word_ids": masked_input_ids,
+ "input_type_ids": segment_ids,
+ "masked_lm_positions": masked_lm_positions,
+ "masked_lm_ids": masked_lm_ids,
+ }
+ padded_inputs_and_mask = tf.nest.map_structure(tf_text.pad_model_inputs,
+ model_inputs, seq_lengths)
+ model_inputs = {
+ k: padded_inputs_and_mask[k][0] for k in padded_inputs_and_mask
+ }
+ model_inputs["masked_lm_weights"] = tf.cast(
+ padded_inputs_and_mask["masked_lm_ids"][1], tf.float32)
+ model_inputs["input_mask"] = padded_inputs_and_mask["input_word_ids"][1]
+
+ if self._use_next_sentence_label:
+ model_inputs["next_sentence_labels"] = is_next
+
+ for name in model_inputs:
+ t = model_inputs[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ model_inputs[name] = t
+
+ return model_inputs
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+
+ def _batch_docs(dataset, input_context):
+ per_core_doc_batch_size = (
+ input_context.get_per_replica_batch_size(self._params.doc_batch_size)
+ if input_context else self._params.doc_batch_size)
+ return dataset.batch(per_core_doc_batch_size)
+
+ reader = input_reader.InputReader(
+ params=self._params,
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
+ decoder_fn=self._decode if self._params.input_path else None,
+ transform_and_batch_fn=_batch_docs
+ if self._use_next_sentence_label else None,
+ postprocess_fn=self._bert_preprocess)
+ transformed_inputs = reader.read(input_context)
+ per_core_example_batch_size = (
+ input_context.get_per_replica_batch_size(self._params.global_batch_size)
+ if input_context else self._params.global_batch_size)
+ batched_inputs = transformed_inputs.unbatch().batch(
+ per_core_example_batch_size, self._params.drop_remainder)
+ return batched_inputs.prefetch(tf.data.experimental.AUTOTUNE)
diff --git a/modeling/official/nlp/data/question_answering_dataloader.py b/modeling/official/nlp/data/question_answering_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0d7df3774a7ea700869405eac765b7e346b7dc5
--- /dev/null
+++ b/modeling/official/nlp/data/question_answering_dataloader.py
@@ -0,0 +1,115 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Loads dataset for the question answering (e.g, SQuAD) task."""
+import dataclasses
+from typing import Mapping, Optional
+
+import tensorflow as tf, tf_keras
+from official.common import dataset_fn
+from official.core import config_definitions as cfg
+from official.core import input_reader
+from official.nlp.data import data_loader
+from official.nlp.data import data_loader_factory
+
+
+@dataclasses.dataclass
+class QADataConfig(cfg.DataConfig):
+ """Data config for question answering task (tasks/question_answering)."""
+ # For training, `input_path` is expected to be a pre-processed TFRecord file,
+ # while for evaluation, it is expected to be a raw JSON file (b/173814590).
+ input_path: str = ''
+ global_batch_size: int = 48
+ is_training: bool = True
+ seq_length: int = 384
+ # Settings below are question answering specific.
+ version_2_with_negative: bool = False
+ # Settings below are only used for eval mode.
+ input_preprocessed_data_path: str = ''
+ doc_stride: int = 128
+ query_length: int = 64
+ # The path to the vocab file of word piece tokenizer or the
+ # model of the sentence piece tokenizer.
+ vocab_file: str = ''
+ tokenization: str = 'WordPiece' # WordPiece or SentencePiece
+ do_lower_case: bool = True
+ xlnet_format: bool = False
+ file_type: str = 'tfrecord'
+
+
+@data_loader_factory.register_data_loader_cls(QADataConfig)
+class QuestionAnsweringDataLoader(data_loader.DataLoader):
+ """A class to load dataset for sentence prediction (classification) task."""
+
+ def __init__(self, params):
+ self._params = params
+ self._seq_length = params.seq_length
+ self._is_training = params.is_training
+ self._xlnet_format = params.xlnet_format
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ }
+ if self._xlnet_format:
+ name_to_features['class_index'] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features['paragraph_mask'] = tf.io.FixedLenFeature(
+ [self._seq_length], tf.int64)
+ if self._is_training:
+ name_to_features['is_impossible'] = tf.io.FixedLenFeature([], tf.int64)
+
+ if self._is_training:
+ name_to_features['start_positions'] = tf.io.FixedLenFeature([], tf.int64)
+ name_to_features['end_positions'] = tf.io.FixedLenFeature([], tf.int64)
+ else:
+ name_to_features['unique_ids'] = tf.io.FixedLenFeature([], tf.int64)
+ example = tf.io.parse_single_example(record, name_to_features)
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in example:
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _parse(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ x, y = {}, {}
+ for name, tensor in record.items():
+ if name in ('start_positions', 'end_positions', 'is_impossible'):
+ y[name] = tensor
+ elif name == 'input_ids':
+ x['input_word_ids'] = tensor
+ elif name == 'segment_ids':
+ x['input_type_ids'] = tensor
+ else:
+ x[name] = tensor
+ if name == 'start_positions' and self._xlnet_format:
+ x[name] = tensor
+ return (x, y)
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ params=self._params,
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
+ decoder_fn=self._decode,
+ parser_fn=self._parse)
+ return reader.read(input_context)
diff --git a/modeling/official/nlp/data/question_answering_dataloader_test.py b/modeling/official/nlp/data/question_answering_dataloader_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b47ef3c71eb3c55671d320d3121bc67039dd83
--- /dev/null
+++ b/modeling/official/nlp/data/question_answering_dataloader_test.py
@@ -0,0 +1,74 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.data.question_answering_dataloader."""
+import os
+
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.nlp.data import question_answering_dataloader
+
+
+def _create_fake_dataset(output_path, seq_length):
+ """Creates a fake dataset."""
+ writer = tf.io.TFRecordWriter(output_path)
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ for _ in range(100):
+ features = {}
+ input_ids = np.random.randint(100, size=(seq_length))
+ features['input_ids'] = create_int_feature(input_ids)
+ features['input_mask'] = create_int_feature(np.ones_like(input_ids))
+ features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
+ features['start_positions'] = create_int_feature(np.array([0]))
+ features['end_positions'] = create_int_feature(np.array([10]))
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+class QuestionAnsweringDataTest(tf.test.TestCase):
+
+ def test_load_dataset(self):
+ seq_length = 128
+ batch_size = 10
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ _create_fake_dataset(input_path, seq_length)
+ data_config = question_answering_dataloader.QADataConfig(
+ is_training=True,
+ input_path=input_path,
+ seq_length=seq_length,
+ global_batch_size=batch_size)
+ dataset = question_answering_dataloader.QuestionAnsweringDataLoader(
+ data_config).load()
+ features, labels = next(iter(dataset))
+
+ self.assertCountEqual(['input_word_ids', 'input_mask', 'input_type_ids'],
+ features.keys())
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
+
+ self.assertCountEqual(['start_positions', 'end_positions'], labels.keys())
+ self.assertEqual(labels['start_positions'].shape, (batch_size,))
+ self.assertEqual(labels['end_positions'].shape, (batch_size,))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/nlp/data/sentence_prediction_dataloader.py b/modeling/official/nlp/data/sentence_prediction_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..e153a0645625bc6e285cefdec96376df0dd0ee27
--- /dev/null
+++ b/modeling/official/nlp/data/sentence_prediction_dataloader.py
@@ -0,0 +1,267 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Loads dataset for the sentence prediction (classification) task."""
+import dataclasses
+import functools
+from typing import List, Mapping, Optional, Tuple
+
+import tensorflow as tf, tf_keras
+import tensorflow_hub as hub
+
+from official.common import dataset_fn
+from official.core import config_definitions as cfg
+from official.core import input_reader
+from official.nlp import modeling
+from official.nlp.data import data_loader
+from official.nlp.data import data_loader_factory
+
+LABEL_TYPES_MAP = {'int': tf.int64, 'float': tf.float32}
+
+
+@dataclasses.dataclass
+class SentencePredictionDataConfig(cfg.DataConfig):
+ """Data config for sentence prediction task (tasks/sentence_prediction)."""
+ input_path: str = ''
+ global_batch_size: int = 32
+ is_training: bool = True
+ seq_length: int = 128
+ label_type: str = 'int'
+ # Whether to include the example id number.
+ include_example_id: bool = False
+ label_field: str = 'label_ids'
+ # Maps the key in TfExample to feature name.
+ # E.g 'label_ids' to 'next_sentence_labels'
+ label_name: Optional[Tuple[str, str]] = None
+ # Either tfrecord, sstable, or recordio.
+ file_type: str = 'tfrecord'
+
+
+@data_loader_factory.register_data_loader_cls(SentencePredictionDataConfig)
+class SentencePredictionDataLoader(data_loader.DataLoader):
+ """A class to load dataset for sentence prediction (classification) task."""
+
+ def __init__(self, params):
+ self._params = params
+ self._seq_length = params.seq_length
+ self._include_example_id = params.include_example_id
+ self._label_field = params.label_field
+ if params.label_name:
+ self._label_name_mapping = dict([params.label_name])
+ else:
+ self._label_name_mapping = dict()
+
+ def name_to_features_spec(self):
+ """Defines features to decode. Subclass may override to append features."""
+ label_type = LABEL_TYPES_MAP[self._params.label_type]
+ name_to_features = {
+ 'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ 'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
+ self._label_field: tf.io.FixedLenFeature([], label_type),
+ }
+ if self._include_example_id:
+ name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
+
+ return name_to_features
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ example = tf.io.parse_single_example(record, self.name_to_features_spec())
+
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in example:
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def _parse(self, record: Mapping[str, tf.Tensor]):
+ """Parses raw tensors into a dict of tensors to be consumed by the model."""
+ key_mapping = {
+ 'input_ids': 'input_word_ids',
+ 'input_mask': 'input_mask',
+ 'segment_ids': 'input_type_ids'
+ }
+ ret = {}
+ for record_key in record:
+ if record_key in key_mapping:
+ ret[key_mapping[record_key]] = record[record_key]
+ else:
+ ret[record_key] = record[record_key]
+
+ if self._label_field in self._label_name_mapping:
+ ret[self._label_name_mapping[self._label_field]] = record[
+ self._label_field]
+
+ return ret
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
+ params=self._params,
+ decoder_fn=self._decode,
+ parser_fn=self._parse)
+ return reader.read(input_context)
+
+
+@dataclasses.dataclass
+class SentencePredictionTextDataConfig(cfg.DataConfig):
+ """Data config for sentence prediction task with raw text."""
+ # Either set `input_path`...
+ input_path: str = ''
+ # Either `int` or `float`.
+ label_type: str = 'int'
+ # ...or `tfds_name` and `tfds_split` to specify input.
+ tfds_name: str = ''
+ tfds_split: str = ''
+ # The name of the text feature fields. The text features will be
+ # concatenated in order.
+ text_fields: Optional[List[str]] = None
+ label_field: str = 'label'
+ global_batch_size: int = 32
+ seq_length: int = 128
+ is_training: bool = True
+ # Either build preprocessing with Python code by specifying these values
+ # for modeling.layers.BertTokenizer()/SentencepieceTokenizer()....
+ tokenization: str = 'WordPiece' # WordPiece or SentencePiece
+ # Text vocab file if tokenization is WordPiece, or sentencepiece.ModelProto
+ # file if tokenization is SentencePiece.
+ vocab_file: str = ''
+ lower_case: bool = True
+ # ...or load preprocessing from a SavedModel at this location.
+ preprocessing_hub_module_url: str = ''
+ # Either tfrecord or sstsable or recordio.
+ file_type: str = 'tfrecord'
+ include_example_id: bool = False
+
+
+class TextProcessor(tf.Module):
+ """Text features processing for sentence prediction task."""
+
+ def __init__(self,
+ seq_length: int,
+ vocab_file: Optional[str] = None,
+ tokenization: Optional[str] = None,
+ lower_case: Optional[bool] = True,
+ preprocessing_hub_module_url: Optional[str] = None):
+ if preprocessing_hub_module_url:
+ self._preprocessing_hub_module = hub.load(preprocessing_hub_module_url)
+ self._tokenizer = self._preprocessing_hub_module.tokenize
+ self._pack_inputs = functools.partial(
+ self._preprocessing_hub_module.bert_pack_inputs,
+ seq_length=seq_length)
+ return
+
+ if tokenization == 'WordPiece':
+ self._tokenizer = modeling.layers.BertTokenizer(
+ vocab_file=vocab_file, lower_case=lower_case)
+ elif tokenization == 'SentencePiece':
+ self._tokenizer = modeling.layers.SentencepieceTokenizer(
+ model_file_path=vocab_file,
+ lower_case=lower_case,
+ strip_diacritics=True) # Strip diacritics to follow ALBERT model
+ else:
+ raise ValueError('Unsupported tokenization: %s' % tokenization)
+
+ self._pack_inputs = modeling.layers.BertPackInputs(
+ seq_length=seq_length,
+ special_tokens_dict=self._tokenizer.get_special_tokens_dict())
+
+ def __call__(self, segments):
+ segments = [self._tokenizer(s) for s in segments]
+ # BertTokenizer returns a RaggedTensor with shape [batch, word, subword],
+ # and SentencepieceTokenizer returns a RaggedTensor with shape
+ # [batch, sentencepiece],
+ segments = [
+ tf.cast(x.merge_dims(1, -1) if x.shape.rank > 2 else x, tf.int32)
+ for x in segments
+ ]
+ return self._pack_inputs(segments)
+
+
+@data_loader_factory.register_data_loader_cls(SentencePredictionTextDataConfig)
+class SentencePredictionTextDataLoader(data_loader.DataLoader):
+ """Loads dataset with raw text for sentence prediction task."""
+
+ def __init__(self, params):
+ if bool(params.tfds_name) != bool(params.tfds_split):
+ raise ValueError('`tfds_name` and `tfds_split` should be specified or '
+ 'unspecified at the same time.')
+ if bool(params.tfds_name) == bool(params.input_path):
+ raise ValueError('Must specify either `tfds_name` and `tfds_split` '
+ 'or `input_path`.')
+ if not params.text_fields:
+ raise ValueError('Unexpected empty text fields.')
+ if bool(params.vocab_file) == bool(params.preprocessing_hub_module_url):
+ raise ValueError('Must specify exactly one of vocab_file (with matching '
+ 'lower_case flag) or preprocessing_hub_module_url.')
+
+ self._params = params
+ self._text_fields = params.text_fields
+ self._label_field = params.label_field
+ self._label_type = params.label_type
+ self._include_example_id = params.include_example_id
+ self._text_processor = TextProcessor(
+ seq_length=params.seq_length,
+ vocab_file=params.vocab_file,
+ tokenization=params.tokenization,
+ lower_case=params.lower_case,
+ preprocessing_hub_module_url=params.preprocessing_hub_module_url)
+
+ def _bert_preprocess(self, record: Mapping[str, tf.Tensor]):
+ """Berts preprocess."""
+ segments = [record[x] for x in self._text_fields]
+ model_inputs = self._text_processor(segments)
+ for key in record:
+ if key not in self._text_fields:
+ model_inputs[key] = record[key]
+ return model_inputs
+
+ def name_to_features_spec(self):
+ name_to_features = {}
+ for text_field in self._text_fields:
+ name_to_features[text_field] = tf.io.FixedLenFeature([], tf.string)
+
+ label_type = LABEL_TYPES_MAP[self._label_type]
+ name_to_features[self._label_field] = tf.io.FixedLenFeature([], label_type)
+ if self._include_example_id:
+ name_to_features['example_id'] = tf.io.FixedLenFeature([], tf.int64)
+ return name_to_features
+
+ def _decode(self, record: tf.Tensor):
+ """Decodes a serialized tf.Example."""
+ example = tf.io.parse_single_example(record, self.name_to_features_spec())
+ # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
+ # So cast all int64 to int32.
+ for name in example:
+ t = example[name]
+ if t.dtype == tf.int64:
+ t = tf.cast(t, tf.int32)
+ example[name] = t
+
+ return example
+
+ def load(self, input_context: Optional[tf.distribute.InputContext] = None):
+ """Returns a tf.dataset.Dataset."""
+ reader = input_reader.InputReader(
+ dataset_fn=dataset_fn.pick_dataset_fn(self._params.file_type),
+ decoder_fn=self._decode if self._params.input_path else None,
+ params=self._params,
+ postprocess_fn=self._bert_preprocess)
+ return reader.read(input_context)
diff --git a/modeling/official/nlp/data/sentence_prediction_dataloader_test.py b/modeling/official/nlp/data/sentence_prediction_dataloader_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fe1916757772099e6ed0d50a85ed5df71d9fc6c
--- /dev/null
+++ b/modeling/official/nlp/data/sentence_prediction_dataloader_test.py
@@ -0,0 +1,290 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for official.nlp.data.sentence_prediction_dataloader."""
+import os
+
+from absl.testing import parameterized
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from sentencepiece import SentencePieceTrainer
+from official.nlp.data import sentence_prediction_dataloader as loader
+
+
+def _create_fake_preprocessed_dataset(output_path, seq_length, label_type):
+ """Creates a fake dataset."""
+ writer = tf.io.TFRecordWriter(output_path)
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ for _ in range(100):
+ features = {}
+ input_ids = np.random.randint(100, size=(seq_length))
+ features['input_ids'] = create_int_feature(input_ids)
+ features['input_mask'] = create_int_feature(np.ones_like(input_ids))
+ features['segment_ids'] = create_int_feature(np.ones_like(input_ids))
+
+ if label_type == 'int':
+ features['label_ids'] = create_int_feature([1])
+ elif label_type == 'float':
+ features['label_ids'] = create_float_feature([0.5])
+ else:
+ raise ValueError('Unsupported label_type: %s' % label_type)
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+def _create_fake_raw_dataset(output_path, text_fields, label_type):
+ """Creates a fake tf record file."""
+ writer = tf.io.TFRecordWriter(output_path)
+
+ def create_str_feature(value):
+ f = tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
+ return f
+
+ def create_int_feature(values):
+ f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
+ return f
+
+ def create_float_feature(values):
+ f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
+ return f
+
+ for _ in range(100):
+ features = {}
+ for text_field in text_fields:
+ features[text_field] = create_str_feature([b'hello world'])
+
+ if label_type == 'int':
+ features['label'] = create_int_feature([0])
+ elif label_type == 'float':
+ features['label'] = create_float_feature([0.5])
+ else:
+ raise ValueError('Unexpected label_type: %s' % label_type)
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ writer.write(tf_example.SerializeToString())
+ writer.close()
+
+
+def _create_fake_sentencepiece_model(output_dir):
+ vocab = ['a', 'b', 'c', 'd', 'e', 'abc', 'def', 'ABC', 'DEF']
+ model_prefix = os.path.join(output_dir, 'spm_model')
+ input_text_file_path = os.path.join(output_dir, 'train_input.txt')
+ with tf.io.gfile.GFile(input_text_file_path, 'w') as f:
+ f.write(' '.join(vocab + ['\n']))
+ # Add 7 more tokens: , , [CLS], [SEP], [MASK], , .
+ full_vocab_size = len(vocab) + 7
+ flags = dict(
+ model_prefix=model_prefix,
+ model_type='word',
+ input=input_text_file_path,
+ pad_id=0,
+ unk_id=1,
+ control_symbols='[CLS],[SEP],[MASK]',
+ vocab_size=full_vocab_size,
+ bos_id=full_vocab_size - 2,
+ eos_id=full_vocab_size - 1)
+ SentencePieceTrainer.Train(' '.join(
+ ['--{}={}'.format(k, v) for k, v in flags.items()]))
+ return model_prefix + '.model'
+
+
+def _create_fake_vocab_file(vocab_file_path):
+ tokens = ['[PAD]']
+ for i in range(1, 100):
+ tokens.append('[unused%d]' % i)
+ tokens.extend(['[UNK]', '[CLS]', '[SEP]', '[MASK]', 'hello', 'world'])
+ with tf.io.gfile.GFile(vocab_file_path, 'w') as outfile:
+ outfile.write('\n'.join(tokens))
+
+
+class SentencePredictionDataTest(tf.test.TestCase, parameterized.TestCase):
+
+ @parameterized.parameters(('int', tf.int32), ('float', tf.float32))
+ def test_load_dataset(self, label_type, expected_label_type):
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ batch_size = 10
+ seq_length = 128
+ _create_fake_preprocessed_dataset(input_path, seq_length, label_type)
+ data_config = loader.SentencePredictionDataConfig(
+ input_path=input_path,
+ seq_length=seq_length,
+ global_batch_size=batch_size,
+ label_type=label_type)
+ dataset = loader.SentencePredictionDataLoader(data_config).load()
+ features = next(iter(dataset))
+ self.assertCountEqual(
+ ['input_word_ids', 'input_type_ids', 'input_mask', 'label_ids'],
+ features.keys())
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['label_ids'].shape, (batch_size,))
+ self.assertEqual(features['label_ids'].dtype, expected_label_type)
+
+ def test_load_dataset_with_label_mapping(self):
+ input_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ batch_size = 10
+ seq_length = 128
+ _create_fake_preprocessed_dataset(input_path, seq_length, 'int')
+ data_config = loader.SentencePredictionDataConfig(
+ input_path=input_path,
+ seq_length=seq_length,
+ global_batch_size=batch_size,
+ label_type='int',
+ label_name=('label_ids', 'next_sentence_labels'))
+ dataset = loader.SentencePredictionDataLoader(data_config).load()
+ features = next(iter(dataset))
+ self.assertCountEqual([
+ 'input_word_ids', 'input_mask', 'input_type_ids',
+ 'next_sentence_labels', 'label_ids'
+ ], features.keys())
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['label_ids'].shape, (batch_size,))
+ self.assertEqual(features['label_ids'].dtype, tf.int32)
+ self.assertEqual(features['next_sentence_labels'].shape, (batch_size,))
+ self.assertEqual(features['next_sentence_labels'].dtype, tf.int32)
+
+
+class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
+ parameterized.TestCase):
+
+ @parameterized.parameters(True, False)
+ def test_python_wordpiece_preprocessing(self, use_tfds):
+ batch_size = 10
+ seq_length = 256 # Non-default value.
+ lower_case = True
+
+ tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ text_fields = ['sentence1', 'sentence2']
+ if not use_tfds:
+ _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
+
+ vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
+ _create_fake_vocab_file(vocab_file_path)
+
+ data_config = loader.SentencePredictionTextDataConfig(
+ input_path='' if use_tfds else tf_record_path,
+ tfds_name='glue/mrpc' if use_tfds else '',
+ tfds_split='train' if use_tfds else '',
+ text_fields=text_fields,
+ global_batch_size=batch_size,
+ seq_length=seq_length,
+ is_training=True,
+ lower_case=lower_case,
+ vocab_file=vocab_file_path)
+ dataset = loader.SentencePredictionTextDataLoader(data_config).load()
+ features = next(iter(dataset))
+ label_field = data_config.label_field
+ expected_keys = [
+ 'input_word_ids', 'input_type_ids', 'input_mask', label_field
+ ]
+ if use_tfds:
+ expected_keys += ['idx']
+ self.assertCountEqual(expected_keys, features.keys())
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features[label_field].shape, (batch_size,))
+
+ @parameterized.parameters(True, False)
+ def test_python_sentencepiece_preprocessing(self, use_tfds):
+ batch_size = 10
+ seq_length = 256 # Non-default value.
+ lower_case = True
+
+ tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ text_fields = ['sentence1', 'sentence2']
+ if not use_tfds:
+ _create_fake_raw_dataset(tf_record_path, text_fields, label_type='int')
+
+ sp_model_file_path = _create_fake_sentencepiece_model(self.get_temp_dir())
+ data_config = loader.SentencePredictionTextDataConfig(
+ input_path='' if use_tfds else tf_record_path,
+ tfds_name='glue/mrpc' if use_tfds else '',
+ tfds_split='train' if use_tfds else '',
+ text_fields=text_fields,
+ global_batch_size=batch_size,
+ seq_length=seq_length,
+ is_training=True,
+ lower_case=lower_case,
+ tokenization='SentencePiece',
+ vocab_file=sp_model_file_path,
+ )
+ dataset = loader.SentencePredictionTextDataLoader(data_config).load()
+ features = next(iter(dataset))
+ label_field = data_config.label_field
+ expected_keys = [
+ 'input_word_ids', 'input_type_ids', 'input_mask', label_field
+ ]
+ if use_tfds:
+ expected_keys += ['idx']
+ self.assertCountEqual(expected_keys, features.keys())
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features[label_field].shape, (batch_size,))
+
+ @parameterized.parameters(True, False)
+ def test_saved_model_preprocessing(self, use_tfds):
+ batch_size = 10
+ seq_length = 256 # Non-default value.
+
+ tf_record_path = os.path.join(self.get_temp_dir(), 'train.tf_record')
+ text_fields = ['sentence1', 'sentence2']
+ if not use_tfds:
+ _create_fake_raw_dataset(tf_record_path, text_fields, label_type='float')
+
+ vocab_file_path = os.path.join(self.get_temp_dir(), 'vocab.txt')
+ _create_fake_vocab_file(vocab_file_path)
+ data_config = loader.SentencePredictionTextDataConfig(
+ input_path='' if use_tfds else tf_record_path,
+ tfds_name='glue/mrpc' if use_tfds else '',
+ tfds_split='train' if use_tfds else '',
+ text_fields=text_fields,
+ global_batch_size=batch_size,
+ seq_length=seq_length,
+ is_training=True,
+ preprocessing_hub_module_url=(
+ 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'),
+ label_type='int' if use_tfds else 'float',
+ )
+ dataset = loader.SentencePredictionTextDataLoader(data_config).load()
+ features = next(iter(dataset))
+ label_field = data_config.label_field
+ expected_keys = [
+ 'input_word_ids', 'input_type_ids', 'input_mask', label_field
+ ]
+ if use_tfds:
+ expected_keys += ['idx']
+ self.assertCountEqual(expected_keys, features.keys())
+ self.assertEqual(features['input_word_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_mask'].shape, (batch_size, seq_length))
+ self.assertEqual(features['input_type_ids'].shape, (batch_size, seq_length))
+ self.assertEqual(features[label_field].shape, (batch_size,))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/modeling/official/nlp/data/sentence_retrieval_lib.py b/modeling/official/nlp/data/sentence_retrieval_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c37cab366e1b6cb9e657bb14c421dd548808d3
--- /dev/null
+++ b/modeling/official/nlp/data/sentence_retrieval_lib.py
@@ -0,0 +1,166 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""BERT library to process data for cross lingual sentence retrieval task."""
+
+import os
+
+from absl import logging
+from official.nlp.data import classifier_data_lib
+from official.nlp.tools import tokenization
+
+
+class BuccProcessor(classifier_data_lib.DataProcessor):
+ """Procssor for Xtreme BUCC data set."""
+ supported_languages = ["de", "fr", "ru", "zh"]
+
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
+ super(BuccProcessor, self).__init__(process_text_fn)
+ self.languages = BuccProcessor.supported_languages
+
+ def get_dev_examples(self, data_dir, file_pattern):
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
+ "sample")
+
+ def get_test_examples(self, data_dir, file_pattern):
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
+ "test")
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "BUCC"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "%s-%s" % (set_type, i)
+ example_id = int(line[0].split("-")[1])
+ text_a = self.process_text_fn(line[1])
+ examples.append(
+ classifier_data_lib.InputExample(
+ guid=guid, text_a=text_a, example_id=example_id))
+ return examples
+
+
+class TatoebaProcessor(classifier_data_lib.DataProcessor):
+ """Procssor for Xtreme Tatoeba data set."""
+ supported_languages = [
+ "af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
+ "he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
+ "nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
+ ]
+
+ def __init__(self, process_text_fn=tokenization.convert_to_unicode):
+ super(TatoebaProcessor, self).__init__(process_text_fn)
+ self.languages = TatoebaProcessor.supported_languages
+
+ def get_test_examples(self, data_dir, file_path):
+ return self._create_examples(
+ self._read_tsv(os.path.join(data_dir, file_path)), "test")
+
+ @staticmethod
+ def get_processor_name():
+ """See base class."""
+ return "TATOEBA"
+
+ def _create_examples(self, lines, set_type):
+ """Creates examples for the training and dev sets."""
+ examples = []
+ for (i, line) in enumerate(lines):
+ guid = "%s-%s" % (set_type, i)
+ text_a = self.process_text_fn(line[0])
+ examples.append(
+ classifier_data_lib.InputExample(
+ guid=guid, text_a=text_a, example_id=i))
+ return examples
+
+
+def generate_sentence_retrevial_tf_record(processor,
+ data_dir,
+ tokenizer,
+ eval_data_output_path=None,
+ test_data_output_path=None,
+ max_seq_length=128):
+ """Generates the tf records for retrieval tasks.
+
+ Args:
+ processor: Input processor object to be used for generating data. Subclass
+ of `DataProcessor`.
+ data_dir: Directory that contains train/eval data to process. Data files
+ should be in from.
+ tokenizer: The tokenizer to be applied on the data.
+ eval_data_output_path: Output to which processed tf record for evaluation
+ will be saved.
+ test_data_output_path: Output to which processed tf record for testing
+ will be saved. Must be a pattern template with {} if processor has
+ language specific test data.
+ max_seq_length: Maximum sequence length of the to be generated
+ training/eval data.
+
+ Returns:
+ A dictionary containing input meta data.
+ """
+ assert eval_data_output_path or test_data_output_path
+
+ if processor.get_processor_name() == "BUCC":
+ path_pattern = "{}-en.{{}}.{}"
+
+ if processor.get_processor_name() == "TATOEBA":
+ path_pattern = "{}-en.{}"
+
+ meta_data = {
+ "processor_type": processor.get_processor_name(),
+ "max_seq_length": max_seq_length,
+ "number_eval_data": {},
+ "number_test_data": {},
+ }
+ logging.info("Start to process %s task data", processor.get_processor_name())
+
+ for lang_a in processor.languages:
+ for lang_b in [lang_a, "en"]:
+ if eval_data_output_path:
+ eval_input_data_examples = processor.get_dev_examples(
+ data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
+
+ num_eval_data = len(eval_input_data_examples)
+ logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
+ lang_a, lang_b)
+ output_file = os.path.join(
+ eval_data_output_path,
+ "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
+ classifier_data_lib.file_based_convert_examples_to_features(
+ eval_input_data_examples, None, max_seq_length, tokenizer,
+ output_file, None)
+ meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
+
+ if test_data_output_path:
+ test_input_data_examples = processor.get_test_examples(
+ data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
+
+ num_test_data = len(test_input_data_examples)
+ logging.info("Processing %d test examples of %s-en.%s", num_test_data,
+ lang_a, lang_b)
+ output_file = os.path.join(
+ test_data_output_path,
+ "{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
+ classifier_data_lib.file_based_convert_examples_to_features(
+ test_input_data_examples, None, max_seq_length, tokenizer,
+ output_file, None)
+ meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
+
+ return meta_data
diff --git a/modeling/official/nlp/data/squad_lib.py b/modeling/official/nlp/data/squad_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bf05f478abc0b95a681d193e821f6483569b00b
--- /dev/null
+++ b/modeling/official/nlp/data/squad_lib.py
@@ -0,0 +1,975 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Library to process data for SQuAD 1.1 and SQuAD 2.0."""
+# pylint: disable=g-bad-import-order
+import collections
+import copy
+import json
+import math
+import os
+
+import six
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.nlp.tools import tokenization
+
+
+class SquadExample(object):
+ """A single training/test example for simple sequence classification.
+
+ For examples without an answer, the start and end position are -1.
+
+ Attributes:
+ qas_id: ID of the question-answer pair.
+ question_text: Original text for the question.
+ doc_tokens: The list of tokens in the context obtained by splitting on
+ whitespace only.
+ orig_answer_text: Original text for the answer.
+ start_position: Starting index of the answer in `doc_tokens`.
+ end_position: Ending index of the answer in `doc_tokens`.
+ is_impossible: Whether the question is impossible to answer given the
+ context. Only used in SQuAD 2.0.
+ """
+
+ def __init__(self,
+ qas_id,
+ question_text,
+ doc_tokens,
+ orig_answer_text=None,
+ start_position=None,
+ end_position=None,
+ is_impossible=False):
+ self.qas_id = qas_id
+ self.question_text = question_text
+ self.doc_tokens = doc_tokens
+ self.orig_answer_text = orig_answer_text
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ s = ""
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
+ s += ", question_text: %s" % (
+ tokenization.printable_text(self.question_text))
+ s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
+ if self.start_position:
+ s += ", start_position: %d" % (self.start_position)
+ if self.start_position:
+ s += ", end_position: %d" % (self.end_position)
+ if self.start_position:
+ s += ", is_impossible: %r" % (self.is_impossible)
+ return s
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tokens,
+ token_to_orig_map,
+ token_is_max_context,
+ input_ids,
+ input_mask,
+ segment_ids,
+ paragraph_mask=None,
+ class_index=None,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tokens = tokens
+ self.token_to_orig_map = token_to_orig_map
+ self.token_is_max_context = token_is_max_context
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.segment_ids = segment_ids
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+ self.paragraph_mask = paragraph_mask
+ self.class_index = class_index
+
+
+class FeatureWriter(object):
+ """Writes InputFeature to TF example file."""
+
+ def __init__(self, filename, is_training):
+ self.filename = filename
+ self.is_training = is_training
+ self.num_features = 0
+ tf.io.gfile.makedirs(os.path.dirname(filename))
+ self._writer = tf.io.TFRecordWriter(filename)
+
+ def process_feature(self, feature):
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
+ self.num_features += 1
+
+ def create_int_feature(values):
+ feature = tf.train.Feature(
+ int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+ features = collections.OrderedDict()
+ features["unique_ids"] = create_int_feature([feature.unique_id])
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_int_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+
+ if feature.paragraph_mask is not None:
+ features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
+ if feature.class_index is not None:
+ features["class_index"] = create_int_feature([feature.class_index])
+
+ if self.is_training:
+ features["start_positions"] = create_int_feature([feature.start_position])
+ features["end_positions"] = create_int_feature([feature.end_position])
+ impossible = 0
+ if feature.is_impossible:
+ impossible = 1
+ features["is_impossible"] = create_int_feature([impossible])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ self._writer.write(tf_example.SerializeToString())
+
+ def close(self):
+ self._writer.close()
+
+
+def read_squad_examples(input_file, is_training,
+ version_2_with_negative,
+ translated_input_folder=None):
+ """Read a SQuAD json file into a list of SquadExample."""
+ with tf.io.gfile.GFile(input_file, "r") as reader:
+ input_data = json.load(reader)["data"]
+
+ if translated_input_folder is not None:
+ translated_files = tf.io.gfile.glob(
+ os.path.join(translated_input_folder, "*.json"))
+ for file in translated_files:
+ with tf.io.gfile.GFile(file, "r") as reader:
+ input_data.extend(json.load(reader)["data"])
+
+ def is_whitespace(c):
+ if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
+ return True
+ return False
+
+ examples = []
+ for entry in input_data:
+ for paragraph in entry["paragraphs"]:
+ paragraph_text = paragraph["context"]
+ doc_tokens = []
+ char_to_word_offset = []
+ prev_is_whitespace = True
+ for c in paragraph_text:
+ if is_whitespace(c):
+ prev_is_whitespace = True
+ else:
+ if prev_is_whitespace:
+ doc_tokens.append(c)
+ else:
+ doc_tokens[-1] += c
+ prev_is_whitespace = False
+ char_to_word_offset.append(len(doc_tokens) - 1)
+
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question_text = qa["question"]
+ start_position = None
+ end_position = None
+ orig_answer_text = None
+ is_impossible = False
+ if is_training:
+
+ if version_2_with_negative:
+ is_impossible = qa["is_impossible"]
+ if (len(qa["answers"]) != 1) and (not is_impossible):
+ raise ValueError(
+ "For training, each question should have exactly 1 answer.")
+ if not is_impossible:
+ answer = qa["answers"][0]
+ orig_answer_text = answer["text"]
+ answer_offset = answer["answer_start"]
+ answer_length = len(orig_answer_text)
+ start_position = char_to_word_offset[answer_offset]
+ end_position = char_to_word_offset[answer_offset + answer_length -
+ 1]
+ # Only add answers where the text can be exactly recovered from the
+ # document. If this CAN'T happen it's likely due to weird Unicode
+ # stuff so we will just skip the example.
+ #
+ # Note that this means for training mode, every example is NOT
+ # guaranteed to be preserved.
+ actual_text = " ".join(doc_tokens[start_position:(end_position +
+ 1)])
+ cleaned_answer_text = " ".join(
+ tokenization.whitespace_tokenize(orig_answer_text))
+ if actual_text.find(cleaned_answer_text) == -1:
+ logging.warning("Could not find answer: '%s' vs. '%s'",
+ actual_text, cleaned_answer_text)
+ continue
+ else:
+ start_position = -1
+ end_position = -1
+ orig_answer_text = ""
+
+ example = SquadExample(
+ qas_id=qas_id,
+ question_text=question_text,
+ doc_tokens=doc_tokens,
+ orig_answer_text=orig_answer_text,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=is_impossible)
+ examples.append(example)
+
+ return examples
+
+
+def convert_examples_to_features(examples,
+ tokenizer,
+ max_seq_length,
+ doc_stride,
+ max_query_length,
+ is_training,
+ output_fn,
+ xlnet_format=False,
+ batch_size=None):
+ """Loads a data file into a list of `InputBatch`s."""
+
+ base_id = 1000000000
+ unique_id = base_id
+ feature = None
+ for (example_index, example) in enumerate(examples):
+ query_tokens = tokenizer.tokenize(example.question_text)
+
+ if len(query_tokens) > max_query_length:
+ query_tokens = query_tokens[0:max_query_length]
+
+ tok_to_orig_index = []
+ orig_to_tok_index = []
+ all_doc_tokens = []
+ for (i, token) in enumerate(example.doc_tokens):
+ orig_to_tok_index.append(len(all_doc_tokens))
+ sub_tokens = tokenizer.tokenize(token)
+ for sub_token in sub_tokens:
+ tok_to_orig_index.append(i)
+ all_doc_tokens.append(sub_token)
+
+ tok_start_position = None
+ tok_end_position = None
+ if is_training and example.is_impossible:
+ tok_start_position = -1
+ tok_end_position = -1
+ if is_training and not example.is_impossible:
+ tok_start_position = orig_to_tok_index[example.start_position]
+ if example.end_position < len(example.doc_tokens) - 1:
+ tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
+ else:
+ tok_end_position = len(all_doc_tokens) - 1
+ (tok_start_position, tok_end_position) = _improve_answer_span(
+ all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
+ example.orig_answer_text)
+
+ # The -3 accounts for [CLS], [SEP] and [SEP]
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
+
+ # We can have documents that are longer than the maximum sequence length.
+ # To deal with this we do a sliding window approach, where we take chunks
+ # of the up to our max length with a stride of `doc_stride`.
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
+ "DocSpan", ["start", "length"])
+ doc_spans = []
+ start_offset = 0
+ while start_offset < len(all_doc_tokens):
+ length = len(all_doc_tokens) - start_offset
+ if length > max_tokens_for_doc:
+ length = max_tokens_for_doc
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
+ if start_offset + length == len(all_doc_tokens):
+ break
+ start_offset += min(length, doc_stride)
+
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
+ tokens = []
+ token_to_orig_map = {}
+ token_is_max_context = {}
+ segment_ids = []
+
+ # Paragraph mask used in XLNet.
+ # 1 represents paragraph and class tokens.
+ # 0 represents query and other special tokens.
+ paragraph_mask = []
+
+ # pylint: disable=cell-var-from-loop
+ def process_query(seg_q):
+ for token in query_tokens:
+ tokens.append(token)
+ segment_ids.append(seg_q)
+ paragraph_mask.append(0)
+ tokens.append("[SEP]")
+ segment_ids.append(seg_q)
+ paragraph_mask.append(0)
+
+ def process_paragraph(seg_p):
+ for i in range(doc_span.length):
+ split_token_index = doc_span.start + i
+ token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
+
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
+ split_token_index)
+ token_is_max_context[len(tokens)] = is_max_context
+ tokens.append(all_doc_tokens[split_token_index])
+ segment_ids.append(seg_p)
+ paragraph_mask.append(1)
+ tokens.append("[SEP]")
+ segment_ids.append(seg_p)
+ paragraph_mask.append(0)
+
+ def process_class(seg_class):
+ class_index = len(segment_ids)
+ tokens.append("[CLS]")
+ segment_ids.append(seg_class)
+ paragraph_mask.append(1)
+ return class_index
+
+ if xlnet_format:
+ seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
+ process_paragraph(seg_p)
+ process_query(seg_q)
+ class_index = process_class(seg_class)
+ else:
+ seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
+ class_index = process_class(seg_class)
+ process_query(seg_q)
+ process_paragraph(seg_p)
+
+ input_ids = tokenizer.convert_tokens_to_ids(tokens)
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(seg_pad)
+ paragraph_mask.append(0)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+ assert len(paragraph_mask) == max_seq_length
+
+ start_position = 0
+ end_position = 0
+ span_contains_answer = False
+
+ if is_training and not example.is_impossible:
+ # For training, if our document chunk does not contain an annotation
+ # we throw it out, since there is nothing to predict.
+ doc_start = doc_span.start
+ doc_end = doc_span.start + doc_span.length - 1
+ span_contains_answer = (tok_start_position >= doc_start and
+ tok_end_position <= doc_end)
+ if span_contains_answer:
+ doc_offset = 0 if xlnet_format else len(query_tokens) + 2
+ start_position = tok_start_position - doc_start + doc_offset
+ end_position = tok_end_position - doc_start + doc_offset
+
+ if example_index < 20:
+ logging.info("*** Example ***")
+ logging.info("unique_id: %s", (unique_id))
+ logging.info("example_index: %s", (example_index))
+ logging.info("doc_span_index: %s", (doc_span_index))
+ logging.info("tokens: %s",
+ " ".join([tokenization.printable_text(x) for x in tokens]))
+ logging.info(
+ "token_to_orig_map: %s", " ".join([
+ "%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)
+ ]))
+ logging.info(
+ "token_is_max_context: %s", " ".join([
+ "%d:%s" % (x, y)
+ for (x, y) in six.iteritems(token_is_max_context)
+ ]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("paragraph_mask: %s", " ".join(
+ [str(x) for x in paragraph_mask]))
+ logging.info("class_index: %d", class_index)
+ if is_training:
+ if span_contains_answer:
+ answer_text = " ".join(tokens[start_position:(end_position + 1)])
+ logging.info("start_position: %d", (start_position))
+ logging.info("end_position: %d", (end_position))
+ logging.info("answer: %s", tokenization.printable_text(answer_text))
+ else:
+ logging.info("document span doesn't contain answer")
+
+ feature = InputFeatures(
+ unique_id=unique_id,
+ example_index=example_index,
+ doc_span_index=doc_span_index,
+ tokens=tokens,
+ paragraph_mask=paragraph_mask,
+ class_index=class_index,
+ token_to_orig_map=token_to_orig_map,
+ token_is_max_context=token_is_max_context,
+ input_ids=input_ids,
+ input_mask=input_mask,
+ segment_ids=segment_ids,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=not span_contains_answer)
+
+ # Run callback
+ if is_training:
+ output_fn(feature)
+ else:
+ output_fn(feature, is_padding=False)
+
+ unique_id += 1
+
+ if not is_training and feature:
+ assert batch_size
+ num_padding = 0
+ num_examples = unique_id - base_id
+ if unique_id % batch_size != 0:
+ num_padding = batch_size - (num_examples % batch_size)
+ logging.info("Adding padding examples to make sure no partial batch.")
+ logging.info("Adds %d padding examples for inference.", num_padding)
+ dummy_feature = copy.deepcopy(feature)
+ for _ in range(num_padding):
+ dummy_feature.unique_id = unique_id
+
+ # Run callback
+ output_fn(feature, is_padding=True)
+ unique_id += 1
+ return unique_id - base_id
+
+
+def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
+ orig_answer_text):
+ """Returns tokenized answer spans that better match the annotated answer."""
+
+ # The SQuAD annotations are character based. We first project them to
+ # whitespace-tokenized words. But then after WordPiece tokenization, we can
+ # often find a "better match". For example:
+ #
+ # Question: What year was John Smith born?
+ # Context: The leader was John Smith (1895-1943).
+ # Answer: 1895
+ #
+ # The original whitespace-tokenized answer will be "(1895-1943).". However
+ # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
+ # the exact answer, 1895.
+ #
+ # However, this is not always possible. Consider the following:
+ #
+ # Question: What country is the top exporter of electornics?
+ # Context: The Japanese electronics industry is the lagest in the world.
+ # Answer: Japan
+ #
+ # In this case, the annotator chose "Japan" as a character sub-span of
+ # the word "Japanese". Since our WordPiece tokenizer does not split
+ # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
+ # in SQuAD, but does happen.
+ tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
+
+ for new_start in range(input_start, input_end + 1):
+ for new_end in range(input_end, new_start - 1, -1):
+ text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
+ if text_span == tok_answer_text:
+ return (new_start, new_end)
+
+ return (input_start, input_end)
+
+
+def _check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the 'max context' doc span for the token."""
+
+ # Because of the sliding window approach taken to scoring documents, a single
+ # token can appear in multiple documents. E.g.
+ # Doc: the man went to the store and bought a gallon of milk
+ # Span A: the man went to the
+ # Span B: to the store and bought
+ # Span C: and bought a gallon of
+ # ...
+ #
+ # Now the word 'bought' will have two scores from spans B and C. We only
+ # want to consider the score with "maximum context", which we define as
+ # the *minimum* of its left and right context (the *sum* of left and
+ # right context will always be the same, of course).
+ #
+ # In the example the maximum context for 'bought' would be span C since
+ # it has 1 left context and 3 right context, while span B has 4 left context
+ # and 0 right context.
+ best_score = None
+ best_span_index = None
+ for (span_index, doc_span) in enumerate(doc_spans):
+ end = doc_span.start + doc_span.length - 1
+ if position < doc_span.start:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span.start
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+def write_predictions(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ verbose=False):
+ """Write final predictions to the json file and log-odds of null if needed."""
+ logging.info("Writing predictions to: %s", (output_prediction_file))
+ logging.info("Writing nbest to: %s", (output_nbest_file))
+
+ all_predictions, all_nbest_json, scores_diff_json = (
+ postprocess_output(
+ all_examples=all_examples,
+ all_features=all_features,
+ all_results=all_results,
+ n_best_size=n_best_size,
+ max_answer_length=max_answer_length,
+ do_lower_case=do_lower_case,
+ version_2_with_negative=version_2_with_negative,
+ null_score_diff_threshold=null_score_diff_threshold,
+ verbose=verbose))
+
+ write_to_json_files(all_predictions, output_prediction_file)
+ write_to_json_files(all_nbest_json, output_nbest_file)
+ if version_2_with_negative:
+ write_to_json_files(scores_diff_json, output_null_log_odds_file)
+
+
+def postprocess_output(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ xlnet_format=False,
+ verbose=False):
+ """Postprocess model output, to form predicton results."""
+
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction",
+ ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for (example_index, example) in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+ min_null_feature_index = 0 # the paragraph slice with min mull score
+ null_start_logit = 0 # the start logit at the slice with min null score
+ null_end_logit = 0 # the end logit at the slice with min null score
+ for (feature_index, feature) in enumerate(features):
+ if feature.unique_id not in unique_id_to_result:
+ logging.info("Skip eval example %s, not in pred.", feature.unique_id)
+ continue
+ result = unique_id_to_result[feature.unique_id]
+
+ # if we could have irrelevant answers, get the min score of irrelevant
+ if version_2_with_negative:
+ if xlnet_format:
+ feature_null_score = result.class_logits
+ else:
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
+ if feature_null_score < score_null:
+ score_null = feature_null_score
+ min_null_feature_index = feature_index
+ null_start_logit = result.start_logits[0]
+ null_end_logit = result.end_logits[0]
+ for (start_index, start_logit,
+ end_index, end_logit) in _get_best_indexes_and_logits(
+ result=result,
+ n_best_size=n_best_size,
+ xlnet_format=xlnet_format):
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index >= len(feature.tokens):
+ continue
+ if end_index >= len(feature.tokens):
+ continue
+ if start_index not in feature.token_to_orig_map:
+ continue
+ if end_index not in feature.token_to_orig_map:
+ continue
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index,
+ end_index=end_index,
+ start_logit=start_logit,
+ end_logit=end_logit))
+
+ if version_2_with_negative and not xlnet_format:
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=min_null_feature_index,
+ start_index=0,
+ end_index=0,
+ start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ prelim_predictions = sorted(
+ prelim_predictions,
+ key=lambda x: (x.start_logit + x.end_logit),
+ reverse=True)
+
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_logit", "end_logit"])
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+ if pred.start_index > 0 or xlnet_format: # this is a non-null prediction
+ tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
+ orig_doc_start = feature.token_to_orig_map[pred.start_index]
+ orig_doc_end = feature.token_to_orig_map[pred.end_index]
+ orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
+ tok_text = " ".join(tok_tokens)
+
+ # De-tokenize WordPieces that have been split off.
+ tok_text = tok_text.replace(" ##", "")
+ tok_text = tok_text.replace("##", "")
+
+ # Clean whitespace
+ tok_text = tok_text.strip()
+ tok_text = " ".join(tok_text.split())
+ orig_text = " ".join(orig_tokens)
+
+ final_text = get_final_text(
+ tok_text, orig_text, do_lower_case, verbose=verbose)
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+ else:
+ final_text = ""
+ seen_predictions[final_text] = True
+
+ nbest.append(
+ _NbestPrediction(
+ text=final_text,
+ start_logit=pred.start_logit,
+ end_logit=pred.end_logit))
+
+ # if we didn't inlude the empty option in the n-best, inlcude it
+ if version_2_with_negative and not xlnet_format:
+ if "" not in seen_predictions:
+ nbest.append(
+ _NbestPrediction(
+ text="", start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(
+ _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
+
+ assert len(nbest) >= 1
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_logit + entry.end_logit)
+ if not best_non_null_entry:
+ if entry.text:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for (i, entry) in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_logit"] = entry.start_logit
+ output["end_logit"] = entry.end_logit
+ nbest_json.append(output)
+
+ assert len(nbest_json) >= 1
+
+ if not version_2_with_negative:
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
+ else:
+ # pytype: disable=attribute-error
+ # predict "" iff the null score - the score of best non-null > threshold
+ if best_non_null_entry is not None:
+ if xlnet_format:
+ score_diff = score_null
+ scores_diff_json[example.qas_id] = score_diff
+ all_predictions[example.qas_id] = best_non_null_entry.text
+ else:
+ score_diff = score_null - best_non_null_entry.start_logit - (
+ best_non_null_entry.end_logit)
+ scores_diff_json[example.qas_id] = score_diff
+ if score_diff > null_score_diff_threshold:
+ all_predictions[example.qas_id] = ""
+ else:
+ all_predictions[example.qas_id] = best_non_null_entry.text
+ else:
+ logging.warning("best_non_null_entry is None")
+ scores_diff_json[example.qas_id] = score_null
+ all_predictions[example.qas_id] = ""
+ # pytype: enable=attribute-error
+
+ all_nbest_json[example.qas_id] = nbest_json
+
+ return all_predictions, all_nbest_json, scores_diff_json
+
+
+def write_to_json_files(json_records, json_file):
+ with tf.io.gfile.GFile(json_file, "w") as writer:
+ writer.write(json.dumps(json_records, indent=4) + "\n")
+
+
+def get_final_text(pred_text, orig_text, do_lower_case, verbose=False):
+ """Project the tokenized prediction back to the original text."""
+
+ # When we created the data, we kept track of the alignment between original
+ # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
+ # now `orig_text` contains the span of our original text corresponding to the
+ # span that we predicted.
+ #
+ # However, `orig_text` may contain extra characters that we don't want in
+ # our prediction.
+ #
+ # For example, let's say:
+ # pred_text = steve smith
+ # orig_text = Steve Smith's
+ #
+ # We don't want to return `orig_text` because it contains the extra "'s".
+ #
+ # We don't want to return `pred_text` because it's already been normalized
+ # (the SQuAD eval script also does punctuation stripping/lower casing but
+ # our tokenizer does additional normalization like stripping accent
+ # characters).
+ #
+ # What we really want to return is "Steve Smith".
+ #
+ # Therefore, we have to apply a semi-complicated alignment heruistic between
+ # `pred_text` and `orig_text` to get a character-to-charcter alignment. This
+ # can fail in certain cases in which case we just return `orig_text`.
+
+ def _strip_spaces(text):
+ ns_chars = []
+ ns_to_s_map = collections.OrderedDict()
+ for (i, c) in enumerate(text):
+ if c == " ":
+ continue
+ ns_to_s_map[len(ns_chars)] = i
+ ns_chars.append(c)
+ ns_text = "".join(ns_chars)
+ return (ns_text, ns_to_s_map)
+
+ # We first tokenize `orig_text`, strip whitespace from the result
+ # and `pred_text`, and check if they are the same length. If they are
+ # NOT the same length, the heuristic has failed. If they are the same
+ # length, we assume the characters are one-to-one aligned.
+ tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
+
+ tok_text = " ".join(tokenizer.tokenize(orig_text))
+
+ start_position = tok_text.find(pred_text)
+ if start_position == -1:
+ if verbose:
+ logging.info("Unable to find text: '%s' in '%s'", pred_text, orig_text)
+ return orig_text
+ end_position = start_position + len(pred_text) - 1
+
+ (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
+ (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
+
+ if len(orig_ns_text) != len(tok_ns_text):
+ if verbose:
+ logging.info("Length not equal after stripping spaces: '%s' vs '%s'",
+ orig_ns_text, tok_ns_text)
+ return orig_text
+
+ # We then project the characters in `pred_text` back to `orig_text` using
+ # the character-to-character alignment.
+ tok_s_to_ns_map = {}
+ for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
+ tok_s_to_ns_map[tok_index] = i
+
+ orig_start_position = None
+ if start_position in tok_s_to_ns_map:
+ ns_start_position = tok_s_to_ns_map[start_position]
+ if ns_start_position in orig_ns_to_s_map:
+ orig_start_position = orig_ns_to_s_map[ns_start_position]
+
+ if orig_start_position is None:
+ if verbose:
+ logging.info("Couldn't map start position")
+ return orig_text
+
+ orig_end_position = None
+ if end_position in tok_s_to_ns_map:
+ ns_end_position = tok_s_to_ns_map[end_position]
+ if ns_end_position in orig_ns_to_s_map:
+ orig_end_position = orig_ns_to_s_map[ns_end_position]
+
+ if orig_end_position is None:
+ if verbose:
+ logging.info("Couldn't map end position")
+ return orig_text
+
+ output_text = orig_text[orig_start_position:(orig_end_position + 1)]
+ return output_text
+
+
+def _get_best_indexes_and_logits(result,
+ n_best_size,
+ xlnet_format=False):
+ """Generates the n-best indexes and logits from a list."""
+ if xlnet_format:
+ for i in range(n_best_size):
+ for j in range(n_best_size):
+ j_index = i * n_best_size + j
+ yield (result.start_indexes[i], result.start_logits[i],
+ result.end_indexes[j_index], result.end_logits[j_index])
+ else:
+ start_index_and_score = sorted(enumerate(result.start_logits),
+ key=lambda x: x[1], reverse=True)
+ end_index_and_score = sorted(enumerate(result.end_logits),
+ key=lambda x: x[1], reverse=True)
+ for i in range(len(start_index_and_score)):
+ if i >= n_best_size:
+ break
+ for j in range(len(end_index_and_score)):
+ if j >= n_best_size:
+ break
+ yield (start_index_and_score[i][0], start_index_and_score[i][1],
+ end_index_and_score[j][0], end_index_and_score[j][1])
+
+
+def _compute_softmax(scores):
+ """Compute softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
+
+
+def generate_tf_record_from_json_file(input_file_path,
+ vocab_file_path,
+ output_path,
+ translated_input_folder=None,
+ max_seq_length=384,
+ do_lower_case=True,
+ max_query_length=64,
+ doc_stride=128,
+ version_2_with_negative=False,
+ xlnet_format=False):
+ """Generates and saves training data into a tf record file."""
+ train_examples = read_squad_examples(
+ input_file=input_file_path,
+ is_training=True,
+ version_2_with_negative=version_2_with_negative,
+ translated_input_folder=translated_input_folder)
+ tokenizer = tokenization.FullTokenizer(
+ vocab_file=vocab_file_path, do_lower_case=do_lower_case)
+ train_writer = FeatureWriter(filename=output_path, is_training=True)
+ number_of_examples = convert_examples_to_features(
+ examples=train_examples,
+ tokenizer=tokenizer,
+ max_seq_length=max_seq_length,
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=True,
+ output_fn=train_writer.process_feature,
+ xlnet_format=xlnet_format)
+ train_writer.close()
+
+ meta_data = {
+ "task_type": "bert_squad",
+ "train_data_size": number_of_examples,
+ "max_seq_length": max_seq_length,
+ "max_query_length": max_query_length,
+ "doc_stride": doc_stride,
+ "version_2_with_negative": version_2_with_negative,
+ }
+
+ return meta_data
diff --git a/modeling/official/nlp/data/squad_lib_sp.py b/modeling/official/nlp/data/squad_lib_sp.py
new file mode 100644
index 0000000000000000000000000000000000000000..55394c38a1bafb0839810a6608b334f3820bf8c5
--- /dev/null
+++ b/modeling/official/nlp/data/squad_lib_sp.py
@@ -0,0 +1,976 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Run ALBERT on SQuAD 1.1 and SQuAD 2.0 using sentence piece tokenization.
+
+The file is forked from:
+
+https://github.com/google-research/ALBERT/blob/master/run_squad_sp.py
+"""
+import collections
+import copy
+import json
+import math
+import os
+
+from absl import logging
+import numpy as np
+import tensorflow as tf, tf_keras
+
+from official.nlp.tools import tokenization
+
+
+class SquadExample(object):
+ """A single training/test example for simple sequence classification.
+
+ For examples without an answer, the start and end position are -1.
+ """
+
+ def __init__(self,
+ qas_id,
+ question_text,
+ paragraph_text,
+ orig_answer_text=None,
+ start_position=None,
+ end_position=None,
+ is_impossible=False):
+ self.qas_id = qas_id
+ self.question_text = question_text
+ self.paragraph_text = paragraph_text
+ self.orig_answer_text = orig_answer_text
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+ def __str__(self):
+ return self.__repr__()
+
+ def __repr__(self):
+ s = ""
+ s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
+ s += ", question_text: %s" % (
+ tokenization.printable_text(self.question_text))
+ s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
+ if self.start_position:
+ s += ", start_position: %d" % (self.start_position)
+ if self.start_position:
+ s += ", end_position: %d" % (self.end_position)
+ if self.start_position:
+ s += ", is_impossible: %r" % (self.is_impossible)
+ return s
+
+
+class InputFeatures(object):
+ """A single set of features of data."""
+
+ def __init__(self,
+ unique_id,
+ example_index,
+ doc_span_index,
+ tok_start_to_orig_index,
+ tok_end_to_orig_index,
+ token_is_max_context,
+ tokens,
+ input_ids,
+ input_mask,
+ segment_ids,
+ paragraph_len,
+ class_index=None,
+ paragraph_mask=None,
+ start_position=None,
+ end_position=None,
+ is_impossible=None):
+ self.unique_id = unique_id
+ self.example_index = example_index
+ self.doc_span_index = doc_span_index
+ self.tok_start_to_orig_index = tok_start_to_orig_index
+ self.tok_end_to_orig_index = tok_end_to_orig_index
+ self.token_is_max_context = token_is_max_context
+ self.tokens = tokens
+ self.input_ids = input_ids
+ self.input_mask = input_mask
+ self.paragraph_mask = paragraph_mask
+ self.segment_ids = segment_ids
+ self.paragraph_len = paragraph_len
+ self.class_index = class_index
+ self.start_position = start_position
+ self.end_position = end_position
+ self.is_impossible = is_impossible
+
+
+def read_squad_examples(input_file,
+ is_training,
+ version_2_with_negative,
+ translated_input_folder=None):
+ """Read a SQuAD json file into a list of SquadExample."""
+ del version_2_with_negative
+ with tf.io.gfile.GFile(input_file, "r") as reader:
+ input_data = json.load(reader)["data"]
+
+ if translated_input_folder is not None:
+ translated_files = tf.io.gfile.glob(
+ os.path.join(translated_input_folder, "*.json"))
+ for file in translated_files:
+ with tf.io.gfile.GFile(file, "r") as reader:
+ input_data.extend(json.load(reader)["data"])
+
+ examples = []
+ for entry in input_data:
+ for paragraph in entry["paragraphs"]:
+ paragraph_text = paragraph["context"]
+
+ for qa in paragraph["qas"]:
+ qas_id = qa["id"]
+ question_text = qa["question"]
+ start_position = None
+ orig_answer_text = None
+ is_impossible = False
+
+ if is_training:
+ is_impossible = qa.get("is_impossible", False)
+ if (len(qa["answers"]) != 1) and (not is_impossible):
+ raise ValueError(
+ "For training, each question should have exactly 1 answer.")
+ if not is_impossible:
+ answer = qa["answers"][0]
+ orig_answer_text = answer["text"]
+ start_position = answer["answer_start"]
+ else:
+ start_position = -1
+ orig_answer_text = ""
+
+ example = SquadExample(
+ qas_id=qas_id,
+ question_text=question_text,
+ paragraph_text=paragraph_text,
+ orig_answer_text=orig_answer_text,
+ start_position=start_position,
+ is_impossible=is_impossible)
+ examples.append(example)
+
+ return examples
+
+
+def _convert_index(index, pos, m=None, is_start=True):
+ """Converts index."""
+ if index[pos] is not None:
+ return index[pos]
+ n = len(index)
+ rear = pos
+ while rear < n - 1 and index[rear] is None:
+ rear += 1
+ front = pos
+ while front > 0 and index[front] is None:
+ front -= 1
+ assert index[front] is not None or index[rear] is not None
+ if index[front] is None:
+ if index[rear] >= 1: # pytype: disable=unsupported-operands
+ if is_start:
+ return 0
+ else:
+ return index[rear] - 1
+ return index[rear]
+ if index[rear] is None:
+ if m is not None and index[front] < m - 1:
+ if is_start:
+ return index[front] + 1
+ else:
+ return m - 1
+ return index[front]
+ if is_start:
+ if index[rear] > index[front] + 1:
+ return index[front] + 1
+ else:
+ return index[rear]
+ else:
+ if index[rear] > index[front] + 1:
+ return index[rear] - 1
+ else:
+ return index[front]
+
+
+def convert_examples_to_features(examples,
+ tokenizer,
+ max_seq_length,
+ doc_stride,
+ max_query_length,
+ is_training,
+ output_fn,
+ do_lower_case,
+ xlnet_format=False,
+ batch_size=None):
+ """Loads a data file into a list of `InputBatch`s."""
+ cnt_pos, cnt_neg = 0, 0
+ base_id = 1000000000
+ unique_id = base_id
+ max_n, max_m = 1024, 1024
+ f = np.zeros((max_n, max_m), dtype=np.float32)
+
+ for (example_index, example) in enumerate(examples):
+
+ if example_index % 100 == 0:
+ logging.info("Converting %d/%d pos %d neg %d", example_index,
+ len(examples), cnt_pos, cnt_neg)
+
+ query_tokens = tokenization.encode_ids(
+ tokenizer.sp_model,
+ tokenization.preprocess_text(
+ example.question_text, lower=do_lower_case))
+
+ if len(query_tokens) > max_query_length:
+ query_tokens = query_tokens[0:max_query_length]
+
+ paragraph_text = example.paragraph_text
+ para_tokens = tokenization.encode_pieces(
+ tokenizer.sp_model,
+ tokenization.preprocess_text(
+ example.paragraph_text, lower=do_lower_case))
+
+ chartok_to_tok_index = []
+ tok_start_to_chartok_index = []
+ tok_end_to_chartok_index = []
+ char_cnt = 0
+ for i, token in enumerate(para_tokens):
+ new_token = token.replace(tokenization.SPIECE_UNDERLINE, " ")
+ chartok_to_tok_index.extend([i] * len(new_token))
+ tok_start_to_chartok_index.append(char_cnt)
+ char_cnt += len(new_token)
+ tok_end_to_chartok_index.append(char_cnt - 1)
+
+ tok_cat_text = "".join(para_tokens).replace(tokenization.SPIECE_UNDERLINE,
+ " ")
+ n, m = len(paragraph_text), len(tok_cat_text)
+
+ if n > max_n or m > max_m:
+ max_n = max(n, max_n)
+ max_m = max(m, max_m)
+ f = np.zeros((max_n, max_m), dtype=np.float32)
+
+ g = {}
+
+ # pylint: disable=cell-var-from-loop
+ def _lcs_match(max_dist, n=n, m=m):
+ """Longest-common-substring algorithm."""
+ f.fill(0)
+ g.clear()
+
+ ### longest common sub sequence
+ # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
+ for i in range(n):
+
+ # unlike standard LCS, this is specifically optimized for the setting
+ # because the mismatch between sentence pieces and original text will
+ # be small
+ for j in range(i - max_dist, i + max_dist):
+ if j >= m or j < 0:
+ continue
+
+ if i > 0:
+ g[(i, j)] = 0
+ f[i, j] = f[i - 1, j]
+
+ if j > 0 and f[i, j - 1] > f[i, j]:
+ g[(i, j)] = 1
+ f[i, j] = f[i, j - 1]
+
+ f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
+ if (tokenization.preprocess_text(
+ paragraph_text[i], lower=do_lower_case,
+ remove_space=False) == tok_cat_text[j] and f_prev + 1 > f[i, j]):
+ g[(i, j)] = 2
+ f[i, j] = f_prev + 1
+
+ # pylint: enable=cell-var-from-loop
+
+ max_dist = abs(n - m) + 5
+ for _ in range(2):
+ _lcs_match(max_dist)
+ if f[n - 1, m - 1] > 0.8 * n:
+ break
+ max_dist *= 2
+
+ orig_to_chartok_index = [None] * n
+ chartok_to_orig_index = [None] * m
+ i, j = n - 1, m - 1
+ while i >= 0 and j >= 0:
+ if (i, j) not in g:
+ break
+ if g[(i, j)] == 2:
+ orig_to_chartok_index[i] = j
+ chartok_to_orig_index[j] = i
+ i, j = i - 1, j - 1
+ elif g[(i, j)] == 1:
+ j = j - 1
+ else:
+ i = i - 1
+
+ if (all(v is None for v in orig_to_chartok_index) or
+ f[n - 1, m - 1] < 0.8 * n):
+ logging.info("MISMATCH DETECTED!")
+ continue
+
+ tok_start_to_orig_index = []
+ tok_end_to_orig_index = []
+ for i in range(len(para_tokens)):
+ start_chartok_pos = tok_start_to_chartok_index[i]
+ end_chartok_pos = tok_end_to_chartok_index[i]
+ start_orig_pos = _convert_index(
+ chartok_to_orig_index, start_chartok_pos, n, is_start=True)
+ end_orig_pos = _convert_index(
+ chartok_to_orig_index, end_chartok_pos, n, is_start=False)
+
+ tok_start_to_orig_index.append(start_orig_pos)
+ tok_end_to_orig_index.append(end_orig_pos)
+
+ if not is_training:
+ tok_start_position = tok_end_position = None
+
+ if is_training and example.is_impossible:
+ tok_start_position = 0
+ tok_end_position = 0
+
+ if is_training and not example.is_impossible:
+ start_position = example.start_position
+ end_position = start_position + len(example.orig_answer_text) - 1
+
+ start_chartok_pos = _convert_index(
+ orig_to_chartok_index, start_position, is_start=True)
+ tok_start_position = chartok_to_tok_index[start_chartok_pos]
+
+ end_chartok_pos = _convert_index(
+ orig_to_chartok_index, end_position, is_start=False)
+ tok_end_position = chartok_to_tok_index[end_chartok_pos]
+ assert tok_start_position <= tok_end_position
+
+ def _piece_to_id(x):
+ return tokenizer.sp_model.PieceToId(x)
+
+ all_doc_tokens = list(map(_piece_to_id, para_tokens))
+
+ # The -3 accounts for [CLS], [SEP] and [SEP]
+ max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
+
+ # We can have documents that are longer than the maximum sequence length.
+ # To deal with this we do a sliding window approach, where we take chunks
+ # of the up to our max length with a stride of `doc_stride`.
+ _DocSpan = collections.namedtuple( # pylint: disable=invalid-name
+ "DocSpan", ["start", "length"])
+ doc_spans = []
+ start_offset = 0
+
+ while start_offset < len(all_doc_tokens):
+ length = len(all_doc_tokens) - start_offset
+ if length > max_tokens_for_doc:
+ length = max_tokens_for_doc
+ doc_spans.append(_DocSpan(start=start_offset, length=length))
+ if start_offset + length == len(all_doc_tokens):
+ break
+ start_offset += min(length, doc_stride)
+
+ for (doc_span_index, doc_span) in enumerate(doc_spans):
+ tokens = []
+ token_is_max_context = {}
+ segment_ids = []
+
+ # Paragraph mask used in XLNet.
+ # 1 represents paragraph and class tokens.
+ # 0 represents query and other special tokens.
+ paragraph_mask = []
+
+ cur_tok_start_to_orig_index = []
+ cur_tok_end_to_orig_index = []
+
+ # pylint: disable=cell-var-from-loop
+ def process_query(seg_q):
+ for token in query_tokens:
+ tokens.append(token)
+ segment_ids.append(seg_q)
+ paragraph_mask.append(0)
+ tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
+ segment_ids.append(seg_q)
+ paragraph_mask.append(0)
+
+ def process_paragraph(seg_p):
+ for i in range(doc_span.length):
+ split_token_index = doc_span.start + i
+
+ cur_tok_start_to_orig_index.append(
+ tok_start_to_orig_index[split_token_index])
+ cur_tok_end_to_orig_index.append(
+ tok_end_to_orig_index[split_token_index])
+
+ is_max_context = _check_is_max_context(doc_spans, doc_span_index,
+ split_token_index)
+ token_is_max_context[len(tokens)] = is_max_context
+ tokens.append(all_doc_tokens[split_token_index])
+ segment_ids.append(seg_p)
+ paragraph_mask.append(1)
+ tokens.append(tokenizer.sp_model.PieceToId("[SEP]"))
+ segment_ids.append(seg_p)
+ paragraph_mask.append(0)
+ return len(tokens)
+
+ def process_class(seg_class):
+ class_index = len(segment_ids)
+ tokens.append(tokenizer.sp_model.PieceToId("[CLS]"))
+ segment_ids.append(seg_class)
+ paragraph_mask.append(1)
+ return class_index
+
+ if xlnet_format:
+ seg_p, seg_q, seg_class, seg_pad = 0, 1, 2, 3
+ paragraph_len = process_paragraph(seg_p)
+ process_query(seg_q)
+ class_index = process_class(seg_class)
+ else:
+ seg_p, seg_q, seg_class, seg_pad = 1, 0, 0, 0
+ class_index = process_class(seg_class)
+ process_query(seg_q)
+ paragraph_len = process_paragraph(seg_p)
+
+ input_ids = tokens
+
+ # The mask has 1 for real tokens and 0 for padding tokens. Only real
+ # tokens are attended to.
+ input_mask = [1] * len(input_ids)
+
+ # Zero-pad up to the sequence length.
+ while len(input_ids) < max_seq_length:
+ input_ids.append(0)
+ input_mask.append(0)
+ segment_ids.append(seg_pad)
+ paragraph_mask.append(0)
+
+ assert len(input_ids) == max_seq_length
+ assert len(input_mask) == max_seq_length
+ assert len(segment_ids) == max_seq_length
+ assert len(paragraph_mask) == max_seq_length
+
+ span_is_impossible = example.is_impossible
+ start_position = None
+ end_position = None
+ if is_training and not span_is_impossible:
+ # For training, if our document chunk does not contain an annotation
+ # we throw it out, since there is nothing to predict.
+ doc_start = doc_span.start
+ doc_end = doc_span.start + doc_span.length - 1
+ out_of_span = False
+ if not (tok_start_position >= doc_start and
+ tok_end_position <= doc_end):
+ out_of_span = True
+ if out_of_span:
+ # continue
+ start_position = 0
+ end_position = 0
+ span_is_impossible = True
+ else:
+ doc_offset = 0 if xlnet_format else len(query_tokens) + 2
+ start_position = tok_start_position - doc_start + doc_offset
+ end_position = tok_end_position - doc_start + doc_offset
+
+ if is_training and span_is_impossible:
+ start_position = class_index
+ end_position = class_index
+
+ if example_index < 20:
+ logging.info("*** Example ***")
+ logging.info("unique_id: %s", (unique_id))
+ logging.info("example_index: %s", (example_index))
+ logging.info("doc_span_index: %s", (doc_span_index))
+ logging.info("tok_start_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_start_to_orig_index]))
+ logging.info("tok_end_to_orig_index: %s",
+ " ".join([str(x) for x in cur_tok_end_to_orig_index]))
+ logging.info(
+ "token_is_max_context: %s", " ".join(
+ ["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()]))
+ logging.info(
+ "input_pieces: %s",
+ " ".join([tokenizer.sp_model.IdToPiece(x) for x in tokens]))
+ logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
+ logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
+ logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
+ logging.info("paragraph_mask: %s", " ".join(
+ [str(x) for x in paragraph_mask]))
+ logging.info("class_index: %d", class_index)
+
+ if is_training and span_is_impossible:
+ logging.info("impossible example span")
+
+ if is_training and not span_is_impossible:
+ pieces = [
+ tokenizer.sp_model.IdToPiece(token)
+ for token in tokens[start_position:(end_position + 1)]
+ ]
+ answer_text = tokenizer.sp_model.DecodePieces(pieces)
+ logging.info("start_position: %d", (start_position))
+ logging.info("end_position: %d", (end_position))
+ logging.info("answer: %s", (tokenization.printable_text(answer_text)))
+
+ # With multi processing, the example_index is actually the index
+ # within the current process therefore we use example_index=None
+ # to avoid being used in the future.
+ # The current code does not use example_index of training data.
+ if is_training:
+ feat_example_index = None
+ else:
+ feat_example_index = example_index
+
+ feature = InputFeatures(
+ unique_id=unique_id,
+ example_index=feat_example_index,
+ doc_span_index=doc_span_index,
+ tok_start_to_orig_index=cur_tok_start_to_orig_index,
+ tok_end_to_orig_index=cur_tok_end_to_orig_index,
+ token_is_max_context=token_is_max_context,
+ tokens=[tokenizer.sp_model.IdToPiece(x) for x in tokens],
+ input_ids=input_ids,
+ input_mask=input_mask,
+ paragraph_mask=paragraph_mask,
+ segment_ids=segment_ids,
+ paragraph_len=paragraph_len,
+ class_index=class_index,
+ start_position=start_position,
+ end_position=end_position,
+ is_impossible=span_is_impossible)
+
+ # Run callback
+ if is_training:
+ output_fn(feature)
+ else:
+ output_fn(feature, is_padding=False)
+
+ unique_id += 1
+ if span_is_impossible:
+ cnt_neg += 1
+ else:
+ cnt_pos += 1
+
+ if not is_training and feature:
+ assert batch_size
+ num_padding = 0
+ num_examples = unique_id - base_id
+ if unique_id % batch_size != 0:
+ num_padding = batch_size - (num_examples % batch_size)
+ dummy_feature = copy.deepcopy(feature)
+ for _ in range(num_padding):
+ dummy_feature.unique_id = unique_id
+
+ # Run callback
+ output_fn(feature, is_padding=True)
+ unique_id += 1
+
+ logging.info("Total number of instances: %d = pos %d neg %d",
+ cnt_pos + cnt_neg, cnt_pos, cnt_neg)
+ return unique_id - base_id
+
+
+def _check_is_max_context(doc_spans, cur_span_index, position):
+ """Check if this is the 'max context' doc span for the token."""
+
+ # Because of the sliding window approach taken to scoring documents, a single
+ # token can appear in multiple documents. E.g.
+ # Doc: the man went to the store and bought a gallon of milk
+ # Span A: the man went to the
+ # Span B: to the store and bought
+ # Span C: and bought a gallon of
+ # ...
+ #
+ # Now the word 'bought' will have two scores from spans B and C. We only
+ # want to consider the score with "maximum context", which we define as
+ # the *minimum* of its left and right context (the *sum* of left and
+ # right context will always be the same, of course).
+ #
+ # In the example the maximum context for 'bought' would be span C since
+ # it has 1 left context and 3 right context, while span B has 4 left context
+ # and 0 right context.
+ best_score = None
+ best_span_index = None
+ for (span_index, doc_span) in enumerate(doc_spans):
+ end = doc_span.start + doc_span.length - 1
+ if position < doc_span.start:
+ continue
+ if position > end:
+ continue
+ num_left_context = position - doc_span.start
+ num_right_context = end - position
+ score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
+ if best_score is None or score > best_score:
+ best_score = score
+ best_span_index = span_index
+
+ return cur_span_index == best_span_index
+
+
+def write_predictions(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ output_prediction_file,
+ output_nbest_file,
+ output_null_log_odds_file,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ verbose=False):
+ """Write final predictions to the json file and log-odds of null if needed."""
+ logging.info("Writing predictions to: %s", (output_prediction_file))
+ logging.info("Writing nbest to: %s", (output_nbest_file))
+
+ all_predictions, all_nbest_json, scores_diff_json = (
+ postprocess_output(
+ all_examples=all_examples,
+ all_features=all_features,
+ all_results=all_results,
+ n_best_size=n_best_size,
+ max_answer_length=max_answer_length,
+ do_lower_case=do_lower_case,
+ version_2_with_negative=version_2_with_negative,
+ null_score_diff_threshold=null_score_diff_threshold,
+ verbose=verbose))
+
+ write_to_json_files(all_predictions, output_prediction_file)
+ write_to_json_files(all_nbest_json, output_nbest_file)
+ if version_2_with_negative:
+ write_to_json_files(scores_diff_json, output_null_log_odds_file)
+
+
+def postprocess_output(all_examples,
+ all_features,
+ all_results,
+ n_best_size,
+ max_answer_length,
+ do_lower_case,
+ version_2_with_negative=False,
+ null_score_diff_threshold=0.0,
+ xlnet_format=False,
+ verbose=False):
+ """Postprocess model output, to form predicton results."""
+
+ del do_lower_case, verbose
+ example_index_to_features = collections.defaultdict(list)
+ for feature in all_features:
+ example_index_to_features[feature.example_index].append(feature)
+
+ unique_id_to_result = {}
+ for result in all_results:
+ unique_id_to_result[result.unique_id] = result
+
+ _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "PrelimPrediction",
+ ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])
+
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict()
+
+ for (example_index, example) in enumerate(all_examples):
+ features = example_index_to_features[example_index]
+
+ prelim_predictions = []
+ # keep track of the minimum score of null start+end of position 0
+ score_null = 1000000 # large and positive
+ min_null_feature_index = 0 # the paragraph slice with min mull score
+ null_start_logit = 0 # the start logit at the slice with min null score
+ null_end_logit = 0 # the end logit at the slice with min null score
+ for (feature_index, feature) in enumerate(features):
+ if feature.unique_id not in unique_id_to_result:
+ logging.info("Skip eval example %s, not in pred.", feature.unique_id)
+ continue
+ result = unique_id_to_result[feature.unique_id]
+
+ # if we could have irrelevant answers, get the min score of irrelevant
+ if version_2_with_negative:
+ if xlnet_format:
+ feature_null_score = result.class_logits
+ else:
+ feature_null_score = result.start_logits[0] + result.end_logits[0]
+ if feature_null_score < score_null:
+ score_null = feature_null_score
+ min_null_feature_index = feature_index
+ null_start_logit = result.start_logits[0]
+ null_end_logit = result.end_logits[0]
+
+ doc_offset = 0 if xlnet_format else feature.tokens.index("[SEP]") + 1
+
+ for (start_index, start_logit,
+ end_index, end_logit) in _get_best_indexes_and_logits(
+ result=result,
+ n_best_size=n_best_size,
+ xlnet_format=xlnet_format):
+ # We could hypothetically create invalid predictions, e.g., predict
+ # that the start of the span is in the question. We throw out all
+ # invalid predictions.
+ if start_index - doc_offset >= len(feature.tok_start_to_orig_index):
+ continue
+ if end_index - doc_offset >= len(feature.tok_end_to_orig_index):
+ continue
+ if not feature.token_is_max_context.get(start_index, False):
+ continue
+ if end_index < start_index:
+ continue
+ length = end_index - start_index + 1
+ if length > max_answer_length:
+ continue
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=feature_index,
+ start_index=start_index - doc_offset,
+ end_index=end_index - doc_offset,
+ start_logit=start_logit,
+ end_logit=end_logit))
+
+ if version_2_with_negative and not xlnet_format:
+ prelim_predictions.append(
+ _PrelimPrediction(
+ feature_index=min_null_feature_index,
+ start_index=-1,
+ end_index=-1,
+ start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ prelim_predictions = sorted(
+ prelim_predictions,
+ key=lambda x: (x.start_logit + x.end_logit),
+ reverse=True)
+
+ _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
+ "NbestPrediction", ["text", "start_logit", "end_logit"])
+
+ seen_predictions = {}
+ nbest = []
+ for pred in prelim_predictions:
+ if len(nbest) >= n_best_size:
+ break
+ feature = features[pred.feature_index]
+ if pred.start_index >= 0 or xlnet_format: # this is a non-null prediction
+ tok_start_to_orig_index = feature.tok_start_to_orig_index
+ tok_end_to_orig_index = feature.tok_end_to_orig_index
+ start_orig_pos = tok_start_to_orig_index[pred.start_index]
+ end_orig_pos = tok_end_to_orig_index[pred.end_index]
+
+ paragraph_text = example.paragraph_text
+ final_text = paragraph_text[start_orig_pos:end_orig_pos + 1].strip()
+ if final_text in seen_predictions:
+ continue
+
+ seen_predictions[final_text] = True
+ else:
+ final_text = ""
+ seen_predictions[final_text] = True
+
+ nbest.append(
+ _NbestPrediction(
+ text=final_text,
+ start_logit=pred.start_logit,
+ end_logit=pred.end_logit))
+
+ # if we didn't inlude the empty option in the n-best, include it
+ if version_2_with_negative and not xlnet_format:
+ if "" not in seen_predictions:
+ nbest.append(
+ _NbestPrediction(
+ text="", start_logit=null_start_logit,
+ end_logit=null_end_logit))
+ # In very rare edge cases we could have no valid predictions. So we
+ # just create a nonce prediction in this case to avoid failure.
+ if not nbest:
+ nbest.append(
+ _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
+
+ assert len(nbest) >= 1
+
+ total_scores = []
+ best_non_null_entry = None
+ for entry in nbest:
+ total_scores.append(entry.start_logit + entry.end_logit)
+ if not best_non_null_entry:
+ if entry.text:
+ best_non_null_entry = entry
+
+ probs = _compute_softmax(total_scores)
+
+ nbest_json = []
+ for (i, entry) in enumerate(nbest):
+ output = collections.OrderedDict()
+ output["text"] = entry.text
+ output["probability"] = probs[i]
+ output["start_logit"] = entry.start_logit
+ output["end_logit"] = entry.end_logit
+ nbest_json.append(output)
+
+ assert len(nbest_json) >= 1
+
+ if not version_2_with_negative:
+ all_predictions[example.qas_id] = nbest_json[0]["text"]
+ else:
+ assert best_non_null_entry is not None
+ if xlnet_format:
+ score_diff = score_null
+ scores_diff_json[example.qas_id] = score_diff
+ all_predictions[example.qas_id] = best_non_null_entry.text
+ else:
+ # predict "" iff the null score - the score of best non-null > threshold
+ score_diff = score_null - best_non_null_entry.start_logit - (
+ best_non_null_entry.end_logit)
+ scores_diff_json[example.qas_id] = score_diff
+ if score_diff > null_score_diff_threshold:
+ all_predictions[example.qas_id] = ""
+ else:
+ all_predictions[example.qas_id] = best_non_null_entry.text
+
+ all_nbest_json[example.qas_id] = nbest_json
+
+ return all_predictions, all_nbest_json, scores_diff_json
+
+
+def write_to_json_files(json_records, json_file):
+ with tf.io.gfile.GFile(json_file, "w") as writer:
+ writer.write(json.dumps(json_records, indent=4) + "\n")
+
+
+def _get_best_indexes_and_logits(result,
+ n_best_size,
+ xlnet_format=False):
+ """Generates the n-best indexes and logits from a list."""
+ if xlnet_format:
+ for i in range(n_best_size):
+ for j in range(n_best_size):
+ j_index = i * n_best_size + j
+ yield (result.start_indexes[i], result.start_logits[i],
+ result.end_indexes[j_index], result.end_logits[j_index])
+ else:
+ start_index_and_score = sorted(enumerate(result.start_logits),
+ key=lambda x: x[1], reverse=True)
+ end_index_and_score = sorted(enumerate(result.end_logits),
+ key=lambda x: x[1], reverse=True)
+ for i in range(len(start_index_and_score)):
+ if i >= n_best_size:
+ break
+ for j in range(len(end_index_and_score)):
+ if j >= n_best_size:
+ break
+ yield (start_index_and_score[i][0], start_index_and_score[i][1],
+ end_index_and_score[j][0], end_index_and_score[j][1])
+
+
+def _compute_softmax(scores):
+ """Compute softmax probability over raw logits."""
+ if not scores:
+ return []
+
+ max_score = None
+ for score in scores:
+ if max_score is None or score > max_score:
+ max_score = score
+
+ exp_scores = []
+ total_sum = 0.0
+ for score in scores:
+ x = math.exp(score - max_score)
+ exp_scores.append(x)
+ total_sum += x
+
+ probs = []
+ for score in exp_scores:
+ probs.append(score / total_sum)
+ return probs
+
+
+class FeatureWriter(object):
+ """Writes InputFeature to TF example file."""
+
+ def __init__(self, filename, is_training):
+ self.filename = filename
+ self.is_training = is_training
+ self.num_features = 0
+ tf.io.gfile.makedirs(os.path.dirname(filename))
+ self._writer = tf.io.TFRecordWriter(filename)
+
+ def process_feature(self, feature):
+ """Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
+ self.num_features += 1
+
+ def create_int_feature(values):
+ feature = tf.train.Feature(
+ int64_list=tf.train.Int64List(value=list(values)))
+ return feature
+
+ features = collections.OrderedDict()
+ features["unique_ids"] = create_int_feature([feature.unique_id])
+ features["input_ids"] = create_int_feature(feature.input_ids)
+ features["input_mask"] = create_int_feature(feature.input_mask)
+ features["segment_ids"] = create_int_feature(feature.segment_ids)
+ if feature.paragraph_mask is not None:
+ features["paragraph_mask"] = create_int_feature(feature.paragraph_mask)
+ if feature.class_index is not None:
+ features["class_index"] = create_int_feature([feature.class_index])
+
+ if self.is_training:
+ features["start_positions"] = create_int_feature([feature.start_position])
+ features["end_positions"] = create_int_feature([feature.end_position])
+ impossible = 0
+ if feature.is_impossible:
+ impossible = 1
+ features["is_impossible"] = create_int_feature([impossible])
+
+ tf_example = tf.train.Example(features=tf.train.Features(feature=features))
+ self._writer.write(tf_example.SerializeToString())
+
+ def close(self):
+ self._writer.close()
+
+
+def generate_tf_record_from_json_file(input_file_path,
+ sp_model_file,
+ output_path,
+ translated_input_folder=None,
+ max_seq_length=384,
+ do_lower_case=True,
+ max_query_length=64,
+ doc_stride=128,
+ xlnet_format=False,
+ version_2_with_negative=False):
+ """Generates and saves training data into a tf record file."""
+ train_examples = read_squad_examples(
+ input_file=input_file_path,
+ is_training=True,
+ version_2_with_negative=version_2_with_negative,
+ translated_input_folder=translated_input_folder)
+ tokenizer = tokenization.FullSentencePieceTokenizer(
+ sp_model_file=sp_model_file)
+ train_writer = FeatureWriter(
+ filename=output_path, is_training=True)
+ number_of_examples = convert_examples_to_features(
+ examples=train_examples,
+ tokenizer=tokenizer,
+ max_seq_length=max_seq_length,
+ doc_stride=doc_stride,
+ max_query_length=max_query_length,
+ is_training=True,
+ output_fn=train_writer.process_feature,
+ xlnet_format=xlnet_format,
+ do_lower_case=do_lower_case)
+ train_writer.close()
+
+ meta_data = {
+ "task_type": "bert_squad",
+ "train_data_size": number_of_examples,
+ "max_seq_length": max_seq_length,
+ "max_query_length": max_query_length,
+ "doc_stride": doc_stride,
+ "version_2_with_negative": version_2_with_negative,
+ }
+
+ return meta_data
diff --git a/modeling/official/nlp/data/tagging_data_lib.py b/modeling/official/nlp/data/tagging_data_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bc109dce6cf7ad87481e5f111faccd4466854f1
--- /dev/null
+++ b/modeling/official/nlp/data/tagging_data_lib.py
@@ -0,0 +1,426 @@
+# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Library to process data for tagging task such as NER/POS."""
+import collections
+import os
+
+from absl import logging
+import tensorflow as tf, tf_keras
+
+from official.nlp.data import classifier_data_lib
+from official.nlp.tools import tokenization
+
+# A negative label id for the padding label, which will not contribute
+# to loss/metrics in training.
+_PADDING_LABEL_ID = -1
+
+# The special unknown token, used to substitute a word which has too many
+# subwords after tokenization.
+_UNK_TOKEN = "[UNK]"
+
+
+class InputExample(object):
+ """A single training/test example for token classification."""
+
+ def __init__(self,
+ sentence_id,
+ sub_sentence_id=0,
+ words=None,
+ label_ids=None):
+ """Constructs an InputExample."""
+ self.sentence_id = sentence_id
+ self.sub_sentence_id = sub_sentence_id
+ self.words = words if words else []
+ self.label_ids = label_ids if label_ids else []
+
+ def add_word_and_label_id(self, word, label_id):
+ """Adds word and label_id pair in the example."""
+ self.words.append(word)
+ self.label_ids.append(label_id)
+
+
+def _read_one_file(file_name, label_list):
+ """Reads one file and returns a list of `InputExample` instances."""
+ lines = tf.io.gfile.GFile(file_name, "r").readlines()
+ examples = []
+ label_id_map = {label: i for i, label in enumerate(label_list)}
+ sentence_id = 0
+ example = InputExample(sentence_id=0)
+ for line in lines:
+ line = line.strip("\n")
+ if line:
+ # The format is: \t