File size: 8,116 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Vision models export utility function for serving/inference."""

import os
from typing import Optional, List, Union, Text, Dict

from absl import logging
import tensorflow as tf, tf_keras

from official.core import config_definitions as cfg
from official.core import export_base
from official.core import train_utils
from official.vision import configs
from official.vision.serving import detection
from official.vision.serving import image_classification
from official.vision.serving import semantic_segmentation
from official.vision.serving import video_classification


def export_inference_graph(
    input_type: str,
    batch_size: Optional[int],
    input_image_size: List[int],
    params: cfg.ExperimentConfig,
    checkpoint_path: str,
    export_dir: str,
    num_channels: Optional[int] = 3,
    export_module: Optional[export_base.ExportModule] = None,
    export_checkpoint_subdir: Optional[str] = None,
    export_saved_model_subdir: Optional[str] = None,
    save_options: Optional[tf.saved_model.SaveOptions] = None,
    log_model_flops_and_params: bool = False,
    checkpoint: Optional[tf.train.Checkpoint] = None,
    input_name: Optional[str] = None,
    function_keys: Optional[Union[List[Text], Dict[Text, Text]]] = None,
    add_tpu_function_alias: Optional[bool] = False,
):
  """Exports inference graph for the model specified in the exp config.

  Saved model is stored at export_dir/saved_model, checkpoint is saved
  at export_dir/checkpoint, and params is saved at export_dir/params.yaml.

  Args:
    input_type: One of `image_tensor`, `image_bytes`, `tf_example` or `tflite`.
    batch_size: 'int', or None.
    input_image_size: List or Tuple of height and width.
    params: Experiment params.
    checkpoint_path: Trained checkpoint path or directory.
    export_dir: Export directory path.
    num_channels: The number of input image channels.
    export_module: Optional export module to be used instead of using params to
      create one. If None, the params will be used to create an export module.
    export_checkpoint_subdir: Optional subdirectory under export_dir to store
      checkpoint.
    export_saved_model_subdir: Optional subdirectory under export_dir to store
      saved model.
    save_options: `SaveOptions` for `tf.saved_model.save`.
    log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
      and model parameters to model_params.txt.
    checkpoint: An optional tf.train.Checkpoint. If provided, the export module
      will use it to read the weights.
    input_name: The input tensor name, default at `None` which produces input
      tensor name `inputs`.
    function_keys: a list of string keys to retrieve pre-defined serving
      signatures. The signaute keys will be set with defaults. If a dictionary
      is provided, the values will be used as signature keys.
    add_tpu_function_alias: Whether to add TPU function alias so that it can be
      converted to a TPU compatible saved model later. Default is False.
  """

  if export_checkpoint_subdir:
    output_checkpoint_directory = os.path.join(
        export_dir, export_checkpoint_subdir)
  else:
    output_checkpoint_directory = None

  if export_saved_model_subdir:
    output_saved_model_directory = os.path.join(
        export_dir, export_saved_model_subdir)
  else:
    output_saved_model_directory = export_dir

  # TODO(arashwan): Offers a direct path to use ExportModule with Task objects.
  if not export_module:
    if isinstance(params.task,
                  configs.image_classification.ImageClassificationTask):
      export_module = image_classification.ClassificationModule(
          params=params,
          batch_size=batch_size,
          input_image_size=input_image_size,
          input_type=input_type,
          num_channels=num_channels,
          input_name=input_name)
    elif isinstance(params.task, configs.retinanet.RetinaNetTask) or isinstance(
        params.task, configs.maskrcnn.MaskRCNNTask):
      export_module = detection.DetectionModule(
          params=params,
          batch_size=batch_size,
          input_image_size=input_image_size,
          input_type=input_type,
          num_channels=num_channels,
          input_name=input_name)
    elif isinstance(params.task,
                    configs.semantic_segmentation.SemanticSegmentationTask):
      export_module = semantic_segmentation.SegmentationModule(
          params=params,
          batch_size=batch_size,
          input_image_size=input_image_size,
          input_type=input_type,
          num_channels=num_channels,
          input_name=input_name)
    elif isinstance(params.task,
                    configs.video_classification.VideoClassificationTask):
      export_module = video_classification.VideoClassificationModule(
          params=params,
          batch_size=batch_size,
          input_image_size=input_image_size,
          input_type=input_type,
          num_channels=num_channels,
          input_name=input_name)
    else:
      raise ValueError('Export module not implemented for {} task.'.format(
          type(params.task)))

  if add_tpu_function_alias:
    if input_type == 'image_tensor':
      inference_func = export_module.inference_from_image_tensors
    elif input_type == 'image_bytes':
      inference_func = export_module.inference_from_image_bytes
    elif input_type == 'tf_example':
      inference_func = export_module.inference_from_tf_example
    else:
      raise ValueError(
          'add_tpu_function_alias is only allowed for input_type of:'
          ' image_tensor, image_bytes, tf_example.'
      )
    save_options = tf.saved_model.SaveOptions(
        function_aliases={
            'tpu_candidate': inference_func,
        }
    )

  export_base.export(
      export_module,
      function_keys=function_keys if function_keys else [input_type],
      export_savedmodel_dir=output_saved_model_directory,
      checkpoint=checkpoint,
      checkpoint_path=checkpoint_path,
      timestamped=False,
      save_options=save_options)

  if output_checkpoint_directory:
    ckpt = tf.train.Checkpoint(model=export_module.model)
    ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
  train_utils.serialize_config(params, export_dir)

  if log_model_flops_and_params:
    inputs_kwargs = None
    if isinstance(
        params.task,
        (configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
      # We need to create inputs_kwargs argument to specify the input shapes for
      # subclass model that overrides model.call to take multiple inputs,
      # e.g., RetinaNet model.
      inputs_kwargs = {
          'images':
              tf.TensorSpec([1] + input_image_size + [num_channels],
                            tf.float32),
          'image_shape':
              tf.TensorSpec([1, 2], tf.float32)
      }
      dummy_inputs = {
          k: tf.ones(v.shape.as_list(), tf.float32)
          for k, v in inputs_kwargs.items()
      }
      # Must do forward pass to build the model.
      export_module.model(**dummy_inputs)
    else:
      logging.info(
          'Logging model flops and params not implemented for %s task.',
          type(params.task))
      return
    train_utils.try_count_flops(export_module.model, inputs_kwargs,
                                os.path.join(export_dir, 'model_flops.txt'))
    train_utils.write_model_params(export_module.model,
                                   os.path.join(export_dir, 'model_params.txt'))