Spaces:
Runtime error
Runtime error
deanna-emery
commited on
Commit
•
5672777
1
Parent(s):
9e6df20
updates
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- modeling/official/README-TPU.md +32 -0
- modeling/official/README.md +166 -0
- modeling/official/__init__.py +14 -0
- modeling/official/common/__init__.py +15 -0
- modeling/official/common/dataset_fn.py +44 -0
- modeling/official/common/distribute_utils.py +233 -0
- modeling/official/common/distribute_utils_test.py +124 -0
- modeling/official/common/flags.py +114 -0
- modeling/official/common/registry_imports.py +20 -0
- modeling/official/common/streamz_counters.py +27 -0
- modeling/official/core/__init__.py +31 -0
- modeling/official/core/actions.py +236 -0
- modeling/official/core/actions_test.py +131 -0
- modeling/official/core/base_task.py +360 -0
- modeling/official/core/base_trainer.py +498 -0
- modeling/official/core/base_trainer_test.py +363 -0
- modeling/official/core/config_definitions.py +309 -0
- modeling/official/core/exp_factory.py +32 -0
- modeling/official/core/export_base.py +182 -0
- modeling/official/core/export_base_test.py +133 -0
- modeling/official/core/file_writers.py +80 -0
- modeling/official/core/file_writers_test.py +53 -0
- modeling/official/core/input_reader.py +591 -0
- modeling/official/core/registry.py +101 -0
- modeling/official/core/registry_test.py +88 -0
- modeling/official/core/savedmodel_checkpoint_manager.py +258 -0
- modeling/official/core/savedmodel_checkpoint_manager_test.py +125 -0
- modeling/official/core/task_factory.py +70 -0
- modeling/official/core/test_utils.py +59 -0
- modeling/official/core/tf_example_builder.py +144 -0
- modeling/official/core/tf_example_builder_test.py +165 -0
- modeling/official/core/tf_example_feature_key.py +62 -0
- modeling/official/core/tf_example_feature_key_test.py +49 -0
- modeling/official/core/train_lib.py +372 -0
- modeling/official/core/train_lib_test.py +280 -0
- modeling/official/core/train_utils.py +610 -0
- modeling/official/core/train_utils_test.py +215 -0
- modeling/official/legacy/README.md +5 -0
- modeling/official/legacy/__init__.py +14 -0
- modeling/official/legacy/albert/README.md +4 -0
- modeling/official/legacy/albert/__init__.py +14 -0
- modeling/official/legacy/albert/configs.py +50 -0
- modeling/official/legacy/bert/README.md +395 -0
- modeling/official/legacy/bert/__init__.py +15 -0
- modeling/official/legacy/bert/bert_cloud_tpu.md +110 -0
- modeling/official/legacy/bert/bert_models.py +365 -0
- modeling/official/legacy/bert/bert_models_test.py +106 -0
- modeling/official/legacy/bert/common_flags.py +125 -0
- modeling/official/legacy/bert/configs.py +104 -0
- modeling/official/legacy/bert/export_tfhub.py +139 -0
modeling/official/README-TPU.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Offically Supported TensorFlow 2.1+ Models on Cloud TPU
|
2 |
+
|
3 |
+
## Natural Language Processing
|
4 |
+
|
5 |
+
* [bert](nlp/bert): A powerful pre-trained language representation model:
|
6 |
+
BERT, which stands for Bidirectional Encoder Representations from
|
7 |
+
Transformers.
|
8 |
+
[BERT FineTuning with Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/bert-2.x) provides step by step instructions on Cloud TPU training. You can look [Bert MNLI Tensorboard.dev metrics](https://tensorboard.dev/experiment/LijZ1IrERxKALQfr76gndA) for MNLI fine tuning task.
|
9 |
+
* [transformer](nlp/transformer): A transformer model to translate the WMT
|
10 |
+
English to German dataset.
|
11 |
+
[Training transformer on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/transformer-2.x) for step by step instructions on Cloud TPU training.
|
12 |
+
|
13 |
+
## Computer Vision
|
14 |
+
|
15 |
+
* [efficientnet](vision/image_classification): A family of convolutional
|
16 |
+
neural networks that scale by balancing network depth, width, and
|
17 |
+
resolution and can be used to classify ImageNet's dataset of 1000 classes.
|
18 |
+
See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/KnaWjrq5TXGfv0NW5m7rpg/#scalars).
|
19 |
+
* [mnist](vision/image_classification): A basic model to classify digits
|
20 |
+
from the MNIST dataset. See [Running MNIST on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/mnist-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/mIah5lppTASvrHqWrdr6NA).
|
21 |
+
* [mask-rcnn](vision/detection): An object detection and instance segmentation model. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/LH7k0fMsRwqUAcE09o9kPA).
|
22 |
+
* [resnet](vision/image_classification): A deep residual network that can
|
23 |
+
be used to classify ImageNet's dataset of 1000 classes.
|
24 |
+
See [Training ResNet on Cloud TPU](https://cloud.google.com/tpu/docs/tutorials/resnet-2.x) tutorial and [Tensorboard.dev metrics](https://tensorboard.dev/experiment/CxlDK8YMRrSpYEGtBRpOhg).
|
25 |
+
* [retinanet](vision/detection): A fast and powerful object detector. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/b8NRnWU3TqG6Rw0UxueU6Q).
|
26 |
+
* [shapemask](vision/detection): An object detection and instance segmentation model using shape priors. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/ZbXgVoc6Rf6mBRlPj0JpLA).
|
27 |
+
|
28 |
+
## Recommendation
|
29 |
+
* [dlrm](recommendation/ranking): [Deep Learning Recommendation Model for
|
30 |
+
Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
|
31 |
+
* [dcn v2](recommendation/ranking): [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535).
|
32 |
+
* [ncf](recommendation): Neural Collaborative Filtering. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/0k3gKjZlR1ewkVTRyLB6IQ).
|
modeling/official/README.md
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src="https://storage.googleapis.com/tf_model_garden/tf_model_garden_logo.png">
|
3 |
+
</div>
|
4 |
+
|
5 |
+
# TensorFlow Official Models
|
6 |
+
|
7 |
+
The TensorFlow official models are a collection of models
|
8 |
+
that use TensorFlow’s high-level APIs.
|
9 |
+
They are intended to be well-maintained, tested, and kept up to date
|
10 |
+
with the latest TensorFlow API.
|
11 |
+
|
12 |
+
They should also be reasonably optimized for fast performance while still
|
13 |
+
being easy to read.
|
14 |
+
These models are used as end-to-end tests, ensuring that the models run
|
15 |
+
with the same or improved speed and performance with each new TensorFlow build.
|
16 |
+
|
17 |
+
The API documentation of the latest stable release is published to
|
18 |
+
[tensorflow.org](https://www.tensorflow.org/api_docs/python/tfm).
|
19 |
+
|
20 |
+
## More models to come!
|
21 |
+
|
22 |
+
The team is actively developing new models.
|
23 |
+
In the near future, we will add:
|
24 |
+
|
25 |
+
* State-of-the-art language understanding models.
|
26 |
+
* State-of-the-art image classification models.
|
27 |
+
* State-of-the-art object detection and instance segmentation models.
|
28 |
+
* State-of-the-art video classification models.
|
29 |
+
|
30 |
+
## Table of Contents
|
31 |
+
|
32 |
+
- [Models and Implementations](#models-and-implementations)
|
33 |
+
* [Computer Vision](#computer-vision)
|
34 |
+
+ [Image Classification](#image-classification)
|
35 |
+
+ [Object Detection and Segmentation](#object-detection-and-segmentation)
|
36 |
+
+ [Video Classification](#video-classification)
|
37 |
+
* [Natural Language Processing](#natural-language-processing)
|
38 |
+
* [Recommendation](#recommendation)
|
39 |
+
- [How to get started with the official models](#how-to-get-started-with-the-official-models)
|
40 |
+
- [Contributions](#contributions)
|
41 |
+
|
42 |
+
## Models and Implementations
|
43 |
+
|
44 |
+
### [Computer Vision](vision/README.md)
|
45 |
+
|
46 |
+
#### Image Classification
|
47 |
+
|
48 |
+
| Model | Reference (Paper) |
|
49 |
+
|-------|-------------------|
|
50 |
+
| [ResNet](vision/MODEL_GARDEN.md) | [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) |
|
51 |
+
| [ResNet-RS](vision/MODEL_GARDEN.md) | [Revisiting ResNets: Improved Training and Scaling Strategies](https://arxiv.org/abs/2103.07579) |
|
52 |
+
| [EfficientNet](vision/MODEL_GARDEN.md) | [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](https://arxiv.org/abs/1905.11946) |
|
53 |
+
| [Vision Transformer](vision/MODEL_GARDEN.md) | [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) |
|
54 |
+
|
55 |
+
#### Object Detection and Segmentation
|
56 |
+
|
57 |
+
| Model | Reference (Paper) |
|
58 |
+
|-------|-------------------|
|
59 |
+
| [RetinaNet](vision/MODEL_GARDEN.md) | [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002) |
|
60 |
+
| [Mask R-CNN](vision/MODEL_GARDEN.md) | [Mask R-CNN](https://arxiv.org/abs/1703.06870) |
|
61 |
+
| [YOLO](projects/yolo/README.md) | [YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors](https://arxiv.org/abs/2207.02696) |
|
62 |
+
| [SpineNet](vision/MODEL_GARDEN.md) | [SpineNet: Learning Scale-Permuted Backbone for Recognition and Localization](https://arxiv.org/abs/1912.05027) |
|
63 |
+
| [Cascade RCNN-RS and RetinaNet-RS](vision/MODEL_GARDEN.md) | [Simple Training Strategies and Model Scaling for Object Detection](https://arxiv.org/abs/2107.00057)|
|
64 |
+
|
65 |
+
#### Video Classification
|
66 |
+
|
67 |
+
| Model | Reference (Paper) |
|
68 |
+
|-------|-------------------|
|
69 |
+
| [Mobile Video Networks (MoViNets)](projects/movinet) | [MoViNets: Mobile Video Networks for Efficient Video Recognition](https://arxiv.org/abs/2103.11511) |
|
70 |
+
|
71 |
+
### [Natural Language Processing](nlp/README.md)
|
72 |
+
|
73 |
+
#### Pre-trained Language Model
|
74 |
+
|
75 |
+
| Model | Reference (Paper) |
|
76 |
+
|-------|-------------------|
|
77 |
+
| [ALBERT](nlp/MODEL_GARDEN.md#available-model-configs) | [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942) |
|
78 |
+
| [BERT](nlp/MODEL_GARDEN.md#available-model-configs) | [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) |
|
79 |
+
| [ELECTRA](nlp/tasks/electra_task.py) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) |
|
80 |
+
|
81 |
+
|
82 |
+
#### Neural Machine Translation
|
83 |
+
|
84 |
+
| Model | Reference (Paper) |
|
85 |
+
|-------|-------------------|
|
86 |
+
| [Transformer](nlp/MODEL_GARDEN.md#available-model-configs) | [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
|
87 |
+
|
88 |
+
#### Natural Language Generation
|
89 |
+
|
90 |
+
| Model | Reference (Paper) |
|
91 |
+
|-------|-------------------|
|
92 |
+
| [NHNet (News Headline generation model)](projects/nhnet) | [Generating Representative Headlines for News Stories](https://arxiv.org/abs/2001.09386) |
|
93 |
+
|
94 |
+
|
95 |
+
#### Knowledge Distillation
|
96 |
+
|
97 |
+
| Model | Reference (Paper) |
|
98 |
+
|-------|-------------------|
|
99 |
+
| [MobileBERT](projects/mobilebert) | [MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices](https://arxiv.org/abs/2004.02984) |
|
100 |
+
|
101 |
+
### Recommendation
|
102 |
+
|
103 |
+
Model | Reference (Paper)
|
104 |
+
-------------------------------- | -----------------
|
105 |
+
[DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
|
106 |
+
[DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
|
107 |
+
[NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)
|
108 |
+
|
109 |
+
## How to get started with the official models
|
110 |
+
|
111 |
+
* The official models in the master branch are developed using
|
112 |
+
[master branch of TensorFlow 2](https://github.com/tensorflow/tensorflow/tree/master).
|
113 |
+
When you clone (the repository) or download (`pip` binary) master branch of
|
114 |
+
official models , master branch of TensorFlow gets downloaded as a
|
115 |
+
dependency. This is equivalent to the following.
|
116 |
+
|
117 |
+
```shell
|
118 |
+
pip3 install tf-models-nightly
|
119 |
+
pip3 install tensorflow-text-nightly # when model uses `nlp` packages
|
120 |
+
```
|
121 |
+
|
122 |
+
* Incase of stable versions, targeting a specific release, Tensorflow-models
|
123 |
+
repository version numbers match with the target TensorFlow release. For
|
124 |
+
example, [TensorFlow-models v2.8.x](https://github.com/tensorflow/models/releases/tag/v2.8.0)
|
125 |
+
is compatible with [TensorFlow v2.8.x](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0).
|
126 |
+
This is equivalent to the following:
|
127 |
+
|
128 |
+
```shell
|
129 |
+
pip3 install tf-models-official==2.8.0
|
130 |
+
pip3 install tensorflow-text==2.8.0 # when models in uses `nlp` packages
|
131 |
+
```
|
132 |
+
|
133 |
+
Starting from 2.9.x release, we release the modeling library as
|
134 |
+
`tensorflow_models` package and users can `import tensorflow_models` directly to
|
135 |
+
access to the exported symbols. If you are
|
136 |
+
using the latest nightly version or github code directly, please follow the
|
137 |
+
docstrings in the github.
|
138 |
+
|
139 |
+
Please follow the below steps before running models in this repository.
|
140 |
+
|
141 |
+
### Requirements
|
142 |
+
|
143 |
+
* The latest TensorFlow Model Garden release and the latest TensorFlow 2
|
144 |
+
* If you are on a version of TensorFlow earlier than 2.2, please
|
145 |
+
upgrade your TensorFlow to [the latest TensorFlow 2](https://www.tensorflow.org/install/).
|
146 |
+
* Python 3.7+
|
147 |
+
|
148 |
+
Our integration tests run with Python 3.7. Although Python 3.6 should work, we
|
149 |
+
don't recommend earlier versions.
|
150 |
+
|
151 |
+
### Installation
|
152 |
+
|
153 |
+
Please check [here](https://github.com/tensorflow/models#Installation) for the
|
154 |
+
instructions.
|
155 |
+
|
156 |
+
Available pypi packages:
|
157 |
+
|
158 |
+
* [tf-models-official](https://pypi.org/project/tf-models-official/)
|
159 |
+
* [tf-models-nightly](https://pypi.org/project/tf-models-nightly/): nightly
|
160 |
+
release with the latest changes.
|
161 |
+
* [tf-models-no-deps](https://pypi.org/project/tf-models-no-deps/): without
|
162 |
+
`tensorflow` and `tensorflow-text` in the `install_requires` list.
|
163 |
+
|
164 |
+
## Contributions
|
165 |
+
|
166 |
+
If you want to contribute, please review the [contribution guidelines](https://github.com/tensorflow/models/wiki/How-to-contribute).
|
modeling/official/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
modeling/official/common/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
modeling/official/common/dataset_fn.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
16 |
+
#
|
17 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
18 |
+
# you may not use this file except in compliance with the License.
|
19 |
+
# You may obtain a copy of the License at
|
20 |
+
#
|
21 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
22 |
+
#
|
23 |
+
# Unless required by applicable law or agreed to in writing, software
|
24 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
25 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
26 |
+
# See the License for the specific language governing permissions and
|
27 |
+
# limitations under the License.
|
28 |
+
# ==============================================================================
|
29 |
+
"""Utility library for picking an appropriate dataset function."""
|
30 |
+
|
31 |
+
import functools
|
32 |
+
from typing import Any, Callable, Type, Union
|
33 |
+
|
34 |
+
import tensorflow as tf, tf_keras
|
35 |
+
|
36 |
+
PossibleDatasetType = Union[Type[tf.data.Dataset], Callable[[tf.Tensor], Any]]
|
37 |
+
|
38 |
+
|
39 |
+
def pick_dataset_fn(file_type: str) -> PossibleDatasetType:
|
40 |
+
if file_type == 'tfrecord':
|
41 |
+
return tf.data.TFRecordDataset
|
42 |
+
if file_type == 'tfrecord_compressed':
|
43 |
+
return functools.partial(tf.data.TFRecordDataset, compression_type='GZIP')
|
44 |
+
raise ValueError('Unrecognized file_type: {}'.format(file_type))
|
modeling/official/common/distribute_utils.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Helper functions for running models in a distributed setting."""
|
16 |
+
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
|
22 |
+
def _collective_communication(all_reduce_alg):
|
23 |
+
"""Return a CollectiveCommunication based on all_reduce_alg.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
all_reduce_alg: a string specifying which collective communication to pick,
|
27 |
+
or None.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
tf.distribute.experimental.CollectiveCommunication object
|
31 |
+
|
32 |
+
Raises:
|
33 |
+
ValueError: if `all_reduce_alg` not in [None, "ring", "nccl"]
|
34 |
+
"""
|
35 |
+
collective_communication_options = {
|
36 |
+
None: tf.distribute.experimental.CollectiveCommunication.AUTO,
|
37 |
+
"ring": tf.distribute.experimental.CollectiveCommunication.RING,
|
38 |
+
"nccl": tf.distribute.experimental.CollectiveCommunication.NCCL
|
39 |
+
}
|
40 |
+
if all_reduce_alg not in collective_communication_options:
|
41 |
+
raise ValueError(
|
42 |
+
"When used with `multi_worker_mirrored`, valid values for "
|
43 |
+
"all_reduce_alg are [`ring`, `nccl`]. Supplied value: {}".format(
|
44 |
+
all_reduce_alg))
|
45 |
+
return collective_communication_options[all_reduce_alg]
|
46 |
+
|
47 |
+
|
48 |
+
def _mirrored_cross_device_ops(all_reduce_alg, num_packs):
|
49 |
+
"""Return a CrossDeviceOps based on all_reduce_alg and num_packs.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
all_reduce_alg: a string specifying which cross device op to pick, or None.
|
53 |
+
num_packs: an integer specifying number of packs for the cross device op.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tf.distribute.CrossDeviceOps object or None.
|
57 |
+
|
58 |
+
Raises:
|
59 |
+
ValueError: if `all_reduce_alg` not in [None, "nccl", "hierarchical_copy"].
|
60 |
+
"""
|
61 |
+
if all_reduce_alg is None:
|
62 |
+
return None
|
63 |
+
mirrored_all_reduce_options = {
|
64 |
+
"nccl": tf.distribute.NcclAllReduce,
|
65 |
+
"hierarchical_copy": tf.distribute.HierarchicalCopyAllReduce
|
66 |
+
}
|
67 |
+
if all_reduce_alg not in mirrored_all_reduce_options:
|
68 |
+
raise ValueError(
|
69 |
+
"When used with `mirrored`, valid values for all_reduce_alg are "
|
70 |
+
"[`nccl`, `hierarchical_copy`]. Supplied value: {}".format(
|
71 |
+
all_reduce_alg))
|
72 |
+
cross_device_ops_class = mirrored_all_reduce_options[all_reduce_alg]
|
73 |
+
return cross_device_ops_class(num_packs=num_packs)
|
74 |
+
|
75 |
+
|
76 |
+
def tpu_initialize(tpu_address):
|
77 |
+
"""Initializes TPU for TF 2.x training.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
tpu_address: string, bns address of master TPU worker.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
A TPUClusterResolver.
|
84 |
+
"""
|
85 |
+
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
|
86 |
+
tpu=tpu_address)
|
87 |
+
if tpu_address not in ("", "local"):
|
88 |
+
tf.config.experimental_connect_to_cluster(cluster_resolver)
|
89 |
+
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
|
90 |
+
return cluster_resolver
|
91 |
+
|
92 |
+
|
93 |
+
def get_distribution_strategy(distribution_strategy="mirrored",
|
94 |
+
num_gpus=0,
|
95 |
+
all_reduce_alg=None,
|
96 |
+
num_packs=1,
|
97 |
+
tpu_address=None,
|
98 |
+
**kwargs):
|
99 |
+
"""Return a Strategy for running the model.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
distribution_strategy: a string specifying which distribution strategy to
|
103 |
+
use. Accepted values are "off", "one_device", "mirrored",
|
104 |
+
"parameter_server", "multi_worker_mirrored", and "tpu" -- case
|
105 |
+
insensitive. "tpu" means to use TPUStrategy using `tpu_address`.
|
106 |
+
"off" means to use the default strategy which is obtained from
|
107 |
+
tf.distribute.get_strategy (for details on the default strategy, see
|
108 |
+
https://www.tensorflow.org/guide/distributed_training#default_strategy).
|
109 |
+
num_gpus: Number of GPUs to run this model.
|
110 |
+
all_reduce_alg: Optional. Specifies which algorithm to use when performing
|
111 |
+
all-reduce. For `MirroredStrategy`, valid values are "nccl" and
|
112 |
+
"hierarchical_copy". For `MultiWorkerMirroredStrategy`, valid values are
|
113 |
+
"ring" and "nccl". If None, DistributionStrategy will choose based on
|
114 |
+
device topology.
|
115 |
+
num_packs: Optional. Sets the `num_packs` in `tf.distribute.NcclAllReduce`
|
116 |
+
or `tf.distribute.HierarchicalCopyAllReduce` for `MirroredStrategy`.
|
117 |
+
tpu_address: Optional. String that represents TPU to connect to. Must not be
|
118 |
+
None if `distribution_strategy` is set to `tpu`.
|
119 |
+
**kwargs: Additional kwargs for internal usages.
|
120 |
+
|
121 |
+
Returns:
|
122 |
+
tf.distribute.Strategy object.
|
123 |
+
Raises:
|
124 |
+
ValueError: if `distribution_strategy` is "off" or "one_device" and
|
125 |
+
`num_gpus` is larger than 1; or `num_gpus` is negative or if
|
126 |
+
`distribution_strategy` is `tpu` but `tpu_address` is not specified.
|
127 |
+
"""
|
128 |
+
del kwargs
|
129 |
+
if num_gpus < 0:
|
130 |
+
raise ValueError("`num_gpus` can not be negative.")
|
131 |
+
|
132 |
+
if not isinstance(distribution_strategy, str):
|
133 |
+
msg = ("distribution_strategy must be a string but got: %s." %
|
134 |
+
(distribution_strategy,))
|
135 |
+
if distribution_strategy == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison
|
136 |
+
msg += (" If you meant to pass the string 'off', make sure you add "
|
137 |
+
"quotes around 'off' so that yaml interprets it as a string "
|
138 |
+
"instead of a bool.")
|
139 |
+
raise ValueError(msg)
|
140 |
+
|
141 |
+
distribution_strategy = distribution_strategy.lower()
|
142 |
+
if distribution_strategy == "off":
|
143 |
+
if num_gpus > 1:
|
144 |
+
raise ValueError(f"When {num_gpus} GPUs are specified, "
|
145 |
+
"distribution_strategy flag cannot be set to `off`.")
|
146 |
+
# Return the default distribution strategy.
|
147 |
+
return tf.distribute.get_strategy()
|
148 |
+
|
149 |
+
if distribution_strategy == "tpu":
|
150 |
+
# When tpu_address is an empty string, we communicate with local TPUs.
|
151 |
+
cluster_resolver = tpu_initialize(tpu_address)
|
152 |
+
return tf.distribute.TPUStrategy(cluster_resolver)
|
153 |
+
|
154 |
+
if distribution_strategy == "multi_worker_mirrored":
|
155 |
+
return tf.distribute.experimental.MultiWorkerMirroredStrategy(
|
156 |
+
communication=_collective_communication(all_reduce_alg))
|
157 |
+
|
158 |
+
if distribution_strategy == "one_device":
|
159 |
+
if num_gpus == 0:
|
160 |
+
return tf.distribute.OneDeviceStrategy("device:CPU:0")
|
161 |
+
if num_gpus > 1:
|
162 |
+
raise ValueError("`OneDeviceStrategy` can not be used for more than "
|
163 |
+
"one device.")
|
164 |
+
return tf.distribute.OneDeviceStrategy("device:GPU:0")
|
165 |
+
|
166 |
+
if distribution_strategy == "mirrored":
|
167 |
+
if num_gpus == 0:
|
168 |
+
devices = ["device:CPU:0"]
|
169 |
+
else:
|
170 |
+
devices = ["device:GPU:%d" % i for i in range(num_gpus)]
|
171 |
+
return tf.distribute.MirroredStrategy(
|
172 |
+
devices=devices,
|
173 |
+
cross_device_ops=_mirrored_cross_device_ops(all_reduce_alg, num_packs))
|
174 |
+
|
175 |
+
if distribution_strategy == "parameter_server":
|
176 |
+
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
|
177 |
+
return tf.distribute.experimental.ParameterServerStrategy(cluster_resolver)
|
178 |
+
|
179 |
+
raise ValueError("Unrecognized Distribution Strategy: %r" %
|
180 |
+
distribution_strategy)
|
181 |
+
|
182 |
+
|
183 |
+
def configure_cluster(worker_hosts=None, task_index=-1):
|
184 |
+
"""Set multi-worker cluster spec in TF_CONFIG environment variable.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
worker_hosts: comma-separated list of worker ip:port pairs.
|
188 |
+
task_index: index of the worker.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
Number of workers in the cluster.
|
192 |
+
"""
|
193 |
+
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
194 |
+
if tf_config:
|
195 |
+
num_workers = (
|
196 |
+
len(tf_config["cluster"].get("chief", [])) +
|
197 |
+
len(tf_config["cluster"].get("worker", [])))
|
198 |
+
elif worker_hosts:
|
199 |
+
workers = worker_hosts.split(",")
|
200 |
+
num_workers = len(workers)
|
201 |
+
if num_workers > 1 and task_index < 0:
|
202 |
+
raise ValueError("Must specify task_index when number of workers > 1")
|
203 |
+
task_index = 0 if num_workers == 1 else task_index
|
204 |
+
os.environ["TF_CONFIG"] = json.dumps({
|
205 |
+
"cluster": {
|
206 |
+
"worker": workers
|
207 |
+
},
|
208 |
+
"task": {
|
209 |
+
"type": "worker",
|
210 |
+
"index": task_index
|
211 |
+
}
|
212 |
+
})
|
213 |
+
else:
|
214 |
+
num_workers = 1
|
215 |
+
return num_workers
|
216 |
+
|
217 |
+
|
218 |
+
def get_strategy_scope(strategy):
|
219 |
+
if strategy:
|
220 |
+
strategy_scope = strategy.scope()
|
221 |
+
else:
|
222 |
+
strategy_scope = DummyContextManager()
|
223 |
+
|
224 |
+
return strategy_scope
|
225 |
+
|
226 |
+
|
227 |
+
class DummyContextManager(object):
|
228 |
+
|
229 |
+
def __enter__(self):
|
230 |
+
pass
|
231 |
+
|
232 |
+
def __exit__(self, *args):
|
233 |
+
pass
|
modeling/official/common/distribute_utils_test.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for distribution util functions."""
|
16 |
+
|
17 |
+
import sys
|
18 |
+
import tensorflow as tf, tf_keras
|
19 |
+
|
20 |
+
from official.common import distribute_utils
|
21 |
+
|
22 |
+
TPU_TEST = 'test_tpu' in sys.argv[0]
|
23 |
+
|
24 |
+
|
25 |
+
class DistributeUtilsTest(tf.test.TestCase):
|
26 |
+
"""Tests for distribute util functions."""
|
27 |
+
|
28 |
+
def test_invalid_args(self):
|
29 |
+
with self.assertRaisesRegex(ValueError, '`num_gpus` can not be negative.'):
|
30 |
+
_ = distribute_utils.get_distribution_strategy(num_gpus=-1)
|
31 |
+
|
32 |
+
with self.assertRaisesRegex(ValueError,
|
33 |
+
'.*If you meant to pass the string .*'):
|
34 |
+
_ = distribute_utils.get_distribution_strategy(
|
35 |
+
distribution_strategy=False, num_gpus=0)
|
36 |
+
with self.assertRaisesRegex(ValueError, 'When 2 GPUs are specified.*'):
|
37 |
+
_ = distribute_utils.get_distribution_strategy(
|
38 |
+
distribution_strategy='off', num_gpus=2)
|
39 |
+
with self.assertRaisesRegex(ValueError,
|
40 |
+
'`OneDeviceStrategy` can not be used.*'):
|
41 |
+
_ = distribute_utils.get_distribution_strategy(
|
42 |
+
distribution_strategy='one_device', num_gpus=2)
|
43 |
+
|
44 |
+
def test_one_device_strategy_cpu(self):
|
45 |
+
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
|
46 |
+
self.assertEquals(ds.num_replicas_in_sync, 1)
|
47 |
+
self.assertEquals(len(ds.extended.worker_devices), 1)
|
48 |
+
self.assertIn('CPU', ds.extended.worker_devices[0])
|
49 |
+
|
50 |
+
def test_one_device_strategy_gpu(self):
|
51 |
+
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=1)
|
52 |
+
self.assertEquals(ds.num_replicas_in_sync, 1)
|
53 |
+
self.assertEquals(len(ds.extended.worker_devices), 1)
|
54 |
+
self.assertIn('GPU', ds.extended.worker_devices[0])
|
55 |
+
|
56 |
+
def test_mirrored_strategy(self):
|
57 |
+
# CPU only.
|
58 |
+
_ = distribute_utils.get_distribution_strategy(num_gpus=0)
|
59 |
+
# 5 GPUs.
|
60 |
+
ds = distribute_utils.get_distribution_strategy(num_gpus=5)
|
61 |
+
self.assertEquals(ds.num_replicas_in_sync, 5)
|
62 |
+
self.assertEquals(len(ds.extended.worker_devices), 5)
|
63 |
+
for device in ds.extended.worker_devices:
|
64 |
+
self.assertIn('GPU', device)
|
65 |
+
|
66 |
+
_ = distribute_utils.get_distribution_strategy(
|
67 |
+
distribution_strategy='mirrored',
|
68 |
+
num_gpus=2,
|
69 |
+
all_reduce_alg='nccl',
|
70 |
+
num_packs=2)
|
71 |
+
with self.assertRaisesRegex(
|
72 |
+
ValueError,
|
73 |
+
'When used with `mirrored`, valid values for all_reduce_alg are.*'):
|
74 |
+
_ = distribute_utils.get_distribution_strategy(
|
75 |
+
distribution_strategy='mirrored',
|
76 |
+
num_gpus=2,
|
77 |
+
all_reduce_alg='dummy',
|
78 |
+
num_packs=2)
|
79 |
+
|
80 |
+
def test_mwms(self):
|
81 |
+
distribute_utils.configure_cluster(worker_hosts=None, task_index=-1)
|
82 |
+
ds = distribute_utils.get_distribution_strategy(
|
83 |
+
'multi_worker_mirrored', all_reduce_alg='nccl')
|
84 |
+
self.assertIsInstance(
|
85 |
+
ds, tf.distribute.experimental.MultiWorkerMirroredStrategy)
|
86 |
+
|
87 |
+
with self.assertRaisesRegex(
|
88 |
+
ValueError,
|
89 |
+
'When used with `multi_worker_mirrored`, valid values.*'):
|
90 |
+
_ = distribute_utils.get_distribution_strategy(
|
91 |
+
'multi_worker_mirrored', all_reduce_alg='dummy')
|
92 |
+
|
93 |
+
def test_no_strategy(self):
|
94 |
+
ds = distribute_utils.get_distribution_strategy('off')
|
95 |
+
self.assertIs(ds, tf.distribute.get_strategy())
|
96 |
+
|
97 |
+
def test_tpu_strategy(self):
|
98 |
+
if not TPU_TEST:
|
99 |
+
self.skipTest('Only Cloud TPU VM instances can have local TPUs.')
|
100 |
+
with self.assertRaises(ValueError):
|
101 |
+
_ = distribute_utils.get_distribution_strategy('tpu')
|
102 |
+
|
103 |
+
ds = distribute_utils.get_distribution_strategy('tpu', tpu_address='local')
|
104 |
+
self.assertIsInstance(
|
105 |
+
ds, tf.distribute.TPUStrategy)
|
106 |
+
|
107 |
+
def test_invalid_strategy(self):
|
108 |
+
with self.assertRaisesRegexp(
|
109 |
+
ValueError,
|
110 |
+
'distribution_strategy must be a string but got: False. If'):
|
111 |
+
distribute_utils.get_distribution_strategy(False)
|
112 |
+
with self.assertRaisesRegexp(
|
113 |
+
ValueError, 'distribution_strategy must be a string but got: 1'):
|
114 |
+
distribute_utils.get_distribution_strategy(1)
|
115 |
+
|
116 |
+
def test_get_strategy_scope(self):
|
117 |
+
ds = distribute_utils.get_distribution_strategy('one_device', num_gpus=0)
|
118 |
+
with distribute_utils.get_strategy_scope(ds):
|
119 |
+
self.assertIs(tf.distribute.get_strategy(), ds)
|
120 |
+
with distribute_utils.get_strategy_scope(None):
|
121 |
+
self.assertIsNot(tf.distribute.get_strategy(), ds)
|
122 |
+
|
123 |
+
if __name__ == '__main__':
|
124 |
+
tf.test.main()
|
modeling/official/common/flags.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""The central place to define flags."""
|
16 |
+
|
17 |
+
from absl import flags
|
18 |
+
|
19 |
+
|
20 |
+
def define_flags():
|
21 |
+
"""Defines flags.
|
22 |
+
|
23 |
+
All flags are defined as optional, but in practice most models use some of
|
24 |
+
these flags and so mark_flags_as_required() should be called after calling
|
25 |
+
this function. Typically, 'experiment', 'mode', and 'model_dir' are required.
|
26 |
+
For example:
|
27 |
+
|
28 |
+
```
|
29 |
+
from absl import flags
|
30 |
+
from official.common import flags as tfm_flags # pylint: disable=line-too-long
|
31 |
+
...
|
32 |
+
tfm_flags.define_flags()
|
33 |
+
flags.mark_flags_as_required(['experiment', 'mode', 'model_dir'])
|
34 |
+
```
|
35 |
+
|
36 |
+
The reason all flags are optional is because unit tests often do not set or
|
37 |
+
use any of the flags.
|
38 |
+
"""
|
39 |
+
flags.DEFINE_string(
|
40 |
+
'experiment', default=None, help=
|
41 |
+
'The experiment type registered, specifying an ExperimentConfig.')
|
42 |
+
|
43 |
+
flags.DEFINE_enum(
|
44 |
+
'mode',
|
45 |
+
default=None,
|
46 |
+
enum_values=[
|
47 |
+
'train', 'eval', 'train_and_eval', 'continuous_eval',
|
48 |
+
'continuous_train_and_eval', 'train_and_validate',
|
49 |
+
'train_and_post_eval'
|
50 |
+
],
|
51 |
+
help='Mode to run: `train`, `eval`, `train_and_eval`, '
|
52 |
+
'`continuous_eval`, `continuous_train_and_eval` and '
|
53 |
+
'`train_and_validate` (which is not implemented in '
|
54 |
+
'the open source version).')
|
55 |
+
|
56 |
+
flags.DEFINE_string(
|
57 |
+
'model_dir',
|
58 |
+
default=None,
|
59 |
+
help='The directory where the model and training/evaluation summaries'
|
60 |
+
'are stored.')
|
61 |
+
|
62 |
+
flags.DEFINE_multi_string(
|
63 |
+
'config_file',
|
64 |
+
default=None,
|
65 |
+
help='YAML/JSON files which specifies overrides. The override order '
|
66 |
+
'follows the order of args. Note that each file '
|
67 |
+
'can be used as an override template to override the default parameters '
|
68 |
+
'specified in Python. If the same parameter is specified in both '
|
69 |
+
'`--config_file` and `--params_override`, `config_file` will be used '
|
70 |
+
'first, followed by params_override.')
|
71 |
+
|
72 |
+
flags.DEFINE_string(
|
73 |
+
'params_override',
|
74 |
+
default=None,
|
75 |
+
help='a YAML/JSON string or a YAML file which specifies additional '
|
76 |
+
'overrides over the default parameters and those specified in '
|
77 |
+
'`--config_file`. Note that this is supposed to be used only to override '
|
78 |
+
'the model parameters, but not the parameters like TPU specific flags. '
|
79 |
+
'One canonical use case of `--config_file` and `--params_override` is '
|
80 |
+
'users first define a template config file using `--config_file`, then '
|
81 |
+
'use `--params_override` to adjust the minimal set of tuning parameters, '
|
82 |
+
'for example setting up different `train_batch_size`. The final override '
|
83 |
+
'order of parameters: default_model_params --> params from config_file '
|
84 |
+
'--> params in params_override. See also the help message of '
|
85 |
+
'`--config_file`.')
|
86 |
+
|
87 |
+
# The libraries rely on gin often make mistakes that include flags inside
|
88 |
+
# the library files which causes conflicts.
|
89 |
+
try:
|
90 |
+
flags.DEFINE_multi_string(
|
91 |
+
'gin_file', default=None, help='List of paths to the config files.')
|
92 |
+
except flags.DuplicateFlagError:
|
93 |
+
pass
|
94 |
+
|
95 |
+
try:
|
96 |
+
flags.DEFINE_multi_string(
|
97 |
+
'gin_params',
|
98 |
+
default=None,
|
99 |
+
help='Newline separated list of Gin parameter bindings.')
|
100 |
+
except flags.DuplicateFlagError:
|
101 |
+
pass
|
102 |
+
|
103 |
+
flags.DEFINE_string(
|
104 |
+
'tpu',
|
105 |
+
default=None,
|
106 |
+
help='The Cloud TPU to use for training. This should be either the name '
|
107 |
+
'used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 '
|
108 |
+
'url.')
|
109 |
+
|
110 |
+
flags.DEFINE_string(
|
111 |
+
'tf_data_service', default=None, help='The tf.data service address')
|
112 |
+
|
113 |
+
flags.DEFINE_string(
|
114 |
+
'tpu_platform', default=None, help='TPU platform type.')
|
modeling/official/common/registry_imports.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""All necessary imports for registration."""
|
16 |
+
# pylint: disable=unused-import
|
17 |
+
from official import vision
|
18 |
+
from official.nlp import tasks
|
19 |
+
from official.nlp.configs import experiment_configs
|
20 |
+
from official.utils.testing import mock_task
|
modeling/official/common/streamz_counters.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Global streamz counters."""
|
16 |
+
|
17 |
+
from tensorflow.python.eager import monitoring
|
18 |
+
|
19 |
+
|
20 |
+
progressive_policy_creation_counter = monitoring.Counter(
|
21 |
+
"/tensorflow/training/fast_training/progressive_policy_creation",
|
22 |
+
"Counter for the number of ProgressivePolicy creations.")
|
23 |
+
|
24 |
+
|
25 |
+
stack_vars_to_vars_call_counter = monitoring.Counter(
|
26 |
+
"/tensorflow/training/fast_training/tf_vars_to_vars",
|
27 |
+
"Counter for the number of low-level stacking API calls.")
|
modeling/official/core/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Core is shared by both `nlp` and `vision`."""
|
16 |
+
|
17 |
+
from official.core import actions
|
18 |
+
from official.core import base_task
|
19 |
+
from official.core import base_trainer
|
20 |
+
from official.core import config_definitions
|
21 |
+
from official.core import exp_factory
|
22 |
+
from official.core import export_base
|
23 |
+
from official.core import file_writers
|
24 |
+
from official.core import input_reader
|
25 |
+
from official.core import registry
|
26 |
+
from official.core import savedmodel_checkpoint_manager
|
27 |
+
from official.core import task_factory
|
28 |
+
from official.core import tf_example_builder
|
29 |
+
from official.core import tf_example_feature_key
|
30 |
+
from official.core import train_lib
|
31 |
+
from official.core import train_utils
|
modeling/official/core/actions.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Provides TFM orbit actions and associated helper functions/classes."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from typing import List
|
19 |
+
from absl import logging
|
20 |
+
|
21 |
+
import gin
|
22 |
+
import orbit
|
23 |
+
import tensorflow as tf, tf_keras
|
24 |
+
|
25 |
+
from official.core import base_trainer
|
26 |
+
from official.core import config_definitions
|
27 |
+
from official.modeling import optimization
|
28 |
+
|
29 |
+
|
30 |
+
class PruningAction:
|
31 |
+
"""Train action to updates pruning related information.
|
32 |
+
|
33 |
+
This action updates pruning steps at the end of trainig loop, and log
|
34 |
+
pruning metrics to tensorboard.
|
35 |
+
|
36 |
+
This action must be used when training a pruned model to avoid pruning error.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
export_dir: str,
|
42 |
+
model: tf_keras.Model,
|
43 |
+
optimizer: tf_keras.optimizers.Optimizer,
|
44 |
+
):
|
45 |
+
"""Initializes the instance.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
export_dir: `str` for the export directory of the pruning summaries.
|
49 |
+
model: `tf_keras.Model` model instance used for training. This will be
|
50 |
+
used to assign a pruning step to each prunable weight.
|
51 |
+
optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for
|
52 |
+
training. This will be used to find the current training steps.
|
53 |
+
"""
|
54 |
+
# TODO(b/221490190): Avoid local import when the bug is fixed.
|
55 |
+
import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
|
56 |
+
self._optimizer = optimizer
|
57 |
+
self.update_pruning_step = tfmot.sparsity.keras.UpdatePruningStep()
|
58 |
+
self.update_pruning_step.set_model(model)
|
59 |
+
self.update_pruning_step.on_train_begin()
|
60 |
+
|
61 |
+
self.pruning_summaries = tfmot.sparsity.keras.PruningSummaries(
|
62 |
+
log_dir=export_dir)
|
63 |
+
model.optimizer = optimizer
|
64 |
+
self.pruning_summaries.set_model(model)
|
65 |
+
|
66 |
+
def __call__(self, output: orbit.runner.Output):
|
67 |
+
"""Update pruning step and log pruning summaries.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
output: The train output.
|
71 |
+
"""
|
72 |
+
self.update_pruning_step.on_epoch_end(batch=None)
|
73 |
+
self.pruning_summaries.on_epoch_begin(epoch=None)
|
74 |
+
|
75 |
+
|
76 |
+
class EMACheckpointing:
|
77 |
+
"""Eval action to save checkpoint with average weights when EMA is used.
|
78 |
+
|
79 |
+
This action swaps the weights of the model with the average weights, then it
|
80 |
+
saves the checkpoint under export_dir/ema_checkpoints. Checkpointing is
|
81 |
+
expensive for large models, so doing this action in eval is more efficient
|
82 |
+
than training.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
export_dir: str,
|
87 |
+
optimizer: tf_keras.optimizers.Optimizer,
|
88 |
+
checkpoint: tf.train.Checkpoint,
|
89 |
+
max_to_keep: int = 1):
|
90 |
+
"""Initializes the instance.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
export_dir: `str` for the export directory of the EMA average weights.
|
94 |
+
optimizer: `tf_keras.optimizers.Optimizer` optimizer instance used for
|
95 |
+
training. This will be used to swap the model weights with the average
|
96 |
+
weigths.
|
97 |
+
checkpoint: `tf.train.Checkpoint` instance.
|
98 |
+
max_to_keep: `int` for max checkpoints to keep in ema_checkpoints subdir.
|
99 |
+
"""
|
100 |
+
if not isinstance(optimizer, optimization.ExponentialMovingAverage):
|
101 |
+
raise ValueError('Optimizer has to be instance of'
|
102 |
+
'optimization.ExponentialMovingAverage for'
|
103 |
+
'EMACheckpointing action')
|
104 |
+
|
105 |
+
export_dir = os.path.join(export_dir, 'ema_checkpoints')
|
106 |
+
tf.io.gfile.makedirs(os.path.dirname(export_dir))
|
107 |
+
self._optimizer = optimizer
|
108 |
+
self._checkpoint = checkpoint
|
109 |
+
self._checkpoint_manager = tf.train.CheckpointManager(
|
110 |
+
checkpoint,
|
111 |
+
directory=export_dir,
|
112 |
+
max_to_keep=max_to_keep,
|
113 |
+
checkpoint_name='average_weights')
|
114 |
+
|
115 |
+
def __call__(self, output: orbit.runner.Output):
|
116 |
+
"""Swaps model weights, and saves the checkpoint.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
output: The train or eval output.
|
120 |
+
"""
|
121 |
+
self._optimizer.swap_weights()
|
122 |
+
self._checkpoint_manager.save(checkpoint_number=self._optimizer.iterations)
|
123 |
+
self._optimizer.swap_weights()
|
124 |
+
|
125 |
+
|
126 |
+
class RecoveryAction:
|
127 |
+
"""Train action to recover from loss blowup.
|
128 |
+
|
129 |
+
Checks the loss value by the given threshold. If applicable, recover the
|
130 |
+
model by reading the checkpoint on disk.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, checkpoint_manager: tf.train.CheckpointManager):
|
134 |
+
self.checkpoint_manager = checkpoint_manager
|
135 |
+
|
136 |
+
def __call__(self, _):
|
137 |
+
"""Recovers the training by triggering checkpoint restoration."""
|
138 |
+
# Loads the previous good checkpoint.
|
139 |
+
checkpoint_path = self.checkpoint_manager.restore_or_initialize()
|
140 |
+
logging.warning('Recovering the model from checkpoint: %s.',
|
141 |
+
checkpoint_path)
|
142 |
+
|
143 |
+
|
144 |
+
class RecoveryCondition:
|
145 |
+
"""Recovery Condition."""
|
146 |
+
|
147 |
+
def __init__(self,
|
148 |
+
global_step: tf.Variable,
|
149 |
+
loss_upper_bound: float,
|
150 |
+
recovery_begin_steps: int = 0,
|
151 |
+
recovery_max_trials: int = 3):
|
152 |
+
self.recover_counter = 0
|
153 |
+
self.recovery_begin_steps = recovery_begin_steps
|
154 |
+
self.recovery_max_trials = recovery_max_trials
|
155 |
+
self.loss_upper_bound = loss_upper_bound
|
156 |
+
self.global_step = global_step
|
157 |
+
|
158 |
+
def __call__(self, outputs: orbit.runner.Output):
|
159 |
+
loss_value = outputs['training_loss']
|
160 |
+
if tf.math.is_nan(loss_value):
|
161 |
+
self.recover_counter += 1
|
162 |
+
if self.recover_counter > self.recovery_max_trials:
|
163 |
+
raise RuntimeError(
|
164 |
+
'The loss value is NaN after training loop and it happens %d times.'
|
165 |
+
% self.recover_counter)
|
166 |
+
return True
|
167 |
+
if (self.global_step >= self.recovery_begin_steps and
|
168 |
+
loss_value > self.loss_upper_bound):
|
169 |
+
self.recover_counter += 1
|
170 |
+
if self.recover_counter > self.recovery_max_trials:
|
171 |
+
raise RuntimeError(
|
172 |
+
f'The loss value is {loss_value}, which is larger than the bound {self.loss_upper_bound}, happens {self.recover_counter} times.'
|
173 |
+
)
|
174 |
+
return True
|
175 |
+
return False
|
176 |
+
|
177 |
+
|
178 |
+
@gin.configurable
|
179 |
+
def get_eval_actions(params: config_definitions.ExperimentConfig,
|
180 |
+
trainer: base_trainer.Trainer,
|
181 |
+
model_dir: str) -> List[orbit.Action]:
|
182 |
+
"""Gets eval actions for TFM trainer."""
|
183 |
+
eval_actions = []
|
184 |
+
# Adds ema checkpointing action to save the average weights under
|
185 |
+
# ema_checkpoints subdir.
|
186 |
+
if isinstance(trainer.optimizer, optimization.ExponentialMovingAverage):
|
187 |
+
eval_actions.append(
|
188 |
+
EMACheckpointing(
|
189 |
+
export_dir=model_dir,
|
190 |
+
optimizer=trainer.optimizer,
|
191 |
+
checkpoint=trainer.checkpoint,
|
192 |
+
max_to_keep=params.trainer.max_to_keep))
|
193 |
+
|
194 |
+
return eval_actions
|
195 |
+
|
196 |
+
|
197 |
+
@gin.configurable
|
198 |
+
def get_train_actions(
|
199 |
+
params: config_definitions.ExperimentConfig, trainer: base_trainer.Trainer,
|
200 |
+
model_dir: str,
|
201 |
+
checkpoint_manager: tf.train.CheckpointManager) -> List[orbit.Action]:
|
202 |
+
"""Gets train actions for TFM trainer."""
|
203 |
+
train_actions = []
|
204 |
+
# Adds pruning callback actions.
|
205 |
+
if hasattr(params.task, 'pruning') and params.task.pruning:
|
206 |
+
train_actions.append(
|
207 |
+
PruningAction(
|
208 |
+
export_dir=model_dir,
|
209 |
+
model=trainer.model,
|
210 |
+
optimizer=trainer.optimizer))
|
211 |
+
|
212 |
+
if params.trainer.recovery_max_trials >= 0:
|
213 |
+
recovery_condition = RecoveryCondition(
|
214 |
+
global_step=trainer.global_step,
|
215 |
+
loss_upper_bound=params.trainer.loss_upper_bound,
|
216 |
+
recovery_begin_steps=params.trainer.recovery_begin_steps,
|
217 |
+
recovery_max_trials=params.trainer.recovery_max_trials,
|
218 |
+
)
|
219 |
+
recover_action = orbit.actions.ConditionalAction(
|
220 |
+
condition=recovery_condition,
|
221 |
+
action=RecoveryAction(checkpoint_manager),
|
222 |
+
)
|
223 |
+
train_actions.append(recover_action)
|
224 |
+
|
225 |
+
if (
|
226 |
+
params.trainer.preemption_on_demand_checkpoint
|
227 |
+
and trainer.strategy.cluster_resolver
|
228 |
+
):
|
229 |
+
on_demand_checkpoint_action = orbit.actions.SaveCheckpointIfPreempted(
|
230 |
+
trainer.strategy.cluster_resolver,
|
231 |
+
checkpoint_manager,
|
232 |
+
trainer.global_step,
|
233 |
+
keep_running_after_save=True,
|
234 |
+
)
|
235 |
+
train_actions.append(on_demand_checkpoint_action)
|
236 |
+
return train_actions
|
modeling/official/core/actions_test.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for TFM actions."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
|
19 |
+
from absl.testing import parameterized
|
20 |
+
import numpy as np
|
21 |
+
import orbit
|
22 |
+
import tensorflow as tf, tf_keras
|
23 |
+
|
24 |
+
from tensorflow.python.distribute import combinations
|
25 |
+
from tensorflow.python.distribute import strategy_combinations
|
26 |
+
from official.core import actions
|
27 |
+
from official.modeling import optimization
|
28 |
+
|
29 |
+
|
30 |
+
class TestModel(tf_keras.Model):
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
self.value = tf.Variable(0.0)
|
35 |
+
self.dense = tf_keras.layers.Dense(2)
|
36 |
+
_ = self.dense(tf.zeros((2, 2), tf.float32))
|
37 |
+
|
38 |
+
def call(self, x, training=None):
|
39 |
+
return self.value + x
|
40 |
+
|
41 |
+
|
42 |
+
class ActionsTest(tf.test.TestCase, parameterized.TestCase):
|
43 |
+
|
44 |
+
@combinations.generate(
|
45 |
+
combinations.combine(
|
46 |
+
distribution=[
|
47 |
+
strategy_combinations.cloud_tpu_strategy,
|
48 |
+
strategy_combinations.one_device_strategy,
|
49 |
+
],))
|
50 |
+
def test_ema_checkpointing(self, distribution):
|
51 |
+
with distribution.scope():
|
52 |
+
directory = self.create_tempdir()
|
53 |
+
model = TestModel()
|
54 |
+
optimizer = tf_keras.optimizers.SGD()
|
55 |
+
optimizer = optimization.ExponentialMovingAverage(
|
56 |
+
optimizer, trainable_weights_only=False)
|
57 |
+
|
58 |
+
# Creats average weights for the model variables. Average weights are
|
59 |
+
# initialized to zero.
|
60 |
+
optimizer.shadow_copy(model)
|
61 |
+
checkpoint = tf.train.Checkpoint(model=model)
|
62 |
+
|
63 |
+
# Changes model.value to 3, average value is still 0.
|
64 |
+
model.value.assign(3)
|
65 |
+
|
66 |
+
# Checks model.value is 3
|
67 |
+
self.assertEqual(model(0.), 3)
|
68 |
+
ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
|
69 |
+
|
70 |
+
ema_action({})
|
71 |
+
self.assertNotEmpty(
|
72 |
+
tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
|
73 |
+
|
74 |
+
checkpoint.read(
|
75 |
+
tf.train.latest_checkpoint(
|
76 |
+
os.path.join(directory, 'ema_checkpoints')))
|
77 |
+
|
78 |
+
# Checks model.value is 0 after swapping.
|
79 |
+
self.assertEqual(model(0.), 0)
|
80 |
+
|
81 |
+
# Raises an error for a normal optimizer.
|
82 |
+
with self.assertRaisesRegex(ValueError,
|
83 |
+
'Optimizer has to be instance of.*'):
|
84 |
+
_ = actions.EMACheckpointing(directory, tf_keras.optimizers.SGD(),
|
85 |
+
checkpoint)
|
86 |
+
|
87 |
+
@combinations.generate(
|
88 |
+
combinations.combine(
|
89 |
+
distribution=[
|
90 |
+
strategy_combinations.default_strategy,
|
91 |
+
strategy_combinations.cloud_tpu_strategy,
|
92 |
+
strategy_combinations.one_device_strategy_gpu,
|
93 |
+
],))
|
94 |
+
def test_recovery_condition(self, distribution):
|
95 |
+
with distribution.scope():
|
96 |
+
global_step = orbit.utils.create_global_step()
|
97 |
+
recover_condition = actions.RecoveryCondition(
|
98 |
+
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
|
99 |
+
outputs = {'training_loss': 0.6}
|
100 |
+
self.assertTrue(recover_condition(outputs))
|
101 |
+
self.assertTrue(recover_condition(outputs))
|
102 |
+
with self.assertRaises(RuntimeError):
|
103 |
+
recover_condition(outputs)
|
104 |
+
|
105 |
+
global_step = orbit.utils.create_global_step()
|
106 |
+
recover_condition = actions.RecoveryCondition(
|
107 |
+
global_step, loss_upper_bound=0.5, recovery_max_trials=2)
|
108 |
+
outputs = {'training_loss': tf.constant([np.nan], tf.float32)}
|
109 |
+
self.assertTrue(recover_condition(outputs))
|
110 |
+
self.assertTrue(recover_condition(outputs))
|
111 |
+
with self.assertRaises(RuntimeError):
|
112 |
+
recover_condition(outputs)
|
113 |
+
|
114 |
+
@combinations.generate(
|
115 |
+
combinations.combine(
|
116 |
+
distribution=[
|
117 |
+
strategy_combinations.one_device_strategy_gpu,
|
118 |
+
strategy_combinations.one_device_strategy,
|
119 |
+
],))
|
120 |
+
def test_pruning(self, distribution):
|
121 |
+
with distribution.scope():
|
122 |
+
directory = self.get_temp_dir()
|
123 |
+
model = TestModel()
|
124 |
+
optimizer = tf_keras.optimizers.SGD()
|
125 |
+
pruning = actions.PruningAction(directory, model, optimizer)
|
126 |
+
|
127 |
+
pruning({})
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == '__main__':
|
131 |
+
tf.test.main()
|
modeling/official/core/base_task.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Defines the base task abstraction."""
|
16 |
+
import abc
|
17 |
+
import functools
|
18 |
+
from typing import Optional
|
19 |
+
|
20 |
+
from absl import logging
|
21 |
+
import tensorflow as tf, tf_keras
|
22 |
+
|
23 |
+
from official.core import config_definitions
|
24 |
+
from official.modeling import optimization
|
25 |
+
from official.modeling import performance
|
26 |
+
from official.modeling.privacy import configs
|
27 |
+
from official.modeling.privacy import ops
|
28 |
+
|
29 |
+
OptimizationConfig = optimization.OptimizationConfig
|
30 |
+
RuntimeConfig = config_definitions.RuntimeConfig
|
31 |
+
DifferentialPrivacyConfig = configs.DifferentialPrivacyConfig
|
32 |
+
|
33 |
+
|
34 |
+
class Task(tf.Module, metaclass=abc.ABCMeta):
|
35 |
+
"""A single-replica view of training procedure.
|
36 |
+
|
37 |
+
Tasks provide artifacts for training/validation procedures, including
|
38 |
+
loading/iterating over Datasets, training/validation steps, calculating the
|
39 |
+
loss and customized metrics with reduction.
|
40 |
+
"""
|
41 |
+
|
42 |
+
# Special keys in train/validate step returned logs.
|
43 |
+
loss = "loss"
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
params,
|
47 |
+
logging_dir: Optional[str] = None,
|
48 |
+
name: Optional[str] = None):
|
49 |
+
"""Task initialization.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
params: the task configuration instance, which can be any of dataclass,
|
53 |
+
ConfigDict, namedtuple, etc.
|
54 |
+
logging_dir: a string pointing to where the model, summaries etc. will be
|
55 |
+
saved. You can also write additional stuff in this directory.
|
56 |
+
name: the task name.
|
57 |
+
"""
|
58 |
+
super().__init__(name=name)
|
59 |
+
self._task_config = params
|
60 |
+
self._logging_dir = (
|
61 |
+
logging_dir or ""
|
62 |
+
) # Empty directory hints current working dir.
|
63 |
+
|
64 |
+
@property
|
65 |
+
def task_config(self):
|
66 |
+
return self._task_config
|
67 |
+
|
68 |
+
@property
|
69 |
+
def logging_dir(self) -> str:
|
70 |
+
return self._logging_dir
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def create_optimizer(cls, optimizer_config: OptimizationConfig,
|
74 |
+
runtime_config: Optional[RuntimeConfig] = None,
|
75 |
+
dp_config: Optional[DifferentialPrivacyConfig] = None):
|
76 |
+
"""Creates an TF optimizer from configurations.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
optimizer_config: the parameters of the Optimization settings.
|
80 |
+
runtime_config: the parameters of the runtime.
|
81 |
+
dp_config: the parameter of differential privacy.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
A tf.optimizers.Optimizer object.
|
85 |
+
"""
|
86 |
+
gradient_transformers = None
|
87 |
+
if dp_config is not None:
|
88 |
+
logging.info("Adding differential privacy transform with config %s.",
|
89 |
+
dp_config.as_dict())
|
90 |
+
noise_stddev = dp_config.clipping_norm * dp_config.noise_multiplier
|
91 |
+
gradient_transformers = [
|
92 |
+
functools.partial(
|
93 |
+
ops.clip_l2_norm, l2_norm_clip=dp_config.clipping_norm),
|
94 |
+
functools.partial(
|
95 |
+
ops.add_noise, noise_stddev=noise_stddev)
|
96 |
+
]
|
97 |
+
|
98 |
+
opt_factory = optimization.OptimizerFactory(optimizer_config)
|
99 |
+
optimizer = opt_factory.build_optimizer(
|
100 |
+
opt_factory.build_learning_rate(),
|
101 |
+
gradient_transformers=gradient_transformers
|
102 |
+
)
|
103 |
+
# Configuring optimizer when loss_scale is set in runtime config. This helps
|
104 |
+
# avoiding overflow/underflow for float16 computations.
|
105 |
+
if runtime_config:
|
106 |
+
optimizer = performance.configure_optimizer(
|
107 |
+
optimizer,
|
108 |
+
use_float16=runtime_config.mixed_precision_dtype == "float16",
|
109 |
+
loss_scale=runtime_config.loss_scale)
|
110 |
+
|
111 |
+
return optimizer
|
112 |
+
|
113 |
+
def initialize(self, model: tf_keras.Model):
|
114 |
+
"""[Optional] A callback function used as CheckpointManager's init_fn.
|
115 |
+
|
116 |
+
This function will be called when no checkpoint is found for the model.
|
117 |
+
If there is a checkpoint, the checkpoint will be loaded and this function
|
118 |
+
will not be called. You can use this callback function to load a pretrained
|
119 |
+
checkpoint, saved under a directory other than the model_dir.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
model: The keras.Model built or used by this task.
|
123 |
+
"""
|
124 |
+
ckpt_dir_or_file = self.task_config.init_checkpoint
|
125 |
+
logging.info("Trying to load pretrained checkpoint from %s",
|
126 |
+
ckpt_dir_or_file)
|
127 |
+
if ckpt_dir_or_file and tf.io.gfile.isdir(ckpt_dir_or_file):
|
128 |
+
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
|
129 |
+
if not ckpt_dir_or_file:
|
130 |
+
logging.info("No checkpoint file found from %s. Will not load.",
|
131 |
+
ckpt_dir_or_file)
|
132 |
+
return
|
133 |
+
|
134 |
+
if hasattr(model, "checkpoint_items"):
|
135 |
+
checkpoint_items = model.checkpoint_items
|
136 |
+
else:
|
137 |
+
checkpoint_items = dict(model=model)
|
138 |
+
ckpt = tf.train.Checkpoint(**checkpoint_items)
|
139 |
+
status = ckpt.read(ckpt_dir_or_file)
|
140 |
+
status.expect_partial().assert_existing_objects_matched()
|
141 |
+
logging.info("Finished loading pretrained checkpoint from %s",
|
142 |
+
ckpt_dir_or_file)
|
143 |
+
|
144 |
+
def build_model(self) -> tf_keras.Model:
|
145 |
+
"""[Optional] Creates model architecture.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
A model instance.
|
149 |
+
""" # pytype: disable=bad-return-type # typed-keras
|
150 |
+
|
151 |
+
@abc.abstractmethod
|
152 |
+
def build_inputs(self,
|
153 |
+
params,
|
154 |
+
input_context: Optional[tf.distribute.InputContext] = None):
|
155 |
+
"""Returns a dataset or a nested structure of dataset functions.
|
156 |
+
|
157 |
+
Dataset functions define per-host datasets with the per-replica batch size.
|
158 |
+
With distributed training, this method runs on remote hosts.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
params: hyperparams to create input pipelines, which can be any of
|
162 |
+
dataclass, ConfigDict, namedtuple, etc.
|
163 |
+
input_context: optional distribution input pipeline context.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
A nested structure of per-replica input functions.
|
167 |
+
"""
|
168 |
+
|
169 |
+
def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
|
170 |
+
"""Standard interface to compute losses.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
labels: optional label tensors.
|
174 |
+
model_outputs: a nested structure of output tensors.
|
175 |
+
aux_losses: auxiliary loss tensors, i.e. `losses` in keras.Model.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
The total loss tensor.
|
179 |
+
"""
|
180 |
+
del model_outputs, labels
|
181 |
+
|
182 |
+
if aux_losses is None:
|
183 |
+
losses = [tf.constant(0.0, dtype=tf.float32)]
|
184 |
+
else:
|
185 |
+
losses = aux_losses
|
186 |
+
total_loss = tf.add_n(losses)
|
187 |
+
return total_loss
|
188 |
+
|
189 |
+
def build_metrics(self, training: bool = True):
|
190 |
+
"""Gets streaming metrics for training/validation."""
|
191 |
+
del training
|
192 |
+
return []
|
193 |
+
|
194 |
+
def process_metrics(self, metrics, labels, model_outputs, **kwargs):
|
195 |
+
"""Process and update metrics.
|
196 |
+
|
197 |
+
Called when using custom training loop API.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
metrics: a nested structure of metrics objects. The return of function
|
201 |
+
self.build_metrics.
|
202 |
+
labels: a tensor or a nested structure of tensors.
|
203 |
+
model_outputs: a tensor or a nested structure of tensors. For example,
|
204 |
+
output of the keras model built by self.build_model.
|
205 |
+
**kwargs: other args.
|
206 |
+
"""
|
207 |
+
for metric in metrics:
|
208 |
+
metric.update_state(labels, model_outputs)
|
209 |
+
|
210 |
+
def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
|
211 |
+
"""Process and update compiled_metrics.
|
212 |
+
|
213 |
+
call when using compile/fit API.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
compiled_metrics: the compiled metrics (model.compiled_metrics).
|
217 |
+
labels: a tensor or a nested structure of tensors.
|
218 |
+
model_outputs: a tensor or a nested structure of tensors. For example,
|
219 |
+
output of the keras model built by self.build_model.
|
220 |
+
"""
|
221 |
+
compiled_metrics.update_state(labels, model_outputs)
|
222 |
+
|
223 |
+
def train_step(self,
|
224 |
+
inputs,
|
225 |
+
model: tf_keras.Model,
|
226 |
+
optimizer: tf_keras.optimizers.Optimizer,
|
227 |
+
metrics=None):
|
228 |
+
"""Does forward and backward.
|
229 |
+
|
230 |
+
With distribution strategies, this method runs on devices.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
inputs: a dictionary of input tensors.
|
234 |
+
model: the model, forward pass definition.
|
235 |
+
optimizer: the optimizer for this training step.
|
236 |
+
metrics: a nested structure of metrics objects.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
A dictionary of logs.
|
240 |
+
"""
|
241 |
+
if isinstance(inputs, tuple) and len(inputs) == 2:
|
242 |
+
features, labels = inputs
|
243 |
+
else:
|
244 |
+
features, labels = inputs, inputs
|
245 |
+
with tf.GradientTape() as tape:
|
246 |
+
outputs = model(features, training=True)
|
247 |
+
# Computes per-replica loss.
|
248 |
+
if model.compiled_loss:
|
249 |
+
loss = model.compiled_loss(
|
250 |
+
labels, outputs, regularization_losses=model.losses)
|
251 |
+
loss += self.build_losses(
|
252 |
+
labels=labels, model_outputs=outputs, aux_losses=None)
|
253 |
+
else:
|
254 |
+
loss = self.build_losses(
|
255 |
+
labels=labels, model_outputs=outputs, aux_losses=model.losses)
|
256 |
+
# Scales loss as the default gradients allreduce performs sum inside the
|
257 |
+
# optimizer.
|
258 |
+
scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
|
259 |
+
|
260 |
+
# For mixed precision, when a LossScaleOptimizer is used, the loss is
|
261 |
+
# scaled to avoid numeric underflow.
|
262 |
+
if isinstance(optimizer,
|
263 |
+
tf_keras.mixed_precision.LossScaleOptimizer):
|
264 |
+
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
|
265 |
+
|
266 |
+
tvars = model.trainable_variables
|
267 |
+
grads = tape.gradient(scaled_loss, tvars)
|
268 |
+
|
269 |
+
if isinstance(optimizer,
|
270 |
+
tf_keras.mixed_precision.LossScaleOptimizer):
|
271 |
+
grads = optimizer.get_unscaled_gradients(grads)
|
272 |
+
optimizer.apply_gradients(list(zip(grads, tvars)))
|
273 |
+
logs = {self.loss: loss}
|
274 |
+
if metrics:
|
275 |
+
self.process_metrics(metrics, labels, outputs)
|
276 |
+
if model.compiled_metrics:
|
277 |
+
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
|
278 |
+
logs.update({m.name: m.result() for m in metrics or []})
|
279 |
+
logs.update({m.name: m.result() for m in model.metrics})
|
280 |
+
return logs
|
281 |
+
|
282 |
+
def validation_step(self, inputs, model: tf_keras.Model, metrics=None):
|
283 |
+
"""Validation step.
|
284 |
+
|
285 |
+
With distribution strategies, this method runs on devices.
|
286 |
+
|
287 |
+
Args:
|
288 |
+
inputs: a dictionary of input tensors.
|
289 |
+
model: the keras.Model.
|
290 |
+
metrics: a nested structure of metrics objects.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
A dictionary of logs.
|
294 |
+
"""
|
295 |
+
if isinstance(inputs, tuple) and len(inputs) == 2:
|
296 |
+
features, labels = inputs
|
297 |
+
else:
|
298 |
+
features, labels = inputs, inputs
|
299 |
+
outputs = self.inference_step(features, model)
|
300 |
+
loss = self.build_losses(
|
301 |
+
labels=labels, model_outputs=outputs, aux_losses=model.losses)
|
302 |
+
logs = {self.loss: loss}
|
303 |
+
if metrics:
|
304 |
+
self.process_metrics(metrics, labels, outputs)
|
305 |
+
if model.compiled_metrics:
|
306 |
+
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
|
307 |
+
logs.update({m.name: m.result() for m in metrics or []})
|
308 |
+
logs.update({m.name: m.result() for m in model.metrics})
|
309 |
+
return logs
|
310 |
+
|
311 |
+
def inference_step(self, inputs, model: tf_keras.Model):
|
312 |
+
"""Performs the forward step.
|
313 |
+
|
314 |
+
With distribution strategies, this method runs on devices.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
inputs: a dictionary of input tensors.
|
318 |
+
model: the keras.Model.
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
Model outputs.
|
322 |
+
"""
|
323 |
+
return model(inputs, training=False)
|
324 |
+
|
325 |
+
def aggregate_logs(self, state, step_logs):
|
326 |
+
"""Optional aggregation over logs returned from a validation step.
|
327 |
+
|
328 |
+
Given step_logs from a validation step, this function aggregates the logs
|
329 |
+
after each eval_step() (see eval_reduce() function in
|
330 |
+
official/core/base_trainer.py). It runs on CPU and can be used to aggregate
|
331 |
+
metrics during validation, when there are too many metrics that cannot fit
|
332 |
+
into TPU memory. Note that this may increase latency due to data transfer
|
333 |
+
between TPU and CPU. Also, the step output from a validation step may be a
|
334 |
+
tuple with elements from replicas, and a concatenation of the elements is
|
335 |
+
needed in such case.
|
336 |
+
|
337 |
+
Args:
|
338 |
+
state: The current state of training, for example, it can be a sequence of
|
339 |
+
metrics.
|
340 |
+
step_logs: Logs from a validation step. Can be a dictionary.
|
341 |
+
"""
|
342 |
+
pass
|
343 |
+
|
344 |
+
def reduce_aggregated_logs(self,
|
345 |
+
aggregated_logs,
|
346 |
+
global_step: Optional[tf.Tensor] = None):
|
347 |
+
"""Optional reduce of aggregated logs over validation steps.
|
348 |
+
|
349 |
+
This function reduces aggregated logs at the end of validation, and can be
|
350 |
+
used to compute the final metrics. It runs on CPU and in each eval_end() in
|
351 |
+
base trainer (see eval_end() function in official/core/base_trainer.py).
|
352 |
+
|
353 |
+
Args:
|
354 |
+
aggregated_logs: Aggregated logs over multiple validation steps.
|
355 |
+
global_step: An optional variable of global step.
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
A dictionary of reduced results.
|
359 |
+
"""
|
360 |
+
return {}
|
modeling/official/core/base_trainer.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Standard Trainer implementation.
|
16 |
+
|
17 |
+
The base trainer implements the Orbit `StandardTrainable` and
|
18 |
+
`StandardEvaluable` interfaces. Trainers inside this project should be
|
19 |
+
interchangable and independent on model architectures and tasks.
|
20 |
+
"""
|
21 |
+
import functools
|
22 |
+
from typing import Union, Optional
|
23 |
+
from absl import logging
|
24 |
+
import gin
|
25 |
+
import orbit
|
26 |
+
import tensorflow as tf, tf_keras
|
27 |
+
|
28 |
+
from official.core import base_task
|
29 |
+
from official.core import config_definitions
|
30 |
+
from official.modeling import optimization
|
31 |
+
|
32 |
+
ExperimentConfig = config_definitions.ExperimentConfig
|
33 |
+
TrainerConfig = config_definitions.TrainerConfig
|
34 |
+
|
35 |
+
|
36 |
+
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator):
|
37 |
+
"""Trainer class for both sync and async Strategy."""
|
38 |
+
|
39 |
+
def init_async(self):
|
40 |
+
"""Initializes the Async Trainer base class."""
|
41 |
+
assert isinstance(self._strategy, tf.distribute.Strategy)
|
42 |
+
self._is_async = isinstance(
|
43 |
+
self._strategy, tf.distribute.experimental.ParameterServerStrategy)
|
44 |
+
self._coordinator = None
|
45 |
+
if self._is_async:
|
46 |
+
self._coordinator = (
|
47 |
+
tf.distribute.experimental.coordinator.ClusterCoordinator(
|
48 |
+
self._strategy))
|
49 |
+
|
50 |
+
def coordinator_for_async(
|
51 |
+
self,
|
52 |
+
) -> tf.distribute.experimental.coordinator.ClusterCoordinator:
|
53 |
+
if not self._coordinator:
|
54 |
+
raise ValueError(
|
55 |
+
"Coordinator uninitialized for async run. Call init_async() first."
|
56 |
+
)
|
57 |
+
return self._coordinator
|
58 |
+
|
59 |
+
def join(self):
|
60 |
+
"""Join all async steps. Only useful in aysnc training."""
|
61 |
+
if getattr(self, "_is_async", False):
|
62 |
+
self.coordinator_for_async().join()
|
63 |
+
|
64 |
+
def create_train_loop_fn(self):
|
65 |
+
"""Creates a eval loop from the given step function and options."""
|
66 |
+
train_loop_fn = super().create_train_loop_fn()
|
67 |
+
if getattr(self, "_is_async", False):
|
68 |
+
|
69 |
+
def _async_loop_fn(iterator, num_steps):
|
70 |
+
self.coordinator_for_async().schedule(
|
71 |
+
train_loop_fn, args=(iterator, num_steps)
|
72 |
+
)
|
73 |
+
|
74 |
+
return _async_loop_fn
|
75 |
+
else:
|
76 |
+
return train_loop_fn
|
77 |
+
|
78 |
+
def create_eval_loop_fn(self, has_state: bool):
|
79 |
+
"""Creates a training loop from the given step function and options."""
|
80 |
+
eval_loop_fn = super().create_eval_loop_fn(has_state)
|
81 |
+
|
82 |
+
if getattr(self, "_is_async", False):
|
83 |
+
if has_state:
|
84 |
+
raise ValueError(
|
85 |
+
"Stateful eval loop is not supported in async training.")
|
86 |
+
|
87 |
+
def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None):
|
88 |
+
assert state is None
|
89 |
+
assert reduce_fn is None
|
90 |
+
self.coordinator_for_async().schedule(
|
91 |
+
eval_loop_fn, args=(iterator, num_steps)
|
92 |
+
)
|
93 |
+
|
94 |
+
return _async_loop_fn
|
95 |
+
else:
|
96 |
+
return eval_loop_fn
|
97 |
+
|
98 |
+
def distribute_dataset(self, dataset_or_fn, *args, **kwargs):
|
99 |
+
"""A utility function to help create a `tf.distribute.DistributedDataset`.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function"
|
103 |
+
returning a `tf.data.Dataset`. If it is a function, it may optionally
|
104 |
+
have an argument named `input_context` which will be passed a
|
105 |
+
`tf.distribute.InputContext` instance.
|
106 |
+
*args: Any positional arguments to pass through to `dataset_or_fn`.
|
107 |
+
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
A distributed Dataset.
|
111 |
+
"""
|
112 |
+
if getattr(self, "_is_async", False):
|
113 |
+
per_worker_dataset_fn = functools.partial(
|
114 |
+
orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn,
|
115 |
+
*args, **kwargs)
|
116 |
+
per_worker_dataset_fn = tf.function(per_worker_dataset_fn)
|
117 |
+
|
118 |
+
return self.coordinator_for_async().create_per_worker_dataset(
|
119 |
+
per_worker_dataset_fn
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn,
|
123 |
+
*args, **kwargs)
|
124 |
+
|
125 |
+
|
126 |
+
def get_runtime_options(config: ExperimentConfig):
|
127 |
+
"""Get tf.distribute.RunOptions from config."""
|
128 |
+
xla_options = {}
|
129 |
+
if config.runtime.tpu_enable_xla_dynamic_padder is not None:
|
130 |
+
xla_options["enable_xla_dynamic_padder"] = (
|
131 |
+
config.runtime.tpu_enable_xla_dynamic_padder)
|
132 |
+
return tf.distribute.RunOptions(
|
133 |
+
experimental_xla_options=tf.tpu.XLAOptions(**xla_options))
|
134 |
+
|
135 |
+
|
136 |
+
@gin.configurable
|
137 |
+
class Trainer(_AsyncTrainer):
|
138 |
+
"""Implements the common trainer shared for TensorFlow models."""
|
139 |
+
|
140 |
+
# pylint: disable=super-init-not-called
|
141 |
+
def __init__(
|
142 |
+
self,
|
143 |
+
config: ExperimentConfig,
|
144 |
+
task: base_task.Task,
|
145 |
+
model: tf_keras.Model,
|
146 |
+
optimizer: tf.optimizers.Optimizer,
|
147 |
+
train: bool = True,
|
148 |
+
evaluate: bool = True,
|
149 |
+
train_dataset: Optional[Union[tf.data.Dataset,
|
150 |
+
tf.distribute.DistributedDataset]] = None,
|
151 |
+
validation_dataset: Optional[Union[
|
152 |
+
tf.data.Dataset, tf.distribute.DistributedDataset]] = None,
|
153 |
+
checkpoint_exporter=None):
|
154 |
+
"""Initialize common trainer for TensorFlow models.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
config: An `ExperimentConfig` instance specifying experiment config.
|
158 |
+
task: A base_task.Task instance.
|
159 |
+
model: The model instance, e.g. a tf_keras.Model instance.
|
160 |
+
optimizer: tf.optimizers.Optimizer instance.
|
161 |
+
train: bool, whether or not this trainer will be used for training.
|
162 |
+
default to True.
|
163 |
+
evaluate: bool, whether or not this trainer will be used for evaluation.
|
164 |
+
default to True.
|
165 |
+
train_dataset: a dataset object created for training. With tf.distribute,
|
166 |
+
it needs to be a `DistributedDataset`.
|
167 |
+
validation_dataset: a dataset object created for evaluation. With
|
168 |
+
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will
|
169 |
+
create a dataset iterator for each eval round, so the dataset does not
|
170 |
+
need to repeat.
|
171 |
+
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
|
172 |
+
interface.
|
173 |
+
"""
|
174 |
+
# Gets the current distribution strategy. If not inside any strategy scope,
|
175 |
+
# it gets a single-replica no-op strategy.
|
176 |
+
self._strategy = tf.distribute.get_strategy()
|
177 |
+
self._validate_params(
|
178 |
+
config,
|
179 |
+
check_train_data=train_dataset is None,
|
180 |
+
check_validation_data=validation_dataset is None)
|
181 |
+
self._config = config
|
182 |
+
self._task = task
|
183 |
+
self._model = model
|
184 |
+
self._optimizer = optimizer
|
185 |
+
self._checkpoint_exporter = checkpoint_exporter
|
186 |
+
self._recovery = None
|
187 |
+
# Runtime options are only applied to train_step.
|
188 |
+
# We use default for eval_step.
|
189 |
+
self._runtime_options = get_runtime_options(config)
|
190 |
+
|
191 |
+
# Creates a shadow copy of the weights to store weights moving average.
|
192 |
+
if isinstance(self._optimizer, optimization.ExponentialMovingAverage
|
193 |
+
) and not self._optimizer.has_shadow_copy:
|
194 |
+
self._optimizer.shadow_copy(self._model)
|
195 |
+
|
196 |
+
# global_step increases by 1 after each training iteration.
|
197 |
+
# We should have global_step.numpy() == self.optimizer.iterations.numpy()
|
198 |
+
# when there is only 1 optimizer.
|
199 |
+
self._global_step = orbit.utils.create_global_step()
|
200 |
+
if hasattr(self.model, "checkpoint_items"):
|
201 |
+
checkpoint_items = self.model.checkpoint_items
|
202 |
+
else:
|
203 |
+
checkpoint_items = {}
|
204 |
+
self._checkpoint = tf.train.Checkpoint(
|
205 |
+
global_step=self.global_step,
|
206 |
+
model=self.model,
|
207 |
+
optimizer=self.optimizer,
|
208 |
+
**checkpoint_items)
|
209 |
+
|
210 |
+
self._train_loss = tf_keras.metrics.Mean("training_loss", dtype=tf.float32)
|
211 |
+
self._validation_loss = tf_keras.metrics.Mean(
|
212 |
+
"validation_loss", dtype=tf.float32)
|
213 |
+
model_metrics = model.metrics if hasattr(model, "metrics") else []
|
214 |
+
|
215 |
+
self.init_async()
|
216 |
+
|
217 |
+
if train:
|
218 |
+
self._train_metrics = self.task.build_metrics(
|
219 |
+
training=True) + model_metrics
|
220 |
+
train_dataset = train_dataset or self.distribute_dataset(
|
221 |
+
self.task.build_inputs, self.config.task.train_data)
|
222 |
+
orbit.StandardTrainer.__init__(
|
223 |
+
self,
|
224 |
+
train_dataset,
|
225 |
+
options=orbit.StandardTrainerOptions(
|
226 |
+
use_tf_while_loop=config.trainer.train_tf_while_loop,
|
227 |
+
use_tf_function=config.trainer.train_tf_function,
|
228 |
+
use_tpu_summary_optimization=config.trainer.allow_tpu_summary))
|
229 |
+
|
230 |
+
if evaluate:
|
231 |
+
self._validation_metrics = self.task.build_metrics(
|
232 |
+
training=False) + model_metrics
|
233 |
+
validation_dataset = validation_dataset or self.distribute_dataset(
|
234 |
+
self.task.build_inputs, self.config.task.validation_data)
|
235 |
+
orbit.StandardEvaluator.__init__(
|
236 |
+
self,
|
237 |
+
validation_dataset,
|
238 |
+
options=orbit.StandardEvaluatorOptions(
|
239 |
+
use_tf_function=config.trainer.eval_tf_function,
|
240 |
+
use_tf_while_loop=config.trainer.eval_tf_while_loop))
|
241 |
+
|
242 |
+
def _validate_params(self,
|
243 |
+
config,
|
244 |
+
check_train_data=True,
|
245 |
+
check_validation_data=True):
|
246 |
+
r"""Validates if the configuration object passed to the Trainer.
|
247 |
+
|
248 |
+
The experiment configuration should be structured as:
|
249 |
+
\trainer
|
250 |
+
\task
|
251 |
+
\train_data
|
252 |
+
\validation_data
|
253 |
+
|
254 |
+
Args:
|
255 |
+
config: a namedtuple, dataclass, ConfigDict, etc.
|
256 |
+
check_train_data: whether to check task.train_data field.
|
257 |
+
check_validation_data: whether to check task.validation_data field.
|
258 |
+
"""
|
259 |
+
if not hasattr(config, "trainer"):
|
260 |
+
raise AttributeError("The trainer requires the configuration contains an"
|
261 |
+
" attribute `trainer`.")
|
262 |
+
|
263 |
+
if not hasattr(config, "task"):
|
264 |
+
raise AttributeError("The trainer requires the configuration contains an"
|
265 |
+
" attribute `task`.")
|
266 |
+
|
267 |
+
if check_train_data and not hasattr(config.task, "train_data"):
|
268 |
+
raise AttributeError("The trainer requires the configuration contains an"
|
269 |
+
" attribute `task.train_data`.")
|
270 |
+
|
271 |
+
if check_validation_data and not hasattr(config.task, "validation_data"):
|
272 |
+
raise AttributeError("The trainer requires the configuration contains an"
|
273 |
+
" attribute `task.validation_data`.")
|
274 |
+
|
275 |
+
@property
|
276 |
+
def strategy(self):
|
277 |
+
return self._strategy
|
278 |
+
|
279 |
+
@property
|
280 |
+
def config(self):
|
281 |
+
return self._config
|
282 |
+
|
283 |
+
@property
|
284 |
+
def task(self):
|
285 |
+
return self._task
|
286 |
+
|
287 |
+
@property
|
288 |
+
def model(self):
|
289 |
+
return self._model
|
290 |
+
|
291 |
+
@property
|
292 |
+
def optimizer(self):
|
293 |
+
if hasattr(self, "_optimizer"):
|
294 |
+
return self._optimizer
|
295 |
+
else:
|
296 |
+
return None
|
297 |
+
|
298 |
+
@property
|
299 |
+
def global_step(self):
|
300 |
+
return self._global_step
|
301 |
+
|
302 |
+
@property
|
303 |
+
def train_loss(self):
|
304 |
+
"""Accesses the training loss metric object."""
|
305 |
+
return self._train_loss
|
306 |
+
|
307 |
+
@property
|
308 |
+
def validation_loss(self):
|
309 |
+
"""Accesses the validation loss metric object."""
|
310 |
+
return self._validation_loss
|
311 |
+
|
312 |
+
@property
|
313 |
+
def train_metrics(self):
|
314 |
+
"""Accesses all training metric objects."""
|
315 |
+
return self._train_metrics
|
316 |
+
|
317 |
+
@property
|
318 |
+
def validation_metrics(self):
|
319 |
+
"""Accesses all validation metric metric objects."""
|
320 |
+
return self._validation_metrics
|
321 |
+
|
322 |
+
def initialize(self):
|
323 |
+
"""A callback function.
|
324 |
+
|
325 |
+
This function will be called when no checkpoint found for the model.
|
326 |
+
If there is a checkpoint, the checkpoint will be loaded and this function
|
327 |
+
will not be called. Tasks may use this callback function to load a
|
328 |
+
pretrained checkpoint, saved under a directory other than the model_dir.
|
329 |
+
"""
|
330 |
+
self.task.initialize(self.model)
|
331 |
+
|
332 |
+
@property
|
333 |
+
def checkpoint(self):
|
334 |
+
"""Accesses the training checkpoint."""
|
335 |
+
return self._checkpoint
|
336 |
+
|
337 |
+
@property
|
338 |
+
def checkpoint_exporter(self):
|
339 |
+
"""Accesses the checkpoint exporter."""
|
340 |
+
return self._checkpoint_exporter
|
341 |
+
|
342 |
+
def train_loop_end(self):
|
343 |
+
"""See base class."""
|
344 |
+
self.join()
|
345 |
+
logs = {}
|
346 |
+
for metric in self.train_metrics + [self.train_loss]:
|
347 |
+
logs[metric.name] = metric.result()
|
348 |
+
metric.reset_states()
|
349 |
+
if callable(self.optimizer.learning_rate):
|
350 |
+
# Maybe a self-implemented optimizer does not have `optimizer.iterations`.
|
351 |
+
# So just to be safe here.
|
352 |
+
if hasattr(self.optimizer, "iterations"):
|
353 |
+
logs["learning_rate"] = self.optimizer.learning_rate(
|
354 |
+
self.optimizer.iterations)
|
355 |
+
else:
|
356 |
+
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step)
|
357 |
+
else:
|
358 |
+
logs["learning_rate"] = self.optimizer.learning_rate
|
359 |
+
return logs
|
360 |
+
|
361 |
+
def next_train_inputs(self, iterator):
|
362 |
+
"""Fetches the next inputs for the model during train.
|
363 |
+
|
364 |
+
This method consumes the input iterator and returns the next inputs for the
|
365 |
+
model.
|
366 |
+
|
367 |
+
This method provides a way to control how to fetch the next model input, and
|
368 |
+
what data to send to the model.
|
369 |
+
|
370 |
+
Note: This function runs on the host side when accelerators are used.
|
371 |
+
|
372 |
+
Note: Depending on the training setup this may or may not run in eager mode.
|
373 |
+
In most cases it will be run in graph mode.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
iterator: Dataset iterator to generate the next inputs from.
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
The inputs to the model.
|
380 |
+
"""
|
381 |
+
return next(iterator)
|
382 |
+
|
383 |
+
def train_step(self, iterator):
|
384 |
+
"""See base class."""
|
385 |
+
|
386 |
+
def step_fn(inputs):
|
387 |
+
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0):
|
388 |
+
task_train_step = tf.function(self.task.train_step, jit_compile=True)
|
389 |
+
else:
|
390 |
+
task_train_step = self.task.train_step
|
391 |
+
logs = task_train_step(
|
392 |
+
inputs,
|
393 |
+
model=self.model,
|
394 |
+
optimizer=self.optimizer,
|
395 |
+
metrics=self.train_metrics)
|
396 |
+
self._train_loss.update_state(logs[self.task.loss])
|
397 |
+
self.global_step.assign_add(1)
|
398 |
+
|
399 |
+
inputs = self.next_train_inputs(iterator)
|
400 |
+
self.strategy.run(step_fn, args=(inputs,), options=self._runtime_options)
|
401 |
+
|
402 |
+
def eval_begin(self):
|
403 |
+
"""Sets up metrics."""
|
404 |
+
for metric in self.validation_metrics + [self.validation_loss]:
|
405 |
+
metric.reset_states()
|
406 |
+
# Swaps weights to test on weights moving average.
|
407 |
+
if self.optimizer and isinstance(self.optimizer,
|
408 |
+
optimization.ExponentialMovingAverage):
|
409 |
+
self.optimizer.swap_weights()
|
410 |
+
|
411 |
+
def next_eval_inputs(self, iterator):
|
412 |
+
"""Fetches the next inputs for the model during eval.
|
413 |
+
|
414 |
+
This method consumes the input iterator and returns the next inputs for the
|
415 |
+
model and an additional logs dict. The output dict remains in the host (not
|
416 |
+
sent to GPUs/TPUs) and is merged with the model outputs which will be
|
417 |
+
processed later in `aggregate_logs`. This is useful for sending extra logs
|
418 |
+
downstream that are not compatible with the accelerators.
|
419 |
+
|
420 |
+
Note: This function runs on the host side when accelerators are used.
|
421 |
+
|
422 |
+
Note: Depending on the training setup this may or may not run in eager mode.
|
423 |
+
In most cases it will be run in graph mode.
|
424 |
+
|
425 |
+
Args:
|
426 |
+
iterator: Dataset iterator to generate the next inputs from.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
The inputs to the model, and an additional logs dictionnary. The logs
|
430 |
+
are not passed to the model, instead they are merged with model output
|
431 |
+
logs.
|
432 |
+
"""
|
433 |
+
passthrough_logs = dict()
|
434 |
+
return next(iterator), passthrough_logs
|
435 |
+
|
436 |
+
def eval_step(self, iterator):
|
437 |
+
"""See base class."""
|
438 |
+
|
439 |
+
def step_fn(inputs):
|
440 |
+
logs = self.task.validation_step(
|
441 |
+
inputs, model=self.model, metrics=self.validation_metrics)
|
442 |
+
if self.task.loss in logs:
|
443 |
+
self._validation_loss.update_state(logs[self.task.loss])
|
444 |
+
return logs
|
445 |
+
|
446 |
+
inputs, passthrough_logs = self.next_eval_inputs(iterator)
|
447 |
+
distributed_outputs = self.strategy.run(step_fn, args=(inputs,))
|
448 |
+
logs = tf.nest.map_structure(
|
449 |
+
self.strategy.experimental_local_results, distributed_outputs
|
450 |
+
)
|
451 |
+
|
452 |
+
if set(logs.keys()) & set(passthrough_logs.keys()):
|
453 |
+
logging.warning(
|
454 |
+
(
|
455 |
+
"Conflict between the pasthrough log keys and the returned model"
|
456 |
+
" log keys. Found %r keys in the passthrough logs and %r keys in"
|
457 |
+
" the model logs. Model log keys takes precedence."
|
458 |
+
),
|
459 |
+
logs.keys(),
|
460 |
+
passthrough_logs.keys(),
|
461 |
+
)
|
462 |
+
|
463 |
+
return passthrough_logs | logs
|
464 |
+
|
465 |
+
def eval_end(self, aggregated_logs=None):
|
466 |
+
"""Processes evaluation results."""
|
467 |
+
self.join()
|
468 |
+
logs = {}
|
469 |
+
for metric in self.validation_metrics:
|
470 |
+
logs[metric.name] = metric.result()
|
471 |
+
if self.validation_loss.count.numpy() != 0:
|
472 |
+
logs[self.validation_loss.name] = self.validation_loss.result()
|
473 |
+
else:
|
474 |
+
# `self.validation_loss` metric was not updated, because the validation
|
475 |
+
# loss was not returned from the task's `validation_step` method.
|
476 |
+
logging.info("The task did not report validation loss.")
|
477 |
+
if aggregated_logs:
|
478 |
+
metrics = self.task.reduce_aggregated_logs(
|
479 |
+
aggregated_logs, global_step=self.global_step)
|
480 |
+
logs.update(metrics)
|
481 |
+
|
482 |
+
if self._checkpoint_exporter:
|
483 |
+
self._checkpoint_exporter.maybe_export_checkpoint(
|
484 |
+
self.checkpoint, logs, self.global_step.numpy())
|
485 |
+
metric_name = self.config.trainer.best_checkpoint_eval_metric
|
486 |
+
logs["best_" +
|
487 |
+
metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name]
|
488 |
+
|
489 |
+
# Swaps back weights after testing when EMA is used.
|
490 |
+
# This happens after best checkpoint export so that average weights used for
|
491 |
+
# eval are exported instead of regular weights.
|
492 |
+
if self.optimizer and isinstance(self.optimizer,
|
493 |
+
optimization.ExponentialMovingAverage):
|
494 |
+
self.optimizer.swap_weights()
|
495 |
+
return logs
|
496 |
+
|
497 |
+
def eval_reduce(self, state=None, step_outputs=None):
|
498 |
+
return self.task.aggregate_logs(state, step_outputs)
|
modeling/official/core/base_trainer_test.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for tensorflow_models.core.trainers.trainer."""
|
16 |
+
# pylint: disable=g-direct-tensorflow-import
|
17 |
+
import gc
|
18 |
+
import multiprocessing
|
19 |
+
import os
|
20 |
+
import sys
|
21 |
+
|
22 |
+
from absl.testing import parameterized
|
23 |
+
import orbit
|
24 |
+
import portpicker
|
25 |
+
import tensorflow as tf, tf_keras
|
26 |
+
|
27 |
+
from tensorflow.python.distribute import combinations
|
28 |
+
from tensorflow.python.distribute import strategy_combinations
|
29 |
+
from official.core import base_trainer as trainer_lib
|
30 |
+
from official.core import config_definitions as cfg
|
31 |
+
from official.core import train_lib
|
32 |
+
from official.utils.testing import mock_task
|
33 |
+
|
34 |
+
TPU_TEST = 'test_tpu' in sys.argv[0]
|
35 |
+
GPU_TEST = 'test_gpu' in sys.argv[0]
|
36 |
+
|
37 |
+
|
38 |
+
def all_strategy_combinations():
|
39 |
+
return combinations.combine(
|
40 |
+
distribution=[
|
41 |
+
strategy_combinations.default_strategy,
|
42 |
+
strategy_combinations.cloud_tpu_strategy,
|
43 |
+
strategy_combinations.one_device_strategy_gpu,
|
44 |
+
],)
|
45 |
+
|
46 |
+
|
47 |
+
def create_in_process_cluster(num_workers, num_ps):
|
48 |
+
"""Creates and starts local servers and returns the cluster_resolver."""
|
49 |
+
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
|
50 |
+
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
|
51 |
+
|
52 |
+
cluster_dict = {}
|
53 |
+
cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
|
54 |
+
if num_ps > 0:
|
55 |
+
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
|
56 |
+
|
57 |
+
cluster_spec = tf.train.ClusterSpec(cluster_dict)
|
58 |
+
|
59 |
+
# Workers need some inter_ops threads to work properly.
|
60 |
+
worker_config = tf.compat.v1.ConfigProto()
|
61 |
+
if multiprocessing.cpu_count() < num_workers + 1:
|
62 |
+
worker_config.inter_op_parallelism_threads = num_workers + 1
|
63 |
+
|
64 |
+
for i in range(num_workers):
|
65 |
+
tf.distribute.Server(
|
66 |
+
cluster_spec,
|
67 |
+
job_name='worker',
|
68 |
+
task_index=i,
|
69 |
+
config=worker_config,
|
70 |
+
protocol='grpc')
|
71 |
+
|
72 |
+
for i in range(num_ps):
|
73 |
+
tf.distribute.Server(
|
74 |
+
cluster_spec, job_name='ps', task_index=i, protocol='grpc')
|
75 |
+
|
76 |
+
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
|
77 |
+
cluster_spec, rpc_layer='grpc')
|
78 |
+
return cluster_resolver
|
79 |
+
|
80 |
+
|
81 |
+
def dataset_fn(input_context=None):
|
82 |
+
del input_context
|
83 |
+
|
84 |
+
def dummy_data(_):
|
85 |
+
return tf.zeros((1, 1), dtype=tf.float32)
|
86 |
+
|
87 |
+
dataset = tf.data.Dataset.range(1)
|
88 |
+
dataset = dataset.repeat()
|
89 |
+
dataset = dataset.map(
|
90 |
+
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
91 |
+
return dataset
|
92 |
+
|
93 |
+
|
94 |
+
class MockAsyncTrainer(trainer_lib._AsyncTrainer):
|
95 |
+
"""Mock AsyncTrainer to test the _AsyncTrainer class."""
|
96 |
+
|
97 |
+
def __init__(self):
|
98 |
+
self._strategy = tf.distribute.get_strategy()
|
99 |
+
self.init_async()
|
100 |
+
|
101 |
+
self.global_step = tf.Variable(
|
102 |
+
0,
|
103 |
+
dtype=tf.int64,
|
104 |
+
name='global_step',
|
105 |
+
trainable=False,
|
106 |
+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
|
107 |
+
self.eval_global_step = tf.Variable(
|
108 |
+
0,
|
109 |
+
dtype=tf.int64,
|
110 |
+
name='eval_global_step',
|
111 |
+
trainable=False,
|
112 |
+
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
|
113 |
+
|
114 |
+
train_dataset = self.distribute_dataset(dataset_fn)
|
115 |
+
orbit.StandardTrainer.__init__(
|
116 |
+
self, train_dataset, options=orbit.StandardTrainerOptions())
|
117 |
+
|
118 |
+
validation_dataset = self.distribute_dataset(dataset_fn)
|
119 |
+
orbit.StandardEvaluator.__init__(
|
120 |
+
self,
|
121 |
+
validation_dataset,
|
122 |
+
options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
|
123 |
+
|
124 |
+
def train_loop_begin(self):
|
125 |
+
self.global_step.assign(0)
|
126 |
+
|
127 |
+
def train_step(self, iterator):
|
128 |
+
|
129 |
+
def replica_step(_):
|
130 |
+
self.global_step.assign_add(1)
|
131 |
+
|
132 |
+
self._strategy.run(replica_step, args=(next(iterator),))
|
133 |
+
|
134 |
+
def train_loop_end(self):
|
135 |
+
self.join()
|
136 |
+
return self.global_step.numpy()
|
137 |
+
|
138 |
+
def eval_begin(self):
|
139 |
+
self.eval_global_step.assign(0)
|
140 |
+
|
141 |
+
def eval_step(self, iterator):
|
142 |
+
|
143 |
+
def replica_step(_):
|
144 |
+
self.eval_global_step.assign_add(1)
|
145 |
+
|
146 |
+
self._strategy.run(replica_step, args=(next(iterator),))
|
147 |
+
|
148 |
+
def eval_end(self):
|
149 |
+
self.join()
|
150 |
+
return self.eval_global_step.numpy()
|
151 |
+
|
152 |
+
|
153 |
+
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
|
154 |
+
|
155 |
+
def setUp(self):
|
156 |
+
super().setUp()
|
157 |
+
self._config = cfg.ExperimentConfig(
|
158 |
+
trainer=cfg.TrainerConfig(
|
159 |
+
optimizer_config=cfg.OptimizationConfig({
|
160 |
+
'optimizer': {
|
161 |
+
'type': 'sgd'
|
162 |
+
},
|
163 |
+
'learning_rate': {
|
164 |
+
'type': 'constant'
|
165 |
+
}
|
166 |
+
})))
|
167 |
+
|
168 |
+
def tearDown(self):
|
169 |
+
gc.collect()
|
170 |
+
# This will only contain uncollectable garbage, i.e. reference cycles
|
171 |
+
# involving objects with __del__ defined.
|
172 |
+
self.assertEmpty(gc.garbage)
|
173 |
+
super().tearDown()
|
174 |
+
|
175 |
+
def create_test_trainer(self, config, model_dir=None, task=None):
|
176 |
+
task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
|
177 |
+
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
|
178 |
+
trainer = trainer_lib.Trainer(
|
179 |
+
config,
|
180 |
+
task,
|
181 |
+
model=task.build_model(),
|
182 |
+
optimizer=task.create_optimizer(config.trainer.optimizer_config,
|
183 |
+
config.runtime),
|
184 |
+
checkpoint_exporter=ckpt_exporter)
|
185 |
+
return trainer
|
186 |
+
|
187 |
+
@combinations.generate(all_strategy_combinations())
|
188 |
+
def test_trainer_train(self, distribution):
|
189 |
+
with distribution.scope():
|
190 |
+
trainer = self.create_test_trainer(self._config)
|
191 |
+
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
192 |
+
self.assertIn('training_loss', logs)
|
193 |
+
self.assertIn('learning_rate', logs)
|
194 |
+
|
195 |
+
@combinations.generate(all_strategy_combinations())
|
196 |
+
def test_trainer_passing_datasets(self, distribution):
|
197 |
+
with distribution.scope():
|
198 |
+
task = mock_task.MockTask(self._config)
|
199 |
+
train_dataset = orbit.utils.make_distributed_dataset(
|
200 |
+
distribution, task.build_inputs, self._config.task.train_data)
|
201 |
+
validation_dataset = orbit.utils.make_distributed_dataset(
|
202 |
+
distribution, task.build_inputs, self._config.task.validation_data)
|
203 |
+
self._config.task.train_data = None
|
204 |
+
self._config.task.validation_data = None
|
205 |
+
trainer = trainer_lib.Trainer(
|
206 |
+
self._config,
|
207 |
+
task,
|
208 |
+
model=task.build_model(),
|
209 |
+
optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
|
210 |
+
self._config.runtime),
|
211 |
+
train_dataset=train_dataset,
|
212 |
+
validation_dataset=validation_dataset)
|
213 |
+
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
214 |
+
self.assertIn('training_loss', logs)
|
215 |
+
self.assertIn('learning_rate', logs)
|
216 |
+
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
217 |
+
self.assertIn('validation_loss', logs)
|
218 |
+
|
219 |
+
def test_base_async_trainer(self):
|
220 |
+
if TPU_TEST or GPU_TEST:
|
221 |
+
self.skipTest('Aysnc training is not available on GPU/GPU.')
|
222 |
+
num_workers = 3
|
223 |
+
num_ps = 2
|
224 |
+
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
|
225 |
+
distribution = tf.distribute.experimental.ParameterServerStrategy(
|
226 |
+
cluster_resolver)
|
227 |
+
with distribution.scope():
|
228 |
+
trainer = MockAsyncTrainer()
|
229 |
+
trainer.init_async()
|
230 |
+
self.assertIsInstance(
|
231 |
+
trainer._coordinator,
|
232 |
+
tf.distribute.experimental.coordinator.ClusterCoordinator)
|
233 |
+
self.assertEqual(trainer.train(tf.constant(10)), 10)
|
234 |
+
self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
|
235 |
+
|
236 |
+
def test_async_trainer_train(self):
|
237 |
+
if TPU_TEST or GPU_TEST:
|
238 |
+
self.skipTest('Aysnc training is not available on GPU/TPU.')
|
239 |
+
num_workers = 3
|
240 |
+
num_ps = 2
|
241 |
+
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
|
242 |
+
distribution = tf.distribute.experimental.ParameterServerStrategy(
|
243 |
+
cluster_resolver)
|
244 |
+
with distribution.scope():
|
245 |
+
config = cfg.ExperimentConfig(**self._config.as_dict())
|
246 |
+
config.trainer.eval_tf_while_loop = True
|
247 |
+
trainer = self.create_test_trainer(config)
|
248 |
+
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
249 |
+
self.assertIn('training_loss', logs)
|
250 |
+
self.assertIn('learning_rate', logs)
|
251 |
+
|
252 |
+
def test_async_trainer_validate(self):
|
253 |
+
if TPU_TEST or GPU_TEST:
|
254 |
+
self.skipTest('Aysnc training is not available on GPU/GPU.')
|
255 |
+
num_workers = 3
|
256 |
+
num_ps = 2
|
257 |
+
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
|
258 |
+
distribution = tf.distribute.experimental.ParameterServerStrategy(
|
259 |
+
cluster_resolver)
|
260 |
+
with distribution.scope():
|
261 |
+
config = cfg.ExperimentConfig(**self._config.as_dict())
|
262 |
+
config.trainer.eval_tf_while_loop = True
|
263 |
+
trainer = self.create_test_trainer(config)
|
264 |
+
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
265 |
+
self.assertIn('acc', logs)
|
266 |
+
self.assertIn('validation_loss', logs)
|
267 |
+
|
268 |
+
@combinations.generate(all_strategy_combinations())
|
269 |
+
def test_trainer_validate(self, distribution):
|
270 |
+
with distribution.scope():
|
271 |
+
trainer = self.create_test_trainer(self._config)
|
272 |
+
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
273 |
+
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
|
274 |
+
self.assertIn('validation_loss', logs)
|
275 |
+
|
276 |
+
@combinations.generate(all_strategy_combinations())
|
277 |
+
def test_trainer_validate_without_loss(self, distribution):
|
278 |
+
|
279 |
+
class MockTaskWithoutValidationLoss(mock_task.MockTask):
|
280 |
+
|
281 |
+
def validation_step(self, inputs, model, metrics=None):
|
282 |
+
# Disable validation loss.
|
283 |
+
logs = super().validation_step(inputs, model)
|
284 |
+
del logs[self.loss]
|
285 |
+
return logs
|
286 |
+
|
287 |
+
with distribution.scope():
|
288 |
+
task = MockTaskWithoutValidationLoss()
|
289 |
+
trainer = self.create_test_trainer(self._config, task=task)
|
290 |
+
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
|
291 |
+
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
|
292 |
+
self.assertNotIn('validation_loss', logs)
|
293 |
+
|
294 |
+
@combinations.generate(
|
295 |
+
combinations.combine(
|
296 |
+
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
|
297 |
+
loss_scale=[None, 'dynamic', 128, 256],
|
298 |
+
))
|
299 |
+
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
|
300 |
+
config = cfg.ExperimentConfig(
|
301 |
+
runtime=cfg.RuntimeConfig(
|
302 |
+
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
|
303 |
+
trainer=cfg.TrainerConfig(
|
304 |
+
optimizer_config=cfg.OptimizationConfig({
|
305 |
+
'optimizer': {
|
306 |
+
'type': 'sgd'
|
307 |
+
},
|
308 |
+
'learning_rate': {
|
309 |
+
'type': 'constant'
|
310 |
+
},
|
311 |
+
})))
|
312 |
+
trainer = self.create_test_trainer(config)
|
313 |
+
if mixed_precision_dtype == 'float16':
|
314 |
+
self.assertIsInstance(trainer.optimizer,
|
315 |
+
tf_keras.mixed_precision.LossScaleOptimizer)
|
316 |
+
if loss_scale in (None, 'dynamic'):
|
317 |
+
self.assertTrue(trainer.optimizer.dynamic)
|
318 |
+
else:
|
319 |
+
self.assertFalse(trainer.optimizer.dynamic)
|
320 |
+
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
|
321 |
+
else:
|
322 |
+
self.assertIsInstance(
|
323 |
+
trainer.optimizer,
|
324 |
+
(tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
|
325 |
+
|
326 |
+
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
327 |
+
self.assertIn('training_loss', metrics)
|
328 |
+
|
329 |
+
def test_export_best_ckpt(self):
|
330 |
+
config = cfg.ExperimentConfig(
|
331 |
+
trainer=cfg.TrainerConfig(
|
332 |
+
best_checkpoint_export_subdir='best_ckpt',
|
333 |
+
best_checkpoint_eval_metric='acc',
|
334 |
+
optimizer_config=cfg.OptimizationConfig({
|
335 |
+
'optimizer': {
|
336 |
+
'type': 'sgd'
|
337 |
+
},
|
338 |
+
'learning_rate': {
|
339 |
+
'type': 'constant'
|
340 |
+
}
|
341 |
+
})))
|
342 |
+
model_dir = self.get_temp_dir()
|
343 |
+
trainer = self.create_test_trainer(config, model_dir=model_dir)
|
344 |
+
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
|
345 |
+
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
|
346 |
+
self.assertTrue(
|
347 |
+
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
|
348 |
+
|
349 |
+
def test_model_with_compiled_loss(self):
|
350 |
+
task = mock_task.MockTask()
|
351 |
+
model = task.build_model()
|
352 |
+
model.compile(loss=tf_keras.losses.CategoricalCrossentropy())
|
353 |
+
trainer = trainer_lib.Trainer(
|
354 |
+
self._config,
|
355 |
+
task,
|
356 |
+
model=model,
|
357 |
+
optimizer=task.create_optimizer(self._config.trainer.optimizer_config))
|
358 |
+
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
|
359 |
+
self.assertIn('training_loss', logs)
|
360 |
+
|
361 |
+
|
362 |
+
if __name__ == '__main__':
|
363 |
+
tf.test.main()
|
modeling/official/core/config_definitions.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Common configuration settings."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
from typing import Optional, Sequence, Union
|
19 |
+
|
20 |
+
from official.modeling.hyperparams import base_config
|
21 |
+
from official.modeling.optimization.configs import optimization_config
|
22 |
+
from official.modeling.privacy import configs as dp_configs
|
23 |
+
|
24 |
+
OptimizationConfig = optimization_config.OptimizationConfig
|
25 |
+
|
26 |
+
|
27 |
+
@dataclasses.dataclass
|
28 |
+
class DataConfig(base_config.Config):
|
29 |
+
"""The base configuration for building datasets.
|
30 |
+
|
31 |
+
Attributes:
|
32 |
+
input_path: The path to the input. It can be either (1) a str indicating a
|
33 |
+
file path/pattern, or (2) a str indicating multiple file paths/patterns
|
34 |
+
separated by comma (e.g "a, b, c" or no spaces "a,b,c"), or (3) a list of
|
35 |
+
str, each of which is a file path/pattern or multiple file paths/patterns
|
36 |
+
separated by comma, or (4) a dictionary of the previous three approaches
|
37 |
+
for more advanced data mixing using named access. It should not be
|
38 |
+
specified when the following `tfds_name` is specified.
|
39 |
+
tfds_name: The name of the tensorflow dataset (TFDS). It should not be
|
40 |
+
specified when the above `input_path` is specified.
|
41 |
+
tfds_split: A str indicating which split of the data to load from TFDS. It
|
42 |
+
is required when above `tfds_name` is specified.
|
43 |
+
global_batch_size: The global batch size across all replicas.
|
44 |
+
is_training: Whether this data is used for training or not. This flag is
|
45 |
+
useful for consumers of this object to determine whether the data should
|
46 |
+
be repeated or shuffled.
|
47 |
+
drop_remainder: Whether the last batch should be dropped in the case it has
|
48 |
+
fewer than `global_batch_size` elements.
|
49 |
+
shuffle_buffer_size: The buffer size used for shuffling training data.
|
50 |
+
cache: Whether to cache dataset examples. If `True`, we will cache the
|
51 |
+
dataset after applying the decode_fn and parse_fn. It can be used to avoid
|
52 |
+
re-reading from disk, re-decoding and re-parsing the example on the second
|
53 |
+
epoch, but it requires significant memory overhead.
|
54 |
+
cycle_length: The number of files that will be processed concurrently when
|
55 |
+
interleaving files.
|
56 |
+
block_length: The number of consecutive elements to produce from each input
|
57 |
+
element before cycling to another input element when interleaving files.
|
58 |
+
deterministic: A boolean controlling whether determinism should be enforced.
|
59 |
+
sharding: Whether sharding is used in the input pipeline.
|
60 |
+
enable_tf_data_service: A boolean indicating whether to enable tf.data
|
61 |
+
service for the input pipeline.
|
62 |
+
tf_data_service_address: The URI of a tf.data service to offload
|
63 |
+
preprocessing onto during training. The URI should be in the format
|
64 |
+
"protocol://address", e.g. "grpc://tf-data-service:5050". It can be
|
65 |
+
overridden by `FLAGS.tf_data_service` flag in the binary.
|
66 |
+
tf_data_service_job_name: The name of the tf.data service job. This argument
|
67 |
+
makes it possible for multiple datasets to share the same job. The default
|
68 |
+
behavior is that the dataset creates anonymous, exclusively owned jobs.
|
69 |
+
tfds_data_dir: A str specifying the directory to read/write TFDS data.
|
70 |
+
tfds_as_supervised: A bool. When loading dataset from TFDS, if True, the
|
71 |
+
returned tf.data.Dataset will have a 2-tuple structure (input, label)
|
72 |
+
according to builder.info.supervised_keys; if False, the default, the
|
73 |
+
returned tf.data.Dataset will have a dictionary with all the features.
|
74 |
+
tfds_skip_decoding_feature: A str to indicate which features are skipped for
|
75 |
+
decoding when loading dataset from TFDS. Use comma to separate multiple
|
76 |
+
features. The main use case is to skip the image/video decoding for better
|
77 |
+
performance.
|
78 |
+
enable_shared_tf_data_service_between_parallel_trainers: A bool. When set to
|
79 |
+
true, only a single tf.data service will be started, and it will be shared
|
80 |
+
between all the trainer run simultaneously, e.g. using vizier to tune
|
81 |
+
hyperparameters. This will save CPU and RAM resources compared to running
|
82 |
+
separate tf.data service for each trainer. Notice that if batch size is
|
83 |
+
different for different trainers, the field
|
84 |
+
apply_tf_data_service_before_batching also needs to be true so that only a
|
85 |
+
single tf.data service instance will be created. In this case, tf.data
|
86 |
+
service will be applied before batching operation. So make sure to not
|
87 |
+
apply any processing steps after batching (e.g. in postprocess_fn) since
|
88 |
+
they wouldn't be paralleled by tf.data service and may slow down your
|
89 |
+
tf.data pipeline. When using shared tf.data service, the tf.data dataset
|
90 |
+
must be infinite, and slow trainer may skip certain training examples.
|
91 |
+
More details about shared tf.data service can be found at:
|
92 |
+
https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers.
|
93 |
+
apply_tf_data_service_before_batching: A bool. If set to True, tf.data
|
94 |
+
service will be applied before batching operation. This is useful to make
|
95 |
+
sure only a single tf.data service instance is created when
|
96 |
+
enable_shared_tf_data_service_between_parallel_trainers is true and batch
|
97 |
+
size is changing between parallel trainers.
|
98 |
+
trainer_id: A string. The id of the trainer if there are multiple parallel
|
99 |
+
trainer running at the same time, e.g. in vizier tuning case. It will be
|
100 |
+
automatically set if this field is needed. Users does not need to set it
|
101 |
+
when creating experiment configs.
|
102 |
+
seed: An optional seed to use for deterministic shuffling/preprocessing.
|
103 |
+
prefetch_buffer_size: An int specifying the buffer size of prefetch
|
104 |
+
datasets. If None, the buffer size is autotuned. Specifying this is useful
|
105 |
+
in case autotuning uses up too much memory by making the buffer size too
|
106 |
+
high.
|
107 |
+
autotune_algorithm: If specified, use this algorithm for AUTOTUNE. See:
|
108 |
+
https://www.tensorflow.org/api_docs/python/tf/data/experimental/AutotuneAlgorithm
|
109 |
+
"""
|
110 |
+
input_path: Union[Sequence[str], str, base_config.Config] = ""
|
111 |
+
tfds_name: Union[str, base_config.Config] = ""
|
112 |
+
tfds_split: str = ""
|
113 |
+
global_batch_size: int = 0
|
114 |
+
is_training: Optional[bool] = None
|
115 |
+
drop_remainder: bool = True
|
116 |
+
shuffle_buffer_size: int = 100
|
117 |
+
cache: bool = False
|
118 |
+
cycle_length: Optional[int] = None
|
119 |
+
block_length: int = 1
|
120 |
+
deterministic: Optional[bool] = None
|
121 |
+
sharding: bool = True
|
122 |
+
enable_tf_data_service: bool = False
|
123 |
+
tf_data_service_address: Optional[str] = None
|
124 |
+
tf_data_service_job_name: Optional[str] = None
|
125 |
+
tfds_data_dir: str = ""
|
126 |
+
tfds_as_supervised: bool = False
|
127 |
+
tfds_skip_decoding_feature: str = ""
|
128 |
+
enable_shared_tf_data_service_between_parallel_trainers: bool = False
|
129 |
+
apply_tf_data_service_before_batching: bool = False
|
130 |
+
trainer_id: Optional[str] = None
|
131 |
+
seed: Optional[int] = None
|
132 |
+
prefetch_buffer_size: Optional[int] = None
|
133 |
+
autotune_algorithm: Optional[str] = None
|
134 |
+
|
135 |
+
|
136 |
+
@dataclasses.dataclass
|
137 |
+
class RuntimeConfig(base_config.Config):
|
138 |
+
"""High-level configurations for Runtime.
|
139 |
+
|
140 |
+
These include parameters that are not directly related to the experiment,
|
141 |
+
e.g. directories, accelerator type, etc.
|
142 |
+
|
143 |
+
Attributes:
|
144 |
+
distribution_strategy: e.g. 'mirrored', 'tpu', etc.
|
145 |
+
enable_xla: Whether or not to enable XLA.
|
146 |
+
per_gpu_thread_count: thread count per GPU.
|
147 |
+
gpu_thread_mode: Whether and how the GPU device uses its own threadpool.
|
148 |
+
dataset_num_private_threads: Number of threads for a private threadpool
|
149 |
+
created for all datasets computation.
|
150 |
+
tpu: The address of the TPU to use, if any.
|
151 |
+
num_gpus: The number of GPUs to use, if any.
|
152 |
+
worker_hosts: comma-separated list of worker ip:port pairs for running
|
153 |
+
multi-worker models with DistributionStrategy.
|
154 |
+
task_index: If multi-worker training, the task index of this worker.
|
155 |
+
all_reduce_alg: Defines the algorithm for performing all-reduce.
|
156 |
+
num_packs: Sets `num_packs` in the cross device ops used in
|
157 |
+
MirroredStrategy. For details, see tf.distribute.NcclAllReduce.
|
158 |
+
mixed_precision_dtype: dtype of mixed precision policy. It can be 'float32',
|
159 |
+
'float16', or 'bfloat16'.
|
160 |
+
loss_scale: The type of loss scale, or 'float' value. This is used when
|
161 |
+
setting the mixed precision policy.
|
162 |
+
run_eagerly: Whether or not to run the experiment eagerly.
|
163 |
+
batchnorm_spatial_persistent: Whether or not to enable the spatial
|
164 |
+
persistent mode for CuDNN batch norm kernel for improved GPU performance.
|
165 |
+
"""
|
166 |
+
distribution_strategy: str = "mirrored"
|
167 |
+
enable_xla: bool = False
|
168 |
+
gpu_thread_mode: Optional[str] = None
|
169 |
+
dataset_num_private_threads: Optional[int] = None
|
170 |
+
per_gpu_thread_count: int = 0
|
171 |
+
tpu: Optional[str] = None
|
172 |
+
num_gpus: int = 0
|
173 |
+
worker_hosts: Optional[str] = None
|
174 |
+
task_index: int = -1
|
175 |
+
all_reduce_alg: Optional[str] = None
|
176 |
+
num_packs: int = 1
|
177 |
+
mixed_precision_dtype: Optional[str] = None
|
178 |
+
loss_scale: Optional[Union[str, float]] = None
|
179 |
+
run_eagerly: bool = False
|
180 |
+
batchnorm_spatial_persistent: bool = False
|
181 |
+
|
182 |
+
# XLA runtime params.
|
183 |
+
# XLA params are only applied to the train_step.
|
184 |
+
# These augments can improve training speed. They can also improve eval, but
|
185 |
+
# may reduce usability and users would need to make changes to code.
|
186 |
+
|
187 |
+
# Whether to enable XLA dynamic padder
|
188 |
+
# infrastructure to handle dynamic shapes inputs inside XLA. True by
|
189 |
+
# default. Disabling this may cause correctness issues with dynamic shapes
|
190 |
+
# inputs, as XLA will just assume the inputs are with padded shapes. However
|
191 |
+
# users can optionally set it to False to improve device time if masking is
|
192 |
+
# already handled in the user side.
|
193 |
+
# If None, will respect XLA default.
|
194 |
+
tpu_enable_xla_dynamic_padder: Optional[bool] = None
|
195 |
+
|
196 |
+
# Global model parallelism configurations.
|
197 |
+
num_cores_per_replica: int = 1
|
198 |
+
default_shard_dim: int = -1
|
199 |
+
use_tpu_mp_strategy: bool = False
|
200 |
+
|
201 |
+
def model_parallelism(self):
|
202 |
+
return dict(
|
203 |
+
num_cores_per_replica=self.num_cores_per_replica,
|
204 |
+
default_shard_dim=self.default_shard_dim)
|
205 |
+
|
206 |
+
|
207 |
+
@dataclasses.dataclass
|
208 |
+
class TrainerConfig(base_config.Config):
|
209 |
+
"""Configuration for trainer.
|
210 |
+
|
211 |
+
Attributes:
|
212 |
+
optimizer_config: optimizer config, it includes optimizer, learning rate,
|
213 |
+
and warmup schedule configs.
|
214 |
+
train_tf_while_loop: whether or not to use tf while loop.
|
215 |
+
train_tf_function: whether or not to use tf_function for training loop.
|
216 |
+
eval_tf_function: whether or not to use tf_function for eval.
|
217 |
+
eval_tf_while_loop: whether or not to use tf while loop for eval.
|
218 |
+
allow_tpu_summary: Whether to allow summary happen inside the XLA program
|
219 |
+
runs on TPU through automatic outside compilation.
|
220 |
+
steps_per_loop: number of steps per loop to report training metrics. This
|
221 |
+
can also be used to reduce host worker communication in a TPU setup.
|
222 |
+
summary_interval: number of steps between each summary.
|
223 |
+
checkpoint_interval: number of steps between checkpoints.
|
224 |
+
max_to_keep: max checkpoints to keep.
|
225 |
+
continuous_eval_timeout: maximum number of seconds to wait between
|
226 |
+
checkpoints, if set to None, continuous eval will wait indefinitely. This
|
227 |
+
is only used continuous_train_and_eval and continuous_eval modes. Default
|
228 |
+
value is 1 hrs.
|
229 |
+
train_steps: number of train steps.
|
230 |
+
validation_steps: number of eval steps. If -1, the entire eval dataset is
|
231 |
+
used.
|
232 |
+
validation_interval: number of training steps to run between evaluations.
|
233 |
+
best_checkpoint_export_subdir: if set, the trainer will keep track of the
|
234 |
+
best evaluation metric, and export the corresponding best checkpoint under
|
235 |
+
`model_dir/best_checkpoint_export_subdir`. Note that this only works if
|
236 |
+
mode contains eval (such as `train_and_eval`, `continuous_eval`, and
|
237 |
+
`continuous_train_and_eval`).
|
238 |
+
best_checkpoint_eval_metric: for exporting the best checkpoint, which
|
239 |
+
evaluation metric the trainer should monitor. This can be any evaluation
|
240 |
+
metric appears on tensorboard.
|
241 |
+
best_checkpoint_metric_comp: for exporting the best checkpoint, how the
|
242 |
+
trainer should compare the evaluation metrics. This can be either `higher`
|
243 |
+
(higher the better) or `lower` (lower the better).
|
244 |
+
validation_summary_subdir: A 'str', sub directory for saving eval summary.
|
245 |
+
preemption_on_demand_checkpoint: whether or not to save on-demand
|
246 |
+
checkpoints after a preemption.
|
247 |
+
"""
|
248 |
+
optimizer_config: OptimizationConfig = dataclasses.field(
|
249 |
+
default_factory=OptimizationConfig
|
250 |
+
)
|
251 |
+
# Orbit settings.
|
252 |
+
train_tf_while_loop: bool = True
|
253 |
+
train_tf_function: bool = True
|
254 |
+
eval_tf_function: bool = True
|
255 |
+
eval_tf_while_loop: bool = False
|
256 |
+
allow_tpu_summary: bool = False
|
257 |
+
# Trainer intervals.
|
258 |
+
steps_per_loop: int = 1000
|
259 |
+
summary_interval: int = 1000
|
260 |
+
checkpoint_interval: int = 1000
|
261 |
+
# Checkpoint manager.
|
262 |
+
max_to_keep: int = 5
|
263 |
+
continuous_eval_timeout: int = 60 * 60
|
264 |
+
# Train/Eval routines.
|
265 |
+
train_steps: int = 0
|
266 |
+
# Sets validation steps to be -1 to evaluate the entire dataset.
|
267 |
+
validation_steps: int = -1
|
268 |
+
validation_interval: int = 1000
|
269 |
+
# Best checkpoint export.
|
270 |
+
best_checkpoint_export_subdir: str = ""
|
271 |
+
best_checkpoint_eval_metric: str = ""
|
272 |
+
best_checkpoint_metric_comp: str = "higher"
|
273 |
+
# Blowup recovery.
|
274 |
+
loss_upper_bound: float = 1e6
|
275 |
+
recovery_begin_steps: int = 0 # Enforcing the loss bound after these steps.
|
276 |
+
# When max trials < 0, no recovery module; max trials = 0, we will check
|
277 |
+
# the condition and fail the job if the condition happens; max trials > 0,
|
278 |
+
# we will retore the model states.
|
279 |
+
recovery_max_trials: int = 0
|
280 |
+
validation_summary_subdir: str = "validation"
|
281 |
+
# Preemption on-demand checkpoint.
|
282 |
+
preemption_on_demand_checkpoint: bool = True # copybara-replace
|
283 |
+
|
284 |
+
|
285 |
+
@dataclasses.dataclass
|
286 |
+
class TaskConfig(base_config.Config):
|
287 |
+
"""Config passed to task."""
|
288 |
+
init_checkpoint: str = ""
|
289 |
+
model: Optional[base_config.Config] = None
|
290 |
+
train_data: DataConfig = dataclasses.field(default_factory=DataConfig)
|
291 |
+
validation_data: DataConfig = dataclasses.field(default_factory=DataConfig)
|
292 |
+
name: Optional[str] = None
|
293 |
+
# Configs for differential privacy
|
294 |
+
# These configs are only effective if you use create_optimizer in
|
295 |
+
# tensorflow_models/official/core/base_task.py
|
296 |
+
# DEPRECATED b/264611883
|
297 |
+
differential_privacy_config: Optional[
|
298 |
+
dp_configs.DifferentialPrivacyConfig] = None
|
299 |
+
# Whether to show image summary. Useful to visualize model predictions. Only
|
300 |
+
# work for vision tasks.
|
301 |
+
allow_image_summary: bool = False
|
302 |
+
|
303 |
+
|
304 |
+
@dataclasses.dataclass
|
305 |
+
class ExperimentConfig(base_config.Config):
|
306 |
+
"""Top-level configuration."""
|
307 |
+
task: TaskConfig = dataclasses.field(default_factory=TaskConfig)
|
308 |
+
trainer: TrainerConfig = dataclasses.field(default_factory=TrainerConfig)
|
309 |
+
runtime: RuntimeConfig = dataclasses.field(default_factory=RuntimeConfig)
|
modeling/official/core/exp_factory.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Experiment factory methods."""
|
16 |
+
|
17 |
+
from official.core import config_definitions as cfg
|
18 |
+
from official.core import registry
|
19 |
+
|
20 |
+
|
21 |
+
_REGISTERED_CONFIGS = {}
|
22 |
+
|
23 |
+
|
24 |
+
def register_config_factory(name):
|
25 |
+
"""Register ExperimentConfig factory method."""
|
26 |
+
return registry.register(_REGISTERED_CONFIGS, name)
|
27 |
+
|
28 |
+
|
29 |
+
def get_exp_config(exp_name: str) -> cfg.ExperimentConfig:
|
30 |
+
"""Looks up the `ExperimentConfig` according to the `exp_name`."""
|
31 |
+
exp_creater = registry.lookup(_REGISTERED_CONFIGS, exp_name)
|
32 |
+
return exp_creater()
|
modeling/official/core/export_base.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Base class for model export."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
import functools
|
19 |
+
import time
|
20 |
+
from typing import Any, Callable, Dict, Mapping, List, Optional, Text, Union
|
21 |
+
|
22 |
+
from absl import logging
|
23 |
+
import tensorflow as tf, tf_keras
|
24 |
+
|
25 |
+
MAX_DIRECTORY_CREATION_ATTEMPTS = 10
|
26 |
+
|
27 |
+
|
28 |
+
class ExportModule(tf.Module, metaclass=abc.ABCMeta):
|
29 |
+
"""Base Export Module."""
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
params,
|
33 |
+
model: Union[tf.Module, tf_keras.Model],
|
34 |
+
inference_step: Optional[Callable[..., Any]] = None,
|
35 |
+
*,
|
36 |
+
preprocessor: Optional[Callable[..., Any]] = None,
|
37 |
+
postprocessor: Optional[Callable[..., Any]] = None):
|
38 |
+
"""Instantiates an ExportModel.
|
39 |
+
|
40 |
+
Examples:
|
41 |
+
|
42 |
+
`inference_step` must be a function that has `model` as an kwarg or the
|
43 |
+
second positional argument.
|
44 |
+
```
|
45 |
+
def _inference_step(inputs, model=None):
|
46 |
+
return model(inputs, training=False)
|
47 |
+
|
48 |
+
module = ExportModule(params, model, inference_step=_inference_step)
|
49 |
+
```
|
50 |
+
|
51 |
+
`preprocessor` and `postprocessor` could be either functions or `tf.Module`.
|
52 |
+
The usages of preprocessor and postprocessor are managed by the
|
53 |
+
implementation of `serve()` method.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
params: A dataclass for parameters to the module.
|
57 |
+
model: A model instance which contains weights and forward computation.
|
58 |
+
inference_step: An optional callable to forward-pass the model. If not
|
59 |
+
specified, it creates a parital function with `model` as an required
|
60 |
+
kwarg.
|
61 |
+
preprocessor: An optional callable to preprocess the inputs.
|
62 |
+
postprocessor: An optional callable to postprocess the model outputs.
|
63 |
+
"""
|
64 |
+
super().__init__(name=None)
|
65 |
+
self.model = model
|
66 |
+
self.params = params
|
67 |
+
|
68 |
+
if inference_step is not None:
|
69 |
+
self.inference_step = functools.partial(inference_step, model=self.model)
|
70 |
+
else:
|
71 |
+
if issubclass(type(model), tf_keras.Model):
|
72 |
+
# Default to self.model.call instead of self.model.__call__ to avoid
|
73 |
+
# keras tracing logic designed for training.
|
74 |
+
# Since most of Model Garden's call doesn't not have training kwargs
|
75 |
+
# or the default is False, we don't pass anything here.
|
76 |
+
# Please pass custom inference step if your model has training=True as
|
77 |
+
# default.
|
78 |
+
self.inference_step = self.model.call
|
79 |
+
else:
|
80 |
+
self.inference_step = functools.partial(
|
81 |
+
self.model.__call__, training=False)
|
82 |
+
self.preprocessor = preprocessor
|
83 |
+
self.postprocessor = postprocessor
|
84 |
+
|
85 |
+
@abc.abstractmethod
|
86 |
+
def serve(self) -> Mapping[Text, tf.Tensor]:
|
87 |
+
"""The bare inference function which should run on all devices.
|
88 |
+
|
89 |
+
Expecting tensors are passed in through keyword arguments. Returns a
|
90 |
+
dictionary of tensors, when the keys will be used inside the SignatureDef.
|
91 |
+
"""
|
92 |
+
|
93 |
+
@abc.abstractmethod
|
94 |
+
def get_inference_signatures(
|
95 |
+
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
|
96 |
+
"""Get defined function signatures."""
|
97 |
+
|
98 |
+
|
99 |
+
def export(export_module: ExportModule,
|
100 |
+
function_keys: Union[List[Text], Dict[Text, Text]],
|
101 |
+
export_savedmodel_dir: Text,
|
102 |
+
checkpoint_path: Optional[Text] = None,
|
103 |
+
timestamped: bool = True,
|
104 |
+
save_options: Optional[tf.saved_model.SaveOptions] = None,
|
105 |
+
checkpoint: Optional[tf.train.Checkpoint] = None) -> Text:
|
106 |
+
"""Exports to SavedModel format.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
export_module: a ExportModule with the keras Model and serving tf.functions.
|
110 |
+
function_keys: a list of string keys to retrieve pre-defined serving
|
111 |
+
signatures. The signaute keys will be set with defaults. If a dictionary
|
112 |
+
is provided, the values will be used as signature keys.
|
113 |
+
export_savedmodel_dir: Output saved model directory.
|
114 |
+
checkpoint_path: Object-based checkpoint path or directory.
|
115 |
+
timestamped: Whether to export the savedmodel to a timestamped directory.
|
116 |
+
save_options: `SaveOptions` for `tf.saved_model.save`.
|
117 |
+
checkpoint: An optional tf.train.Checkpoint. If provided, the export module
|
118 |
+
will use it to read the weights.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
The savedmodel directory path.
|
122 |
+
"""
|
123 |
+
ckpt_dir_or_file = checkpoint_path
|
124 |
+
if ckpt_dir_or_file is not None and tf.io.gfile.isdir(ckpt_dir_or_file):
|
125 |
+
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
|
126 |
+
if ckpt_dir_or_file:
|
127 |
+
if checkpoint is None:
|
128 |
+
checkpoint = tf.train.Checkpoint(model=export_module.model)
|
129 |
+
checkpoint.read(
|
130 |
+
ckpt_dir_or_file).assert_existing_objects_matched().expect_partial()
|
131 |
+
if isinstance(function_keys, list):
|
132 |
+
if len(function_keys) == 1:
|
133 |
+
function_keys = {
|
134 |
+
function_keys[0]: tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
135 |
+
}
|
136 |
+
else:
|
137 |
+
raise ValueError(
|
138 |
+
'If the function_keys is a list, it must contain a single element. %s'
|
139 |
+
% function_keys)
|
140 |
+
|
141 |
+
signatures = export_module.get_inference_signatures(function_keys)
|
142 |
+
if timestamped:
|
143 |
+
export_dir = get_timestamped_export_dir(export_savedmodel_dir).decode(
|
144 |
+
'utf-8')
|
145 |
+
else:
|
146 |
+
export_dir = export_savedmodel_dir
|
147 |
+
tf.saved_model.save(
|
148 |
+
export_module, export_dir, signatures=signatures, options=save_options)
|
149 |
+
return export_dir
|
150 |
+
|
151 |
+
|
152 |
+
def get_timestamped_export_dir(export_dir_base):
|
153 |
+
"""Builds a path to a new subdirectory within the base directory.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
export_dir_base: A string containing a directory to write the exported graph
|
157 |
+
and checkpoints.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
The full path of the new subdirectory (which is not actually created yet).
|
161 |
+
|
162 |
+
Raises:
|
163 |
+
RuntimeError: if repeated attempts fail to obtain a unique timestamped
|
164 |
+
directory name.
|
165 |
+
"""
|
166 |
+
attempts = 0
|
167 |
+
while attempts < MAX_DIRECTORY_CREATION_ATTEMPTS:
|
168 |
+
timestamp = int(time.time())
|
169 |
+
|
170 |
+
result_dir = tf.io.gfile.join(
|
171 |
+
tf.compat.as_bytes(export_dir_base), tf.compat.as_bytes(str(timestamp)))
|
172 |
+
if not tf.io.gfile.exists(result_dir):
|
173 |
+
# Collisions are still possible (though extremely unlikely): this
|
174 |
+
# directory is not actually created yet, but it will be almost
|
175 |
+
# instantly on return from this function.
|
176 |
+
return result_dir
|
177 |
+
time.sleep(1)
|
178 |
+
attempts += 1
|
179 |
+
logging.warning('Directory %s already exists; retrying (attempt %s/%s)',
|
180 |
+
str(result_dir), attempts, MAX_DIRECTORY_CREATION_ATTEMPTS)
|
181 |
+
raise RuntimeError('Failed to obtain a unique export directory name after '
|
182 |
+
f'{MAX_DIRECTORY_CREATION_ATTEMPTS} attempts.')
|
modeling/official/core/export_base_test.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.core.export_base."""
|
16 |
+
import os
|
17 |
+
from typing import Any, Dict, Mapping, Text
|
18 |
+
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from official.core import export_base
|
22 |
+
|
23 |
+
|
24 |
+
class TestModule(export_base.ExportModule):
|
25 |
+
|
26 |
+
@tf.function
|
27 |
+
def serve(self, inputs: tf.Tensor) -> Mapping[Text, tf.Tensor]:
|
28 |
+
x = inputs if self.preprocessor is None else self.preprocessor(
|
29 |
+
inputs=inputs)
|
30 |
+
x = self.inference_step(x)
|
31 |
+
x = self.postprocessor(x) if self.postprocessor else x
|
32 |
+
return {'outputs': x}
|
33 |
+
|
34 |
+
def get_inference_signatures(
|
35 |
+
self, function_keys: Dict[Text, Text]) -> Mapping[Text, Any]:
|
36 |
+
input_signature = tf.TensorSpec(shape=[None, None], dtype=tf.float32)
|
37 |
+
return {'foo': self.serve.get_concrete_function(input_signature)}
|
38 |
+
|
39 |
+
|
40 |
+
class ExportBaseTest(tf.test.TestCase):
|
41 |
+
|
42 |
+
def test_export_module(self):
|
43 |
+
tmp_dir = self.get_temp_dir()
|
44 |
+
model = tf_keras.layers.Dense(2)
|
45 |
+
inputs = tf.ones([2, 4], tf.float32)
|
46 |
+
expected_output = model(inputs, training=False)
|
47 |
+
module = TestModule(params=None, model=model)
|
48 |
+
ckpt_path = tf.train.Checkpoint(model=model).save(
|
49 |
+
os.path.join(tmp_dir, 'ckpt'))
|
50 |
+
export_dir = export_base.export(
|
51 |
+
module, ['foo'],
|
52 |
+
export_savedmodel_dir=tmp_dir,
|
53 |
+
checkpoint_path=ckpt_path,
|
54 |
+
timestamped=True)
|
55 |
+
self.assertTrue(os.path.exists(os.path.join(export_dir, 'saved_model.pb')))
|
56 |
+
self.assertTrue(
|
57 |
+
os.path.exists(
|
58 |
+
os.path.join(export_dir, 'variables', 'variables.index')))
|
59 |
+
self.assertTrue(
|
60 |
+
os.path.exists(
|
61 |
+
os.path.join(export_dir, 'variables',
|
62 |
+
'variables.data-00000-of-00001')))
|
63 |
+
|
64 |
+
imported = tf.saved_model.load(export_dir)
|
65 |
+
output = imported.signatures['foo'](inputs)
|
66 |
+
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
|
67 |
+
|
68 |
+
def test_custom_inference_step(self):
|
69 |
+
tmp_dir = self.get_temp_dir()
|
70 |
+
model = tf_keras.layers.Dense(2)
|
71 |
+
inputs = tf.ones([2, 4], tf.float32)
|
72 |
+
|
73 |
+
def _inference_step(inputs, model):
|
74 |
+
return tf.nn.softmax(model(inputs, training=False))
|
75 |
+
|
76 |
+
module = TestModule(
|
77 |
+
params=None, model=model, inference_step=_inference_step)
|
78 |
+
expected_output = _inference_step(inputs, model)
|
79 |
+
ckpt_path = tf.train.Checkpoint(model=model).save(
|
80 |
+
os.path.join(tmp_dir, 'ckpt'))
|
81 |
+
export_dir = export_base.export(
|
82 |
+
module, ['foo'],
|
83 |
+
export_savedmodel_dir=tmp_dir,
|
84 |
+
checkpoint_path=ckpt_path,
|
85 |
+
timestamped=False)
|
86 |
+
imported = tf.saved_model.load(export_dir)
|
87 |
+
output = imported.signatures['foo'](inputs)
|
88 |
+
self.assertAllClose(output['outputs'].numpy(), expected_output.numpy())
|
89 |
+
|
90 |
+
def test_processors(self):
|
91 |
+
model = tf.Module()
|
92 |
+
inputs = tf.zeros((), tf.float32)
|
93 |
+
|
94 |
+
def _inference_step(inputs, model):
|
95 |
+
del model
|
96 |
+
return inputs + 1.0
|
97 |
+
|
98 |
+
def _preprocessor(inputs):
|
99 |
+
print(inputs)
|
100 |
+
return inputs + 0.1
|
101 |
+
|
102 |
+
module = TestModule(
|
103 |
+
params=None,
|
104 |
+
model=model,
|
105 |
+
inference_step=_inference_step,
|
106 |
+
preprocessor=_preprocessor)
|
107 |
+
output = module.serve(inputs)
|
108 |
+
self.assertAllClose(output['outputs'].numpy(), 1.1)
|
109 |
+
|
110 |
+
class _PostProcessor(tf.Module):
|
111 |
+
|
112 |
+
def __call__(self, inputs):
|
113 |
+
return inputs + 0.01
|
114 |
+
|
115 |
+
module = TestModule(
|
116 |
+
params=None,
|
117 |
+
model=model,
|
118 |
+
inference_step=_inference_step,
|
119 |
+
preprocessor=_preprocessor,
|
120 |
+
postprocessor=_PostProcessor())
|
121 |
+
output = module.serve(inputs)
|
122 |
+
self.assertAllClose(output['outputs'].numpy(), 1.11)
|
123 |
+
|
124 |
+
def test_get_timestamped_export_dir(self):
|
125 |
+
export_dir = self.get_temp_dir()
|
126 |
+
timed_dir = export_base.get_timestamped_export_dir(
|
127 |
+
export_dir_base=export_dir)
|
128 |
+
self.assertFalse(tf.io.gfile.exists(timed_dir))
|
129 |
+
self.assertIn(export_dir, str(timed_dir))
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
tf.test.main()
|
modeling/official/core/file_writers.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""File writer functions for dataset preparation, infra validation, and unit tests."""
|
16 |
+
|
17 |
+
import io
|
18 |
+
from typing import Optional, Sequence, Union
|
19 |
+
|
20 |
+
import tensorflow as tf, tf_keras
|
21 |
+
|
22 |
+
|
23 |
+
def write_small_dataset(examples: Sequence[Union[tf.train.Example,
|
24 |
+
tf.train.SequenceExample]],
|
25 |
+
output_path: str,
|
26 |
+
file_type: str = 'tfrecord') -> None:
|
27 |
+
"""Writes `examples` to a file at `output_path` with type `file_type`.
|
28 |
+
|
29 |
+
CAVEAT: This function is not recommended for writing large datasets, since it
|
30 |
+
will loop through `examples` and perform write operation sequentially.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
examples: List of tf.train.Example or tf.train.SequenceExample.
|
34 |
+
output_path: Output path for the dataset.
|
35 |
+
file_type: A string indicating the file format, could be: 'tfrecord',
|
36 |
+
'tfrecords', 'tfrecord_compressed', 'tfrecords_gzip', 'riegeli'. The
|
37 |
+
string is case insensitive.
|
38 |
+
"""
|
39 |
+
file_type = file_type.lower()
|
40 |
+
|
41 |
+
if file_type == 'tfrecord' or file_type == 'tfrecords':
|
42 |
+
_write_tfrecord(examples, output_path)
|
43 |
+
elif file_type == 'tfrecord_compressed' or file_type == 'tfrecords_gzip':
|
44 |
+
_write_tfrecord(examples, output_path,
|
45 |
+
tf.io.TFRecordOptions(compression_type='GZIP'))
|
46 |
+
elif file_type == 'riegeli':
|
47 |
+
_write_riegeli(examples, output_path)
|
48 |
+
else:
|
49 |
+
raise ValueError(f'Unknown file_type: {file_type}')
|
50 |
+
|
51 |
+
|
52 |
+
def _write_tfrecord(examples: Sequence[Union[tf.train.Example,
|
53 |
+
tf.train.SequenceExample]],
|
54 |
+
output_path: str,
|
55 |
+
options: Optional[tf.io.TFRecordOptions] = None) -> None:
|
56 |
+
"""Writes `examples` to a TFRecord file at `output_path`.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
examples: A list of tf.train.Example.
|
60 |
+
output_path: Output path for the dataset.
|
61 |
+
options: Options used for manipulating TFRecord files.
|
62 |
+
"""
|
63 |
+
with tf.io.TFRecordWriter(output_path, options) as writer:
|
64 |
+
for example in examples:
|
65 |
+
writer.write(example.SerializeToString())
|
66 |
+
|
67 |
+
|
68 |
+
def _write_riegeli(examples: Sequence[Union[tf.train.Example,
|
69 |
+
tf.train.SequenceExample]],
|
70 |
+
output_path: str) -> None:
|
71 |
+
"""Writes `examples` to a Riegeli file at `output_path`.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
examples: A list of tf.train.Example.
|
75 |
+
output_path: Output path for the dataset.
|
76 |
+
"""
|
77 |
+
with io.FileIO(output_path, 'wb') as fileio:
|
78 |
+
import riegeli # pylint: disable=g-import-not-at-top
|
79 |
+
with riegeli.RecordWriter(fileio) as writer:
|
80 |
+
writer.write_messages(examples)
|
modeling/official/core/file_writers_test.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for file_writers."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
from absl.testing import parameterized
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from official.core import file_writers
|
22 |
+
from official.core import tf_example_builder
|
23 |
+
|
24 |
+
|
25 |
+
class FileWritersTest(tf.test.TestCase, parameterized.TestCase):
|
26 |
+
|
27 |
+
def setUp(self):
|
28 |
+
super().setUp()
|
29 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
30 |
+
example_builder.add_bytes_feature('foo', 'Hello World!')
|
31 |
+
self._example = example_builder.example
|
32 |
+
|
33 |
+
@parameterized.parameters('tfrecord', 'TFRecord', 'tfrecords',
|
34 |
+
'tfrecord_compressed', 'TFRecord_Compressed',
|
35 |
+
'tfrecords_gzip')
|
36 |
+
def test_write_small_dataset_success(self, file_type):
|
37 |
+
temp_dir = self.create_tempdir()
|
38 |
+
temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
|
39 |
+
file_writers.write_small_dataset([self._example], temp_dataset_file,
|
40 |
+
file_type)
|
41 |
+
self.assertTrue(os.path.exists(temp_dataset_file))
|
42 |
+
|
43 |
+
def test_write_small_dataset_unrecognized_format(self):
|
44 |
+
file_type = 'bar'
|
45 |
+
temp_dir = self.create_tempdir()
|
46 |
+
temp_dataset_file = os.path.join(temp_dir.full_path, 'train')
|
47 |
+
with self.assertRaises(ValueError):
|
48 |
+
file_writers.write_small_dataset([self._example], temp_dataset_file,
|
49 |
+
file_type)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
tf.test.main()
|
modeling/official/core/input_reader.py
ADDED
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A common dataset reader."""
|
16 |
+
import dataclasses
|
17 |
+
import random
|
18 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Union
|
19 |
+
|
20 |
+
from absl import logging
|
21 |
+
import tensorflow as tf, tf_keras
|
22 |
+
import tensorflow_datasets as tfds
|
23 |
+
|
24 |
+
from official.core import config_definitions as cfg
|
25 |
+
|
26 |
+
|
27 |
+
def _get_random_integer():
|
28 |
+
return random.randint(0, (1 << 31) - 1)
|
29 |
+
|
30 |
+
|
31 |
+
def _maybe_map_fn(dataset: tf.data.Dataset,
|
32 |
+
fn: Optional[Callable[..., Any]] = None) -> tf.data.Dataset:
|
33 |
+
"""Calls dataset.map if a valid function is passed in."""
|
34 |
+
return dataset if fn is None else dataset.map(
|
35 |
+
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
36 |
+
|
37 |
+
|
38 |
+
def match_files(input_path: Union[Sequence[str], str]) -> List[str]:
|
39 |
+
"""Matches files from an input_path."""
|
40 |
+
matched_files = []
|
41 |
+
# Read dataset from files.
|
42 |
+
usage = ('`input_path` should be either (1) a str indicating a file '
|
43 |
+
'path/pattern, or (2) a str indicating multiple file '
|
44 |
+
'paths/patterns separated by comma (e.g "a, b, c" or no spaces '
|
45 |
+
'"a,b,c", or (3) a list of str, each of which is a file '
|
46 |
+
'path/pattern or multiple file paths/patterns separated by '
|
47 |
+
'comma, but got: %s')
|
48 |
+
if isinstance(input_path, str):
|
49 |
+
input_path_list = [input_path]
|
50 |
+
elif isinstance(input_path, (list, tuple)):
|
51 |
+
if any(not isinstance(x, str) for x in input_path):
|
52 |
+
raise ValueError(usage % input_path)
|
53 |
+
input_path_list = input_path
|
54 |
+
else:
|
55 |
+
raise ValueError(usage % input_path)
|
56 |
+
|
57 |
+
for input_path in input_path_list:
|
58 |
+
input_patterns = input_path.strip().split(',')
|
59 |
+
for input_pattern in input_patterns:
|
60 |
+
input_pattern = input_pattern.strip()
|
61 |
+
if not input_pattern:
|
62 |
+
continue
|
63 |
+
if '*' in input_pattern or '?' in input_pattern:
|
64 |
+
tmp_matched_files = tf.io.gfile.glob(input_pattern)
|
65 |
+
if not tmp_matched_files:
|
66 |
+
raise ValueError('%s does not match any files.' % input_pattern)
|
67 |
+
matched_files.extend(tmp_matched_files)
|
68 |
+
else:
|
69 |
+
matched_files.append(input_pattern)
|
70 |
+
|
71 |
+
if not matched_files:
|
72 |
+
raise ValueError('%s does not match any files.' % input_path)
|
73 |
+
|
74 |
+
return matched_files
|
75 |
+
|
76 |
+
|
77 |
+
def _read_files_then_shard(matched_files: List[str],
|
78 |
+
dataset_fn,
|
79 |
+
input_context: Optional[
|
80 |
+
tf.distribute.InputContext] = None,
|
81 |
+
sharding: bool = False,
|
82 |
+
repeat: bool = False) -> tf.data.Dataset:
|
83 |
+
"""Sends all data files to every worker and then shard by data."""
|
84 |
+
dataset = dataset_fn(matched_files)
|
85 |
+
|
86 |
+
# When `input_file` is a path to a single file or the number of files is
|
87 |
+
# less than the number of input pipelines, disable auto sharding
|
88 |
+
# so that same input file is sent to all workers.
|
89 |
+
options = tf.data.Options()
|
90 |
+
options.experimental_distribute.auto_shard_policy = (
|
91 |
+
tf.data.experimental.AutoShardPolicy.OFF)
|
92 |
+
dataset = dataset.with_options(options)
|
93 |
+
# Do not enable sharding if tf.data service is enabled, as sharding will be
|
94 |
+
# handled inside tf.data service.
|
95 |
+
if sharding and input_context and (input_context.num_input_pipelines > 1):
|
96 |
+
dataset = dataset.shard(input_context.num_input_pipelines,
|
97 |
+
input_context.input_pipeline_id)
|
98 |
+
|
99 |
+
if repeat:
|
100 |
+
dataset = dataset.repeat()
|
101 |
+
return dataset
|
102 |
+
|
103 |
+
|
104 |
+
def _shard_files_then_read(matched_files: List[str],
|
105 |
+
dataset_fn,
|
106 |
+
input_context: Optional[
|
107 |
+
tf.distribute.InputContext] = None,
|
108 |
+
seed: Optional[Union[int, tf.Tensor]] = None,
|
109 |
+
is_training: bool = False,
|
110 |
+
sharding: bool = False,
|
111 |
+
cache: bool = False,
|
112 |
+
cycle_length: Optional[int] = None,
|
113 |
+
block_length: Optional[int] = None,
|
114 |
+
deterministic: bool = False) -> tf.data.Dataset:
|
115 |
+
"""Shards the data files and then sent a split to every worker to read."""
|
116 |
+
dataset = tf.data.Dataset.from_tensor_slices(matched_files)
|
117 |
+
|
118 |
+
# Shuffle and repeat at file level.
|
119 |
+
# If cache is enabled, `reshuffle_each_iteration` is set to False,
|
120 |
+
# because we will read the same cached data in every iteration anyway.
|
121 |
+
if is_training:
|
122 |
+
# We need a seed to shuffle the files so that when each TPU workers gets
|
123 |
+
# its own shard the files do not overlap.
|
124 |
+
if sharding and seed is None:
|
125 |
+
seed = _get_random_integer()
|
126 |
+
dataset = dataset.shuffle(
|
127 |
+
len(matched_files),
|
128 |
+
seed=seed,
|
129 |
+
reshuffle_each_iteration=True if not cache else False)
|
130 |
+
|
131 |
+
# Do not enable sharding if tf.data service is enabled, as sharding will be
|
132 |
+
# handled inside tf.data service.
|
133 |
+
if sharding and input_context and (input_context.num_input_pipelines > 1):
|
134 |
+
dataset = dataset.shard(input_context.num_input_pipelines,
|
135 |
+
input_context.input_pipeline_id)
|
136 |
+
|
137 |
+
# If cache is enabled, we will call `repeat()` later after `cache()`.
|
138 |
+
if is_training and not cache:
|
139 |
+
dataset = dataset.repeat()
|
140 |
+
|
141 |
+
dataset = dataset.interleave(
|
142 |
+
map_func=dataset_fn,
|
143 |
+
cycle_length=cycle_length,
|
144 |
+
block_length=block_length,
|
145 |
+
num_parallel_calls=(cycle_length
|
146 |
+
if cycle_length else tf.data.experimental.AUTOTUNE),
|
147 |
+
deterministic=deterministic)
|
148 |
+
return dataset
|
149 |
+
|
150 |
+
|
151 |
+
def _read_tfds(tfds_name: Text,
|
152 |
+
tfds_data_dir: Text,
|
153 |
+
tfds_split: Text,
|
154 |
+
tfds_skip_decoding_feature: Text,
|
155 |
+
tfds_as_supervised: bool,
|
156 |
+
input_context: Optional[tf.distribute.InputContext] = None,
|
157 |
+
seed: Optional[Union[int, tf.Tensor]] = None,
|
158 |
+
is_training: bool = False,
|
159 |
+
cache: bool = False,
|
160 |
+
cycle_length: Optional[int] = None,
|
161 |
+
block_length: Optional[int] = None) -> tf.data.Dataset:
|
162 |
+
"""Reads a dataset from tfds."""
|
163 |
+
repeat_filenames = is_training and not cache
|
164 |
+
read_config = tfds.ReadConfig(
|
165 |
+
interleave_cycle_length=cycle_length,
|
166 |
+
interleave_block_length=block_length,
|
167 |
+
input_context=input_context,
|
168 |
+
shuffle_seed=seed,
|
169 |
+
repeat_filenames=repeat_filenames,
|
170 |
+
# Only assert cardinality when we have a finite dataset.
|
171 |
+
assert_cardinality=not repeat_filenames,
|
172 |
+
skip_prefetch=True)
|
173 |
+
|
174 |
+
decoders = {}
|
175 |
+
if tfds_skip_decoding_feature:
|
176 |
+
for skip_feature in tfds_skip_decoding_feature.split(','):
|
177 |
+
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
|
178 |
+
|
179 |
+
if tfds_name.startswith('mldataset.'):
|
180 |
+
dataset = tfds.load(name=tfds_name,
|
181 |
+
split=tfds_split,
|
182 |
+
as_supervised=tfds_as_supervised,
|
183 |
+
decoders=decoders if decoders else None,
|
184 |
+
read_config=read_config)
|
185 |
+
else:
|
186 |
+
builder = tfds.builder(tfds_name, data_dir=tfds_data_dir)
|
187 |
+
if builder.info.splits:
|
188 |
+
num_shards = len(builder.info.splits[tfds_split].file_instructions)
|
189 |
+
else:
|
190 |
+
# The tfds mock path often does not provide splits.
|
191 |
+
num_shards = 1
|
192 |
+
load_kwargs = dict(
|
193 |
+
name=tfds_name, download=True, split=tfds_split,
|
194 |
+
shuffle_files=is_training, as_supervised=tfds_as_supervised,
|
195 |
+
decoders=decoders if decoders else None)
|
196 |
+
if tfds_data_dir:
|
197 |
+
load_kwargs.update({'data_dir': tfds_data_dir})
|
198 |
+
|
199 |
+
if input_context and num_shards < input_context.num_input_pipelines:
|
200 |
+
# The number of files in the dataset split is smaller than the number of
|
201 |
+
# input pipelines. We read the entire dataset first and then shard in the
|
202 |
+
# host memory.
|
203 |
+
read_config = dataclasses.replace(read_config, input_context=None)
|
204 |
+
load_kwargs.update({'read_config': read_config})
|
205 |
+
dataset = tfds.load(**load_kwargs)
|
206 |
+
dataset = dataset.shard(input_context.num_input_pipelines,
|
207 |
+
input_context.input_pipeline_id)
|
208 |
+
else:
|
209 |
+
load_kwargs.update({'read_config': read_config})
|
210 |
+
dataset = tfds.load(**load_kwargs)
|
211 |
+
return dataset
|
212 |
+
|
213 |
+
|
214 |
+
class InputReader:
|
215 |
+
"""Input reader that returns a tf.data.Dataset instance."""
|
216 |
+
|
217 |
+
# A static random number which is the same across different InputReader
|
218 |
+
# instances.
|
219 |
+
static_randnum = _get_random_integer()
|
220 |
+
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
params: cfg.DataConfig,
|
224 |
+
dataset_fn=tf.data.TFRecordDataset,
|
225 |
+
decoder_fn: Optional[Callable[..., Any]] = None,
|
226 |
+
combine_fn: Optional[Callable[..., Any]] = None,
|
227 |
+
sample_fn: Optional[Callable[..., Any]] = None,
|
228 |
+
parser_fn: Optional[Callable[..., Any]] = None,
|
229 |
+
filter_fn: Optional[Callable[..., tf.Tensor]] = None,
|
230 |
+
transform_and_batch_fn: Optional[
|
231 |
+
Callable[
|
232 |
+
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
|
233 |
+
tf.data.Dataset,
|
234 |
+
]
|
235 |
+
] = None,
|
236 |
+
postprocess_fn: Optional[Callable[..., Any]] = None,
|
237 |
+
):
|
238 |
+
"""Initializes an InputReader instance.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
params: A config_definitions.DataConfig object.
|
242 |
+
dataset_fn: A `tf.data.Dataset` that consumes the input files. For
|
243 |
+
example, it can be `tf.data.TFRecordDataset`.
|
244 |
+
decoder_fn: An optional `callable` that takes the serialized data string
|
245 |
+
and decodes them into the raw tensor dictionary.
|
246 |
+
combine_fn: An optional `callable` that takes a dictionarty of
|
247 |
+
`tf.data.Dataset` objects as input and outputs a combined dataset. It
|
248 |
+
will be executed after the decoder_fn and before the sample_fn.
|
249 |
+
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
|
250 |
+
input and outputs the transformed dataset. It performs sampling on the
|
251 |
+
decoded raw tensors dict before the parser_fn.
|
252 |
+
parser_fn: An optional `callable` that takes the decoded raw tensors dict
|
253 |
+
and parse them into a dictionary of tensors that can be consumed by the
|
254 |
+
model. It will be executed after decoder_fn.
|
255 |
+
filter_fn: An optional `callable` mapping a dataset element to a boolean.
|
256 |
+
It will be executed after parser_fn.
|
257 |
+
transform_and_batch_fn: An optional `callable` that takes a
|
258 |
+
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
|
259 |
+
input, and returns a `tf.data.Dataset` object. It will be executed after
|
260 |
+
`parser_fn` to transform and batch the dataset; if None, after
|
261 |
+
`parser_fn` is executed, the dataset will be batched into per-replica
|
262 |
+
batch size.
|
263 |
+
postprocess_fn: A optional `callable` that processes batched tensors. It
|
264 |
+
will be executed after batching.
|
265 |
+
"""
|
266 |
+
if params.input_path and params.tfds_name:
|
267 |
+
raise ValueError('At most one of `input_path` and `tfds_name` can be '
|
268 |
+
'specified, but got %s and %s.' %
|
269 |
+
(params.input_path, params.tfds_name))
|
270 |
+
|
271 |
+
if (isinstance(params.input_path, cfg.base_config.Config) or
|
272 |
+
isinstance(params.tfds_name, cfg.base_config.Config)
|
273 |
+
) and combine_fn is None:
|
274 |
+
raise ValueError(
|
275 |
+
'A combine_fn is required if `input_path` or `tfds_name` is a dict.')
|
276 |
+
|
277 |
+
self._tfds_name = params.tfds_name
|
278 |
+
self._tfds_data_dir = params.tfds_data_dir
|
279 |
+
self._matched_files = None
|
280 |
+
if not params.input_path:
|
281 |
+
# Read dataset from TFDS.
|
282 |
+
if not params.tfds_split:
|
283 |
+
raise ValueError(
|
284 |
+
'`tfds_name` is %s, but `tfds_split` is not specified.' %
|
285 |
+
params.tfds_name)
|
286 |
+
else:
|
287 |
+
self._matched_files = self.get_files(params.input_path)
|
288 |
+
|
289 |
+
self._global_batch_size = params.global_batch_size
|
290 |
+
self._is_training = params.is_training
|
291 |
+
self._drop_remainder = params.drop_remainder
|
292 |
+
self._shuffle_buffer_size = params.shuffle_buffer_size
|
293 |
+
self._cache = params.cache
|
294 |
+
self._cycle_length = params.cycle_length
|
295 |
+
self._block_length = params.block_length
|
296 |
+
self._deterministic = params.deterministic
|
297 |
+
self._sharding = params.sharding
|
298 |
+
self._tfds_split = params.tfds_split
|
299 |
+
self._tfds_as_supervised = params.tfds_as_supervised
|
300 |
+
self._tfds_skip_decoding_feature = params.tfds_skip_decoding_feature
|
301 |
+
|
302 |
+
self._dataset_fn = dataset_fn
|
303 |
+
self._decoder_fn = decoder_fn
|
304 |
+
self._combine_fn = combine_fn
|
305 |
+
self._sample_fn = sample_fn
|
306 |
+
self._parser_fn = parser_fn
|
307 |
+
self._transform_and_batch_fn = transform_and_batch_fn
|
308 |
+
self._postprocess_fn = postprocess_fn
|
309 |
+
self._filter_fn = filter_fn
|
310 |
+
self._seed = params.seed
|
311 |
+
self._prefetch_buffer_size = (
|
312 |
+
params.prefetch_buffer_size or tf.data.experimental.AUTOTUNE)
|
313 |
+
self._autotune_algorithm = params.autotune_algorithm
|
314 |
+
|
315 |
+
# When tf.data service is enabled, each data service worker should get
|
316 |
+
# different random seeds. Thus, we set `seed` to None.
|
317 |
+
# Sharding should also be disabled because tf data service handles how
|
318 |
+
# each worker shard data with `processing_mode` in distribute method.
|
319 |
+
if params.enable_tf_data_service:
|
320 |
+
self._seed = None
|
321 |
+
self._sharding = False
|
322 |
+
|
323 |
+
self._enable_tf_data_service = (
|
324 |
+
params.enable_tf_data_service and params.tf_data_service_address)
|
325 |
+
self._tf_data_service_address = params.tf_data_service_address
|
326 |
+
self._enable_shared_tf_data_service_between_parallel_trainers = (
|
327 |
+
params.enable_shared_tf_data_service_between_parallel_trainers)
|
328 |
+
self._apply_tf_data_service_before_batching = (
|
329 |
+
params.apply_tf_data_service_before_batching)
|
330 |
+
self._trainer_id = params.trainer_id
|
331 |
+
if self._enable_tf_data_service:
|
332 |
+
# Add a random seed as the tf.data service job name suffix, so tf.data
|
333 |
+
# service doesn't reuse the previous state if TPU worker gets preempted.
|
334 |
+
# It's necessary to add global batch size into the tf data service job
|
335 |
+
# name because when tuning batch size with vizier and tf data service is
|
336 |
+
# also enable, the tf data servce job name should be different for
|
337 |
+
# different vizier trials since once batch size is changed, from the
|
338 |
+
# tf.data perspective, the dataset is a different instance, and a
|
339 |
+
# different job name should be used for tf data service. Otherwise, the
|
340 |
+
# model would read tensors from the incorrect tf data service job, which
|
341 |
+
# would causes dimension mismatch on the batch size dimension.
|
342 |
+
self._tf_data_service_job_name = (
|
343 |
+
f'{params.tf_data_service_job_name}_bs{params.global_batch_size}_'
|
344 |
+
f'{self.static_randnum}')
|
345 |
+
self._enable_round_robin_tf_data_service = params.get(
|
346 |
+
'enable_round_robin_tf_data_service', False)
|
347 |
+
if self._enable_shared_tf_data_service_between_parallel_trainers:
|
348 |
+
# When shared tf.data service is enabled, only a single tf.data service
|
349 |
+
# instance should be created and shared between parallel trainers. If
|
350 |
+
# the global batch size is different across trainers,
|
351 |
+
# params.apply_tf_data_service_before_batching should be set to true
|
352 |
+
# because tf.data service with different batch sizes will be considered
|
353 |
+
# separate tf.data service instances.
|
354 |
+
self._tf_data_service_job_name = (
|
355 |
+
f'{params.tf_data_service_job_name}_{self.static_randnum}')
|
356 |
+
|
357 |
+
def get_files(self, input_path):
|
358 |
+
"""Gets matched files. Can be overridden by subclasses."""
|
359 |
+
if not input_path:
|
360 |
+
return None
|
361 |
+
# we want to combine / mix datasets
|
362 |
+
if isinstance(input_path, cfg.base_config.Config):
|
363 |
+
matched_files = {}
|
364 |
+
for k, v in input_path.as_dict().items():
|
365 |
+
matched_files[k] = match_files(v)
|
366 |
+
# single dataset
|
367 |
+
else:
|
368 |
+
matched_files = match_files(input_path)
|
369 |
+
return matched_files
|
370 |
+
|
371 |
+
def _read_data_source(
|
372 |
+
self,
|
373 |
+
matched_files: Union[Dict[str, List[str]], List[str]],
|
374 |
+
dataset_fn,
|
375 |
+
input_context: Optional[tf.distribute.InputContext] = None,
|
376 |
+
):
|
377 |
+
"""Reads the data source (files/tfds) to a dataset."""
|
378 |
+
|
379 |
+
def _files_to_dataset(files: List[str]) -> tf.data.Dataset:
|
380 |
+
if len(files) > 1:
|
381 |
+
if input_context and (len(files) < input_context.num_input_pipelines):
|
382 |
+
logging.warn(
|
383 |
+
(
|
384 |
+
'The number of files %d is less than the number of input '
|
385 |
+
'pipelines %d. We will send all input files to every worker. '
|
386 |
+
'Please consider sharding your data into more files.'
|
387 |
+
),
|
388 |
+
len(files),
|
389 |
+
input_context.num_input_pipelines,
|
390 |
+
)
|
391 |
+
return _read_files_then_shard(
|
392 |
+
files,
|
393 |
+
dataset_fn,
|
394 |
+
input_context,
|
395 |
+
sharding=self._sharding,
|
396 |
+
repeat=self._is_training and not self._cache)
|
397 |
+
else:
|
398 |
+
return _shard_files_then_read(
|
399 |
+
files,
|
400 |
+
dataset_fn,
|
401 |
+
input_context,
|
402 |
+
seed=self._seed,
|
403 |
+
is_training=self._is_training,
|
404 |
+
sharding=self._sharding,
|
405 |
+
cache=self._cache,
|
406 |
+
cycle_length=self._cycle_length,
|
407 |
+
block_length=self._block_length,
|
408 |
+
deterministic=self._deterministic)
|
409 |
+
elif len(files) == 1:
|
410 |
+
return _read_files_then_shard(
|
411 |
+
files,
|
412 |
+
dataset_fn,
|
413 |
+
input_context,
|
414 |
+
sharding=self._sharding,
|
415 |
+
repeat=self._is_training and not self._cache)
|
416 |
+
else:
|
417 |
+
raise ValueError('It is unexpected that `tfds_builder` is None and '
|
418 |
+
'there is also no `files`.')
|
419 |
+
|
420 |
+
if self._tfds_name:
|
421 |
+
if isinstance(self._tfds_name, cfg.base_config.Config):
|
422 |
+
dataset = {}
|
423 |
+
for k, tfds_name in self._tfds_name.as_dict().items():
|
424 |
+
dataset[k] = _read_tfds(
|
425 |
+
tfds_name=tfds_name,
|
426 |
+
tfds_data_dir=self._tfds_data_dir,
|
427 |
+
tfds_split=self._tfds_split,
|
428 |
+
tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
|
429 |
+
tfds_as_supervised=self._tfds_as_supervised,
|
430 |
+
input_context=input_context,
|
431 |
+
seed=self._seed,
|
432 |
+
is_training=self._is_training,
|
433 |
+
cache=self._cache,
|
434 |
+
cycle_length=self._cycle_length,
|
435 |
+
block_length=self._block_length)
|
436 |
+
else:
|
437 |
+
dataset = _read_tfds(
|
438 |
+
tfds_name=self._tfds_name,
|
439 |
+
tfds_data_dir=self._tfds_data_dir,
|
440 |
+
tfds_split=self._tfds_split,
|
441 |
+
tfds_skip_decoding_feature=self._tfds_skip_decoding_feature,
|
442 |
+
tfds_as_supervised=self._tfds_as_supervised,
|
443 |
+
input_context=input_context,
|
444 |
+
seed=self._seed,
|
445 |
+
is_training=self._is_training,
|
446 |
+
cache=self._cache,
|
447 |
+
cycle_length=self._cycle_length,
|
448 |
+
block_length=self._block_length)
|
449 |
+
elif isinstance(matched_files, (list, tuple)):
|
450 |
+
dataset = _files_to_dataset(matched_files)
|
451 |
+
elif isinstance(matched_files, dict):
|
452 |
+
dataset = {}
|
453 |
+
for k, fs in matched_files.items():
|
454 |
+
dataset[k] = _files_to_dataset(fs)
|
455 |
+
else:
|
456 |
+
raise ValueError('`matched_files` should be a list or dict.')
|
457 |
+
|
458 |
+
return dataset
|
459 |
+
|
460 |
+
def _decode_and_parse_dataset(
|
461 |
+
self,
|
462 |
+
dataset: Union[tf.data.Dataset, Dict[Text, tf.data.Dataset]],
|
463 |
+
batch_size: int,
|
464 |
+
input_context: Optional[tf.distribute.InputContext] = None
|
465 |
+
) -> tf.data.Dataset:
|
466 |
+
"""Returns a tf.data.Dataset object after shuffling, decoding, and parsing."""
|
467 |
+
|
468 |
+
def _shuffle_and_decode(ds):
|
469 |
+
# If cache is enabled, we will call `shuffle()` later after `cache()`.
|
470 |
+
if self._is_training and not self._cache:
|
471 |
+
ds = ds.shuffle(self._shuffle_buffer_size, seed=self._seed)
|
472 |
+
# Decode
|
473 |
+
ds = _maybe_map_fn(ds, self._decoder_fn)
|
474 |
+
return ds
|
475 |
+
|
476 |
+
dataset = tf.nest.map_structure(_shuffle_and_decode, dataset)
|
477 |
+
if tf.nest.is_nested(dataset):
|
478 |
+
dataset = self._combine_fn(dataset)
|
479 |
+
|
480 |
+
if self._sample_fn is not None:
|
481 |
+
dataset = dataset.apply(self._sample_fn)
|
482 |
+
dataset = _maybe_map_fn(dataset, self._parser_fn)
|
483 |
+
|
484 |
+
if self._filter_fn is not None:
|
485 |
+
dataset = dataset.filter(self._filter_fn)
|
486 |
+
|
487 |
+
if self._cache:
|
488 |
+
dataset = dataset.cache()
|
489 |
+
if self._is_training:
|
490 |
+
dataset = dataset.repeat()
|
491 |
+
dataset = dataset.shuffle(self._shuffle_buffer_size, seed=self._seed)
|
492 |
+
|
493 |
+
# Applies tf.data service before batching operations. This is useful when
|
494 |
+
# tf.data service is shared between parallel trainers, and batch size is
|
495 |
+
# changing between parallel trainers. Then batch size is changing, tf.data
|
496 |
+
# services will be considered different instances if applied after batching
|
497 |
+
# operations, which make it difficult to share between parallel trainers.
|
498 |
+
# However, if there are additional expensive operations in
|
499 |
+
# self._transform_and_batch_fn and self._postprocess_fn, the entire tf.data
|
500 |
+
# pipeline could be slowed down. In this case, try to move these dataset
|
501 |
+
# operations into early stages if possible.
|
502 |
+
if (self._enable_shared_tf_data_service_between_parallel_trainers and
|
503 |
+
self._apply_tf_data_service_before_batching):
|
504 |
+
dataset = self._maybe_apply_data_service(dataset, input_context)
|
505 |
+
|
506 |
+
if self._transform_and_batch_fn is not None:
|
507 |
+
dataset = self._transform_and_batch_fn(dataset, input_context)
|
508 |
+
else:
|
509 |
+
per_replica_batch_size = input_context.get_per_replica_batch_size(
|
510 |
+
batch_size) if input_context else batch_size
|
511 |
+
dataset = dataset.batch(
|
512 |
+
per_replica_batch_size, drop_remainder=self._drop_remainder)
|
513 |
+
|
514 |
+
return dataset
|
515 |
+
|
516 |
+
def _maybe_apply_data_service(
|
517 |
+
self,
|
518 |
+
dataset: tf.data.Dataset,
|
519 |
+
input_context: Optional[tf.distribute.InputContext] = None
|
520 |
+
) -> tf.data.Dataset:
|
521 |
+
"""Potentially distributes a dataset."""
|
522 |
+
if self._enable_tf_data_service and input_context:
|
523 |
+
if self._enable_round_robin_tf_data_service:
|
524 |
+
replicas_per_input_pipeline = input_context.num_replicas_in_sync // (
|
525 |
+
input_context.num_input_pipelines)
|
526 |
+
base_consumer_index = input_context.input_pipeline_id * (
|
527 |
+
replicas_per_input_pipeline)
|
528 |
+
num_consumers = input_context.num_input_pipelines * (
|
529 |
+
replicas_per_input_pipeline)
|
530 |
+
range_dataset = tf.data.Dataset.range(replicas_per_input_pipeline)
|
531 |
+
tfds_kwargs = {
|
532 |
+
'processing_mode': 'parallel_epochs',
|
533 |
+
'service': self._tf_data_service_address,
|
534 |
+
'job_name': self._tf_data_service_job_name,
|
535 |
+
'num_consumers': num_consumers
|
536 |
+
}
|
537 |
+
if self._enable_shared_tf_data_service_between_parallel_trainers:
|
538 |
+
raise ValueError('Shared tf.data service does not support round-robin'
|
539 |
+
' tf.data service.')
|
540 |
+
dataset = range_dataset.map(lambda i: dataset.apply( # pylint: disable=g-long-lambda
|
541 |
+
tf.data.experimental.service.distribute(
|
542 |
+
consumer_index=base_consumer_index + i, **tfds_kwargs)))
|
543 |
+
# Use parallel interleave to read multiple batches from a tf.data
|
544 |
+
# service worker in parallel.
|
545 |
+
dataset = dataset.interleave(
|
546 |
+
lambda x: x,
|
547 |
+
cycle_length=replicas_per_input_pipeline,
|
548 |
+
num_parallel_calls=replicas_per_input_pipeline,
|
549 |
+
deterministic=True)
|
550 |
+
else:
|
551 |
+
tfds_kwargs = {
|
552 |
+
'processing_mode': 'parallel_epochs',
|
553 |
+
'service': self._tf_data_service_address,
|
554 |
+
'job_name': self._tf_data_service_job_name,
|
555 |
+
}
|
556 |
+
if self._enable_shared_tf_data_service_between_parallel_trainers:
|
557 |
+
tfds_kwargs.update({
|
558 |
+
'processing_mode':
|
559 |
+
tf.data.experimental.service.ShardingPolicy.OFF,
|
560 |
+
'cross_trainer_cache':
|
561 |
+
tf.data.experimental.service.CrossTrainerCache(
|
562 |
+
trainer_id=self._trainer_id)
|
563 |
+
})
|
564 |
+
dataset = dataset.apply(
|
565 |
+
tf.data.experimental.service.distribute(**tfds_kwargs))
|
566 |
+
return dataset
|
567 |
+
|
568 |
+
def read(self,
|
569 |
+
input_context: Optional[tf.distribute.InputContext] = None,
|
570 |
+
dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
|
571 |
+
"""Generates a tf.data.Dataset object."""
|
572 |
+
if dataset is None:
|
573 |
+
dataset = self._read_data_source(self._matched_files, self._dataset_fn,
|
574 |
+
input_context)
|
575 |
+
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
|
576 |
+
input_context)
|
577 |
+
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
|
578 |
+
if not (self._enable_shared_tf_data_service_between_parallel_trainers and
|
579 |
+
self._apply_tf_data_service_before_batching):
|
580 |
+
dataset = self._maybe_apply_data_service(dataset, input_context)
|
581 |
+
|
582 |
+
if self._deterministic is not None:
|
583 |
+
options = tf.data.Options()
|
584 |
+
options.deterministic = self._deterministic
|
585 |
+
dataset = dataset.with_options(options)
|
586 |
+
if self._autotune_algorithm:
|
587 |
+
options = tf.data.Options()
|
588 |
+
options.autotune.autotune_algorithm = (
|
589 |
+
tf.data.experimental.AutotuneAlgorithm[self._autotune_algorithm])
|
590 |
+
dataset = dataset.with_options(options)
|
591 |
+
return dataset.prefetch(self._prefetch_buffer_size)
|
modeling/official/core/registry.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Registry utility."""
|
16 |
+
|
17 |
+
|
18 |
+
def register(registered_collection, reg_key):
|
19 |
+
"""Register decorated function or class to collection.
|
20 |
+
|
21 |
+
Register decorated function or class into registered_collection, in a
|
22 |
+
hierarchical order. For example, when reg_key="my_model/my_exp/my_config_0"
|
23 |
+
the decorated function or class is stored under
|
24 |
+
registered_collection["my_model"]["my_exp"]["my_config_0"].
|
25 |
+
This decorator is supposed to be used together with the lookup() function in
|
26 |
+
this file.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
registered_collection: a dictionary. The decorated function or class will be
|
30 |
+
put into this collection.
|
31 |
+
reg_key: The key for retrieving the registered function or class. If reg_key
|
32 |
+
is a string, it can be hierarchical like my_model/my_exp/my_config_0
|
33 |
+
Returns:
|
34 |
+
A decorator function
|
35 |
+
Raises:
|
36 |
+
KeyError: when function or class to register already exists.
|
37 |
+
"""
|
38 |
+
def decorator(fn_or_cls):
|
39 |
+
"""Put fn_or_cls in the dictionary."""
|
40 |
+
if isinstance(reg_key, str):
|
41 |
+
hierarchy = reg_key.split("/")
|
42 |
+
collection = registered_collection
|
43 |
+
for h_idx, entry_name in enumerate(hierarchy[:-1]):
|
44 |
+
if entry_name not in collection:
|
45 |
+
collection[entry_name] = {}
|
46 |
+
collection = collection[entry_name]
|
47 |
+
if not isinstance(collection, dict):
|
48 |
+
raise KeyError(
|
49 |
+
"Collection path {} at position {} already registered as "
|
50 |
+
"a function or class.".format(entry_name, h_idx))
|
51 |
+
leaf_reg_key = hierarchy[-1]
|
52 |
+
else:
|
53 |
+
collection = registered_collection
|
54 |
+
leaf_reg_key = reg_key
|
55 |
+
|
56 |
+
if leaf_reg_key in collection:
|
57 |
+
raise KeyError("Function or class {} registered multiple times.".format(
|
58 |
+
leaf_reg_key))
|
59 |
+
|
60 |
+
collection[leaf_reg_key] = fn_or_cls
|
61 |
+
return fn_or_cls
|
62 |
+
return decorator
|
63 |
+
|
64 |
+
|
65 |
+
def lookup(registered_collection, reg_key):
|
66 |
+
"""Lookup and return decorated function or class in the collection.
|
67 |
+
|
68 |
+
Lookup decorated function or class in registered_collection, in a
|
69 |
+
hierarchical order. For example, when
|
70 |
+
reg_key="my_model/my_exp/my_config_0",
|
71 |
+
this function will return
|
72 |
+
registered_collection["my_model"]["my_exp"]["my_config_0"].
|
73 |
+
|
74 |
+
Args:
|
75 |
+
registered_collection: a dictionary. The decorated function or class will be
|
76 |
+
retrieved from this collection.
|
77 |
+
reg_key: The key for retrieving the registered function or class. If reg_key
|
78 |
+
is a string, it can be hierarchical like my_model/my_exp/my_config_0
|
79 |
+
Returns:
|
80 |
+
The registered function or class.
|
81 |
+
Raises:
|
82 |
+
LookupError: when reg_key cannot be found.
|
83 |
+
"""
|
84 |
+
if isinstance(reg_key, str):
|
85 |
+
hierarchy = reg_key.split("/")
|
86 |
+
collection = registered_collection
|
87 |
+
for h_idx, entry_name in enumerate(hierarchy):
|
88 |
+
if entry_name not in collection:
|
89 |
+
raise LookupError(
|
90 |
+
f"collection path {entry_name} at position {h_idx} is never "
|
91 |
+
f"registered. Please make sure the {entry_name} and its library is "
|
92 |
+
"imported and linked to the trainer binary.")
|
93 |
+
collection = collection[entry_name]
|
94 |
+
return collection
|
95 |
+
else:
|
96 |
+
if reg_key not in registered_collection:
|
97 |
+
raise LookupError(
|
98 |
+
f"registration key {reg_key} is never "
|
99 |
+
f"registered. Please make sure the {reg_key} and its library is "
|
100 |
+
"imported and linked to the trainer binary.")
|
101 |
+
return registered_collection[reg_key]
|
modeling/official/core/registry_test.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for registry."""
|
16 |
+
|
17 |
+
import tensorflow as tf, tf_keras
|
18 |
+
from official.core import registry
|
19 |
+
|
20 |
+
|
21 |
+
class RegistryTest(tf.test.TestCase):
|
22 |
+
|
23 |
+
def test_register(self):
|
24 |
+
collection = {}
|
25 |
+
|
26 |
+
@registry.register(collection, 'functions/func_0')
|
27 |
+
def func_test():
|
28 |
+
pass
|
29 |
+
|
30 |
+
self.assertEqual(registry.lookup(collection, 'functions/func_0'), func_test)
|
31 |
+
|
32 |
+
@registry.register(collection, 'classes/cls_0')
|
33 |
+
class ClassRegistryKey:
|
34 |
+
pass
|
35 |
+
|
36 |
+
self.assertEqual(
|
37 |
+
registry.lookup(collection, 'classes/cls_0'), ClassRegistryKey)
|
38 |
+
|
39 |
+
@registry.register(collection, ClassRegistryKey)
|
40 |
+
class ClassRegistryValue:
|
41 |
+
pass
|
42 |
+
|
43 |
+
self.assertEqual(
|
44 |
+
registry.lookup(collection, ClassRegistryKey), ClassRegistryValue)
|
45 |
+
|
46 |
+
def test_register_hierarchy(self):
|
47 |
+
collection = {}
|
48 |
+
|
49 |
+
@registry.register(collection, 'functions/func_0')
|
50 |
+
def func_test0():
|
51 |
+
pass
|
52 |
+
|
53 |
+
@registry.register(collection, 'func_1')
|
54 |
+
def func_test1():
|
55 |
+
pass
|
56 |
+
|
57 |
+
@registry.register(collection, func_test1)
|
58 |
+
def func_test2():
|
59 |
+
pass
|
60 |
+
|
61 |
+
expected_collection = {
|
62 |
+
'functions': {
|
63 |
+
'func_0': func_test0,
|
64 |
+
},
|
65 |
+
'func_1': func_test1,
|
66 |
+
func_test1: func_test2,
|
67 |
+
}
|
68 |
+
self.assertEqual(collection, expected_collection)
|
69 |
+
|
70 |
+
def test_register_error(self):
|
71 |
+
collection = {}
|
72 |
+
|
73 |
+
@registry.register(collection, 'functions/func_0')
|
74 |
+
def func_test0(): # pylint: disable=unused-variable
|
75 |
+
pass
|
76 |
+
|
77 |
+
with self.assertRaises(KeyError):
|
78 |
+
|
79 |
+
@registry.register(collection, 'functions/func_0/sub_func')
|
80 |
+
def func_test1(): # pylint: disable=unused-variable
|
81 |
+
pass
|
82 |
+
|
83 |
+
with self.assertRaises(LookupError):
|
84 |
+
registry.lookup(collection, 'non-exist')
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
tf.test.main()
|
modeling/official/core/savedmodel_checkpoint_manager.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Custom checkpoint manager that also exports saved models."""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import time
|
20 |
+
from typing import Callable, List, Mapping, Optional, Union
|
21 |
+
|
22 |
+
from absl import logging
|
23 |
+
import tensorflow as tf, tf_keras
|
24 |
+
|
25 |
+
SAVED_MODULES_PATH_SUFFIX = 'saved_modules'
|
26 |
+
|
27 |
+
|
28 |
+
def make_saved_modules_directory_name(checkpoint_name: str) -> str:
|
29 |
+
return f'{checkpoint_name}_{SAVED_MODULES_PATH_SUFFIX}'
|
30 |
+
|
31 |
+
|
32 |
+
class SavedModelCheckpointManager(tf.train.CheckpointManager):
|
33 |
+
"""A CheckpointManager that also exports `SavedModel`s."""
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
checkpoint: tf.train.Checkpoint,
|
37 |
+
directory: str,
|
38 |
+
max_to_keep: int,
|
39 |
+
modules_to_export: Optional[Mapping[str, tf.Module]] = None,
|
40 |
+
keep_checkpoint_every_n_hours: Optional[int] = None,
|
41 |
+
checkpoint_name: str = 'ckpt',
|
42 |
+
step_counter: Optional[tf.Variable] = None,
|
43 |
+
checkpoint_interval: Optional[int] = None,
|
44 |
+
init_fn: Optional[Callable[[], None]] = None):
|
45 |
+
"""See base class."""
|
46 |
+
super().__init__(
|
47 |
+
checkpoint=checkpoint,
|
48 |
+
directory=directory,
|
49 |
+
max_to_keep=max_to_keep,
|
50 |
+
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
|
51 |
+
checkpoint_name=checkpoint_name,
|
52 |
+
step_counter=step_counter,
|
53 |
+
checkpoint_interval=checkpoint_interval,
|
54 |
+
init_fn=init_fn)
|
55 |
+
self._modules_to_export = modules_to_export
|
56 |
+
self._savedmodels = self.get_existing_savedmodels()
|
57 |
+
|
58 |
+
def save(self,
|
59 |
+
checkpoint_number: Optional[int] = None,
|
60 |
+
check_interval: bool = True,
|
61 |
+
options: Optional[tf.train.CheckpointOptions] = None):
|
62 |
+
"""See base class."""
|
63 |
+
checkpoint_path = super().save(
|
64 |
+
checkpoint_number=checkpoint_number,
|
65 |
+
check_interval=check_interval,
|
66 |
+
options=options)
|
67 |
+
if not checkpoint_path: # Nothing got written.
|
68 |
+
return
|
69 |
+
if not self._modules_to_export: # No modules to export.
|
70 |
+
logging.info('Skip saving SavedModel due to empty modules_to_export.')
|
71 |
+
return checkpoint_path
|
72 |
+
|
73 |
+
# Save the models for the checkpoint that just got written.
|
74 |
+
saved_modules_directory = make_saved_modules_directory_name(checkpoint_path)
|
75 |
+
# Atomic export of SavedModel. Write into a temporary direcotory and then
|
76 |
+
# rename as the final direcotory after finishing the writing.
|
77 |
+
# This can avoid trying to read an unfinished savedmodel.
|
78 |
+
saved_modules_directory_tmp = saved_modules_directory + '_temp'
|
79 |
+
for model_name, model in self._modules_to_export.items():
|
80 |
+
signatures = getattr(model, 'saved_model_signatures', None)
|
81 |
+
if signatures is not None:
|
82 |
+
tf.saved_model.save(
|
83 |
+
obj=model,
|
84 |
+
export_dir=os.path.join(saved_modules_directory_tmp, model_name),
|
85 |
+
signatures=signatures)
|
86 |
+
if tf.io.gfile.exists(saved_modules_directory_tmp):
|
87 |
+
tf.io.gfile.rename(saved_modules_directory_tmp, saved_modules_directory)
|
88 |
+
|
89 |
+
saved_modules_directories_to_keep = [
|
90 |
+
make_saved_modules_directory_name(ckpt) for ckpt in self.checkpoints
|
91 |
+
]
|
92 |
+
existing_saved_modules_dirs = self.get_existing_savedmodels()
|
93 |
+
|
94 |
+
self._savedmodels = []
|
95 |
+
# Keep savedmodels in the same order as checkpoints (from oldest to newest).
|
96 |
+
for saved_modules_dir_to_keep in saved_modules_directories_to_keep:
|
97 |
+
if saved_modules_dir_to_keep in existing_saved_modules_dirs:
|
98 |
+
self._savedmodels.append(saved_modules_dir_to_keep)
|
99 |
+
|
100 |
+
for existing_saved_modules_dir in existing_saved_modules_dirs:
|
101 |
+
if existing_saved_modules_dir not in self._savedmodels:
|
102 |
+
tf.io.gfile.rmtree(existing_saved_modules_dir)
|
103 |
+
|
104 |
+
return checkpoint_path
|
105 |
+
|
106 |
+
def get_existing_savedmodels(self) -> List[str]:
|
107 |
+
"""Gets a list of all existing SavedModel paths in `directory`.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
A list of all existing SavedModel paths.
|
111 |
+
"""
|
112 |
+
saved_modules_glob = make_saved_modules_directory_name(
|
113 |
+
self._checkpoint_prefix + '-*')
|
114 |
+
savedmodels = tf.io.gfile.glob(saved_modules_glob)
|
115 |
+
# Filter out temporary savedmodel.
|
116 |
+
savedmodels = [
|
117 |
+
savedmodel
|
118 |
+
for savedmodel in savedmodels
|
119 |
+
if savedmodel.endswith(SAVED_MODULES_PATH_SUFFIX)
|
120 |
+
]
|
121 |
+
return savedmodels
|
122 |
+
|
123 |
+
@property
|
124 |
+
def latest_savedmodel(self) -> Union[str, None]:
|
125 |
+
"""The path of the most recent SavedModel in `directory`.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
The latest SavedModel path. If there are no SavedModels, returns `None`.
|
129 |
+
"""
|
130 |
+
if self._savedmodels:
|
131 |
+
return self._savedmodels[-1]
|
132 |
+
return None
|
133 |
+
|
134 |
+
@property
|
135 |
+
def savedmodels(self) -> List[str]:
|
136 |
+
"""A list of managed SavedModels.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
A list of SavedModel paths, sorted from oldest to newest.
|
140 |
+
"""
|
141 |
+
return self._savedmodels
|
142 |
+
|
143 |
+
@property
|
144 |
+
def modules_to_export(self) -> Union[Mapping[str, tf.Module], None]:
|
145 |
+
return self._modules_to_export
|
146 |
+
|
147 |
+
def get_savedmodel_number_from_path(self,
|
148 |
+
savedmodel_path: str) -> Union[int, None]:
|
149 |
+
"""Gets the savedmodel_number/checkpoint_number from savedmodel filepath.
|
150 |
+
|
151 |
+
The savedmodel_number is global step when using with orbit controller.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
savedmodel_path: savedmodel directory path.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Savedmodel number or None if no matched pattern found in savedmodel path.
|
158 |
+
"""
|
159 |
+
pattern = rf'\d+_{SAVED_MODULES_PATH_SUFFIX}$'
|
160 |
+
savedmodel_number = re.search(pattern, savedmodel_path)
|
161 |
+
if savedmodel_number:
|
162 |
+
savedmodel_number = savedmodel_number.group()
|
163 |
+
return int(savedmodel_number[:-len(SAVED_MODULES_PATH_SUFFIX) - 1])
|
164 |
+
return None
|
165 |
+
|
166 |
+
def savedmodels_iterator(self,
|
167 |
+
min_interval_secs: float = 0,
|
168 |
+
timeout: Optional[float] = None,
|
169 |
+
timeout_fn: Optional[Callable[[], bool]] = None):
|
170 |
+
"""Continuously yield new SavedModel files as they appear.
|
171 |
+
|
172 |
+
The iterator only checks for new savedmodels when control flow has been
|
173 |
+
reverted to it. The logic is same to the `train.checkpoints_iterator`.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
min_interval_secs: The minimum number of seconds between yielding
|
177 |
+
savedmodels.
|
178 |
+
timeout: The maximum number of seconds to wait between savedmodels. If
|
179 |
+
left as `None`, then the process will wait indefinitely.
|
180 |
+
timeout_fn: Optional function to call after a timeout. If the function
|
181 |
+
returns True, then it means that no new savedmodels will be generated
|
182 |
+
and the iterator will exit. The function is called with no arguments.
|
183 |
+
|
184 |
+
Yields:
|
185 |
+
String paths to latest SavedModel files as they arrive.
|
186 |
+
"""
|
187 |
+
savedmodel_path = None
|
188 |
+
while True:
|
189 |
+
new_savedmodel_path = self.wait_for_new_savedmodel(
|
190 |
+
savedmodel_path, timeout=timeout)
|
191 |
+
if new_savedmodel_path is None:
|
192 |
+
if not timeout_fn:
|
193 |
+
# timed out
|
194 |
+
logging.info('Timed-out waiting for a savedmodel.')
|
195 |
+
return
|
196 |
+
if timeout_fn():
|
197 |
+
# The timeout_fn indicated that we are truly done.
|
198 |
+
return
|
199 |
+
else:
|
200 |
+
# The timeout_fn indicated that more savedmodels may come.
|
201 |
+
continue
|
202 |
+
start = time.time()
|
203 |
+
savedmodel_path = new_savedmodel_path
|
204 |
+
yield savedmodel_path
|
205 |
+
time_to_next_eval = start + min_interval_secs - time.time()
|
206 |
+
if time_to_next_eval > 0:
|
207 |
+
time.sleep(time_to_next_eval)
|
208 |
+
|
209 |
+
def wait_for_new_savedmodel(
|
210 |
+
self,
|
211 |
+
last_savedmodel: Optional[str] = None,
|
212 |
+
seconds_to_sleep: float = 1.0,
|
213 |
+
timeout: Optional[float] = None) -> Union[str, None]:
|
214 |
+
"""Waits until a new savedmodel file is found.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
last_savedmodel: The last savedmodel path used or `None` if we're
|
218 |
+
expecting a savedmodel for the first time.
|
219 |
+
seconds_to_sleep: The number of seconds to sleep for before looking for a
|
220 |
+
new savedmodel.
|
221 |
+
timeout: The maximum number of seconds to wait. If left as `None`, then
|
222 |
+
the process will wait indefinitely.
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
A new savedmodel path, or None if the timeout was reached.
|
226 |
+
"""
|
227 |
+
logging.info('Waiting for new savedmodel at %s', self._directory)
|
228 |
+
stop_time = time.time() + timeout if timeout is not None else None
|
229 |
+
|
230 |
+
last_savedmodel_number = -1
|
231 |
+
if last_savedmodel:
|
232 |
+
last_savedmodel_number = self.get_savedmodel_number_from_path(
|
233 |
+
last_savedmodel)
|
234 |
+
|
235 |
+
while True:
|
236 |
+
if stop_time is not None and time.time() + seconds_to_sleep > stop_time:
|
237 |
+
return None
|
238 |
+
|
239 |
+
existing_savedmodels = {}
|
240 |
+
for savedmodel_path in self.get_existing_savedmodels():
|
241 |
+
savedmodel_number = self.get_savedmodel_number_from_path(
|
242 |
+
savedmodel_path)
|
243 |
+
if savedmodel_number is not None:
|
244 |
+
existing_savedmodels[savedmodel_number] = savedmodel_path
|
245 |
+
|
246 |
+
# Find the first savedmodel with larger step number as next savedmodel.
|
247 |
+
savedmodel_path = None
|
248 |
+
existing_savedmodels = dict(sorted(existing_savedmodels.items()))
|
249 |
+
for savedmodel_number in existing_savedmodels:
|
250 |
+
if savedmodel_number > last_savedmodel_number:
|
251 |
+
savedmodel_path = existing_savedmodels[savedmodel_number]
|
252 |
+
break
|
253 |
+
|
254 |
+
if savedmodel_path:
|
255 |
+
logging.info('Found new savedmodel at %s', savedmodel_path)
|
256 |
+
return savedmodel_path
|
257 |
+
else:
|
258 |
+
time.sleep(seconds_to_sleep)
|
modeling/official/core/savedmodel_checkpoint_manager_test.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
from typing import Iterable
|
18 |
+
|
19 |
+
import tensorflow as tf, tf_keras
|
20 |
+
|
21 |
+
from official.core import savedmodel_checkpoint_manager
|
22 |
+
|
23 |
+
|
24 |
+
def _models_exist(checkpoint_path: str, models: Iterable[str]) -> bool:
|
25 |
+
for model_name in models:
|
26 |
+
if not tf.io.gfile.isdir(
|
27 |
+
os.path.join(
|
28 |
+
savedmodel_checkpoint_manager.make_saved_modules_directory_name(
|
29 |
+
checkpoint_path), model_name)):
|
30 |
+
return False
|
31 |
+
return True
|
32 |
+
|
33 |
+
|
34 |
+
class _ModelForTest(tf_keras.Model):
|
35 |
+
def __init__(self, hidden_size: int = 8):
|
36 |
+
super().__init__()
|
37 |
+
self.dense = tf_keras.layers.Dense(hidden_size)
|
38 |
+
|
39 |
+
@tf.function(input_signature=[tf.TensorSpec([None, 16])])
|
40 |
+
def call(self, inputs):
|
41 |
+
return self.dense(inputs)
|
42 |
+
|
43 |
+
@property
|
44 |
+
def saved_model_signatures(self):
|
45 |
+
# Build SavedModel signatures.
|
46 |
+
return dict(serving_default=self.call)
|
47 |
+
|
48 |
+
|
49 |
+
class CheckpointManagerTest(tf.test.TestCase):
|
50 |
+
|
51 |
+
def _create_manager(self, max_to_keep: int = 1) -> tf.train.CheckpointManager:
|
52 |
+
"""Sets up SavedModelCheckpointManager object.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
max_to_keep: max number of savedmodels to keep.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
created savedmodel manager.
|
59 |
+
"""
|
60 |
+
models = {
|
61 |
+
'model_1': _ModelForTest(12),
|
62 |
+
'model_2': _ModelForTest(14),
|
63 |
+
}
|
64 |
+
checkpoint = tf.train.Checkpoint()
|
65 |
+
manager = savedmodel_checkpoint_manager.SavedModelCheckpointManager(
|
66 |
+
checkpoint=checkpoint,
|
67 |
+
directory=self.get_temp_dir(),
|
68 |
+
max_to_keep=max_to_keep,
|
69 |
+
modules_to_export=models)
|
70 |
+
return manager
|
71 |
+
|
72 |
+
def test_max_to_keep(self):
|
73 |
+
manager = self._create_manager()
|
74 |
+
models = manager.modules_to_export
|
75 |
+
first_path = manager.save()
|
76 |
+
second_path = manager.save()
|
77 |
+
|
78 |
+
savedmodel = savedmodel_checkpoint_manager.make_saved_modules_directory_name(
|
79 |
+
manager.latest_checkpoint)
|
80 |
+
self.assertEqual(savedmodel, manager.latest_savedmodel)
|
81 |
+
self.assertTrue(_models_exist(second_path, models.keys()))
|
82 |
+
self.assertFalse(_models_exist(first_path, models.keys()))
|
83 |
+
|
84 |
+
def test_returns_none_after_timeout(self):
|
85 |
+
manager = self._create_manager()
|
86 |
+
start = time.time()
|
87 |
+
ret = manager.wait_for_new_savedmodel(
|
88 |
+
None, timeout=1.0, seconds_to_sleep=0.5)
|
89 |
+
end = time.time()
|
90 |
+
self.assertIsNone(ret)
|
91 |
+
# We've waited 0.5 second.
|
92 |
+
self.assertGreater(end, start + 0.5)
|
93 |
+
# The timeout kicked in.
|
94 |
+
self.assertLess(end, start + 0.6)
|
95 |
+
|
96 |
+
def test_saved_model_iterator(self):
|
97 |
+
manager = self._create_manager(max_to_keep=2)
|
98 |
+
self.assertIsNotNone(manager.save(checkpoint_number=1))
|
99 |
+
self.assertIsNotNone(manager.save(checkpoint_number=2))
|
100 |
+
self.assertIsNotNone(manager.save(checkpoint_number=3))
|
101 |
+
|
102 |
+
# Savedmodels are in time order.
|
103 |
+
expected_savedmodels = manager.savedmodels
|
104 |
+
# Order not guaranteed.
|
105 |
+
existing_savedmodels = manager.get_existing_savedmodels()
|
106 |
+
savedmodels = list(manager.savedmodels_iterator(timeout=3.0))
|
107 |
+
self.assertEqual(savedmodels, expected_savedmodels)
|
108 |
+
self.assertEqual(set(savedmodels), set(existing_savedmodels))
|
109 |
+
|
110 |
+
def test_saved_model_iterator_timeout_fn(self):
|
111 |
+
manager = self._create_manager()
|
112 |
+
timeout_fn_calls = [0]
|
113 |
+
|
114 |
+
def timeout_fn():
|
115 |
+
timeout_fn_calls[0] += 1
|
116 |
+
return timeout_fn_calls[0] > 3
|
117 |
+
|
118 |
+
results = list(
|
119 |
+
manager.savedmodels_iterator(timeout=0.1, timeout_fn=timeout_fn))
|
120 |
+
self.assertEqual([], results)
|
121 |
+
self.assertEqual(4, timeout_fn_calls[0])
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
tf.test.main()
|
modeling/official/core/task_factory.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A global factory to register and access all registered tasks."""
|
16 |
+
|
17 |
+
from official.core import registry
|
18 |
+
|
19 |
+
_REGISTERED_TASK_CLS = {}
|
20 |
+
|
21 |
+
|
22 |
+
# TODO(b/158741360): Add type annotations once pytype checks across modules.
|
23 |
+
def register_task_cls(task_config_cls):
|
24 |
+
"""Decorates a factory of Tasks for lookup by a subclass of TaskConfig.
|
25 |
+
|
26 |
+
This decorator supports registration of tasks as follows:
|
27 |
+
|
28 |
+
```
|
29 |
+
@dataclasses.dataclass
|
30 |
+
class MyTaskConfig(TaskConfig):
|
31 |
+
# Add fields here.
|
32 |
+
pass
|
33 |
+
|
34 |
+
@register_task_cls(MyTaskConfig)
|
35 |
+
class MyTask(Task):
|
36 |
+
# Inherits def __init__(self, task_config).
|
37 |
+
pass
|
38 |
+
|
39 |
+
my_task_config = MyTaskConfig()
|
40 |
+
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
|
41 |
+
```
|
42 |
+
|
43 |
+
Besisdes a class itself, other callables that create a Task from a TaskConfig
|
44 |
+
can be decorated by the result of this function, as long as there is at most
|
45 |
+
one registration for each config class.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
|
49 |
+
Each task_config_cls can only be used for a single registration.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
A callable for use as class decorator that registers the decorated class
|
53 |
+
for creation from an instance of task_config_cls.
|
54 |
+
"""
|
55 |
+
return registry.register(_REGISTERED_TASK_CLS, task_config_cls)
|
56 |
+
|
57 |
+
|
58 |
+
def get_task(task_config, **kwargs):
|
59 |
+
"""Creates a Task (of suitable subclass type) from task_config."""
|
60 |
+
# TODO(hongkuny): deprecate the task factory to use config.BUILDER.
|
61 |
+
if task_config.BUILDER is not None:
|
62 |
+
return task_config.BUILDER(task_config, **kwargs)
|
63 |
+
return get_task_cls(task_config.__class__)(task_config, **kwargs)
|
64 |
+
|
65 |
+
|
66 |
+
# The user-visible get_task() is defined after classes have been registered.
|
67 |
+
# TODO(b/158741360): Add type annotations once pytype checks across modules.
|
68 |
+
def get_task_cls(task_config_cls):
|
69 |
+
task_cls = registry.lookup(_REGISTERED_TASK_CLS, task_config_cls)
|
70 |
+
return task_cls
|
modeling/official/core/test_utils.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Utils for testing."""
|
16 |
+
|
17 |
+
import tensorflow as tf, tf_keras
|
18 |
+
|
19 |
+
|
20 |
+
class FakeKerasModel(tf_keras.Model):
|
21 |
+
"""Fake keras model for testing."""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
self.dense = tf_keras.layers.Dense(4, activation=tf.nn.relu)
|
26 |
+
self.dense2 = tf_keras.layers.Dense(4, activation=tf.nn.relu)
|
27 |
+
|
28 |
+
def call(self, inputs): # pytype: disable=signature-mismatch # overriding-parameter-count-checks
|
29 |
+
return self.dense2(self.dense(inputs))
|
30 |
+
|
31 |
+
|
32 |
+
class _Dense(tf.Module):
|
33 |
+
"""A dense layer."""
|
34 |
+
|
35 |
+
def __init__(self, input_dim, output_size, name=None):
|
36 |
+
super().__init__(name=name)
|
37 |
+
with self.name_scope:
|
38 |
+
self.w = tf.Variable(
|
39 |
+
tf.random.normal([input_dim, output_size]), name='w')
|
40 |
+
self.b = tf.Variable(tf.zeros([output_size]), name='b')
|
41 |
+
|
42 |
+
@tf.Module.with_name_scope
|
43 |
+
def __call__(self, x):
|
44 |
+
y = tf.matmul(x, self.w) + self.b
|
45 |
+
return tf.nn.relu(y)
|
46 |
+
|
47 |
+
|
48 |
+
class FakeModule(tf.Module):
|
49 |
+
"""Fake model using tf.Module for testing."""
|
50 |
+
|
51 |
+
def __init__(self, input_size, name=None):
|
52 |
+
super().__init__(name=name)
|
53 |
+
with self.name_scope:
|
54 |
+
self.dense = _Dense(input_size, 4, name='dense')
|
55 |
+
self.dense2 = _Dense(4, 4, name='dense_1')
|
56 |
+
|
57 |
+
@tf.Module.with_name_scope
|
58 |
+
def __call__(self, x):
|
59 |
+
return self.dense2(self.dense(x))
|
modeling/official/core/tf_example_builder.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Builder class for preparing tf.train.Example."""
|
16 |
+
|
17 |
+
# https://www.python.org/dev/peps/pep-0563/#enabling-the-future-behavior-in-python-3-7
|
18 |
+
from __future__ import annotations
|
19 |
+
|
20 |
+
from typing import Mapping, Sequence, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import tensorflow as tf, tf_keras
|
24 |
+
|
25 |
+
BytesValueType = Union[bytes, Sequence[bytes], str, Sequence[str]]
|
26 |
+
|
27 |
+
_to_array = lambda v: [v] if not isinstance(v, (list, np.ndarray)) else v
|
28 |
+
_to_bytes = lambda v: v.encode() if isinstance(v, str) else v
|
29 |
+
_to_bytes_array = lambda v: list(map(_to_bytes, _to_array(v)))
|
30 |
+
|
31 |
+
|
32 |
+
class TfExampleBuilder(object):
|
33 |
+
"""Builder class for preparing tf.train.Example.
|
34 |
+
|
35 |
+
Read API doc at https://www.tensorflow.org/api_docs/python/tf/train/Example.
|
36 |
+
|
37 |
+
Example usage:
|
38 |
+
>>> example_builder = TfExampleBuilder()
|
39 |
+
>>> example = (
|
40 |
+
example_builder.add_bytes_feature('feature_a', 'foobarbaz')
|
41 |
+
.add_ints_feature('feature_b', [1, 2, 3])
|
42 |
+
.example)
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self) -> None:
|
46 |
+
self._example = tf.train.Example()
|
47 |
+
|
48 |
+
@property
|
49 |
+
def example(self) -> tf.train.Example:
|
50 |
+
"""Returns a copy of the generated tf.train.Example proto."""
|
51 |
+
return self._example
|
52 |
+
|
53 |
+
@property
|
54 |
+
def serialized_example(self) -> str:
|
55 |
+
"""Returns a serialized string of the generated tf.train.Example proto."""
|
56 |
+
return self._example.SerializeToString()
|
57 |
+
|
58 |
+
def set(self, example: tf.train.Example) -> TfExampleBuilder:
|
59 |
+
"""Sets the example."""
|
60 |
+
self._example = example
|
61 |
+
return self
|
62 |
+
|
63 |
+
def reset(self) -> TfExampleBuilder:
|
64 |
+
"""Resets the example to an empty proto."""
|
65 |
+
self._example = tf.train.Example()
|
66 |
+
return self
|
67 |
+
|
68 |
+
###### Basic APIs for primitive data types ######
|
69 |
+
def add_feature_dict(
|
70 |
+
self, feature_dict: Mapping[str, tf.train.Feature]) -> TfExampleBuilder:
|
71 |
+
"""Adds the predefined `feature_dict` to the example.
|
72 |
+
|
73 |
+
Note: Please prefer to using feature-type-specific methods.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
feature_dict: A dictionary from tf.Example feature key to
|
77 |
+
tf.train.Feature.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
The builder object for subsequent method calls.
|
81 |
+
"""
|
82 |
+
for k, v in feature_dict.items():
|
83 |
+
self._example.features.feature[k].CopyFrom(v)
|
84 |
+
return self
|
85 |
+
|
86 |
+
def add_feature(self, key: str,
|
87 |
+
feature: tf.train.Feature) -> TfExampleBuilder:
|
88 |
+
"""Adds predefined `feature` with `key` to the example.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
key: String key of the feature.
|
92 |
+
feature: The feature to be added to the example.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
The builder object for subsequent method calls.
|
96 |
+
"""
|
97 |
+
self._example.features.feature[key].CopyFrom(feature)
|
98 |
+
return self
|
99 |
+
|
100 |
+
def add_bytes_feature(self, key: str,
|
101 |
+
value: BytesValueType) -> TfExampleBuilder:
|
102 |
+
"""Adds byte(s) or string(s) with `key` to the example.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
key: String key of the feature.
|
106 |
+
value: The byte(s) or string(s) to be added to the example.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
The builder object for subsequent method calls.
|
110 |
+
"""
|
111 |
+
return self.add_feature(
|
112 |
+
key,
|
113 |
+
tf.train.Feature(
|
114 |
+
bytes_list=tf.train.BytesList(value=_to_bytes_array(value))))
|
115 |
+
|
116 |
+
def add_ints_feature(self, key: str,
|
117 |
+
value: Union[int, Sequence[int]]) -> TfExampleBuilder:
|
118 |
+
"""Adds integer(s) with `key` to the example.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
key: String key of the feature.
|
122 |
+
value: The integer(s) to be added to the example.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
The builder object for subsequent method calls.
|
126 |
+
"""
|
127 |
+
return self.add_feature(
|
128 |
+
key,
|
129 |
+
tf.train.Feature(int64_list=tf.train.Int64List(value=_to_array(value))))
|
130 |
+
|
131 |
+
def add_floats_feature(
|
132 |
+
self, key: str, value: Union[float, Sequence[float]]) -> TfExampleBuilder:
|
133 |
+
"""Adds float(s) with `key` to the example.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
key: String key of the feature.
|
137 |
+
value: The float(s) to be added to the example.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
The builder object for subsequent method calls.
|
141 |
+
"""
|
142 |
+
return self.add_feature(
|
143 |
+
key,
|
144 |
+
tf.train.Feature(float_list=tf.train.FloatList(value=_to_array(value))))
|
modeling/official/core/tf_example_builder_test.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for tf_example_builder.
|
16 |
+
|
17 |
+
See `test_add_image_matrix_feature_with_fake_image` for the typical structure of
|
18 |
+
a unit test.
|
19 |
+
"""
|
20 |
+
|
21 |
+
from absl.testing import parameterized
|
22 |
+
import tensorflow as tf, tf_keras
|
23 |
+
from official.core import tf_example_builder
|
24 |
+
|
25 |
+
|
26 |
+
class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
|
27 |
+
|
28 |
+
def test_init_an_empty_example(self):
|
29 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
30 |
+
example = example_builder.example
|
31 |
+
self.assertProtoEquals('', example)
|
32 |
+
|
33 |
+
def test_init_an_empty_serialized_example(self):
|
34 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
35 |
+
example = example_builder.serialized_example
|
36 |
+
self.assertProtoEquals('', example)
|
37 |
+
|
38 |
+
def test_add_feature(self):
|
39 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
40 |
+
example_builder.add_feature(
|
41 |
+
'foo',
|
42 |
+
tf.train.Feature(
|
43 |
+
bytes_list=tf.train.BytesList(value=[b'Hello World!'])))
|
44 |
+
example = example_builder.example
|
45 |
+
# Use proto text to show how the entire proto would look like.
|
46 |
+
self.assertProtoEquals(
|
47 |
+
"""
|
48 |
+
features: {
|
49 |
+
feature: {
|
50 |
+
key: "foo"
|
51 |
+
value: {
|
52 |
+
bytes_list: {
|
53 |
+
value: "Hello World!"
|
54 |
+
}
|
55 |
+
}
|
56 |
+
}
|
57 |
+
}""", example)
|
58 |
+
|
59 |
+
def test_add_feature_dict(self):
|
60 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
61 |
+
example_builder.add_feature_dict({
|
62 |
+
'foo':
|
63 |
+
tf.train.Feature(
|
64 |
+
bytes_list=tf.train.BytesList(value=[b'Hello World!'])),
|
65 |
+
'bar':
|
66 |
+
tf.train.Feature(
|
67 |
+
int64_list=tf.train.Int64List(value=[299, 792, 458]))
|
68 |
+
})
|
69 |
+
example = example_builder.example
|
70 |
+
# Use proto text to show how the entire proto would look like.
|
71 |
+
self.assertProtoEquals(
|
72 |
+
"""
|
73 |
+
features: {
|
74 |
+
feature: {
|
75 |
+
key: "foo"
|
76 |
+
value: {
|
77 |
+
bytes_list: {
|
78 |
+
value: "Hello World!"
|
79 |
+
}
|
80 |
+
}
|
81 |
+
}
|
82 |
+
feature: {
|
83 |
+
key: "bar"
|
84 |
+
value: {
|
85 |
+
int64_list: {
|
86 |
+
value: 299
|
87 |
+
value: 792
|
88 |
+
value: 458
|
89 |
+
}
|
90 |
+
}
|
91 |
+
}
|
92 |
+
}""", example)
|
93 |
+
|
94 |
+
@parameterized.named_parameters(
|
95 |
+
('single_bytes', b'Hello World!', b'Hello World!'),
|
96 |
+
('single_string', 'Hello World!', b'Hello World!'))
|
97 |
+
def test_add_single_byte_feature(self, value, expected_value):
|
98 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
99 |
+
example_builder.add_bytes_feature('foo', value)
|
100 |
+
example = example_builder.example
|
101 |
+
# Use constructor to easily work with test parameters.
|
102 |
+
self.assertProtoEquals(
|
103 |
+
tf.train.Example(
|
104 |
+
features=tf.train.Features(
|
105 |
+
feature={
|
106 |
+
'foo':
|
107 |
+
tf.train.Feature(
|
108 |
+
bytes_list=tf.train.BytesList(
|
109 |
+
value=[expected_value]))
|
110 |
+
})), example)
|
111 |
+
|
112 |
+
@parameterized.named_parameters(
|
113 |
+
('multiple_bytes', [b'Hello World!', b'Good Morning!'
|
114 |
+
], [b'Hello World!', b'Good Morning!']),
|
115 |
+
('multiple_sring', ['Hello World!', 'Good Morning!'
|
116 |
+
], [b'Hello World!', b'Good Morning!']))
|
117 |
+
def test_add_multiple_bytes_feature(self, values, expected_values):
|
118 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
119 |
+
example_builder.add_bytes_feature('foo', values)
|
120 |
+
example = example_builder.example
|
121 |
+
self.assertProtoEquals(
|
122 |
+
tf.train.Example(
|
123 |
+
features=tf.train.Features(
|
124 |
+
feature={
|
125 |
+
'foo':
|
126 |
+
tf.train.Feature(
|
127 |
+
bytes_list=tf.train.BytesList(
|
128 |
+
value=expected_values))
|
129 |
+
})), example)
|
130 |
+
|
131 |
+
@parameterized.named_parameters(
|
132 |
+
('single_integer', 123, [123]),
|
133 |
+
('multiple_integers', [123, 456, 789], [123, 456, 789]))
|
134 |
+
def test_add_ints_feature(self, value, expected_value):
|
135 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
136 |
+
example_builder.add_ints_feature('bar', value)
|
137 |
+
example = example_builder.example
|
138 |
+
self.assertProtoEquals(
|
139 |
+
tf.train.Example(
|
140 |
+
features=tf.train.Features(
|
141 |
+
feature={
|
142 |
+
'bar':
|
143 |
+
tf.train.Feature(
|
144 |
+
int64_list=tf.train.Int64List(value=expected_value))
|
145 |
+
})), example)
|
146 |
+
|
147 |
+
@parameterized.named_parameters(
|
148 |
+
('single_float', 3.14, [3.14]),
|
149 |
+
('multiple_floats', [3.14, 1.57, 6.28], [3.14, 1.57, 6.28]))
|
150 |
+
def test_add_floats_feature(self, value, expected_value):
|
151 |
+
example_builder = tf_example_builder.TfExampleBuilder()
|
152 |
+
example_builder.add_floats_feature('baz', value)
|
153 |
+
example = example_builder.example
|
154 |
+
self.assertProtoEquals(
|
155 |
+
tf.train.Example(
|
156 |
+
features=tf.train.Features(
|
157 |
+
feature={
|
158 |
+
'baz':
|
159 |
+
tf.train.Feature(
|
160 |
+
float_list=tf.train.FloatList(value=expected_value))
|
161 |
+
})), example)
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == '__main__':
|
165 |
+
tf.test.main()
|
modeling/official/core/tf_example_feature_key.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Data classes for tf.Example proto feature keys.
|
16 |
+
|
17 |
+
Feature keys are grouped by feature types. Key names follow conventions in
|
18 |
+
go/tf-example.
|
19 |
+
"""
|
20 |
+
import dataclasses
|
21 |
+
import functools
|
22 |
+
from typing import Optional
|
23 |
+
|
24 |
+
# Disable init function to use the one defined in base class.
|
25 |
+
dataclass = functools.partial(dataclasses.dataclass(init=False))
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class TfExampleFeatureKeyBase:
|
30 |
+
"""Base dataclass for defining tf.Example proto feature keys.
|
31 |
+
|
32 |
+
This class defines the logic of adding prefix to feature keys. Subclasses
|
33 |
+
will define feature keys for a specific feature type in data fields.
|
34 |
+
|
35 |
+
NOTE: Please follow subclass examples in this module to define feature keys
|
36 |
+
for a new feature type.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, prefix: Optional[str] = None):
|
40 |
+
"""Instantiates the feature key class.
|
41 |
+
|
42 |
+
Adds a string prefix to all fields of a feature key instance if `prefix` is
|
43 |
+
not None nor empty.
|
44 |
+
|
45 |
+
Example usage:
|
46 |
+
|
47 |
+
>>> test_key = EncodedImageFeatureKey()
|
48 |
+
>>> test_key.encoded
|
49 |
+
image/encoded
|
50 |
+
>>> test_key = EncodedImageFeatureKey('prefix')
|
51 |
+
>>> test_key.encoded
|
52 |
+
prefix/image/encoded
|
53 |
+
|
54 |
+
Args:
|
55 |
+
prefix: A prefix string that will be added before the feature key string
|
56 |
+
with a trailing slash '/'.
|
57 |
+
"""
|
58 |
+
if prefix:
|
59 |
+
for field in dataclasses.fields(self): # pytype: disable=wrong-arg-types # re-none
|
60 |
+
key_name = field.name
|
61 |
+
key_value = getattr(self, key_name)
|
62 |
+
setattr(self, key_name, f'{prefix}/{key_value}')
|
modeling/official/core/tf_example_feature_key_test.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for tf_example_feature_key."""
|
16 |
+
import dataclasses
|
17 |
+
import inspect
|
18 |
+
from absl.testing import absltest
|
19 |
+
from absl.testing import parameterized
|
20 |
+
|
21 |
+
from official.core import tf_example_feature_key
|
22 |
+
|
23 |
+
|
24 |
+
@tf_example_feature_key.dataclass
|
25 |
+
class TestFeatureKey(tf_example_feature_key.TfExampleFeatureKeyBase):
|
26 |
+
test: str = 'foo/bar'
|
27 |
+
|
28 |
+
|
29 |
+
class TfExampleFeatureKeyTest(parameterized.TestCase):
|
30 |
+
|
31 |
+
def test_add_prefix_success(self):
|
32 |
+
test_key = TestFeatureKey('prefix')
|
33 |
+
self.assertEqual(test_key.test, 'prefix/foo/bar')
|
34 |
+
|
35 |
+
@parameterized.parameters(None, '')
|
36 |
+
def test_add_prefix_skip_success(self, prefix):
|
37 |
+
test_key = TestFeatureKey(prefix)
|
38 |
+
self.assertEqual(test_key.test, 'foo/bar')
|
39 |
+
|
40 |
+
def test_all_feature_key_classes_are_valid(self):
|
41 |
+
for _, obj in inspect.getmembers(tf_example_feature_key):
|
42 |
+
if inspect.isclass(obj):
|
43 |
+
self.assertTrue(dataclasses.is_dataclass(obj))
|
44 |
+
self.assertTrue(
|
45 |
+
issubclass(obj, tf_example_feature_key.TfExampleFeatureKeyBase))
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
absltest.main()
|
modeling/official/core/train_lib.py
ADDED
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""TFM common training driver library."""
|
16 |
+
# pytype: disable=attribute-error
|
17 |
+
import os
|
18 |
+
import tempfile
|
19 |
+
from typing import Any, List, Mapping, Optional, Tuple
|
20 |
+
|
21 |
+
# Import libraries
|
22 |
+
|
23 |
+
from absl import logging
|
24 |
+
import orbit
|
25 |
+
import tensorflow as tf, tf_keras
|
26 |
+
|
27 |
+
from official.core import actions
|
28 |
+
from official.core import base_task
|
29 |
+
from official.core import base_trainer
|
30 |
+
from official.core import config_definitions
|
31 |
+
from official.core import train_utils
|
32 |
+
|
33 |
+
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter
|
34 |
+
|
35 |
+
|
36 |
+
class OrbitExperimentRunner:
|
37 |
+
"""Runs experiment with Orbit training loop.
|
38 |
+
|
39 |
+
The default experiment runner for model garden experiments. User can
|
40 |
+
customize the experiment pipeline by subclassing this class and replacing
|
41 |
+
components or functions.
|
42 |
+
|
43 |
+
For example, an experiment runner with customized checkpoint manager:
|
44 |
+
|
45 |
+
```python
|
46 |
+
class MyExpRunnerWithExporter(OrbitExperimentRunner):
|
47 |
+
def _maybe_build_checkpoint_manager(sefl):
|
48 |
+
# Replaces the default CheckpointManger with a customized one.
|
49 |
+
return MyCheckpointManager(*args)
|
50 |
+
|
51 |
+
# In user code, instead of the orginal
|
52 |
+
# `OrbitExperimentRunner(..).run(mode)`, now user can do:
|
53 |
+
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
|
54 |
+
```
|
55 |
+
|
56 |
+
Similar override can be done to other components.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
distribution_strategy: tf.distribute.Strategy,
|
62 |
+
task: base_task.Task,
|
63 |
+
mode: str,
|
64 |
+
params: config_definitions.ExperimentConfig,
|
65 |
+
model_dir: str,
|
66 |
+
run_post_eval: bool = False,
|
67 |
+
save_summary: bool = True,
|
68 |
+
train_actions: Optional[List[orbit.Action]] = None,
|
69 |
+
eval_actions: Optional[List[orbit.Action]] = None,
|
70 |
+
trainer: Optional[base_trainer.Trainer] = None,
|
71 |
+
controller_cls=orbit.Controller,
|
72 |
+
summary_manager: Optional[orbit.utils.SummaryManager] = None,
|
73 |
+
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
|
74 |
+
enable_async_checkpointing: bool = False,
|
75 |
+
):
|
76 |
+
"""Constructor.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
distribution_strategy: A distribution strategy.
|
80 |
+
task: A Task instance.
|
81 |
+
mode: A 'str', specifying the mode. Can be 'train', 'eval',
|
82 |
+
'train_and_eval' or 'continuous_eval'.
|
83 |
+
params: ExperimentConfig instance.
|
84 |
+
model_dir: A 'str', a path to store model checkpoints and summaries.
|
85 |
+
run_post_eval: Whether to run post eval once after training, metrics logs
|
86 |
+
are returned.
|
87 |
+
save_summary: Whether to save train and validation summary.
|
88 |
+
train_actions: Optional list of Orbit train actions.
|
89 |
+
eval_actions: Optional list of Orbit eval actions.
|
90 |
+
trainer: the base_trainer.Trainer instance. It should be created within
|
91 |
+
the strategy.scope().
|
92 |
+
controller_cls: The controller class to manage the train and eval process.
|
93 |
+
Must be a orbit.Controller subclass.
|
94 |
+
summary_manager: Instance of the summary manager to override default
|
95 |
+
summary manager.
|
96 |
+
eval_summary_manager: Instance of the eval summary manager to override
|
97 |
+
default eval summary manager.
|
98 |
+
enable_async_checkpointing: Optional boolean indicating whether to enable
|
99 |
+
async checkpoint saving.
|
100 |
+
"""
|
101 |
+
self.strategy = distribution_strategy or tf.distribute.get_strategy()
|
102 |
+
self._params = params
|
103 |
+
self._model_dir = model_dir
|
104 |
+
self._mode = mode
|
105 |
+
self._run_post_eval = run_post_eval
|
106 |
+
|
107 |
+
self._trainer = trainer or self._build_trainer(
|
108 |
+
task,
|
109 |
+
train='train' in mode,
|
110 |
+
evaluate=('eval' in mode) or run_post_eval)
|
111 |
+
assert self.trainer is not None
|
112 |
+
self._checkpoint_manager = self._maybe_build_checkpoint_manager()
|
113 |
+
self._summary_manager = summary_manager
|
114 |
+
self._eval_summary_manager = eval_summary_manager
|
115 |
+
self._controller = self._build_controller(
|
116 |
+
trainer=self.trainer if 'train' in mode else None,
|
117 |
+
evaluator=self.trainer,
|
118 |
+
save_summary=save_summary,
|
119 |
+
train_actions=train_actions,
|
120 |
+
eval_actions=eval_actions,
|
121 |
+
controller_cls=controller_cls,
|
122 |
+
enable_async_checkpointing=enable_async_checkpointing)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def params(self) -> config_definitions.ExperimentConfig:
|
126 |
+
"""The whole experiment parameters object."""
|
127 |
+
return self._params
|
128 |
+
|
129 |
+
@property
|
130 |
+
def model_dir(self) -> str:
|
131 |
+
"""Path to the model folder, which stores checkpoints, params, log, etc."""
|
132 |
+
return self._model_dir
|
133 |
+
|
134 |
+
@property
|
135 |
+
def trainer(self) -> base_trainer.Trainer:
|
136 |
+
"""The underlying Orbit Trainer object."""
|
137 |
+
return self._trainer
|
138 |
+
|
139 |
+
@property
|
140 |
+
def checkpoint_manager(self) -> Optional[tf.train.CheckpointManager]:
|
141 |
+
"""The CheckpointManager that stores the checkpoints in a train job."""
|
142 |
+
return self._checkpoint_manager
|
143 |
+
|
144 |
+
@property
|
145 |
+
def controller(self) -> orbit.Controller:
|
146 |
+
"""The Orbit controller object."""
|
147 |
+
return self._controller
|
148 |
+
|
149 |
+
def _build_trainer(self, task: base_task.Task, train: bool,
|
150 |
+
evaluate: bool) -> base_trainer.Trainer:
|
151 |
+
"""Create trainer."""
|
152 |
+
with self.strategy.scope():
|
153 |
+
trainer = train_utils.create_trainer(
|
154 |
+
self.params,
|
155 |
+
task,
|
156 |
+
train=train,
|
157 |
+
evaluate=evaluate,
|
158 |
+
checkpoint_exporter=self._build_best_checkpoint_exporter())
|
159 |
+
return trainer
|
160 |
+
|
161 |
+
def _build_best_checkpoint_exporter(self):
|
162 |
+
return maybe_create_best_ckpt_exporter(self.params, self.model_dir)
|
163 |
+
|
164 |
+
def _maybe_build_checkpoint_manager(
|
165 |
+
self) -> Optional[tf.train.CheckpointManager]:
|
166 |
+
"""Maybe create a CheckpointManager."""
|
167 |
+
assert self.trainer is not None
|
168 |
+
if self.trainer.checkpoint:
|
169 |
+
if self.model_dir is None:
|
170 |
+
raise ValueError('model_dir must be specified, but got None')
|
171 |
+
|
172 |
+
if (not self.strategy) or self.strategy.extended.should_checkpoint:
|
173 |
+
ckpt_path = self.model_dir
|
174 |
+
max_to_keep = self.params.trainer.max_to_keep
|
175 |
+
else:
|
176 |
+
# In multi worker training we need every worker to save checkpoint,
|
177 |
+
# because variables can trigger synchronization on read and
|
178 |
+
# synchronization needs all workers to participate. To avoid workers
|
179 |
+
# overriding each other we save to a temporary directory on non-chief
|
180 |
+
# workers.
|
181 |
+
ckpt_path = tempfile.mkdtemp()
|
182 |
+
max_to_keep = 1
|
183 |
+
|
184 |
+
checkpoint_manager = tf.train.CheckpointManager(
|
185 |
+
self.trainer.checkpoint,
|
186 |
+
directory=ckpt_path,
|
187 |
+
max_to_keep=max_to_keep,
|
188 |
+
step_counter=self.trainer.global_step,
|
189 |
+
checkpoint_interval=self.params.trainer.checkpoint_interval,
|
190 |
+
init_fn=self.trainer.initialize)
|
191 |
+
else:
|
192 |
+
checkpoint_manager = None
|
193 |
+
return checkpoint_manager
|
194 |
+
|
195 |
+
def _build_controller(
|
196 |
+
self,
|
197 |
+
trainer,
|
198 |
+
evaluator,
|
199 |
+
save_summary: bool = True,
|
200 |
+
train_actions: Optional[List[orbit.Action]] = None,
|
201 |
+
eval_actions: Optional[List[orbit.Action]] = None,
|
202 |
+
controller_cls=orbit.Controller,
|
203 |
+
enable_async_checkpointing: bool = False,
|
204 |
+
) -> orbit.Controller:
|
205 |
+
"""Builds a Orbit controler."""
|
206 |
+
train_actions = [] if not train_actions else train_actions
|
207 |
+
if trainer:
|
208 |
+
checkpoint_manager = self.checkpoint_manager
|
209 |
+
assert checkpoint_manager, 'Checkpoint manager required but undefined.'
|
210 |
+
train_actions += actions.get_train_actions(
|
211 |
+
self.params,
|
212 |
+
trainer,
|
213 |
+
self.model_dir,
|
214 |
+
checkpoint_manager=checkpoint_manager,
|
215 |
+
)
|
216 |
+
|
217 |
+
eval_actions = [] if not eval_actions else eval_actions
|
218 |
+
if evaluator:
|
219 |
+
eval_actions += actions.get_eval_actions(self.params, evaluator,
|
220 |
+
self.model_dir)
|
221 |
+
|
222 |
+
if save_summary:
|
223 |
+
eval_summary_dir = os.path.join(
|
224 |
+
self.model_dir, self.params.trainer.validation_summary_subdir
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
eval_summary_dir = None
|
228 |
+
|
229 |
+
controller = controller_cls(
|
230 |
+
strategy=self.strategy,
|
231 |
+
trainer=trainer,
|
232 |
+
evaluator=evaluator,
|
233 |
+
global_step=self.trainer.global_step,
|
234 |
+
steps_per_loop=self.params.trainer.steps_per_loop,
|
235 |
+
checkpoint_manager=self.checkpoint_manager,
|
236 |
+
enable_async_checkpointing=enable_async_checkpointing,
|
237 |
+
summary_dir=os.path.join(self.model_dir, 'train')
|
238 |
+
if (save_summary)
|
239 |
+
else None,
|
240 |
+
eval_summary_dir=eval_summary_dir,
|
241 |
+
summary_interval=self.params.trainer.summary_interval
|
242 |
+
if (save_summary)
|
243 |
+
else None,
|
244 |
+
train_actions=train_actions,
|
245 |
+
eval_actions=eval_actions,
|
246 |
+
summary_manager=self._summary_manager
|
247 |
+
if hasattr(self, '_summary_manager')
|
248 |
+
else None,
|
249 |
+
eval_summary_manager=self._eval_summary_manager
|
250 |
+
if hasattr(self, '_eval_summary_manager')
|
251 |
+
else None,
|
252 |
+
)
|
253 |
+
return controller
|
254 |
+
|
255 |
+
def run(self) -> Tuple[tf_keras.Model, Mapping[str, Any]]:
|
256 |
+
"""Run experiments by mode.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
A 2-tuple of (model, eval_logs).
|
260 |
+
model: `tf_keras.Model` instance.
|
261 |
+
eval_logs: returns eval metrics logs when run_post_eval is set to True,
|
262 |
+
otherwise, returns {}.
|
263 |
+
"""
|
264 |
+
mode = self._mode
|
265 |
+
params = self.params
|
266 |
+
logging.info('Starts to execute mode: %s', mode)
|
267 |
+
with self.strategy.scope():
|
268 |
+
if mode == 'train' or mode == 'train_and_post_eval':
|
269 |
+
self.controller.train(steps=params.trainer.train_steps)
|
270 |
+
elif mode == 'train_and_eval':
|
271 |
+
self.controller.train_and_evaluate(
|
272 |
+
train_steps=params.trainer.train_steps,
|
273 |
+
eval_steps=params.trainer.validation_steps,
|
274 |
+
eval_interval=params.trainer.validation_interval)
|
275 |
+
elif mode == 'eval':
|
276 |
+
self.controller.evaluate(steps=params.trainer.validation_steps)
|
277 |
+
elif mode == 'continuous_eval':
|
278 |
+
|
279 |
+
def timeout_fn():
|
280 |
+
if self.trainer.global_step.numpy() >= params.trainer.train_steps:
|
281 |
+
return True
|
282 |
+
return False
|
283 |
+
|
284 |
+
self.controller.evaluate_continuously(
|
285 |
+
steps=params.trainer.validation_steps,
|
286 |
+
timeout=params.trainer.continuous_eval_timeout,
|
287 |
+
timeout_fn=timeout_fn)
|
288 |
+
else:
|
289 |
+
raise NotImplementedError('The mode is not implemented: %s' % mode)
|
290 |
+
|
291 |
+
num_params = train_utils.try_count_params(self.trainer.model)
|
292 |
+
if num_params is not None:
|
293 |
+
logging.info('Number of trainable params in model: %f Millions.',
|
294 |
+
num_params / 10.**6)
|
295 |
+
|
296 |
+
flops = train_utils.try_count_flops(self.trainer.model)
|
297 |
+
if flops is not None:
|
298 |
+
logging.info('FLOPs (multi-adds) in model: %f Billions.',
|
299 |
+
flops / 10.**9 / 2)
|
300 |
+
|
301 |
+
if self._run_post_eval or mode == 'train_and_post_eval':
|
302 |
+
with self.strategy.scope():
|
303 |
+
return self.trainer.model, self.controller.evaluate(
|
304 |
+
steps=params.trainer.validation_steps)
|
305 |
+
else:
|
306 |
+
return self.trainer.model, {}
|
307 |
+
|
308 |
+
|
309 |
+
def run_experiment(
|
310 |
+
distribution_strategy: tf.distribute.Strategy,
|
311 |
+
task: base_task.Task,
|
312 |
+
mode: str,
|
313 |
+
params: config_definitions.ExperimentConfig,
|
314 |
+
model_dir: str,
|
315 |
+
run_post_eval: bool = False,
|
316 |
+
save_summary: bool = True,
|
317 |
+
train_actions: Optional[List[orbit.Action]] = None,
|
318 |
+
eval_actions: Optional[List[orbit.Action]] = None,
|
319 |
+
trainer: Optional[base_trainer.Trainer] = None,
|
320 |
+
controller_cls=orbit.Controller,
|
321 |
+
summary_manager: Optional[orbit.utils.SummaryManager] = None,
|
322 |
+
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
|
323 |
+
enable_async_checkpointing: bool = False,
|
324 |
+
) -> Tuple[tf_keras.Model, Mapping[str, Any]]:
|
325 |
+
"""Runs train/eval configured by the experiment params.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
distribution_strategy: A distribution distribution_strategy.
|
329 |
+
task: A Task instance.
|
330 |
+
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
|
331 |
+
or 'continuous_eval'.
|
332 |
+
params: ExperimentConfig instance.
|
333 |
+
model_dir: A 'str', a path to store model checkpoints and summaries.
|
334 |
+
run_post_eval: Whether to run post eval once after training, metrics logs
|
335 |
+
are returned.
|
336 |
+
save_summary: Whether to save train and validation summary.
|
337 |
+
train_actions: Optional list of Orbit train actions.
|
338 |
+
eval_actions: Optional list of Orbit eval actions.
|
339 |
+
trainer: the base_trainer.Trainer instance. It should be created within the
|
340 |
+
strategy.scope().
|
341 |
+
controller_cls: The controller class to manage the train and eval process.
|
342 |
+
Must be a orbit.Controller subclass.
|
343 |
+
summary_manager: Instance of the summary manager to override default summary
|
344 |
+
manager.
|
345 |
+
eval_summary_manager: Instance of the eval summary manager to override
|
346 |
+
default eval summary manager.
|
347 |
+
enable_async_checkpointing: Optional boolean indicating whether to enable
|
348 |
+
async checkpoint saving.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
A 2-tuple of (model, eval_logs).
|
352 |
+
model: `tf_keras.Model` instance.
|
353 |
+
eval_logs: returns eval metrics logs when run_post_eval is set to True,
|
354 |
+
otherwise, returns {}.
|
355 |
+
"""
|
356 |
+
runner = OrbitExperimentRunner(
|
357 |
+
distribution_strategy=distribution_strategy,
|
358 |
+
task=task,
|
359 |
+
mode=mode,
|
360 |
+
params=params,
|
361 |
+
model_dir=model_dir,
|
362 |
+
run_post_eval=run_post_eval,
|
363 |
+
save_summary=save_summary,
|
364 |
+
train_actions=train_actions,
|
365 |
+
eval_actions=eval_actions,
|
366 |
+
trainer=trainer,
|
367 |
+
controller_cls=controller_cls,
|
368 |
+
summary_manager=summary_manager,
|
369 |
+
eval_summary_manager=eval_summary_manager,
|
370 |
+
enable_async_checkpointing=enable_async_checkpointing,
|
371 |
+
)
|
372 |
+
return runner.run()
|
modeling/official/core/train_lib_test.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for train_ctl_lib."""
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
|
19 |
+
from absl import flags
|
20 |
+
from absl.testing import flagsaver
|
21 |
+
from absl.testing import parameterized
|
22 |
+
import numpy as np
|
23 |
+
import tensorflow as tf, tf_keras
|
24 |
+
|
25 |
+
from tensorflow.python.distribute import combinations
|
26 |
+
from tensorflow.python.distribute import strategy_combinations
|
27 |
+
from official.common import flags as tfm_flags
|
28 |
+
# pylint: disable=unused-import
|
29 |
+
from official.common import registry_imports
|
30 |
+
# pylint: enable=unused-import
|
31 |
+
from official.core import task_factory
|
32 |
+
from official.core import train_lib
|
33 |
+
from official.core import train_utils
|
34 |
+
from official.utils.testing import mock_task
|
35 |
+
|
36 |
+
FLAGS = flags.FLAGS
|
37 |
+
|
38 |
+
tfm_flags.define_flags()
|
39 |
+
|
40 |
+
|
41 |
+
class TrainTest(tf.test.TestCase, parameterized.TestCase):
|
42 |
+
|
43 |
+
def setUp(self):
|
44 |
+
super(TrainTest, self).setUp()
|
45 |
+
self._test_config = {
|
46 |
+
'trainer': {
|
47 |
+
'checkpoint_interval': 10,
|
48 |
+
'steps_per_loop': 10,
|
49 |
+
'summary_interval': 10,
|
50 |
+
'train_steps': 10,
|
51 |
+
'validation_steps': 5,
|
52 |
+
'validation_interval': 10,
|
53 |
+
'continuous_eval_timeout': 1,
|
54 |
+
'validation_summary_subdir': 'validation',
|
55 |
+
'optimizer_config': {
|
56 |
+
'optimizer': {
|
57 |
+
'type': 'sgd',
|
58 |
+
},
|
59 |
+
'learning_rate': {
|
60 |
+
'type': 'constant'
|
61 |
+
}
|
62 |
+
}
|
63 |
+
},
|
64 |
+
}
|
65 |
+
|
66 |
+
@combinations.generate(
|
67 |
+
combinations.combine(
|
68 |
+
distribution_strategy=[
|
69 |
+
strategy_combinations.default_strategy,
|
70 |
+
strategy_combinations.cloud_tpu_strategy,
|
71 |
+
strategy_combinations.one_device_strategy_gpu,
|
72 |
+
],
|
73 |
+
flag_mode=['train', 'eval', 'train_and_eval'],
|
74 |
+
run_post_eval=[True, False]))
|
75 |
+
def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
|
76 |
+
model_dir = self.get_temp_dir()
|
77 |
+
flags_dict = dict(
|
78 |
+
experiment='mock',
|
79 |
+
mode=flag_mode,
|
80 |
+
model_dir=model_dir,
|
81 |
+
params_override=json.dumps(self._test_config))
|
82 |
+
with flagsaver.flagsaver(**flags_dict):
|
83 |
+
params = train_utils.parse_configuration(flags.FLAGS)
|
84 |
+
train_utils.serialize_config(params, model_dir)
|
85 |
+
with distribution_strategy.scope():
|
86 |
+
task = task_factory.get_task(params.task, logging_dir=model_dir)
|
87 |
+
|
88 |
+
_, logs = train_lib.run_experiment(
|
89 |
+
distribution_strategy=distribution_strategy,
|
90 |
+
task=task,
|
91 |
+
mode=flag_mode,
|
92 |
+
params=params,
|
93 |
+
model_dir=model_dir,
|
94 |
+
run_post_eval=run_post_eval)
|
95 |
+
|
96 |
+
if 'eval' in flag_mode:
|
97 |
+
self.assertTrue(
|
98 |
+
tf.io.gfile.exists(
|
99 |
+
os.path.join(model_dir,
|
100 |
+
params.trainer.validation_summary_subdir)))
|
101 |
+
if run_post_eval:
|
102 |
+
self.assertNotEmpty(logs)
|
103 |
+
else:
|
104 |
+
self.assertEmpty(logs)
|
105 |
+
self.assertNotEmpty(
|
106 |
+
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
|
107 |
+
if flag_mode == 'eval':
|
108 |
+
return
|
109 |
+
self.assertNotEmpty(
|
110 |
+
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
|
111 |
+
# Tests continuous evaluation.
|
112 |
+
_, logs = train_lib.run_experiment(
|
113 |
+
distribution_strategy=distribution_strategy,
|
114 |
+
task=task,
|
115 |
+
mode='continuous_eval',
|
116 |
+
params=params,
|
117 |
+
model_dir=model_dir,
|
118 |
+
run_post_eval=run_post_eval)
|
119 |
+
|
120 |
+
@combinations.generate(
|
121 |
+
combinations.combine(
|
122 |
+
distribution_strategy=[
|
123 |
+
strategy_combinations.default_strategy,
|
124 |
+
strategy_combinations.cloud_tpu_strategy,
|
125 |
+
strategy_combinations.one_device_strategy_gpu,
|
126 |
+
],
|
127 |
+
flag_mode=['train', 'eval', 'train_and_eval'],
|
128 |
+
run_post_eval=[True, False]))
|
129 |
+
def test_end_to_end_class(self, distribution_strategy, flag_mode,
|
130 |
+
run_post_eval):
|
131 |
+
model_dir = self.get_temp_dir()
|
132 |
+
flags_dict = dict(
|
133 |
+
experiment='mock',
|
134 |
+
mode=flag_mode,
|
135 |
+
model_dir=model_dir,
|
136 |
+
params_override=json.dumps(self._test_config))
|
137 |
+
with flagsaver.flagsaver(**flags_dict):
|
138 |
+
params = train_utils.parse_configuration(flags.FLAGS)
|
139 |
+
train_utils.serialize_config(params, model_dir)
|
140 |
+
with distribution_strategy.scope():
|
141 |
+
task = task_factory.get_task(params.task, logging_dir=model_dir)
|
142 |
+
|
143 |
+
_, logs = train_lib.OrbitExperimentRunner(
|
144 |
+
distribution_strategy=distribution_strategy,
|
145 |
+
task=task,
|
146 |
+
mode=flag_mode,
|
147 |
+
params=params,
|
148 |
+
model_dir=model_dir,
|
149 |
+
run_post_eval=run_post_eval).run()
|
150 |
+
|
151 |
+
if 'eval' in flag_mode:
|
152 |
+
self.assertTrue(
|
153 |
+
tf.io.gfile.exists(
|
154 |
+
os.path.join(model_dir,
|
155 |
+
params.trainer.validation_summary_subdir)))
|
156 |
+
if run_post_eval:
|
157 |
+
self.assertNotEmpty(logs)
|
158 |
+
else:
|
159 |
+
self.assertEmpty(logs)
|
160 |
+
self.assertNotEmpty(
|
161 |
+
tf.io.gfile.glob(os.path.join(model_dir, 'params.yaml')))
|
162 |
+
if flag_mode == 'eval':
|
163 |
+
return
|
164 |
+
self.assertNotEmpty(
|
165 |
+
tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
|
166 |
+
# Tests continuous evaluation.
|
167 |
+
_, logs = train_lib.OrbitExperimentRunner(
|
168 |
+
distribution_strategy=distribution_strategy,
|
169 |
+
task=task,
|
170 |
+
mode='continuous_eval',
|
171 |
+
params=params,
|
172 |
+
model_dir=model_dir,
|
173 |
+
run_post_eval=run_post_eval).run()
|
174 |
+
|
175 |
+
@combinations.generate(
|
176 |
+
combinations.combine(
|
177 |
+
distribution_strategy=[
|
178 |
+
strategy_combinations.default_strategy,
|
179 |
+
strategy_combinations.cloud_tpu_strategy,
|
180 |
+
strategy_combinations.one_device_strategy_gpu,
|
181 |
+
],
|
182 |
+
flag_mode=['train', 'train_and_eval'],
|
183 |
+
))
|
184 |
+
def test_recovery_nan_error(self, distribution_strategy, flag_mode):
|
185 |
+
model_dir = self.get_temp_dir()
|
186 |
+
flags_dict = dict(
|
187 |
+
experiment='mock',
|
188 |
+
mode=flag_mode,
|
189 |
+
model_dir=model_dir,
|
190 |
+
params_override=json.dumps(self._test_config))
|
191 |
+
with flagsaver.flagsaver(**flags_dict):
|
192 |
+
params = train_utils.parse_configuration(flags.FLAGS)
|
193 |
+
train_utils.serialize_config(params, model_dir)
|
194 |
+
with distribution_strategy.scope():
|
195 |
+
# task = task_factory.get_task(params.task, logging_dir=model_dir)
|
196 |
+
task = mock_task.MockTask(params.task, logging_dir=model_dir)
|
197 |
+
|
198 |
+
# Set the loss to NaN to trigger RunTimeError.
|
199 |
+
def build_losses(labels, model_outputs, aux_losses=None):
|
200 |
+
del labels, model_outputs
|
201 |
+
return tf.constant([np.nan], tf.float32) + aux_losses
|
202 |
+
|
203 |
+
task.build_losses = build_losses
|
204 |
+
|
205 |
+
with self.assertRaises(RuntimeError):
|
206 |
+
train_lib.OrbitExperimentRunner(
|
207 |
+
distribution_strategy=distribution_strategy,
|
208 |
+
task=task,
|
209 |
+
mode=flag_mode,
|
210 |
+
params=params,
|
211 |
+
model_dir=model_dir).run()
|
212 |
+
|
213 |
+
@combinations.generate(
|
214 |
+
combinations.combine(
|
215 |
+
distribution_strategy=[
|
216 |
+
strategy_combinations.default_strategy,
|
217 |
+
strategy_combinations.cloud_tpu_strategy,
|
218 |
+
strategy_combinations.one_device_strategy_gpu,
|
219 |
+
],
|
220 |
+
flag_mode=['train'],
|
221 |
+
))
|
222 |
+
def test_recovery(self, distribution_strategy, flag_mode):
|
223 |
+
loss_threshold = 1.0
|
224 |
+
model_dir = self.get_temp_dir()
|
225 |
+
flags_dict = dict(
|
226 |
+
experiment='mock',
|
227 |
+
mode=flag_mode,
|
228 |
+
model_dir=model_dir,
|
229 |
+
params_override=json.dumps(self._test_config))
|
230 |
+
with flagsaver.flagsaver(**flags_dict):
|
231 |
+
params = train_utils.parse_configuration(flags.FLAGS)
|
232 |
+
params.trainer.loss_upper_bound = loss_threshold
|
233 |
+
params.trainer.recovery_max_trials = 1
|
234 |
+
train_utils.serialize_config(params, model_dir)
|
235 |
+
with distribution_strategy.scope():
|
236 |
+
task = task_factory.get_task(params.task, logging_dir=model_dir)
|
237 |
+
|
238 |
+
# Saves a checkpoint for reference.
|
239 |
+
model = task.build_model()
|
240 |
+
checkpoint = tf.train.Checkpoint(model=model)
|
241 |
+
checkpoint_manager = tf.train.CheckpointManager(
|
242 |
+
checkpoint, self.get_temp_dir(), max_to_keep=2)
|
243 |
+
checkpoint_manager.save()
|
244 |
+
before_weights = model.get_weights()
|
245 |
+
|
246 |
+
def build_losses(labels, model_outputs, aux_losses=None):
|
247 |
+
del labels, model_outputs
|
248 |
+
return tf.constant([loss_threshold], tf.float32) + aux_losses
|
249 |
+
|
250 |
+
task.build_losses = build_losses
|
251 |
+
|
252 |
+
model, _ = train_lib.OrbitExperimentRunner(
|
253 |
+
distribution_strategy=distribution_strategy,
|
254 |
+
task=task,
|
255 |
+
mode=flag_mode,
|
256 |
+
params=params,
|
257 |
+
model_dir=model_dir).run()
|
258 |
+
after_weights = model.get_weights()
|
259 |
+
for left, right in zip(before_weights, after_weights):
|
260 |
+
self.assertAllEqual(left, right)
|
261 |
+
|
262 |
+
def test_parse_configuration(self):
|
263 |
+
model_dir = self.get_temp_dir()
|
264 |
+
flags_dict = dict(
|
265 |
+
experiment='mock',
|
266 |
+
mode='train',
|
267 |
+
model_dir=model_dir,
|
268 |
+
params_override=json.dumps(self._test_config))
|
269 |
+
with flagsaver.flagsaver(**flags_dict):
|
270 |
+
params = train_utils.parse_configuration(flags.FLAGS, lock_return=True)
|
271 |
+
with self.assertRaises(ValueError):
|
272 |
+
params.override({'task': {'init_checkpoint': 'Foo'}})
|
273 |
+
|
274 |
+
params = train_utils.parse_configuration(flags.FLAGS, lock_return=False)
|
275 |
+
params.override({'task': {'init_checkpoint': 'Bar'}})
|
276 |
+
self.assertEqual(params.task.init_checkpoint, 'Bar')
|
277 |
+
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
tf.test.main()
|
modeling/official/core/train_utils.py
ADDED
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Training utils."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
import inspect
|
19 |
+
import json
|
20 |
+
import os
|
21 |
+
import pprint
|
22 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
23 |
+
|
24 |
+
from absl import logging
|
25 |
+
import gin
|
26 |
+
import numpy as np
|
27 |
+
import orbit
|
28 |
+
import tensorflow as tf, tf_keras
|
29 |
+
|
30 |
+
# pylint: disable=g-direct-tensorflow-import
|
31 |
+
from tensorflow.python.framework import ops
|
32 |
+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
|
33 |
+
# pylint: enable=g-direct-tensorflow-import
|
34 |
+
from official.core import base_task
|
35 |
+
from official.core import base_trainer
|
36 |
+
from official.core import config_definitions
|
37 |
+
from official.core import exp_factory
|
38 |
+
from official.modeling import hyperparams
|
39 |
+
|
40 |
+
|
41 |
+
BEST_CHECKPOINT_NAME = 'best_ckpt'
|
42 |
+
|
43 |
+
|
44 |
+
def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
|
45 |
+
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
d: The dictionary to extract value from.
|
49 |
+
keys: The list of keys to extract values recursively.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
The value of the leaf.
|
53 |
+
|
54 |
+
Raises:
|
55 |
+
KeyError: If the value of keys extracted is a dictionary.
|
56 |
+
"""
|
57 |
+
leaf = d
|
58 |
+
for k in keys:
|
59 |
+
if not isinstance(leaf, dict) or k not in leaf:
|
60 |
+
raise KeyError(
|
61 |
+
'Path not exist while traversing the dictionary: d with keys'
|
62 |
+
': %s.' % keys)
|
63 |
+
leaf = leaf[k]
|
64 |
+
|
65 |
+
if isinstance(leaf, dict):
|
66 |
+
raise KeyError('The value extracted with keys: %s is not a leaf of the '
|
67 |
+
'dictionary: %s.' % (keys, d))
|
68 |
+
return leaf
|
69 |
+
|
70 |
+
|
71 |
+
def cast_leaf_nested_dict(d: Dict[str, Any],
|
72 |
+
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
|
73 |
+
"""Cast the leaves of a dictionary with arbitrary depth in place.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
d: The dictionary to extract value from.
|
77 |
+
cast_fn: The casting function.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
A dictionray with the same structure as d.
|
81 |
+
"""
|
82 |
+
for key, value in d.items():
|
83 |
+
if isinstance(value, dict):
|
84 |
+
d[key] = cast_leaf_nested_dict(value, cast_fn)
|
85 |
+
else:
|
86 |
+
d[key] = cast_fn(value)
|
87 |
+
return d
|
88 |
+
|
89 |
+
|
90 |
+
def _filter_leaf_nested_dict(
|
91 |
+
d: Dict[str, Any], predicate: Callable[[Any], bool]
|
92 |
+
) -> Dict[str, Any]:
|
93 |
+
"""Filters the leaves of a dictionary with arbitrary depth in place.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
d: The dictionary to extract value from.
|
97 |
+
predicate: A function that will be called on every leave item. When the
|
98 |
+
function returns True the leave will be kept. Otherwise the leave will be
|
99 |
+
dropped.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
A new dictionray with filtered result.
|
103 |
+
"""
|
104 |
+
result = {}
|
105 |
+
for key, value in d.items():
|
106 |
+
if isinstance(value, dict):
|
107 |
+
result[key] = _filter_leaf_nested_dict(value, predicate)
|
108 |
+
elif predicate(value):
|
109 |
+
result[key] = value
|
110 |
+
return result
|
111 |
+
|
112 |
+
|
113 |
+
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
|
114 |
+
data_dir: str) -> Any:
|
115 |
+
"""Maybe create a BestCheckpointExporter object, according to the config."""
|
116 |
+
export_subdir = params.trainer.best_checkpoint_export_subdir
|
117 |
+
metric_name = params.trainer.best_checkpoint_eval_metric
|
118 |
+
metric_comp = params.trainer.best_checkpoint_metric_comp
|
119 |
+
if data_dir and export_subdir and metric_name:
|
120 |
+
best_ckpt_dir = os.path.join(data_dir, export_subdir)
|
121 |
+
best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
|
122 |
+
metric_comp)
|
123 |
+
logging.info(
|
124 |
+
'Created the best checkpoint exporter. '
|
125 |
+
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
|
126 |
+
export_subdir, metric_name)
|
127 |
+
else:
|
128 |
+
best_ckpt_exporter = None
|
129 |
+
|
130 |
+
return best_ckpt_exporter
|
131 |
+
|
132 |
+
|
133 |
+
class BestCheckpointExporter:
|
134 |
+
"""Keeps track of the best result, and saves its checkpoint.
|
135 |
+
|
136 |
+
Orbit will support an API for checkpoint exporter. This class will be used
|
137 |
+
together with orbit once this functionality is ready.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
|
141 |
+
"""Initialization.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
export_dir: The directory that will contain exported checkpoints.
|
145 |
+
metric_name: Indicates which metric to look at, when determining which
|
146 |
+
result is better. If eval_logs being passed to maybe_export_checkpoint
|
147 |
+
is a nested dictionary, use `|` as a seperator for different layers.
|
148 |
+
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
|
149 |
+
"""
|
150 |
+
self._export_dir = export_dir
|
151 |
+
self._metric_name = metric_name.split('|')
|
152 |
+
self._metric_comp = metric_comp
|
153 |
+
if self._metric_comp not in ('lower', 'higher'):
|
154 |
+
raise ValueError('best checkpoint metric comp must be one of '
|
155 |
+
'higher, lower. Got: {}'.format(self._metric_comp))
|
156 |
+
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
|
157 |
+
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
|
158 |
+
self._checkpoint_manager = None
|
159 |
+
|
160 |
+
def _get_checkpoint_manager(self, checkpoint):
|
161 |
+
"""Gets an existing checkpoint manager or creates a new one."""
|
162 |
+
if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
|
163 |
+
!= checkpoint):
|
164 |
+
logging.info('Creates a new checkpoint manager.')
|
165 |
+
self._checkpoint_manager = tf.train.CheckpointManager(
|
166 |
+
checkpoint,
|
167 |
+
directory=self._export_dir,
|
168 |
+
max_to_keep=1,
|
169 |
+
checkpoint_name=BEST_CHECKPOINT_NAME)
|
170 |
+
|
171 |
+
return self._checkpoint_manager
|
172 |
+
|
173 |
+
def maybe_export_checkpoint(
|
174 |
+
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
|
175 |
+
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
|
176 |
+
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
|
177 |
+
eval_logs, global_step)
|
178 |
+
if self._best_ckpt_logs is None or self._new_metric_is_better(
|
179 |
+
self._best_ckpt_logs, eval_logs):
|
180 |
+
self._best_ckpt_logs = eval_logs
|
181 |
+
if write_logs:
|
182 |
+
self.export_best_eval_metric(self._best_ckpt_logs, global_step)
|
183 |
+
self._get_checkpoint_manager(checkpoint).save()
|
184 |
+
return True
|
185 |
+
return False
|
186 |
+
|
187 |
+
def _maybe_load_best_eval_metric(self):
|
188 |
+
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
|
189 |
+
return None
|
190 |
+
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
|
191 |
+
return json.loads(reader.read())
|
192 |
+
|
193 |
+
def _new_metric_is_better(self, old_logs, new_logs):
|
194 |
+
"""Check if the metric in new_logs is better than the metric in old_logs."""
|
195 |
+
old_value = float(
|
196 |
+
orbit.utils.get_value(
|
197 |
+
get_leaf_nested_dict(old_logs, self._metric_name)))
|
198 |
+
new_value = float(
|
199 |
+
orbit.utils.get_value(
|
200 |
+
get_leaf_nested_dict(new_logs, self._metric_name)))
|
201 |
+
|
202 |
+
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
|
203 |
+
old_value, new_value)
|
204 |
+
if self._metric_comp == 'higher':
|
205 |
+
if new_value > old_value:
|
206 |
+
logging.info('[BestCheckpointExporter] '
|
207 |
+
'the new number is better since it is higher.')
|
208 |
+
return True
|
209 |
+
else: # self._metric_comp == 'lower':
|
210 |
+
if new_value < old_value:
|
211 |
+
logging.info('[BestCheckpointExporter] '
|
212 |
+
'the new number is better since it is lower.')
|
213 |
+
return True
|
214 |
+
return False
|
215 |
+
|
216 |
+
def export_best_eval_metric(self, eval_logs, global_step):
|
217 |
+
"""Export evaluation results of the best checkpoint into a json file."""
|
218 |
+
# eval_log_ext may contains non-scalar tensors, such as image data when
|
219 |
+
# `allow_image_summary` is True. Here we only keep scalar tensors.
|
220 |
+
eval_logs_ext = _filter_leaf_nested_dict(
|
221 |
+
eval_logs, lambda x: tf.rank(x) <= 1
|
222 |
+
)
|
223 |
+
eval_logs_ext['best_ckpt_global_step'] = global_step
|
224 |
+
eval_logs_ext = cast_leaf_nested_dict(
|
225 |
+
eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
|
226 |
+
# Saving json file is very fast.
|
227 |
+
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
|
228 |
+
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
|
229 |
+
|
230 |
+
@property
|
231 |
+
def best_ckpt_logs(self):
|
232 |
+
return self._best_ckpt_logs
|
233 |
+
|
234 |
+
@property
|
235 |
+
def best_ckpt_logs_path(self):
|
236 |
+
return os.path.join(self._export_dir, 'info.json')
|
237 |
+
|
238 |
+
@property
|
239 |
+
def best_ckpt_path(self):
|
240 |
+
"""Returns the best ckpt path or None if there is no ckpt yet."""
|
241 |
+
return tf.train.latest_checkpoint(self._export_dir)
|
242 |
+
|
243 |
+
|
244 |
+
def create_optimizer(task: base_task.Task,
|
245 |
+
params: config_definitions.ExperimentConfig
|
246 |
+
) -> tf_keras.optimizers.Optimizer:
|
247 |
+
"""A create optimizer util to be backward compatability with new args."""
|
248 |
+
if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
|
249 |
+
dp_config = None
|
250 |
+
if hasattr(params.task, 'differential_privacy_config'):
|
251 |
+
dp_config = params.task.differential_privacy_config
|
252 |
+
optimizer = task.create_optimizer(
|
253 |
+
params.trainer.optimizer_config, params.runtime,
|
254 |
+
dp_config=dp_config)
|
255 |
+
else:
|
256 |
+
if hasattr(params.task, 'differential_privacy_config'
|
257 |
+
) and params.task.differential_privacy_config is not None:
|
258 |
+
raise ValueError('Differential privacy config is specified but '
|
259 |
+
'task.create_optimizer api does not accept it.')
|
260 |
+
optimizer = task.create_optimizer(
|
261 |
+
params.trainer.optimizer_config,
|
262 |
+
params.runtime)
|
263 |
+
return optimizer
|
264 |
+
|
265 |
+
|
266 |
+
@gin.configurable
|
267 |
+
def create_trainer(params: config_definitions.ExperimentConfig,
|
268 |
+
task: base_task.Task,
|
269 |
+
train: bool,
|
270 |
+
evaluate: bool,
|
271 |
+
checkpoint_exporter: Optional[BestCheckpointExporter] = None,
|
272 |
+
trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
|
273 |
+
"""Create trainer."""
|
274 |
+
logging.info('Running default trainer.')
|
275 |
+
model = task.build_model()
|
276 |
+
optimizer = create_optimizer(task, params)
|
277 |
+
return trainer_cls(
|
278 |
+
params,
|
279 |
+
task,
|
280 |
+
model=model,
|
281 |
+
optimizer=optimizer,
|
282 |
+
train=train,
|
283 |
+
evaluate=evaluate,
|
284 |
+
checkpoint_exporter=checkpoint_exporter)
|
285 |
+
|
286 |
+
|
287 |
+
@dataclasses.dataclass
|
288 |
+
class ParseConfigOptions:
|
289 |
+
"""Use this dataclass instead of FLAGS to customize parse_configuration()."""
|
290 |
+
experiment: str
|
291 |
+
config_file: List[str]
|
292 |
+
tpu: str = ''
|
293 |
+
tf_data_service: str = ''
|
294 |
+
params_override: str = ''
|
295 |
+
|
296 |
+
def __contains__(self, name):
|
297 |
+
return name in dataclasses.asdict(self)
|
298 |
+
|
299 |
+
|
300 |
+
class ExperimentParser:
|
301 |
+
"""Constructs the Experiment config from Flags or equivalent object.
|
302 |
+
|
303 |
+
Most of the cases, users only need to call the `parse()` function:
|
304 |
+
```
|
305 |
+
builder = ExperimentParser(FLAGS)
|
306 |
+
params = builder.parse()
|
307 |
+
```
|
308 |
+
|
309 |
+
The advanced users can modify the flow by calling the parse_*() functions
|
310 |
+
separately.
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(self, flags_obj):
|
314 |
+
self._flags_obj = flags_obj
|
315 |
+
|
316 |
+
def parse(self):
|
317 |
+
"""Overrall process of constructing Experiment config."""
|
318 |
+
params = self.base_experiment()
|
319 |
+
params = self.parse_config_file(params)
|
320 |
+
params = self.parse_runtime(params)
|
321 |
+
params = self.parse_data_service(params)
|
322 |
+
params = self.parse_params_override(params)
|
323 |
+
return params
|
324 |
+
|
325 |
+
def base_experiment(self):
|
326 |
+
"""Get the base experiment config from --experiment field."""
|
327 |
+
if self._flags_obj.experiment is None:
|
328 |
+
raise ValueError('The flag --experiment must be specified.')
|
329 |
+
return exp_factory.get_exp_config(self._flags_obj.experiment)
|
330 |
+
|
331 |
+
def parse_config_file(self, params):
|
332 |
+
"""Override the configs of params from the config_file."""
|
333 |
+
for config_file in self._flags_obj.config_file or []:
|
334 |
+
params = hyperparams.override_params_dict(
|
335 |
+
params, config_file, is_strict=True)
|
336 |
+
return params
|
337 |
+
|
338 |
+
def parse_runtime(self, params):
|
339 |
+
"""Override the runtime configs of params from flags."""
|
340 |
+
# Override the TPU address and tf.data service address.
|
341 |
+
params.override({
|
342 |
+
'runtime': {
|
343 |
+
'tpu': self._flags_obj.tpu,
|
344 |
+
},
|
345 |
+
})
|
346 |
+
return params
|
347 |
+
|
348 |
+
def parse_data_service(self, params):
|
349 |
+
"""Override the data service configs of params from flags."""
|
350 |
+
if ('tf_data_service' in self._flags_obj and
|
351 |
+
self._flags_obj.tf_data_service and
|
352 |
+
isinstance(params.task, config_definitions.TaskConfig)):
|
353 |
+
params.override({
|
354 |
+
'task': {
|
355 |
+
'train_data': {
|
356 |
+
'tf_data_service_address': self._flags_obj.tf_data_service,
|
357 |
+
},
|
358 |
+
'validation_data': {
|
359 |
+
'tf_data_service_address': self._flags_obj.tf_data_service,
|
360 |
+
}
|
361 |
+
}
|
362 |
+
})
|
363 |
+
return params
|
364 |
+
|
365 |
+
def parse_params_override(self, params):
|
366 |
+
# Get the second level of override from `--params_override`.
|
367 |
+
# `--params_override` is typically used as a further override over the
|
368 |
+
# template. For example, one may define a particular template for training
|
369 |
+
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
|
370 |
+
# then define different learning rates and pass it via `--params_override`.
|
371 |
+
if self._flags_obj.params_override:
|
372 |
+
params = hyperparams.override_params_dict(
|
373 |
+
params, self._flags_obj.params_override, is_strict=True)
|
374 |
+
return params
|
375 |
+
|
376 |
+
|
377 |
+
def parse_configuration(flags_obj, lock_return=True, print_return=True):
|
378 |
+
"""Parses ExperimentConfig from flags."""
|
379 |
+
|
380 |
+
params = ExperimentParser(flags_obj).parse()
|
381 |
+
|
382 |
+
params.validate()
|
383 |
+
if lock_return:
|
384 |
+
params.lock()
|
385 |
+
|
386 |
+
if print_return:
|
387 |
+
pp = pprint.PrettyPrinter()
|
388 |
+
logging.info('Final experiment parameters:\n%s',
|
389 |
+
pp.pformat(params.as_dict()))
|
390 |
+
|
391 |
+
return params
|
392 |
+
|
393 |
+
|
394 |
+
def serialize_config(params: config_definitions.ExperimentConfig,
|
395 |
+
model_dir: str):
|
396 |
+
"""Serializes and saves the experiment config."""
|
397 |
+
if model_dir is None:
|
398 |
+
raise ValueError('model_dir must be specified, but got None')
|
399 |
+
params_save_path = os.path.join(model_dir, 'params.yaml')
|
400 |
+
logging.info('Saving experiment configuration to %s', params_save_path)
|
401 |
+
tf.io.gfile.makedirs(model_dir)
|
402 |
+
hyperparams.save_params_dict_to_yaml(params, params_save_path)
|
403 |
+
|
404 |
+
|
405 |
+
def save_gin_config(filename_suffix: str, model_dir: str):
|
406 |
+
"""Serializes and saves the experiment config."""
|
407 |
+
gin_save_path = os.path.join(
|
408 |
+
model_dir, 'operative_config.{}.gin'.format(filename_suffix))
|
409 |
+
logging.info('Saving gin configurations to %s', gin_save_path)
|
410 |
+
tf.io.gfile.makedirs(model_dir)
|
411 |
+
with tf.io.gfile.GFile(gin_save_path, 'w') as f:
|
412 |
+
f.write(gin.operative_config_str())
|
413 |
+
|
414 |
+
|
415 |
+
def read_global_step_from_checkpoint(ckpt_file_path):
|
416 |
+
"""Read global step from checkpoint, or get global step from its filename."""
|
417 |
+
global_step = tf.Variable(-1, dtype=tf.int64)
|
418 |
+
ckpt = tf.train.Checkpoint(global_step=global_step)
|
419 |
+
try:
|
420 |
+
ckpt.restore(ckpt_file_path).expect_partial()
|
421 |
+
global_step_maybe_restored = global_step.numpy()
|
422 |
+
except tf.errors.InvalidArgumentError:
|
423 |
+
global_step_maybe_restored = -1
|
424 |
+
|
425 |
+
if global_step_maybe_restored == -1:
|
426 |
+
raise ValueError('global_step not found in checkpoint {}. '
|
427 |
+
'If you want to run finetune eval jobs, you need to '
|
428 |
+
'make sure that your pretrain model writes '
|
429 |
+
'global_step in its checkpoints.'.format(ckpt_file_path))
|
430 |
+
global_step_restored = global_step.numpy()
|
431 |
+
logging.info('get global_step %d from checkpoint %s', global_step_restored,
|
432 |
+
ckpt_file_path)
|
433 |
+
return global_step_restored
|
434 |
+
|
435 |
+
|
436 |
+
def write_json_summary(log_dir, global_step, eval_metrics):
|
437 |
+
"""Dump evaluation metrics to json file."""
|
438 |
+
serializable_dict = {}
|
439 |
+
for name, value in eval_metrics.items():
|
440 |
+
if hasattr(value, 'numpy'):
|
441 |
+
serializable_dict[name] = str(value.numpy())
|
442 |
+
else:
|
443 |
+
serializable_dict[name] = str(value)
|
444 |
+
output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
|
445 |
+
logging.info('Evaluation results at pretrain step %d: %s', global_step,
|
446 |
+
serializable_dict)
|
447 |
+
with tf.io.gfile.GFile(output_json, 'w') as writer:
|
448 |
+
writer.write(json.dumps(serializable_dict, indent=4) + '\n')
|
449 |
+
|
450 |
+
|
451 |
+
def write_summary(summary_writer, global_step, eval_metrics):
|
452 |
+
"""Write evaluation metrics to TF summary."""
|
453 |
+
numeric_dict = {}
|
454 |
+
for name, value in eval_metrics.items():
|
455 |
+
numeric_dict[name] = float(orbit.utils.get_value(value))
|
456 |
+
with summary_writer.as_default():
|
457 |
+
for name, value in numeric_dict.items():
|
458 |
+
tf.summary.scalar(name, value, step=global_step)
|
459 |
+
summary_writer.flush()
|
460 |
+
|
461 |
+
|
462 |
+
def remove_ckpts(model_dir):
|
463 |
+
"""Remove model checkpoints, so we can restart."""
|
464 |
+
ckpts = os.path.join(model_dir, 'ckpt-*')
|
465 |
+
logging.info('removing checkpoint files %s', ckpts)
|
466 |
+
for file_to_remove in tf.io.gfile.glob(ckpts):
|
467 |
+
tf.io.gfile.rmtree(file_to_remove)
|
468 |
+
|
469 |
+
file_to_remove = os.path.join(model_dir, 'checkpoint')
|
470 |
+
if tf.io.gfile.exists(file_to_remove):
|
471 |
+
tf.io.gfile.remove(file_to_remove)
|
472 |
+
|
473 |
+
|
474 |
+
def write_model_params(model: Union[tf.Module, tf_keras.Model],
|
475 |
+
output_path: str) -> None:
|
476 |
+
"""Writes the model parameters and shapes to a file.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
model: A model instance.
|
480 |
+
output_path: Output file path.
|
481 |
+
"""
|
482 |
+
with tf.io.gfile.GFile(output_path, 'w') as f:
|
483 |
+
total_params = 0
|
484 |
+
for var in model.variables:
|
485 |
+
shape = tf.shape(var)
|
486 |
+
total_params += tf.math.reduce_prod(shape).numpy()
|
487 |
+
f.write(f'{var.name} {shape.numpy().tolist()}\n')
|
488 |
+
f.write(f'\nTotal params: {total_params}\n')
|
489 |
+
|
490 |
+
|
491 |
+
def try_count_params(
|
492 |
+
model: Union[tf.Module, tf_keras.Model],
|
493 |
+
trainable_only: bool = False):
|
494 |
+
"""Count the number of parameters if model is possible.
|
495 |
+
|
496 |
+
Args:
|
497 |
+
model: Try to count the number of params in this model.
|
498 |
+
trainable_only: Whether to calculate trainable params only. This flag is
|
499 |
+
not used when the model has `count_params` attribute.
|
500 |
+
|
501 |
+
Returns:
|
502 |
+
The number of parameters or None.
|
503 |
+
"""
|
504 |
+
if hasattr(model, 'count_params'):
|
505 |
+
try:
|
506 |
+
return model.count_params()
|
507 |
+
except ValueError:
|
508 |
+
logging.info('Number of trainable params unknown, because the build() '
|
509 |
+
'methods in keras layers were not called. This is probably '
|
510 |
+
'because the model was not feed any input, e.g., the max '
|
511 |
+
'train step already reached before this run.')
|
512 |
+
return None
|
513 |
+
else:
|
514 |
+
total_params = 0
|
515 |
+
variables = model.trainable_variables if trainable_only else model.variables
|
516 |
+
for var in variables:
|
517 |
+
shape = tf.shape(var)
|
518 |
+
total_params += tf.math.reduce_prod(shape).numpy()
|
519 |
+
return total_params
|
520 |
+
|
521 |
+
|
522 |
+
def try_count_flops(model: Union[tf.Module, tf_keras.Model],
|
523 |
+
inputs_kwargs: Optional[Dict[str, Any]] = None,
|
524 |
+
output_path: Optional[str] = None):
|
525 |
+
"""Counts and returns model FLOPs.
|
526 |
+
|
527 |
+
Args:
|
528 |
+
model: A model instance.
|
529 |
+
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
|
530 |
+
shape specifications to getting corresponding concrete function.
|
531 |
+
output_path: A file path to write the profiling results to.
|
532 |
+
|
533 |
+
Returns:
|
534 |
+
The model's FLOPs.
|
535 |
+
"""
|
536 |
+
if hasattr(model, 'inputs'):
|
537 |
+
try:
|
538 |
+
# Get input shape and set batch size to 1.
|
539 |
+
if model.inputs:
|
540 |
+
inputs = [
|
541 |
+
tf.TensorSpec([1] + input.shape[1:], input.dtype)
|
542 |
+
for input in model.inputs
|
543 |
+
]
|
544 |
+
concrete_func = tf.function(model).get_concrete_function(inputs)
|
545 |
+
# If model.inputs is invalid, try to use the input to get concrete
|
546 |
+
# function for model.call (subclass model).
|
547 |
+
else:
|
548 |
+
concrete_func = tf.function(model.call).get_concrete_function(
|
549 |
+
**inputs_kwargs)
|
550 |
+
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
|
551 |
+
|
552 |
+
# Calculate FLOPs.
|
553 |
+
run_meta = tf.compat.v1.RunMetadata()
|
554 |
+
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
|
555 |
+
if output_path is not None:
|
556 |
+
opts['output'] = f'file:outfile={output_path}'
|
557 |
+
else:
|
558 |
+
opts['output'] = 'none'
|
559 |
+
flops = tf.compat.v1.profiler.profile(
|
560 |
+
graph=frozen_func.graph, run_meta=run_meta, options=opts)
|
561 |
+
return flops.total_float_ops
|
562 |
+
except Exception as e: # pylint: disable=broad-except
|
563 |
+
logging.info(
|
564 |
+
'Failed to count model FLOPs with error %s, because the build() '
|
565 |
+
'methods in keras layers were not called. This is probably because '
|
566 |
+
'the model was not feed any input, e.g., the max train step already '
|
567 |
+
'reached before this run.', e)
|
568 |
+
return None
|
569 |
+
return None
|
570 |
+
|
571 |
+
|
572 |
+
@ops.RegisterStatistics('Einsum', 'flops')
|
573 |
+
def _einsum_flops(graph, node):
|
574 |
+
"""Calculates the compute resources needed for Einsum."""
|
575 |
+
assert len(node.input) == 2
|
576 |
+
x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
|
577 |
+
graph, node.input[0])
|
578 |
+
y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
|
579 |
+
graph, node.input[1])
|
580 |
+
x_shape.assert_is_fully_defined()
|
581 |
+
y_shape.assert_is_fully_defined()
|
582 |
+
x_shape = x_shape.as_list()
|
583 |
+
y_shape = y_shape.as_list()
|
584 |
+
equation = str(node.attr['equation'])
|
585 |
+
equation = (
|
586 |
+
equation.replace('s:', '')
|
587 |
+
.replace('"', '')
|
588 |
+
.replace(' ', '')
|
589 |
+
.replace('\n', '')
|
590 |
+
)
|
591 |
+
x_str = equation.split(',')[0]
|
592 |
+
y_r_str = equation.split(',')[1]
|
593 |
+
y_str = y_r_str.split('->')[0]
|
594 |
+
r_str = y_r_str.split('->')[1]
|
595 |
+
shape_dic = {}
|
596 |
+
contracted = set()
|
597 |
+
for indice in x_str + y_str:
|
598 |
+
if indice in x_str:
|
599 |
+
indice_dim = x_shape[x_str.find(indice)]
|
600 |
+
elif indice in y_str:
|
601 |
+
indice_dim = y_shape[y_str.find(indice)]
|
602 |
+
else:
|
603 |
+
raise ValueError('indice {} not found in inputs'.format(indice))
|
604 |
+
shape_dic[indice] = indice_dim
|
605 |
+
if indice not in r_str:
|
606 |
+
contracted.add(indice)
|
607 |
+
madds = np.prod([shape_dic[indice] for indice in r_str]) * (
|
608 |
+
np.prod([shape_dic[indice] for indice in contracted]))
|
609 |
+
flops = 2 * madds
|
610 |
+
return ops.OpStats('flops', flops)
|
modeling/official/core/train_utils_test.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Tests for official.core.train_utils."""
|
16 |
+
import json
|
17 |
+
import os
|
18 |
+
import pprint
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import tensorflow as tf, tf_keras
|
22 |
+
|
23 |
+
from official.core import exp_factory
|
24 |
+
from official.core import test_utils
|
25 |
+
from official.core import train_utils
|
26 |
+
from official.modeling import hyperparams
|
27 |
+
|
28 |
+
|
29 |
+
@exp_factory.register_config_factory('foo')
|
30 |
+
def foo():
|
31 |
+
"""Multitask experiment for test."""
|
32 |
+
experiment_config = hyperparams.Config(
|
33 |
+
default_params={
|
34 |
+
'runtime': {
|
35 |
+
'tpu': 'fake',
|
36 |
+
},
|
37 |
+
'task': {
|
38 |
+
'model': {
|
39 |
+
'model_id': 'bar',
|
40 |
+
},
|
41 |
+
},
|
42 |
+
'trainer': {
|
43 |
+
'train_steps': -1,
|
44 |
+
'validation_steps': -1,
|
45 |
+
},
|
46 |
+
})
|
47 |
+
return experiment_config
|
48 |
+
|
49 |
+
|
50 |
+
class TrainUtilsTest(tf.test.TestCase):
|
51 |
+
|
52 |
+
def test_get_leaf_nested_dict(self):
|
53 |
+
d = {'a': {'i': {'x': 5}}}
|
54 |
+
self.assertEqual(train_utils.get_leaf_nested_dict(d, ['a', 'i', 'x']), 5)
|
55 |
+
|
56 |
+
def test_get_leaf_nested_dict_not_leaf(self):
|
57 |
+
with self.assertRaisesRegex(KeyError, 'The value extracted with keys.*'):
|
58 |
+
d = {'a': {'i': {'x': 5}}}
|
59 |
+
train_utils.get_leaf_nested_dict(d, ['a', 'i'])
|
60 |
+
|
61 |
+
def test_get_leaf_nested_dict_path_not_exist_missing_key(self):
|
62 |
+
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
|
63 |
+
d = {'a': {'i': {'x': 5}}}
|
64 |
+
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'y'])
|
65 |
+
|
66 |
+
def test_get_leaf_nested_dict_path_not_exist_out_of_range(self):
|
67 |
+
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
|
68 |
+
d = {'a': {'i': {'x': 5}}}
|
69 |
+
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
|
70 |
+
|
71 |
+
def test_get_leaf_nested_dict_path_not_exist_meets_leaf(self):
|
72 |
+
with self.assertRaisesRegex(KeyError, 'Path not exist while traversing .*'):
|
73 |
+
d = {'a': {'i': 5}}
|
74 |
+
train_utils.get_leaf_nested_dict(d, ['a', 'i', 'z'])
|
75 |
+
|
76 |
+
def test_cast_leaf_nested_dict(self):
|
77 |
+
d = {'a': {'i': {'x': '123'}}, 'b': 456.5}
|
78 |
+
d = train_utils.cast_leaf_nested_dict(d, int)
|
79 |
+
self.assertEqual(d['a']['i']['x'], 123)
|
80 |
+
self.assertEqual(d['b'], 456)
|
81 |
+
|
82 |
+
def test_write_model_params_keras_model(self):
|
83 |
+
inputs = np.zeros([2, 3])
|
84 |
+
model = test_utils.FakeKerasModel()
|
85 |
+
model(inputs) # Must do forward pass to build the model.
|
86 |
+
|
87 |
+
filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
|
88 |
+
train_utils.write_model_params(model, filepath)
|
89 |
+
actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
|
90 |
+
|
91 |
+
expected = [
|
92 |
+
'fake_keras_model/dense/kernel:0 [3, 4]',
|
93 |
+
'fake_keras_model/dense/bias:0 [4]',
|
94 |
+
'fake_keras_model/dense_1/kernel:0 [4, 4]',
|
95 |
+
'fake_keras_model/dense_1/bias:0 [4]',
|
96 |
+
'',
|
97 |
+
'Total params: 36',
|
98 |
+
]
|
99 |
+
self.assertEqual(actual, expected)
|
100 |
+
|
101 |
+
def test_write_model_params_module(self):
|
102 |
+
inputs = np.zeros([2, 3], dtype=np.float32)
|
103 |
+
model = test_utils.FakeModule(3, name='fake_module')
|
104 |
+
model(inputs) # Must do forward pass to build the model.
|
105 |
+
|
106 |
+
filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
|
107 |
+
train_utils.write_model_params(model, filepath)
|
108 |
+
actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
|
109 |
+
|
110 |
+
expected = [
|
111 |
+
'fake_module/dense/b:0 [4]',
|
112 |
+
'fake_module/dense/w:0 [3, 4]',
|
113 |
+
'fake_module/dense_1/b:0 [4]',
|
114 |
+
'fake_module/dense_1/w:0 [4, 4]',
|
115 |
+
'',
|
116 |
+
'Total params: 36',
|
117 |
+
]
|
118 |
+
self.assertEqual(actual, expected)
|
119 |
+
|
120 |
+
def test_construct_experiment_from_flags(self):
|
121 |
+
options = train_utils.ParseConfigOptions(
|
122 |
+
experiment='foo',
|
123 |
+
config_file=[],
|
124 |
+
tpu='bar',
|
125 |
+
tf_data_service='',
|
126 |
+
params_override='task.model.model_id=new,'
|
127 |
+
'trainer.train_steps=10,'
|
128 |
+
'trainer.validation_steps=11')
|
129 |
+
builder = train_utils.ExperimentParser(options)
|
130 |
+
params_from_obj = builder.parse()
|
131 |
+
params_from_func = train_utils.parse_configuration(options)
|
132 |
+
pp = pprint.PrettyPrinter()
|
133 |
+
self.assertEqual(
|
134 |
+
pp.pformat(params_from_obj.as_dict()),
|
135 |
+
pp.pformat(params_from_func.as_dict()))
|
136 |
+
self.assertEqual(params_from_obj.runtime.tpu, 'bar')
|
137 |
+
self.assertEqual(params_from_obj.task.model.model_id, 'new')
|
138 |
+
self.assertEqual(params_from_obj.trainer.train_steps, 10)
|
139 |
+
self.assertEqual(params_from_obj.trainer.validation_steps, 11)
|
140 |
+
|
141 |
+
|
142 |
+
class BestCheckpointExporterTest(tf.test.TestCase):
|
143 |
+
|
144 |
+
def test_maybe_export(self):
|
145 |
+
model_dir = self.create_tempdir().full_path
|
146 |
+
best_ckpt_path = os.path.join(model_dir, 'best_ckpt-1')
|
147 |
+
metric_name = 'test_metric|metric_1'
|
148 |
+
exporter = train_utils.BestCheckpointExporter(
|
149 |
+
model_dir, metric_name, 'higher')
|
150 |
+
v = tf.Variable(1.0)
|
151 |
+
checkpoint = tf.train.Checkpoint(v=v)
|
152 |
+
ret = exporter.maybe_export_checkpoint(
|
153 |
+
checkpoint, {'test_metric': {'metric_1': 5.0}}, 100)
|
154 |
+
with self.subTest(name='Successful first save.'):
|
155 |
+
self.assertEqual(ret, True)
|
156 |
+
v_2 = tf.Variable(2.0)
|
157 |
+
checkpoint_2 = tf.train.Checkpoint(v=v_2)
|
158 |
+
checkpoint_2.restore(best_ckpt_path)
|
159 |
+
self.assertEqual(v_2.numpy(), 1.0)
|
160 |
+
|
161 |
+
v = tf.Variable(3.0)
|
162 |
+
checkpoint = tf.train.Checkpoint(v=v)
|
163 |
+
ret = exporter.maybe_export_checkpoint(
|
164 |
+
checkpoint, {'test_metric': {'metric_1': 6.0}}, 200)
|
165 |
+
with self.subTest(name='Successful better metic save.'):
|
166 |
+
self.assertEqual(ret, True)
|
167 |
+
v_2 = tf.Variable(2.0)
|
168 |
+
checkpoint_2 = tf.train.Checkpoint(v=v_2)
|
169 |
+
checkpoint_2.restore(best_ckpt_path)
|
170 |
+
self.assertEqual(v_2.numpy(), 3.0)
|
171 |
+
|
172 |
+
v = tf.Variable(5.0)
|
173 |
+
checkpoint = tf.train.Checkpoint(v=v)
|
174 |
+
ret = exporter.maybe_export_checkpoint(
|
175 |
+
checkpoint, {'test_metric': {'metric_1': 1.0}}, 300)
|
176 |
+
with self.subTest(name='Worse metic no save.'):
|
177 |
+
self.assertEqual(ret, False)
|
178 |
+
v_2 = tf.Variable(2.0)
|
179 |
+
checkpoint_2 = tf.train.Checkpoint(v=v_2)
|
180 |
+
checkpoint_2.restore(best_ckpt_path)
|
181 |
+
self.assertEqual(v_2.numpy(), 3.0)
|
182 |
+
|
183 |
+
def test_export_best_eval_metric(self):
|
184 |
+
model_dir = self.create_tempdir().full_path
|
185 |
+
metric_name = 'test_metric|metric_1'
|
186 |
+
exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
|
187 |
+
'higher')
|
188 |
+
exporter.export_best_eval_metric({'test_metric': {'metric_1': 5.0}}, 100)
|
189 |
+
with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
|
190 |
+
'rb') as reader:
|
191 |
+
metric = json.loads(reader.read())
|
192 |
+
self.assertAllEqual(
|
193 |
+
metric,
|
194 |
+
{'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
|
195 |
+
|
196 |
+
def test_export_best_eval_metric_skips_non_scalar_values(self):
|
197 |
+
model_dir = self.create_tempdir().full_path
|
198 |
+
metric_name = 'test_metric|metric_1'
|
199 |
+
exporter = train_utils.BestCheckpointExporter(model_dir, metric_name,
|
200 |
+
'higher')
|
201 |
+
image = tf.zeros(shape=[16, 8, 1])
|
202 |
+
eval_logs = {'test_metric': {'metric_1': 5.0, 'image': image}}
|
203 |
+
|
204 |
+
exporter.export_best_eval_metric(eval_logs, 100)
|
205 |
+
|
206 |
+
with tf.io.gfile.GFile(os.path.join(model_dir, 'info.json'),
|
207 |
+
'rb') as reader:
|
208 |
+
metric = json.loads(reader.read())
|
209 |
+
self.assertAllEqual(
|
210 |
+
metric,
|
211 |
+
{'test_metric': {'metric_1': 5.0}, 'best_ckpt_global_step': 100.0})
|
212 |
+
|
213 |
+
|
214 |
+
if __name__ == '__main__':
|
215 |
+
tf.test.main()
|
modeling/official/legacy/README.md
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Models in this `legacy` directory are mainly are used for benchmarking the
|
2 |
+
models.
|
3 |
+
|
4 |
+
Please note that the models in this `legacy` directory are not supported like
|
5 |
+
the models in official/nlp and official/vision.
|
modeling/official/legacy/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
modeling/official/legacy/albert/README.md
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ALBERT (ALBERT: A Lite BERT for Self-supervised Learning of Language Representations)
|
2 |
+
|
3 |
+
**WARNING**: This directory is deprecated.
|
4 |
+
See `nlp/docs/MODEL_GARDEN.md` for the new ALBERT implementation.
|
modeling/official/legacy/albert/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
modeling/official/legacy/albert/configs.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""The ALBERT configurations."""
|
16 |
+
|
17 |
+
import six
|
18 |
+
|
19 |
+
from official.legacy.bert import configs
|
20 |
+
|
21 |
+
|
22 |
+
class AlbertConfig(configs.BertConfig):
|
23 |
+
"""Configuration for `ALBERT`."""
|
24 |
+
|
25 |
+
def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
|
26 |
+
"""Constructs AlbertConfig.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
num_hidden_groups: Number of group for the hidden layers, parameters in
|
30 |
+
the same group are shared. Note that this value and also the following
|
31 |
+
'inner_group_num' has to be 1 for now, because all released ALBERT
|
32 |
+
models set them to 1. We may support arbitary valid values in future.
|
33 |
+
inner_group_num: Number of inner repetition of attention and ffn.
|
34 |
+
**kwargs: The remaining arguments are the same as above 'BertConfig'.
|
35 |
+
"""
|
36 |
+
super(AlbertConfig, self).__init__(**kwargs)
|
37 |
+
|
38 |
+
# TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
|
39 |
+
# in the released ALBERT. Support other values in AlbertEncoder if needed.
|
40 |
+
if inner_group_num != 1 or num_hidden_groups != 1:
|
41 |
+
raise ValueError("We only support 'inner_group_num' and "
|
42 |
+
"'num_hidden_groups' as 1.")
|
43 |
+
|
44 |
+
@classmethod
|
45 |
+
def from_dict(cls, json_object):
|
46 |
+
"""Constructs a `AlbertConfig` from a Python dictionary of parameters."""
|
47 |
+
config = AlbertConfig(vocab_size=None)
|
48 |
+
for (key, value) in six.iteritems(json_object):
|
49 |
+
config.__dict__[key] = value
|
50 |
+
return config
|
modeling/official/legacy/bert/README.md
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BERT (Bidirectional Encoder Representations from Transformers)
|
2 |
+
|
3 |
+
**WARNING**: We are on the way to deprecating most of the code in this directory.
|
4 |
+
Please see
|
5 |
+
[this link](../g3doc/tutorials/bert_new.md)
|
6 |
+
for the new tutorial and use the new code in `nlp/modeling`. This README is
|
7 |
+
still correct for this legacy implementation.
|
8 |
+
|
9 |
+
The academic paper which describes BERT in detail and provides full results on a
|
10 |
+
number of tasks can be found here: https://arxiv.org/abs/1810.04805.
|
11 |
+
|
12 |
+
This repository contains TensorFlow 2.x implementation for BERT.
|
13 |
+
|
14 |
+
## Contents
|
15 |
+
* [Contents](#contents)
|
16 |
+
* [Pre-trained Models](#pre-trained-models)
|
17 |
+
* [Restoring from Checkpoints](#restoring-from-checkpoints)
|
18 |
+
* [Set Up](#set-up)
|
19 |
+
* [Process Datasets](#process-datasets)
|
20 |
+
* [Fine-tuning with BERT](#fine-tuning-with-bert)
|
21 |
+
* [Cloud GPUs and TPUs](#cloud-gpus-and-tpus)
|
22 |
+
* [Sentence and Sentence-pair Classification Tasks](#sentence-and-sentence-pair-classification-tasks)
|
23 |
+
* [SQuAD 1.1](#squad-1.1)
|
24 |
+
|
25 |
+
|
26 |
+
## Pre-trained Models
|
27 |
+
|
28 |
+
We released both checkpoints and tf.hub modules as the pretrained models for
|
29 |
+
fine-tuning. They are TF 2.x compatible and are converted from the checkpoints
|
30 |
+
released in TF 1.x official BERT repository
|
31 |
+
[google-research/bert](https://github.com/google-research/bert)
|
32 |
+
in order to keep consistent with BERT paper.
|
33 |
+
|
34 |
+
|
35 |
+
### Access to Pretrained Checkpoints
|
36 |
+
|
37 |
+
Pretrained checkpoints can be found in the following links:
|
38 |
+
|
39 |
+
**Note: We have switched BERT implementation
|
40 |
+
to use Keras functional-style networks in [nlp/modeling](../modeling).
|
41 |
+
The new checkpoints are:**
|
42 |
+
|
43 |
+
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_uncased_L-24_H-1024_A-16.tar.gz)**:
|
44 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
45 |
+
* **[`BERT-Large, Cased (Whole Word Masking)`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/wwm_cased_L-24_H-1024_A-16.tar.gz)**:
|
46 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
47 |
+
* **[`BERT-Base, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-12_H-768_A-12.tar.gz)**:
|
48 |
+
12-layer, 768-hidden, 12-heads, 110M parameters
|
49 |
+
* **[`BERT-Large, Uncased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16.tar.gz)**:
|
50 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
51 |
+
* **[`BERT-Base, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-12_H-768_A-12.tar.gz)**:
|
52 |
+
12-layer, 768-hidden, 12-heads , 110M parameters
|
53 |
+
* **[`BERT-Large, Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/cased_L-24_H-1024_A-16.tar.gz)**:
|
54 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
55 |
+
* **[`BERT-Base, Multilingual Cased`](https://storage.googleapis.com/cloud-tpu-checkpoints/bert/keras_bert/multi_cased_L-12_H-768_A-12.tar.gz)**:
|
56 |
+
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
|
57 |
+
|
58 |
+
We recommend to host checkpoints on Google Cloud Storage buckets when you use
|
59 |
+
Cloud GPU/TPU.
|
60 |
+
|
61 |
+
### Restoring from Checkpoints
|
62 |
+
|
63 |
+
`tf.train.Checkpoint` is used to manage model checkpoints in TF 2. To restore
|
64 |
+
weights from provided pre-trained checkpoints, you can use the following code:
|
65 |
+
|
66 |
+
```python
|
67 |
+
init_checkpoint='the pretrained model checkpoint path.'
|
68 |
+
model=tf.keras.Model() # Bert pre-trained model as feature extractor.
|
69 |
+
checkpoint = tf.train.Checkpoint(model=model)
|
70 |
+
checkpoint.restore(init_checkpoint)
|
71 |
+
```
|
72 |
+
|
73 |
+
Checkpoints featuring native serialized Keras models
|
74 |
+
(i.e. model.load()/load_weights()) will be available soon.
|
75 |
+
|
76 |
+
### Access to Pretrained hub modules.
|
77 |
+
|
78 |
+
Pretrained tf.hub modules in TF 2.x SavedModel format can be found in the
|
79 |
+
following links:
|
80 |
+
|
81 |
+
* **[`BERT-Large, Uncased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_uncased_L-24_H-1024_A-16/)**:
|
82 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
83 |
+
* **[`BERT-Large, Cased (Whole Word Masking)`](https://tfhub.dev/tensorflow/bert_en_wwm_cased_L-24_H-1024_A-16/)**:
|
84 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
85 |
+
* **[`BERT-Base, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/)**:
|
86 |
+
12-layer, 768-hidden, 12-heads, 110M parameters
|
87 |
+
* **[`BERT-Large, Uncased`](https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/)**:
|
88 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
89 |
+
* **[`BERT-Base, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/)**:
|
90 |
+
12-layer, 768-hidden, 12-heads , 110M parameters
|
91 |
+
* **[`BERT-Large, Cased`](https://tfhub.dev/tensorflow/bert_en_cased_L-24_H-1024_A-16/)**:
|
92 |
+
24-layer, 1024-hidden, 16-heads, 340M parameters
|
93 |
+
* **[`BERT-Base, Multilingual Cased`](https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/)**:
|
94 |
+
104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
|
95 |
+
* **[`BERT-Base, Chinese`](https://tfhub.dev/tensorflow/bert_zh_L-12_H-768_A-12/)**:
|
96 |
+
Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads,
|
97 |
+
110M parameters
|
98 |
+
|
99 |
+
## Set Up
|
100 |
+
|
101 |
+
```shell
|
102 |
+
export PYTHONPATH="$PYTHONPATH:/path/to/models"
|
103 |
+
```
|
104 |
+
|
105 |
+
Install `tf-nightly` to get latest updates:
|
106 |
+
|
107 |
+
```shell
|
108 |
+
pip install tf-nightly-gpu
|
109 |
+
```
|
110 |
+
|
111 |
+
With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
|
112 |
+
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):
|
113 |
+
|
114 |
+
```shell
|
115 |
+
ctpu up -name <instance name> --tf-version=”nightly”
|
116 |
+
```
|
117 |
+
|
118 |
+
Second, you need to install TF 2 `tf-nightly` on your VM:
|
119 |
+
|
120 |
+
```shell
|
121 |
+
pip install tf-nightly
|
122 |
+
```
|
123 |
+
|
124 |
+
## Process Datasets
|
125 |
+
|
126 |
+
### Pre-training
|
127 |
+
|
128 |
+
There is no change to generate pre-training data. Please use the script
|
129 |
+
[`../data/create_pretraining_data.py`](../data/create_pretraining_data.py)
|
130 |
+
which is essentially branched from the [BERT research repo](https://github.com/google-research/bert)
|
131 |
+
to get processed pre-training data and it adapts to TF2 symbols and python3
|
132 |
+
compatibility.
|
133 |
+
|
134 |
+
Running the pre-training script requires an input and output directory, as well as a vocab file. Note that max_seq_length will need to match the sequence length parameter you specify when you run pre-training.
|
135 |
+
|
136 |
+
Example shell script to call create_pretraining_data.py
|
137 |
+
```
|
138 |
+
export WORKING_DIR='local disk or cloud location'
|
139 |
+
export BERT_DIR='local disk or cloud location'
|
140 |
+
python models/official/nlp/data/create_pretraining_data.py \
|
141 |
+
--input_file=$WORKING_DIR/input/input.txt \
|
142 |
+
--output_file=$WORKING_DIR/output/tf_examples.tfrecord \
|
143 |
+
--vocab_file=$BERT_DIR/wwm_uncased_L-24_H-1024_A-16/vocab.txt \
|
144 |
+
--do_lower_case=True \
|
145 |
+
--max_seq_length=512 \
|
146 |
+
--max_predictions_per_seq=76 \
|
147 |
+
--masked_lm_prob=0.15 \
|
148 |
+
--random_seed=12345 \
|
149 |
+
--dupe_factor=5
|
150 |
+
```
|
151 |
+
|
152 |
+
### Fine-tuning
|
153 |
+
|
154 |
+
To prepare the fine-tuning data for final model training, use the
|
155 |
+
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
|
156 |
+
Resulting datasets in `tf_record` format and training meta data should be later
|
157 |
+
passed to training or evaluation scripts. The task-specific arguments are
|
158 |
+
described in the following sections:
|
159 |
+
|
160 |
+
* GLUE
|
161 |
+
|
162 |
+
Users can download the
|
163 |
+
[GLUE data](https://gluebenchmark.com/tasks) by running
|
164 |
+
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
165 |
+
and unpack it to some directory `$GLUE_DIR`.
|
166 |
+
Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate it on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.
|
167 |
+
|
168 |
+
```shell
|
169 |
+
export GLUE_DIR=~/glue
|
170 |
+
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
171 |
+
|
172 |
+
export TASK_NAME=MNLI
|
173 |
+
export OUTPUT_DIR=gs://some_bucket/datasets
|
174 |
+
python ../data/create_finetuning_data.py \
|
175 |
+
--input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
|
176 |
+
--vocab_file=${BERT_DIR}/vocab.txt \
|
177 |
+
--train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
|
178 |
+
--eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
|
179 |
+
--meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
|
180 |
+
--fine_tuning_task_type=classification --max_seq_length=128 \
|
181 |
+
--classification_task_name=${TASK_NAME}
|
182 |
+
```
|
183 |
+
|
184 |
+
* SQUAD
|
185 |
+
|
186 |
+
The [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/) contains
|
187 |
+
detailed information about the SQuAD datasets and evaluation.
|
188 |
+
|
189 |
+
The necessary files can be found here:
|
190 |
+
|
191 |
+
* [train-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)
|
192 |
+
* [dev-v1.1.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)
|
193 |
+
* [evaluate-v1.1.py](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py)
|
194 |
+
* [train-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json)
|
195 |
+
* [dev-v2.0.json](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json)
|
196 |
+
* [evaluate-v2.0.py](https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/)
|
197 |
+
|
198 |
+
```shell
|
199 |
+
export SQUAD_DIR=~/squad
|
200 |
+
export SQUAD_VERSION=v1.1
|
201 |
+
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
202 |
+
export OUTPUT_DIR=gs://some_bucket/datasets
|
203 |
+
|
204 |
+
python ../data/create_finetuning_data.py \
|
205 |
+
--squad_data_file=${SQUAD_DIR}/train-${SQUAD_VERSION}.json \
|
206 |
+
--vocab_file=${BERT_DIR}/vocab.txt \
|
207 |
+
--train_data_output_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
208 |
+
--meta_data_file_path=${OUTPUT_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
209 |
+
--fine_tuning_task_type=squad --max_seq_length=384
|
210 |
+
```
|
211 |
+
|
212 |
+
Note: To create fine-tuning data with SQUAD 2.0, you need to add flag `--version_2_with_negative=True`.
|
213 |
+
|
214 |
+
## Fine-tuning with BERT
|
215 |
+
|
216 |
+
### Cloud GPUs and TPUs
|
217 |
+
|
218 |
+
* Cloud Storage
|
219 |
+
|
220 |
+
The unzipped pre-trained model files can also be found in the Google Cloud
|
221 |
+
Storage folder `gs://cloud-tpu-checkpoints/bert/keras_bert`. For example:
|
222 |
+
|
223 |
+
```shell
|
224 |
+
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
225 |
+
export MODEL_DIR=gs://some_bucket/my_output_dir
|
226 |
+
```
|
227 |
+
|
228 |
+
Currently, users are able to access to `tf-nightly` TPUs and the following TPU
|
229 |
+
script should run with `tf-nightly`.
|
230 |
+
|
231 |
+
* GPU -> TPU
|
232 |
+
|
233 |
+
Just add the following flags to `run_classifier.py` or `run_squad.py`:
|
234 |
+
|
235 |
+
```shell
|
236 |
+
--distribution_strategy=tpu
|
237 |
+
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
238 |
+
```
|
239 |
+
|
240 |
+
### Sentence and Sentence-pair Classification Tasks
|
241 |
+
|
242 |
+
This example code fine-tunes `BERT-Large` on the Microsoft Research Paraphrase
|
243 |
+
Corpus (MRPC) corpus, which only contains 3,600 examples and can fine-tune in a
|
244 |
+
few minutes on most GPUs.
|
245 |
+
|
246 |
+
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
|
247 |
+
workflow.
|
248 |
+
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
|
249 |
+
(uncased_L-12_H-768_A-12).
|
250 |
+
|
251 |
+
```shell
|
252 |
+
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
253 |
+
export MODEL_DIR=gs://some_bucket/my_output_dir
|
254 |
+
export GLUE_DIR=gs://some_bucket/datasets
|
255 |
+
export TASK=MRPC
|
256 |
+
|
257 |
+
python run_classifier.py \
|
258 |
+
--mode='train_and_eval' \
|
259 |
+
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
260 |
+
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
261 |
+
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
262 |
+
--bert_config_file=${BERT_DIR}/bert_config.json \
|
263 |
+
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
264 |
+
--train_batch_size=4 \
|
265 |
+
--eval_batch_size=4 \
|
266 |
+
--steps_per_loop=1 \
|
267 |
+
--learning_rate=2e-5 \
|
268 |
+
--num_train_epochs=3 \
|
269 |
+
--model_dir=${MODEL_DIR} \
|
270 |
+
--distribution_strategy=mirrored
|
271 |
+
```
|
272 |
+
|
273 |
+
Alternatively, instead of specifying `init_checkpoint`, you can specify
|
274 |
+
`hub_module_url` to employ a pre-trained BERT hub module, e.g.,
|
275 |
+
` --hub_module_url=https://tfhub.dev/tensorflow/bert_en_uncased_L-24_H-1024_A-16/1`.
|
276 |
+
|
277 |
+
After training a model, to get predictions from the classifier, you can set the
|
278 |
+
`--mode=predict` and offer the test set tfrecords to `--eval_data_path`.
|
279 |
+
The output will be created in file called test_results.tsv in the output folder.
|
280 |
+
Each line will contain output for each sample, columns are the class
|
281 |
+
probabilities.
|
282 |
+
|
283 |
+
```shell
|
284 |
+
python run_classifier.py \
|
285 |
+
--mode='predict' \
|
286 |
+
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
287 |
+
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
288 |
+
--bert_config_file=${BERT_DIR}/bert_config.json \
|
289 |
+
--eval_batch_size=4 \
|
290 |
+
--model_dir=${MODEL_DIR} \
|
291 |
+
--distribution_strategy=mirrored
|
292 |
+
```
|
293 |
+
|
294 |
+
To use TPU, you only need to switch the distribution strategy type to `tpu` with TPU
|
295 |
+
information and use remote storage for model checkpoints.
|
296 |
+
|
297 |
+
```shell
|
298 |
+
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
299 |
+
export TPU_IP_ADDRESS='???'
|
300 |
+
export MODEL_DIR=gs://some_bucket/my_output_dir
|
301 |
+
export GLUE_DIR=gs://some_bucket/datasets
|
302 |
+
export TASK=MRPC
|
303 |
+
|
304 |
+
python run_classifier.py \
|
305 |
+
--mode='train_and_eval' \
|
306 |
+
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
307 |
+
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
308 |
+
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
309 |
+
--bert_config_file=${BERT_DIR}/bert_config.json \
|
310 |
+
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
311 |
+
--train_batch_size=32 \
|
312 |
+
--eval_batch_size=32 \
|
313 |
+
--steps_per_loop=1000 \
|
314 |
+
--learning_rate=2e-5 \
|
315 |
+
--num_train_epochs=3 \
|
316 |
+
--model_dir=${MODEL_DIR} \
|
317 |
+
--distribution_strategy=tpu \
|
318 |
+
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
319 |
+
```
|
320 |
+
|
321 |
+
Note that, we specify `steps_per_loop=1000` for TPU, because running a loop of
|
322 |
+
training steps inside a `tf.function` can significantly increase TPU utilization
|
323 |
+
and callbacks will not be called inside the loop.
|
324 |
+
|
325 |
+
### SQuAD 1.1
|
326 |
+
|
327 |
+
The Stanford Question Answering Dataset (SQuAD) is a popular question answering
|
328 |
+
benchmark dataset. See more on [SQuAD website](https://rajpurkar.github.io/SQuAD-explorer/).
|
329 |
+
|
330 |
+
We use the `BERT-Large` (uncased_L-24_H-1024_A-16) as an example throughout the
|
331 |
+
workflow.
|
332 |
+
For GPU memory of 16GB or smaller, you may try to use `BERT-Base`
|
333 |
+
(uncased_L-12_H-768_A-12).
|
334 |
+
|
335 |
+
```shell
|
336 |
+
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
337 |
+
export SQUAD_DIR=gs://some_bucket/datasets
|
338 |
+
export MODEL_DIR=gs://some_bucket/my_output_dir
|
339 |
+
export SQUAD_VERSION=v1.1
|
340 |
+
|
341 |
+
python run_squad.py \
|
342 |
+
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
343 |
+
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
344 |
+
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
|
345 |
+
--vocab_file=${BERT_DIR}/vocab.txt \
|
346 |
+
--bert_config_file=${BERT_DIR}/bert_config.json \
|
347 |
+
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
348 |
+
--train_batch_size=4 \
|
349 |
+
--predict_batch_size=4 \
|
350 |
+
--learning_rate=8e-5 \
|
351 |
+
--num_train_epochs=2 \
|
352 |
+
--model_dir=${MODEL_DIR} \
|
353 |
+
--distribution_strategy=mirrored
|
354 |
+
```
|
355 |
+
|
356 |
+
Similarly, you can replace `init_checkpoint` FLAG with `hub_module_url` to
|
357 |
+
specify a hub module path.
|
358 |
+
|
359 |
+
`run_squad.py` writes the prediction for `--predict_file` by default. If you set
|
360 |
+
the `--model=predict` and offer the SQuAD test data, the scripts will generate
|
361 |
+
the prediction json file.
|
362 |
+
|
363 |
+
To use TPU, you need to switch the distribution strategy type to `tpu` with TPU
|
364 |
+
information.
|
365 |
+
|
366 |
+
```shell
|
367 |
+
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
368 |
+
export TPU_IP_ADDRESS='???'
|
369 |
+
export MODEL_DIR=gs://some_bucket/my_output_dir
|
370 |
+
export SQUAD_DIR=gs://some_bucket/datasets
|
371 |
+
export SQUAD_VERSION=v1.1
|
372 |
+
|
373 |
+
python run_squad.py \
|
374 |
+
--input_meta_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_meta_data \
|
375 |
+
--train_data_path=${SQUAD_DIR}/squad_${SQUAD_VERSION}_train.tf_record \
|
376 |
+
--predict_file=${SQUAD_DIR}/dev-v1.1.json \
|
377 |
+
--vocab_file=${BERT_DIR}/vocab.txt \
|
378 |
+
--bert_config_file=${BERT_DIR}/bert_config.json \
|
379 |
+
--init_checkpoint=${BERT_DIR}/bert_model.ckpt \
|
380 |
+
--train_batch_size=32 \
|
381 |
+
--learning_rate=8e-5 \
|
382 |
+
--num_train_epochs=2 \
|
383 |
+
--model_dir=${MODEL_DIR} \
|
384 |
+
--distribution_strategy=tpu \
|
385 |
+
--tpu=grpc://${TPU_IP_ADDRESS}:8470
|
386 |
+
```
|
387 |
+
|
388 |
+
The dev set predictions will be saved into a file called predictions.json in the
|
389 |
+
model_dir:
|
390 |
+
|
391 |
+
```shell
|
392 |
+
python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ./squad/predictions.json
|
393 |
+
```
|
394 |
+
|
395 |
+
|
modeling/official/legacy/bert/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
modeling/official/legacy/bert/bert_cloud_tpu.md
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BERT FineTuning with Cloud TPU: Sentence and Sentence-Pair Classification Tasks (TF 2.1)
|
2 |
+
This tutorial shows you how to train the Bidirectional Encoder Representations from Transformers (BERT) model on Cloud TPU.
|
3 |
+
|
4 |
+
|
5 |
+
## Set up Cloud Storage and Compute Engine VM
|
6 |
+
1. [Open a cloud shell window](https://console.cloud.google.com/?cloudshell=true&_ga=2.11844148.-1612541229.1552429951)
|
7 |
+
2. Create a variable for the project's id:
|
8 |
+
```
|
9 |
+
export PROJECT_ID=your-project_id
|
10 |
+
```
|
11 |
+
3. Configure `gcloud` command-line tool to use the project where you want to create Cloud TPU.
|
12 |
+
```
|
13 |
+
gcloud config set project ${PROJECT_ID}
|
14 |
+
```
|
15 |
+
4. Create a Cloud Storage bucket using the following command:
|
16 |
+
```
|
17 |
+
gsutil mb -p ${PROJECT_ID} -c standard -l europe-west4 -b on gs://your-bucket-name
|
18 |
+
```
|
19 |
+
This Cloud Storage bucket stores the data you use to train your model and the training results.
|
20 |
+
5. Launch a Compute Engine VM and Cloud TPU using the ctpu up command.
|
21 |
+
```
|
22 |
+
ctpu up --tpu-size=v3-8 \
|
23 |
+
--machine-type=n1-standard-8 \
|
24 |
+
--zone=europe-west4-a \
|
25 |
+
--tf-version=2.1 [optional flags: --project, --name]
|
26 |
+
```
|
27 |
+
6. The configuration you specified appears. Enter y to approve or n to cancel.
|
28 |
+
7. When the ctpu up command has finished executing, verify that your shell prompt has changed from username@project to username@tpuname. This change shows that you are now logged into your Compute Engine VM.
|
29 |
+
```
|
30 |
+
gcloud compute ssh vm-name --zone=europe-west4-a
|
31 |
+
(vm)$ export TPU_NAME=vm-name
|
32 |
+
```
|
33 |
+
As you continue these instructions, run each command that begins with `(vm)$` in your VM session window.
|
34 |
+
|
35 |
+
## Prepare the Dataset
|
36 |
+
1. From your Compute Engine virtual machine (VM), install requirements.txt.
|
37 |
+
```
|
38 |
+
(vm)$ cd /usr/share/models
|
39 |
+
(vm)$ sudo pip3 install -r official/requirements.txt
|
40 |
+
```
|
41 |
+
2. Optional: download download_glue_data.py
|
42 |
+
|
43 |
+
This tutorial uses the General Language Understanding Evaluation (GLUE) benchmark to evaluate and analyze the performance of the model. The GLUE data is provided for this tutorial at gs://cloud-tpu-checkpoints/bert/classification.
|
44 |
+
|
45 |
+
## Define parameter values
|
46 |
+
Next, define several parameter values that are required when you train and evaluate your model:
|
47 |
+
|
48 |
+
```
|
49 |
+
(vm)$ export PYTHONPATH="$PYTHONPATH:/usr/share/tpu/models"
|
50 |
+
(vm)$ export STORAGE_BUCKET=gs://your-bucket-name
|
51 |
+
(vm)$ export BERT_BASE_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16
|
52 |
+
(vm)$ export MODEL_DIR=${STORAGE_BUCKET}/bert-output
|
53 |
+
(vm)$ export GLUE_DIR=gs://cloud-tpu-checkpoints/bert/classification
|
54 |
+
(vm)$ export TASK=mnli
|
55 |
+
```
|
56 |
+
|
57 |
+
## Train the model
|
58 |
+
From your Compute Engine VM, run the following command.
|
59 |
+
|
60 |
+
```
|
61 |
+
(vm)$ python3 official/nlp/bert/run_classifier.py \
|
62 |
+
--mode='train_and_eval' \
|
63 |
+
--input_meta_data_path=${GLUE_DIR}/${TASK}_meta_data \
|
64 |
+
--train_data_path=${GLUE_DIR}/${TASK}_train.tf_record \
|
65 |
+
--eval_data_path=${GLUE_DIR}/${TASK}_eval.tf_record \
|
66 |
+
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
|
67 |
+
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
|
68 |
+
--train_batch_size=32 \
|
69 |
+
--eval_batch_size=32 \
|
70 |
+
--learning_rate=2e-5 \
|
71 |
+
--num_train_epochs=3 \
|
72 |
+
--model_dir=${MODEL_DIR} \
|
73 |
+
--distribution_strategy=tpu \
|
74 |
+
--tpu=${TPU_NAME}
|
75 |
+
```
|
76 |
+
|
77 |
+
## Verify your results
|
78 |
+
The training takes approximately 1 hour on a v3-8 TPU. When script completes, you should see results similar to the following:
|
79 |
+
```
|
80 |
+
Training Summary:
|
81 |
+
{'train_loss': 0.28142181038856506,
|
82 |
+
'last_train_metrics': 0.9467429518699646,
|
83 |
+
'eval_metrics': 0.8599063158035278,
|
84 |
+
'total_training_steps': 36813}
|
85 |
+
```
|
86 |
+
|
87 |
+
## Clean up
|
88 |
+
To avoid incurring charges to your GCP account for the resources used in this topic:
|
89 |
+
1. Disconnect from the Compute Engine VM:
|
90 |
+
```
|
91 |
+
(vm)$ exit
|
92 |
+
```
|
93 |
+
2. In your Cloud Shell, run ctpu delete with the --zone flag you used when you set up the Cloud TPU to delete your Compute Engine VM and your Cloud TPU:
|
94 |
+
```
|
95 |
+
$ ctpu delete --zone=your-zone
|
96 |
+
```
|
97 |
+
3. Run ctpu status specifying your zone to make sure you have no instances allocated to avoid unnecessary charges for TPU usage. The deletion might take several minutes. A response like the one below indicates there are no more allocated instances:
|
98 |
+
```
|
99 |
+
$ ctpu status --zone=your-zone
|
100 |
+
```
|
101 |
+
4. Run gsutil as shown, replacing your-bucket with the name of the Cloud Storage bucket you created for this tutorial:
|
102 |
+
```
|
103 |
+
$ gsutil rm -r gs://your-bucket
|
104 |
+
```
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
modeling/official/legacy/bert/bert_models.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""BERT models that are compatible with TF 2.0."""
|
16 |
+
|
17 |
+
import gin
|
18 |
+
import tensorflow as tf, tf_keras
|
19 |
+
import tensorflow_hub as hub
|
20 |
+
from official.legacy.albert import configs as albert_configs
|
21 |
+
from official.legacy.bert import configs
|
22 |
+
from official.modeling import tf_utils
|
23 |
+
from official.nlp.modeling import models
|
24 |
+
from official.nlp.modeling import networks
|
25 |
+
|
26 |
+
|
27 |
+
class BertPretrainLossAndMetricLayer(tf_keras.layers.Layer):
|
28 |
+
"""Returns layer that computes custom loss and metrics for pretraining."""
|
29 |
+
|
30 |
+
def __init__(self, vocab_size, **kwargs):
|
31 |
+
super(BertPretrainLossAndMetricLayer, self).__init__(**kwargs)
|
32 |
+
self._vocab_size = vocab_size
|
33 |
+
self.config = {
|
34 |
+
'vocab_size': vocab_size,
|
35 |
+
}
|
36 |
+
|
37 |
+
def _add_metrics(self, lm_output, lm_labels, lm_label_weights,
|
38 |
+
lm_example_loss, sentence_output, sentence_labels,
|
39 |
+
next_sentence_loss):
|
40 |
+
"""Adds metrics."""
|
41 |
+
masked_lm_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
|
42 |
+
lm_labels, lm_output)
|
43 |
+
numerator = tf.reduce_sum(masked_lm_accuracy * lm_label_weights)
|
44 |
+
denominator = tf.reduce_sum(lm_label_weights) + 1e-5
|
45 |
+
masked_lm_accuracy = numerator / denominator
|
46 |
+
self.add_metric(
|
47 |
+
masked_lm_accuracy, name='masked_lm_accuracy', aggregation='mean')
|
48 |
+
|
49 |
+
self.add_metric(lm_example_loss, name='lm_example_loss', aggregation='mean')
|
50 |
+
|
51 |
+
if sentence_labels is not None:
|
52 |
+
next_sentence_accuracy = tf_keras.metrics.sparse_categorical_accuracy(
|
53 |
+
sentence_labels, sentence_output)
|
54 |
+
self.add_metric(
|
55 |
+
next_sentence_accuracy,
|
56 |
+
name='next_sentence_accuracy',
|
57 |
+
aggregation='mean')
|
58 |
+
|
59 |
+
if next_sentence_loss is not None:
|
60 |
+
self.add_metric(
|
61 |
+
next_sentence_loss, name='next_sentence_loss', aggregation='mean')
|
62 |
+
|
63 |
+
def call(self,
|
64 |
+
lm_output_logits,
|
65 |
+
sentence_output_logits,
|
66 |
+
lm_label_ids,
|
67 |
+
lm_label_weights,
|
68 |
+
sentence_labels=None):
|
69 |
+
"""Implements call() for the layer."""
|
70 |
+
lm_label_weights = tf.cast(lm_label_weights, tf.float32)
|
71 |
+
lm_output_logits = tf.cast(lm_output_logits, tf.float32)
|
72 |
+
|
73 |
+
lm_prediction_losses = tf_keras.losses.sparse_categorical_crossentropy(
|
74 |
+
lm_label_ids, lm_output_logits, from_logits=True)
|
75 |
+
lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
|
76 |
+
lm_denominator_loss = tf.reduce_sum(lm_label_weights)
|
77 |
+
mask_label_loss = tf.math.divide_no_nan(lm_numerator_loss,
|
78 |
+
lm_denominator_loss)
|
79 |
+
|
80 |
+
if sentence_labels is not None:
|
81 |
+
sentence_output_logits = tf.cast(sentence_output_logits, tf.float32)
|
82 |
+
sentence_loss = tf_keras.losses.sparse_categorical_crossentropy(
|
83 |
+
sentence_labels, sentence_output_logits, from_logits=True)
|
84 |
+
sentence_loss = tf.reduce_mean(sentence_loss)
|
85 |
+
loss = mask_label_loss + sentence_loss
|
86 |
+
else:
|
87 |
+
sentence_loss = None
|
88 |
+
loss = mask_label_loss
|
89 |
+
|
90 |
+
batch_shape = tf.slice(tf.shape(lm_label_ids), [0], [1])
|
91 |
+
# TODO(hongkuny): Avoids the hack and switches add_loss.
|
92 |
+
final_loss = tf.fill(batch_shape, loss)
|
93 |
+
|
94 |
+
self._add_metrics(lm_output_logits, lm_label_ids, lm_label_weights,
|
95 |
+
mask_label_loss, sentence_output_logits, sentence_labels,
|
96 |
+
sentence_loss)
|
97 |
+
return final_loss
|
98 |
+
|
99 |
+
|
100 |
+
@gin.configurable
|
101 |
+
def get_transformer_encoder(bert_config,
|
102 |
+
sequence_length=None,
|
103 |
+
transformer_encoder_cls=None,
|
104 |
+
output_range=None):
|
105 |
+
"""Gets a 'TransformerEncoder' object.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
bert_config: A 'modeling.BertConfig' or 'modeling.AlbertConfig' object.
|
109 |
+
sequence_length: [Deprecated].
|
110 |
+
transformer_encoder_cls: A EncoderScaffold class. If it is None, uses the
|
111 |
+
default BERT encoder implementation.
|
112 |
+
output_range: the sequence output range, [0, output_range). Default setting
|
113 |
+
is to return the entire sequence output.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
A encoder object.
|
117 |
+
"""
|
118 |
+
del sequence_length
|
119 |
+
if transformer_encoder_cls is not None:
|
120 |
+
# TODO(hongkuny): evaluate if it is better to put cfg definition in gin.
|
121 |
+
embedding_cfg = dict(
|
122 |
+
vocab_size=bert_config.vocab_size,
|
123 |
+
type_vocab_size=bert_config.type_vocab_size,
|
124 |
+
hidden_size=bert_config.hidden_size,
|
125 |
+
max_seq_length=bert_config.max_position_embeddings,
|
126 |
+
initializer=tf_keras.initializers.TruncatedNormal(
|
127 |
+
stddev=bert_config.initializer_range),
|
128 |
+
dropout_rate=bert_config.hidden_dropout_prob,
|
129 |
+
)
|
130 |
+
hidden_cfg = dict(
|
131 |
+
num_attention_heads=bert_config.num_attention_heads,
|
132 |
+
intermediate_size=bert_config.intermediate_size,
|
133 |
+
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
|
134 |
+
dropout_rate=bert_config.hidden_dropout_prob,
|
135 |
+
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
|
136 |
+
kernel_initializer=tf_keras.initializers.TruncatedNormal(
|
137 |
+
stddev=bert_config.initializer_range),
|
138 |
+
)
|
139 |
+
kwargs = dict(
|
140 |
+
embedding_cfg=embedding_cfg,
|
141 |
+
hidden_cfg=hidden_cfg,
|
142 |
+
num_hidden_instances=bert_config.num_hidden_layers,
|
143 |
+
pooled_output_dim=bert_config.hidden_size,
|
144 |
+
pooler_layer_initializer=tf_keras.initializers.TruncatedNormal(
|
145 |
+
stddev=bert_config.initializer_range))
|
146 |
+
|
147 |
+
# Relies on gin configuration to define the Transformer encoder arguments.
|
148 |
+
return transformer_encoder_cls(**kwargs)
|
149 |
+
|
150 |
+
kwargs = dict(
|
151 |
+
vocab_size=bert_config.vocab_size,
|
152 |
+
hidden_size=bert_config.hidden_size,
|
153 |
+
num_layers=bert_config.num_hidden_layers,
|
154 |
+
num_attention_heads=bert_config.num_attention_heads,
|
155 |
+
intermediate_size=bert_config.intermediate_size,
|
156 |
+
activation=tf_utils.get_activation(bert_config.hidden_act),
|
157 |
+
dropout_rate=bert_config.hidden_dropout_prob,
|
158 |
+
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
|
159 |
+
max_sequence_length=bert_config.max_position_embeddings,
|
160 |
+
type_vocab_size=bert_config.type_vocab_size,
|
161 |
+
embedding_width=bert_config.embedding_size,
|
162 |
+
initializer=tf_keras.initializers.TruncatedNormal(
|
163 |
+
stddev=bert_config.initializer_range))
|
164 |
+
if isinstance(bert_config, albert_configs.AlbertConfig):
|
165 |
+
return networks.AlbertEncoder(**kwargs)
|
166 |
+
else:
|
167 |
+
assert isinstance(bert_config, configs.BertConfig)
|
168 |
+
kwargs['output_range'] = output_range
|
169 |
+
return networks.BertEncoder(**kwargs)
|
170 |
+
|
171 |
+
|
172 |
+
def pretrain_model(bert_config,
|
173 |
+
seq_length,
|
174 |
+
max_predictions_per_seq,
|
175 |
+
initializer=None,
|
176 |
+
use_next_sentence_label=True,
|
177 |
+
return_core_pretrainer_model=False):
|
178 |
+
"""Returns model to be used for pre-training.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
bert_config: Configuration that defines the core BERT model.
|
182 |
+
seq_length: Maximum sequence length of the training data.
|
183 |
+
max_predictions_per_seq: Maximum number of tokens in sequence to mask out
|
184 |
+
and use for pretraining.
|
185 |
+
initializer: Initializer for weights in BertPretrainer.
|
186 |
+
use_next_sentence_label: Whether to use the next sentence label.
|
187 |
+
return_core_pretrainer_model: Whether to also return the `BertPretrainer`
|
188 |
+
object.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
A Tuple of (1) Pretraining model, (2) core BERT submodel from which to
|
192 |
+
save weights after pretraining, and (3) optional core `BertPretrainer`
|
193 |
+
object if argument `return_core_pretrainer_model` is True.
|
194 |
+
"""
|
195 |
+
input_word_ids = tf_keras.layers.Input(
|
196 |
+
shape=(seq_length,), name='input_word_ids', dtype=tf.int32)
|
197 |
+
input_mask = tf_keras.layers.Input(
|
198 |
+
shape=(seq_length,), name='input_mask', dtype=tf.int32)
|
199 |
+
input_type_ids = tf_keras.layers.Input(
|
200 |
+
shape=(seq_length,), name='input_type_ids', dtype=tf.int32)
|
201 |
+
masked_lm_positions = tf_keras.layers.Input(
|
202 |
+
shape=(max_predictions_per_seq,),
|
203 |
+
name='masked_lm_positions',
|
204 |
+
dtype=tf.int32)
|
205 |
+
masked_lm_ids = tf_keras.layers.Input(
|
206 |
+
shape=(max_predictions_per_seq,), name='masked_lm_ids', dtype=tf.int32)
|
207 |
+
masked_lm_weights = tf_keras.layers.Input(
|
208 |
+
shape=(max_predictions_per_seq,),
|
209 |
+
name='masked_lm_weights',
|
210 |
+
dtype=tf.int32)
|
211 |
+
|
212 |
+
if use_next_sentence_label:
|
213 |
+
next_sentence_labels = tf_keras.layers.Input(
|
214 |
+
shape=(1,), name='next_sentence_labels', dtype=tf.int32)
|
215 |
+
else:
|
216 |
+
next_sentence_labels = None
|
217 |
+
|
218 |
+
transformer_encoder = get_transformer_encoder(bert_config, seq_length)
|
219 |
+
if initializer is None:
|
220 |
+
initializer = tf_keras.initializers.TruncatedNormal(
|
221 |
+
stddev=bert_config.initializer_range)
|
222 |
+
pretrainer_model = models.BertPretrainer(
|
223 |
+
network=transformer_encoder,
|
224 |
+
embedding_table=transformer_encoder.get_embedding_table(),
|
225 |
+
num_classes=2, # The next sentence prediction label has two classes.
|
226 |
+
activation=tf_utils.get_activation(bert_config.hidden_act),
|
227 |
+
num_token_predictions=max_predictions_per_seq,
|
228 |
+
initializer=initializer,
|
229 |
+
output='logits')
|
230 |
+
|
231 |
+
outputs = pretrainer_model(
|
232 |
+
[input_word_ids, input_mask, input_type_ids, masked_lm_positions])
|
233 |
+
lm_output = outputs['masked_lm']
|
234 |
+
sentence_output = outputs['classification']
|
235 |
+
pretrain_loss_layer = BertPretrainLossAndMetricLayer(
|
236 |
+
vocab_size=bert_config.vocab_size)
|
237 |
+
output_loss = pretrain_loss_layer(lm_output, sentence_output, masked_lm_ids,
|
238 |
+
masked_lm_weights, next_sentence_labels)
|
239 |
+
inputs = {
|
240 |
+
'input_word_ids': input_word_ids,
|
241 |
+
'input_mask': input_mask,
|
242 |
+
'input_type_ids': input_type_ids,
|
243 |
+
'masked_lm_positions': masked_lm_positions,
|
244 |
+
'masked_lm_ids': masked_lm_ids,
|
245 |
+
'masked_lm_weights': masked_lm_weights,
|
246 |
+
}
|
247 |
+
if use_next_sentence_label:
|
248 |
+
inputs['next_sentence_labels'] = next_sentence_labels
|
249 |
+
|
250 |
+
keras_model = tf_keras.Model(inputs=inputs, outputs=output_loss)
|
251 |
+
if return_core_pretrainer_model:
|
252 |
+
return keras_model, transformer_encoder, pretrainer_model
|
253 |
+
else:
|
254 |
+
return keras_model, transformer_encoder
|
255 |
+
|
256 |
+
|
257 |
+
def squad_model(bert_config,
|
258 |
+
max_seq_length,
|
259 |
+
initializer=None,
|
260 |
+
hub_module_url=None,
|
261 |
+
hub_module_trainable=True):
|
262 |
+
"""Returns BERT Squad model along with core BERT model to import weights.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
bert_config: BertConfig, the config defines the core Bert model.
|
266 |
+
max_seq_length: integer, the maximum input sequence length.
|
267 |
+
initializer: Initializer for the final dense layer in the span labeler.
|
268 |
+
Defaulted to TruncatedNormal initializer.
|
269 |
+
hub_module_url: TF-Hub path/url to Bert module.
|
270 |
+
hub_module_trainable: True to finetune layers in the hub module.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
A tuple of (1) keras model that outputs start logits and end logits and
|
274 |
+
(2) the core BERT transformer encoder.
|
275 |
+
"""
|
276 |
+
if initializer is None:
|
277 |
+
initializer = tf_keras.initializers.TruncatedNormal(
|
278 |
+
stddev=bert_config.initializer_range)
|
279 |
+
if not hub_module_url:
|
280 |
+
bert_encoder = get_transformer_encoder(bert_config, max_seq_length)
|
281 |
+
return models.BertSpanLabeler(
|
282 |
+
network=bert_encoder, initializer=initializer), bert_encoder
|
283 |
+
|
284 |
+
input_word_ids = tf_keras.layers.Input(
|
285 |
+
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
|
286 |
+
input_mask = tf_keras.layers.Input(
|
287 |
+
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
|
288 |
+
input_type_ids = tf_keras.layers.Input(
|
289 |
+
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
|
290 |
+
core_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
|
291 |
+
pooled_output, sequence_output = core_model(
|
292 |
+
[input_word_ids, input_mask, input_type_ids])
|
293 |
+
bert_encoder = tf_keras.Model(
|
294 |
+
inputs={
|
295 |
+
'input_word_ids': input_word_ids,
|
296 |
+
'input_mask': input_mask,
|
297 |
+
'input_type_ids': input_type_ids,
|
298 |
+
},
|
299 |
+
outputs=[sequence_output, pooled_output],
|
300 |
+
name='core_model')
|
301 |
+
return models.BertSpanLabeler(
|
302 |
+
network=bert_encoder, initializer=initializer), bert_encoder
|
303 |
+
|
304 |
+
|
305 |
+
def classifier_model(bert_config,
|
306 |
+
num_labels,
|
307 |
+
max_seq_length=None,
|
308 |
+
final_layer_initializer=None,
|
309 |
+
hub_module_url=None,
|
310 |
+
hub_module_trainable=True):
|
311 |
+
"""BERT classifier model in functional API style.
|
312 |
+
|
313 |
+
Construct a Keras model for predicting `num_labels` outputs from an input with
|
314 |
+
maximum sequence length `max_seq_length`.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
bert_config: BertConfig or AlbertConfig, the config defines the core BERT or
|
318 |
+
ALBERT model.
|
319 |
+
num_labels: integer, the number of classes.
|
320 |
+
max_seq_length: integer, the maximum input sequence length.
|
321 |
+
final_layer_initializer: Initializer for final dense layer. Defaulted
|
322 |
+
TruncatedNormal initializer.
|
323 |
+
hub_module_url: TF-Hub path/url to Bert module.
|
324 |
+
hub_module_trainable: True to finetune layers in the hub module.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
Combined prediction model (words, mask, type) -> (one-hot labels)
|
328 |
+
BERT sub-model (words, mask, type) -> (bert_outputs)
|
329 |
+
"""
|
330 |
+
if final_layer_initializer is not None:
|
331 |
+
initializer = final_layer_initializer
|
332 |
+
else:
|
333 |
+
initializer = tf_keras.initializers.TruncatedNormal(
|
334 |
+
stddev=bert_config.initializer_range)
|
335 |
+
|
336 |
+
if not hub_module_url:
|
337 |
+
bert_encoder = get_transformer_encoder(
|
338 |
+
bert_config, max_seq_length, output_range=1)
|
339 |
+
return models.BertClassifier(
|
340 |
+
bert_encoder,
|
341 |
+
num_classes=num_labels,
|
342 |
+
dropout_rate=bert_config.hidden_dropout_prob,
|
343 |
+
initializer=initializer), bert_encoder
|
344 |
+
|
345 |
+
input_word_ids = tf_keras.layers.Input(
|
346 |
+
shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
|
347 |
+
input_mask = tf_keras.layers.Input(
|
348 |
+
shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
|
349 |
+
input_type_ids = tf_keras.layers.Input(
|
350 |
+
shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
|
351 |
+
bert_model = hub.KerasLayer(hub_module_url, trainable=hub_module_trainable)
|
352 |
+
pooled_output, _ = bert_model([input_word_ids, input_mask, input_type_ids])
|
353 |
+
output = tf_keras.layers.Dropout(rate=bert_config.hidden_dropout_prob)(
|
354 |
+
pooled_output)
|
355 |
+
|
356 |
+
output = tf_keras.layers.Dense(
|
357 |
+
num_labels, kernel_initializer=initializer, name='output')(
|
358 |
+
output)
|
359 |
+
return tf_keras.Model(
|
360 |
+
inputs={
|
361 |
+
'input_word_ids': input_word_ids,
|
362 |
+
'input_mask': input_mask,
|
363 |
+
'input_type_ids': input_type_ids
|
364 |
+
},
|
365 |
+
outputs=output), bert_model
|
modeling/official/legacy/bert/bert_models_test.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import tensorflow as tf, tf_keras
|
16 |
+
|
17 |
+
from official.legacy.bert import bert_models
|
18 |
+
from official.legacy.bert import configs as bert_configs
|
19 |
+
from official.nlp.modeling import networks
|
20 |
+
|
21 |
+
|
22 |
+
class BertModelsTest(tf.test.TestCase):
|
23 |
+
|
24 |
+
def setUp(self):
|
25 |
+
super(BertModelsTest, self).setUp()
|
26 |
+
self._bert_test_config = bert_configs.BertConfig(
|
27 |
+
attention_probs_dropout_prob=0.0,
|
28 |
+
hidden_act='gelu',
|
29 |
+
hidden_dropout_prob=0.0,
|
30 |
+
hidden_size=16,
|
31 |
+
initializer_range=0.02,
|
32 |
+
intermediate_size=32,
|
33 |
+
max_position_embeddings=128,
|
34 |
+
num_attention_heads=2,
|
35 |
+
num_hidden_layers=2,
|
36 |
+
type_vocab_size=2,
|
37 |
+
vocab_size=30522)
|
38 |
+
|
39 |
+
def test_pretrain_model(self):
|
40 |
+
model, encoder = bert_models.pretrain_model(
|
41 |
+
self._bert_test_config,
|
42 |
+
seq_length=5,
|
43 |
+
max_predictions_per_seq=2,
|
44 |
+
initializer=None,
|
45 |
+
use_next_sentence_label=True)
|
46 |
+
self.assertIsInstance(model, tf_keras.Model)
|
47 |
+
self.assertIsInstance(encoder, networks.BertEncoder)
|
48 |
+
|
49 |
+
# model has one scalar output: loss value.
|
50 |
+
self.assertEqual(model.output.shape.as_list(), [
|
51 |
+
None,
|
52 |
+
])
|
53 |
+
|
54 |
+
# Expect two output from encoder: sequence and classification output.
|
55 |
+
self.assertIsInstance(encoder.output, list)
|
56 |
+
self.assertLen(encoder.output, 2)
|
57 |
+
# shape should be [batch size, hidden_size]
|
58 |
+
self.assertEqual(encoder.output[1].shape.as_list(), [None, 16])
|
59 |
+
|
60 |
+
def test_squad_model(self):
|
61 |
+
model, core_model = bert_models.squad_model(
|
62 |
+
self._bert_test_config,
|
63 |
+
max_seq_length=5,
|
64 |
+
initializer=None,
|
65 |
+
hub_module_url=None,
|
66 |
+
hub_module_trainable=None)
|
67 |
+
self.assertIsInstance(model, tf_keras.Model)
|
68 |
+
self.assertIsInstance(core_model, tf_keras.Model)
|
69 |
+
|
70 |
+
# Expect two output from model: start positions and end positions
|
71 |
+
self.assertIsInstance(model.output, list)
|
72 |
+
self.assertLen(model.output, 2)
|
73 |
+
|
74 |
+
# Expect two output from core_model: sequence and classification output.
|
75 |
+
self.assertIsInstance(core_model.output, list)
|
76 |
+
self.assertLen(core_model.output, 2)
|
77 |
+
# shape should be [batch size, None, hidden_size]
|
78 |
+
self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
|
79 |
+
# shape should be [batch size, hidden_size]
|
80 |
+
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
|
81 |
+
|
82 |
+
def test_classifier_model(self):
|
83 |
+
model, core_model = bert_models.classifier_model(
|
84 |
+
self._bert_test_config,
|
85 |
+
num_labels=3,
|
86 |
+
max_seq_length=5,
|
87 |
+
final_layer_initializer=None,
|
88 |
+
hub_module_url=None,
|
89 |
+
hub_module_trainable=None)
|
90 |
+
self.assertIsInstance(model, tf_keras.Model)
|
91 |
+
self.assertIsInstance(core_model, tf_keras.Model)
|
92 |
+
|
93 |
+
# model has one classification output with num_labels=3.
|
94 |
+
self.assertEqual(model.output.shape.as_list(), [None, 3])
|
95 |
+
|
96 |
+
# Expect two output from core_model: sequence and classification output.
|
97 |
+
self.assertIsInstance(core_model.output, list)
|
98 |
+
self.assertLen(core_model.output, 2)
|
99 |
+
# shape should be [batch size, None, hidden_size]
|
100 |
+
self.assertEqual(core_model.output[0].shape.as_list(), [None, None, 16])
|
101 |
+
# shape should be [batch size, hidden_size]
|
102 |
+
self.assertEqual(core_model.output[1].shape.as_list(), [None, 16])
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
tf.test.main()
|
modeling/official/legacy/bert/common_flags.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Defining common flags used across all BERT models/applications."""
|
16 |
+
|
17 |
+
from absl import flags
|
18 |
+
import tensorflow as tf, tf_keras
|
19 |
+
|
20 |
+
from official.utils import hyperparams_flags
|
21 |
+
from official.utils.flags import core as flags_core
|
22 |
+
|
23 |
+
|
24 |
+
def define_common_bert_flags():
|
25 |
+
"""Define common flags for BERT tasks."""
|
26 |
+
flags_core.define_base(
|
27 |
+
data_dir=False,
|
28 |
+
model_dir=True,
|
29 |
+
clean=False,
|
30 |
+
train_epochs=False,
|
31 |
+
epochs_between_evals=False,
|
32 |
+
stop_threshold=False,
|
33 |
+
batch_size=False,
|
34 |
+
num_gpu=True,
|
35 |
+
export_dir=False,
|
36 |
+
distribution_strategy=True,
|
37 |
+
run_eagerly=True)
|
38 |
+
flags_core.define_distribution()
|
39 |
+
flags.DEFINE_string('bert_config_file', None,
|
40 |
+
'Bert configuration file to define core bert layers.')
|
41 |
+
flags.DEFINE_string(
|
42 |
+
'model_export_path', None,
|
43 |
+
'Path to the directory, where trainined model will be '
|
44 |
+
'exported.')
|
45 |
+
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
|
46 |
+
flags.DEFINE_string(
|
47 |
+
'init_checkpoint', None,
|
48 |
+
'Initial checkpoint (usually from a pre-trained BERT model).')
|
49 |
+
flags.DEFINE_integer('num_train_epochs', 3,
|
50 |
+
'Total number of training epochs to perform.')
|
51 |
+
flags.DEFINE_integer(
|
52 |
+
'steps_per_loop', None,
|
53 |
+
'Number of steps per graph-mode loop. Only training step '
|
54 |
+
'happens inside the loop. Callbacks will not be called '
|
55 |
+
'inside. If not set the value will be configured depending on the '
|
56 |
+
'devices available.')
|
57 |
+
flags.DEFINE_float('learning_rate', 5e-5,
|
58 |
+
'The initial learning rate for Adam.')
|
59 |
+
flags.DEFINE_float('end_lr', 0.0,
|
60 |
+
'The end learning rate for learning rate decay.')
|
61 |
+
flags.DEFINE_string('optimizer_type', 'adamw',
|
62 |
+
'The type of optimizer to use for training (adamw|lamb)')
|
63 |
+
flags.DEFINE_boolean(
|
64 |
+
'scale_loss', False,
|
65 |
+
'Whether to divide the loss by number of replica inside the per-replica '
|
66 |
+
'loss function.')
|
67 |
+
flags.DEFINE_boolean(
|
68 |
+
'use_keras_compile_fit', False,
|
69 |
+
'If True, uses Keras compile/fit() API for training logic. Otherwise '
|
70 |
+
'use custom training loop.')
|
71 |
+
flags.DEFINE_string(
|
72 |
+
'hub_module_url', None, 'TF-Hub path/url to Bert module. '
|
73 |
+
'If specified, init_checkpoint flag should not be used.')
|
74 |
+
flags.DEFINE_bool('hub_module_trainable', True,
|
75 |
+
'True to make keras layers in the hub module trainable.')
|
76 |
+
flags.DEFINE_string(
|
77 |
+
'sub_model_export_name', None,
|
78 |
+
'If set, `sub_model` checkpoints are exported into '
|
79 |
+
'FLAGS.model_dir/FLAGS.sub_model_export_name.')
|
80 |
+
flags.DEFINE_bool('explicit_allreduce', False,
|
81 |
+
'True to use explicit allreduce instead of the implicit '
|
82 |
+
'allreduce in optimizer.apply_gradients(). If fp16 mixed '
|
83 |
+
'precision training is used, this also enables allreduce '
|
84 |
+
'gradients in fp16.')
|
85 |
+
flags.DEFINE_integer('allreduce_bytes_per_pack', 0,
|
86 |
+
'Number of bytes of a gradient pack for allreduce. '
|
87 |
+
'Should be positive integer, if set to 0, all '
|
88 |
+
'gradients are in one pack. Breaking gradient into '
|
89 |
+
'packs could enable overlap between allreduce and '
|
90 |
+
'backprop computation. This flag only takes effect '
|
91 |
+
'when explicit_allreduce is set to True.')
|
92 |
+
|
93 |
+
flags_core.define_log_steps()
|
94 |
+
|
95 |
+
# Adds flags for mixed precision and multi-worker training.
|
96 |
+
flags_core.define_performance(
|
97 |
+
num_parallel_calls=False,
|
98 |
+
inter_op=False,
|
99 |
+
intra_op=False,
|
100 |
+
synthetic_data=False,
|
101 |
+
max_train_steps=False,
|
102 |
+
dtype=True,
|
103 |
+
loss_scale=True,
|
104 |
+
all_reduce_alg=True,
|
105 |
+
num_packs=False,
|
106 |
+
tf_gpu_thread_mode=True,
|
107 |
+
datasets_num_private_threads=True,
|
108 |
+
enable_xla=True,
|
109 |
+
fp16_implementation=True,
|
110 |
+
)
|
111 |
+
|
112 |
+
# Adds gin configuration flags.
|
113 |
+
hyperparams_flags.define_gin_flags()
|
114 |
+
|
115 |
+
|
116 |
+
def dtype():
|
117 |
+
return flags_core.get_tf_dtype(flags.FLAGS)
|
118 |
+
|
119 |
+
|
120 |
+
def use_float16():
|
121 |
+
return flags_core.get_tf_dtype(flags.FLAGS) == tf.float16
|
122 |
+
|
123 |
+
|
124 |
+
def get_loss_scale():
|
125 |
+
return flags_core.get_loss_scale(flags.FLAGS, default_for_fp16='dynamic')
|
modeling/official/legacy/bert/configs.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""The main BERT model and related functions."""
|
16 |
+
|
17 |
+
import copy
|
18 |
+
import json
|
19 |
+
|
20 |
+
import six
|
21 |
+
import tensorflow as tf, tf_keras
|
22 |
+
|
23 |
+
|
24 |
+
class BertConfig(object):
|
25 |
+
"""Configuration for `BertModel`."""
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
vocab_size,
|
29 |
+
hidden_size=768,
|
30 |
+
num_hidden_layers=12,
|
31 |
+
num_attention_heads=12,
|
32 |
+
intermediate_size=3072,
|
33 |
+
hidden_act="gelu",
|
34 |
+
hidden_dropout_prob=0.1,
|
35 |
+
attention_probs_dropout_prob=0.1,
|
36 |
+
max_position_embeddings=512,
|
37 |
+
type_vocab_size=16,
|
38 |
+
initializer_range=0.02,
|
39 |
+
embedding_size=None,
|
40 |
+
backward_compatible=True):
|
41 |
+
"""Constructs BertConfig.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
|
45 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
46 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
47 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
48 |
+
the Transformer encoder.
|
49 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
50 |
+
layer in the Transformer encoder.
|
51 |
+
hidden_act: The non-linear activation function (function or string) in the
|
52 |
+
encoder and pooler.
|
53 |
+
hidden_dropout_prob: The dropout probability for all fully connected
|
54 |
+
layers in the embeddings, encoder, and pooler.
|
55 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
56 |
+
probabilities.
|
57 |
+
max_position_embeddings: The maximum sequence length that this model might
|
58 |
+
ever be used with. Typically set this to something large just in case
|
59 |
+
(e.g., 512 or 1024 or 2048).
|
60 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
61 |
+
`BertModel`.
|
62 |
+
initializer_range: The stdev of the truncated_normal_initializer for
|
63 |
+
initializing all weight matrices.
|
64 |
+
embedding_size: (Optional) width of the factorized word embeddings.
|
65 |
+
backward_compatible: Boolean, whether the variables shape are compatible
|
66 |
+
with checkpoints converted from TF 1.x BERT.
|
67 |
+
"""
|
68 |
+
self.vocab_size = vocab_size
|
69 |
+
self.hidden_size = hidden_size
|
70 |
+
self.num_hidden_layers = num_hidden_layers
|
71 |
+
self.num_attention_heads = num_attention_heads
|
72 |
+
self.hidden_act = hidden_act
|
73 |
+
self.intermediate_size = intermediate_size
|
74 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
75 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
76 |
+
self.max_position_embeddings = max_position_embeddings
|
77 |
+
self.type_vocab_size = type_vocab_size
|
78 |
+
self.initializer_range = initializer_range
|
79 |
+
self.embedding_size = embedding_size
|
80 |
+
self.backward_compatible = backward_compatible
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def from_dict(cls, json_object):
|
84 |
+
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
85 |
+
config = BertConfig(vocab_size=None)
|
86 |
+
for (key, value) in six.iteritems(json_object):
|
87 |
+
config.__dict__[key] = value
|
88 |
+
return config
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def from_json_file(cls, json_file):
|
92 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
93 |
+
with tf.io.gfile.GFile(json_file, "r") as reader:
|
94 |
+
text = reader.read()
|
95 |
+
return cls.from_dict(json.loads(text))
|
96 |
+
|
97 |
+
def to_dict(self):
|
98 |
+
"""Serializes this instance to a Python dictionary."""
|
99 |
+
output = copy.deepcopy(self.__dict__)
|
100 |
+
return output
|
101 |
+
|
102 |
+
def to_json_string(self):
|
103 |
+
"""Serializes this instance to a JSON string."""
|
104 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
modeling/official/legacy/bert/export_tfhub.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""A script to export BERT as a TF-Hub SavedModel.
|
16 |
+
|
17 |
+
This script is **DEPRECATED** for exporting BERT encoder models;
|
18 |
+
see the error message in by main() for details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
from typing import Text
|
22 |
+
|
23 |
+
# Import libraries
|
24 |
+
from absl import app
|
25 |
+
from absl import flags
|
26 |
+
from absl import logging
|
27 |
+
import tensorflow as tf, tf_keras
|
28 |
+
from official.legacy.bert import bert_models
|
29 |
+
from official.legacy.bert import configs
|
30 |
+
|
31 |
+
FLAGS = flags.FLAGS
|
32 |
+
|
33 |
+
flags.DEFINE_string("bert_config_file", None,
|
34 |
+
"Bert configuration file to define core bert layers.")
|
35 |
+
flags.DEFINE_string("model_checkpoint_path", None,
|
36 |
+
"File path to TF model checkpoint.")
|
37 |
+
flags.DEFINE_string("export_path", None, "TF-Hub SavedModel destination path.")
|
38 |
+
flags.DEFINE_string("vocab_file", None,
|
39 |
+
"The vocabulary file that the BERT model was trained on.")
|
40 |
+
flags.DEFINE_bool(
|
41 |
+
"do_lower_case", None, "Whether to lowercase. If None, "
|
42 |
+
"do_lower_case will be enabled if 'uncased' appears in the "
|
43 |
+
"name of --vocab_file")
|
44 |
+
flags.DEFINE_enum("model_type", "encoder", ["encoder", "squad"],
|
45 |
+
"What kind of BERT model to export.")
|
46 |
+
|
47 |
+
|
48 |
+
def create_bert_model(bert_config: configs.BertConfig) -> tf_keras.Model:
|
49 |
+
"""Creates a BERT keras core model from BERT configuration.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
bert_config: A `BertConfig` to create the core model.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
A keras model.
|
56 |
+
"""
|
57 |
+
# Adds input layers just as placeholders.
|
58 |
+
input_word_ids = tf_keras.layers.Input(
|
59 |
+
shape=(None,), dtype=tf.int32, name="input_word_ids")
|
60 |
+
input_mask = tf_keras.layers.Input(
|
61 |
+
shape=(None,), dtype=tf.int32, name="input_mask")
|
62 |
+
input_type_ids = tf_keras.layers.Input(
|
63 |
+
shape=(None,), dtype=tf.int32, name="input_type_ids")
|
64 |
+
transformer_encoder = bert_models.get_transformer_encoder(
|
65 |
+
bert_config, sequence_length=None)
|
66 |
+
sequence_output, pooled_output = transformer_encoder(
|
67 |
+
[input_word_ids, input_mask, input_type_ids])
|
68 |
+
# To keep consistent with legacy hub modules, the outputs are
|
69 |
+
# "pooled_output" and "sequence_output".
|
70 |
+
return tf_keras.Model(
|
71 |
+
inputs=[input_word_ids, input_mask, input_type_ids],
|
72 |
+
outputs=[pooled_output, sequence_output]), transformer_encoder
|
73 |
+
|
74 |
+
|
75 |
+
def export_bert_tfhub(bert_config: configs.BertConfig,
|
76 |
+
model_checkpoint_path: Text,
|
77 |
+
hub_destination: Text,
|
78 |
+
vocab_file: Text,
|
79 |
+
do_lower_case: bool = None):
|
80 |
+
"""Restores a tf_keras.Model and saves for TF-Hub."""
|
81 |
+
# If do_lower_case is not explicit, default to checking whether "uncased" is
|
82 |
+
# in the vocab file name
|
83 |
+
if do_lower_case is None:
|
84 |
+
do_lower_case = "uncased" in vocab_file
|
85 |
+
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
|
86 |
+
do_lower_case, vocab_file)
|
87 |
+
core_model, encoder = create_bert_model(bert_config)
|
88 |
+
checkpoint = tf.train.Checkpoint(
|
89 |
+
model=encoder, # Legacy checkpoints.
|
90 |
+
encoder=encoder)
|
91 |
+
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
|
92 |
+
core_model.vocab_file = tf.saved_model.Asset(vocab_file)
|
93 |
+
core_model.do_lower_case = tf.Variable(do_lower_case, trainable=False)
|
94 |
+
core_model.save(hub_destination, include_optimizer=False, save_format="tf")
|
95 |
+
|
96 |
+
|
97 |
+
def export_bert_squad_tfhub(bert_config: configs.BertConfig,
|
98 |
+
model_checkpoint_path: Text,
|
99 |
+
hub_destination: Text,
|
100 |
+
vocab_file: Text,
|
101 |
+
do_lower_case: bool = None):
|
102 |
+
"""Restores a tf_keras.Model for BERT with SQuAD and saves for TF-Hub."""
|
103 |
+
# If do_lower_case is not explicit, default to checking whether "uncased" is
|
104 |
+
# in the vocab file name
|
105 |
+
if do_lower_case is None:
|
106 |
+
do_lower_case = "uncased" in vocab_file
|
107 |
+
logging.info("Using do_lower_case=%s based on name of vocab_file=%s",
|
108 |
+
do_lower_case, vocab_file)
|
109 |
+
span_labeling, _ = bert_models.squad_model(bert_config, max_seq_length=None)
|
110 |
+
checkpoint = tf.train.Checkpoint(model=span_labeling)
|
111 |
+
checkpoint.restore(model_checkpoint_path).assert_existing_objects_matched()
|
112 |
+
span_labeling.vocab_file = tf.saved_model.Asset(vocab_file)
|
113 |
+
span_labeling.do_lower_case = tf.Variable(do_lower_case, trainable=False)
|
114 |
+
span_labeling.save(hub_destination, include_optimizer=False, save_format="tf")
|
115 |
+
|
116 |
+
|
117 |
+
def main(_):
|
118 |
+
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
119 |
+
if FLAGS.model_type == "encoder":
|
120 |
+
deprecation_note = (
|
121 |
+
"nlp/bert/export_tfhub is **DEPRECATED** for exporting BERT encoder "
|
122 |
+
"models. Please switch to nlp/tools/export_tfhub for exporting BERT "
|
123 |
+
"(and other) encoders with dict inputs/outputs conforming to "
|
124 |
+
"https://www.tensorflow.org/hub/common_saved_model_apis/text#transformer-encoders"
|
125 |
+
)
|
126 |
+
logging.error(deprecation_note)
|
127 |
+
print("\n\nNOTICE:", deprecation_note, "\n")
|
128 |
+
export_bert_tfhub(bert_config, FLAGS.model_checkpoint_path,
|
129 |
+
FLAGS.export_path, FLAGS.vocab_file, FLAGS.do_lower_case)
|
130 |
+
elif FLAGS.model_type == "squad":
|
131 |
+
export_bert_squad_tfhub(bert_config, FLAGS.model_checkpoint_path,
|
132 |
+
FLAGS.export_path, FLAGS.vocab_file,
|
133 |
+
FLAGS.do_lower_case)
|
134 |
+
else:
|
135 |
+
raise ValueError("Unsupported model_type %s." % FLAGS.model_type)
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
app.run(main)
|