Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Image segmentation task definition.""" | |
from typing import Any, List, Mapping, Optional, Tuple, Union | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.common import dataset_fn | |
from official.core import base_task | |
from official.core import task_factory | |
from official.vision.configs import semantic_segmentation as exp_cfg | |
from official.vision.dataloaders import input_reader | |
from official.vision.dataloaders import input_reader_factory | |
from official.vision.dataloaders import segmentation_input | |
from official.vision.dataloaders import tfds_factory | |
from official.vision.evaluation import segmentation_metrics | |
from official.vision.losses import segmentation_losses | |
from official.vision.modeling import factory | |
from official.vision.utils.object_detection import visualization_utils | |
class SemanticSegmentationTask(base_task.Task): | |
"""A task for semantic segmentation.""" | |
def build_model(self): | |
"""Builds segmentation model.""" | |
input_specs = tf_keras.layers.InputSpec(shape=[None] + | |
self.task_config.model.input_size) | |
l2_weight_decay = self.task_config.losses.l2_weight_decay | |
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss. | |
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2) | |
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss) | |
l2_regularizer = ( | |
tf_keras.regularizers.l2(l2_weight_decay / | |
2.0) if l2_weight_decay else None) | |
model = factory.build_segmentation_model( | |
input_specs=input_specs, | |
model_config=self.task_config.model, | |
l2_regularizer=l2_regularizer) | |
# Builds the model | |
dummy_inputs = tf_keras.Input(self.task_config.model.input_size) | |
_ = model(dummy_inputs, training=False) | |
return model | |
def initialize(self, model: tf_keras.Model): | |
"""Loads pretrained checkpoint.""" | |
if not self.task_config.init_checkpoint: | |
return | |
ckpt_dir_or_file = self.task_config.init_checkpoint | |
if tf.io.gfile.isdir(ckpt_dir_or_file): | |
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file) | |
# Restoring checkpoint. | |
if 'all' in self.task_config.init_checkpoint_modules: | |
ckpt = tf.train.Checkpoint(**model.checkpoint_items) | |
status = ckpt.read(ckpt_dir_or_file) | |
status.expect_partial().assert_existing_objects_matched() | |
else: | |
ckpt_items = {} | |
if 'backbone' in self.task_config.init_checkpoint_modules: | |
ckpt_items.update(backbone=model.backbone) | |
if 'decoder' in self.task_config.init_checkpoint_modules: | |
ckpt_items.update(decoder=model.decoder) | |
ckpt = tf.train.Checkpoint(**ckpt_items) | |
status = ckpt.read(ckpt_dir_or_file) | |
status.expect_partial().assert_existing_objects_matched() | |
logging.info('Finished loading pretrained checkpoint from %s', | |
ckpt_dir_or_file) | |
def build_inputs(self, | |
params: exp_cfg.DataConfig, | |
input_context: Optional[tf.distribute.InputContext] = None): | |
"""Builds classification input.""" | |
ignore_label = self.task_config.losses.ignore_label | |
gt_is_matting_map = self.task_config.losses.gt_is_matting_map | |
if params.tfds_name: | |
decoder = tfds_factory.get_segmentation_decoder(params.tfds_name) | |
else: | |
decoder = segmentation_input.Decoder( | |
image_feature=params.image_feature, | |
additional_dense_features=params.additional_dense_features) | |
parser = segmentation_input.Parser( | |
output_size=params.output_size, | |
crop_size=params.crop_size, | |
ignore_label=ignore_label, | |
resize_eval_groundtruth=params.resize_eval_groundtruth, | |
gt_is_matting_map=gt_is_matting_map, | |
groundtruth_padded_size=params.groundtruth_padded_size, | |
aug_scale_min=params.aug_scale_min, | |
aug_scale_max=params.aug_scale_max, | |
aug_rand_hflip=params.aug_rand_hflip, | |
preserve_aspect_ratio=params.preserve_aspect_ratio, | |
dtype=params.dtype, | |
image_feature=params.image_feature, | |
additional_dense_features=params.additional_dense_features) | |
reader = input_reader_factory.input_reader_generator( | |
params, | |
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), | |
decoder_fn=decoder.decode, | |
combine_fn=input_reader.create_combine_fn(params), | |
parser_fn=parser.parse_fn(params.is_training)) | |
dataset = reader.read(input_context=input_context) | |
return dataset | |
def build_losses(self, | |
labels: Mapping[str, tf.Tensor], | |
model_outputs: Union[Mapping[str, tf.Tensor], tf.Tensor], | |
aux_losses: Optional[Any] = None): | |
"""Segmentation loss. | |
Args: | |
labels: labels. | |
model_outputs: Output logits of the classifier. | |
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. | |
Returns: | |
The total loss tensor. | |
""" | |
loss_params = self._task_config.losses | |
segmentation_loss_fn = segmentation_losses.SegmentationLoss( | |
loss_params.label_smoothing, | |
loss_params.class_weights, | |
loss_params.ignore_label, | |
use_groundtruth_dimension=loss_params.use_groundtruth_dimension, | |
use_binary_cross_entropy=loss_params.use_binary_cross_entropy, | |
top_k_percent_pixels=loss_params.top_k_percent_pixels, | |
gt_is_matting_map=loss_params.gt_is_matting_map) | |
total_loss = segmentation_loss_fn(model_outputs['logits'], labels['masks']) | |
if 'mask_scores' in model_outputs: | |
mask_scoring_loss_fn = segmentation_losses.MaskScoringLoss( | |
loss_params.ignore_label) | |
total_loss += loss_params.mask_scoring_weight * mask_scoring_loss_fn( | |
model_outputs['mask_scores'], | |
model_outputs['logits'], | |
labels['masks']) | |
if aux_losses: | |
total_loss += tf.add_n(aux_losses) | |
total_loss = loss_params.loss_weight * total_loss | |
return total_loss | |
def process_metrics(self, metrics, labels, model_outputs, **kwargs): | |
"""Process and update metrics. | |
Called when using custom training loop API. | |
Args: | |
metrics: a nested structure of metrics objects. The return of function | |
self.build_metrics. | |
labels: a tensor or a nested structure of tensors. | |
model_outputs: a tensor or a nested structure of tensors. For example, | |
output of the keras model built by self.build_model. | |
**kwargs: other args. | |
""" | |
for metric in metrics: | |
if 'mask_scores_mse' == metric.name: | |
actual_mask_scores = segmentation_losses.get_actual_mask_scores( | |
model_outputs['logits'], labels['masks'], | |
self.task_config.losses.ignore_label) | |
metric.update_state(actual_mask_scores, model_outputs['mask_scores']) | |
else: | |
metric.update_state(labels, model_outputs['logits']) | |
def build_metrics(self, training: bool = True): | |
"""Gets streaming metrics for training/validation.""" | |
metrics = [] | |
self.iou_metric = None | |
if training and self.task_config.evaluation.report_train_mean_iou: | |
metrics.append( | |
segmentation_metrics.MeanIoU( | |
name='mean_iou', | |
num_classes=self.task_config.model.num_classes, | |
rescale_predictions=False, | |
dtype=tf.float32)) | |
if self.task_config.model.get('mask_scoring_head'): | |
metrics.append( | |
tf_keras.metrics.MeanSquaredError(name='mask_scores_mse')) | |
if not training: | |
self.iou_metric = segmentation_metrics.PerClassIoU( | |
name='per_class_iou', | |
num_classes=self.task_config.model.num_classes, | |
rescale_predictions=( | |
not self.task_config.validation_data.resize_eval_groundtruth), | |
dtype=tf.float32) | |
if (self.task_config.validation_data.resize_eval_groundtruth and | |
self.task_config.model.get('mask_scoring_head')): | |
# Masks scores metric can only be computed if labels are scaled to match | |
# preticted mask scores. | |
metrics.append( | |
tf_keras.metrics.MeanSquaredError(name='mask_scores_mse')) | |
return metrics | |
def train_step(self, | |
inputs: Tuple[Any, Any], | |
model: tf_keras.Model, | |
optimizer: tf_keras.optimizers.Optimizer, | |
metrics: Optional[List[Any]] = None): | |
"""Does forward and backward. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the model, forward pass definition. | |
optimizer: the optimizer for this training step. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
features, labels = inputs | |
input_partition_dims = self.task_config.train_input_partition_dims | |
if input_partition_dims: | |
strategy = tf.distribute.get_strategy() | |
features = strategy.experimental_split_to_logical_devices( | |
features, input_partition_dims) | |
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync | |
with tf.GradientTape() as tape: | |
outputs = model(features, training=True) | |
if isinstance(outputs, tf.Tensor): | |
outputs = {'logits': outputs} | |
# Casting output layer as float32 is necessary when mixed_precision is | |
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32. | |
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) | |
# Computes per-replica loss. | |
loss = self.build_losses( | |
model_outputs=outputs, labels=labels, aux_losses=model.losses) | |
# Scales loss as the default gradients allreduce performs sum inside the | |
# optimizer. | |
scaled_loss = loss / num_replicas | |
# For mixed_precision policy, when LossScaleOptimizer is used, loss is | |
# scaled for numerical stability. | |
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
scaled_loss = optimizer.get_scaled_loss(scaled_loss) | |
tvars = model.trainable_variables | |
grads = tape.gradient(scaled_loss, tvars) | |
# Scales back gradient before apply_gradients when LossScaleOptimizer is | |
# used. | |
if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer): | |
grads = optimizer.get_unscaled_gradients(grads) | |
optimizer.apply_gradients(list(zip(grads, tvars))) | |
logs = {self.loss: loss} | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
logs.update({m.name: m.result() for m in metrics}) | |
return logs | |
def validation_step(self, | |
inputs: Tuple[Any, Any], | |
model: tf_keras.Model, | |
metrics: Optional[List[Any]] = None): | |
"""Validatation step. | |
Args: | |
inputs: a dictionary of input tensors. | |
model: the keras.Model. | |
metrics: a nested structure of metrics objects. | |
Returns: | |
A dictionary of logs. | |
""" | |
features, labels = inputs | |
input_partition_dims = self.task_config.eval_input_partition_dims | |
if input_partition_dims: | |
strategy = tf.distribute.get_strategy() | |
features = strategy.experimental_split_to_logical_devices( | |
features, input_partition_dims) | |
outputs = self.inference_step(features, model) | |
if isinstance(outputs, tf.Tensor): | |
outputs = {'logits': outputs} | |
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) | |
if self.task_config.validation_data.resize_eval_groundtruth: | |
loss = self.build_losses( | |
model_outputs=outputs, labels=labels, aux_losses=model.losses) | |
else: | |
loss = 0 | |
logs = {self.loss: loss} | |
if self.iou_metric is not None: | |
self.iou_metric.update_state(labels, outputs['logits']) | |
if metrics: | |
self.process_metrics(metrics, labels, outputs) | |
if ( | |
hasattr(self.task_config, 'allow_image_summary') | |
and self.task_config.allow_image_summary | |
): | |
logs.update( | |
{'visualization': (tf.cast(features, dtype=tf.float32), outputs)} | |
) | |
return logs | |
def inference_step(self, inputs: tf.Tensor, model: tf_keras.Model): | |
"""Performs the forward step.""" | |
return model(inputs, training=False) | |
def aggregate_logs(self, state=None, step_outputs=None): | |
if state is None and self.iou_metric is not None: | |
self.iou_metric.reset_states() | |
if 'visualization' in step_outputs: | |
# Update segmentation state for writing summary if there are artifacts for | |
# visualization. | |
if state is None: | |
state = {} | |
state.update(visualization_utils.update_segmentation_state(step_outputs)) | |
if state is None: | |
# Create an arbitrary state to indicate it's not the first step in the | |
# following calls to this function. | |
state = True | |
return state | |
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): | |
logs = {} | |
if self.iou_metric is not None: | |
ious = self.iou_metric.result() | |
# TODO(arashwan): support loading class name from a label map file. | |
if self.task_config.evaluation.report_per_class_iou: | |
for i, value in enumerate(ious.numpy()): | |
logs.update({'iou/{}'.format(i): value}) | |
# Computes mean IoU | |
logs.update({'mean_iou': tf.reduce_mean(ious)}) | |
# Add visualization for summary. | |
if isinstance(aggregated_logs, dict) and 'image' in aggregated_logs: | |
validation_outputs = visualization_utils.visualize_segmentation_outputs( | |
logs=aggregated_logs, task_config=self.task_config | |
) | |
logs.update(validation_outputs) | |
return logs | |