# Copyright (C) 2021-2024, Mindee. # This program is licensed under the Apache License 2.0. # See LICENSE or go to for full license details. from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch from torch import nn from torch.nn import functional as F from torchvision.models._utils import IntermediateLayerGetter from doctr.file_utils import CLASS_NAME from doctr.models.classification import resnet18, resnet34, resnet50 from ...utils import _bf16_to_float32, load_pretrained_params from .base import LinkNetPostProcessor, _LinkNet __all__ = ["LinkNet", "linknet_resnet18", "linknet_resnet34", "linknet_resnet50"] default_cfgs: Dict[str, Dict[str, Any]] = { "linknet_resnet18": { "input_shape": (3, 1024, 1024), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet18-e47a14dc.pt&src=0", }, "linknet_resnet34": { "input_shape": (3, 1024, 1024), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet34-9ca2df3e.pt&src=0", }, "linknet_resnet50": { "input_shape": (3, 1024, 1024), "mean": (0.798, 0.785, 0.772), "std": (0.264, 0.2749, 0.287), "url": "https://doctr-static.mindee.com/models?id=v0.7.0/linknet_resnet50-6cf565c1.pt&src=0", }, } class LinkNetFPN(nn.Module): def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None: super().__init__() strides = [ 1 if (in_shape[-1] == out_shape[-1]) else 2 for in_shape, out_shape in zip(layer_shapes[:-1], layer_shapes[1:]) ] chans = [shape[0] for shape in layer_shapes] _decoder_layers = [ self.decoder_block(ochan, ichan, stride) for ichan, ochan, stride in zip(chans[:-1], chans[1:], strides) ] self.decoders = nn.ModuleList(_decoder_layers) @staticmethod def decoder_block(in_chan: int, out_chan: int, stride: int) -> nn.Sequential: """Creates a LinkNet decoder block""" mid_chan = in_chan // 4 return nn.Sequential( nn.Conv2d(in_chan, mid_chan, kernel_size=1, bias=False), nn.BatchNorm2d(mid_chan), nn.ReLU(inplace=True), nn.ConvTranspose2d(mid_chan, mid_chan, 3, padding=1, output_padding=stride - 1, stride=stride, bias=False), nn.BatchNorm2d(mid_chan), nn.ReLU(inplace=True), nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=False), nn.BatchNorm2d(out_chan), nn.ReLU(inplace=True), ) def forward(self, feats: List[torch.Tensor]) -> torch.Tensor: out = feats[-1] for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]): out = decoder(out) + fmap out = self.decoders[0](out) return out class LinkNet(nn.Module, _LinkNet): """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. Args: ---- feature extractor: the backbone serving as feature extractor bin_thresh: threshold for binarization of the output feature map box_thresh: minimal objectness score to consider a box head_chans: number of channels in the head layers assume_straight_pages: if True, fit straight bounding boxes only exportable: onnx exportable returns only logits cfg: the configuration dict of the model class_names: list of class names """ def __init__( self, feat_extractor: IntermediateLayerGetter, bin_thresh: float = 0.1, box_thresh: float = 0.1, head_chans: int = 32, assume_straight_pages: bool = True, exportable: bool = False, cfg: Optional[Dict[str, Any]] = None, class_names: List[str] = [CLASS_NAME], ) -> None: super().__init__() self.class_names = class_names num_classes: int = len(self.class_names) self.cfg = cfg self.exportable = exportable self.assume_straight_pages = assume_straight_pages self.feat_extractor = feat_extractor # Identify the number of channels for the FPN initialization self.feat_extractor.eval() with torch.no_grad(): in_shape = (3, 512, 512) out = self.feat_extractor(torch.zeros((1, *in_shape))) # Get the shapes of the extracted feature maps _shapes = [v.shape[1:] for _, v in out.items()] # Prepend the expected shapes of the first encoder _shapes = [(_shapes[0][0], in_shape[1] // 4, in_shape[2] // 4)] + _shapes self.feat_extractor.train() self.fpn = LinkNetFPN(_shapes) self.classifier = nn.Sequential( nn.ConvTranspose2d( _shapes[0][0], head_chans, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False ), nn.BatchNorm2d(head_chans), nn.ReLU(inplace=True), nn.Conv2d(head_chans, head_chans, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(head_chans), nn.ReLU(inplace=True), nn.ConvTranspose2d(head_chans, num_classes, kernel_size=2, stride=2), ) self.postprocessor = LinkNetPostProcessor( assume_straight_pages=self.assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh ) for n, m in self.named_modules(): # Don't override the initialization of the backbone if n.startswith("feat_extractor."): continue if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight.data, mode="fan_out", nonlinearity="relu") if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1.0) m.bias.data.zero_() def forward( self, x: torch.Tensor, target: Optional[List[np.ndarray]] = None, return_model_output: bool = False, return_preds: bool = False, **kwargs: Any, ) -> Dict[str, Any]: feats = self.feat_extractor(x) logits = self.fpn([feats[str(idx)] for idx in range(len(feats))]) logits = self.classifier(logits) out: Dict[str, Any] = {} if self.exportable: out["logits"] = logits return out if return_model_output or target is None or return_preds: prob_map = _bf16_to_float32(torch.sigmoid(logits)) if return_model_output: out["out_map"] = prob_map if target is None or return_preds: # Post-process boxes out["preds"] = [ dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) ] if target is not None: loss = self.compute_loss(logits, target) out["loss"] = loss return out def compute_loss( self, out_map: torch.Tensor, target: List[np.ndarray], gamma: float = 2.0, alpha: float = 0.5, eps: float = 1e-8, ) -> torch.Tensor: """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on `_. Args: ---- out_map: output feature map of the model of shape (N, num_classes, H, W) target: list of dictionary where each dict has a `boxes` and a `flags` entry gamma: modulating factor in the focal loss formula alpha: balancing factor in the focal loss formula eps: epsilon factor in dice loss Returns: ------- A loss tensor """ _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) seg_mask = seg_mask.to(dtype=torch.float32) bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none") proba_map = torch.sigmoid(out_map) # Focal loss if gamma < 0: raise ValueError("Value of gamma should be greater than or equal to zero.") p_t = proba_map * seg_target + (1 - proba_map) * (1 - seg_target) alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target) # Unreduced version focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss # Class reduced focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3)) # Compute dice loss for each class dice_map = torch.softmax(out_map, dim=1) if len(self.class_names) > 1 else proba_map # Class reduced inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3)) cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3)) dice_loss = (1 - 2 * inter / (cardinality + eps)).mean() # Return the full loss (equal sum of focal loss and dice loss) return focal_loss + dice_loss def _linknet( arch: str, pretrained: bool, backbone_fn: Callable[[bool], nn.Module], fpn_layers: List[str], pretrained_backbone: bool = True, ignore_keys: Optional[List[str]] = None, **kwargs: Any, ) -> LinkNet: pretrained_backbone = pretrained_backbone and not pretrained # Build the feature extractor backbone = backbone_fn(pretrained_backbone) feat_extractor = IntermediateLayerGetter( backbone, {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)}, ) if not kwargs.get("class_names", None): kwargs["class_names"] = default_cfgs[arch].get("class_names", [CLASS_NAME]) else: kwargs["class_names"] = sorted(kwargs["class_names"]) # Build the model model = LinkNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) # Load pretrained parameters if pretrained: # The number of class_names is not the same as the number of classes in the pretrained model => # remove the layer weights _ignore_keys = ( ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None ) load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. >>> import torch >>> from doctr.models import linknet_resnet18 >>> model = linknet_resnet18(pretrained=True).eval() >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: ------- text detection architecture """ return _linknet( "linknet_resnet18", pretrained, resnet18, ["layer1", "layer2", "layer3", "layer4"], ignore_keys=[ "classifier.6.weight", "classifier.6.bias", ], **kwargs, ) def linknet_resnet34(pretrained: bool = False, **kwargs: Any) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. >>> import torch >>> from doctr.models import linknet_resnet34 >>> model = linknet_resnet34(pretrained=True).eval() >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: ------- text detection architecture """ return _linknet( "linknet_resnet34", pretrained, resnet34, ["layer1", "layer2", "layer3", "layer4"], ignore_keys=[ "classifier.6.weight", "classifier.6.bias", ], **kwargs, ) def linknet_resnet50(pretrained: bool = False, **kwargs: Any) -> LinkNet: """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_. >>> import torch >>> from doctr.models import linknet_resnet50 >>> model = linknet_resnet50(pretrained=True).eval() >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) >>> out = model(input_tensor) Args: ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the LinkNet architecture Returns: ------- text detection architecture """ return _linknet( "linknet_resnet50", pretrained, resnet50, ["layer1", "layer2", "layer3", "layer4"], ignore_keys=[ "classifier.6.weight", "classifier.6.bias", ], **kwargs, )