Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Build segmentation models.""" | |
from typing import Any, Mapping, Union, Optional, Dict | |
# Import libraries | |
import tensorflow as tf, tf_keras | |
layers = tf_keras.layers | |
class SegmentationModel(tf_keras.Model): | |
"""A Segmentation class model. | |
Input images are passed through backbone first. Decoder network is then | |
applied, and finally, segmentation head is applied on the output of the | |
decoder network. Layers such as ASPP should be part of decoder. Any feature | |
fusion is done as part of the segmentation head (i.e. deeplabv3+ feature | |
fusion is not part of the decoder, instead it is part of the segmentation | |
head). This way, different feature fusion techniques can be combined with | |
different backbones, and decoders. | |
""" | |
def __init__(self, backbone: tf_keras.Model, decoder: tf_keras.Model, | |
head: tf_keras.layers.Layer, | |
mask_scoring_head: Optional[tf_keras.layers.Layer] = None, | |
**kwargs): | |
"""Segmentation initialization function. | |
Args: | |
backbone: a backbone network. | |
decoder: a decoder network. E.g. FPN. | |
head: segmentation head. | |
mask_scoring_head: mask scoring head. | |
**kwargs: keyword arguments to be passed. | |
""" | |
super(SegmentationModel, self).__init__(**kwargs) | |
self._config_dict = { | |
'backbone': backbone, | |
'decoder': decoder, | |
'head': head, | |
'mask_scoring_head': mask_scoring_head, | |
} | |
self.backbone = backbone | |
self.decoder = decoder | |
self.head = head | |
self.mask_scoring_head = mask_scoring_head | |
def call(self, inputs: tf.Tensor, training: bool = None # pytype: disable=signature-mismatch # overriding-parameter-count-checks | |
) -> Dict[str, tf.Tensor]: | |
backbone_features = self.backbone(inputs) | |
if self.decoder: | |
decoder_features = self.decoder(backbone_features) | |
else: | |
decoder_features = backbone_features | |
logits = self.head((backbone_features, decoder_features)) | |
outputs = {'logits': logits} | |
if self.mask_scoring_head: | |
mask_scores = self.mask_scoring_head(logits) | |
outputs.update({'mask_scores': mask_scores}) | |
return outputs | |
def checkpoint_items( | |
self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]: | |
"""Returns a dictionary of items to be additionally checkpointed.""" | |
items = dict(backbone=self.backbone, head=self.head) | |
if self.decoder is not None: | |
items.update(decoder=self.decoder) | |
if self.mask_scoring_head is not None: | |
items.update(mask_scoring_head=self.mask_scoring_head) | |
return items | |
def get_config(self) -> Mapping[str, Any]: | |
return self._config_dict | |
def from_config(cls, config, custom_objects=None): | |
return cls(**config) | |