# Copyright (c) Facebook, Inc. and its affiliates. import numpy as np from typing import Callable, Dict, Optional, Tuple, Union import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.config import configurable from detectron2.layers import Conv2d, ShapeSpec, get_norm from detectron2.structures import ImageList from detectron2.utils.registry import Registry from ..backbone import Backbone, build_backbone from ..postprocessing import sem_seg_postprocess from .build import META_ARCH_REGISTRY __all__ = [ "SemanticSegmentor", "SEM_SEG_HEADS_REGISTRY", "SemSegFPNHead", "build_sem_seg_head", ] SEM_SEG_HEADS_REGISTRY = Registry("SEM_SEG_HEADS") SEM_SEG_HEADS_REGISTRY.__doc__ = """ Registry for semantic segmentation heads, which make semantic segmentation predictions from feature maps. """ @META_ARCH_REGISTRY.register() class SemanticSegmentor(nn.Module): """ Main class for semantic segmentation architectures. """ @configurable def __init__( self, *, backbone: Backbone, sem_seg_head: nn.Module, pixel_mean: Tuple[float], pixel_std: Tuple[float], ): """ Args: backbone: a backbone module, must follow detectron2's backbone interface sem_seg_head: a module that predicts semantic segmentation from backbone features pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image """ super().__init__() self.backbone = backbone self.sem_seg_head = sem_seg_head self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) @classmethod def from_config(cls, cfg): backbone = build_backbone(cfg) sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) return { "backbone": backbone, "sem_seg_head": sem_seg_head, "pixel_mean": cfg.MODEL.PIXEL_MEAN, "pixel_std": cfg.MODEL.PIXEL_STD, } @property def device(self): return self.pixel_mean.device def forward(self, batched_inputs): """ Args: batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * "image": Tensor, image in (C, H, W) format. * "sem_seg": semantic segmentation ground truth * Other information that's included in the original dicts, such as: "height", "width" (int): the output resolution of the model (may be different from input resolution), used in inference. Returns: list[dict]: Each dict is the output for one input image. The dict contains one key "sem_seg" whose value is a Tensor that represents the per-pixel segmentation prediced by the head. The prediction has shape KxHxW that represents the logits of each class for each pixel. """ images = [x["image"].to(self.device) for x in batched_inputs] images = [(x - self.pixel_mean) / self.pixel_std for x in images] images = ImageList.from_tensors(images, self.backbone.size_divisibility) features = self.backbone(images.tensor) if "sem_seg" in batched_inputs[0]: targets = [x["sem_seg"].to(self.device) for x in batched_inputs] targets = ImageList.from_tensors( targets, self.backbone.size_divisibility, self.sem_seg_head.ignore_value ).tensor else: targets = None results, losses = self.sem_seg_head(features, targets) if self.training: return losses processed_results = [] for result, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) r = sem_seg_postprocess(result, image_size, height, width) processed_results.append({"sem_seg": r}) return processed_results def build_sem_seg_head(cfg, input_shape): """ Build a semantic segmentation head from `cfg.MODEL.SEM_SEG_HEAD.NAME`. """ name = cfg.MODEL.SEM_SEG_HEAD.NAME return SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) @SEM_SEG_HEADS_REGISTRY.register() class SemSegFPNHead(nn.Module): """ A semantic segmentation head described in :paper:`PanopticFPN`. It takes a list of FPN features as input, and applies a sequence of 3x3 convs and upsampling to scale all of them to the stride defined by ``common_stride``. Then these features are added and used to make final predictions by another 1x1 conv layer. """ @configurable def __init__( self, input_shape: Dict[str, ShapeSpec], *, num_classes: int, conv_dims: int, common_stride: int, loss_weight: float = 1.0, norm: Optional[Union[str, Callable]] = None, ignore_value: int = -1, ): """ NOTE: this interface is experimental. Args: input_shape: shapes (channels and stride) of the input features num_classes: number of classes to predict conv_dims: number of output channels for the intermediate conv layers. common_stride: the common stride that all features will be upscaled to loss_weight: loss weight norm (str or callable): normalization for all conv layers ignore_value: category id to be ignored during training. """ super().__init__() input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) if not len(input_shape): raise ValueError("SemSegFPNHead(input_shape=) cannot be empty!") self.in_features = [k for k, v in input_shape] feature_strides = [v.stride for k, v in input_shape] feature_channels = [v.channels for k, v in input_shape] self.ignore_value = ignore_value self.common_stride = common_stride self.loss_weight = loss_weight self.scale_heads = [] for in_feature, stride, channels in zip( self.in_features, feature_strides, feature_channels ): head_ops = [] head_length = max(1, int(np.log2(stride) - np.log2(self.common_stride))) for k in range(head_length): norm_module = get_norm(norm, conv_dims) conv = Conv2d( channels if k == 0 else conv_dims, conv_dims, kernel_size=3, stride=1, padding=1, bias=not norm, norm=norm_module, activation=F.relu, ) weight_init.c2_msra_fill(conv) head_ops.append(conv) if stride != self.common_stride: head_ops.append( nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) ) self.scale_heads.append(nn.Sequential(*head_ops)) self.add_module(in_feature, self.scale_heads[-1]) self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0) weight_init.c2_msra_fill(self.predictor) @classmethod def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): return { "input_shape": { k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES }, "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, "conv_dims": cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM, "common_stride": cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE, "norm": cfg.MODEL.SEM_SEG_HEAD.NORM, "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, } def forward(self, features, targets=None): """ Returns: In training, returns (None, dict of losses) In inference, returns (CxHxW logits, {}) """ x = self.layers(features) if self.training: return None, self.losses(x, targets) else: x = F.interpolate( x, scale_factor=self.common_stride, mode="bilinear", align_corners=False ) return x, {} def layers(self, features): for i, f in enumerate(self.in_features): if i == 0: x = self.scale_heads[i](features[f]) else: x = x + self.scale_heads[i](features[f]) x = self.predictor(x) return x def losses(self, predictions, targets): predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 predictions = F.interpolate( predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False, ) loss = F.cross_entropy( predictions, targets, reduction="mean", ignore_index=self.ignore_value ) losses = {"loss_sem_seg": loss * self.loss_weight} return losses