deanna-emery commited on
Commit
5672777
1 Parent(s): 9e6df20
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. modeling/official/README-TPU.md +32 -0
  2. modeling/official/README.md +166 -0
  3. modeling/official/__init__.py +14 -0
  4. modeling/official/common/__init__.py +15 -0
  5. modeling/official/common/dataset_fn.py +44 -0
  6. modeling/official/common/distribute_utils.py +233 -0
  7. modeling/official/common/distribute_utils_test.py +124 -0
  8. modeling/official/common/flags.py +114 -0
  9. modeling/official/common/registry_imports.py +20 -0
  10. modeling/official/common/streamz_counters.py +27 -0
  11. modeling/official/core/__init__.py +31 -0
  12. modeling/official/core/actions.py +236 -0
  13. modeling/official/core/actions_test.py +131 -0
  14. modeling/official/core/base_task.py +360 -0
  15. modeling/official/core/base_trainer.py +498 -0
  16. modeling/official/core/base_trainer_test.py +363 -0
  17. modeling/official/core/config_definitions.py +309 -0
  18. modeling/official/core/exp_factory.py +32 -0
  19. modeling/official/core/export_base.py +182 -0
  20. modeling/official/core/export_base_test.py +133 -0
  21. modeling/official/core/file_writers.py +80 -0
  22. modeling/official/core/file_writers_test.py +53 -0
  23. modeling/official/core/input_reader.py +591 -0
  24. modeling/official/core/registry.py +101 -0
  25. modeling/official/core/registry_test.py +88 -0
  26. modeling/official/core/savedmodel_checkpoint_manager.py +258 -0
  27. modeling/official/core/savedmodel_checkpoint_manager_test.py +125 -0
  28. modeling/official/core/task_factory.py +70 -0
  29. modeling/official/core/test_utils.py +59 -0
  30. modeling/official/core/tf_example_builder.py +144 -0
  31. modeling/official/core/tf_example_builder_test.py +165 -0
  32. modeling/official/core/tf_example_feature_key.py +62 -0
  33. modeling/official/core/tf_example_feature_key_test.py +49 -0
  34. modeling/official/core/train_lib.py +372 -0
  35. modeling/official/core/train_lib_test.py +280 -0
  36. modeling/official/core/train_utils.py +610 -0
  37. modeling/official/core/train_utils_test.py +215 -0
  38. modeling/official/legacy/README.md +5 -0
  39. modeling/official/legacy/__init__.py +14 -0
  40. modeling/official/legacy/albert/README.md +4 -0
  41. modeling/official/legacy/albert/__init__.py +14 -0
  42. modeling/official/legacy/albert/configs.py +50 -0
  43. modeling/official/legacy/bert/README.md +395 -0
  44. modeling/official/legacy/bert/__init__.py +15 -0
  45. modeling/official/legacy/bert/bert_cloud_tpu.md +110 -0
  46. modeling/official/legacy/bert/bert_models.py +365 -0
  47. modeling/official/legacy/bert/bert_models_test.py +106 -0
  48. modeling/official/legacy/bert/common_flags.py +125 -0
  49. modeling/official/legacy/bert/configs.py +104 -0
  50. modeling/official/legacy/bert/export_tfhub.py +139 -0
modeling/official/README-TPU.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Offically Supported TensorFlow 2.1+ Models on Cloud TPU
2
+
3
+ ## Natural Language Processing
4
+
5
+ * [bert](nlp/bert): A powerful pre-trained language representation model:
6
+ BERT, which stands for Bidirectional Encoder Representations from
7
+ Transformers.
8
+ [BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
9
+ * [transformer](nlp/transformer): A transformer model to translate the WMT
10
+ English to German dataset.
11
+ [Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
12
+
13
+ ## Computer Vision
14
+
15
+ * [efficientnet](vision/image_classification): A family of convolutional
16
+ neural networks that scale by balancing network depth, width, and
17
+ resolution and can be used to classify ImageNet's dataset of 1000 classes.
18
+ See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
19
+ * [mnist](vision/image_classification): A basic model to classify digits
20
+ from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
21
+ * [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
22
+ * [resnet](vision/image_classification): A deep residual network that can
23
+ be used to classify ImageNet's dataset of 1000 classes.
24
+ See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
25
+ * [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
26
+ * [shapemask](vision/detection): An object detection and instance segmentation model using shape priors. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/ZbXgVoc6Rf6mBRlPj0JpLA).
27
+
28
+ ## Recommendation
29
+ * [dlrm](recommendation/ranking): [Deep Learning Recommendation Model for
30
+ Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
31
+ * [dcn v2](recommendation/ranking): [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535).
32
+ * [ncf](recommendation): Neural Collaborative Filtering. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/0k3gKjZlR1ewkVTRyLB6IQ).
modeling/official/README.md ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <img src="https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png">
3
+ </div>
4
+
5
+ # TensorFlow Official Models
6
+
7
+ The TensorFlow official models are a collection of models
8
+ that use TensorFlow’s high-level APIs.
9
+ They are intended to be well-maintained, tested, and kept up to date
10
+ with the latest TensorFlow API.
11
+
12
+ They should also be reasonably optimized for fast performance while still
13
+ being easy to read.
14
+ These models are used as end-to-end tests, ensuring that the models run
15
+ with the same or improved speed and performance with each new TensorFlow build.
16
+
17
+ The API documentation of the latest stable release is published to
18
+ [tensorflow.org](https://www.tensorflow.org/api_docs/python/tfm).
19
+
20
+ ## More models to come!
21
+
22
+ The team is actively developing new models.
23
+ In the near future, we will add:
24
+
25
+ * State-of-the-art language understanding models.
26
+ * State-of-the-art image classification models.
27
+ * State-of-the-art object detection and instance segmentation models.
28
+ * State-of-the-art video classification models.
29
+
30
+ ## Table of Contents
31
+
32
+ - [Models and Implementations](#models-and-implementations)
33
+ * [Computer Vision](#computer-vision)
34
+ + [Image Classification](#image-classification)
35
+ + [Object Detection and Segmentation](#object-detection-and-segmentation)
36
+ + [Video Classification](#video-classification)
37
+ * [Natural Language Processing](#natural-language-processing)
38
+ * [Recommendation](#recommendation)
39
+ - [How to get started with the official models](#how-to-get-started-with-the-official-models)
40
+ - [Contributions](#contributions)
41
+
42
+ ## Models and Implementations
43
+
44
+ ### [Computer Vision](vision/README.md)
45
+
46
+ #### Image Classification
47
+
48
+ | Model | Reference (Paper) |
49
+ |-------|-------------------|
50
+ | [ResNet](vision/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
51
+ | [ResNet-RS](vision/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
52
+ | [EfficientNet](vision/MODEL_GARDEN.md) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
53
+ | [Vision Transformer](vision/MODEL_GARDEN.md) | [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) |
54
+
55
+ #### Object Detection and Segmentation
56
+
57
+ | Model | Reference (Paper) |
58
+ |-------|-------------------|
59
+ | [RetinaNet](vision/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
60
+ | [Mask R-CNN](vision/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
61
+ | [YOLO](projects/yolo/README.md) | [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696) |
62
+ | [SpineNet](vision/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
63
+ | [Cascade RCNN-RS and RetinaNet-RS](vision/MODEL_GARDEN.md) | [Simple Training Strategies and Model Scaling for Object Detection](https://arxiv.org/abs/2107.00057)|
64
+
65
+ #### Video Classification
66
+
67
+ | Model | Reference (Paper) |
68
+ |-------|-------------------|
69
+ | [Mobile Video Networks (MoViNets)](projects/movinet) | [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) |
70
+
71
+ ### [Natural Language Processing](nlp/README.md)
72
+
73
+ #### Pre-trained Language Model
74
+
75
+ | Model | Reference (Paper) |
76
+ |-------|-------------------|
77
+ | [ALBERT](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
78
+ | [BERT](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
79
+ | [ELECTRA](nlp/tasks/electra_task.py) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) |
80
+
81
+
82
+ #### Neural Machine Translation
83
+
84
+ | Model | Reference (Paper) |
85
+ |-------|-------------------|
86
+ | [Transformer](nlp/MODEL_GARDEN.md#available-model-configs) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
87
+
88
+ #### Natural Language Generation
89
+
90
+ | Model | Reference (Paper) |
91
+ |-------|-------------------|
92
+ | [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
93
+
94
+
95
+ #### Knowledge Distillation
96
+
97
+ | Model | Reference (Paper) |
98
+ |-------|-------------------|
99
+ | [MobileBERT](projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
100
+
101
+ ### Recommendation
102
+
103
+ Model | Reference (Paper)
104
+ -------------------------------- | -----------------
105
+ [DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
106
+ [DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
107
+ [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)
108
+
109
+ ## How to get started with the official models
110
+
111
+ * The official models in the master branch are developed using
112
+ [master branch of TensorFlow 2](https://github.com/tensorflow/tensorflow/tree/master).
113
+ When you clone (the repository) or download (`pip` binary) master branch of
114
+ official models , master branch of TensorFlow gets downloaded as a
115
+ dependency. This is equivalent to the following.
116
+
117
+ ```shell
118
+ pip3 install tf-models-nightly
119
+ pip3 install tensorflow-text-nightly # when model uses `nlp` packages
120
+ ```
121
+
122
+ * Incase of stable versions, targeting a specific release, Tensorflow-models
123
+ repository version numbers match with the target TensorFlow release. For
124
+ example, [TensorFlow-models v2.8.x](https://github.com/tensorflow/models/releases/tag/v2.8.0)
125
+ is compatible with [TensorFlow v2.8.x](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0).
126
+ This is equivalent to the following:
127
+
128
+ ```shell
129
+ pip3 install tf-models-official==2.8.0
130
+ pip3 install tensorflow-text==2.8.0 # when models in uses `nlp` packages
131
+ ```
132
+
133
+ Starting from 2.9.x release, we release the modeling library as
134
+ `tensorflow_models` package and users can `import tensorflow_models` directly to
135
+ access to the exported symbols. If you are
136
+ using the latest nightly version or github code directly, please follow the
137
+ docstrings in the github.
138
+
139
+ Please follow the below steps before running models in this repository.
140
+
141
+ ### Requirements
142
+
143
+ * The latest TensorFlow Model Garden release and the latest TensorFlow 2
144
+ * If you are on a version of TensorFlow earlier than 2.2, please
145
+ upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
146
+ * Python 3.7+
147
+
148
+ Our integration tests run with Python 3.7. Although Python 3.6 should work, we
149
+ don't recommend earlier versions.
150
+
151
+ ### Installation
152
+
153
+ Please check [here](https://github.com/tensorflow/models#Installation) for the
154
+ instructions.
155
+
156
+ Available pypi packages:
157
+
158
+ * [tf-models-official](https://pypi.org/project/tf-models-official/)
159
+ * [tf-models-nightly](https://pypi.org/project/tf-models-nightly/): nightly
160
+ release with the latest changes.
161
+ * [tf-models-no-deps](https://pypi.org/project/tf-models-no-deps/): without
162
+ `tensorflow` and `tensorflow-text` in the `install_requires` list.
163
+
164
+ ## Contributions
165
+
166
+ If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
modeling/official/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
modeling/official/common/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
modeling/official/common/dataset_fn.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2020 The TensorFlow Authors. All Rights Reserved.
16
+ #
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at
20
+ #
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+ #
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+ # ==============================================================================
29
+ """Utility library for picking an appropriate dataset function."""
30
+
31
+ import functools
32
+ from typing import Any, Callable, Type, Union
33
+
34
+ import tensorflow as tf, tf_keras
35
+
36
+ PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
37
+
38
+
39
+ def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
40
+ if file_type == 'tfrecord':
41
+ return tf.data.TFRecordDataset
42
+ if file_type == 'tfrecord_compressed':
43
+ return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
44
+ raise ValueError('Unrecognized file_type: {}'.format(file_type))
modeling/official/common/distribute_utils.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Helper functions for running models in a distributed setting."""
16
+
17
+ import json
18
+ import os
19
+ import tensorflow as tf, tf_keras
20
+
21
+
22
+ def _collective_communication(all_reduce_alg):
23
+ """Return a CollectiveCommunication based on all_reduce_alg.
24
+
25
+ Args:
26
+ all_reduce_alg: a string specifying which collective communication to pick,
27
+ or None.
28
+
29
+ Returns:
30
+ tf.distribute.experimental.CollectiveCommunication object
31
+
32
+ Raises:
33
+ ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
34
+ """
35
+ collective_communication_options = {
36
+ None: tf.distribute.experimental.CollectiveCommunication.AUTO,
37
+ "ring": tf.distribute.experimental.CollectiveCommunication.RING,
38
+ "nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
39
+ }
40
+ if all_reduce_alg not in collective_communication_options:
41
+ raise ValueError(
42
+ "When used with `multi_worker_mirrored`, valid values for "
43
+ "all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
44
+ all_reduce_alg))
45
+ return collective_communication_options[all_reduce_alg]
46
+
47
+
48
+ def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
49
+ """Return a CrossDeviceOps based on all_reduce_alg and num_packs.
50
+
51
+ Args:
52
+ all_reduce_alg: a string specifying which cross device op to pick, or None.
53
+ num_packs: an integer specifying number of packs for the cross device op.
54
+
55
+ Returns:
56
+ tf.distribute.CrossDeviceOps object or None.
57
+
58
+ Raises:
59
+ ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
60
+ """
61
+ if all_reduce_alg is None:
62
+ return None
63
+ mirrored_all_reduce_options = {
64
+ "nccl": tf.distribute.NcclAllReduce,
65
+ "hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
66
+ }
67
+ if all_reduce_alg not in mirrored_all_reduce_options:
68
+ raise ValueError(
69
+ "When used with `mirrored`, valid values for all_reduce_alg are "
70
+ "[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
71
+ all_reduce_alg))
72
+ cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
73
+ return cross_device_ops_class(num_packs=num_packs)
74
+
75
+
76
+ def tpu_initialize(tpu_address):
77
+ """Initializes TPU for TF 2.x training.
78
+
79
+ Args:
80
+ tpu_address: string, bns address of master TPU worker.
81
+
82
+ Returns:
83
+ A TPUClusterResolver.
84
+ """
85
+ cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
86
+ tpu=tpu_address)
87
+ if tpu_address not in ("", "local"):
88
+ tf.config.experimental_connect_to_cluster(cluster_resolver)
89
+ tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
90
+ return cluster_resolver
91
+
92
+
93
+ def get_distribution_strategy(distribution_strategy="mirrored",
94
+ num_gpus=0,
95
+ all_reduce_alg=None,
96
+ num_packs=1,
97
+ tpu_address=None,
98
+ **kwargs):
99
+ """Return a Strategy for running the model.
100
+
101
+ Args:
102
+ distribution_strategy: a string specifying which distribution strategy to
103
+ use. Accepted values are "off", "one_device", "mirrored",
104
+ "parameter_server", "multi_worker_mirrored", and "tpu" -- case
105
+ insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
106
+ "off" means to use the default strategy which is obtained from
107
+ tf.distribute.get_strategy (for details on the default strategy, see
108
+ https://www.tensorflow.org/guide/distributed_training#default_strategy).
109
+ num_gpus: Number of GPUs to run this model.
110
+ all_reduce_alg: Optional. Specifies which algorithm to use when performing
111
+ all-reduce. For `MirroredStrategy`, valid values are "nccl" and
112
+ "hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
113
+ "ring" and "nccl". If None, DistributionStrategy will choose based on
114
+ device topology.
115
+ num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
116
+ or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
117
+ tpu_address: Optional. String that represents TPU to connect to. Must not be
118
+ None if `distribution_strategy` is set to `tpu`.
119
+ **kwargs: Additional kwargs for internal usages.
120
+
121
+ Returns:
122
+ tf.distribute.Strategy object.
123
+ Raises:
124
+ ValueError: if `distribution_strategy` is "off" or "one_device" and
125
+ `num_gpus` is larger than 1; or `num_gpus` is negative or if
126
+ `distribution_strategy` is `tpu` but `tpu_address` is not specified.
127
+ """
128
+ del kwargs
129
+ if num_gpus < 0:
130
+ raise ValueError("`num_gpus` can not be negative.")
131
+
132
+ if not isinstance(distribution_strategy, str):
133
+ msg = ("distribution_strategy must be a string but got: %s." %
134
+ (distribution_strategy,))
135
+ if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
136
+ msg += (" If you meant to pass the string 'off', make sure you add "
137
+ "quotes around 'off' so that yaml interprets it as a string "
138
+ "instead of a bool.")
139
+ raise ValueError(msg)
140
+
141
+ distribution_strategy = distribution_strategy.lower()
142
+ if distribution_strategy == "off":
143
+ if num_gpus > 1:
144
+ raise ValueError(f"When {num_gpus} GPUs are specified, "
145
+ "distribution_strategy flag cannot be set to `off`.")
146
+ # Return the default distribution strategy.
147
+ return tf.distribute.get_strategy()
148
+
149
+ if distribution_strategy == "tpu":
150
+ # When tpu_address is an empty string, we communicate with local TPUs.
151
+ cluster_resolver = tpu_initialize(tpu_address)
152
+ return tf.distribute.TPUStrategy(cluster_resolver)
153
+
154
+ if distribution_strategy == "multi_worker_mirrored":
155
+ return tf.distribute.experimental.MultiWorkerMirroredStrategy(
156
+ communication=_collective_communication(all_reduce_alg))
157
+
158
+ if distribution_strategy == "one_device":
159
+ if num_gpus == 0:
160
+ return tf.distribute.OneDeviceStrategy("device:CPU:0")
161
+ if num_gpus > 1:
162
+ raise ValueError("`OneDeviceStrategy` can not be used for more than "
163
+ "one device.")
164
+ return tf.distribute.OneDeviceStrategy("device:GPU:0")
165
+
166
+ if distribution_strategy == "mirrored":
167
+ if num_gpus == 0:
168
+ devices = ["device:CPU:0"]
169
+ else:
170
+ devices = ["device:GPU:%d" % i for i in range(num_gpus)]
171
+ return tf.distribute.MirroredStrategy(
172
+ devices=devices,
173
+ cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
174
+
175
+ if distribution_strategy == "parameter_server":
176
+ cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
177
+ return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
178
+
179
+ raise ValueError("Unrecognized Distribution Strategy: %r" %
180
+ distribution_strategy)
181
+
182
+
183
+ def configure_cluster(worker_hosts=None, task_index=-1):
184
+ """Set multi-worker cluster spec in TF_CONFIG environment variable.
185
+
186
+ Args:
187
+ worker_hosts: comma-separated list of worker ip:port pairs.
188
+ task_index: index of the worker.
189
+
190
+ Returns:
191
+ Number of workers in the cluster.
192
+ """
193
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
194
+ if tf_config:
195
+ num_workers = (
196
+ len(tf_config["cluster"].get("chief", [])) +
197
+ len(tf_config["cluster"].get("worker", [])))
198
+ elif worker_hosts:
199
+ workers = worker_hosts.split(",")
200
+ num_workers = len(workers)
201
+ if num_workers > 1 and task_index < 0:
202
+ raise ValueError("Must specify task_index when number of workers > 1")
203
+ task_index = 0 if num_workers == 1 else task_index
204
+ os.environ["TF_CONFIG"] = json.dumps({
205
+ "cluster": {
206
+ "worker": workers
207
+ },
208
+ "task": {
209
+ "type": "worker",
210
+ "index": task_index
211
+ }
212
+ })
213
+ else:
214
+ num_workers = 1
215
+ return num_workers
216
+
217
+
218
+ def get_strategy_scope(strategy):
219
+ if strategy:
220
+ strategy_scope = strategy.scope()
221
+ else:
222
+ strategy_scope = DummyContextManager()
223
+
224
+ return strategy_scope
225
+
226
+
227
+ class DummyContextManager(object):
228
+
229
+ def __enter__(self):
230
+ pass
231
+
232
+ def __exit__(self, *args):
233
+ pass
modeling/official/common/distribute_utils_test.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for distribution util functions."""
16
+
17
+ import sys
18
+ import tensorflow as tf, tf_keras
19
+
20
+ from official.common import distribute_utils
21
+
22
+ TPU_TEST = 'test_tpu' in sys.argv[0]
23
+
24
+
25
+ class DistributeUtilsTest(tf.test.TestCase):
26
+ """Tests for distribute util functions."""
27
+
28
+ def test_invalid_args(self):
29
+ with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
30
+ _ = distribute_utils.get_distribution_strategy(num_gpus=-1)
31
+
32
+ with self.assertRaisesRegex(ValueError,
33
+ '.*If you meant to pass the string .*'):
34
+ _ = distribute_utils.get_distribution_strategy(
35
+ distribution_strategy=False, num_gpus=0)
36
+ with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
37
+ _ = distribute_utils.get_distribution_strategy(
38
+ distribution_strategy='off', num_gpus=2)
39
+ with self.assertRaisesRegex(ValueError,
40
+ '`OneDeviceStrategy` can not be used.*'):
41
+ _ = distribute_utils.get_distribution_strategy(
42
+ distribution_strategy='one_device', num_gpus=2)
43
+
44
+ def test_one_device_strategy_cpu(self):
45
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
46
+ self.assertEquals(ds.num_replicas_in_sync, 1)
47
+ self.assertEquals(len(ds.extended.worker_devices), 1)
48
+ self.assertIn('CPU', ds.extended.worker_devices[0])
49
+
50
+ def test_one_device_strategy_gpu(self):
51
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
52
+ self.assertEquals(ds.num_replicas_in_sync, 1)
53
+ self.assertEquals(len(ds.extended.worker_devices), 1)
54
+ self.assertIn('GPU', ds.extended.worker_devices[0])
55
+
56
+ def test_mirrored_strategy(self):
57
+ # CPU only.
58
+ _ = distribute_utils.get_distribution_strategy(num_gpus=0)
59
+ # 5 GPUs.
60
+ ds = distribute_utils.get_distribution_strategy(num_gpus=5)
61
+ self.assertEquals(ds.num_replicas_in_sync, 5)
62
+ self.assertEquals(len(ds.extended.worker_devices), 5)
63
+ for device in ds.extended.worker_devices:
64
+ self.assertIn('GPU', device)
65
+
66
+ _ = distribute_utils.get_distribution_strategy(
67
+ distribution_strategy='mirrored',
68
+ num_gpus=2,
69
+ all_reduce_alg='nccl',
70
+ num_packs=2)
71
+ with self.assertRaisesRegex(
72
+ ValueError,
73
+ 'When used with `mirrored`, valid values for all_reduce_alg are.*'):
74
+ _ = distribute_utils.get_distribution_strategy(
75
+ distribution_strategy='mirrored',
76
+ num_gpus=2,
77
+ all_reduce_alg='dummy',
78
+ num_packs=2)
79
+
80
+ def test_mwms(self):
81
+ distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
82
+ ds = distribute_utils.get_distribution_strategy(
83
+ 'multi_worker_mirrored', all_reduce_alg='nccl')
84
+ self.assertIsInstance(
85
+ ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
86
+
87
+ with self.assertRaisesRegex(
88
+ ValueError,
89
+ 'When used with `multi_worker_mirrored`, valid values.*'):
90
+ _ = distribute_utils.get_distribution_strategy(
91
+ 'multi_worker_mirrored', all_reduce_alg='dummy')
92
+
93
+ def test_no_strategy(self):
94
+ ds = distribute_utils.get_distribution_strategy('off')
95
+ self.assertIs(ds, tf.distribute.get_strategy())
96
+
97
+ def test_tpu_strategy(self):
98
+ if not TPU_TEST:
99
+ self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
100
+ with self.assertRaises(ValueError):
101
+ _ = distribute_utils.get_distribution_strategy('tpu')
102
+
103
+ ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
104
+ self.assertIsInstance(
105
+ ds, tf.distribute.TPUStrategy)
106
+
107
+ def test_invalid_strategy(self):
108
+ with self.assertRaisesRegexp(
109
+ ValueError,
110
+ 'distribution_strategy must be a string but got: False. If'):
111
+ distribute_utils.get_distribution_strategy(False)
112
+ with self.assertRaisesRegexp(
113
+ ValueError, 'distribution_strategy must be a string but got: 1'):
114
+ distribute_utils.get_distribution_strategy(1)
115
+
116
+ def test_get_strategy_scope(self):
117
+ ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
118
+ with distribute_utils.get_strategy_scope(ds):
119
+ self.assertIs(tf.distribute.get_strategy(), ds)
120
+ with distribute_utils.get_strategy_scope(None):
121
+ self.assertIsNot(tf.distribute.get_strategy(), ds)
122
+
123
+ if __name__ == '__main__':
124
+ tf.test.main()
modeling/official/common/flags.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The central place to define flags."""
16
+
17
+ from absl import flags
18
+
19
+
20
+ def define_flags():
21
+ """Defines flags.
22
+
23
+ All flags are defined as optional, but in practice most models use some of
24
+ these flags and so mark_flags_as_required() should be called after calling
25
+ this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
26
+ For example:
27
+
28
+ ```
29
+ from absl import flags
30
+ from official.common import flags as tfm_flags # pylint: disable=line-too-long
31
+ ...
32
+ tfm_flags.define_flags()
33
+ flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
34
+ ```
35
+
36
+ The reason all flags are optional is because unit tests often do not set or
37
+ use any of the flags.
38
+ """
39
+ flags.DEFINE_string(
40
+ 'experiment', default=None, help=
41
+ 'The experiment type registered, specifying an ExperimentConfig.')
42
+
43
+ flags.DEFINE_enum(
44
+ 'mode',
45
+ default=None,
46
+ enum_values=[
47
+ 'train', 'eval', 'train_and_eval', 'continuous_eval',
48
+ 'continuous_train_and_eval', 'train_and_validate',
49
+ 'train_and_post_eval'
50
+ ],
51
+ help='Mode to run: `train`, `eval`, `train_and_eval`, '
52
+ '`continuous_eval`, `continuous_train_and_eval` and '
53
+ '`train_and_validate` (which is not implemented in '
54
+ 'the open source version).')
55
+
56
+ flags.DEFINE_string(
57
+ 'model_dir',
58
+ default=None,
59
+ help='The directory where the model and training/evaluation summaries'
60
+ 'are stored.')
61
+
62
+ flags.DEFINE_multi_string(
63
+ 'config_file',
64
+ default=None,
65
+ help='YAML/JSON files which specifies overrides. The override order '
66
+ 'follows the order of args. Note that each file '
67
+ 'can be used as an override template to override the default parameters '
68
+ 'specified in Python. If the same parameter is specified in both '
69
+ '`--config_file` and `--params_override`, `config_file` will be used '
70
+ 'first, followed by params_override.')
71
+
72
+ flags.DEFINE_string(
73
+ 'params_override',
74
+ default=None,
75
+ help='a YAML/JSON string or a YAML file which specifies additional '
76
+ 'overrides over the default parameters and those specified in '
77
+ '`--config_file`. Note that this is supposed to be used only to override '
78
+ 'the model parameters, but not the parameters like TPU specific flags. '
79
+ 'One canonical use case of `--config_file` and `--params_override` is '
80
+ 'users first define a template config file using `--config_file`, then '
81
+ 'use `--params_override` to adjust the minimal set of tuning parameters, '
82
+ 'for example setting up different `train_batch_size`. The final override '
83
+ 'order of parameters: default_model_params --> params from config_file '
84
+ '--> params in params_override. See also the help message of '
85
+ '`--config_file`.')
86
+
87
+ # The libraries rely on gin often make mistakes that include flags inside
88
+ # the library files which causes conflicts.
89
+ try:
90
+ flags.DEFINE_multi_string(
91
+ 'gin_file', default=None, help='List of paths to the config files.')
92
+ except flags.DuplicateFlagError:
93
+ pass
94
+
95
+ try:
96
+ flags.DEFINE_multi_string(
97
+ 'gin_params',
98
+ default=None,
99
+ help='Newline separated list of Gin parameter bindings.')
100
+ except flags.DuplicateFlagError:
101
+ pass
102
+
103
+ flags.DEFINE_string(
104
+ 'tpu',
105
+ default=None,
106
+ help='The Cloud TPU to use for training. This should be either the name '
107
+ 'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
108
+ 'url.')
109
+
110
+ flags.DEFINE_string(
111
+ 'tf_data_service', default=None, help='The tf.data service address')
112
+
113
+ flags.DEFINE_string(
114
+ 'tpu_platform', default=None, help='TPU platform type.')
modeling/official/common/registry_imports.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """All necessary imports for registration."""
16
+ # pylint: disable=unused-import
17
+ from official import vision
18
+ from official.nlp import tasks
19
+ from official.nlp.configs import experiment_configs
20
+ from official.utils.testing import mock_task
modeling/official/common/streamz_counters.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Global streamz counters."""
16
+
17
+ from tensorflow.python.eager import monitoring
18
+
19
+
20
+ progressive_policy_creation_counter = monitoring.Counter(
21
+ "/tensorflow/training/fast_training/progressive_policy_creation",
22
+ "Counter for the number of ProgressivePolicy creations.")
23
+
24
+
25
+ stack_vars_to_vars_call_counter = monitoring.Counter(
26
+ "/tensorflow/training/fast_training/tf_vars_to_vars",
27
+ "Counter for the number of low-level stacking API calls.")
modeling/official/core/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Core is shared by both `nlp` and `vision`."""
16
+
17
+ from official.core import actions
18
+ from official.core import base_task
19
+ from official.core import base_trainer
20
+ from official.core import config_definitions
21
+ from official.core import exp_factory
22
+ from official.core import export_base
23
+ from official.core import file_writers
24
+ from official.core import input_reader
25
+ from official.core import registry
26
+ from official.core import savedmodel_checkpoint_manager
27
+ from official.core import task_factory
28
+ from official.core import tf_example_builder
29
+ from official.core import tf_example_feature_key
30
+ from official.core import train_lib
31
+ from official.core import train_utils
modeling/official/core/actions.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Provides TFM orbit actions and associated helper functions/classes."""
16
+
17
+ import os
18
+ from typing import List
19
+ from absl import logging
20
+
21
+ import gin
22
+ import orbit
23
+ import tensorflow as tf, tf_keras
24
+
25
+ from official.core import base_trainer
26
+ from official.core import config_definitions
27
+ from official.modeling import optimization
28
+
29
+
30
+ class PruningAction:
31
+ """Train action to updates pruning related information.
32
+
33
+ This action updates pruning steps at the end of trainig loop, and log
34
+ pruning metrics to tensorboard.
35
+
36
+ This action must be used when training a pruned model to avoid pruning error.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ export_dir: str,
42
+ model: tf_keras.Model,
43
+ optimizer: tf_keras.optimizers.Optimizer,
44
+ ):
45
+ """Initializes the instance.
46
+
47
+ Args:
48
+ export_dir: `str` for the export directory of the pruning summaries.
49
+ model: `tf_keras.Model` model instance used for training. This will be
50
+ used to assign a pruning step to each prunable weight.
51
+ optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for
52
+ training. This will be used to find the current training steps.
53
+ """
54
+ # TODO(b/221490190): Avoid local import when the bug is fixed.
55
+ import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
56
+ self._optimizer = optimizer
57
+ self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
58
+ self.update_pruning_step.set_model(model)
59
+ self.update_pruning_step.on_train_begin()
60
+
61
+ self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
62
+ log_dir=export_dir)
63
+ model.optimizer = optimizer
64
+ self.pruning_summaries.set_model(model)
65
+
66
+ def __call__(self, output: orbit.runner.Output):
67
+ """Update pruning step and log pruning summaries.
68
+
69
+ Args:
70
+ output: The train output.
71
+ """
72
+ self.update_pruning_step.on_epoch_end(batch=None)
73
+ self.pruning_summaries.on_epoch_begin(epoch=None)
74
+
75
+
76
+ class EMACheckpointing:
77
+ """Eval action to save checkpoint with average weights when EMA is used.
78
+
79
+ This action swaps the weights of the model with the average weights, then it
80
+ saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
81
+ expensive for large models, so doing this action in eval is more efficient
82
+ than training.
83
+ """
84
+
85
+ def __init__(self,
86
+ export_dir: str,
87
+ optimizer: tf_keras.optimizers.Optimizer,
88
+ checkpoint: tf.train.Checkpoint,
89
+ max_to_keep: int = 1):
90
+ """Initializes the instance.
91
+
92
+ Args:
93
+ export_dir: `str` for the export directory of the EMA average weights.
94
+ optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for
95
+ training. This will be used to swap the model weights with the average
96
+ weigths.
97
+ checkpoint: `tf.train.Checkpoint` instance.
98
+ max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
99
+ """
100
+ if not isinstance(optimizer, optimization.ExponentialMovingAverage):
101
+ raise ValueError('Optimizer has to be instance of'
102
+ 'optimization.ExponentialMovingAverage for'
103
+ 'EMACheckpointing action')
104
+
105
+ export_dir = os.path.join(export_dir, 'ema_checkpoints')
106
+ tf.io.gfile.makedirs(os.path.dirname(export_dir))
107
+ self._optimizer = optimizer
108
+ self._checkpoint = checkpoint
109
+ self._checkpoint_manager = tf.train.CheckpointManager(
110
+ checkpoint,
111
+ directory=export_dir,
112
+ max_to_keep=max_to_keep,
113
+ checkpoint_name='average_weights')
114
+
115
+ def __call__(self, output: orbit.runner.Output):
116
+ """Swaps model weights, and saves the checkpoint.
117
+
118
+ Args:
119
+ output: The train or eval output.
120
+ """
121
+ self._optimizer.swap_weights()
122
+ self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
123
+ self._optimizer.swap_weights()
124
+
125
+
126
+ class RecoveryAction:
127
+ """Train action to recover from loss blowup.
128
+
129
+ Checks the loss value by the given threshold. If applicable, recover the
130
+ model by reading the checkpoint on disk.
131
+ """
132
+
133
+ def __init__(self, checkpoint_manager: tf.train.CheckpointManager):
134
+ self.checkpoint_manager = checkpoint_manager
135
+
136
+ def __call__(self, _):
137
+ """Recovers the training by triggering checkpoint restoration."""
138
+ # Loads the previous good checkpoint.
139
+ checkpoint_path = self.checkpoint_manager.restore_or_initialize()
140
+ logging.warning('Recovering the model from checkpoint: %s.',
141
+ checkpoint_path)
142
+
143
+
144
+ class RecoveryCondition:
145
+ """Recovery Condition."""
146
+
147
+ def __init__(self,
148
+ global_step: tf.Variable,
149
+ loss_upper_bound: float,
150
+ recovery_begin_steps: int = 0,
151
+ recovery_max_trials: int = 3):
152
+ self.recover_counter = 0
153
+ self.recovery_begin_steps = recovery_begin_steps
154
+ self.recovery_max_trials = recovery_max_trials
155
+ self.loss_upper_bound = loss_upper_bound
156
+ self.global_step = global_step
157
+
158
+ def __call__(self, outputs: orbit.runner.Output):
159
+ loss_value = outputs['training_loss']
160
+ if tf.math.is_nan(loss_value):
161
+ self.recover_counter += 1
162
+ if self.recover_counter > self.recovery_max_trials:
163
+ raise RuntimeError(
164
+ 'The loss value is NaN after training loop and it happens %d times.'
165
+ % self.recover_counter)
166
+ return True
167
+ if (self.global_step >= self.recovery_begin_steps and
168
+ loss_value > self.loss_upper_bound):
169
+ self.recover_counter += 1
170
+ if self.recover_counter > self.recovery_max_trials:
171
+ raise RuntimeError(
172
+ f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.'
173
+ )
174
+ return True
175
+ return False
176
+
177
+
178
+ @gin.configurable
179
+ def get_eval_actions(params: config_definitions.ExperimentConfig,
180
+ trainer: base_trainer.Trainer,
181
+ model_dir: str) -> List[orbit.Action]:
182
+ """Gets eval actions for TFM trainer."""
183
+ eval_actions = []
184
+ # Adds ema checkpointing action to save the average weights under
185
+ # ema_checkpoints subdir.
186
+ if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage):
187
+ eval_actions.append(
188
+ EMACheckpointing(
189
+ export_dir=model_dir,
190
+ optimizer=trainer.optimizer,
191
+ checkpoint=trainer.checkpoint,
192
+ max_to_keep=params.trainer.max_to_keep))
193
+
194
+ return eval_actions
195
+
196
+
197
+ @gin.configurable
198
+ def get_train_actions(
199
+ params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
200
+ model_dir: str,
201
+ checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]:
202
+ """Gets train actions for TFM trainer."""
203
+ train_actions = []
204
+ # Adds pruning callback actions.
205
+ if hasattr(params.task, 'pruning') and params.task.pruning:
206
+ train_actions.append(
207
+ PruningAction(
208
+ export_dir=model_dir,
209
+ model=trainer.model,
210
+ optimizer=trainer.optimizer))
211
+
212
+ if params.trainer.recovery_max_trials >= 0:
213
+ recovery_condition = RecoveryCondition(
214
+ global_step=trainer.global_step,
215
+ loss_upper_bound=params.trainer.loss_upper_bound,
216
+ recovery_begin_steps=params.trainer.recovery_begin_steps,
217
+ recovery_max_trials=params.trainer.recovery_max_trials,
218
+ )
219
+ recover_action = orbit.actions.ConditionalAction(
220
+ condition=recovery_condition,
221
+ action=RecoveryAction(checkpoint_manager),
222
+ )
223
+ train_actions.append(recover_action)
224
+
225
+ if (
226
+ params.trainer.preemption_on_demand_checkpoint
227
+ and trainer.strategy.cluster_resolver
228
+ ):
229
+ on_demand_checkpoint_action = orbit.actions.SaveCheckpointIfPreempted(
230
+ trainer.strategy.cluster_resolver,
231
+ checkpoint_manager,
232
+ trainer.global_step,
233
+ keep_running_after_save=True,
234
+ )
235
+ train_actions.append(on_demand_checkpoint_action)
236
+ return train_actions
modeling/official/core/actions_test.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for TFM actions."""
16
+
17
+ import os
18
+
19
+ from absl.testing import parameterized
20
+ import numpy as np
21
+ import orbit
22
+ import tensorflow as tf, tf_keras
23
+
24
+ from tensorflow.python.distribute import combinations
25
+ from tensorflow.python.distribute import strategy_combinations
26
+ from official.core import actions
27
+ from official.modeling import optimization
28
+
29
+
30
+ class TestModel(tf_keras.Model):
31
+
32
+ def __init__(self):
33
+ super().__init__()
34
+ self.value = tf.Variable(0.0)
35
+ self.dense = tf_keras.layers.Dense(2)
36
+ _ = self.dense(tf.zeros((2, 2), tf.float32))
37
+
38
+ def call(self, x, training=None):
39
+ return self.value + x
40
+
41
+
42
+ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
43
+
44
+ @combinations.generate(
45
+ combinations.combine(
46
+ distribution=[
47
+ strategy_combinations.cloud_tpu_strategy,
48
+ strategy_combinations.one_device_strategy,
49
+ ],))
50
+ def test_ema_checkpointing(self, distribution):
51
+ with distribution.scope():
52
+ directory = self.create_tempdir()
53
+ model = TestModel()
54
+ optimizer = tf_keras.optimizers.SGD()
55
+ optimizer = optimization.ExponentialMovingAverage(
56
+ optimizer, trainable_weights_only=False)
57
+
58
+ # Creats average weights for the model variables. Average weights are
59
+ # initialized to zero.
60
+ optimizer.shadow_copy(model)
61
+ checkpoint = tf.train.Checkpoint(model=model)
62
+
63
+ # Changes model.value to 3, average value is still 0.
64
+ model.value.assign(3)
65
+
66
+ # Checks model.value is 3
67
+ self.assertEqual(model(0.), 3)
68
+ ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
69
+
70
+ ema_action({})
71
+ self.assertNotEmpty(
72
+ tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
73
+
74
+ checkpoint.read(
75
+ tf.train.latest_checkpoint(
76
+ os.path.join(directory, 'ema_checkpoints')))
77
+
78
+ # Checks model.value is 0 after swapping.
79
+ self.assertEqual(model(0.), 0)
80
+
81
+ # Raises an error for a normal optimizer.
82
+ with self.assertRaisesRegex(ValueError,
83
+ 'Optimizer has to be instance of.*'):
84
+ _ = actions.EMACheckpointing(directory, tf_keras.optimizers.SGD(),
85
+ checkpoint)
86
+
87
+ @combinations.generate(
88
+ combinations.combine(
89
+ distribution=[
90
+ strategy_combinations.default_strategy,
91
+ strategy_combinations.cloud_tpu_strategy,
92
+ strategy_combinations.one_device_strategy_gpu,
93
+ ],))
94
+ def test_recovery_condition(self, distribution):
95
+ with distribution.scope():
96
+ global_step = orbit.utils.create_global_step()
97
+ recover_condition = actions.RecoveryCondition(
98
+ global_step, loss_upper_bound=0.5, recovery_max_trials=2)
99
+ outputs = {'training_loss': 0.6}
100
+ self.assertTrue(recover_condition(outputs))
101
+ self.assertTrue(recover_condition(outputs))
102
+ with self.assertRaises(RuntimeError):
103
+ recover_condition(outputs)
104
+
105
+ global_step = orbit.utils.create_global_step()
106
+ recover_condition = actions.RecoveryCondition(
107
+ global_step, loss_upper_bound=0.5, recovery_max_trials=2)
108
+ outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
109
+ self.assertTrue(recover_condition(outputs))
110
+ self.assertTrue(recover_condition(outputs))
111
+ with self.assertRaises(RuntimeError):
112
+ recover_condition(outputs)
113
+
114
+ @combinations.generate(
115
+ combinations.combine(
116
+ distribution=[
117
+ strategy_combinations.one_device_strategy_gpu,
118
+ strategy_combinations.one_device_strategy,
119
+ ],))
120
+ def test_pruning(self, distribution):
121
+ with distribution.scope():
122
+ directory = self.get_temp_dir()
123
+ model = TestModel()
124
+ optimizer = tf_keras.optimizers.SGD()
125
+ pruning = actions.PruningAction(directory, model, optimizer)
126
+
127
+ pruning({})
128
+
129
+
130
+ if __name__ == '__main__':
131
+ tf.test.main()
modeling/official/core/base_task.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Defines the base task abstraction."""
16
+ import abc
17
+ import functools
18
+ from typing import Optional
19
+
20
+ from absl import logging
21
+ import tensorflow as tf, tf_keras
22
+
23
+ from official.core import config_definitions
24
+ from official.modeling import optimization
25
+ from official.modeling import performance
26
+ from official.modeling.privacy import configs
27
+ from official.modeling.privacy import ops
28
+
29
+ OptimizationConfig = optimization.OptimizationConfig
30
+ RuntimeConfig = config_definitions.RuntimeConfig
31
+ DifferentialPrivacyConfig = configs.DifferentialPrivacyConfig
32
+
33
+
34
+ class Task(tf.Module, metaclass=abc.ABCMeta):
35
+ """A single-replica view of training procedure.
36
+
37
+ Tasks provide artifacts for training/validation procedures, including
38
+ loading/iterating over Datasets, training/validation steps, calculating the
39
+ loss and customized metrics with reduction.
40
+ """
41
+
42
+ # Special keys in train/validate step returned logs.
43
+ loss = "loss"
44
+
45
+ def __init__(self,
46
+ params,
47
+ logging_dir: Optional[str] = None,
48
+ name: Optional[str] = None):
49
+ """Task initialization.
50
+
51
+ Args:
52
+ params: the task configuration instance, which can be any of dataclass,
53
+ ConfigDict, namedtuple, etc.
54
+ logging_dir: a string pointing to where the model, summaries etc. will be
55
+ saved. You can also write additional stuff in this directory.
56
+ name: the task name.
57
+ """
58
+ super().__init__(name=name)
59
+ self._task_config = params
60
+ self._logging_dir = (
61
+ logging_dir or ""
62
+ ) # Empty directory hints current working dir.
63
+
64
+ @property
65
+ def task_config(self):
66
+ return self._task_config
67
+
68
+ @property
69
+ def logging_dir(self) -> str:
70
+ return self._logging_dir
71
+
72
+ @classmethod
73
+ def create_optimizer(cls, optimizer_config: OptimizationConfig,
74
+ runtime_config: Optional[RuntimeConfig] = None,
75
+ dp_config: Optional[DifferentialPrivacyConfig] = None):
76
+ """Creates an TF optimizer from configurations.
77
+
78
+ Args:
79
+ optimizer_config: the parameters of the Optimization settings.
80
+ runtime_config: the parameters of the runtime.
81
+ dp_config: the parameter of differential privacy.
82
+
83
+ Returns:
84
+ A tf.optimizers.Optimizer object.
85
+ """
86
+ gradient_transformers = None
87
+ if dp_config is not None:
88
+ logging.info("Adding differential privacy transform with config %s.",
89
+ dp_config.as_dict())
90
+ noise_stddev = dp_config.clipping_norm * dp_config.noise_multiplier
91
+ gradient_transformers = [
92
+ functools.partial(
93
+ ops.clip_l2_norm, l2_norm_clip=dp_config.clipping_norm),
94
+ functools.partial(
95
+ ops.add_noise, noise_stddev=noise_stddev)
96
+ ]
97
+
98
+ opt_factory = optimization.OptimizerFactory(optimizer_config)
99
+ optimizer = opt_factory.build_optimizer(
100
+ opt_factory.build_learning_rate(),
101
+ gradient_transformers=gradient_transformers
102
+ )
103
+ # Configuring optimizer when loss_scale is set in runtime config. This helps
104
+ # avoiding overflow/underflow for float16 computations.
105
+ if runtime_config:
106
+ optimizer = performance.configure_optimizer(
107
+ optimizer,
108
+ use_float16=runtime_config.mixed_precision_dtype == "float16",
109
+ loss_scale=runtime_config.loss_scale)
110
+
111
+ return optimizer
112
+
113
+ def initialize(self, model: tf_keras.Model):
114
+ """[Optional] A callback function used as CheckpointManager's init_fn.
115
+
116
+ This function will be called when no checkpoint is found for the model.
117
+ If there is a checkpoint, the checkpoint will be loaded and this function
118
+ will not be called. You can use this callback function to load a pretrained
119
+ checkpoint, saved under a directory other than the model_dir.
120
+
121
+ Args:
122
+ model: The keras.Model built or used by this task.
123
+ """
124
+ ckpt_dir_or_file = self.task_config.init_checkpoint
125
+ logging.info("Trying to load pretrained checkpoint from %s",
126
+ ckpt_dir_or_file)
127
+ if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
128
+ ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
129
+ if not ckpt_dir_or_file:
130
+ logging.info("No checkpoint file found from %s. Will not load.",
131
+ ckpt_dir_or_file)
132
+ return
133
+
134
+ if hasattr(model, "checkpoint_items"):
135
+ checkpoint_items = model.checkpoint_items
136
+ else:
137
+ checkpoint_items = dict(model=model)
138
+ ckpt = tf.train.Checkpoint(**checkpoint_items)
139
+ status = ckpt.read(ckpt_dir_or_file)
140
+ status.expect_partial().assert_existing_objects_matched()
141
+ logging.info("Finished loading pretrained checkpoint from %s",
142
+ ckpt_dir_or_file)
143
+
144
+ def build_model(self) -> tf_keras.Model:
145
+ """[Optional] Creates model architecture.
146
+
147
+ Returns:
148
+ A model instance.
149
+ """ # pytype: disable=bad-return-type # typed-keras
150
+
151
+ @abc.abstractmethod
152
+ def build_inputs(self,
153
+ params,
154
+ input_context: Optional[tf.distribute.InputContext] = None):
155
+ """Returns a dataset or a nested structure of dataset functions.
156
+
157
+ Dataset functions define per-host datasets with the per-replica batch size.
158
+ With distributed training, this method runs on remote hosts.
159
+
160
+ Args:
161
+ params: hyperparams to create input pipelines, which can be any of
162
+ dataclass, ConfigDict, namedtuple, etc.
163
+ input_context: optional distribution input pipeline context.
164
+
165
+ Returns:
166
+ A nested structure of per-replica input functions.
167
+ """
168
+
169
+ def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
170
+ """Standard interface to compute losses.
171
+
172
+ Args:
173
+ labels: optional label tensors.
174
+ model_outputs: a nested structure of output tensors.
175
+ aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
176
+
177
+ Returns:
178
+ The total loss tensor.
179
+ """
180
+ del model_outputs, labels
181
+
182
+ if aux_losses is None:
183
+ losses = [tf.constant(0.0, dtype=tf.float32)]
184
+ else:
185
+ losses = aux_losses
186
+ total_loss = tf.add_n(losses)
187
+ return total_loss
188
+
189
+ def build_metrics(self, training: bool = True):
190
+ """Gets streaming metrics for training/validation."""
191
+ del training
192
+ return []
193
+
194
+ def process_metrics(self, metrics, labels, model_outputs, **kwargs):
195
+ """Process and update metrics.
196
+
197
+ Called when using custom training loop API.
198
+
199
+ Args:
200
+ metrics: a nested structure of metrics objects. The return of function
201
+ self.build_metrics.
202
+ labels: a tensor or a nested structure of tensors.
203
+ model_outputs: a tensor or a nested structure of tensors. For example,
204
+ output of the keras model built by self.build_model.
205
+ **kwargs: other args.
206
+ """
207
+ for metric in metrics:
208
+ metric.update_state(labels, model_outputs)
209
+
210
+ def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
211
+ """Process and update compiled_metrics.
212
+
213
+ call when using compile/fit API.
214
+
215
+ Args:
216
+ compiled_metrics: the compiled metrics (model.compiled_metrics).
217
+ labels: a tensor or a nested structure of tensors.
218
+ model_outputs: a tensor or a nested structure of tensors. For example,
219
+ output of the keras model built by self.build_model.
220
+ """
221
+ compiled_metrics.update_state(labels, model_outputs)
222
+
223
+ def train_step(self,
224
+ inputs,
225
+ model: tf_keras.Model,
226
+ optimizer: tf_keras.optimizers.Optimizer,
227
+ metrics=None):
228
+ """Does forward and backward.
229
+
230
+ With distribution strategies, this method runs on devices.
231
+
232
+ Args:
233
+ inputs: a dictionary of input tensors.
234
+ model: the model, forward pass definition.
235
+ optimizer: the optimizer for this training step.
236
+ metrics: a nested structure of metrics objects.
237
+
238
+ Returns:
239
+ A dictionary of logs.
240
+ """
241
+ if isinstance(inputs, tuple) and len(inputs) == 2:
242
+ features, labels = inputs
243
+ else:
244
+ features, labels = inputs, inputs
245
+ with tf.GradientTape() as tape:
246
+ outputs = model(features, training=True)
247
+ # Computes per-replica loss.
248
+ if model.compiled_loss:
249
+ loss = model.compiled_loss(
250
+ labels, outputs, regularization_losses=model.losses)
251
+ loss += self.build_losses(
252
+ labels=labels, model_outputs=outputs, aux_losses=None)
253
+ else:
254
+ loss = self.build_losses(
255
+ labels=labels, model_outputs=outputs, aux_losses=model.losses)
256
+ # Scales loss as the default gradients allreduce performs sum inside the
257
+ # optimizer.
258
+ scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
259
+
260
+ # For mixed precision, when a LossScaleOptimizer is used, the loss is
261
+ # scaled to avoid numeric underflow.
262
+ if isinstance(optimizer,
263
+ tf_keras.mixed_precision.LossScaleOptimizer):
264
+ scaled_loss = optimizer.get_scaled_loss(scaled_loss)
265
+
266
+ tvars = model.trainable_variables
267
+ grads = tape.gradient(scaled_loss, tvars)
268
+
269
+ if isinstance(optimizer,
270
+ tf_keras.mixed_precision.LossScaleOptimizer):
271
+ grads = optimizer.get_unscaled_gradients(grads)
272
+ optimizer.apply_gradients(list(zip(grads, tvars)))
273
+ logs = {self.loss: loss}
274
+ if metrics:
275
+ self.process_metrics(metrics, labels, outputs)
276
+ if model.compiled_metrics:
277
+ self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
278
+ logs.update({m.name: m.result() for m in metrics or []})
279
+ logs.update({m.name: m.result() for m in model.metrics})
280
+ return logs
281
+
282
+ def validation_step(self, inputs, model: tf_keras.Model, metrics=None):
283
+ """Validation step.
284
+
285
+ With distribution strategies, this method runs on devices.
286
+
287
+ Args:
288
+ inputs: a dictionary of input tensors.
289
+ model: the keras.Model.
290
+ metrics: a nested structure of metrics objects.
291
+
292
+ Returns:
293
+ A dictionary of logs.
294
+ """
295
+ if isinstance(inputs, tuple) and len(inputs) == 2:
296
+ features, labels = inputs
297
+ else:
298
+ features, labels = inputs, inputs
299
+ outputs = self.inference_step(features, model)
300
+ loss = self.build_losses(
301
+ labels=labels, model_outputs=outputs, aux_losses=model.losses)
302
+ logs = {self.loss: loss}
303
+ if metrics:
304
+ self.process_metrics(metrics, labels, outputs)
305
+ if model.compiled_metrics:
306
+ self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
307
+ logs.update({m.name: m.result() for m in metrics or []})
308
+ logs.update({m.name: m.result() for m in model.metrics})
309
+ return logs
310
+
311
+ def inference_step(self, inputs, model: tf_keras.Model):
312
+ """Performs the forward step.
313
+
314
+ With distribution strategies, this method runs on devices.
315
+
316
+ Args:
317
+ inputs: a dictionary of input tensors.
318
+ model: the keras.Model.
319
+
320
+ Returns:
321
+ Model outputs.
322
+ """
323
+ return model(inputs, training=False)
324
+
325
+ def aggregate_logs(self, state, step_logs):
326
+ """Optional aggregation over logs returned from a validation step.
327
+
328
+ Given step_logs from a validation step, this function aggregates the logs
329
+ after each eval_step() (see eval_reduce() function in
330
+ official/core/base_trainer.py). It runs on CPU and can be used to aggregate
331
+ metrics during validation, when there are too many metrics that cannot fit
332
+ into TPU memory. Note that this may increase latency due to data transfer
333
+ between TPU and CPU. Also, the step output from a validation step may be a
334
+ tuple with elements from replicas, and a concatenation of the elements is
335
+ needed in such case.
336
+
337
+ Args:
338
+ state: The current state of training, for example, it can be a sequence of
339
+ metrics.
340
+ step_logs: Logs from a validation step. Can be a dictionary.
341
+ """
342
+ pass
343
+
344
+ def reduce_aggregated_logs(self,
345
+ aggregated_logs,
346
+ global_step: Optional[tf.Tensor] = None):
347
+ """Optional reduce of aggregated logs over validation steps.
348
+
349
+ This function reduces aggregated logs at the end of validation, and can be
350
+ used to compute the final metrics. It runs on CPU and in each eval_end() in
351
+ base trainer (see eval_end() function in official/core/base_trainer.py).
352
+
353
+ Args:
354
+ aggregated_logs: Aggregated logs over multiple validation steps.
355
+ global_step: An optional variable of global step.
356
+
357
+ Returns:
358
+ A dictionary of reduced results.
359
+ """
360
+ return {}
modeling/official/core/base_trainer.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Standard Trainer implementation.
16
+
17
+ The base trainer implements the Orbit `StandardTrainable` and
18
+ `StandardEvaluable` interfaces. Trainers inside this project should be
19
+ interchangable and independent on model architectures and tasks.
20
+ """
21
+ import functools
22
+ from typing import Union, Optional
23
+ from absl import logging
24
+ import gin
25
+ import orbit
26
+ import tensorflow as tf, tf_keras
27
+
28
+ from official.core import base_task
29
+ from official.core import config_definitions
30
+ from official.modeling import optimization
31
+
32
+ ExperimentConfig = config_definitions.ExperimentConfig
33
+ TrainerConfig = config_definitions.TrainerConfig
34
+
35
+
36
+ class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
37
+ """Trainer class for both sync and async Strategy."""
38
+
39
+ def init_async(self):
40
+ """Initializes the Async Trainer base class."""
41
+ assert isinstance(self._strategy, tf.distribute.Strategy)
42
+ self._is_async = isinstance(
43
+ self._strategy, tf.distribute.experimental.ParameterServerStrategy)
44
+ self._coordinator = None
45
+ if self._is_async:
46
+ self._coordinator = (
47
+ tf.distribute.experimental.coordinator.ClusterCoordinator(
48
+ self._strategy))
49
+
50
+ def coordinator_for_async(
51
+ self,
52
+ ) -> tf.distribute.experimental.coordinator.ClusterCoordinator:
53
+ if not self._coordinator:
54
+ raise ValueError(
55
+ "Coordinator uninitialized for async run. Call init_async() first."
56
+ )
57
+ return self._coordinator
58
+
59
+ def join(self):
60
+ """Join all async steps. Only useful in aysnc training."""
61
+ if getattr(self, "_is_async", False):
62
+ self.coordinator_for_async().join()
63
+
64
+ def create_train_loop_fn(self):
65
+ """Creates a eval loop from the given step function and options."""
66
+ train_loop_fn = super().create_train_loop_fn()
67
+ if getattr(self, "_is_async", False):
68
+
69
+ def _async_loop_fn(iterator, num_steps):
70
+ self.coordinator_for_async().schedule(
71
+ train_loop_fn, args=(iterator, num_steps)
72
+ )
73
+
74
+ return _async_loop_fn
75
+ else:
76
+ return train_loop_fn
77
+
78
+ def create_eval_loop_fn(self, has_state: bool):
79
+ """Creates a training loop from the given step function and options."""
80
+ eval_loop_fn = super().create_eval_loop_fn(has_state)
81
+
82
+ if getattr(self, "_is_async", False):
83
+ if has_state:
84
+ raise ValueError(
85
+ "Stateful eval loop is not supported in async training.")
86
+
87
+ def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
88
+ assert state is None
89
+ assert reduce_fn is None
90
+ self.coordinator_for_async().schedule(
91
+ eval_loop_fn, args=(iterator, num_steps)
92
+ )
93
+
94
+ return _async_loop_fn
95
+ else:
96
+ return eval_loop_fn
97
+
98
+ def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
99
+ """A utility function to help create a `tf.distribute.DistributedDataset`.
100
+
101
+ Args:
102
+ dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
103
+ returning a `tf.data.Dataset`. If it is a function, it may optionally
104
+ have an argument named `input_context` which will be passed a
105
+ `tf.distribute.InputContext` instance.
106
+ *args: Any positional arguments to pass through to `dataset_or_fn`.
107
+ **kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
108
+
109
+ Returns:
110
+ A distributed Dataset.
111
+ """
112
+ if getattr(self, "_is_async", False):
113
+ per_worker_dataset_fn = functools.partial(
114
+ orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn,
115
+ *args, **kwargs)
116
+ per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
117
+
118
+ return self.coordinator_for_async().create_per_worker_dataset(
119
+ per_worker_dataset_fn
120
+ )
121
+ else:
122
+ return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
123
+ *args, **kwargs)
124
+
125
+
126
+ def get_runtime_options(config: ExperimentConfig):
127
+ """Get tf.distribute.RunOptions from config."""
128
+ xla_options = {}
129
+ if config.runtime.tpu_enable_xla_dynamic_padder is not None:
130
+ xla_options["enable_xla_dynamic_padder"] = (
131
+ config.runtime.tpu_enable_xla_dynamic_padder)
132
+ return tf.distribute.RunOptions(
133
+ experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
134
+
135
+
136
+ @gin.configurable
137
+ class Trainer(_AsyncTrainer):
138
+ """Implements the common trainer shared for TensorFlow models."""
139
+
140
+ # pylint: disable=super-init-not-called
141
+ def __init__(
142
+ self,
143
+ config: ExperimentConfig,
144
+ task: base_task.Task,
145
+ model: tf_keras.Model,
146
+ optimizer: tf.optimizers.Optimizer,
147
+ train: bool = True,
148
+ evaluate: bool = True,
149
+ train_dataset: Optional[Union[tf.data.Dataset,
150
+ tf.distribute.DistributedDataset]] = None,
151
+ validation_dataset: Optional[Union[
152
+ tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
153
+ checkpoint_exporter=None):
154
+ """Initialize common trainer for TensorFlow models.
155
+
156
+ Args:
157
+ config: An `ExperimentConfig` instance specifying experiment config.
158
+ task: A base_task.Task instance.
159
+ model: The model instance, e.g. a tf_keras.Model instance.
160
+ optimizer: tf.optimizers.Optimizer instance.
161
+ train: bool, whether or not this trainer will be used for training.
162
+ default to True.
163
+ evaluate: bool, whether or not this trainer will be used for evaluation.
164
+ default to True.
165
+ train_dataset: a dataset object created for training. With tf.distribute,
166
+ it needs to be a `DistributedDataset`.
167
+ validation_dataset: a dataset object created for evaluation. With
168
+ tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
169
+ create a dataset iterator for each eval round, so the dataset does not
170
+ need to repeat.
171
+ checkpoint_exporter: an object that has the `maybe_export_checkpoint`
172
+ interface.
173
+ """
174
+ # Gets the current distribution strategy. If not inside any strategy scope,
175
+ # it gets a single-replica no-op strategy.
176
+ self._strategy = tf.distribute.get_strategy()
177
+ self._validate_params(
178
+ config,
179
+ check_train_data=train_dataset is None,
180
+ check_validation_data=validation_dataset is None)
181
+ self._config = config
182
+ self._task = task
183
+ self._model = model
184
+ self._optimizer = optimizer
185
+ self._checkpoint_exporter = checkpoint_exporter
186
+ self._recovery = None
187
+ # Runtime options are only applied to train_step.
188
+ # We use default for eval_step.
189
+ self._runtime_options = get_runtime_options(config)
190
+
191
+ # Creates a shadow copy of the weights to store weights moving average.
192
+ if isinstance(self._optimizer, optimization.ExponentialMovingAverage
193
+ ) and not self._optimizer.has_shadow_copy:
194
+ self._optimizer.shadow_copy(self._model)
195
+
196
+ # global_step increases by 1 after each training iteration.
197
+ # We should have global_step.numpy() == self.optimizer.iterations.numpy()
198
+ # when there is only 1 optimizer.
199
+ self._global_step = orbit.utils.create_global_step()
200
+ if hasattr(self.model, "checkpoint_items"):
201
+ checkpoint_items = self.model.checkpoint_items
202
+ else:
203
+ checkpoint_items = {}
204
+ self._checkpoint = tf.train.Checkpoint(
205
+ global_step=self.global_step,
206
+ model=self.model,
207
+ optimizer=self.optimizer,
208
+ **checkpoint_items)
209
+
210
+ self._train_loss = tf_keras.metrics.Mean("training_loss", dtype=tf.float32)
211
+ self._validation_loss = tf_keras.metrics.Mean(
212
+ "validation_loss", dtype=tf.float32)
213
+ model_metrics = model.metrics if hasattr(model, "metrics") else []
214
+
215
+ self.init_async()
216
+
217
+ if train:
218
+ self._train_metrics = self.task.build_metrics(
219
+ training=True) + model_metrics
220
+ train_dataset = train_dataset or self.distribute_dataset(
221
+ self.task.build_inputs, self.config.task.train_data)
222
+ orbit.StandardTrainer.__init__(
223
+ self,
224
+ train_dataset,
225
+ options=orbit.StandardTrainerOptions(
226
+ use_tf_while_loop=config.trainer.train_tf_while_loop,
227
+ use_tf_function=config.trainer.train_tf_function,
228
+ use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
229
+
230
+ if evaluate:
231
+ self._validation_metrics = self.task.build_metrics(
232
+ training=False) + model_metrics
233
+ validation_dataset = validation_dataset or self.distribute_dataset(
234
+ self.task.build_inputs, self.config.task.validation_data)
235
+ orbit.StandardEvaluator.__init__(
236
+ self,
237
+ validation_dataset,
238
+ options=orbit.StandardEvaluatorOptions(
239
+ use_tf_function=config.trainer.eval_tf_function,
240
+ use_tf_while_loop=config.trainer.eval_tf_while_loop))
241
+
242
+ def _validate_params(self,
243
+ config,
244
+ check_train_data=True,
245
+ check_validation_data=True):
246
+ r"""Validates if the configuration object passed to the Trainer.
247
+
248
+ The experiment configuration should be structured as:
249
+ \trainer
250
+ \task
251
+ \train_data
252
+ \validation_data
253
+
254
+ Args:
255
+ config: a namedtuple, dataclass, ConfigDict, etc.
256
+ check_train_data: whether to check task.train_data field.
257
+ check_validation_data: whether to check task.validation_data field.
258
+ """
259
+ if not hasattr(config, "trainer"):
260
+ raise AttributeError("The trainer requires the configuration contains an"
261
+ " attribute `trainer`.")
262
+
263
+ if not hasattr(config, "task"):
264
+ raise AttributeError("The trainer requires the configuration contains an"
265
+ " attribute `task`.")
266
+
267
+ if check_train_data and not hasattr(config.task, "train_data"):
268
+ raise AttributeError("The trainer requires the configuration contains an"
269
+ " attribute `task.train_data`.")
270
+
271
+ if check_validation_data and not hasattr(config.task, "validation_data"):
272
+ raise AttributeError("The trainer requires the configuration contains an"
273
+ " attribute `task.validation_data`.")
274
+
275
+ @property
276
+ def strategy(self):
277
+ return self._strategy
278
+
279
+ @property
280
+ def config(self):
281
+ return self._config
282
+
283
+ @property
284
+ def task(self):
285
+ return self._task
286
+
287
+ @property
288
+ def model(self):
289
+ return self._model
290
+
291
+ @property
292
+ def optimizer(self):
293
+ if hasattr(self, "_optimizer"):
294
+ return self._optimizer
295
+ else:
296
+ return None
297
+
298
+ @property
299
+ def global_step(self):
300
+ return self._global_step
301
+
302
+ @property
303
+ def train_loss(self):
304
+ """Accesses the training loss metric object."""
305
+ return self._train_loss
306
+
307
+ @property
308
+ def validation_loss(self):
309
+ """Accesses the validation loss metric object."""
310
+ return self._validation_loss
311
+
312
+ @property
313
+ def train_metrics(self):
314
+ """Accesses all training metric objects."""
315
+ return self._train_metrics
316
+
317
+ @property
318
+ def validation_metrics(self):
319
+ """Accesses all validation metric metric objects."""
320
+ return self._validation_metrics
321
+
322
+ def initialize(self):
323
+ """A callback function.
324
+
325
+ This function will be called when no checkpoint found for the model.
326
+ If there is a checkpoint, the checkpoint will be loaded and this function
327
+ will not be called. Tasks may use this callback function to load a
328
+ pretrained checkpoint, saved under a directory other than the model_dir.
329
+ """
330
+ self.task.initialize(self.model)
331
+
332
+ @property
333
+ def checkpoint(self):
334
+ """Accesses the training checkpoint."""
335
+ return self._checkpoint
336
+
337
+ @property
338
+ def checkpoint_exporter(self):
339
+ """Accesses the checkpoint exporter."""
340
+ return self._checkpoint_exporter
341
+
342
+ def train_loop_end(self):
343
+ """See base class."""
344
+ self.join()
345
+ logs = {}
346
+ for metric in self.train_metrics + [self.train_loss]:
347
+ logs[metric.name] = metric.result()
348
+ metric.reset_states()
349
+ if callable(self.optimizer.learning_rate):
350
+ # Maybe a self-implemented optimizer does not have `optimizer.iterations`.
351
+ # So just to be safe here.
352
+ if hasattr(self.optimizer, "iterations"):
353
+ logs["learning_rate"] = self.optimizer.learning_rate(
354
+ self.optimizer.iterations)
355
+ else:
356
+ logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
357
+ else:
358
+ logs["learning_rate"] = self.optimizer.learning_rate
359
+ return logs
360
+
361
+ def next_train_inputs(self, iterator):
362
+ """Fetches the next inputs for the model during train.
363
+
364
+ This method consumes the input iterator and returns the next inputs for the
365
+ model.
366
+
367
+ This method provides a way to control how to fetch the next model input, and
368
+ what data to send to the model.
369
+
370
+ Note: This function runs on the host side when accelerators are used.
371
+
372
+ Note: Depending on the training setup this may or may not run in eager mode.
373
+ In most cases it will be run in graph mode.
374
+
375
+ Args:
376
+ iterator: Dataset iterator to generate the next inputs from.
377
+
378
+ Returns:
379
+ The inputs to the model.
380
+ """
381
+ return next(iterator)
382
+
383
+ def train_step(self, iterator):
384
+ """See base class."""
385
+
386
+ def step_fn(inputs):
387
+ if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
388
+ task_train_step = tf.function(self.task.train_step, jit_compile=True)
389
+ else:
390
+ task_train_step = self.task.train_step
391
+ logs = task_train_step(
392
+ inputs,
393
+ model=self.model,
394
+ optimizer=self.optimizer,
395
+ metrics=self.train_metrics)
396
+ self._train_loss.update_state(logs[self.task.loss])
397
+ self.global_step.assign_add(1)
398
+
399
+ inputs = self.next_train_inputs(iterator)
400
+ self.strategy.run(step_fn, args=(inputs,), options=self._runtime_options)
401
+
402
+ def eval_begin(self):
403
+ """Sets up metrics."""
404
+ for metric in self.validation_metrics + [self.validation_loss]:
405
+ metric.reset_states()
406
+ # Swaps weights to test on weights moving average.
407
+ if self.optimizer and isinstance(self.optimizer,
408
+ optimization.ExponentialMovingAverage):
409
+ self.optimizer.swap_weights()
410
+
411
+ def next_eval_inputs(self, iterator):
412
+ """Fetches the next inputs for the model during eval.
413
+
414
+ This method consumes the input iterator and returns the next inputs for the
415
+ model and an additional logs dict. The output dict remains in the host (not
416
+ sent to GPUs/TPUs) and is merged with the model outputs which will be
417
+ processed later in `aggregate_logs`. This is useful for sending extra logs
418
+ downstream that are not compatible with the accelerators.
419
+
420
+ Note: This function runs on the host side when accelerators are used.
421
+
422
+ Note: Depending on the training setup this may or may not run in eager mode.
423
+ In most cases it will be run in graph mode.
424
+
425
+ Args:
426
+ iterator: Dataset iterator to generate the next inputs from.
427
+
428
+ Returns:
429
+ The inputs to the model, and an additional logs dictionnary. The logs
430
+ are not passed to the model, instead they are merged with model output
431
+ logs.
432
+ """
433
+ passthrough_logs = dict()
434
+ return next(iterator), passthrough_logs
435
+
436
+ def eval_step(self, iterator):
437
+ """See base class."""
438
+
439
+ def step_fn(inputs):
440
+ logs = self.task.validation_step(
441
+ inputs, model=self.model, metrics=self.validation_metrics)
442
+ if self.task.loss in logs:
443
+ self._validation_loss.update_state(logs[self.task.loss])
444
+ return logs
445
+
446
+ inputs, passthrough_logs = self.next_eval_inputs(iterator)
447
+ distributed_outputs = self.strategy.run(step_fn, args=(inputs,))
448
+ logs = tf.nest.map_structure(
449
+ self.strategy.experimental_local_results, distributed_outputs
450
+ )
451
+
452
+ if set(logs.keys()) & set(passthrough_logs.keys()):
453
+ logging.warning(
454
+ (
455
+ "Conflict between the pasthrough log keys and the returned model"
456
+ " log keys. Found %r keys in the passthrough logs and %r keys in"
457
+ " the model logs. Model log keys takes precedence."
458
+ ),
459
+ logs.keys(),
460
+ passthrough_logs.keys(),
461
+ )
462
+
463
+ return passthrough_logs | logs
464
+
465
+ def eval_end(self, aggregated_logs=None):
466
+ """Processes evaluation results."""
467
+ self.join()
468
+ logs = {}
469
+ for metric in self.validation_metrics:
470
+ logs[metric.name] = metric.result()
471
+ if self.validation_loss.count.numpy() != 0:
472
+ logs[self.validation_loss.name] = self.validation_loss.result()
473
+ else:
474
+ # `self.validation_loss` metric was not updated, because the validation
475
+ # loss was not returned from the task's `validation_step` method.
476
+ logging.info("The task did not report validation loss.")
477
+ if aggregated_logs:
478
+ metrics = self.task.reduce_aggregated_logs(
479
+ aggregated_logs, global_step=self.global_step)
480
+ logs.update(metrics)
481
+
482
+ if self._checkpoint_exporter:
483
+ self._checkpoint_exporter.maybe_export_checkpoint(
484
+ self.checkpoint, logs, self.global_step.numpy())
485
+ metric_name = self.config.trainer.best_checkpoint_eval_metric
486
+ logs["best_" +
487
+ metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name]
488
+
489
+ # Swaps back weights after testing when EMA is used.
490
+ # This happens after best checkpoint export so that average weights used for
491
+ # eval are exported instead of regular weights.
492
+ if self.optimizer and isinstance(self.optimizer,
493
+ optimization.ExponentialMovingAverage):
494
+ self.optimizer.swap_weights()
495
+ return logs
496
+
497
+ def eval_reduce(self, state=None, step_outputs=None):
498
+ return self.task.aggregate_logs(state, step_outputs)
modeling/official/core/base_trainer_test.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for tensorflow_models.core.trainers.trainer."""
16
+ # pylint: disable=g-direct-tensorflow-import
17
+ import gc
18
+ import multiprocessing
19
+ import os
20
+ import sys
21
+
22
+ from absl.testing import parameterized
23
+ import orbit
24
+ import portpicker
25
+ import tensorflow as tf, tf_keras
26
+
27
+ from tensorflow.python.distribute import combinations
28
+ from tensorflow.python.distribute import strategy_combinations
29
+ from official.core import base_trainer as trainer_lib
30
+ from official.core import config_definitions as cfg
31
+ from official.core import train_lib
32
+ from official.utils.testing import mock_task
33
+
34
+ TPU_TEST = 'test_tpu' in sys.argv[0]
35
+ GPU_TEST = 'test_gpu' in sys.argv[0]
36
+
37
+
38
+ def all_strategy_combinations():
39
+ return combinations.combine(
40
+ distribution=[
41
+ strategy_combinations.default_strategy,
42
+ strategy_combinations.cloud_tpu_strategy,
43
+ strategy_combinations.one_device_strategy_gpu,
44
+ ],)
45
+
46
+
47
+ def create_in_process_cluster(num_workers, num_ps):
48
+ """Creates and starts local servers and returns the cluster_resolver."""
49
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
50
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
51
+
52
+ cluster_dict = {}
53
+ cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
54
+ if num_ps > 0:
55
+ cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
56
+
57
+ cluster_spec = tf.train.ClusterSpec(cluster_dict)
58
+
59
+ # Workers need some inter_ops threads to work properly.
60
+ worker_config = tf.compat.v1.ConfigProto()
61
+ if multiprocessing.cpu_count() < num_workers + 1:
62
+ worker_config.inter_op_parallelism_threads = num_workers + 1
63
+
64
+ for i in range(num_workers):
65
+ tf.distribute.Server(
66
+ cluster_spec,
67
+ job_name='worker',
68
+ task_index=i,
69
+ config=worker_config,
70
+ protocol='grpc')
71
+
72
+ for i in range(num_ps):
73
+ tf.distribute.Server(
74
+ cluster_spec, job_name='ps', task_index=i, protocol='grpc')
75
+
76
+ cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
77
+ cluster_spec, rpc_layer='grpc')
78
+ return cluster_resolver
79
+
80
+
81
+ def dataset_fn(input_context=None):
82
+ del input_context
83
+
84
+ def dummy_data(_):
85
+ return tf.zeros((1, 1), dtype=tf.float32)
86
+
87
+ dataset = tf.data.Dataset.range(1)
88
+ dataset = dataset.repeat()
89
+ dataset = dataset.map(
90
+ dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
91
+ return dataset
92
+
93
+
94
+ class MockAsyncTrainer(trainer_lib._AsyncTrainer):
95
+ """Mock AsyncTrainer to test the _AsyncTrainer class."""
96
+
97
+ def __init__(self):
98
+ self._strategy = tf.distribute.get_strategy()
99
+ self.init_async()
100
+
101
+ self.global_step = tf.Variable(
102
+ 0,
103
+ dtype=tf.int64,
104
+ name='global_step',
105
+ trainable=False,
106
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
107
+ self.eval_global_step = tf.Variable(
108
+ 0,
109
+ dtype=tf.int64,
110
+ name='eval_global_step',
111
+ trainable=False,
112
+ aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
113
+
114
+ train_dataset = self.distribute_dataset(dataset_fn)
115
+ orbit.StandardTrainer.__init__(
116
+ self, train_dataset, options=orbit.StandardTrainerOptions())
117
+
118
+ validation_dataset = self.distribute_dataset(dataset_fn)
119
+ orbit.StandardEvaluator.__init__(
120
+ self,
121
+ validation_dataset,
122
+ options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
123
+
124
+ def train_loop_begin(self):
125
+ self.global_step.assign(0)
126
+
127
+ def train_step(self, iterator):
128
+
129
+ def replica_step(_):
130
+ self.global_step.assign_add(1)
131
+
132
+ self._strategy.run(replica_step, args=(next(iterator),))
133
+
134
+ def train_loop_end(self):
135
+ self.join()
136
+ return self.global_step.numpy()
137
+
138
+ def eval_begin(self):
139
+ self.eval_global_step.assign(0)
140
+
141
+ def eval_step(self, iterator):
142
+
143
+ def replica_step(_):
144
+ self.eval_global_step.assign_add(1)
145
+
146
+ self._strategy.run(replica_step, args=(next(iterator),))
147
+
148
+ def eval_end(self):
149
+ self.join()
150
+ return self.eval_global_step.numpy()
151
+
152
+
153
+ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
154
+
155
+ def setUp(self):
156
+ super().setUp()
157
+ self._config = cfg.ExperimentConfig(
158
+ trainer=cfg.TrainerConfig(
159
+ optimizer_config=cfg.OptimizationConfig({
160
+ 'optimizer': {
161
+ 'type': 'sgd'
162
+ },
163
+ 'learning_rate': {
164
+ 'type': 'constant'
165
+ }
166
+ })))
167
+
168
+ def tearDown(self):
169
+ gc.collect()
170
+ # This will only contain uncollectable garbage, i.e. reference cycles
171
+ # involving objects with __del__ defined.
172
+ self.assertEmpty(gc.garbage)
173
+ super().tearDown()
174
+
175
+ def create_test_trainer(self, config, model_dir=None, task=None):
176
+ task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
177
+ ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
178
+ trainer = trainer_lib.Trainer(
179
+ config,
180
+ task,
181
+ model=task.build_model(),
182
+ optimizer=task.create_optimizer(config.trainer.optimizer_config,
183
+ config.runtime),
184
+ checkpoint_exporter=ckpt_exporter)
185
+ return trainer
186
+
187
+ @combinations.generate(all_strategy_combinations())
188
+ def test_trainer_train(self, distribution):
189
+ with distribution.scope():
190
+ trainer = self.create_test_trainer(self._config)
191
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
192
+ self.assertIn('training_loss', logs)
193
+ self.assertIn('learning_rate', logs)
194
+
195
+ @combinations.generate(all_strategy_combinations())
196
+ def test_trainer_passing_datasets(self, distribution):
197
+ with distribution.scope():
198
+ task = mock_task.MockTask(self._config)
199
+ train_dataset = orbit.utils.make_distributed_dataset(
200
+ distribution, task.build_inputs, self._config.task.train_data)
201
+ validation_dataset = orbit.utils.make_distributed_dataset(
202
+ distribution, task.build_inputs, self._config.task.validation_data)
203
+ self._config.task.train_data = None
204
+ self._config.task.validation_data = None
205
+ trainer = trainer_lib.Trainer(
206
+ self._config,
207
+ task,
208
+ model=task.build_model(),
209
+ optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
210
+ self._config.runtime),
211
+ train_dataset=train_dataset,
212
+ validation_dataset=validation_dataset)
213
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
214
+ self.assertIn('training_loss', logs)
215
+ self.assertIn('learning_rate', logs)
216
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
217
+ self.assertIn('validation_loss', logs)
218
+
219
+ def test_base_async_trainer(self):
220
+ if TPU_TEST or GPU_TEST:
221
+ self.skipTest('Aysnc training is not available on GPU/GPU.')
222
+ num_workers = 3
223
+ num_ps = 2
224
+ cluster_resolver = create_in_process_cluster(num_workers, num_ps)
225
+ distribution = tf.distribute.experimental.ParameterServerStrategy(
226
+ cluster_resolver)
227
+ with distribution.scope():
228
+ trainer = MockAsyncTrainer()
229
+ trainer.init_async()
230
+ self.assertIsInstance(
231
+ trainer._coordinator,
232
+ tf.distribute.experimental.coordinator.ClusterCoordinator)
233
+ self.assertEqual(trainer.train(tf.constant(10)), 10)
234
+ self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
235
+
236
+ def test_async_trainer_train(self):
237
+ if TPU_TEST or GPU_TEST:
238
+ self.skipTest('Aysnc training is not available on GPU/TPU.')
239
+ num_workers = 3
240
+ num_ps = 2
241
+ cluster_resolver = create_in_process_cluster(num_workers, num_ps)
242
+ distribution = tf.distribute.experimental.ParameterServerStrategy(
243
+ cluster_resolver)
244
+ with distribution.scope():
245
+ config = cfg.ExperimentConfig(**self._config.as_dict())
246
+ config.trainer.eval_tf_while_loop = True
247
+ trainer = self.create_test_trainer(config)
248
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
249
+ self.assertIn('training_loss', logs)
250
+ self.assertIn('learning_rate', logs)
251
+
252
+ def test_async_trainer_validate(self):
253
+ if TPU_TEST or GPU_TEST:
254
+ self.skipTest('Aysnc training is not available on GPU/GPU.')
255
+ num_workers = 3
256
+ num_ps = 2
257
+ cluster_resolver = create_in_process_cluster(num_workers, num_ps)
258
+ distribution = tf.distribute.experimental.ParameterServerStrategy(
259
+ cluster_resolver)
260
+ with distribution.scope():
261
+ config = cfg.ExperimentConfig(**self._config.as_dict())
262
+ config.trainer.eval_tf_while_loop = True
263
+ trainer = self.create_test_trainer(config)
264
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
265
+ self.assertIn('acc', logs)
266
+ self.assertIn('validation_loss', logs)
267
+
268
+ @combinations.generate(all_strategy_combinations())
269
+ def test_trainer_validate(self, distribution):
270
+ with distribution.scope():
271
+ trainer = self.create_test_trainer(self._config)
272
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
273
+ self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
274
+ self.assertIn('validation_loss', logs)
275
+
276
+ @combinations.generate(all_strategy_combinations())
277
+ def test_trainer_validate_without_loss(self, distribution):
278
+
279
+ class MockTaskWithoutValidationLoss(mock_task.MockTask):
280
+
281
+ def validation_step(self, inputs, model, metrics=None):
282
+ # Disable validation loss.
283
+ logs = super().validation_step(inputs, model)
284
+ del logs[self.loss]
285
+ return logs
286
+
287
+ with distribution.scope():
288
+ task = MockTaskWithoutValidationLoss()
289
+ trainer = self.create_test_trainer(self._config, task=task)
290
+ logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
291
+ self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
292
+ self.assertNotIn('validation_loss', logs)
293
+
294
+ @combinations.generate(
295
+ combinations.combine(
296
+ mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
297
+ loss_scale=[None, 'dynamic', 128, 256],
298
+ ))
299
+ def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
300
+ config = cfg.ExperimentConfig(
301
+ runtime=cfg.RuntimeConfig(
302
+ mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
303
+ trainer=cfg.TrainerConfig(
304
+ optimizer_config=cfg.OptimizationConfig({
305
+ 'optimizer': {
306
+ 'type': 'sgd'
307
+ },
308
+ 'learning_rate': {
309
+ 'type': 'constant'
310
+ },
311
+ })))
312
+ trainer = self.create_test_trainer(config)
313
+ if mixed_precision_dtype == 'float16':
314
+ self.assertIsInstance(trainer.optimizer,
315
+ tf_keras.mixed_precision.LossScaleOptimizer)
316
+ if loss_scale in (None, 'dynamic'):
317
+ self.assertTrue(trainer.optimizer.dynamic)
318
+ else:
319
+ self.assertFalse(trainer.optimizer.dynamic)
320
+ self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
321
+ else:
322
+ self.assertIsInstance(
323
+ trainer.optimizer,
324
+ (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
325
+
326
+ metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
327
+ self.assertIn('training_loss', metrics)
328
+
329
+ def test_export_best_ckpt(self):
330
+ config = cfg.ExperimentConfig(
331
+ trainer=cfg.TrainerConfig(
332
+ best_checkpoint_export_subdir='best_ckpt',
333
+ best_checkpoint_eval_metric='acc',
334
+ optimizer_config=cfg.OptimizationConfig({
335
+ 'optimizer': {
336
+ 'type': 'sgd'
337
+ },
338
+ 'learning_rate': {
339
+ 'type': 'constant'
340
+ }
341
+ })))
342
+ model_dir = self.get_temp_dir()
343
+ trainer = self.create_test_trainer(config, model_dir=model_dir)
344
+ trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
345
+ trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
346
+ self.assertTrue(
347
+ tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
348
+
349
+ def test_model_with_compiled_loss(self):
350
+ task = mock_task.MockTask()
351
+ model = task.build_model()
352
+ model.compile(loss=tf_keras.losses.CategoricalCrossentropy())
353
+ trainer = trainer_lib.Trainer(
354
+ self._config,
355
+ task,
356
+ model=model,
357
+ optimizer=task.create_optimizer(self._config.trainer.optimizer_config))
358
+ logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
359
+ self.assertIn('training_loss', logs)
360
+
361
+
362
+ if __name__ == '__main__':
363
+ tf.test.main()
modeling/official/core/config_definitions.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Common configuration settings."""
16
+
17
+ import dataclasses
18
+ from typing import Optional, Sequence, Union
19
+
20
+ from official.modeling.hyperparams import base_config
21
+ from official.modeling.optimization.configs import optimization_config
22
+ from official.modeling.privacy import configs as dp_configs
23
+
24
+ OptimizationConfig = optimization_config.OptimizationConfig
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class DataConfig(base_config.Config):
29
+ """The base configuration for building datasets.
30
+
31
+ Attributes:
32
+ input_path: The path to the input. It can be either (1) a str indicating a
33
+ file path/pattern, or (2) a str indicating multiple file paths/patterns
34
+ separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
35
+ str, each of which is a file path/pattern or multiple file paths/patterns
36
+ separated by comma, or (4) a dictionary of the previous three approaches
37
+ for more advanced data mixing using named access. It should not be
38
+ specified when the following `tfds_name` is specified.
39
+ tfds_name: The name of the tensorflow dataset (TFDS). It should not be
40
+ specified when the above `input_path` is specified.
41
+ tfds_split: A str indicating which split of the data to load from TFDS. It
42
+ is required when above `tfds_name` is specified.
43
+ global_batch_size: The global batch size across all replicas.
44
+ is_training: Whether this data is used for training or not. This flag is
45
+ useful for consumers of this object to determine whether the data should
46
+ be repeated or shuffled.
47
+ drop_remainder: Whether the last batch should be dropped in the case it has
48
+ fewer than `global_batch_size` elements.
49
+ shuffle_buffer_size: The buffer size used for shuffling training data.
50
+ cache: Whether to cache dataset examples. If `True`, we will cache the
51
+ dataset after applying the decode_fn and parse_fn. It can be used to avoid
52
+ re-reading from disk, re-decoding and re-parsing the example on the second
53
+ epoch, but it requires significant memory overhead.
54
+ cycle_length: The number of files that will be processed concurrently when
55
+ interleaving files.
56
+ block_length: The number of consecutive elements to produce from each input
57
+ element before cycling to another input element when interleaving files.
58
+ deterministic: A boolean controlling whether determinism should be enforced.
59
+ sharding: Whether sharding is used in the input pipeline.
60
+ enable_tf_data_service: A boolean indicating whether to enable tf.data
61
+ service for the input pipeline.
62
+ tf_data_service_address: The URI of a tf.data service to offload
63
+ preprocessing onto during training. The URI should be in the format
64
+ "protocol://address", e.g. "grpc://tf-data-service:5050". It can be
65
+ overridden by `FLAGS.tf_data_service` flag in the binary.
66
+ tf_data_service_job_name: The name of the tf.data service job. This argument
67
+ makes it possible for multiple datasets to share the same job. The default
68
+ behavior is that the dataset creates anonymous, exclusively owned jobs.
69
+ tfds_data_dir: A str specifying the directory to read/write TFDS data.
70
+ tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
71
+ returned tf.data.Dataset will have a 2-tuple structure (input, label)
72
+ according to builder.info.supervised_keys; if False, the default, the
73
+ returned tf.data.Dataset will have a dictionary with all the features.
74
+ tfds_skip_decoding_feature: A str to indicate which features are skipped for
75
+ decoding when loading dataset from TFDS. Use comma to separate multiple
76
+ features. The main use case is to skip the image/video decoding for better
77
+ performance.
78
+ enable_shared_tf_data_service_between_parallel_trainers: A bool. When set to
79
+ true, only a single tf.data service will be started, and it will be shared
80
+ between all the trainer run simultaneously, e.g. using vizier to tune
81
+ hyperparameters. This will save CPU and RAM resources compared to running
82
+ separate tf.data service for each trainer. Notice that if batch size is
83
+ different for different trainers, the field
84
+ apply_tf_data_service_before_batching also needs to be true so that only a
85
+ single tf.data service instance will be created. In this case, tf.data
86
+ service will be applied before batching operation. So make sure to not
87
+ apply any processing steps after batching (e.g. in postprocess_fn) since
88
+ they wouldn't be paralleled by tf.data service and may slow down your
89
+ tf.data pipeline. When using shared tf.data service, the tf.data dataset
90
+ must be infinite, and slow trainer may skip certain training examples.
91
+ More details about shared tf.data service can be found at:
92
+ https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
93
+ apply_tf_data_service_before_batching: A bool. If set to True, tf.data
94
+ service will be applied before batching operation. This is useful to make
95
+ sure only a single tf.data service instance is created when
96
+ enable_shared_tf_data_service_between_parallel_trainers is true and batch
97
+ size is changing between parallel trainers.
98
+ trainer_id: A string. The id of the trainer if there are multiple parallel
99
+ trainer running at the same time, e.g. in vizier tuning case. It will be
100
+ automatically set if this field is needed. Users does not need to set it
101
+ when creating experiment configs.
102
+ seed: An optional seed to use for deterministic shuffling/preprocessing.
103
+ prefetch_buffer_size: An int specifying the buffer size of prefetch
104
+ datasets. If None, the buffer size is autotuned. Specifying this is useful
105
+ in case autotuning uses up too much memory by making the buffer size too
106
+ high.
107
+ autotune_algorithm: If specified, use this algorithm for AUTOTUNE. See:
108
+ https://www.tensorflow.org/api_docs/python/tf/data/experimental/AutotuneAlgorithm
109
+ """
110
+ input_path: Union[Sequence[str], str, base_config.Config] = ""
111
+ tfds_name: Union[str, base_config.Config] = ""
112
+ tfds_split: str = ""
113
+ global_batch_size: int = 0
114
+ is_training: Optional[bool] = None
115
+ drop_remainder: bool = True
116
+ shuffle_buffer_size: int = 100
117
+ cache: bool = False
118
+ cycle_length: Optional[int] = None
119
+ block_length: int = 1
120
+ deterministic: Optional[bool] = None
121
+ sharding: bool = True
122
+ enable_tf_data_service: bool = False
123
+ tf_data_service_address: Optional[str] = None
124
+ tf_data_service_job_name: Optional[str] = None
125
+ tfds_data_dir: str = ""
126
+ tfds_as_supervised: bool = False
127
+ tfds_skip_decoding_feature: str = ""
128
+ enable_shared_tf_data_service_between_parallel_trainers: bool = False
129
+ apply_tf_data_service_before_batching: bool = False
130
+ trainer_id: Optional[str] = None
131
+ seed: Optional[int] = None
132
+ prefetch_buffer_size: Optional[int] = None
133
+ autotune_algorithm: Optional[str] = None
134
+
135
+
136
+ @dataclasses.dataclass
137
+ class RuntimeConfig(base_config.Config):
138
+ """High-level configurations for Runtime.
139
+
140
+ These include parameters that are not directly related to the experiment,
141
+ e.g. directories, accelerator type, etc.
142
+
143
+ Attributes:
144
+ distribution_strategy: e.g. 'mirrored', 'tpu', etc.
145
+ enable_xla: Whether or not to enable XLA.
146
+ per_gpu_thread_count: thread count per GPU.
147
+ gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
148
+ dataset_num_private_threads: Number of threads for a private threadpool
149
+ created for all datasets computation.
150
+ tpu: The address of the TPU to use, if any.
151
+ num_gpus: The number of GPUs to use, if any.
152
+ worker_hosts: comma-separated list of worker ip:port pairs for running
153
+ multi-worker models with DistributionStrategy.
154
+ task_index: If multi-worker training, the task index of this worker.
155
+ all_reduce_alg: Defines the algorithm for performing all-reduce.
156
+ num_packs: Sets `num_packs` in the cross device ops used in
157
+ MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
158
+ mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
159
+ 'float16', or 'bfloat16'.
160
+ loss_scale: The type of loss scale, or 'float' value. This is used when
161
+ setting the mixed precision policy.
162
+ run_eagerly: Whether or not to run the experiment eagerly.
163
+ batchnorm_spatial_persistent: Whether or not to enable the spatial
164
+ persistent mode for CuDNN batch norm kernel for improved GPU performance.
165
+ """
166
+ distribution_strategy: str = "mirrored"
167
+ enable_xla: bool = False
168
+ gpu_thread_mode: Optional[str] = None
169
+ dataset_num_private_threads: Optional[int] = None
170
+ per_gpu_thread_count: int = 0
171
+ tpu: Optional[str] = None
172
+ num_gpus: int = 0
173
+ worker_hosts: Optional[str] = None
174
+ task_index: int = -1
175
+ all_reduce_alg: Optional[str] = None
176
+ num_packs: int = 1
177
+ mixed_precision_dtype: Optional[str] = None
178
+ loss_scale: Optional[Union[str, float]] = None
179
+ run_eagerly: bool = False
180
+ batchnorm_spatial_persistent: bool = False
181
+
182
+ # XLA runtime params.
183
+ # XLA params are only applied to the train_step.
184
+ # These augments can improve training speed. They can also improve eval, but
185
+ # may reduce usability and users would need to make changes to code.
186
+
187
+ # Whether to enable XLA dynamic padder
188
+ # infrastructure to handle dynamic shapes inputs inside XLA. True by
189
+ # default. Disabling this may cause correctness issues with dynamic shapes
190
+ # inputs, as XLA will just assume the inputs are with padded shapes. However
191
+ # users can optionally set it to False to improve device time if masking is
192
+ # already handled in the user side.
193
+ # If None, will respect XLA default.
194
+ tpu_enable_xla_dynamic_padder: Optional[bool] = None
195
+
196
+ # Global model parallelism configurations.
197
+ num_cores_per_replica: int = 1
198
+ default_shard_dim: int = -1
199
+ use_tpu_mp_strategy: bool = False
200
+
201
+ def model_parallelism(self):
202
+ return dict(
203
+ num_cores_per_replica=self.num_cores_per_replica,
204
+ default_shard_dim=self.default_shard_dim)
205
+
206
+
207
+ @dataclasses.dataclass
208
+ class TrainerConfig(base_config.Config):
209
+ """Configuration for trainer.
210
+
211
+ Attributes:
212
+ optimizer_config: optimizer config, it includes optimizer, learning rate,
213
+ and warmup schedule configs.
214
+ train_tf_while_loop: whether or not to use tf while loop.
215
+ train_tf_function: whether or not to use tf_function for training loop.
216
+ eval_tf_function: whether or not to use tf_function for eval.
217
+ eval_tf_while_loop: whether or not to use tf while loop for eval.
218
+ allow_tpu_summary: Whether to allow summary happen inside the XLA program
219
+ runs on TPU through automatic outside compilation.
220
+ steps_per_loop: number of steps per loop to report training metrics. This
221
+ can also be used to reduce host worker communication in a TPU setup.
222
+ summary_interval: number of steps between each summary.
223
+ checkpoint_interval: number of steps between checkpoints.
224
+ max_to_keep: max checkpoints to keep.
225
+ continuous_eval_timeout: maximum number of seconds to wait between
226
+ checkpoints, if set to None, continuous eval will wait indefinitely. This
227
+ is only used continuous_train_and_eval and continuous_eval modes. Default
228
+ value is 1 hrs.
229
+ train_steps: number of train steps.
230
+ validation_steps: number of eval steps. If -1, the entire eval dataset is
231
+ used.
232
+ validation_interval: number of training steps to run between evaluations.
233
+ best_checkpoint_export_subdir: if set, the trainer will keep track of the
234
+ best evaluation metric, and export the corresponding best checkpoint under
235
+ `model_dir/best_checkpoint_export_subdir`. Note that this only works if
236
+ mode contains eval (such as `train_and_eval`, `continuous_eval`, and
237
+ `continuous_train_and_eval`).
238
+ best_checkpoint_eval_metric: for exporting the best checkpoint, which
239
+ evaluation metric the trainer should monitor. This can be any evaluation
240
+ metric appears on tensorboard.
241
+ best_checkpoint_metric_comp: for exporting the best checkpoint, how the
242
+ trainer should compare the evaluation metrics. This can be either `higher`
243
+ (higher the better) or `lower` (lower the better).
244
+ validation_summary_subdir: A 'str', sub directory for saving eval summary.
245
+ preemption_on_demand_checkpoint: whether or not to save on-demand
246
+ checkpoints after a preemption.
247
+ """
248
+ optimizer_config: OptimizationConfig = dataclasses.field(
249
+ default_factory=OptimizationConfig
250
+ )
251
+ # Orbit settings.
252
+ train_tf_while_loop: bool = True
253
+ train_tf_function: bool = True
254
+ eval_tf_function: bool = True
255
+ eval_tf_while_loop: bool = False
256
+ allow_tpu_summary: bool = False
257
+ # Trainer intervals.
258
+ steps_per_loop: int = 1000
259
+ summary_interval: int = 1000
260
+ checkpoint_interval: int = 1000
261
+ # Checkpoint manager.
262
+ max_to_keep: int = 5
263
+ continuous_eval_timeout: int = 60 * 60
264
+ # Train/Eval routines.
265
+ train_steps: int = 0
266
+ # Sets validation steps to be -1 to evaluate the entire dataset.
267
+ validation_steps: int = -1
268
+ validation_interval: int = 1000
269
+ # Best checkpoint export.
270
+ best_checkpoint_export_subdir: str = ""
271
+ best_checkpoint_eval_metric: str = ""
272
+ best_checkpoint_metric_comp: str = "higher"
273
+ # Blowup recovery.
274
+ loss_upper_bound: float = 1e6
275
+ recovery_begin_steps: int = 0 # Enforcing the loss bound after these steps.
276
+ # When max trials < 0, no recovery module; max trials = 0, we will check
277
+ # the condition and fail the job if the condition happens; max trials > 0,
278
+ # we will retore the model states.
279
+ recovery_max_trials: int = 0
280
+ validation_summary_subdir: str = "validation"
281
+ # Preemption on-demand checkpoint.
282
+ preemption_on_demand_checkpoint: bool = True # copybara-replace
283
+
284
+
285
+ @dataclasses.dataclass
286
+ class TaskConfig(base_config.Config):
287
+ """Config passed to task."""
288
+ init_checkpoint: str = ""
289
+ model: Optional[base_config.Config] = None
290
+ train_data: DataConfig = dataclasses.field(default_factory=DataConfig)
291
+ validation_data: DataConfig = dataclasses.field(default_factory=DataConfig)
292
+ name: Optional[str] = None
293
+ # Configs for differential privacy
294
+ # These configs are only effective if you use create_optimizer in
295
+ # tensorflow_models/official/core/base_task.py
296
+ # DEPRECATED b/264611883
297
+ differential_privacy_config: Optional[
298
+ dp_configs.DifferentialPrivacyConfig] = None
299
+ # Whether to show image summary. Useful to visualize model predictions. Only
300
+ # work for vision tasks.
301
+ allow_image_summary: bool = False
302
+
303
+
304
+ @dataclasses.dataclass
305
+ class ExperimentConfig(base_config.Config):
306
+ """Top-level configuration."""
307
+ task: TaskConfig = dataclasses.field(default_factory=TaskConfig)
308
+ trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig)
309
+ runtime: RuntimeConfig = dataclasses.field(default_factory=RuntimeConfig)
modeling/official/core/exp_factory.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Experiment factory methods."""
16
+
17
+ from official.core import config_definitions as cfg
18
+ from official.core import registry
19
+
20
+
21
+ _REGISTERED_CONFIGS = {}
22
+
23
+
24
+ def register_config_factory(name):
25
+ """Register ExperimentConfig factory method."""
26
+ return registry.register(_REGISTERED_CONFIGS, name)
27
+
28
+
29
+ def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
30
+ """Looks up the `ExperimentConfig` according to the `exp_name`."""
31
+ exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
32
+ return exp_creater()
modeling/official/core/export_base.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Base class for model export."""
16
+
17
+ import abc
18
+ import functools
19
+ import time
20
+ from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
21
+
22
+ from absl import logging
23
+ import tensorflow as tf, tf_keras
24
+
25
+ MAX_DIRECTORY_CREATION_ATTEMPTS = 10
26
+
27
+
28
+ class ExportModule(tf.Module, metaclass=abc.ABCMeta):
29
+ """Base Export Module."""
30
+
31
+ def __init__(self,
32
+ params,
33
+ model: Union[tf.Module, tf_keras.Model],
34
+ inference_step: Optional[Callable[..., Any]] = None,
35
+ *,
36
+ preprocessor: Optional[Callable[..., Any]] = None,
37
+ postprocessor: Optional[Callable[..., Any]] = None):
38
+ """Instantiates an ExportModel.
39
+
40
+ Examples:
41
+
42
+ `inference_step` must be a function that has `model` as an kwarg or the
43
+ second positional argument.
44
+ ```
45
+ def _inference_step(inputs, model=None):
46
+ return model(inputs, training=False)
47
+
48
+ module = ExportModule(params, model, inference_step=_inference_step)
49
+ ```
50
+
51
+ `preprocessor` and `postprocessor` could be either functions or `tf.Module`.
52
+ The usages of preprocessor and postprocessor are managed by the
53
+ implementation of `serve()` method.
54
+
55
+ Args:
56
+ params: A dataclass for parameters to the module.
57
+ model: A model instance which contains weights and forward computation.
58
+ inference_step: An optional callable to forward-pass the model. If not
59
+ specified, it creates a parital function with `model` as an required
60
+ kwarg.
61
+ preprocessor: An optional callable to preprocess the inputs.
62
+ postprocessor: An optional callable to postprocess the model outputs.
63
+ """
64
+ super().__init__(name=None)
65
+ self.model = model
66
+ self.params = params
67
+
68
+ if inference_step is not None:
69
+ self.inference_step = functools.partial(inference_step, model=self.model)
70
+ else:
71
+ if issubclass(type(model), tf_keras.Model):
72
+ # Default to self.model.call instead of self.model.__call__ to avoid
73
+ # keras tracing logic designed for training.
74
+ # Since most of Model Garden's call doesn't not have training kwargs
75
+ # or the default is False, we don't pass anything here.
76
+ # Please pass custom inference step if your model has training=True as
77
+ # default.
78
+ self.inference_step = self.model.call
79
+ else:
80
+ self.inference_step = functools.partial(
81
+ self.model.__call__, training=False)
82
+ self.preprocessor = preprocessor
83
+ self.postprocessor = postprocessor
84
+
85
+ @abc.abstractmethod
86
+ def serve(self) -> Mapping[Text, tf.Tensor]:
87
+ """The bare inference function which should run on all devices.
88
+
89
+ Expecting tensors are passed in through keyword arguments. Returns a
90
+ dictionary of tensors, when the keys will be used inside the SignatureDef.
91
+ """
92
+
93
+ @abc.abstractmethod
94
+ def get_inference_signatures(
95
+ self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
96
+ """Get defined function signatures."""
97
+
98
+
99
+ def export(export_module: ExportModule,
100
+ function_keys: Union[List[Text], Dict[Text, Text]],
101
+ export_savedmodel_dir: Text,
102
+ checkpoint_path: Optional[Text] = None,
103
+ timestamped: bool = True,
104
+ save_options: Optional[tf.saved_model.SaveOptions] = None,
105
+ checkpoint: Optional[tf.train.Checkpoint] = None) -> Text:
106
+ """Exports to SavedModel format.
107
+
108
+ Args:
109
+ export_module: a ExportModule with the keras Model and serving tf.functions.
110
+ function_keys: a list of string keys to retrieve pre-defined serving
111
+ signatures. The signaute keys will be set with defaults. If a dictionary
112
+ is provided, the values will be used as signature keys.
113
+ export_savedmodel_dir: Output saved model directory.
114
+ checkpoint_path: Object-based checkpoint path or directory.
115
+ timestamped: Whether to export the savedmodel to a timestamped directory.
116
+ save_options: `SaveOptions` for `tf.saved_model.save`.
117
+ checkpoint: An optional tf.train.Checkpoint. If provided, the export module
118
+ will use it to read the weights.
119
+
120
+ Returns:
121
+ The savedmodel directory path.
122
+ """
123
+ ckpt_dir_or_file = checkpoint_path
124
+ if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
125
+ ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
126
+ if ckpt_dir_or_file:
127
+ if checkpoint is None:
128
+ checkpoint = tf.train.Checkpoint(model=export_module.model)
129
+ checkpoint.read(
130
+ ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
131
+ if isinstance(function_keys, list):
132
+ if len(function_keys) == 1:
133
+ function_keys = {
134
+ function_keys[0]: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
135
+ }
136
+ else:
137
+ raise ValueError(
138
+ 'If the function_keys is a list, it must contain a single element. %s'
139
+ % function_keys)
140
+
141
+ signatures = export_module.get_inference_signatures(function_keys)
142
+ if timestamped:
143
+ export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode(
144
+ 'utf-8')
145
+ else:
146
+ export_dir = export_savedmodel_dir
147
+ tf.saved_model.save(
148
+ export_module, export_dir, signatures=signatures, options=save_options)
149
+ return export_dir
150
+
151
+
152
+ def get_timestamped_export_dir(export_dir_base):
153
+ """Builds a path to a new subdirectory within the base directory.
154
+
155
+ Args:
156
+ export_dir_base: A string containing a directory to write the exported graph
157
+ and checkpoints.
158
+
159
+ Returns:
160
+ The full path of the new subdirectory (which is not actually created yet).
161
+
162
+ Raises:
163
+ RuntimeError: if repeated attempts fail to obtain a unique timestamped
164
+ directory name.
165
+ """
166
+ attempts = 0
167
+ while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
168
+ timestamp = int(time.time())
169
+
170
+ result_dir = tf.io.gfile.join(
171
+ tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp)))
172
+ if not tf.io.gfile.exists(result_dir):
173
+ # Collisions are still possible (though extremely unlikely): this
174
+ # directory is not actually created yet, but it will be almost
175
+ # instantly on return from this function.
176
+ return result_dir
177
+ time.sleep(1)
178
+ attempts += 1
179
+ logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
180
+ str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
181
+ raise RuntimeError('Failed to obtain a unique export directory name after '
182
+ f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
modeling/official/core/export_base_test.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.core.export_base."""
16
+ import os
17
+ from typing import Any, Dict, Mapping, Text
18
+
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.core import export_base
22
+
23
+
24
+ class TestModule(export_base.ExportModule):
25
+
26
+ @tf.function
27
+ def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
28
+ x = inputs if self.preprocessor is None else self.preprocessor(
29
+ inputs=inputs)
30
+ x = self.inference_step(x)
31
+ x = self.postprocessor(x) if self.postprocessor else x
32
+ return {'outputs': x}
33
+
34
+ def get_inference_signatures(
35
+ self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
36
+ input_signature = tf.TensorSpec(shape=[None, None], dtype=tf.float32)
37
+ return {'foo': self.serve.get_concrete_function(input_signature)}
38
+
39
+
40
+ class ExportBaseTest(tf.test.TestCase):
41
+
42
+ def test_export_module(self):
43
+ tmp_dir = self.get_temp_dir()
44
+ model = tf_keras.layers.Dense(2)
45
+ inputs = tf.ones([2, 4], tf.float32)
46
+ expected_output = model(inputs, training=False)
47
+ module = TestModule(params=None, model=model)
48
+ ckpt_path = tf.train.Checkpoint(model=model).save(
49
+ os.path.join(tmp_dir, 'ckpt'))
50
+ export_dir = export_base.export(
51
+ module, ['foo'],
52
+ export_savedmodel_dir=tmp_dir,
53
+ checkpoint_path=ckpt_path,
54
+ timestamped=True)
55
+ self.assertTrue(os.path.exists(os.path.join(export_dir, 'saved_model.pb')))
56
+ self.assertTrue(
57
+ os.path.exists(
58
+ os.path.join(export_dir, 'variables', 'variables.index')))
59
+ self.assertTrue(
60
+ os.path.exists(
61
+ os.path.join(export_dir, 'variables',
62
+ 'variables.data-00000-of-00001')))
63
+
64
+ imported = tf.saved_model.load(export_dir)
65
+ output = imported.signatures['foo'](inputs)
66
+ self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
67
+
68
+ def test_custom_inference_step(self):
69
+ tmp_dir = self.get_temp_dir()
70
+ model = tf_keras.layers.Dense(2)
71
+ inputs = tf.ones([2, 4], tf.float32)
72
+
73
+ def _inference_step(inputs, model):
74
+ return tf.nn.softmax(model(inputs, training=False))
75
+
76
+ module = TestModule(
77
+ params=None, model=model, inference_step=_inference_step)
78
+ expected_output = _inference_step(inputs, model)
79
+ ckpt_path = tf.train.Checkpoint(model=model).save(
80
+ os.path.join(tmp_dir, 'ckpt'))
81
+ export_dir = export_base.export(
82
+ module, ['foo'],
83
+ export_savedmodel_dir=tmp_dir,
84
+ checkpoint_path=ckpt_path,
85
+ timestamped=False)
86
+ imported = tf.saved_model.load(export_dir)
87
+ output = imported.signatures['foo'](inputs)
88
+ self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
89
+
90
+ def test_processors(self):
91
+ model = tf.Module()
92
+ inputs = tf.zeros((), tf.float32)
93
+
94
+ def _inference_step(inputs, model):
95
+ del model
96
+ return inputs + 1.0
97
+
98
+ def _preprocessor(inputs):
99
+ print(inputs)
100
+ return inputs + 0.1
101
+
102
+ module = TestModule(
103
+ params=None,
104
+ model=model,
105
+ inference_step=_inference_step,
106
+ preprocessor=_preprocessor)
107
+ output = module.serve(inputs)
108
+ self.assertAllClose(output['outputs'].numpy(), 1.1)
109
+
110
+ class _PostProcessor(tf.Module):
111
+
112
+ def __call__(self, inputs):
113
+ return inputs + 0.01
114
+
115
+ module = TestModule(
116
+ params=None,
117
+ model=model,
118
+ inference_step=_inference_step,
119
+ preprocessor=_preprocessor,
120
+ postprocessor=_PostProcessor())
121
+ output = module.serve(inputs)
122
+ self.assertAllClose(output['outputs'].numpy(), 1.11)
123
+
124
+ def test_get_timestamped_export_dir(self):
125
+ export_dir = self.get_temp_dir()
126
+ timed_dir = export_base.get_timestamped_export_dir(
127
+ export_dir_base=export_dir)
128
+ self.assertFalse(tf.io.gfile.exists(timed_dir))
129
+ self.assertIn(export_dir, str(timed_dir))
130
+
131
+
132
+ if __name__ == '__main__':
133
+ tf.test.main()
modeling/official/core/file_writers.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """File writer functions for dataset preparation, infra validation, and unit tests."""
16
+
17
+ import io
18
+ from typing import Optional, Sequence, Union
19
+
20
+ import tensorflow as tf, tf_keras
21
+
22
+
23
+ def write_small_dataset(examples: Sequence[Union[tf.train.Example,
24
+ tf.train.SequenceExample]],
25
+ output_path: str,
26
+ file_type: str = 'tfrecord') -> None:
27
+ """Writes `examples` to a file at `output_path` with type `file_type`.
28
+
29
+ CAVEAT: This function is not recommended for writing large datasets, since it
30
+ will loop through `examples` and perform write operation sequentially.
31
+
32
+ Args:
33
+ examples: List of tf.train.Example or tf.train.SequenceExample.
34
+ output_path: Output path for the dataset.
35
+ file_type: A string indicating the file format, could be: 'tfrecord',
36
+ 'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
37
+ string is case insensitive.
38
+ """
39
+ file_type = file_type.lower()
40
+
41
+ if file_type == 'tfrecord' or file_type == 'tfrecords':
42
+ _write_tfrecord(examples, output_path)
43
+ elif file_type == 'tfrecord_compressed' or file_type == 'tfrecords_gzip':
44
+ _write_tfrecord(examples, output_path,
45
+ tf.io.TFRecordOptions(compression_type='GZIP'))
46
+ elif file_type == 'riegeli':
47
+ _write_riegeli(examples, output_path)
48
+ else:
49
+ raise ValueError(f'Unknown file_type: {file_type}')
50
+
51
+
52
+ def _write_tfrecord(examples: Sequence[Union[tf.train.Example,
53
+ tf.train.SequenceExample]],
54
+ output_path: str,
55
+ options: Optional[tf.io.TFRecordOptions] = None) -> None:
56
+ """Writes `examples` to a TFRecord file at `output_path`.
57
+
58
+ Args:
59
+ examples: A list of tf.train.Example.
60
+ output_path: Output path for the dataset.
61
+ options: Options used for manipulating TFRecord files.
62
+ """
63
+ with tf.io.TFRecordWriter(output_path, options) as writer:
64
+ for example in examples:
65
+ writer.write(example.SerializeToString())
66
+
67
+
68
+ def _write_riegeli(examples: Sequence[Union[tf.train.Example,
69
+ tf.train.SequenceExample]],
70
+ output_path: str) -> None:
71
+ """Writes `examples` to a Riegeli file at `output_path`.
72
+
73
+ Args:
74
+ examples: A list of tf.train.Example.
75
+ output_path: Output path for the dataset.
76
+ """
77
+ with io.FileIO(output_path, 'wb') as fileio:
78
+ import riegeli # pylint: disable=g-import-not-at-top
79
+ with riegeli.RecordWriter(fileio) as writer:
80
+ writer.write_messages(examples)
modeling/official/core/file_writers_test.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for file_writers."""
16
+
17
+ import os
18
+ from absl.testing import parameterized
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.core import file_writers
22
+ from official.core import tf_example_builder
23
+
24
+
25
+ class FileWritersTest(tf.test.TestCase, parameterized.TestCase):
26
+
27
+ def setUp(self):
28
+ super().setUp()
29
+ example_builder = tf_example_builder.TfExampleBuilder()
30
+ example_builder.add_bytes_feature('foo', 'Hello World!')
31
+ self._example = example_builder.example
32
+
33
+ @parameterized.parameters('tfrecord', 'TFRecord', 'tfrecords',
34
+ 'tfrecord_compressed', 'TFRecord_Compressed',
35
+ 'tfrecords_gzip')
36
+ def test_write_small_dataset_success(self, file_type):
37
+ temp_dir = self.create_tempdir()
38
+ temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
39
+ file_writers.write_small_dataset([self._example], temp_dataset_file,
40
+ file_type)
41
+ self.assertTrue(os.path.exists(temp_dataset_file))
42
+
43
+ def test_write_small_dataset_unrecognized_format(self):
44
+ file_type = 'bar'
45
+ temp_dir = self.create_tempdir()
46
+ temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
47
+ with self.assertRaises(ValueError):
48
+ file_writers.write_small_dataset([self._example], temp_dataset_file,
49
+ file_type)
50
+
51
+
52
+ if __name__ == '__main__':
53
+ tf.test.main()
modeling/official/core/input_reader.py ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A common dataset reader."""
16
+ import dataclasses
17
+ import random
18
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
19
+
20
+ from absl import logging
21
+ import tensorflow as tf, tf_keras
22
+ import tensorflow_datasets as tfds
23
+
24
+ from official.core import config_definitions as cfg
25
+
26
+
27
+ def _get_random_integer():
28
+ return random.randint(0, (1 << 31) - 1)
29
+
30
+
31
+ def _maybe_map_fn(dataset: tf.data.Dataset,
32
+ fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
33
+ """Calls dataset.map if a valid function is passed in."""
34
+ return dataset if fn is None else dataset.map(
35
+ fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
36
+
37
+
38
+ def match_files(input_path: Union[Sequence[str], str]) -> List[str]:
39
+ """Matches files from an input_path."""
40
+ matched_files = []
41
+ # Read dataset from files.
42
+ usage = ('`input_path` should be either (1) a str indicating a file '
43
+ 'path/pattern, or (2) a str indicating multiple file '
44
+ 'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
45
+ '"a,b,c", or (3) a list of str, each of which is a file '
46
+ 'path/pattern or multiple file paths/patterns separated by '
47
+ 'comma, but got: %s')
48
+ if isinstance(input_path, str):
49
+ input_path_list = [input_path]
50
+ elif isinstance(input_path, (list, tuple)):
51
+ if any(not isinstance(x, str) for x in input_path):
52
+ raise ValueError(usage % input_path)
53
+ input_path_list = input_path
54
+ else:
55
+ raise ValueError(usage % input_path)
56
+
57
+ for input_path in input_path_list:
58
+ input_patterns = input_path.strip().split(',')
59
+ for input_pattern in input_patterns:
60
+ input_pattern = input_pattern.strip()
61
+ if not input_pattern:
62
+ continue
63
+ if '*' in input_pattern or '?' in input_pattern:
64
+ tmp_matched_files = tf.io.gfile.glob(input_pattern)
65
+ if not tmp_matched_files:
66
+ raise ValueError('%s does not match any files.' % input_pattern)
67
+ matched_files.extend(tmp_matched_files)
68
+ else:
69
+ matched_files.append(input_pattern)
70
+
71
+ if not matched_files:
72
+ raise ValueError('%s does not match any files.' % input_path)
73
+
74
+ return matched_files
75
+
76
+
77
+ def _read_files_then_shard(matched_files: List[str],
78
+ dataset_fn,
79
+ input_context: Optional[
80
+ tf.distribute.InputContext] = None,
81
+ sharding: bool = False,
82
+ repeat: bool = False) -> tf.data.Dataset:
83
+ """Sends all data files to every worker and then shard by data."""
84
+ dataset = dataset_fn(matched_files)
85
+
86
+ # When `input_file` is a path to a single file or the number of files is
87
+ # less than the number of input pipelines, disable auto sharding
88
+ # so that same input file is sent to all workers.
89
+ options = tf.data.Options()
90
+ options.experimental_distribute.auto_shard_policy = (
91
+ tf.data.experimental.AutoShardPolicy.OFF)
92
+ dataset = dataset.with_options(options)
93
+ # Do not enable sharding if tf.data service is enabled, as sharding will be
94
+ # handled inside tf.data service.
95
+ if sharding and input_context and (input_context.num_input_pipelines > 1):
96
+ dataset = dataset.shard(input_context.num_input_pipelines,
97
+ input_context.input_pipeline_id)
98
+
99
+ if repeat:
100
+ dataset = dataset.repeat()
101
+ return dataset
102
+
103
+
104
+ def _shard_files_then_read(matched_files: List[str],
105
+ dataset_fn,
106
+ input_context: Optional[
107
+ tf.distribute.InputContext] = None,
108
+ seed: Optional[Union[int, tf.Tensor]] = None,
109
+ is_training: bool = False,
110
+ sharding: bool = False,
111
+ cache: bool = False,
112
+ cycle_length: Optional[int] = None,
113
+ block_length: Optional[int] = None,
114
+ deterministic: bool = False) -> tf.data.Dataset:
115
+ """Shards the data files and then sent a split to every worker to read."""
116
+ dataset = tf.data.Dataset.from_tensor_slices(matched_files)
117
+
118
+ # Shuffle and repeat at file level.
119
+ # If cache is enabled, `reshuffle_each_iteration` is set to False,
120
+ # because we will read the same cached data in every iteration anyway.
121
+ if is_training:
122
+ # We need a seed to shuffle the files so that when each TPU workers gets
123
+ # its own shard the files do not overlap.
124
+ if sharding and seed is None:
125
+ seed = _get_random_integer()
126
+ dataset = dataset.shuffle(
127
+ len(matched_files),
128
+ seed=seed,
129
+ reshuffle_each_iteration=True if not cache else False)
130
+
131
+ # Do not enable sharding if tf.data service is enabled, as sharding will be
132
+ # handled inside tf.data service.
133
+ if sharding and input_context and (input_context.num_input_pipelines > 1):
134
+ dataset = dataset.shard(input_context.num_input_pipelines,
135
+ input_context.input_pipeline_id)
136
+
137
+ # If cache is enabled, we will call `repeat()` later after `cache()`.
138
+ if is_training and not cache:
139
+ dataset = dataset.repeat()
140
+
141
+ dataset = dataset.interleave(
142
+ map_func=dataset_fn,
143
+ cycle_length=cycle_length,
144
+ block_length=block_length,
145
+ num_parallel_calls=(cycle_length
146
+ if cycle_length else tf.data.experimental.AUTOTUNE),
147
+ deterministic=deterministic)
148
+ return dataset
149
+
150
+
151
+ def _read_tfds(tfds_name: Text,
152
+ tfds_data_dir: Text,
153
+ tfds_split: Text,
154
+ tfds_skip_decoding_feature: Text,
155
+ tfds_as_supervised: bool,
156
+ input_context: Optional[tf.distribute.InputContext] = None,
157
+ seed: Optional[Union[int, tf.Tensor]] = None,
158
+ is_training: bool = False,
159
+ cache: bool = False,
160
+ cycle_length: Optional[int] = None,
161
+ block_length: Optional[int] = None) -> tf.data.Dataset:
162
+ """Reads a dataset from tfds."""
163
+ repeat_filenames = is_training and not cache
164
+ read_config = tfds.ReadConfig(
165
+ interleave_cycle_length=cycle_length,
166
+ interleave_block_length=block_length,
167
+ input_context=input_context,
168
+ shuffle_seed=seed,
169
+ repeat_filenames=repeat_filenames,
170
+ # Only assert cardinality when we have a finite dataset.
171
+ assert_cardinality=not repeat_filenames,
172
+ skip_prefetch=True)
173
+
174
+ decoders = {}
175
+ if tfds_skip_decoding_feature:
176
+ for skip_feature in tfds_skip_decoding_feature.split(','):
177
+ decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
178
+
179
+ if tfds_name.startswith('mldataset.'):
180
+ dataset = tfds.load(name=tfds_name,
181
+ split=tfds_split,
182
+ as_supervised=tfds_as_supervised,
183
+ decoders=decoders if decoders else None,
184
+ read_config=read_config)
185
+ else:
186
+ builder = tfds.builder(tfds_name, data_dir=tfds_data_dir)
187
+ if builder.info.splits:
188
+ num_shards = len(builder.info.splits[tfds_split].file_instructions)
189
+ else:
190
+ # The tfds mock path often does not provide splits.
191
+ num_shards = 1
192
+ load_kwargs = dict(
193
+ name=tfds_name, download=True, split=tfds_split,
194
+ shuffle_files=is_training, as_supervised=tfds_as_supervised,
195
+ decoders=decoders if decoders else None)
196
+ if tfds_data_dir:
197
+ load_kwargs.update({'data_dir': tfds_data_dir})
198
+
199
+ if input_context and num_shards < input_context.num_input_pipelines:
200
+ # The number of files in the dataset split is smaller than the number of
201
+ # input pipelines. We read the entire dataset first and then shard in the
202
+ # host memory.
203
+ read_config = dataclasses.replace(read_config, input_context=None)
204
+ load_kwargs.update({'read_config': read_config})
205
+ dataset = tfds.load(**load_kwargs)
206
+ dataset = dataset.shard(input_context.num_input_pipelines,
207
+ input_context.input_pipeline_id)
208
+ else:
209
+ load_kwargs.update({'read_config': read_config})
210
+ dataset = tfds.load(**load_kwargs)
211
+ return dataset
212
+
213
+
214
+ class InputReader:
215
+ """Input reader that returns a tf.data.Dataset instance."""
216
+
217
+ # A static random number which is the same across different InputReader
218
+ # instances.
219
+ static_randnum = _get_random_integer()
220
+
221
+ def __init__(
222
+ self,
223
+ params: cfg.DataConfig,
224
+ dataset_fn=tf.data.TFRecordDataset,
225
+ decoder_fn: Optional[Callable[..., Any]] = None,
226
+ combine_fn: Optional[Callable[..., Any]] = None,
227
+ sample_fn: Optional[Callable[..., Any]] = None,
228
+ parser_fn: Optional[Callable[..., Any]] = None,
229
+ filter_fn: Optional[Callable[..., tf.Tensor]] = None,
230
+ transform_and_batch_fn: Optional[
231
+ Callable[
232
+ [tf.data.Dataset, Optional[tf.distribute.InputContext]],
233
+ tf.data.Dataset,
234
+ ]
235
+ ] = None,
236
+ postprocess_fn: Optional[Callable[..., Any]] = None,
237
+ ):
238
+ """Initializes an InputReader instance.
239
+
240
+ Args:
241
+ params: A config_definitions.DataConfig object.
242
+ dataset_fn: A `tf.data.Dataset` that consumes the input files. For
243
+ example, it can be `tf.data.TFRecordDataset`.
244
+ decoder_fn: An optional `callable` that takes the serialized data string
245
+ and decodes them into the raw tensor dictionary.
246
+ combine_fn: An optional `callable` that takes a dictionarty of
247
+ `tf.data.Dataset` objects as input and outputs a combined dataset. It
248
+ will be executed after the decoder_fn and before the sample_fn.
249
+ sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
250
+ input and outputs the transformed dataset. It performs sampling on the
251
+ decoded raw tensors dict before the parser_fn.
252
+ parser_fn: An optional `callable` that takes the decoded raw tensors dict
253
+ and parse them into a dictionary of tensors that can be consumed by the
254
+ model. It will be executed after decoder_fn.
255
+ filter_fn: An optional `callable` mapping a dataset element to a boolean.
256
+ It will be executed after parser_fn.
257
+ transform_and_batch_fn: An optional `callable` that takes a
258
+ `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
259
+ input, and returns a `tf.data.Dataset` object. It will be executed after
260
+ `parser_fn` to transform and batch the dataset; if None, after
261
+ `parser_fn` is executed, the dataset will be batched into per-replica
262
+ batch size.
263
+ postprocess_fn: A optional `callable` that processes batched tensors. It
264
+ will be executed after batching.
265
+ """
266
+ if params.input_path and params.tfds_name:
267
+ raise ValueError('At most one of `input_path` and `tfds_name` can be '
268
+ 'specified, but got %s and %s.' %
269
+ (params.input_path, params.tfds_name))
270
+
271
+ if (isinstance(params.input_path, cfg.base_config.Config) or
272
+ isinstance(params.tfds_name, cfg.base_config.Config)
273
+ ) and combine_fn is None:
274
+ raise ValueError(
275
+ 'A combine_fn is required if `input_path` or `tfds_name` is a dict.')
276
+
277
+ self._tfds_name = params.tfds_name
278
+ self._tfds_data_dir = params.tfds_data_dir
279
+ self._matched_files = None
280
+ if not params.input_path:
281
+ # Read dataset from TFDS.
282
+ if not params.tfds_split:
283
+ raise ValueError(
284
+ '`tfds_name` is %s, but `tfds_split` is not specified.' %
285
+ params.tfds_name)
286
+ else:
287
+ self._matched_files = self.get_files(params.input_path)
288
+
289
+ self._global_batch_size = params.global_batch_size
290
+ self._is_training = params.is_training
291
+ self._drop_remainder = params.drop_remainder
292
+ self._shuffle_buffer_size = params.shuffle_buffer_size
293
+ self._cache = params.cache
294
+ self._cycle_length = params.cycle_length
295
+ self._block_length = params.block_length
296
+ self._deterministic = params.deterministic
297
+ self._sharding = params.sharding
298
+ self._tfds_split = params.tfds_split
299
+ self._tfds_as_supervised = params.tfds_as_supervised
300
+ self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
301
+
302
+ self._dataset_fn = dataset_fn
303
+ self._decoder_fn = decoder_fn
304
+ self._combine_fn = combine_fn
305
+ self._sample_fn = sample_fn
306
+ self._parser_fn = parser_fn
307
+ self._transform_and_batch_fn = transform_and_batch_fn
308
+ self._postprocess_fn = postprocess_fn
309
+ self._filter_fn = filter_fn
310
+ self._seed = params.seed
311
+ self._prefetch_buffer_size = (
312
+ params.prefetch_buffer_size or tf.data.experimental.AUTOTUNE)
313
+ self._autotune_algorithm = params.autotune_algorithm
314
+
315
+ # When tf.data service is enabled, each data service worker should get
316
+ # different random seeds. Thus, we set `seed` to None.
317
+ # Sharding should also be disabled because tf data service handles how
318
+ # each worker shard data with `processing_mode` in distribute method.
319
+ if params.enable_tf_data_service:
320
+ self._seed = None
321
+ self._sharding = False
322
+
323
+ self._enable_tf_data_service = (
324
+ params.enable_tf_data_service and params.tf_data_service_address)
325
+ self._tf_data_service_address = params.tf_data_service_address
326
+ self._enable_shared_tf_data_service_between_parallel_trainers = (
327
+ params.enable_shared_tf_data_service_between_parallel_trainers)
328
+ self._apply_tf_data_service_before_batching = (
329
+ params.apply_tf_data_service_before_batching)
330
+ self._trainer_id = params.trainer_id
331
+ if self._enable_tf_data_service:
332
+ # Add a random seed as the tf.data service job name suffix, so tf.data
333
+ # service doesn't reuse the previous state if TPU worker gets preempted.
334
+ # It's necessary to add global batch size into the tf data service job
335
+ # name because when tuning batch size with vizier and tf data service is
336
+ # also enable, the tf data servce job name should be different for
337
+ # different vizier trials since once batch size is changed, from the
338
+ # tf.data perspective, the dataset is a different instance, and a
339
+ # different job name should be used for tf data service. Otherwise, the
340
+ # model would read tensors from the incorrect tf data service job, which
341
+ # would causes dimension mismatch on the batch size dimension.
342
+ self._tf_data_service_job_name = (
343
+ f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
344
+ f'{self.static_randnum}')
345
+ self._enable_round_robin_tf_data_service = params.get(
346
+ 'enable_round_robin_tf_data_service', False)
347
+ if self._enable_shared_tf_data_service_between_parallel_trainers:
348
+ # When shared tf.data service is enabled, only a single tf.data service
349
+ # instance should be created and shared between parallel trainers. If
350
+ # the global batch size is different across trainers,
351
+ # params.apply_tf_data_service_before_batching should be set to true
352
+ # because tf.data service with different batch sizes will be considered
353
+ # separate tf.data service instances.
354
+ self._tf_data_service_job_name = (
355
+ f'{params.tf_data_service_job_name}_{self.static_randnum}')
356
+
357
+ def get_files(self, input_path):
358
+ """Gets matched files. Can be overridden by subclasses."""
359
+ if not input_path:
360
+ return None
361
+ # we want to combine / mix datasets
362
+ if isinstance(input_path, cfg.base_config.Config):
363
+ matched_files = {}
364
+ for k, v in input_path.as_dict().items():
365
+ matched_files[k] = match_files(v)
366
+ # single dataset
367
+ else:
368
+ matched_files = match_files(input_path)
369
+ return matched_files
370
+
371
+ def _read_data_source(
372
+ self,
373
+ matched_files: Union[Dict[str, List[str]], List[str]],
374
+ dataset_fn,
375
+ input_context: Optional[tf.distribute.InputContext] = None,
376
+ ):
377
+ """Reads the data source (files/tfds) to a dataset."""
378
+
379
+ def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
380
+ if len(files) > 1:
381
+ if input_context and (len(files) < input_context.num_input_pipelines):
382
+ logging.warn(
383
+ (
384
+ 'The number of files %d is less than the number of input '
385
+ 'pipelines %d. We will send all input files to every worker. '
386
+ 'Please consider sharding your data into more files.'
387
+ ),
388
+ len(files),
389
+ input_context.num_input_pipelines,
390
+ )
391
+ return _read_files_then_shard(
392
+ files,
393
+ dataset_fn,
394
+ input_context,
395
+ sharding=self._sharding,
396
+ repeat=self._is_training and not self._cache)
397
+ else:
398
+ return _shard_files_then_read(
399
+ files,
400
+ dataset_fn,
401
+ input_context,
402
+ seed=self._seed,
403
+ is_training=self._is_training,
404
+ sharding=self._sharding,
405
+ cache=self._cache,
406
+ cycle_length=self._cycle_length,
407
+ block_length=self._block_length,
408
+ deterministic=self._deterministic)
409
+ elif len(files) == 1:
410
+ return _read_files_then_shard(
411
+ files,
412
+ dataset_fn,
413
+ input_context,
414
+ sharding=self._sharding,
415
+ repeat=self._is_training and not self._cache)
416
+ else:
417
+ raise ValueError('It is unexpected that `tfds_builder` is None and '
418
+ 'there is also no `files`.')
419
+
420
+ if self._tfds_name:
421
+ if isinstance(self._tfds_name, cfg.base_config.Config):
422
+ dataset = {}
423
+ for k, tfds_name in self._tfds_name.as_dict().items():
424
+ dataset[k] = _read_tfds(
425
+ tfds_name=tfds_name,
426
+ tfds_data_dir=self._tfds_data_dir,
427
+ tfds_split=self._tfds_split,
428
+ tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
429
+ tfds_as_supervised=self._tfds_as_supervised,
430
+ input_context=input_context,
431
+ seed=self._seed,
432
+ is_training=self._is_training,
433
+ cache=self._cache,
434
+ cycle_length=self._cycle_length,
435
+ block_length=self._block_length)
436
+ else:
437
+ dataset = _read_tfds(
438
+ tfds_name=self._tfds_name,
439
+ tfds_data_dir=self._tfds_data_dir,
440
+ tfds_split=self._tfds_split,
441
+ tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
442
+ tfds_as_supervised=self._tfds_as_supervised,
443
+ input_context=input_context,
444
+ seed=self._seed,
445
+ is_training=self._is_training,
446
+ cache=self._cache,
447
+ cycle_length=self._cycle_length,
448
+ block_length=self._block_length)
449
+ elif isinstance(matched_files, (list, tuple)):
450
+ dataset = _files_to_dataset(matched_files)
451
+ elif isinstance(matched_files, dict):
452
+ dataset = {}
453
+ for k, fs in matched_files.items():
454
+ dataset[k] = _files_to_dataset(fs)
455
+ else:
456
+ raise ValueError('`matched_files` should be a list or dict.')
457
+
458
+ return dataset
459
+
460
+ def _decode_and_parse_dataset(
461
+ self,
462
+ dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]],
463
+ batch_size: int,
464
+ input_context: Optional[tf.distribute.InputContext] = None
465
+ ) -> tf.data.Dataset:
466
+ """Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
467
+
468
+ def _shuffle_and_decode(ds):
469
+ # If cache is enabled, we will call `shuffle()` later after `cache()`.
470
+ if self._is_training and not self._cache:
471
+ ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
472
+ # Decode
473
+ ds = _maybe_map_fn(ds, self._decoder_fn)
474
+ return ds
475
+
476
+ dataset = tf.nest.map_structure(_shuffle_and_decode, dataset)
477
+ if tf.nest.is_nested(dataset):
478
+ dataset = self._combine_fn(dataset)
479
+
480
+ if self._sample_fn is not None:
481
+ dataset = dataset.apply(self._sample_fn)
482
+ dataset = _maybe_map_fn(dataset, self._parser_fn)
483
+
484
+ if self._filter_fn is not None:
485
+ dataset = dataset.filter(self._filter_fn)
486
+
487
+ if self._cache:
488
+ dataset = dataset.cache()
489
+ if self._is_training:
490
+ dataset = dataset.repeat()
491
+ dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
492
+
493
+ # Applies tf.data service before batching operations. This is useful when
494
+ # tf.data service is shared between parallel trainers, and batch size is
495
+ # changing between parallel trainers. Then batch size is changing, tf.data
496
+ # services will be considered different instances if applied after batching
497
+ # operations, which make it difficult to share between parallel trainers.
498
+ # However, if there are additional expensive operations in
499
+ # self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data
500
+ # pipeline could be slowed down. In this case, try to move these dataset
501
+ # operations into early stages if possible.
502
+ if (self._enable_shared_tf_data_service_between_parallel_trainers and
503
+ self._apply_tf_data_service_before_batching):
504
+ dataset = self._maybe_apply_data_service(dataset, input_context)
505
+
506
+ if self._transform_and_batch_fn is not None:
507
+ dataset = self._transform_and_batch_fn(dataset, input_context)
508
+ else:
509
+ per_replica_batch_size = input_context.get_per_replica_batch_size(
510
+ batch_size) if input_context else batch_size
511
+ dataset = dataset.batch(
512
+ per_replica_batch_size, drop_remainder=self._drop_remainder)
513
+
514
+ return dataset
515
+
516
+ def _maybe_apply_data_service(
517
+ self,
518
+ dataset: tf.data.Dataset,
519
+ input_context: Optional[tf.distribute.InputContext] = None
520
+ ) -> tf.data.Dataset:
521
+ """Potentially distributes a dataset."""
522
+ if self._enable_tf_data_service and input_context:
523
+ if self._enable_round_robin_tf_data_service:
524
+ replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
525
+ input_context.num_input_pipelines)
526
+ base_consumer_index = input_context.input_pipeline_id * (
527
+ replicas_per_input_pipeline)
528
+ num_consumers = input_context.num_input_pipelines * (
529
+ replicas_per_input_pipeline)
530
+ range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
531
+ tfds_kwargs = {
532
+ 'processing_mode': 'parallel_epochs',
533
+ 'service': self._tf_data_service_address,
534
+ 'job_name': self._tf_data_service_job_name,
535
+ 'num_consumers': num_consumers
536
+ }
537
+ if self._enable_shared_tf_data_service_between_parallel_trainers:
538
+ raise ValueError('Shared tf.data service does not support round-robin'
539
+ ' tf.data service.')
540
+ dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
541
+ tf.data.experimental.service.distribute(
542
+ consumer_index=base_consumer_index + i, **tfds_kwargs)))
543
+ # Use parallel interleave to read multiple batches from a tf.data
544
+ # service worker in parallel.
545
+ dataset = dataset.interleave(
546
+ lambda x: x,
547
+ cycle_length=replicas_per_input_pipeline,
548
+ num_parallel_calls=replicas_per_input_pipeline,
549
+ deterministic=True)
550
+ else:
551
+ tfds_kwargs = {
552
+ 'processing_mode': 'parallel_epochs',
553
+ 'service': self._tf_data_service_address,
554
+ 'job_name': self._tf_data_service_job_name,
555
+ }
556
+ if self._enable_shared_tf_data_service_between_parallel_trainers:
557
+ tfds_kwargs.update({
558
+ 'processing_mode':
559
+ tf.data.experimental.service.ShardingPolicy.OFF,
560
+ 'cross_trainer_cache':
561
+ tf.data.experimental.service.CrossTrainerCache(
562
+ trainer_id=self._trainer_id)
563
+ })
564
+ dataset = dataset.apply(
565
+ tf.data.experimental.service.distribute(**tfds_kwargs))
566
+ return dataset
567
+
568
+ def read(self,
569
+ input_context: Optional[tf.distribute.InputContext] = None,
570
+ dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
571
+ """Generates a tf.data.Dataset object."""
572
+ if dataset is None:
573
+ dataset = self._read_data_source(self._matched_files, self._dataset_fn,
574
+ input_context)
575
+ dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
576
+ input_context)
577
+ dataset = _maybe_map_fn(dataset, self._postprocess_fn)
578
+ if not (self._enable_shared_tf_data_service_between_parallel_trainers and
579
+ self._apply_tf_data_service_before_batching):
580
+ dataset = self._maybe_apply_data_service(dataset, input_context)
581
+
582
+ if self._deterministic is not None:
583
+ options = tf.data.Options()
584
+ options.deterministic = self._deterministic
585
+ dataset = dataset.with_options(options)
586
+ if self._autotune_algorithm:
587
+ options = tf.data.Options()
588
+ options.autotune.autotune_algorithm = (
589
+ tf.data.experimental.AutotuneAlgorithm[self._autotune_algorithm])
590
+ dataset = dataset.with_options(options)
591
+ return dataset.prefetch(self._prefetch_buffer_size)
modeling/official/core/registry.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Registry utility."""
16
+
17
+
18
+ def register(registered_collection, reg_key):
19
+ """Register decorated function or class to collection.
20
+
21
+ Register decorated function or class into registered_collection, in a
22
+ hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
23
+ the decorated function or class is stored under
24
+ registered_collection["my_model"]["my_exp"]["my_config_0"].
25
+ This decorator is supposed to be used together with the lookup() function in
26
+ this file.
27
+
28
+ Args:
29
+ registered_collection: a dictionary. The decorated function or class will be
30
+ put into this collection.
31
+ reg_key: The key for retrieving the registered function or class. If reg_key
32
+ is a string, it can be hierarchical like my_model/my_exp/my_config_0
33
+ Returns:
34
+ A decorator function
35
+ Raises:
36
+ KeyError: when function or class to register already exists.
37
+ """
38
+ def decorator(fn_or_cls):
39
+ """Put fn_or_cls in the dictionary."""
40
+ if isinstance(reg_key, str):
41
+ hierarchy = reg_key.split("/")
42
+ collection = registered_collection
43
+ for h_idx, entry_name in enumerate(hierarchy[:-1]):
44
+ if entry_name not in collection:
45
+ collection[entry_name] = {}
46
+ collection = collection[entry_name]
47
+ if not isinstance(collection, dict):
48
+ raise KeyError(
49
+ "Collection path {} at position {} already registered as "
50
+ "a function or class.".format(entry_name, h_idx))
51
+ leaf_reg_key = hierarchy[-1]
52
+ else:
53
+ collection = registered_collection
54
+ leaf_reg_key = reg_key
55
+
56
+ if leaf_reg_key in collection:
57
+ raise KeyError("Function or class {} registered multiple times.".format(
58
+ leaf_reg_key))
59
+
60
+ collection[leaf_reg_key] = fn_or_cls
61
+ return fn_or_cls
62
+ return decorator
63
+
64
+
65
+ def lookup(registered_collection, reg_key):
66
+ """Lookup and return decorated function or class in the collection.
67
+
68
+ Lookup decorated function or class in registered_collection, in a
69
+ hierarchical order. For example, when
70
+ reg_key="my_model/my_exp/my_config_0",
71
+ this function will return
72
+ registered_collection["my_model"]["my_exp"]["my_config_0"].
73
+
74
+ Args:
75
+ registered_collection: a dictionary. The decorated function or class will be
76
+ retrieved from this collection.
77
+ reg_key: The key for retrieving the registered function or class. If reg_key
78
+ is a string, it can be hierarchical like my_model/my_exp/my_config_0
79
+ Returns:
80
+ The registered function or class.
81
+ Raises:
82
+ LookupError: when reg_key cannot be found.
83
+ """
84
+ if isinstance(reg_key, str):
85
+ hierarchy = reg_key.split("/")
86
+ collection = registered_collection
87
+ for h_idx, entry_name in enumerate(hierarchy):
88
+ if entry_name not in collection:
89
+ raise LookupError(
90
+ f"collection path {entry_name} at position {h_idx} is never "
91
+ f"registered. Please make sure the {entry_name} and its library is "
92
+ "imported and linked to the trainer binary.")
93
+ collection = collection[entry_name]
94
+ return collection
95
+ else:
96
+ if reg_key not in registered_collection:
97
+ raise LookupError(
98
+ f"registration key {reg_key} is never "
99
+ f"registered. Please make sure the {reg_key} and its library is "
100
+ "imported and linked to the trainer binary.")
101
+ return registered_collection[reg_key]
modeling/official/core/registry_test.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for registry."""
16
+
17
+ import tensorflow as tf, tf_keras
18
+ from official.core import registry
19
+
20
+
21
+ class RegistryTest(tf.test.TestCase):
22
+
23
+ def test_register(self):
24
+ collection = {}
25
+
26
+ @registry.register(collection, 'functions/func_0')
27
+ def func_test():
28
+ pass
29
+
30
+ self.assertEqual(registry.lookup(collection, 'functions/func_0'), func_test)
31
+
32
+ @registry.register(collection, 'classes/cls_0')
33
+ class ClassRegistryKey:
34
+ pass
35
+
36
+ self.assertEqual(
37
+ registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey)
38
+
39
+ @registry.register(collection, ClassRegistryKey)
40
+ class ClassRegistryValue:
41
+ pass
42
+
43
+ self.assertEqual(
44
+ registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
45
+
46
+ def test_register_hierarchy(self):
47
+ collection = {}
48
+
49
+ @registry.register(collection, 'functions/func_0')
50
+ def func_test0():
51
+ pass
52
+
53
+ @registry.register(collection, 'func_1')
54
+ def func_test1():
55
+ pass
56
+
57
+ @registry.register(collection, func_test1)
58
+ def func_test2():
59
+ pass
60
+
61
+ expected_collection = {
62
+ 'functions': {
63
+ 'func_0': func_test0,
64
+ },
65
+ 'func_1': func_test1,
66
+ func_test1: func_test2,
67
+ }
68
+ self.assertEqual(collection, expected_collection)
69
+
70
+ def test_register_error(self):
71
+ collection = {}
72
+
73
+ @registry.register(collection, 'functions/func_0')
74
+ def func_test0(): # pylint: disable=unused-variable
75
+ pass
76
+
77
+ with self.assertRaises(KeyError):
78
+
79
+ @registry.register(collection, 'functions/func_0/sub_func')
80
+ def func_test1(): # pylint: disable=unused-variable
81
+ pass
82
+
83
+ with self.assertRaises(LookupError):
84
+ registry.lookup(collection, 'non-exist')
85
+
86
+
87
+ if __name__ == '__main__':
88
+ tf.test.main()
modeling/official/core/savedmodel_checkpoint_manager.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Custom checkpoint manager that also exports saved models."""
16
+
17
+ import os
18
+ import re
19
+ import time
20
+ from typing import Callable, List, Mapping, Optional, Union
21
+
22
+ from absl import logging
23
+ import tensorflow as tf, tf_keras
24
+
25
+ SAVED_MODULES_PATH_SUFFIX = 'saved_modules'
26
+
27
+
28
+ def make_saved_modules_directory_name(checkpoint_name: str) -> str:
29
+ return f'{checkpoint_name}_{SAVED_MODULES_PATH_SUFFIX}'
30
+
31
+
32
+ class SavedModelCheckpointManager(tf.train.CheckpointManager):
33
+ """A CheckpointManager that also exports `SavedModel`s."""
34
+
35
+ def __init__(self,
36
+ checkpoint: tf.train.Checkpoint,
37
+ directory: str,
38
+ max_to_keep: int,
39
+ modules_to_export: Optional[Mapping[str, tf.Module]] = None,
40
+ keep_checkpoint_every_n_hours: Optional[int] = None,
41
+ checkpoint_name: str = 'ckpt',
42
+ step_counter: Optional[tf.Variable] = None,
43
+ checkpoint_interval: Optional[int] = None,
44
+ init_fn: Optional[Callable[[], None]] = None):
45
+ """See base class."""
46
+ super().__init__(
47
+ checkpoint=checkpoint,
48
+ directory=directory,
49
+ max_to_keep=max_to_keep,
50
+ keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
51
+ checkpoint_name=checkpoint_name,
52
+ step_counter=step_counter,
53
+ checkpoint_interval=checkpoint_interval,
54
+ init_fn=init_fn)
55
+ self._modules_to_export = modules_to_export
56
+ self._savedmodels = self.get_existing_savedmodels()
57
+
58
+ def save(self,
59
+ checkpoint_number: Optional[int] = None,
60
+ check_interval: bool = True,
61
+ options: Optional[tf.train.CheckpointOptions] = None):
62
+ """See base class."""
63
+ checkpoint_path = super().save(
64
+ checkpoint_number=checkpoint_number,
65
+ check_interval=check_interval,
66
+ options=options)
67
+ if not checkpoint_path: # Nothing got written.
68
+ return
69
+ if not self._modules_to_export: # No modules to export.
70
+ logging.info('Skip saving SavedModel due to empty modules_to_export.')
71
+ return checkpoint_path
72
+
73
+ # Save the models for the checkpoint that just got written.
74
+ saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
75
+ # Atomic export of SavedModel. Write into a temporary direcotory and then
76
+ # rename as the final direcotory after finishing the writing.
77
+ # This can avoid trying to read an unfinished savedmodel.
78
+ saved_modules_directory_tmp = saved_modules_directory + '_temp'
79
+ for model_name, model in self._modules_to_export.items():
80
+ signatures = getattr(model, 'saved_model_signatures', None)
81
+ if signatures is not None:
82
+ tf.saved_model.save(
83
+ obj=model,
84
+ export_dir=os.path.join(saved_modules_directory_tmp, model_name),
85
+ signatures=signatures)
86
+ if tf.io.gfile.exists(saved_modules_directory_tmp):
87
+ tf.io.gfile.rename(saved_modules_directory_tmp, saved_modules_directory)
88
+
89
+ saved_modules_directories_to_keep = [
90
+ make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
91
+ ]
92
+ existing_saved_modules_dirs = self.get_existing_savedmodels()
93
+
94
+ self._savedmodels = []
95
+ # Keep savedmodels in the same order as checkpoints (from oldest to newest).
96
+ for saved_modules_dir_to_keep in saved_modules_directories_to_keep:
97
+ if saved_modules_dir_to_keep in existing_saved_modules_dirs:
98
+ self._savedmodels.append(saved_modules_dir_to_keep)
99
+
100
+ for existing_saved_modules_dir in existing_saved_modules_dirs:
101
+ if existing_saved_modules_dir not in self._savedmodels:
102
+ tf.io.gfile.rmtree(existing_saved_modules_dir)
103
+
104
+ return checkpoint_path
105
+
106
+ def get_existing_savedmodels(self) -> List[str]:
107
+ """Gets a list of all existing SavedModel paths in `directory`.
108
+
109
+ Returns:
110
+ A list of all existing SavedModel paths.
111
+ """
112
+ saved_modules_glob = make_saved_modules_directory_name(
113
+ self._checkpoint_prefix + '-*')
114
+ savedmodels = tf.io.gfile.glob(saved_modules_glob)
115
+ # Filter out temporary savedmodel.
116
+ savedmodels = [
117
+ savedmodel
118
+ for savedmodel in savedmodels
119
+ if savedmodel.endswith(SAVED_MODULES_PATH_SUFFIX)
120
+ ]
121
+ return savedmodels
122
+
123
+ @property
124
+ def latest_savedmodel(self) -> Union[str, None]:
125
+ """The path of the most recent SavedModel in `directory`.
126
+
127
+ Returns:
128
+ The latest SavedModel path. If there are no SavedModels, returns `None`.
129
+ """
130
+ if self._savedmodels:
131
+ return self._savedmodels[-1]
132
+ return None
133
+
134
+ @property
135
+ def savedmodels(self) -> List[str]:
136
+ """A list of managed SavedModels.
137
+
138
+ Returns:
139
+ A list of SavedModel paths, sorted from oldest to newest.
140
+ """
141
+ return self._savedmodels
142
+
143
+ @property
144
+ def modules_to_export(self) -> Union[Mapping[str, tf.Module], None]:
145
+ return self._modules_to_export
146
+
147
+ def get_savedmodel_number_from_path(self,
148
+ savedmodel_path: str) -> Union[int, None]:
149
+ """Gets the savedmodel_number/checkpoint_number from savedmodel filepath.
150
+
151
+ The savedmodel_number is global step when using with orbit controller.
152
+
153
+ Args:
154
+ savedmodel_path: savedmodel directory path.
155
+
156
+ Returns:
157
+ Savedmodel number or None if no matched pattern found in savedmodel path.
158
+ """
159
+ pattern = rf'\d+_{SAVED_MODULES_PATH_SUFFIX}$'
160
+ savedmodel_number = re.search(pattern, savedmodel_path)
161
+ if savedmodel_number:
162
+ savedmodel_number = savedmodel_number.group()
163
+ return int(savedmodel_number[:-len(SAVED_MODULES_PATH_SUFFIX) - 1])
164
+ return None
165
+
166
+ def savedmodels_iterator(self,
167
+ min_interval_secs: float = 0,
168
+ timeout: Optional[float] = None,
169
+ timeout_fn: Optional[Callable[[], bool]] = None):
170
+ """Continuously yield new SavedModel files as they appear.
171
+
172
+ The iterator only checks for new savedmodels when control flow has been
173
+ reverted to it. The logic is same to the `train.checkpoints_iterator`.
174
+
175
+ Args:
176
+ min_interval_secs: The minimum number of seconds between yielding
177
+ savedmodels.
178
+ timeout: The maximum number of seconds to wait between savedmodels. If
179
+ left as `None`, then the process will wait indefinitely.
180
+ timeout_fn: Optional function to call after a timeout. If the function
181
+ returns True, then it means that no new savedmodels will be generated
182
+ and the iterator will exit. The function is called with no arguments.
183
+
184
+ Yields:
185
+ String paths to latest SavedModel files as they arrive.
186
+ """
187
+ savedmodel_path = None
188
+ while True:
189
+ new_savedmodel_path = self.wait_for_new_savedmodel(
190
+ savedmodel_path, timeout=timeout)
191
+ if new_savedmodel_path is None:
192
+ if not timeout_fn:
193
+ # timed out
194
+ logging.info('Timed-out waiting for a savedmodel.')
195
+ return
196
+ if timeout_fn():
197
+ # The timeout_fn indicated that we are truly done.
198
+ return
199
+ else:
200
+ # The timeout_fn indicated that more savedmodels may come.
201
+ continue
202
+ start = time.time()
203
+ savedmodel_path = new_savedmodel_path
204
+ yield savedmodel_path
205
+ time_to_next_eval = start + min_interval_secs - time.time()
206
+ if time_to_next_eval > 0:
207
+ time.sleep(time_to_next_eval)
208
+
209
+ def wait_for_new_savedmodel(
210
+ self,
211
+ last_savedmodel: Optional[str] = None,
212
+ seconds_to_sleep: float = 1.0,
213
+ timeout: Optional[float] = None) -> Union[str, None]:
214
+ """Waits until a new savedmodel file is found.
215
+
216
+ Args:
217
+ last_savedmodel: The last savedmodel path used or `None` if we're
218
+ expecting a savedmodel for the first time.
219
+ seconds_to_sleep: The number of seconds to sleep for before looking for a
220
+ new savedmodel.
221
+ timeout: The maximum number of seconds to wait. If left as `None`, then
222
+ the process will wait indefinitely.
223
+
224
+ Returns:
225
+ A new savedmodel path, or None if the timeout was reached.
226
+ """
227
+ logging.info('Waiting for new savedmodel at %s', self._directory)
228
+ stop_time = time.time() + timeout if timeout is not None else None
229
+
230
+ last_savedmodel_number = -1
231
+ if last_savedmodel:
232
+ last_savedmodel_number = self.get_savedmodel_number_from_path(
233
+ last_savedmodel)
234
+
235
+ while True:
236
+ if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
237
+ return None
238
+
239
+ existing_savedmodels = {}
240
+ for savedmodel_path in self.get_existing_savedmodels():
241
+ savedmodel_number = self.get_savedmodel_number_from_path(
242
+ savedmodel_path)
243
+ if savedmodel_number is not None:
244
+ existing_savedmodels[savedmodel_number] = savedmodel_path
245
+
246
+ # Find the first savedmodel with larger step number as next savedmodel.
247
+ savedmodel_path = None
248
+ existing_savedmodels = dict(sorted(existing_savedmodels.items()))
249
+ for savedmodel_number in existing_savedmodels:
250
+ if savedmodel_number > last_savedmodel_number:
251
+ savedmodel_path = existing_savedmodels[savedmodel_number]
252
+ break
253
+
254
+ if savedmodel_path:
255
+ logging.info('Found new savedmodel at %s', savedmodel_path)
256
+ return savedmodel_path
257
+ else:
258
+ time.sleep(seconds_to_sleep)
modeling/official/core/savedmodel_checkpoint_manager_test.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import time
17
+ from typing import Iterable
18
+
19
+ import tensorflow as tf, tf_keras
20
+
21
+ from official.core import savedmodel_checkpoint_manager
22
+
23
+
24
+ def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
25
+ for model_name in models:
26
+ if not tf.io.gfile.isdir(
27
+ os.path.join(
28
+ savedmodel_checkpoint_manager.make_saved_modules_directory_name(
29
+ checkpoint_path), model_name)):
30
+ return False
31
+ return True
32
+
33
+
34
+ class _ModelForTest(tf_keras.Model):
35
+ def __init__(self, hidden_size: int = 8):
36
+ super().__init__()
37
+ self.dense = tf_keras.layers.Dense(hidden_size)
38
+
39
+ @tf.function(input_signature=[tf.TensorSpec([None, 16])])
40
+ def call(self, inputs):
41
+ return self.dense(inputs)
42
+
43
+ @property
44
+ def saved_model_signatures(self):
45
+ # Build SavedModel signatures.
46
+ return dict(serving_default=self.call)
47
+
48
+
49
+ class CheckpointManagerTest(tf.test.TestCase):
50
+
51
+ def _create_manager(self, max_to_keep: int = 1) -> tf.train.CheckpointManager:
52
+ """Sets up SavedModelCheckpointManager object.
53
+
54
+ Args:
55
+ max_to_keep: max number of savedmodels to keep.
56
+
57
+ Returns:
58
+ created savedmodel manager.
59
+ """
60
+ models = {
61
+ 'model_1': _ModelForTest(12),
62
+ 'model_2': _ModelForTest(14),
63
+ }
64
+ checkpoint = tf.train.Checkpoint()
65
+ manager = savedmodel_checkpoint_manager.SavedModelCheckpointManager(
66
+ checkpoint=checkpoint,
67
+ directory=self.get_temp_dir(),
68
+ max_to_keep=max_to_keep,
69
+ modules_to_export=models)
70
+ return manager
71
+
72
+ def test_max_to_keep(self):
73
+ manager = self._create_manager()
74
+ models = manager.modules_to_export
75
+ first_path = manager.save()
76
+ second_path = manager.save()
77
+
78
+ savedmodel = savedmodel_checkpoint_manager.make_saved_modules_directory_name(
79
+ manager.latest_checkpoint)
80
+ self.assertEqual(savedmodel, manager.latest_savedmodel)
81
+ self.assertTrue(_models_exist(second_path, models.keys()))
82
+ self.assertFalse(_models_exist(first_path, models.keys()))
83
+
84
+ def test_returns_none_after_timeout(self):
85
+ manager = self._create_manager()
86
+ start = time.time()
87
+ ret = manager.wait_for_new_savedmodel(
88
+ None, timeout=1.0, seconds_to_sleep=0.5)
89
+ end = time.time()
90
+ self.assertIsNone(ret)
91
+ # We've waited 0.5 second.
92
+ self.assertGreater(end, start + 0.5)
93
+ # The timeout kicked in.
94
+ self.assertLess(end, start + 0.6)
95
+
96
+ def test_saved_model_iterator(self):
97
+ manager = self._create_manager(max_to_keep=2)
98
+ self.assertIsNotNone(manager.save(checkpoint_number=1))
99
+ self.assertIsNotNone(manager.save(checkpoint_number=2))
100
+ self.assertIsNotNone(manager.save(checkpoint_number=3))
101
+
102
+ # Savedmodels are in time order.
103
+ expected_savedmodels = manager.savedmodels
104
+ # Order not guaranteed.
105
+ existing_savedmodels = manager.get_existing_savedmodels()
106
+ savedmodels = list(manager.savedmodels_iterator(timeout=3.0))
107
+ self.assertEqual(savedmodels, expected_savedmodels)
108
+ self.assertEqual(set(savedmodels), set(existing_savedmodels))
109
+
110
+ def test_saved_model_iterator_timeout_fn(self):
111
+ manager = self._create_manager()
112
+ timeout_fn_calls = [0]
113
+
114
+ def timeout_fn():
115
+ timeout_fn_calls[0] += 1
116
+ return timeout_fn_calls[0] > 3
117
+
118
+ results = list(
119
+ manager.savedmodels_iterator(timeout=0.1, timeout_fn=timeout_fn))
120
+ self.assertEqual([], results)
121
+ self.assertEqual(4, timeout_fn_calls[0])
122
+
123
+
124
+ if __name__ == '__main__':
125
+ tf.test.main()
modeling/official/core/task_factory.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A global factory to register and access all registered tasks."""
16
+
17
+ from official.core import registry
18
+
19
+ _REGISTERED_TASK_CLS = {}
20
+
21
+
22
+ # TODO(b/158741360): Add type annotations once pytype checks across modules.
23
+ def register_task_cls(task_config_cls):
24
+ """Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
25
+
26
+ This decorator supports registration of tasks as follows:
27
+
28
+ ```
29
+ @dataclasses.dataclass
30
+ class MyTaskConfig(TaskConfig):
31
+ # Add fields here.
32
+ pass
33
+
34
+ @register_task_cls(MyTaskConfig)
35
+ class MyTask(Task):
36
+ # Inherits def __init__(self, task_config).
37
+ pass
38
+
39
+ my_task_config = MyTaskConfig()
40
+ my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
41
+ ```
42
+
43
+ Besisdes a class itself, other callables that create a Task from a TaskConfig
44
+ can be decorated by the result of this function, as long as there is at most
45
+ one registration for each config class.
46
+
47
+ Args:
48
+ task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
49
+ Each task_config_cls can only be used for a single registration.
50
+
51
+ Returns:
52
+ A callable for use as class decorator that registers the decorated class
53
+ for creation from an instance of task_config_cls.
54
+ """
55
+ return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
56
+
57
+
58
+ def get_task(task_config, **kwargs):
59
+ """Creates a Task (of suitable subclass type) from task_config."""
60
+ # TODO(hongkuny): deprecate the task factory to use config.BUILDER.
61
+ if task_config.BUILDER is not None:
62
+ return task_config.BUILDER(task_config, **kwargs)
63
+ return get_task_cls(task_config.__class__)(task_config, **kwargs)
64
+
65
+
66
+ # The user-visible get_task() is defined after classes have been registered.
67
+ # TODO(b/158741360): Add type annotations once pytype checks across modules.
68
+ def get_task_cls(task_config_cls):
69
+ task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
70
+ return task_cls
modeling/official/core/test_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Utils for testing."""
16
+
17
+ import tensorflow as tf, tf_keras
18
+
19
+
20
+ class FakeKerasModel(tf_keras.Model):
21
+ """Fake keras model for testing."""
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.dense = tf_keras.layers.Dense(4, activation=tf.nn.relu)
26
+ self.dense2 = tf_keras.layers.Dense(4, activation=tf.nn.relu)
27
+
28
+ def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
29
+ return self.dense2(self.dense(inputs))
30
+
31
+
32
+ class _Dense(tf.Module):
33
+ """A dense layer."""
34
+
35
+ def __init__(self, input_dim, output_size, name=None):
36
+ super().__init__(name=name)
37
+ with self.name_scope:
38
+ self.w = tf.Variable(
39
+ tf.random.normal([input_dim, output_size]), name='w')
40
+ self.b = tf.Variable(tf.zeros([output_size]), name='b')
41
+
42
+ @tf.Module.with_name_scope
43
+ def __call__(self, x):
44
+ y = tf.matmul(x, self.w) + self.b
45
+ return tf.nn.relu(y)
46
+
47
+
48
+ class FakeModule(tf.Module):
49
+ """Fake model using tf.Module for testing."""
50
+
51
+ def __init__(self, input_size, name=None):
52
+ super().__init__(name=name)
53
+ with self.name_scope:
54
+ self.dense = _Dense(input_size, 4, name='dense')
55
+ self.dense2 = _Dense(4, 4, name='dense_1')
56
+
57
+ @tf.Module.with_name_scope
58
+ def __call__(self, x):
59
+ return self.dense2(self.dense(x))
modeling/official/core/tf_example_builder.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Builder class for preparing tf.train.Example."""
16
+
17
+ # https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7
18
+ from __future__ import annotations
19
+
20
+ from typing import Mapping, Sequence, Union
21
+
22
+ import numpy as np
23
+ import tensorflow as tf, tf_keras
24
+
25
+ BytesValueType = Union[bytes, Sequence[bytes], str, Sequence[str]]
26
+
27
+ _to_array = lambda v: [v] if not isinstance(v, (list, np.ndarray)) else v
28
+ _to_bytes = lambda v: v.encode() if isinstance(v, str) else v
29
+ _to_bytes_array = lambda v: list(map(_to_bytes, _to_array(v)))
30
+
31
+
32
+ class TfExampleBuilder(object):
33
+ """Builder class for preparing tf.train.Example.
34
+
35
+ Read API doc at https://www.tensorflow.org/api_docs/python/tf/train/Example.
36
+
37
+ Example usage:
38
+ >>> example_builder = TfExampleBuilder()
39
+ >>> example = (
40
+ example_builder.add_bytes_feature('feature_a', 'foobarbaz')
41
+ .add_ints_feature('feature_b', [1, 2, 3])
42
+ .example)
43
+ """
44
+
45
+ def __init__(self) -> None:
46
+ self._example = tf.train.Example()
47
+
48
+ @property
49
+ def example(self) -> tf.train.Example:
50
+ """Returns a copy of the generated tf.train.Example proto."""
51
+ return self._example
52
+
53
+ @property
54
+ def serialized_example(self) -> str:
55
+ """Returns a serialized string of the generated tf.train.Example proto."""
56
+ return self._example.SerializeToString()
57
+
58
+ def set(self, example: tf.train.Example) -> TfExampleBuilder:
59
+ """Sets the example."""
60
+ self._example = example
61
+ return self
62
+
63
+ def reset(self) -> TfExampleBuilder:
64
+ """Resets the example to an empty proto."""
65
+ self._example = tf.train.Example()
66
+ return self
67
+
68
+ ###### Basic APIs for primitive data types ######
69
+ def add_feature_dict(
70
+ self, feature_dict: Mapping[str, tf.train.Feature]) -> TfExampleBuilder:
71
+ """Adds the predefined `feature_dict` to the example.
72
+
73
+ Note: Please prefer to using feature-type-specific methods.
74
+
75
+ Args:
76
+ feature_dict: A dictionary from tf.Example feature key to
77
+ tf.train.Feature.
78
+
79
+ Returns:
80
+ The builder object for subsequent method calls.
81
+ """
82
+ for k, v in feature_dict.items():
83
+ self._example.features.feature[k].CopyFrom(v)
84
+ return self
85
+
86
+ def add_feature(self, key: str,
87
+ feature: tf.train.Feature) -> TfExampleBuilder:
88
+ """Adds predefined `feature` with `key` to the example.
89
+
90
+ Args:
91
+ key: String key of the feature.
92
+ feature: The feature to be added to the example.
93
+
94
+ Returns:
95
+ The builder object for subsequent method calls.
96
+ """
97
+ self._example.features.feature[key].CopyFrom(feature)
98
+ return self
99
+
100
+ def add_bytes_feature(self, key: str,
101
+ value: BytesValueType) -> TfExampleBuilder:
102
+ """Adds byte(s) or string(s) with `key` to the example.
103
+
104
+ Args:
105
+ key: String key of the feature.
106
+ value: The byte(s) or string(s) to be added to the example.
107
+
108
+ Returns:
109
+ The builder object for subsequent method calls.
110
+ """
111
+ return self.add_feature(
112
+ key,
113
+ tf.train.Feature(
114
+ bytes_list=tf.train.BytesList(value=_to_bytes_array(value))))
115
+
116
+ def add_ints_feature(self, key: str,
117
+ value: Union[int, Sequence[int]]) -> TfExampleBuilder:
118
+ """Adds integer(s) with `key` to the example.
119
+
120
+ Args:
121
+ key: String key of the feature.
122
+ value: The integer(s) to be added to the example.
123
+
124
+ Returns:
125
+ The builder object for subsequent method calls.
126
+ """
127
+ return self.add_feature(
128
+ key,
129
+ tf.train.Feature(int64_list=tf.train.Int64List(value=_to_array(value))))
130
+
131
+ def add_floats_feature(
132
+ self, key: str, value: Union[float, Sequence[float]]) -> TfExampleBuilder:
133
+ """Adds float(s) with `key` to the example.
134
+
135
+ Args:
136
+ key: String key of the feature.
137
+ value: The float(s) to be added to the example.
138
+
139
+ Returns:
140
+ The builder object for subsequent method calls.
141
+ """
142
+ return self.add_feature(
143
+ key,
144
+ tf.train.Feature(float_list=tf.train.FloatList(value=_to_array(value))))
modeling/official/core/tf_example_builder_test.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for tf_example_builder.
16
+
17
+ See `test_add_image_matrix_feature_with_fake_image` for the typical structure of
18
+ a unit test.
19
+ """
20
+
21
+ from absl.testing import parameterized
22
+ import tensorflow as tf, tf_keras
23
+ from official.core import tf_example_builder
24
+
25
+
26
+ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
27
+
28
+ def test_init_an_empty_example(self):
29
+ example_builder = tf_example_builder.TfExampleBuilder()
30
+ example = example_builder.example
31
+ self.assertProtoEquals('', example)
32
+
33
+ def test_init_an_empty_serialized_example(self):
34
+ example_builder = tf_example_builder.TfExampleBuilder()
35
+ example = example_builder.serialized_example
36
+ self.assertProtoEquals('', example)
37
+
38
+ def test_add_feature(self):
39
+ example_builder = tf_example_builder.TfExampleBuilder()
40
+ example_builder.add_feature(
41
+ 'foo',
42
+ tf.train.Feature(
43
+ bytes_list=tf.train.BytesList(value=[b'Hello World!'])))
44
+ example = example_builder.example
45
+ # Use proto text to show how the entire proto would look like.
46
+ self.assertProtoEquals(
47
+ """
48
+ features: {
49
+ feature: {
50
+ key: "foo"
51
+ value: {
52
+ bytes_list: {
53
+ value: "Hello World!"
54
+ }
55
+ }
56
+ }
57
+ }""", example)
58
+
59
+ def test_add_feature_dict(self):
60
+ example_builder = tf_example_builder.TfExampleBuilder()
61
+ example_builder.add_feature_dict({
62
+ 'foo':
63
+ tf.train.Feature(
64
+ bytes_list=tf.train.BytesList(value=[b'Hello World!'])),
65
+ 'bar':
66
+ tf.train.Feature(
67
+ int64_list=tf.train.Int64List(value=[299, 792, 458]))
68
+ })
69
+ example = example_builder.example
70
+ # Use proto text to show how the entire proto would look like.
71
+ self.assertProtoEquals(
72
+ """
73
+ features: {
74
+ feature: {
75
+ key: "foo"
76
+ value: {
77
+ bytes_list: {
78
+ value: "Hello World!"
79
+ }
80
+ }
81
+ }
82
+ feature: {
83
+ key: "bar"
84
+ value: {
85
+ int64_list: {
86
+ value: 299
87
+ value: 792
88
+ value: 458
89
+ }
90
+ }
91
+ }
92
+ }""", example)
93
+
94
+ @parameterized.named_parameters(
95
+ ('single_bytes', b'Hello World!', b'Hello World!'),
96
+ ('single_string', 'Hello World!', b'Hello World!'))
97
+ def test_add_single_byte_feature(self, value, expected_value):
98
+ example_builder = tf_example_builder.TfExampleBuilder()
99
+ example_builder.add_bytes_feature('foo', value)
100
+ example = example_builder.example
101
+ # Use constructor to easily work with test parameters.
102
+ self.assertProtoEquals(
103
+ tf.train.Example(
104
+ features=tf.train.Features(
105
+ feature={
106
+ 'foo':
107
+ tf.train.Feature(
108
+ bytes_list=tf.train.BytesList(
109
+ value=[expected_value]))
110
+ })), example)
111
+
112
+ @parameterized.named_parameters(
113
+ ('multiple_bytes', [b'Hello World!', b'Good Morning!'
114
+ ], [b'Hello World!', b'Good Morning!']),
115
+ ('multiple_sring', ['Hello World!', 'Good Morning!'
116
+ ], [b'Hello World!', b'Good Morning!']))
117
+ def test_add_multiple_bytes_feature(self, values, expected_values):
118
+ example_builder = tf_example_builder.TfExampleBuilder()
119
+ example_builder.add_bytes_feature('foo', values)
120
+ example = example_builder.example
121
+ self.assertProtoEquals(
122
+ tf.train.Example(
123
+ features=tf.train.Features(
124
+ feature={
125
+ 'foo':
126
+ tf.train.Feature(
127
+ bytes_list=tf.train.BytesList(
128
+ value=expected_values))
129
+ })), example)
130
+
131
+ @parameterized.named_parameters(
132
+ ('single_integer', 123, [123]),
133
+ ('multiple_integers', [123, 456, 789], [123, 456, 789]))
134
+ def test_add_ints_feature(self, value, expected_value):
135
+ example_builder = tf_example_builder.TfExampleBuilder()
136
+ example_builder.add_ints_feature('bar', value)
137
+ example = example_builder.example
138
+ self.assertProtoEquals(
139
+ tf.train.Example(
140
+ features=tf.train.Features(
141
+ feature={
142
+ 'bar':
143
+ tf.train.Feature(
144
+ int64_list=tf.train.Int64List(value=expected_value))
145
+ })), example)
146
+
147
+ @parameterized.named_parameters(
148
+ ('single_float', 3.14, [3.14]),
149
+ ('multiple_floats', [3.14, 1.57, 6.28], [3.14, 1.57, 6.28]))
150
+ def test_add_floats_feature(self, value, expected_value):
151
+ example_builder = tf_example_builder.TfExampleBuilder()
152
+ example_builder.add_floats_feature('baz', value)
153
+ example = example_builder.example
154
+ self.assertProtoEquals(
155
+ tf.train.Example(
156
+ features=tf.train.Features(
157
+ feature={
158
+ 'baz':
159
+ tf.train.Feature(
160
+ float_list=tf.train.FloatList(value=expected_value))
161
+ })), example)
162
+
163
+
164
+ if __name__ == '__main__':
165
+ tf.test.main()
modeling/official/core/tf_example_feature_key.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Data classes for tf.Example proto feature keys.
16
+
17
+ Feature keys are grouped by feature types. Key names follow conventions in
18
+ go/tf-example.
19
+ """
20
+ import dataclasses
21
+ import functools
22
+ from typing import Optional
23
+
24
+ # Disable init function to use the one defined in base class.
25
+ dataclass = functools.partial(dataclasses.dataclass(init=False))
26
+
27
+
28
+ @dataclass
29
+ class TfExampleFeatureKeyBase:
30
+ """Base dataclass for defining tf.Example proto feature keys.
31
+
32
+ This class defines the logic of adding prefix to feature keys. Subclasses
33
+ will define feature keys for a specific feature type in data fields.
34
+
35
+ NOTE: Please follow subclass examples in this module to define feature keys
36
+ for a new feature type.
37
+ """
38
+
39
+ def __init__(self, prefix: Optional[str] = None):
40
+ """Instantiates the feature key class.
41
+
42
+ Adds a string prefix to all fields of a feature key instance if `prefix` is
43
+ not None nor empty.
44
+
45
+ Example usage:
46
+
47
+ >>> test_key = EncodedImageFeatureKey()
48
+ >>> test_key.encoded
49
+ image/encoded
50
+ >>> test_key = EncodedImageFeatureKey('prefix')
51
+ >>> test_key.encoded
52
+ prefix/image/encoded
53
+
54
+ Args:
55
+ prefix: A prefix string that will be added before the feature key string
56
+ with a trailing slash '/'.
57
+ """
58
+ if prefix:
59
+ for field in dataclasses.fields(self): # pytype: disable=wrong-arg-types # re-none
60
+ key_name = field.name
61
+ key_value = getattr(self, key_name)
62
+ setattr(self, key_name, f'{prefix}/{key_value}')
modeling/official/core/tf_example_feature_key_test.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for tf_example_feature_key."""
16
+ import dataclasses
17
+ import inspect
18
+ from absl.testing import absltest
19
+ from absl.testing import parameterized
20
+
21
+ from official.core import tf_example_feature_key
22
+
23
+
24
+ @tf_example_feature_key.dataclass
25
+ class TestFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
26
+ test: str = 'foo/bar'
27
+
28
+
29
+ class TfExampleFeatureKeyTest(parameterized.TestCase):
30
+
31
+ def test_add_prefix_success(self):
32
+ test_key = TestFeatureKey('prefix')
33
+ self.assertEqual(test_key.test, 'prefix/foo/bar')
34
+
35
+ @parameterized.parameters(None, '')
36
+ def test_add_prefix_skip_success(self, prefix):
37
+ test_key = TestFeatureKey(prefix)
38
+ self.assertEqual(test_key.test, 'foo/bar')
39
+
40
+ def test_all_feature_key_classes_are_valid(self):
41
+ for _, obj in inspect.getmembers(tf_example_feature_key):
42
+ if inspect.isclass(obj):
43
+ self.assertTrue(dataclasses.is_dataclass(obj))
44
+ self.assertTrue(
45
+ issubclass(obj, tf_example_feature_key.TfExampleFeatureKeyBase))
46
+
47
+
48
+ if __name__ == '__main__':
49
+ absltest.main()
modeling/official/core/train_lib.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """TFM common training driver library."""
16
+ # pytype: disable=attribute-error
17
+ import os
18
+ import tempfile
19
+ from typing import Any, List, Mapping, Optional, Tuple
20
+
21
+ # Import libraries
22
+
23
+ from absl import logging
24
+ import orbit
25
+ import tensorflow as tf, tf_keras
26
+
27
+ from official.core import actions
28
+ from official.core import base_task
29
+ from official.core import base_trainer
30
+ from official.core import config_definitions
31
+ from official.core import train_utils
32
+
33
+ maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
34
+
35
+
36
+ class OrbitExperimentRunner:
37
+ """Runs experiment with Orbit training loop.
38
+
39
+ The default experiment runner for model garden experiments. User can
40
+ customize the experiment pipeline by subclassing this class and replacing
41
+ components or functions.
42
+
43
+ For example, an experiment runner with customized checkpoint manager:
44
+
45
+ ```python
46
+ class MyExpRunnerWithExporter(OrbitExperimentRunner):
47
+ def _maybe_build_checkpoint_manager(sefl):
48
+ # Replaces the default CheckpointManger with a customized one.
49
+ return MyCheckpointManager(*args)
50
+
51
+ # In user code, instead of the orginal
52
+ # `OrbitExperimentRunner(..).run(mode)`, now user can do:
53
+ MyExpRunnerWithExporter(**needed_kwargs).run(mode)
54
+ ```
55
+
56
+ Similar override can be done to other components.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ distribution_strategy: tf.distribute.Strategy,
62
+ task: base_task.Task,
63
+ mode: str,
64
+ params: config_definitions.ExperimentConfig,
65
+ model_dir: str,
66
+ run_post_eval: bool = False,
67
+ save_summary: bool = True,
68
+ train_actions: Optional[List[orbit.Action]] = None,
69
+ eval_actions: Optional[List[orbit.Action]] = None,
70
+ trainer: Optional[base_trainer.Trainer] = None,
71
+ controller_cls=orbit.Controller,
72
+ summary_manager: Optional[orbit.utils.SummaryManager] = None,
73
+ eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
74
+ enable_async_checkpointing: bool = False,
75
+ ):
76
+ """Constructor.
77
+
78
+ Args:
79
+ distribution_strategy: A distribution strategy.
80
+ task: A Task instance.
81
+ mode: A 'str', specifying the mode. Can be 'train', 'eval',
82
+ 'train_and_eval' or 'continuous_eval'.
83
+ params: ExperimentConfig instance.
84
+ model_dir: A 'str', a path to store model checkpoints and summaries.
85
+ run_post_eval: Whether to run post eval once after training, metrics logs
86
+ are returned.
87
+ save_summary: Whether to save train and validation summary.
88
+ train_actions: Optional list of Orbit train actions.
89
+ eval_actions: Optional list of Orbit eval actions.
90
+ trainer: the base_trainer.Trainer instance. It should be created within
91
+ the strategy.scope().
92
+ controller_cls: The controller class to manage the train and eval process.
93
+ Must be a orbit.Controller subclass.
94
+ summary_manager: Instance of the summary manager to override default
95
+ summary manager.
96
+ eval_summary_manager: Instance of the eval summary manager to override
97
+ default eval summary manager.
98
+ enable_async_checkpointing: Optional boolean indicating whether to enable
99
+ async checkpoint saving.
100
+ """
101
+ self.strategy = distribution_strategy or tf.distribute.get_strategy()
102
+ self._params = params
103
+ self._model_dir = model_dir
104
+ self._mode = mode
105
+ self._run_post_eval = run_post_eval
106
+
107
+ self._trainer = trainer or self._build_trainer(
108
+ task,
109
+ train='train' in mode,
110
+ evaluate=('eval' in mode) or run_post_eval)
111
+ assert self.trainer is not None
112
+ self._checkpoint_manager = self._maybe_build_checkpoint_manager()
113
+ self._summary_manager = summary_manager
114
+ self._eval_summary_manager = eval_summary_manager
115
+ self._controller = self._build_controller(
116
+ trainer=self.trainer if 'train' in mode else None,
117
+ evaluator=self.trainer,
118
+ save_summary=save_summary,
119
+ train_actions=train_actions,
120
+ eval_actions=eval_actions,
121
+ controller_cls=controller_cls,
122
+ enable_async_checkpointing=enable_async_checkpointing)
123
+
124
+ @property
125
+ def params(self) -> config_definitions.ExperimentConfig:
126
+ """The whole experiment parameters object."""
127
+ return self._params
128
+
129
+ @property
130
+ def model_dir(self) -> str:
131
+ """Path to the model folder, which stores checkpoints, params, log, etc."""
132
+ return self._model_dir
133
+
134
+ @property
135
+ def trainer(self) -> base_trainer.Trainer:
136
+ """The underlying Orbit Trainer object."""
137
+ return self._trainer
138
+
139
+ @property
140
+ def checkpoint_manager(self) -> Optional[tf.train.CheckpointManager]:
141
+ """The CheckpointManager that stores the checkpoints in a train job."""
142
+ return self._checkpoint_manager
143
+
144
+ @property
145
+ def controller(self) -> orbit.Controller:
146
+ """The Orbit controller object."""
147
+ return self._controller
148
+
149
+ def _build_trainer(self, task: base_task.Task, train: bool,
150
+ evaluate: bool) -> base_trainer.Trainer:
151
+ """Create trainer."""
152
+ with self.strategy.scope():
153
+ trainer = train_utils.create_trainer(
154
+ self.params,
155
+ task,
156
+ train=train,
157
+ evaluate=evaluate,
158
+ checkpoint_exporter=self._build_best_checkpoint_exporter())
159
+ return trainer
160
+
161
+ def _build_best_checkpoint_exporter(self):
162
+ return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
163
+
164
+ def _maybe_build_checkpoint_manager(
165
+ self) -> Optional[tf.train.CheckpointManager]:
166
+ """Maybe create a CheckpointManager."""
167
+ assert self.trainer is not None
168
+ if self.trainer.checkpoint:
169
+ if self.model_dir is None:
170
+ raise ValueError('model_dir must be specified, but got None')
171
+
172
+ if (not self.strategy) or self.strategy.extended.should_checkpoint:
173
+ ckpt_path = self.model_dir
174
+ max_to_keep = self.params.trainer.max_to_keep
175
+ else:
176
+ # In multi worker training we need every worker to save checkpoint,
177
+ # because variables can trigger synchronization on read and
178
+ # synchronization needs all workers to participate. To avoid workers
179
+ # overriding each other we save to a temporary directory on non-chief
180
+ # workers.
181
+ ckpt_path = tempfile.mkdtemp()
182
+ max_to_keep = 1
183
+
184
+ checkpoint_manager = tf.train.CheckpointManager(
185
+ self.trainer.checkpoint,
186
+ directory=ckpt_path,
187
+ max_to_keep=max_to_keep,
188
+ step_counter=self.trainer.global_step,
189
+ checkpoint_interval=self.params.trainer.checkpoint_interval,
190
+ init_fn=self.trainer.initialize)
191
+ else:
192
+ checkpoint_manager = None
193
+ return checkpoint_manager
194
+
195
+ def _build_controller(
196
+ self,
197
+ trainer,
198
+ evaluator,
199
+ save_summary: bool = True,
200
+ train_actions: Optional[List[orbit.Action]] = None,
201
+ eval_actions: Optional[List[orbit.Action]] = None,
202
+ controller_cls=orbit.Controller,
203
+ enable_async_checkpointing: bool = False,
204
+ ) -> orbit.Controller:
205
+ """Builds a Orbit controler."""
206
+ train_actions = [] if not train_actions else train_actions
207
+ if trainer:
208
+ checkpoint_manager = self.checkpoint_manager
209
+ assert checkpoint_manager, 'Checkpoint manager required but undefined.'
210
+ train_actions += actions.get_train_actions(
211
+ self.params,
212
+ trainer,
213
+ self.model_dir,
214
+ checkpoint_manager=checkpoint_manager,
215
+ )
216
+
217
+ eval_actions = [] if not eval_actions else eval_actions
218
+ if evaluator:
219
+ eval_actions += actions.get_eval_actions(self.params, evaluator,
220
+ self.model_dir)
221
+
222
+ if save_summary:
223
+ eval_summary_dir = os.path.join(
224
+ self.model_dir, self.params.trainer.validation_summary_subdir
225
+ )
226
+ else:
227
+ eval_summary_dir = None
228
+
229
+ controller = controller_cls(
230
+ strategy=self.strategy,
231
+ trainer=trainer,
232
+ evaluator=evaluator,
233
+ global_step=self.trainer.global_step,
234
+ steps_per_loop=self.params.trainer.steps_per_loop,
235
+ checkpoint_manager=self.checkpoint_manager,
236
+ enable_async_checkpointing=enable_async_checkpointing,
237
+ summary_dir=os.path.join(self.model_dir, 'train')
238
+ if (save_summary)
239
+ else None,
240
+ eval_summary_dir=eval_summary_dir,
241
+ summary_interval=self.params.trainer.summary_interval
242
+ if (save_summary)
243
+ else None,
244
+ train_actions=train_actions,
245
+ eval_actions=eval_actions,
246
+ summary_manager=self._summary_manager
247
+ if hasattr(self, '_summary_manager')
248
+ else None,
249
+ eval_summary_manager=self._eval_summary_manager
250
+ if hasattr(self, '_eval_summary_manager')
251
+ else None,
252
+ )
253
+ return controller
254
+
255
+ def run(self) -> Tuple[tf_keras.Model, Mapping[str, Any]]:
256
+ """Run experiments by mode.
257
+
258
+ Returns:
259
+ A 2-tuple of (model, eval_logs).
260
+ model: `tf_keras.Model` instance.
261
+ eval_logs: returns eval metrics logs when run_post_eval is set to True,
262
+ otherwise, returns {}.
263
+ """
264
+ mode = self._mode
265
+ params = self.params
266
+ logging.info('Starts to execute mode: %s', mode)
267
+ with self.strategy.scope():
268
+ if mode == 'train' or mode == 'train_and_post_eval':
269
+ self.controller.train(steps=params.trainer.train_steps)
270
+ elif mode == 'train_and_eval':
271
+ self.controller.train_and_evaluate(
272
+ train_steps=params.trainer.train_steps,
273
+ eval_steps=params.trainer.validation_steps,
274
+ eval_interval=params.trainer.validation_interval)
275
+ elif mode == 'eval':
276
+ self.controller.evaluate(steps=params.trainer.validation_steps)
277
+ elif mode == 'continuous_eval':
278
+
279
+ def timeout_fn():
280
+ if self.trainer.global_step.numpy() >= params.trainer.train_steps:
281
+ return True
282
+ return False
283
+
284
+ self.controller.evaluate_continuously(
285
+ steps=params.trainer.validation_steps,
286
+ timeout=params.trainer.continuous_eval_timeout,
287
+ timeout_fn=timeout_fn)
288
+ else:
289
+ raise NotImplementedError('The mode is not implemented: %s' % mode)
290
+
291
+ num_params = train_utils.try_count_params(self.trainer.model)
292
+ if num_params is not None:
293
+ logging.info('Number of trainable params in model: %f Millions.',
294
+ num_params / 10.**6)
295
+
296
+ flops = train_utils.try_count_flops(self.trainer.model)
297
+ if flops is not None:
298
+ logging.info('FLOPs (multi-adds) in model: %f Billions.',
299
+ flops / 10.**9 / 2)
300
+
301
+ if self._run_post_eval or mode == 'train_and_post_eval':
302
+ with self.strategy.scope():
303
+ return self.trainer.model, self.controller.evaluate(
304
+ steps=params.trainer.validation_steps)
305
+ else:
306
+ return self.trainer.model, {}
307
+
308
+
309
+ def run_experiment(
310
+ distribution_strategy: tf.distribute.Strategy,
311
+ task: base_task.Task,
312
+ mode: str,
313
+ params: config_definitions.ExperimentConfig,
314
+ model_dir: str,
315
+ run_post_eval: bool = False,
316
+ save_summary: bool = True,
317
+ train_actions: Optional[List[orbit.Action]] = None,
318
+ eval_actions: Optional[List[orbit.Action]] = None,
319
+ trainer: Optional[base_trainer.Trainer] = None,
320
+ controller_cls=orbit.Controller,
321
+ summary_manager: Optional[orbit.utils.SummaryManager] = None,
322
+ eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
323
+ enable_async_checkpointing: bool = False,
324
+ ) -> Tuple[tf_keras.Model, Mapping[str, Any]]:
325
+ """Runs train/eval configured by the experiment params.
326
+
327
+ Args:
328
+ distribution_strategy: A distribution distribution_strategy.
329
+ task: A Task instance.
330
+ mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
331
+ or 'continuous_eval'.
332
+ params: ExperimentConfig instance.
333
+ model_dir: A 'str', a path to store model checkpoints and summaries.
334
+ run_post_eval: Whether to run post eval once after training, metrics logs
335
+ are returned.
336
+ save_summary: Whether to save train and validation summary.
337
+ train_actions: Optional list of Orbit train actions.
338
+ eval_actions: Optional list of Orbit eval actions.
339
+ trainer: the base_trainer.Trainer instance. It should be created within the
340
+ strategy.scope().
341
+ controller_cls: The controller class to manage the train and eval process.
342
+ Must be a orbit.Controller subclass.
343
+ summary_manager: Instance of the summary manager to override default summary
344
+ manager.
345
+ eval_summary_manager: Instance of the eval summary manager to override
346
+ default eval summary manager.
347
+ enable_async_checkpointing: Optional boolean indicating whether to enable
348
+ async checkpoint saving.
349
+
350
+ Returns:
351
+ A 2-tuple of (model, eval_logs).
352
+ model: `tf_keras.Model` instance.
353
+ eval_logs: returns eval metrics logs when run_post_eval is set to True,
354
+ otherwise, returns {}.
355
+ """
356
+ runner = OrbitExperimentRunner(
357
+ distribution_strategy=distribution_strategy,
358
+ task=task,
359
+ mode=mode,
360
+ params=params,
361
+ model_dir=model_dir,
362
+ run_post_eval=run_post_eval,
363
+ save_summary=save_summary,
364
+ train_actions=train_actions,
365
+ eval_actions=eval_actions,
366
+ trainer=trainer,
367
+ controller_cls=controller_cls,
368
+ summary_manager=summary_manager,
369
+ eval_summary_manager=eval_summary_manager,
370
+ enable_async_checkpointing=enable_async_checkpointing,
371
+ )
372
+ return runner.run()
modeling/official/core/train_lib_test.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for train_ctl_lib."""
16
+ import json
17
+ import os
18
+
19
+ from absl import flags
20
+ from absl.testing import flagsaver
21
+ from absl.testing import parameterized
22
+ import numpy as np
23
+ import tensorflow as tf, tf_keras
24
+
25
+ from tensorflow.python.distribute import combinations
26
+ from tensorflow.python.distribute import strategy_combinations
27
+ from official.common import flags as tfm_flags
28
+ # pylint: disable=unused-import
29
+ from official.common import registry_imports
30
+ # pylint: enable=unused-import
31
+ from official.core import task_factory
32
+ from official.core import train_lib
33
+ from official.core import train_utils
34
+ from official.utils.testing import mock_task
35
+
36
+ FLAGS = flags.FLAGS
37
+
38
+ tfm_flags.define_flags()
39
+
40
+
41
+ class TrainTest(tf.test.TestCase, parameterized.TestCase):
42
+
43
+ def setUp(self):
44
+ super(TrainTest, self).setUp()
45
+ self._test_config = {
46
+ 'trainer': {
47
+ 'checkpoint_interval': 10,
48
+ 'steps_per_loop': 10,
49
+ 'summary_interval': 10,
50
+ 'train_steps': 10,
51
+ 'validation_steps': 5,
52
+ 'validation_interval': 10,
53
+ 'continuous_eval_timeout': 1,
54
+ 'validation_summary_subdir': 'validation',
55
+ 'optimizer_config': {
56
+ 'optimizer': {
57
+ 'type': 'sgd',
58
+ },
59
+ 'learning_rate': {
60
+ 'type': 'constant'
61
+ }
62
+ }
63
+ },
64
+ }
65
+
66
+ @combinations.generate(
67
+ combinations.combine(
68
+ distribution_strategy=[
69
+ strategy_combinations.default_strategy,
70
+ strategy_combinations.cloud_tpu_strategy,
71
+ strategy_combinations.one_device_strategy_gpu,
72
+ ],
73
+ flag_mode=['train', 'eval', 'train_and_eval'],
74
+ run_post_eval=[True, False]))
75
+ def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
76
+ model_dir = self.get_temp_dir()
77
+ flags_dict = dict(
78
+ experiment='mock',
79
+ mode=flag_mode,
80
+ model_dir=model_dir,
81
+ params_override=json.dumps(self._test_config))
82
+ with flagsaver.flagsaver(**flags_dict):
83
+ params = train_utils.parse_configuration(flags.FLAGS)
84
+ train_utils.serialize_config(params, model_dir)
85
+ with distribution_strategy.scope():
86
+ task = task_factory.get_task(params.task, logging_dir=model_dir)
87
+
88
+ _, logs = train_lib.run_experiment(
89
+ distribution_strategy=distribution_strategy,
90
+ task=task,
91
+ mode=flag_mode,
92
+ params=params,
93
+ model_dir=model_dir,
94
+ run_post_eval=run_post_eval)
95
+
96
+ if 'eval' in flag_mode:
97
+ self.assertTrue(
98
+ tf.io.gfile.exists(
99
+ os.path.join(model_dir,
100
+ params.trainer.validation_summary_subdir)))
101
+ if run_post_eval:
102
+ self.assertNotEmpty(logs)
103
+ else:
104
+ self.assertEmpty(logs)
105
+ self.assertNotEmpty(
106
+ tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
107
+ if flag_mode == 'eval':
108
+ return
109
+ self.assertNotEmpty(
110
+ tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
111
+ # Tests continuous evaluation.
112
+ _, logs = train_lib.run_experiment(
113
+ distribution_strategy=distribution_strategy,
114
+ task=task,
115
+ mode='continuous_eval',
116
+ params=params,
117
+ model_dir=model_dir,
118
+ run_post_eval=run_post_eval)
119
+
120
+ @combinations.generate(
121
+ combinations.combine(
122
+ distribution_strategy=[
123
+ strategy_combinations.default_strategy,
124
+ strategy_combinations.cloud_tpu_strategy,
125
+ strategy_combinations.one_device_strategy_gpu,
126
+ ],
127
+ flag_mode=['train', 'eval', 'train_and_eval'],
128
+ run_post_eval=[True, False]))
129
+ def test_end_to_end_class(self, distribution_strategy, flag_mode,
130
+ run_post_eval):
131
+ model_dir = self.get_temp_dir()
132
+ flags_dict = dict(
133
+ experiment='mock',
134
+ mode=flag_mode,
135
+ model_dir=model_dir,
136
+ params_override=json.dumps(self._test_config))
137
+ with flagsaver.flagsaver(**flags_dict):
138
+ params = train_utils.parse_configuration(flags.FLAGS)
139
+ train_utils.serialize_config(params, model_dir)
140
+ with distribution_strategy.scope():
141
+ task = task_factory.get_task(params.task, logging_dir=model_dir)
142
+
143
+ _, logs = train_lib.OrbitExperimentRunner(
144
+ distribution_strategy=distribution_strategy,
145
+ task=task,
146
+ mode=flag_mode,
147
+ params=params,
148
+ model_dir=model_dir,
149
+ run_post_eval=run_post_eval).run()
150
+
151
+ if 'eval' in flag_mode:
152
+ self.assertTrue(
153
+ tf.io.gfile.exists(
154
+ os.path.join(model_dir,
155
+ params.trainer.validation_summary_subdir)))
156
+ if run_post_eval:
157
+ self.assertNotEmpty(logs)
158
+ else:
159
+ self.assertEmpty(logs)
160
+ self.assertNotEmpty(
161
+ tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
162
+ if flag_mode == 'eval':
163
+ return
164
+ self.assertNotEmpty(
165
+ tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
166
+ # Tests continuous evaluation.
167
+ _, logs = train_lib.OrbitExperimentRunner(
168
+ distribution_strategy=distribution_strategy,
169
+ task=task,
170
+ mode='continuous_eval',
171
+ params=params,
172
+ model_dir=model_dir,
173
+ run_post_eval=run_post_eval).run()
174
+
175
+ @combinations.generate(
176
+ combinations.combine(
177
+ distribution_strategy=[
178
+ strategy_combinations.default_strategy,
179
+ strategy_combinations.cloud_tpu_strategy,
180
+ strategy_combinations.one_device_strategy_gpu,
181
+ ],
182
+ flag_mode=['train', 'train_and_eval'],
183
+ ))
184
+ def test_recovery_nan_error(self, distribution_strategy, flag_mode):
185
+ model_dir = self.get_temp_dir()
186
+ flags_dict = dict(
187
+ experiment='mock',
188
+ mode=flag_mode,
189
+ model_dir=model_dir,
190
+ params_override=json.dumps(self._test_config))
191
+ with flagsaver.flagsaver(**flags_dict):
192
+ params = train_utils.parse_configuration(flags.FLAGS)
193
+ train_utils.serialize_config(params, model_dir)
194
+ with distribution_strategy.scope():
195
+ # task = task_factory.get_task(params.task, logging_dir=model_dir)
196
+ task = mock_task.MockTask(params.task, logging_dir=model_dir)
197
+
198
+ # Set the loss to NaN to trigger RunTimeError.
199
+ def build_losses(labels, model_outputs, aux_losses=None):
200
+ del labels, model_outputs
201
+ return tf.constant([np.nan], tf.float32) + aux_losses
202
+
203
+ task.build_losses = build_losses
204
+
205
+ with self.assertRaises(RuntimeError):
206
+ train_lib.OrbitExperimentRunner(
207
+ distribution_strategy=distribution_strategy,
208
+ task=task,
209
+ mode=flag_mode,
210
+ params=params,
211
+ model_dir=model_dir).run()
212
+
213
+ @combinations.generate(
214
+ combinations.combine(
215
+ distribution_strategy=[
216
+ strategy_combinations.default_strategy,
217
+ strategy_combinations.cloud_tpu_strategy,
218
+ strategy_combinations.one_device_strategy_gpu,
219
+ ],
220
+ flag_mode=['train'],
221
+ ))
222
+ def test_recovery(self, distribution_strategy, flag_mode):
223
+ loss_threshold = 1.0
224
+ model_dir = self.get_temp_dir()
225
+ flags_dict = dict(
226
+ experiment='mock',
227
+ mode=flag_mode,
228
+ model_dir=model_dir,
229
+ params_override=json.dumps(self._test_config))
230
+ with flagsaver.flagsaver(**flags_dict):
231
+ params = train_utils.parse_configuration(flags.FLAGS)
232
+ params.trainer.loss_upper_bound = loss_threshold
233
+ params.trainer.recovery_max_trials = 1
234
+ train_utils.serialize_config(params, model_dir)
235
+ with distribution_strategy.scope():
236
+ task = task_factory.get_task(params.task, logging_dir=model_dir)
237
+
238
+ # Saves a checkpoint for reference.
239
+ model = task.build_model()
240
+ checkpoint = tf.train.Checkpoint(model=model)
241
+ checkpoint_manager = tf.train.CheckpointManager(
242
+ checkpoint, self.get_temp_dir(), max_to_keep=2)
243
+ checkpoint_manager.save()
244
+ before_weights = model.get_weights()
245
+
246
+ def build_losses(labels, model_outputs, aux_losses=None):
247
+ del labels, model_outputs
248
+ return tf.constant([loss_threshold], tf.float32) + aux_losses
249
+
250
+ task.build_losses = build_losses
251
+
252
+ model, _ = train_lib.OrbitExperimentRunner(
253
+ distribution_strategy=distribution_strategy,
254
+ task=task,
255
+ mode=flag_mode,
256
+ params=params,
257
+ model_dir=model_dir).run()
258
+ after_weights = model.get_weights()
259
+ for left, right in zip(before_weights, after_weights):
260
+ self.assertAllEqual(left, right)
261
+
262
+ def test_parse_configuration(self):
263
+ model_dir = self.get_temp_dir()
264
+ flags_dict = dict(
265
+ experiment='mock',
266
+ mode='train',
267
+ model_dir=model_dir,
268
+ params_override=json.dumps(self._test_config))
269
+ with flagsaver.flagsaver(**flags_dict):
270
+ params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
271
+ with self.assertRaises(ValueError):
272
+ params.override({'task': {'init_checkpoint': 'Foo'}})
273
+
274
+ params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
275
+ params.override({'task': {'init_checkpoint': 'Bar'}})
276
+ self.assertEqual(params.task.init_checkpoint, 'Bar')
277
+
278
+
279
+ if __name__ == '__main__':
280
+ tf.test.main()
modeling/official/core/train_utils.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Training utils."""
16
+
17
+ import dataclasses
18
+ import inspect
19
+ import json
20
+ import os
21
+ import pprint
22
+ from typing import Any, Callable, Dict, List, Optional, Union
23
+
24
+ from absl import logging
25
+ import gin
26
+ import numpy as np
27
+ import orbit
28
+ import tensorflow as tf, tf_keras
29
+
30
+ # pylint: disable=g-direct-tensorflow-import
31
+ from tensorflow.python.framework import ops
32
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
33
+ # pylint: enable=g-direct-tensorflow-import
34
+ from official.core import base_task
35
+ from official.core import base_trainer
36
+ from official.core import config_definitions
37
+ from official.core import exp_factory
38
+ from official.modeling import hyperparams
39
+
40
+
41
+ BEST_CHECKPOINT_NAME = 'best_ckpt'
42
+
43
+
44
+ def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
45
+ """Get leaf from a dictionary with arbitrary depth with a list of keys.
46
+
47
+ Args:
48
+ d: The dictionary to extract value from.
49
+ keys: The list of keys to extract values recursively.
50
+
51
+ Returns:
52
+ The value of the leaf.
53
+
54
+ Raises:
55
+ KeyError: If the value of keys extracted is a dictionary.
56
+ """
57
+ leaf = d
58
+ for k in keys:
59
+ if not isinstance(leaf, dict) or k not in leaf:
60
+ raise KeyError(
61
+ 'Path not exist while traversing the dictionary: d with keys'
62
+ ': %s.' % keys)
63
+ leaf = leaf[k]
64
+
65
+ if isinstance(leaf, dict):
66
+ raise KeyError('The value extracted with keys: %s is not a leaf of the '
67
+ 'dictionary: %s.' % (keys, d))
68
+ return leaf
69
+
70
+
71
+ def cast_leaf_nested_dict(d: Dict[str, Any],
72
+ cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
73
+ """Cast the leaves of a dictionary with arbitrary depth in place.
74
+
75
+ Args:
76
+ d: The dictionary to extract value from.
77
+ cast_fn: The casting function.
78
+
79
+ Returns:
80
+ A dictionray with the same structure as d.
81
+ """
82
+ for key, value in d.items():
83
+ if isinstance(value, dict):
84
+ d[key] = cast_leaf_nested_dict(value, cast_fn)
85
+ else:
86
+ d[key] = cast_fn(value)
87
+ return d
88
+
89
+
90
+ def _filter_leaf_nested_dict(
91
+ d: Dict[str, Any], predicate: Callable[[Any], bool]
92
+ ) -> Dict[str, Any]:
93
+ """Filters the leaves of a dictionary with arbitrary depth in place.
94
+
95
+ Args:
96
+ d: The dictionary to extract value from.
97
+ predicate: A function that will be called on every leave item. When the
98
+ function returns True the leave will be kept. Otherwise the leave will be
99
+ dropped.
100
+
101
+ Returns:
102
+ A new dictionray with filtered result.
103
+ """
104
+ result = {}
105
+ for key, value in d.items():
106
+ if isinstance(value, dict):
107
+ result[key] = _filter_leaf_nested_dict(value, predicate)
108
+ elif predicate(value):
109
+ result[key] = value
110
+ return result
111
+
112
+
113
+ def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
114
+ data_dir: str) -> Any:
115
+ """Maybe create a BestCheckpointExporter object, according to the config."""
116
+ export_subdir = params.trainer.best_checkpoint_export_subdir
117
+ metric_name = params.trainer.best_checkpoint_eval_metric
118
+ metric_comp = params.trainer.best_checkpoint_metric_comp
119
+ if data_dir and export_subdir and metric_name:
120
+ best_ckpt_dir = os.path.join(data_dir, export_subdir)
121
+ best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
122
+ metric_comp)
123
+ logging.info(
124
+ 'Created the best checkpoint exporter. '
125
+ 'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
126
+ export_subdir, metric_name)
127
+ else:
128
+ best_ckpt_exporter = None
129
+
130
+ return best_ckpt_exporter
131
+
132
+
133
+ class BestCheckpointExporter:
134
+ """Keeps track of the best result, and saves its checkpoint.
135
+
136
+ Orbit will support an API for checkpoint exporter. This class will be used
137
+ together with orbit once this functionality is ready.
138
+ """
139
+
140
+ def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
141
+ """Initialization.
142
+
143
+ Args:
144
+ export_dir: The directory that will contain exported checkpoints.
145
+ metric_name: Indicates which metric to look at, when determining which
146
+ result is better. If eval_logs being passed to maybe_export_checkpoint
147
+ is a nested dictionary, use `|` as a seperator for different layers.
148
+ metric_comp: Indicates how to compare results. Either `lower` or `higher`.
149
+ """
150
+ self._export_dir = export_dir
151
+ self._metric_name = metric_name.split('|')
152
+ self._metric_comp = metric_comp
153
+ if self._metric_comp not in ('lower', 'higher'):
154
+ raise ValueError('best checkpoint metric comp must be one of '
155
+ 'higher, lower. Got: {}'.format(self._metric_comp))
156
+ tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
157
+ self._best_ckpt_logs = self._maybe_load_best_eval_metric()
158
+ self._checkpoint_manager = None
159
+
160
+ def _get_checkpoint_manager(self, checkpoint):
161
+ """Gets an existing checkpoint manager or creates a new one."""
162
+ if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
163
+ != checkpoint):
164
+ logging.info('Creates a new checkpoint manager.')
165
+ self._checkpoint_manager = tf.train.CheckpointManager(
166
+ checkpoint,
167
+ directory=self._export_dir,
168
+ max_to_keep=1,
169
+ checkpoint_name=BEST_CHECKPOINT_NAME)
170
+
171
+ return self._checkpoint_manager
172
+
173
+ def maybe_export_checkpoint(
174
+ self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
175
+ """Compare eval_logs with past eval_logs and export checkpoint if better."""
176
+ logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
177
+ eval_logs, global_step)
178
+ if self._best_ckpt_logs is None or self._new_metric_is_better(
179
+ self._best_ckpt_logs, eval_logs):
180
+ self._best_ckpt_logs = eval_logs
181
+ if write_logs:
182
+ self.export_best_eval_metric(self._best_ckpt_logs, global_step)
183
+ self._get_checkpoint_manager(checkpoint).save()
184
+ return True
185
+ return False
186
+
187
+ def _maybe_load_best_eval_metric(self):
188
+ if not tf.io.gfile.exists(self.best_ckpt_logs_path):
189
+ return None
190
+ with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
191
+ return json.loads(reader.read())
192
+
193
+ def _new_metric_is_better(self, old_logs, new_logs):
194
+ """Check if the metric in new_logs is better than the metric in old_logs."""
195
+ old_value = float(
196
+ orbit.utils.get_value(
197
+ get_leaf_nested_dict(old_logs, self._metric_name)))
198
+ new_value = float(
199
+ orbit.utils.get_value(
200
+ get_leaf_nested_dict(new_logs, self._metric_name)))
201
+
202
+ logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
203
+ old_value, new_value)
204
+ if self._metric_comp == 'higher':
205
+ if new_value > old_value:
206
+ logging.info('[BestCheckpointExporter] '
207
+ 'the new number is better since it is higher.')
208
+ return True
209
+ else: # self._metric_comp == 'lower':
210
+ if new_value < old_value:
211
+ logging.info('[BestCheckpointExporter] '
212
+ 'the new number is better since it is lower.')
213
+ return True
214
+ return False
215
+
216
+ def export_best_eval_metric(self, eval_logs, global_step):
217
+ """Export evaluation results of the best checkpoint into a json file."""
218
+ # eval_log_ext may contains non-scalar tensors, such as image data when
219
+ # `allow_image_summary` is True. Here we only keep scalar tensors.
220
+ eval_logs_ext = _filter_leaf_nested_dict(
221
+ eval_logs, lambda x: tf.rank(x) <= 1
222
+ )
223
+ eval_logs_ext['best_ckpt_global_step'] = global_step
224
+ eval_logs_ext = cast_leaf_nested_dict(
225
+ eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
226
+ # Saving json file is very fast.
227
+ with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
228
+ writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
229
+
230
+ @property
231
+ def best_ckpt_logs(self):
232
+ return self._best_ckpt_logs
233
+
234
+ @property
235
+ def best_ckpt_logs_path(self):
236
+ return os.path.join(self._export_dir, 'info.json')
237
+
238
+ @property
239
+ def best_ckpt_path(self):
240
+ """Returns the best ckpt path or None if there is no ckpt yet."""
241
+ return tf.train.latest_checkpoint(self._export_dir)
242
+
243
+
244
+ def create_optimizer(task: base_task.Task,
245
+ params: config_definitions.ExperimentConfig
246
+ ) -> tf_keras.optimizers.Optimizer:
247
+ """A create optimizer util to be backward compatability with new args."""
248
+ if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
249
+ dp_config = None
250
+ if hasattr(params.task, 'differential_privacy_config'):
251
+ dp_config = params.task.differential_privacy_config
252
+ optimizer = task.create_optimizer(
253
+ params.trainer.optimizer_config, params.runtime,
254
+ dp_config=dp_config)
255
+ else:
256
+ if hasattr(params.task, 'differential_privacy_config'
257
+ ) and params.task.differential_privacy_config is not None:
258
+ raise ValueError('Differential privacy config is specified but '
259
+ 'task.create_optimizer api does not accept it.')
260
+ optimizer = task.create_optimizer(
261
+ params.trainer.optimizer_config,
262
+ params.runtime)
263
+ return optimizer
264
+
265
+
266
+ @gin.configurable
267
+ def create_trainer(params: config_definitions.ExperimentConfig,
268
+ task: base_task.Task,
269
+ train: bool,
270
+ evaluate: bool,
271
+ checkpoint_exporter: Optional[BestCheckpointExporter] = None,
272
+ trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
273
+ """Create trainer."""
274
+ logging.info('Running default trainer.')
275
+ model = task.build_model()
276
+ optimizer = create_optimizer(task, params)
277
+ return trainer_cls(
278
+ params,
279
+ task,
280
+ model=model,
281
+ optimizer=optimizer,
282
+ train=train,
283
+ evaluate=evaluate,
284
+ checkpoint_exporter=checkpoint_exporter)
285
+
286
+
287
+ @dataclasses.dataclass
288
+ class ParseConfigOptions:
289
+ """Use this dataclass instead of FLAGS to customize parse_configuration()."""
290
+ experiment: str
291
+ config_file: List[str]
292
+ tpu: str = ''
293
+ tf_data_service: str = ''
294
+ params_override: str = ''
295
+
296
+ def __contains__(self, name):
297
+ return name in dataclasses.asdict(self)
298
+
299
+
300
+ class ExperimentParser:
301
+ """Constructs the Experiment config from Flags or equivalent object.
302
+
303
+ Most of the cases, users only need to call the `parse()` function:
304
+ ```
305
+ builder = ExperimentParser(FLAGS)
306
+ params = builder.parse()
307
+ ```
308
+
309
+ The advanced users can modify the flow by calling the parse_*() functions
310
+ separately.
311
+ """
312
+
313
+ def __init__(self, flags_obj):
314
+ self._flags_obj = flags_obj
315
+
316
+ def parse(self):
317
+ """Overrall process of constructing Experiment config."""
318
+ params = self.base_experiment()
319
+ params = self.parse_config_file(params)
320
+ params = self.parse_runtime(params)
321
+ params = self.parse_data_service(params)
322
+ params = self.parse_params_override(params)
323
+ return params
324
+
325
+ def base_experiment(self):
326
+ """Get the base experiment config from --experiment field."""
327
+ if self._flags_obj.experiment is None:
328
+ raise ValueError('The flag --experiment must be specified.')
329
+ return exp_factory.get_exp_config(self._flags_obj.experiment)
330
+
331
+ def parse_config_file(self, params):
332
+ """Override the configs of params from the config_file."""
333
+ for config_file in self._flags_obj.config_file or []:
334
+ params = hyperparams.override_params_dict(
335
+ params, config_file, is_strict=True)
336
+ return params
337
+
338
+ def parse_runtime(self, params):
339
+ """Override the runtime configs of params from flags."""
340
+ # Override the TPU address and tf.data service address.
341
+ params.override({
342
+ 'runtime': {
343
+ 'tpu': self._flags_obj.tpu,
344
+ },
345
+ })
346
+ return params
347
+
348
+ def parse_data_service(self, params):
349
+ """Override the data service configs of params from flags."""
350
+ if ('tf_data_service' in self._flags_obj and
351
+ self._flags_obj.tf_data_service and
352
+ isinstance(params.task, config_definitions.TaskConfig)):
353
+ params.override({
354
+ 'task': {
355
+ 'train_data': {
356
+ 'tf_data_service_address': self._flags_obj.tf_data_service,
357
+ },
358
+ 'validation_data': {
359
+ 'tf_data_service_address': self._flags_obj.tf_data_service,
360
+ }
361
+ }
362
+ })
363
+ return params
364
+
365
+ def parse_params_override(self, params):
366
+ # Get the second level of override from `--params_override`.
367
+ # `--params_override` is typically used as a further override over the
368
+ # template. For example, one may define a particular template for training
369
+ # ResNet50 on ImageNet in a config file and pass it via `--config_file`,
370
+ # then define different learning rates and pass it via `--params_override`.
371
+ if self._flags_obj.params_override:
372
+ params = hyperparams.override_params_dict(
373
+ params, self._flags_obj.params_override, is_strict=True)
374
+ return params
375
+
376
+
377
+ def parse_configuration(flags_obj, lock_return=True, print_return=True):
378
+ """Parses ExperimentConfig from flags."""
379
+
380
+ params = ExperimentParser(flags_obj).parse()
381
+
382
+ params.validate()
383
+ if lock_return:
384
+ params.lock()
385
+
386
+ if print_return:
387
+ pp = pprint.PrettyPrinter()
388
+ logging.info('Final experiment parameters:\n%s',
389
+ pp.pformat(params.as_dict()))
390
+
391
+ return params
392
+
393
+
394
+ def serialize_config(params: config_definitions.ExperimentConfig,
395
+ model_dir: str):
396
+ """Serializes and saves the experiment config."""
397
+ if model_dir is None:
398
+ raise ValueError('model_dir must be specified, but got None')
399
+ params_save_path = os.path.join(model_dir, 'params.yaml')
400
+ logging.info('Saving experiment configuration to %s', params_save_path)
401
+ tf.io.gfile.makedirs(model_dir)
402
+ hyperparams.save_params_dict_to_yaml(params, params_save_path)
403
+
404
+
405
+ def save_gin_config(filename_suffix: str, model_dir: str):
406
+ """Serializes and saves the experiment config."""
407
+ gin_save_path = os.path.join(
408
+ model_dir, 'operative_config.{}.gin'.format(filename_suffix))
409
+ logging.info('Saving gin configurations to %s', gin_save_path)
410
+ tf.io.gfile.makedirs(model_dir)
411
+ with tf.io.gfile.GFile(gin_save_path, 'w') as f:
412
+ f.write(gin.operative_config_str())
413
+
414
+
415
+ def read_global_step_from_checkpoint(ckpt_file_path):
416
+ """Read global step from checkpoint, or get global step from its filename."""
417
+ global_step = tf.Variable(-1, dtype=tf.int64)
418
+ ckpt = tf.train.Checkpoint(global_step=global_step)
419
+ try:
420
+ ckpt.restore(ckpt_file_path).expect_partial()
421
+ global_step_maybe_restored = global_step.numpy()
422
+ except tf.errors.InvalidArgumentError:
423
+ global_step_maybe_restored = -1
424
+
425
+ if global_step_maybe_restored == -1:
426
+ raise ValueError('global_step not found in checkpoint {}. '
427
+ 'If you want to run finetune eval jobs, you need to '
428
+ 'make sure that your pretrain model writes '
429
+ 'global_step in its checkpoints.'.format(ckpt_file_path))
430
+ global_step_restored = global_step.numpy()
431
+ logging.info('get global_step %d from checkpoint %s', global_step_restored,
432
+ ckpt_file_path)
433
+ return global_step_restored
434
+
435
+
436
+ def write_json_summary(log_dir, global_step, eval_metrics):
437
+ """Dump evaluation metrics to json file."""
438
+ serializable_dict = {}
439
+ for name, value in eval_metrics.items():
440
+ if hasattr(value, 'numpy'):
441
+ serializable_dict[name] = str(value.numpy())
442
+ else:
443
+ serializable_dict[name] = str(value)
444
+ output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
445
+ logging.info('Evaluation results at pretrain step %d: %s', global_step,
446
+ serializable_dict)
447
+ with tf.io.gfile.GFile(output_json, 'w') as writer:
448
+ writer.write(json.dumps(serializable_dict, indent=4) + '\n')
449
+
450
+
451
+ def write_summary(summary_writer, global_step, eval_metrics):
452
+ """Write evaluation metrics to TF summary."""
453
+ numeric_dict = {}
454
+ for name, value in eval_metrics.items():
455
+ numeric_dict[name] = float(orbit.utils.get_value(value))
456
+ with summary_writer.as_default():
457
+ for name, value in numeric_dict.items():
458
+ tf.summary.scalar(name, value, step=global_step)
459
+ summary_writer.flush()
460
+
461
+
462
+ def remove_ckpts(model_dir):
463
+ """Remove model checkpoints, so we can restart."""
464
+ ckpts = os.path.join(model_dir, 'ckpt-*')
465
+ logging.info('removing checkpoint files %s', ckpts)
466
+ for file_to_remove in tf.io.gfile.glob(ckpts):
467
+ tf.io.gfile.rmtree(file_to_remove)
468
+
469
+ file_to_remove = os.path.join(model_dir, 'checkpoint')
470
+ if tf.io.gfile.exists(file_to_remove):
471
+ tf.io.gfile.remove(file_to_remove)
472
+
473
+
474
+ def write_model_params(model: Union[tf.Module, tf_keras.Model],
475
+ output_path: str) -> None:
476
+ """Writes the model parameters and shapes to a file.
477
+
478
+ Args:
479
+ model: A model instance.
480
+ output_path: Output file path.
481
+ """
482
+ with tf.io.gfile.GFile(output_path, 'w') as f:
483
+ total_params = 0
484
+ for var in model.variables:
485
+ shape = tf.shape(var)
486
+ total_params += tf.math.reduce_prod(shape).numpy()
487
+ f.write(f'{var.name} {shape.numpy().tolist()}\n')
488
+ f.write(f'\nTotal params: {total_params}\n')
489
+
490
+
491
+ def try_count_params(
492
+ model: Union[tf.Module, tf_keras.Model],
493
+ trainable_only: bool = False):
494
+ """Count the number of parameters if model is possible.
495
+
496
+ Args:
497
+ model: Try to count the number of params in this model.
498
+ trainable_only: Whether to calculate trainable params only. This flag is
499
+ not used when the model has `count_params` attribute.
500
+
501
+ Returns:
502
+ The number of parameters or None.
503
+ """
504
+ if hasattr(model, 'count_params'):
505
+ try:
506
+ return model.count_params()
507
+ except ValueError:
508
+ logging.info('Number of trainable params unknown, because the build() '
509
+ 'methods in keras layers were not called. This is probably '
510
+ 'because the model was not feed any input, e.g., the max '
511
+ 'train step already reached before this run.')
512
+ return None
513
+ else:
514
+ total_params = 0
515
+ variables = model.trainable_variables if trainable_only else model.variables
516
+ for var in variables:
517
+ shape = tf.shape(var)
518
+ total_params += tf.math.reduce_prod(shape).numpy()
519
+ return total_params
520
+
521
+
522
+ def try_count_flops(model: Union[tf.Module, tf_keras.Model],
523
+ inputs_kwargs: Optional[Dict[str, Any]] = None,
524
+ output_path: Optional[str] = None):
525
+ """Counts and returns model FLOPs.
526
+
527
+ Args:
528
+ model: A model instance.
529
+ inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
530
+ shape specifications to getting corresponding concrete function.
531
+ output_path: A file path to write the profiling results to.
532
+
533
+ Returns:
534
+ The model's FLOPs.
535
+ """
536
+ if hasattr(model, 'inputs'):
537
+ try:
538
+ # Get input shape and set batch size to 1.
539
+ if model.inputs:
540
+ inputs = [
541
+ tf.TensorSpec([1] + input.shape[1:], input.dtype)
542
+ for input in model.inputs
543
+ ]
544
+ concrete_func = tf.function(model).get_concrete_function(inputs)
545
+ # If model.inputs is invalid, try to use the input to get concrete
546
+ # function for model.call (subclass model).
547
+ else:
548
+ concrete_func = tf.function(model.call).get_concrete_function(
549
+ **inputs_kwargs)
550
+ frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
551
+
552
+ # Calculate FLOPs.
553
+ run_meta = tf.compat.v1.RunMetadata()
554
+ opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
555
+ if output_path is not None:
556
+ opts['output'] = f'file:outfile={output_path}'
557
+ else:
558
+ opts['output'] = 'none'
559
+ flops = tf.compat.v1.profiler.profile(
560
+ graph=frozen_func.graph, run_meta=run_meta, options=opts)
561
+ return flops.total_float_ops
562
+ except Exception as e: # pylint: disable=broad-except
563
+ logging.info(
564
+ 'Failed to count model FLOPs with error %s, because the build() '
565
+ 'methods in keras layers were not called. This is probably because '
566
+ 'the model was not feed any input, e.g., the max train step already '
567
+ 'reached before this run.', e)
568
+ return None
569
+ return None
570
+
571
+
572
+ @ops.RegisterStatistics('Einsum', 'flops')
573
+ def _einsum_flops(graph, node):
574
+ """Calculates the compute resources needed for Einsum."""
575
+ assert len(node.input) == 2
576
+ x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
577
+ graph, node.input[0])
578
+ y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
579
+ graph, node.input[1])
580
+ x_shape.assert_is_fully_defined()
581
+ y_shape.assert_is_fully_defined()
582
+ x_shape = x_shape.as_list()
583
+ y_shape = y_shape.as_list()
584
+ equation = str(node.attr['equation'])
585
+ equation = (
586
+ equation.replace('s:', '')
587
+ .replace('"', '')
588
+ .replace(' ', '')
589
+ .replace('\n', '')
590
+ )
591
+ x_str = equation.split(',')[0]
592
+ y_r_str = equation.split(',')[1]
593
+ y_str = y_r_str.split('->')[0]
594
+ r_str = y_r_str.split('->')[1]
595
+ shape_dic = {}
596
+ contracted = set()
597
+ for indice in x_str + y_str:
598
+ if indice in x_str:
599
+ indice_dim = x_shape[x_str.find(indice)]
600
+ elif indice in y_str:
601
+ indice_dim = y_shape[y_str.find(indice)]
602
+ else:
603
+ raise ValueError('indice {} not found in inputs'.format(indice))
604
+ shape_dic[indice] = indice_dim
605
+ if indice not in r_str:
606
+ contracted.add(indice)
607
+ madds = np.prod([shape_dic[indice] for indice in r_str]) * (
608
+ np.prod([shape_dic[indice] for indice in contracted]))
609
+ flops = 2 * madds
610
+ return ops.OpStats('flops', flops)
modeling/official/core/train_utils_test.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tests for official.core.train_utils."""
16
+ import json
17
+ import os
18
+ import pprint
19
+
20
+ import numpy as np
21
+ import tensorflow as tf, tf_keras
22
+
23
+ from official.core import exp_factory
24
+ from official.core import test_utils
25
+ from official.core import train_utils
26
+ from official.modeling import hyperparams
27
+
28
+
29
+ @exp_factory.register_config_factory('foo')
30
+ def foo():
31
+ """Multitask experiment for test."""
32
+ experiment_config = hyperparams.Config(
33
+ default_params={
34
+ 'runtime': {
35
+ 'tpu': 'fake',
36
+ },
37
+ 'task': {
38
+ 'model': {
39
+ 'model_id': 'bar',
40
+ },
41
+ },
42
+ 'trainer': {
43
+ 'train_steps': -1,
44
+ 'validation_steps': -1,
45
+ },
46
+ })
47
+ return experiment_config
48
+
49
+
50
+ class TrainUtilsTest(tf.test.TestCase):
51
+
52
+ def test_get_leaf_nested_dict(self):
53
+ d = {'a': {'i': {'x': 5}}}
54
+ self.assertEqual(train_utils.get_leaf_nested_dict(d, ['a', 'i', 'x']), 5)
55
+
56
+ def test_get_leaf_nested_dict_not_leaf(self):
57
+ with self.assertRaisesRegex(KeyError, 'The value extracted with keys.*'):
58
+ d = {'a': {'i': {'x': 5}}}
59
+ train_utils.get_leaf_nested_dict(d, ['a', 'i'])
60
+
61
+ def test_get_leaf_nested_dict_path_not_exist_missing_key(self):
62
+ with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
63
+ d = {'a': {'i': {'x': 5}}}
64
+ train_utils.get_leaf_nested_dict(d, ['a', 'i', 'y'])
65
+
66
+ def test_get_leaf_nested_dict_path_not_exist_out_of_range(self):
67
+ with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
68
+ d = {'a': {'i': {'x': 5}}}
69
+ train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
70
+
71
+ def test_get_leaf_nested_dict_path_not_exist_meets_leaf(self):
72
+ with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
73
+ d = {'a': {'i': 5}}
74
+ train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
75
+
76
+ def test_cast_leaf_nested_dict(self):
77
+ d = {'a': {'i': {'x': '123'}}, 'b': 456.5}
78
+ d = train_utils.cast_leaf_nested_dict(d, int)
79
+ self.assertEqual(d['a']['i']['x'], 123)
80
+ self.assertEqual(d['b'], 456)
81
+
82
+ def test_write_model_params_keras_model(self):
83
+ inputs = np.zeros([2, 3])
84
+ model = test_utils.FakeKerasModel()
85
+ model(inputs) # Must do forward pass to build the model.
86
+
87
+ filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
88
+ train_utils.write_model_params(model, filepath)
89
+ actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
90
+
91
+ expected = [
92
+ 'fake_keras_model/dense/kernel:0 [3, 4]',
93
+ 'fake_keras_model/dense/bias:0 [4]',
94
+ 'fake_keras_model/dense_1/kernel:0 [4, 4]',
95
+ 'fake_keras_model/dense_1/bias:0 [4]',
96
+ '',
97
+ 'Total params: 36',
98
+ ]
99
+ self.assertEqual(actual, expected)
100
+
101
+ def test_write_model_params_module(self):
102
+ inputs = np.zeros([2, 3], dtype=np.float32)
103
+ model = test_utils.FakeModule(3, name='fake_module')
104
+ model(inputs) # Must do forward pass to build the model.
105
+
106
+ filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
107
+ train_utils.write_model_params(model, filepath)
108
+ actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
109
+
110
+ expected = [
111
+ 'fake_module/dense/b:0 [4]',
112
+ 'fake_module/dense/w:0 [3, 4]',
113
+ 'fake_module/dense_1/b:0 [4]',
114
+ 'fake_module/dense_1/w:0 [4, 4]',
115
+ '',
116
+ 'Total params: 36',
117
+ ]
118
+ self.assertEqual(actual, expected)
119
+
120
+ def test_construct_experiment_from_flags(self):
121
+ options = train_utils.ParseConfigOptions(
122
+ experiment='foo',
123
+ config_file=[],
124
+ tpu='bar',
125
+ tf_data_service='',
126
+ params_override='task.model.model_id=new,'
127
+ 'trainer.train_steps=10,'
128
+ 'trainer.validation_steps=11')
129
+ builder = train_utils.ExperimentParser(options)
130
+ params_from_obj = builder.parse()
131
+ params_from_func = train_utils.parse_configuration(options)
132
+ pp = pprint.PrettyPrinter()
133
+ self.assertEqual(
134
+ pp.pformat(params_from_obj.as_dict()),
135
+ pp.pformat(params_from_func.as_dict()))
136
+ self.assertEqual(params_from_obj.runtime.tpu, 'bar')
137
+ self.assertEqual(params_from_obj.task.model.model_id, 'new')
138
+ self.assertEqual(params_from_obj.trainer.train_steps, 10)
139
+ self.assertEqual(params_from_obj.trainer.validation_steps, 11)
140
+
141
+
142
+ class BestCheckpointExporterTest(tf.test.TestCase):
143
+
144
+ def test_maybe_export(self):
145
+ model_dir = self.create_tempdir().full_path
146
+ best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1')
147
+ metric_name = 'test_metric|metric_1'
148
+ exporter = train_utils.BestCheckpointExporter(
149
+ model_dir, metric_name, 'higher')
150
+ v = tf.Variable(1.0)
151
+ checkpoint = tf.train.Checkpoint(v=v)
152
+ ret = exporter.maybe_export_checkpoint(
153
+ checkpoint, {'test_metric': {'metric_1': 5.0}}, 100)
154
+ with self.subTest(name='Successful first save.'):
155
+ self.assertEqual(ret, True)
156
+ v_2 = tf.Variable(2.0)
157
+ checkpoint_2 = tf.train.Checkpoint(v=v_2)
158
+ checkpoint_2.restore(best_ckpt_path)
159
+ self.assertEqual(v_2.numpy(), 1.0)
160
+
161
+ v = tf.Variable(3.0)
162
+ checkpoint = tf.train.Checkpoint(v=v)
163
+ ret = exporter.maybe_export_checkpoint(
164
+ checkpoint, {'test_metric': {'metric_1': 6.0}}, 200)
165
+ with self.subTest(name='Successful better metic save.'):
166
+ self.assertEqual(ret, True)
167
+ v_2 = tf.Variable(2.0)
168
+ checkpoint_2 = tf.train.Checkpoint(v=v_2)
169
+ checkpoint_2.restore(best_ckpt_path)
170
+ self.assertEqual(v_2.numpy(), 3.0)
171
+
172
+ v = tf.Variable(5.0)
173
+ checkpoint = tf.train.Checkpoint(v=v)
174
+ ret = exporter.maybe_export_checkpoint(
175
+ checkpoint, {'test_metric': {'metric_1': 1.0}}, 300)
176
+ with self.subTest(name='Worse metic no save.'):
177
+ self.assertEqual(ret, False)
178
+ v_2 = tf.Variable(2.0)
179
+ checkpoint_2 = tf.train.Checkpoint(v=v_2)
180
+ checkpoint_2.restore(best_ckpt_path)
181
+ self.assertEqual(v_2.numpy(), 3.0)
182
+
183
+ def test_export_best_eval_metric(self):
184
+ model_dir = self.create_tempdir().full_path
185
+ metric_name = 'test_metric|metric_1'
186
+ exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
187
+ 'higher')
188
+ exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100)
189
+ with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
190
+ 'rb') as reader:
191
+ metric = json.loads(reader.read())
192
+ self.assertAllEqual(
193
+ metric,
194
+ {'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
195
+
196
+ def test_export_best_eval_metric_skips_non_scalar_values(self):
197
+ model_dir = self.create_tempdir().full_path
198
+ metric_name = 'test_metric|metric_1'
199
+ exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
200
+ 'higher')
201
+ image = tf.zeros(shape=[16, 8, 1])
202
+ eval_logs = {'test_metric': {'metric_1': 5.0, 'image': image}}
203
+
204
+ exporter.export_best_eval_metric(eval_logs, 100)
205
+
206
+ with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
207
+ 'rb') as reader:
208
+ metric = json.loads(reader.read())
209
+ self.assertAllEqual(
210
+ metric,
211
+ {'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
212
+
213
+
214
+ if __name__ == '__main__':
215
+ tf.test.main()
modeling/official/legacy/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Models in this `legacy` directory are mainly are used for benchmarking the
2
+ models.
3
+
4
+ Please note that the models in this `legacy` directory are not supported like
5
+ the models in official/nlp and official/vision.
modeling/official/legacy/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
modeling/official/legacy/albert/README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # ALBERT (ALBERT: A Lite BERT for Self-supervised Learning of Language Representations)
2
+
3
+ **WARNING**: This directory is deprecated.
4
+ See `nlp/docs/MODEL_GARDEN.md` for the new ALBERT implementation.
modeling/official/legacy/albert/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
modeling/official/legacy/albert/configs.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The ALBERT configurations."""
16
+
17
+ import six
18
+
19
+ from official.legacy.bert import configs
20
+
21
+
22
+ class AlbertConfig(configs.BertConfig):
23
+ """Configuration for `ALBERT`."""
24
+
25
+ def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
26
+ """Constructs AlbertConfig.
27
+
28
+ Args:
29
+ num_hidden_groups: Number of group for the hidden layers, parameters in
30
+ the same group are shared. Note that this value and also the following
31
+ 'inner_group_num' has to be 1 for now, because all released ALBERT
32
+ models set them to 1. We may support arbitary valid values in future.
33
+ inner_group_num: Number of inner repetition of attention and ffn.
34
+ **kwargs: The remaining arguments are the same as above 'BertConfig'.
35
+ """
36
+ super(AlbertConfig, self).__init__(**kwargs)
37
+
38
+ # TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
39
+ # in the released ALBERT. Support other values in AlbertEncoder if needed.
40
+ if inner_group_num != 1 or num_hidden_groups != 1:
41
+ raise ValueError("We only support 'inner_group_num' and "
42
+ "'num_hidden_groups' as 1.")
43
+
44
+ @classmethod
45
+ def from_dict(cls, json_object):
46
+ """Constructs a `AlbertConfig` from a Python dictionary of parameters."""
47
+ config = AlbertConfig(vocab_size=None)
48
+ for (key, value) in six.iteritems(json_object):
49
+ config.__dict__[key] = value
50
+ return config
modeling/official/legacy/bert/README.md ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BERT (Bidirectional Encoder Representations from Transformers)
2
+
3
+ **WARNING**: We are on the way to deprecating most of the code in this directory.
4
+ Please see
5
+ [this link](../g3doc/tutorials/bert_new.md)
6
+ for the new tutorial and use the new code in `nlp/modeling`. This README is
7
+ still correct for this legacy implementation.
8
+
9
+ The academic paper which describes BERT in detail and provides full results on a
10
+ number of tasks can be found here: https://arxiv.org/abs/1810.04805.
11
+
12
+ This repository contains TensorFlow 2.x implementation for BERT.
13
+
14
+ ## Contents
15
+ * [Contents](#contents)
16
+ * [Pre-trained Models](#pre-trained-models)
17
+ * [Restoring from Checkpoints](#restoring-from-checkpoints)
18
+ * [Set Up](#set-up)
19
+ * [Process Datasets](#process-datasets)
20
+ * [Fine-tuning with BERT](#fine-tuning-with-bert)
21
+ * [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
22
+ * [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
23
+ * [SQuAD 1.1](#squad-1.1)
24
+
25
+
26
+ ## Pre-trained Models
27
+
28
+ We released both checkpoints and tf.hub modules as the pretrained models for
29
+ fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
30
+ released in TF 1.x official BERT repository
31
+ [google-research/bert](https://github.com/google-research/bert)
32
+ in order to keep consistent with BERT paper.
33
+
34
+
35
+ ### Access to Pretrained Checkpoints
36
+
37
+ Pretrained checkpoints can be found in the following links:
38
+
39
+ **Note: We have switched BERT implementation
40
+ to use Keras functional-style networks in [nlp/modeling](../modeling).
41
+ The new checkpoints are:**
42
+
43
+ * **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
44
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
45
+ * **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
46
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
47
+ * **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
48
+ 12-layer, 768-hidden, 12-heads, 110M parameters
49
+ * **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
50
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
51
+ * **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
52
+ 12-layer, 768-hidden, 12-heads , 110M parameters
53
+ * **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
54
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
55
+ * **[`BERT-Base, Multilingual Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/multi_cased_L-12_H-768_A-12.tar.gz)**:
56
+ 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
57
+
58
+ We recommend to host checkpoints on Google Cloud Storage buckets when you use
59
+ Cloud GPU/TPU.
60
+
61
+ ### Restoring from Checkpoints
62
+
63
+ `tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
64
+ weights from provided pre-trained checkpoints, you can use the following code:
65
+
66
+ ```python
67
+ init_checkpoint='the pretrained model checkpoint path.'
68
+ model=tf.keras.Model() # Bert pre-trained model as feature extractor.
69
+ checkpoint = tf.train.Checkpoint(model=model)
70
+ checkpoint.restore(init_checkpoint)
71
+ ```
72
+
73
+ Checkpoints featuring native serialized Keras models
74
+ (i.e. model.load()/load_weights()) will be available soon.
75
+
76
+ ### Access to Pretrained hub modules.
77
+
78
+ Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
79
+ following links:
80
+
81
+ * **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/)**:
82
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
83
+ * **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/)**:
84
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
85
+ * **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/)**:
86
+ 12-layer, 768-hidden, 12-heads, 110M parameters
87
+ * **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/)**:
88
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
89
+ * **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/)**:
90
+ 12-layer, 768-hidden, 12-heads , 110M parameters
91
+ * **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/)**:
92
+ 24-layer, 1024-hidden, 16-heads, 340M parameters
93
+ * **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/)**:
94
+ 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
95
+ * **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/)**:
96
+ Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
97
+ 110M parameters
98
+
99
+ ## Set Up
100
+
101
+ ```shell
102
+ export PYTHONPATH="$PYTHONPATH:/path/to/models"
103
+ ```
104
+
105
+ Install `tf-nightly` to get latest updates:
106
+
107
+ ```shell
108
+ pip install tf-nightly-gpu
109
+ ```
110
+
111
+ With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
112
+ TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
113
+
114
+ ```shell
115
+ ctpu up -name <instance name> --tf-version=”nightly”
116
+ ```
117
+
118
+ Second, you need to install TF 2 `tf-nightly` on your VM:
119
+
120
+ ```shell
121
+ pip install tf-nightly
122
+ ```
123
+
124
+ ## Process Datasets
125
+
126
+ ### Pre-training
127
+
128
+ There is no change to generate pre-training data. Please use the script
129
+ [`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
130
+ which is essentially branched from the [BERT research repo](https://github.com/google-research/bert)
131
+ to get processed pre-training data and it adapts to TF2 symbols and python3
132
+ compatibility.
133
+
134
+ Running the pre-training script requires an input and output directory, as well as a vocab file. Note that max_seq_length will need to match the sequence length parameter you specify when you run pre-training.
135
+
136
+ Example shell script to call create_pretraining_data.py
137
+ ```
138
+ export WORKING_DIR='local disk or cloud location'
139
+ export BERT_DIR='local disk or cloud location'
140
+ python models/official/nlp/data/create_pretraining_data.py \
141
+ --input_file=$WORKING_DIR/input/input.txt \
142
+ --output_file=$WORKING_DIR/output/tf_examples.tfrecord \
143
+ --vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
144
+ --do_lower_case=True \
145
+ --max_seq_length=512 \
146
+ --max_predictions_per_seq=76 \
147
+ --masked_lm_prob=0.15 \
148
+ --random_seed=12345 \
149
+ --dupe_factor=5
150
+ ```
151
+
152
+ ### Fine-tuning
153
+
154
+ To prepare the fine-tuning data for final model training, use the
155
+ [`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
156
+ Resulting datasets in `tf_record` format and training meta data should be later
157
+ passed to training or evaluation scripts. The task-specific arguments are
158
+ described in the following sections:
159
+
160
+ * GLUE
161
+
162
+ Users can download the
163
+ [GLUE data](https://gluebenchmark.com/tasks) by running
164
+ [this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
165
+ and unpack it to some directory `$GLUE_DIR`.
166
+ Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate it on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.
167
+
168
+ ```shell
169
+ export GLUE_DIR=~/glue
170
+ export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
171
+
172
+ export TASK_NAME=MNLI
173
+ export OUTPUT_DIR=gs://some_bucket/datasets
174
+ python ../data/create_finetuning_data.py \
175
+ --input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
176
+ --vocab_file=${BERT_DIR}/vocab.txt \
177
+ --train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
178
+ --eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
179
+ --meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
180
+ --fine_tuning_task_type=classification --max_seq_length=128 \
181
+ --classification_task_name=${TASK_NAME}
182
+ ```
183
+
184
+ * SQUAD
185
+
186
+ The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
187
+ detailed information about the SQuAD datasets and evaluation.
188
+
189
+ The necessary files can be found here:
190
+
191
+ * [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
192
+ * [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
193
+ * [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
194
+ * [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
195
+ * [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
196
+ * [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
197
+
198
+ ```shell
199
+ export SQUAD_DIR=~/squad
200
+ export SQUAD_VERSION=v1.1
201
+ export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
202
+ export OUTPUT_DIR=gs://some_bucket/datasets
203
+
204
+ python ../data/create_finetuning_data.py \
205
+ --squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
206
+ --vocab_file=${BERT_DIR}/vocab.txt \
207
+ --train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
208
+ --meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
209
+ --fine_tuning_task_type=squad --max_seq_length=384
210
+ ```
211
+
212
+ Note: To create fine-tuning data with SQUAD 2.0, you need to add flag `--version_2_with_negative=True`.
213
+
214
+ ## Fine-tuning with BERT
215
+
216
+ ### Cloud GPUs and TPUs
217
+
218
+ * Cloud Storage
219
+
220
+ The unzipped pre-trained model files can also be found in the Google Cloud
221
+ Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
222
+
223
+ ```shell
224
+ export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
225
+ export MODEL_DIR=gs://some_bucket/my_output_dir
226
+ ```
227
+
228
+ Currently, users are able to access to `tf-nightly` TPUs and the following TPU
229
+ script should run with `tf-nightly`.
230
+
231
+ * GPU -> TPU
232
+
233
+ Just add the following flags to `run_classifier.py` or `run_squad.py`:
234
+
235
+ ```shell
236
+ --distribution_strategy=tpu
237
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
238
+ ```
239
+
240
+ ### Sentence and Sentence-pair Classification Tasks
241
+
242
+ This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
243
+ Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
244
+ few minutes on most GPUs.
245
+
246
+ We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
247
+ workflow.
248
+ For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
249
+ (uncased_L-12_H-768_A-12).
250
+
251
+ ```shell
252
+ export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
253
+ export MODEL_DIR=gs://some_bucket/my_output_dir
254
+ export GLUE_DIR=gs://some_bucket/datasets
255
+ export TASK=MRPC
256
+
257
+ python run_classifier.py \
258
+ --mode='train_and_eval' \
259
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
260
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
261
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
262
+ --bert_config_file=${BERT_DIR}/bert_config.json \
263
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
264
+ --train_batch_size=4 \
265
+ --eval_batch_size=4 \
266
+ --steps_per_loop=1 \
267
+ --learning_rate=2e-5 \
268
+ --num_train_epochs=3 \
269
+ --model_dir=${MODEL_DIR} \
270
+ --distribution_strategy=mirrored
271
+ ```
272
+
273
+ Alternatively, instead of specifying `init_checkpoint`, you can specify
274
+ `hub_module_url` to employ a pre-trained BERT hub module, e.g.,
275
+ ` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
276
+
277
+ After training a model, to get predictions from the classifier, you can set the
278
+ `--mode=predict` and offer the test set tfrecords to `--eval_data_path`.
279
+ The output will be created in file called test_results.tsv in the output folder.
280
+ Each line will contain output for each sample, columns are the class
281
+ probabilities.
282
+
283
+ ```shell
284
+ python run_classifier.py \
285
+ --mode='predict' \
286
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
287
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
288
+ --bert_config_file=${BERT_DIR}/bert_config.json \
289
+ --eval_batch_size=4 \
290
+ --model_dir=${MODEL_DIR} \
291
+ --distribution_strategy=mirrored
292
+ ```
293
+
294
+ To use TPU, you only need to switch the distribution strategy type to `tpu` with TPU
295
+ information and use remote storage for model checkpoints.
296
+
297
+ ```shell
298
+ export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
299
+ export TPU_IP_ADDRESS='???'
300
+ export MODEL_DIR=gs://some_bucket/my_output_dir
301
+ export GLUE_DIR=gs://some_bucket/datasets
302
+ export TASK=MRPC
303
+
304
+ python run_classifier.py \
305
+ --mode='train_and_eval' \
306
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
307
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
308
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
309
+ --bert_config_file=${BERT_DIR}/bert_config.json \
310
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
311
+ --train_batch_size=32 \
312
+ --eval_batch_size=32 \
313
+ --steps_per_loop=1000 \
314
+ --learning_rate=2e-5 \
315
+ --num_train_epochs=3 \
316
+ --model_dir=${MODEL_DIR} \
317
+ --distribution_strategy=tpu \
318
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
319
+ ```
320
+
321
+ Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
322
+ training steps inside a `tf.function` can significantly increase TPU utilization
323
+ and callbacks will not be called inside the loop.
324
+
325
+ ### SQuAD 1.1
326
+
327
+ The Stanford Question Answering Dataset (SQuAD) is a popular question answering
328
+ benchmark dataset. See more on [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
329
+
330
+ We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
331
+ workflow.
332
+ For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
333
+ (uncased_L-12_H-768_A-12).
334
+
335
+ ```shell
336
+ export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
337
+ export SQUAD_DIR=gs://some_bucket/datasets
338
+ export MODEL_DIR=gs://some_bucket/my_output_dir
339
+ export SQUAD_VERSION=v1.1
340
+
341
+ python run_squad.py \
342
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
343
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
344
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
345
+ --vocab_file=${BERT_DIR}/vocab.txt \
346
+ --bert_config_file=${BERT_DIR}/bert_config.json \
347
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
348
+ --train_batch_size=4 \
349
+ --predict_batch_size=4 \
350
+ --learning_rate=8e-5 \
351
+ --num_train_epochs=2 \
352
+ --model_dir=${MODEL_DIR} \
353
+ --distribution_strategy=mirrored
354
+ ```
355
+
356
+ Similarly, you can replace `init_checkpoint` FLAG with `hub_module_url` to
357
+ specify a hub module path.
358
+
359
+ `run_squad.py` writes the prediction for `--predict_file` by default. If you set
360
+ the `--model=predict` and offer the SQuAD test data, the scripts will generate
361
+ the prediction json file.
362
+
363
+ To use TPU, you need to switch the distribution strategy type to `tpu` with TPU
364
+ information.
365
+
366
+ ```shell
367
+ export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
368
+ export TPU_IP_ADDRESS='???'
369
+ export MODEL_DIR=gs://some_bucket/my_output_dir
370
+ export SQUAD_DIR=gs://some_bucket/datasets
371
+ export SQUAD_VERSION=v1.1
372
+
373
+ python run_squad.py \
374
+ --input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
375
+ --train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
376
+ --predict_file=${SQUAD_DIR}/dev-v1.1.json \
377
+ --vocab_file=${BERT_DIR}/vocab.txt \
378
+ --bert_config_file=${BERT_DIR}/bert_config.json \
379
+ --init_checkpoint=${BERT_DIR}/bert_model.ckpt \
380
+ --train_batch_size=32 \
381
+ --learning_rate=8e-5 \
382
+ --num_train_epochs=2 \
383
+ --model_dir=${MODEL_DIR} \
384
+ --distribution_strategy=tpu \
385
+ --tpu=grpc://${TPU_IP_ADDRESS}:8470
386
+ ```
387
+
388
+ The dev set predictions will be saved into a file called predictions.json in the
389
+ model_dir:
390
+
391
+ ```shell
392
+ python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
393
+ ```
394
+
395
+
modeling/official/legacy/bert/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
modeling/official/legacy/bert/bert_cloud_tpu.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BERT FineTuning with Cloud TPU: Sentence and Sentence-Pair Classification Tasks (TF 2.1)
2
+ This tutorial shows you how to train the Bidirectional Encoder Representations from Transformers (BERT) model on Cloud TPU.
3
+
4
+
5
+ ## Set up Cloud Storage and Compute Engine VM
6
+ 1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
7
+ 2. Create a variable for the project's id:
8
+ ```
9
+ export PROJECT_ID=your-project_id
10
+ ```
11
+ 3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
12
+ ```
13
+ gcloud config set project ${PROJECT_ID}
14
+ ```
15
+ 4. Create a Cloud Storage bucket using the following command:
16
+ ```
17
+ gsutil mb -p ${PROJECT_ID} -c standard -l europe-west4 -b on gs://your-bucket-name
18
+ ```
19
+ This Cloud Storage bucket stores the data you use to train your model and the training results.
20
+ 5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
21
+ ```
22
+ ctpu up --tpu-size=v3-8 \
23
+ --machine-type=n1-standard-8 \
24
+ --zone=europe-west4-a \
25
+ --tf-version=2.1 [optional flags: --project, --name]
26
+ ```
27
+ 6. The configuration you specified appears. Enter y to approve or n to cancel.
28
+ 7. When the ctpu up command has finished executing, verify that your shell prompt has changed from username@project to username@tpuname. This change shows that you are now logged into your Compute Engine VM.
29
+ ```
30
+ gcloud compute ssh vm-name --zone=europe-west4-a
31
+ (vm)$ export TPU_NAME=vm-name
32
+ ```
33
+ As you continue these instructions, run each command that begins with `(vm)$` in your VM session window.
34
+
35
+ ## Prepare the Dataset
36
+ 1. From your Compute Engine virtual machine (VM), install requirements.txt.
37
+ ```
38
+ (vm)$ cd /usr/share/models
39
+ (vm)$ sudo pip3 install -r official/requirements.txt
40
+ ```
41
+ 2. Optional: download download_glue_data.py
42
+
43
+ This tutorial uses the General Language Understanding Evaluation (GLUE) benchmark to evaluate and analyze the performance of the model. The GLUE data is provided for this tutorial at gs://cloud-tpu-checkpoints/bert/classification.
44
+
45
+ ## Define parameter values
46
+ Next, define several parameter values that are required when you train and evaluate your model:
47
+
48
+ ```
49
+ (vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
50
+ (vm)$ export STORAGE_BUCKET=gs://your-bucket-name
51
+ (vm)$ export BERT_BASE_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
52
+ (vm)$ export MODEL_DIR=${STORAGE_BUCKET}/bert-output
53
+ (vm)$ export GLUE_DIR=gs://cloud-tpu-checkpoints/bert/classification
54
+ (vm)$ export TASK=mnli
55
+ ```
56
+
57
+ ## Train the model
58
+ From your Compute Engine VM, run the following command.
59
+
60
+ ```
61
+ (vm)$ python3 official/nlp/bert/run_classifier.py \
62
+ --mode='train_and_eval' \
63
+ --input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
64
+ --train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
65
+ --eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
66
+ --bert_config_file=$BERT_BASE_DIR/bert_config.json \
67
+ --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
68
+ --train_batch_size=32 \
69
+ --eval_batch_size=32 \
70
+ --learning_rate=2e-5 \
71
+ --num_train_epochs=3 \
72
+ --model_dir=${MODEL_DIR} \
73
+ --distribution_strategy=tpu \
74
+ --tpu=${TPU_NAME}
75
+ ```
76
+
77
+ ## Verify your results
78
+ The training takes approximately 1 hour on a v3-8 TPU. When script completes, you should see results similar to the following:
79
+ ```
80
+ Training Summary:
81
+ {'train_loss': 0.28142181038856506,
82
+ 'last_train_metrics': 0.9467429518699646,
83
+ 'eval_metrics': 0.8599063158035278,
84
+ 'total_training_steps': 36813}
85
+ ```
86
+
87
+ ## Clean up
88
+ To avoid incurring charges to your GCP account for the resources used in this topic:
89
+ 1. Disconnect from the Compute Engine VM:
90
+ ```
91
+ (vm)$ exit
92
+ ```
93
+ 2. In your Cloud Shell, run ctpu delete with the --zone flag you used when you set up the Cloud TPU to delete your Compute Engine VM and your Cloud TPU:
94
+ ```
95
+ $ ctpu delete --zone=your-zone
96
+ ```
97
+ 3. Run ctpu status specifying your zone to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:
98
+ ```
99
+ $ ctpu status --zone=your-zone
100
+ ```
101
+ 4. Run gsutil as shown, replacing your-bucket with the name of the Cloud Storage bucket you created for this tutorial:
102
+ ```
103
+ $ gsutil rm -r gs://your-bucket
104
+ ```
105
+
106
+
107
+
108
+
109
+
110
+
modeling/official/legacy/bert/bert_models.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """BERT models that are compatible with TF 2.0."""
16
+
17
+ import gin
18
+ import tensorflow as tf, tf_keras
19
+ import tensorflow_hub as hub
20
+ from official.legacy.albert import configs as albert_configs
21
+ from official.legacy.bert import configs
22
+ from official.modeling import tf_utils
23
+ from official.nlp.modeling import models
24
+ from official.nlp.modeling import networks
25
+
26
+
27
+ class BertPretrainLossAndMetricLayer(tf_keras.layers.Layer):
28
+ """Returns layer that computes custom loss and metrics for pretraining."""
29
+
30
+ def __init__(self, vocab_size, **kwargs):
31
+ super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
32
+ self._vocab_size = vocab_size
33
+ self.config = {
34
+ 'vocab_size': vocab_size,
35
+ }
36
+
37
+ def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
38
+ lm_example_loss, sentence_output, sentence_labels,
39
+ next_sentence_loss):
40
+ """Adds metrics."""
41
+ masked_lm_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
42
+ lm_labels, lm_output)
43
+ numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
44
+ denominator = tf.reduce_sum(lm_label_weights) + 1e-5
45
+ masked_lm_accuracy = numerator / denominator
46
+ self.add_metric(
47
+ masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
48
+
49
+ self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
50
+
51
+ if sentence_labels is not None:
52
+ next_sentence_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
53
+ sentence_labels, sentence_output)
54
+ self.add_metric(
55
+ next_sentence_accuracy,
56
+ name='next_sentence_accuracy',
57
+ aggregation='mean')
58
+
59
+ if next_sentence_loss is not None:
60
+ self.add_metric(
61
+ next_sentence_loss, name='next_sentence_loss', aggregation='mean')
62
+
63
+ def call(self,
64
+ lm_output_logits,
65
+ sentence_output_logits,
66
+ lm_label_ids,
67
+ lm_label_weights,
68
+ sentence_labels=None):
69
+ """Implements call() for the layer."""
70
+ lm_label_weights = tf.cast(lm_label_weights, tf.float32)
71
+ lm_output_logits = tf.cast(lm_output_logits, tf.float32)
72
+
73
+ lm_prediction_losses = tf_keras.losses.sparse_categorical_crossentropy(
74
+ lm_label_ids, lm_output_logits, from_logits=True)
75
+ lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
76
+ lm_denominator_loss = tf.reduce_sum(lm_label_weights)
77
+ mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
78
+ lm_denominator_loss)
79
+
80
+ if sentence_labels is not None:
81
+ sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
82
+ sentence_loss = tf_keras.losses.sparse_categorical_crossentropy(
83
+ sentence_labels, sentence_output_logits, from_logits=True)
84
+ sentence_loss = tf.reduce_mean(sentence_loss)
85
+ loss = mask_label_loss + sentence_loss
86
+ else:
87
+ sentence_loss = None
88
+ loss = mask_label_loss
89
+
90
+ batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
91
+ # TODO(hongkuny): Avoids the hack and switches add_loss.
92
+ final_loss = tf.fill(batch_shape, loss)
93
+
94
+ self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
95
+ mask_label_loss, sentence_output_logits, sentence_labels,
96
+ sentence_loss)
97
+ return final_loss
98
+
99
+
100
+ @gin.configurable
101
+ def get_transformer_encoder(bert_config,
102
+ sequence_length=None,
103
+ transformer_encoder_cls=None,
104
+ output_range=None):
105
+ """Gets a 'TransformerEncoder' object.
106
+
107
+ Args:
108
+ bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
109
+ sequence_length: [Deprecated].
110
+ transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
111
+ default BERT encoder implementation.
112
+ output_range: the sequence output range, [0, output_range). Default setting
113
+ is to return the entire sequence output.
114
+
115
+ Returns:
116
+ A encoder object.
117
+ """
118
+ del sequence_length
119
+ if transformer_encoder_cls is not None:
120
+ # TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
121
+ embedding_cfg = dict(
122
+ vocab_size=bert_config.vocab_size,
123
+ type_vocab_size=bert_config.type_vocab_size,
124
+ hidden_size=bert_config.hidden_size,
125
+ max_seq_length=bert_config.max_position_embeddings,
126
+ initializer=tf_keras.initializers.TruncatedNormal(
127
+ stddev=bert_config.initializer_range),
128
+ dropout_rate=bert_config.hidden_dropout_prob,
129
+ )
130
+ hidden_cfg = dict(
131
+ num_attention_heads=bert_config.num_attention_heads,
132
+ intermediate_size=bert_config.intermediate_size,
133
+ intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
134
+ dropout_rate=bert_config.hidden_dropout_prob,
135
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
136
+ kernel_initializer=tf_keras.initializers.TruncatedNormal(
137
+ stddev=bert_config.initializer_range),
138
+ )
139
+ kwargs = dict(
140
+ embedding_cfg=embedding_cfg,
141
+ hidden_cfg=hidden_cfg,
142
+ num_hidden_instances=bert_config.num_hidden_layers,
143
+ pooled_output_dim=bert_config.hidden_size,
144
+ pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
145
+ stddev=bert_config.initializer_range))
146
+
147
+ # Relies on gin configuration to define the Transformer encoder arguments.
148
+ return transformer_encoder_cls(**kwargs)
149
+
150
+ kwargs = dict(
151
+ vocab_size=bert_config.vocab_size,
152
+ hidden_size=bert_config.hidden_size,
153
+ num_layers=bert_config.num_hidden_layers,
154
+ num_attention_heads=bert_config.num_attention_heads,
155
+ intermediate_size=bert_config.intermediate_size,
156
+ activation=tf_utils.get_activation(bert_config.hidden_act),
157
+ dropout_rate=bert_config.hidden_dropout_prob,
158
+ attention_dropout_rate=bert_config.attention_probs_dropout_prob,
159
+ max_sequence_length=bert_config.max_position_embeddings,
160
+ type_vocab_size=bert_config.type_vocab_size,
161
+ embedding_width=bert_config.embedding_size,
162
+ initializer=tf_keras.initializers.TruncatedNormal(
163
+ stddev=bert_config.initializer_range))
164
+ if isinstance(bert_config, albert_configs.AlbertConfig):
165
+ return networks.AlbertEncoder(**kwargs)
166
+ else:
167
+ assert isinstance(bert_config, configs.BertConfig)
168
+ kwargs['output_range'] = output_range
169
+ return networks.BertEncoder(**kwargs)
170
+
171
+
172
+ def pretrain_model(bert_config,
173
+ seq_length,
174
+ max_predictions_per_seq,
175
+ initializer=None,
176
+ use_next_sentence_label=True,
177
+ return_core_pretrainer_model=False):
178
+ """Returns model to be used for pre-training.
179
+
180
+ Args:
181
+ bert_config: Configuration that defines the core BERT model.
182
+ seq_length: Maximum sequence length of the training data.
183
+ max_predictions_per_seq: Maximum number of tokens in sequence to mask out
184
+ and use for pretraining.
185
+ initializer: Initializer for weights in BertPretrainer.
186
+ use_next_sentence_label: Whether to use the next sentence label.
187
+ return_core_pretrainer_model: Whether to also return the `BertPretrainer`
188
+ object.
189
+
190
+ Returns:
191
+ A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
192
+ save weights after pretraining, and (3) optional core `BertPretrainer`
193
+ object if argument `return_core_pretrainer_model` is True.
194
+ """
195
+ input_word_ids = tf_keras.layers.Input(
196
+ shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
197
+ input_mask = tf_keras.layers.Input(
198
+ shape=(seq_length,), name='input_mask', dtype=tf.int32)
199
+ input_type_ids = tf_keras.layers.Input(
200
+ shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
201
+ masked_lm_positions = tf_keras.layers.Input(
202
+ shape=(max_predictions_per_seq,),
203
+ name='masked_lm_positions',
204
+ dtype=tf.int32)
205
+ masked_lm_ids = tf_keras.layers.Input(
206
+ shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
207
+ masked_lm_weights = tf_keras.layers.Input(
208
+ shape=(max_predictions_per_seq,),
209
+ name='masked_lm_weights',
210
+ dtype=tf.int32)
211
+
212
+ if use_next_sentence_label:
213
+ next_sentence_labels = tf_keras.layers.Input(
214
+ shape=(1,), name='next_sentence_labels', dtype=tf.int32)
215
+ else:
216
+ next_sentence_labels = None
217
+
218
+ transformer_encoder = get_transformer_encoder(bert_config, seq_length)
219
+ if initializer is None:
220
+ initializer = tf_keras.initializers.TruncatedNormal(
221
+ stddev=bert_config.initializer_range)
222
+ pretrainer_model = models.BertPretrainer(
223
+ network=transformer_encoder,
224
+ embedding_table=transformer_encoder.get_embedding_table(),
225
+ num_classes=2, # The next sentence prediction label has two classes.
226
+ activation=tf_utils.get_activation(bert_config.hidden_act),
227
+ num_token_predictions=max_predictions_per_seq,
228
+ initializer=initializer,
229
+ output='logits')
230
+
231
+ outputs = pretrainer_model(
232
+ [input_word_ids, input_mask, input_type_ids, masked_lm_positions])
233
+ lm_output = outputs['masked_lm']
234
+ sentence_output = outputs['classification']
235
+ pretrain_loss_layer = BertPretrainLossAndMetricLayer(
236
+ vocab_size=bert_config.vocab_size)
237
+ output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
238
+ masked_lm_weights, next_sentence_labels)
239
+ inputs = {
240
+ 'input_word_ids': input_word_ids,
241
+ 'input_mask': input_mask,
242
+ 'input_type_ids': input_type_ids,
243
+ 'masked_lm_positions': masked_lm_positions,
244
+ 'masked_lm_ids': masked_lm_ids,
245
+ 'masked_lm_weights': masked_lm_weights,
246
+ }
247
+ if use_next_sentence_label:
248
+ inputs['next_sentence_labels'] = next_sentence_labels
249
+
250
+ keras_model = tf_keras.Model(inputs=inputs, outputs=output_loss)
251
+ if return_core_pretrainer_model:
252
+ return keras_model, transformer_encoder, pretrainer_model
253
+ else:
254
+ return keras_model, transformer_encoder
255
+
256
+
257
+ def squad_model(bert_config,
258
+ max_seq_length,
259
+ initializer=None,
260
+ hub_module_url=None,
261
+ hub_module_trainable=True):
262
+ """Returns BERT Squad model along with core BERT model to import weights.
263
+
264
+ Args:
265
+ bert_config: BertConfig, the config defines the core Bert model.
266
+ max_seq_length: integer, the maximum input sequence length.
267
+ initializer: Initializer for the final dense layer in the span labeler.
268
+ Defaulted to TruncatedNormal initializer.
269
+ hub_module_url: TF-Hub path/url to Bert module.
270
+ hub_module_trainable: True to finetune layers in the hub module.
271
+
272
+ Returns:
273
+ A tuple of (1) keras model that outputs start logits and end logits and
274
+ (2) the core BERT transformer encoder.
275
+ """
276
+ if initializer is None:
277
+ initializer = tf_keras.initializers.TruncatedNormal(
278
+ stddev=bert_config.initializer_range)
279
+ if not hub_module_url:
280
+ bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
281
+ return models.BertSpanLabeler(
282
+ network=bert_encoder, initializer=initializer), bert_encoder
283
+
284
+ input_word_ids = tf_keras.layers.Input(
285
+ shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
286
+ input_mask = tf_keras.layers.Input(
287
+ shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
288
+ input_type_ids = tf_keras.layers.Input(
289
+ shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
290
+ core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
291
+ pooled_output, sequence_output = core_model(
292
+ [input_word_ids, input_mask, input_type_ids])
293
+ bert_encoder = tf_keras.Model(
294
+ inputs={
295
+ 'input_word_ids': input_word_ids,
296
+ 'input_mask': input_mask,
297
+ 'input_type_ids': input_type_ids,
298
+ },
299
+ outputs=[sequence_output, pooled_output],
300
+ name='core_model')
301
+ return models.BertSpanLabeler(
302
+ network=bert_encoder, initializer=initializer), bert_encoder
303
+
304
+
305
+ def classifier_model(bert_config,
306
+ num_labels,
307
+ max_seq_length=None,
308
+ final_layer_initializer=None,
309
+ hub_module_url=None,
310
+ hub_module_trainable=True):
311
+ """BERT classifier model in functional API style.
312
+
313
+ Construct a Keras model for predicting `num_labels` outputs from an input with
314
+ maximum sequence length `max_seq_length`.
315
+
316
+ Args:
317
+ bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
318
+ ALBERT model.
319
+ num_labels: integer, the number of classes.
320
+ max_seq_length: integer, the maximum input sequence length.
321
+ final_layer_initializer: Initializer for final dense layer. Defaulted
322
+ TruncatedNormal initializer.
323
+ hub_module_url: TF-Hub path/url to Bert module.
324
+ hub_module_trainable: True to finetune layers in the hub module.
325
+
326
+ Returns:
327
+ Combined prediction model (words, mask, type) -> (one-hot labels)
328
+ BERT sub-model (words, mask, type) -> (bert_outputs)
329
+ """
330
+ if final_layer_initializer is not None:
331
+ initializer = final_layer_initializer
332
+ else:
333
+ initializer = tf_keras.initializers.TruncatedNormal(
334
+ stddev=bert_config.initializer_range)
335
+
336
+ if not hub_module_url:
337
+ bert_encoder = get_transformer_encoder(
338
+ bert_config, max_seq_length, output_range=1)
339
+ return models.BertClassifier(
340
+ bert_encoder,
341
+ num_classes=num_labels,
342
+ dropout_rate=bert_config.hidden_dropout_prob,
343
+ initializer=initializer), bert_encoder
344
+
345
+ input_word_ids = tf_keras.layers.Input(
346
+ shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
347
+ input_mask = tf_keras.layers.Input(
348
+ shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
349
+ input_type_ids = tf_keras.layers.Input(
350
+ shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
351
+ bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
352
+ pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
353
+ output = tf_keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
354
+ pooled_output)
355
+
356
+ output = tf_keras.layers.Dense(
357
+ num_labels, kernel_initializer=initializer, name='output')(
358
+ output)
359
+ return tf_keras.Model(
360
+ inputs={
361
+ 'input_word_ids': input_word_ids,
362
+ 'input_mask': input_mask,
363
+ 'input_type_ids': input_type_ids
364
+ },
365
+ outputs=output), bert_model
modeling/official/legacy/bert/bert_models_test.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import tensorflow as tf, tf_keras
16
+
17
+ from official.legacy.bert import bert_models
18
+ from official.legacy.bert import configs as bert_configs
19
+ from official.nlp.modeling import networks
20
+
21
+
22
+ class BertModelsTest(tf.test.TestCase):
23
+
24
+ def setUp(self):
25
+ super(BertModelsTest, self).setUp()
26
+ self._bert_test_config = bert_configs.BertConfig(
27
+ attention_probs_dropout_prob=0.0,
28
+ hidden_act='gelu',
29
+ hidden_dropout_prob=0.0,
30
+ hidden_size=16,
31
+ initializer_range=0.02,
32
+ intermediate_size=32,
33
+ max_position_embeddings=128,
34
+ num_attention_heads=2,
35
+ num_hidden_layers=2,
36
+ type_vocab_size=2,
37
+ vocab_size=30522)
38
+
39
+ def test_pretrain_model(self):
40
+ model, encoder = bert_models.pretrain_model(
41
+ self._bert_test_config,
42
+ seq_length=5,
43
+ max_predictions_per_seq=2,
44
+ initializer=None,
45
+ use_next_sentence_label=True)
46
+ self.assertIsInstance(model, tf_keras.Model)
47
+ self.assertIsInstance(encoder, networks.BertEncoder)
48
+
49
+ # model has one scalar output: loss value.
50
+ self.assertEqual(model.output.shape.as_list(), [
51
+ None,
52
+ ])
53
+
54
+ # Expect two output from encoder: sequence and classification output.
55
+ self.assertIsInstance(encoder.output, list)
56
+ self.assertLen(encoder.output, 2)
57
+ # shape should be [batch size, hidden_size]
58
+ self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
59
+
60
+ def test_squad_model(self):
61
+ model, core_model = bert_models.squad_model(
62
+ self._bert_test_config,
63
+ max_seq_length=5,
64
+ initializer=None,
65
+ hub_module_url=None,
66
+ hub_module_trainable=None)
67
+ self.assertIsInstance(model, tf_keras.Model)
68
+ self.assertIsInstance(core_model, tf_keras.Model)
69
+
70
+ # Expect two output from model: start positions and end positions
71
+ self.assertIsInstance(model.output, list)
72
+ self.assertLen(model.output, 2)
73
+
74
+ # Expect two output from core_model: sequence and classification output.
75
+ self.assertIsInstance(core_model.output, list)
76
+ self.assertLen(core_model.output, 2)
77
+ # shape should be [batch size, None, hidden_size]
78
+ self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
79
+ # shape should be [batch size, hidden_size]
80
+ self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
81
+
82
+ def test_classifier_model(self):
83
+ model, core_model = bert_models.classifier_model(
84
+ self._bert_test_config,
85
+ num_labels=3,
86
+ max_seq_length=5,
87
+ final_layer_initializer=None,
88
+ hub_module_url=None,
89
+ hub_module_trainable=None)
90
+ self.assertIsInstance(model, tf_keras.Model)
91
+ self.assertIsInstance(core_model, tf_keras.Model)
92
+
93
+ # model has one classification output with num_labels=3.
94
+ self.assertEqual(model.output.shape.as_list(), [None, 3])
95
+
96
+ # Expect two output from core_model: sequence and classification output.
97
+ self.assertIsInstance(core_model.output, list)
98
+ self.assertLen(core_model.output, 2)
99
+ # shape should be [batch size, None, hidden_size]
100
+ self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
101
+ # shape should be [batch size, hidden_size]
102
+ self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
103
+
104
+
105
+ if __name__ == '__main__':
106
+ tf.test.main()
modeling/official/legacy/bert/common_flags.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Defining common flags used across all BERT models/applications."""
16
+
17
+ from absl import flags
18
+ import tensorflow as tf, tf_keras
19
+
20
+ from official.utils import hyperparams_flags
21
+ from official.utils.flags import core as flags_core
22
+
23
+
24
+ def define_common_bert_flags():
25
+ """Define common flags for BERT tasks."""
26
+ flags_core.define_base(
27
+ data_dir=False,
28
+ model_dir=True,
29
+ clean=False,
30
+ train_epochs=False,
31
+ epochs_between_evals=False,
32
+ stop_threshold=False,
33
+ batch_size=False,
34
+ num_gpu=True,
35
+ export_dir=False,
36
+ distribution_strategy=True,
37
+ run_eagerly=True)
38
+ flags_core.define_distribution()
39
+ flags.DEFINE_string('bert_config_file', None,
40
+ 'Bert configuration file to define core bert layers.')
41
+ flags.DEFINE_string(
42
+ 'model_export_path', None,
43
+ 'Path to the directory, where trainined model will be '
44
+ 'exported.')
45
+ flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
46
+ flags.DEFINE_string(
47
+ 'init_checkpoint', None,
48
+ 'Initial checkpoint (usually from a pre-trained BERT model).')
49
+ flags.DEFINE_integer('num_train_epochs', 3,
50
+ 'Total number of training epochs to perform.')
51
+ flags.DEFINE_integer(
52
+ 'steps_per_loop', None,
53
+ 'Number of steps per graph-mode loop. Only training step '
54
+ 'happens inside the loop. Callbacks will not be called '
55
+ 'inside. If not set the value will be configured depending on the '
56
+ 'devices available.')
57
+ flags.DEFINE_float('learning_rate', 5e-5,
58
+ 'The initial learning rate for Adam.')
59
+ flags.DEFINE_float('end_lr', 0.0,
60
+ 'The end learning rate for learning rate decay.')
61
+ flags.DEFINE_string('optimizer_type', 'adamw',
62
+ 'The type of optimizer to use for training (adamw|lamb)')
63
+ flags.DEFINE_boolean(
64
+ 'scale_loss', False,
65
+ 'Whether to divide the loss by number of replica inside the per-replica '
66
+ 'loss function.')
67
+ flags.DEFINE_boolean(
68
+ 'use_keras_compile_fit', False,
69
+ 'If True, uses Keras compile/fit() API for training logic. Otherwise '
70
+ 'use custom training loop.')
71
+ flags.DEFINE_string(
72
+ 'hub_module_url', None, 'TF-Hub path/url to Bert module. '
73
+ 'If specified, init_checkpoint flag should not be used.')
74
+ flags.DEFINE_bool('hub_module_trainable', True,
75
+ 'True to make keras layers in the hub module trainable.')
76
+ flags.DEFINE_string(
77
+ 'sub_model_export_name', None,
78
+ 'If set, `sub_model` checkpoints are exported into '
79
+ 'FLAGS.model_dir/FLAGS.sub_model_export_name.')
80
+ flags.DEFINE_bool('explicit_allreduce', False,
81
+ 'True to use explicit allreduce instead of the implicit '
82
+ 'allreduce in optimizer.apply_gradients(). If fp16 mixed '
83
+ 'precision training is used, this also enables allreduce '
84
+ 'gradients in fp16.')
85
+ flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
86
+ 'Number of bytes of a gradient pack for allreduce. '
87
+ 'Should be positive integer, if set to 0, all '
88
+ 'gradients are in one pack. Breaking gradient into '
89
+ 'packs could enable overlap between allreduce and '
90
+ 'backprop computation. This flag only takes effect '
91
+ 'when explicit_allreduce is set to True.')
92
+
93
+ flags_core.define_log_steps()
94
+
95
+ # Adds flags for mixed precision and multi-worker training.
96
+ flags_core.define_performance(
97
+ num_parallel_calls=False,
98
+ inter_op=False,
99
+ intra_op=False,
100
+ synthetic_data=False,
101
+ max_train_steps=False,
102
+ dtype=True,
103
+ loss_scale=True,
104
+ all_reduce_alg=True,
105
+ num_packs=False,
106
+ tf_gpu_thread_mode=True,
107
+ datasets_num_private_threads=True,
108
+ enable_xla=True,
109
+ fp16_implementation=True,
110
+ )
111
+
112
+ # Adds gin configuration flags.
113
+ hyperparams_flags.define_gin_flags()
114
+
115
+
116
+ def dtype():
117
+ return flags_core.get_tf_dtype(flags.FLAGS)
118
+
119
+
120
+ def use_float16():
121
+ return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
122
+
123
+
124
+ def get_loss_scale():
125
+ return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
modeling/official/legacy/bert/configs.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """The main BERT model and related functions."""
16
+
17
+ import copy
18
+ import json
19
+
20
+ import six
21
+ import tensorflow as tf, tf_keras
22
+
23
+
24
+ class BertConfig(object):
25
+ """Configuration for `BertModel`."""
26
+
27
+ def __init__(self,
28
+ vocab_size,
29
+ hidden_size=768,
30
+ num_hidden_layers=12,
31
+ num_attention_heads=12,
32
+ intermediate_size=3072,
33
+ hidden_act="gelu",
34
+ hidden_dropout_prob=0.1,
35
+ attention_probs_dropout_prob=0.1,
36
+ max_position_embeddings=512,
37
+ type_vocab_size=16,
38
+ initializer_range=0.02,
39
+ embedding_size=None,
40
+ backward_compatible=True):
41
+ """Constructs BertConfig.
42
+
43
+ Args:
44
+ vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
45
+ hidden_size: Size of the encoder layers and the pooler layer.
46
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
47
+ num_attention_heads: Number of attention heads for each attention layer in
48
+ the Transformer encoder.
49
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
50
+ layer in the Transformer encoder.
51
+ hidden_act: The non-linear activation function (function or string) in the
52
+ encoder and pooler.
53
+ hidden_dropout_prob: The dropout probability for all fully connected
54
+ layers in the embeddings, encoder, and pooler.
55
+ attention_probs_dropout_prob: The dropout ratio for the attention
56
+ probabilities.
57
+ max_position_embeddings: The maximum sequence length that this model might
58
+ ever be used with. Typically set this to something large just in case
59
+ (e.g., 512 or 1024 or 2048).
60
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
61
+ `BertModel`.
62
+ initializer_range: The stdev of the truncated_normal_initializer for
63
+ initializing all weight matrices.
64
+ embedding_size: (Optional) width of the factorized word embeddings.
65
+ backward_compatible: Boolean, whether the variables shape are compatible
66
+ with checkpoints converted from TF 1.x BERT.
67
+ """
68
+ self.vocab_size = vocab_size
69
+ self.hidden_size = hidden_size
70
+ self.num_hidden_layers = num_hidden_layers
71
+ self.num_attention_heads = num_attention_heads
72
+ self.hidden_act = hidden_act
73
+ self.intermediate_size = intermediate_size
74
+ self.hidden_dropout_prob = hidden_dropout_prob
75
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
76
+ self.max_position_embeddings = max_position_embeddings
77
+ self.type_vocab_size = type_vocab_size
78
+ self.initializer_range = initializer_range
79
+ self.embedding_size = embedding_size
80
+ self.backward_compatible = backward_compatible
81
+
82
+ @classmethod
83
+ def from_dict(cls, json_object):
84
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
85
+ config = BertConfig(vocab_size=None)
86
+ for (key, value) in six.iteritems(json_object):
87
+ config.__dict__[key] = value
88
+ return config
89
+
90
+ @classmethod
91
+ def from_json_file(cls, json_file):
92
+ """Constructs a `BertConfig` from a json file of parameters."""
93
+ with tf.io.gfile.GFile(json_file, "r") as reader:
94
+ text = reader.read()
95
+ return cls.from_dict(json.loads(text))
96
+
97
+ def to_dict(self):
98
+ """Serializes this instance to a Python dictionary."""
99
+ output = copy.deepcopy(self.__dict__)
100
+ return output
101
+
102
+ def to_json_string(self):
103
+ """Serializes this instance to a JSON string."""
104
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
modeling/official/legacy/bert/export_tfhub.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """A script to export BERT as a TF-Hub SavedModel.
16
+
17
+ This script is **DEPRECATED** for exporting BERT encoder models;
18
+ see the error message in by main() for details.
19
+ """
20
+
21
+ from typing import Text
22
+
23
+ # Import libraries
24
+ from absl import app
25
+ from absl import flags
26
+ from absl import logging
27
+ import tensorflow as tf, tf_keras
28
+ from official.legacy.bert import bert_models
29
+ from official.legacy.bert import configs
30
+
31
+ FLAGS = flags.FLAGS
32
+
33
+ flags.DEFINE_string("bert_config_file", None,
34
+ "Bert configuration file to define core bert layers.")
35
+ flags.DEFINE_string("model_checkpoint_path", None,
36
+ "File path to TF model checkpoint.")
37
+ flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
38
+ flags.DEFINE_string("vocab_file", None,
39
+ "The vocabulary file that the BERT model was trained on.")
40
+ flags.DEFINE_bool(
41
+ "do_lower_case", None, "Whether to lowercase. If None, "
42
+ "do_lower_case will be enabled if 'uncased' appears in the "
43
+ "name of --vocab_file")
44
+ flags.DEFINE_enum("model_type", "encoder", ["encoder", "squad"],
45
+ "What kind of BERT model to export.")
46
+
47
+
48
+ def create_bert_model(bert_config: configs.BertConfig) -> tf_keras.Model:
49
+ """Creates a BERT keras core model from BERT configuration.
50
+
51
+ Args:
52
+ bert_config: A `BertConfig` to create the core model.
53
+
54
+ Returns:
55
+ A keras model.
56
+ """
57
+ # Adds input layers just as placeholders.
58
+ input_word_ids = tf_keras.layers.Input(
59
+ shape=(None,), dtype=tf.int32, name="input_word_ids")
60
+ input_mask = tf_keras.layers.Input(
61
+ shape=(None,), dtype=tf.int32, name="input_mask")
62
+ input_type_ids = tf_keras.layers.Input(
63
+ shape=(None,), dtype=tf.int32, name="input_type_ids")
64
+ transformer_encoder = bert_models.get_transformer_encoder(
65
+ bert_config, sequence_length=None)
66
+ sequence_output, pooled_output = transformer_encoder(
67
+ [input_word_ids, input_mask, input_type_ids])
68
+ # To keep consistent with legacy hub modules, the outputs are
69
+ # "pooled_output" and "sequence_output".
70
+ return tf_keras.Model(
71
+ inputs=[input_word_ids, input_mask, input_type_ids],
72
+ outputs=[pooled_output, sequence_output]), transformer_encoder
73
+
74
+
75
+ def export_bert_tfhub(bert_config: configs.BertConfig,
76
+ model_checkpoint_path: Text,
77
+ hub_destination: Text,
78
+ vocab_file: Text,
79
+ do_lower_case: bool = None):
80
+ """Restores a tf_keras.Model and saves for TF-Hub."""
81
+ # If do_lower_case is not explicit, default to checking whether "uncased" is
82
+ # in the vocab file name
83
+ if do_lower_case is None:
84
+ do_lower_case = "uncased" in vocab_file
85
+ logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
86
+ do_lower_case, vocab_file)
87
+ core_model, encoder = create_bert_model(bert_config)
88
+ checkpoint = tf.train.Checkpoint(
89
+ model=encoder, # Legacy checkpoints.
90
+ encoder=encoder)
91
+ checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
92
+ core_model.vocab_file = tf.saved_model.Asset(vocab_file)
93
+ core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
94
+ core_model.save(hub_destination, include_optimizer=False, save_format="tf")
95
+
96
+
97
+ def export_bert_squad_tfhub(bert_config: configs.BertConfig,
98
+ model_checkpoint_path: Text,
99
+ hub_destination: Text,
100
+ vocab_file: Text,
101
+ do_lower_case: bool = None):
102
+ """Restores a tf_keras.Model for BERT with SQuAD and saves for TF-Hub."""
103
+ # If do_lower_case is not explicit, default to checking whether "uncased" is
104
+ # in the vocab file name
105
+ if do_lower_case is None:
106
+ do_lower_case = "uncased" in vocab_file
107
+ logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
108
+ do_lower_case, vocab_file)
109
+ span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None)
110
+ checkpoint = tf.train.Checkpoint(model=span_labeling)
111
+ checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
112
+ span_labeling.vocab_file = tf.saved_model.Asset(vocab_file)
113
+ span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False)
114
+ span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
115
+
116
+
117
+ def main(_):
118
+ bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
119
+ if FLAGS.model_type == "encoder":
120
+ deprecation_note = (
121
+ "nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
122
+ "models. Please switch to nlp/tools/export_tfhub for exporting BERT "
123
+ "(and other) encoders with dict inputs/outputs conforming to "
124
+ "https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
125
+ )
126
+ logging.error(deprecation_note)
127
+ print("\n\nNOTICE:", deprecation_note, "\n")
128
+ export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
129
+ FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
130
+ elif FLAGS.model_type == "squad":
131
+ export_bert_squad_tfhub(bert_config, FLAGS.model_checkpoint_path,
132
+ FLAGS.export_path, FLAGS.vocab_file,
133
+ FLAGS.do_lower_case)
134
+ else:
135
+ raise ValueError("Unsupported model_type %s." % FLAGS.model_type)
136
+
137
+
138
+ if __name__ == "__main__":
139
+ app.run(main)