Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contains definitions of ROI aligner.""" | |
from typing import Mapping | |
import tensorflow as tf, tf_keras | |
from official.vision.ops import spatial_transform_ops | |
class MultilevelROIAligner(tf_keras.layers.Layer): | |
"""Performs ROIAlign for the second stage processing.""" | |
def __init__(self, crop_size: int = 7, sample_offset: float = 0.5, **kwargs): | |
"""Initializes a ROI aligner. | |
Args: | |
crop_size: An `int` of the output size of the cropped features. | |
sample_offset: A `float` in [0, 1] of the subpixel sample offset. | |
**kwargs: Additional keyword arguments passed to Layer. | |
""" | |
self._config_dict = { | |
'crop_size': crop_size, | |
'sample_offset': sample_offset, | |
} | |
super(MultilevelROIAligner, self).__init__(**kwargs) | |
def call(self, | |
features: Mapping[str, tf.Tensor], | |
boxes: tf.Tensor, | |
training: bool = None): | |
"""Generates ROIs. | |
Args: | |
features: A dictionary with key as pyramid level and value as features. | |
The features are in shape of | |
[batch_size, height_l, width_l, num_filters]. | |
boxes: A 3-D `tf.Tensor` of shape [batch_size, num_boxes, 4]. Each row | |
represents a box with [y1, x1, y2, x2] in un-normalized coordinates. | |
from grid point. | |
training: A `bool` of whether it is in training mode. | |
Returns: | |
A 5-D `tf.Tensor` representing feature crop of shape | |
[batch_size, num_boxes, crop_size, crop_size, num_filters]. | |
""" | |
roi_features = spatial_transform_ops.multilevel_crop_and_resize( | |
features, | |
boxes, | |
output_size=self._config_dict['crop_size'], | |
sample_offset=self._config_dict['sample_offset']) | |
return roi_features | |
def get_config(self): | |
return self._config_dict | |
def from_config(cls, config, custom_objects=None): | |
return cls(**config) | |