File size: 4,023 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Semantic segmentation input and model functions for serving/inference."""

import tensorflow as tf, tf_keras

from official.vision.modeling import factory
from official.vision.ops import preprocess_ops
from official.vision.serving import export_base


class SegmentationModule(export_base.ExportModule):
  """Segmentation Module."""

  def _build_model(self):
    input_specs = tf_keras.layers.InputSpec(
        shape=[self._batch_size] + self._input_image_size + [3])

    return factory.build_segmentation_model(
        input_specs=input_specs,
        model_config=self.params.task.model,
        l2_regularizer=None)

  def _build_inputs(self, image):
    """Builds classification model inputs for serving."""

    # Normalizes image with mean and std pixel values.
    image = preprocess_ops.normalize_image(
        image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)

    if self.params.task.train_data.preserve_aspect_ratio:
      image, image_info = preprocess_ops.resize_and_crop_image(
          image,
          self._input_image_size,
          padded_size=self._input_image_size,
          aug_scale_min=1.0,
          aug_scale_max=1.0)
    else:
      image, image_info = preprocess_ops.resize_image(image,
                                                      self._input_image_size)
    return image, image_info

  def serve(self, images):
    """Cast image to float and run inference.

    Args:
      images: uint8 Tensor of shape [batch_size, None, None, 3]
    Returns:
      Tensor holding classification output logits.
    """
    # Skip image preprocessing when input_type is tflite so it is compatible
    # with TFLite quantization.
    image_info = None
    if self._input_type != 'tflite':
      with tf.device('cpu:0'):
        images = tf.cast(images, dtype=tf.float32)
        images_spec = tf.TensorSpec(
            shape=self._input_image_size + [3], dtype=tf.float32)
        image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)

        images, image_info = tf.nest.map_structure(
            tf.identity,
            tf.map_fn(
                self._build_inputs,
                elems=images,
                fn_output_signature=(images_spec, image_info_spec),
                parallel_iterations=32))

    outputs = self.inference_step(images)

    # Optionally resize prediction to the input image size.
    if self.params.task.export_config.rescale_output:
      logits = outputs['logits']
      if logits.shape[0] != 1:
        raise ValueError('Batch size cannot be more than 1.')

      image_shape = tf.cast(image_info[0, 0, :], tf.int32)
      if self.params.task.train_data.preserve_aspect_ratio:
        rescale_size = tf.cast(
            tf.math.ceil(image_info[0, 1, :] / image_info[0, 2, :]), tf.int32)
        offsets = tf.cast(image_info[0, 3, :], tf.int32)
        logits = tf.image.resize(logits, rescale_size, method='bilinear')
        outputs['logits'] = tf.image.crop_to_bounding_box(
            logits, offsets[0], offsets[1], image_shape[0], image_shape[1])
      else:
        outputs['logits'] = tf.image.resize(
            logits, [image_shape[0], image_shape[1]], method='bilinear')
    else:
      outputs['logits'] = tf.image.resize(
          outputs['logits'], self._input_image_size, method='bilinear')

    if image_info is not None:
      outputs.update({'image_info': image_info})

    return outputs