deanna-emery's picture
updates
93528c6
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains definitions of ROI aligner."""
from typing import Mapping
import tensorflow as tf, tf_keras
from official.vision.ops import spatial_transform_ops
@tf_keras.utils.register_keras_serializable(package='Vision')
class MultilevelROIAligner(tf_keras.layers.Layer):
"""Performs ROIAlign for the second stage processing."""
def __init__(self, crop_size: int = 7, sample_offset: float = 0.5, **kwargs):
"""Initializes a ROI aligner.
Args:
crop_size: An `int` of the output size of the cropped features.
sample_offset: A `float` in [0, 1] of the subpixel sample offset.
**kwargs: Additional keyword arguments passed to Layer.
"""
self._config_dict = {
'crop_size': crop_size,
'sample_offset': sample_offset,
}
super(MultilevelROIAligner, self).__init__(**kwargs)
def call(self,
features: Mapping[str, tf.Tensor],
boxes: tf.Tensor,
training: bool = None):
"""Generates ROIs.
Args:
features: A dictionary with key as pyramid level and value as features.
The features are in shape of
[batch_size, height_l, width_l, num_filters].
boxes: A 3-D `tf.Tensor` of shape [batch_size, num_boxes, 4]. Each row
represents a box with [y1, x1, y2, x2] in un-normalized coordinates.
from grid point.
training: A `bool` of whether it is in training mode.
Returns:
A 5-D `tf.Tensor` representing feature crop of shape
[batch_size, num_boxes, crop_size, crop_size, num_filters].
"""
roi_features = spatial_transform_ops.multilevel_crop_and_resize(
features,
boxes,
output_size=self._config_dict['crop_size'],
sample_offset=self._config_dict['sample_offset'])
return roi_features
def get_config(self):
return self._config_dict
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)