File size: 14,050 Bytes
9bf4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Dict, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from mmdet.models.utils import multi_apply
from torch import nn

from mmocr.registry import MODELS
from mmocr.structures import TextDetDataSample
from .seg_based_module_loss import SegBasedModuleLoss


@MODELS.register_module()
class PANModuleLoss(SegBasedModuleLoss):
    """The class for implementing PANet loss. This was partially adapted from
    https://github.com/whai362/pan_pp.pytorch and
    https://github.com/WenmuZhou/PAN.pytorch.

    PANet: `Efficient and Accurate Arbitrary-
    Shaped Text Detection with Pixel Aggregation Network
    <https://arxiv.org/abs/1908.05900>`_.

    Args:
        loss_text (dict) The loss config for text map. Defaults to
            dict(type='MaskedSquareDiceLoss').
        loss_kernel (dict) The loss config for kernel map. Defaults to
            dict(type='MaskedSquareDiceLoss').
        loss_embedding (dict) The loss config for embedding map. Defaults to
            dict(type='PANEmbLossV1').
        weight_text (float): The weight of text loss. Defaults to 1.
        weight_kernel (float): The weight of kernel loss. Defaults to 0.5.
        weight_embedding (float): The weight of embedding loss.
            Defaults to 0.25.
        ohem_ratio (float): The negative/positive ratio in ohem. Defaults to 3.
        shrink_ratio (tuple[float]) : The ratio of shrinking kernel. Defaults
            to (1.0, 0.5).
        max_shrink_dist (int or float): The maximum shrinking distance.
            Defaults to 20.
        reduction (str): The way to reduce the loss. Available options are
            "mean" and "sum". Defaults to 'mean'.
    """

    def __init__(
            self,
            loss_text: Dict = dict(type='MaskedSquareDiceLoss'),
            loss_kernel: Dict = dict(type='MaskedSquareDiceLoss'),
            loss_embedding: Dict = dict(type='PANEmbLossV1'),
            weight_text: float = 1.0,
            weight_kernel: float = 0.5,
            weight_embedding: float = 0.25,
            ohem_ratio: Union[int, float] = 3,  # TODO Find a better name
            shrink_ratio: Sequence[Union[int, float]] = (1.0, 0.5),
            max_shrink_dist: Union[int, float] = 20,
            reduction: str = 'mean') -> None:
        super().__init__()
        assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']"
        self.weight_text = weight_text
        self.weight_kernel = weight_kernel
        self.weight_embedding = weight_embedding
        self.shrink_ratio = shrink_ratio
        self.ohem_ratio = ohem_ratio
        self.reduction = reduction
        self.max_shrink_dist = max_shrink_dist
        self.loss_text = MODELS.build(loss_text)
        self.loss_kernel = MODELS.build(loss_kernel)
        self.loss_embedding = MODELS.build(loss_embedding)

    def forward(self, preds: torch.Tensor,
                data_samples: Sequence[TextDetDataSample]) -> Dict:
        """Compute PAN loss.

        Args:
            preds (dict): Raw predictions from model with
                shape :math:`(N, C, H, W)`.
            data_samples (list[TextDetDataSample]): The data samples.

        Returns:
            dict: The dict for pan losses with loss_text, loss_kernel,
            loss_aggregation and loss_discrimination.
        """

        gt_kernels, gt_masks = self.get_targets(data_samples)
        target_size = gt_kernels.size()[2:]
        preds = F.interpolate(preds, size=target_size, mode='bilinear')
        pred_texts = preds[:, 0, :, :]
        pred_kernels = preds[:, 1, :, :]
        inst_embed = preds[:, 2:, :, :]
        gt_kernels = gt_kernels.to(preds.device)
        gt_masks = gt_masks.to(preds.device)

        # compute embedding loss
        loss_emb = self.loss_embedding(inst_embed, gt_kernels[0],
                                       gt_kernels[1], gt_masks)
        gt_kernels[gt_kernels <= 0.5] = 0
        gt_kernels[gt_kernels > 0.5] = 1
        # compute text loss
        sampled_mask = self._ohem_batch(pred_texts.detach(), gt_kernels[0],
                                        gt_masks)
        pred_texts = torch.sigmoid(pred_texts)
        loss_texts = self.loss_text(pred_texts, gt_kernels[0], sampled_mask)

        # compute kernel loss
        pred_kernels = torch.sigmoid(pred_kernels)
        sampled_masks_kernel = (gt_kernels[0] > 0.5).float() * gt_masks
        loss_kernels = self.loss_kernel(pred_kernels, gt_kernels[1],
                                        sampled_masks_kernel)

        losses = [loss_texts, loss_kernels, loss_emb]
        if self.reduction == 'mean':
            losses = [item.mean() for item in losses]
        else:
            losses = [item.sum() for item in losses]

        results = dict()
        results.update(
            loss_text=self.weight_text * losses[0],
            loss_kernel=self.weight_kernel * losses[1],
            loss_embedding=self.weight_embedding * losses[2])
        return results

    def get_targets(
        self,
        data_samples: Sequence[TextDetDataSample],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate the gt targets for PANet.

        Args:
            results (dict): The input result dictionary.

        Returns:
            results (dict): The output result dictionary.
        """
        gt_kernels, gt_masks = multi_apply(self._get_target_single,
                                           data_samples)
        # gt_kernels: (N, kernel_number, H, W)->(kernel_number, N, H, W)
        gt_kernels = torch.stack(gt_kernels, dim=0).permute(1, 0, 2, 3)
        gt_masks = torch.stack(gt_masks, dim=0)
        return gt_kernels, gt_masks

    def _get_target_single(self, data_sample: TextDetDataSample
                           ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Generate loss target from a data sample.

        Args:
            data_sample (TextDetDataSample): The data sample.

        Returns:
            tuple: A tuple of four tensors as the targets of one prediction.
        """
        gt_polygons = data_sample.gt_instances.polygons
        gt_ignored = data_sample.gt_instances.ignored

        gt_kernels = []
        for ratio in self.shrink_ratio:
            # TODO pass `gt_ignored` to `_generate_kernels`
            gt_kernel, _ = self._generate_kernels(
                data_sample.img_shape,
                gt_polygons,
                ratio,
                ignore_flags=None,
                max_shrink_dist=self.max_shrink_dist)
            gt_kernels.append(gt_kernel)
        gt_polygons_ignored = data_sample.gt_instances[gt_ignored].polygons
        gt_mask = self._generate_effective_mask(data_sample.img_shape,
                                                gt_polygons_ignored)

        gt_kernels = np.stack(gt_kernels, axis=0)
        gt_kernels = torch.from_numpy(gt_kernels).float()
        gt_mask = torch.from_numpy(gt_mask).float()
        return gt_kernels, gt_mask

    def _ohem_batch(self, text_scores: torch.Tensor, gt_texts: torch.Tensor,
                    gt_mask: torch.Tensor) -> torch.Tensor:
        """OHEM sampling for a batch of imgs.

        Args:
            text_scores (Tensor): The text scores of size :math:`(H, W)`.
            gt_texts (Tensor): The gt text masks of size :math:`(H, W)`.
            gt_mask (Tensor): The gt effective mask of size :math:`(H, W)`.

        Returns:
            Tensor: The sampled mask of size :math:`(H, W)`.
        """
        assert isinstance(text_scores, torch.Tensor)
        assert isinstance(gt_texts, torch.Tensor)
        assert isinstance(gt_mask, torch.Tensor)
        assert len(text_scores.shape) == 3
        assert text_scores.shape == gt_texts.shape
        assert gt_texts.shape == gt_mask.shape

        sampled_masks = []
        for i in range(text_scores.shape[0]):
            sampled_masks.append(
                self._ohem_single(text_scores[i], gt_texts[i], gt_mask[i]))

        sampled_masks = torch.stack(sampled_masks)

        return sampled_masks

    def _ohem_single(self, text_score: torch.Tensor, gt_text: torch.Tensor,
                     gt_mask: torch.Tensor) -> torch.Tensor:
        """Sample the top-k maximal negative samples and all positive samples.

        Args:
            text_score (Tensor): The text score of size :math:`(H, W)`.
            gt_text (Tensor): The ground truth text mask of size
                :math:`(H, W)`.
            gt_mask (Tensor): The effective region mask of size :math:`(H, W)`.

        Returns:
            Tensor: The sampled pixel mask of size :math:`(H, W)`.
        """
        assert isinstance(text_score, torch.Tensor)
        assert isinstance(gt_text, torch.Tensor)
        assert isinstance(gt_mask, torch.Tensor)
        assert len(text_score.shape) == 2
        assert text_score.shape == gt_text.shape
        assert gt_text.shape == gt_mask.shape

        pos_num = (int)(torch.sum(gt_text > 0.5).item()) - (int)(
            torch.sum((gt_text > 0.5) * (gt_mask <= 0.5)).item())
        neg_num = (int)(torch.sum(gt_text <= 0.5).item())
        neg_num = (int)(min(pos_num * self.ohem_ratio, neg_num))

        if pos_num == 0 or neg_num == 0:
            warnings.warn('pos_num = 0 or neg_num = 0')
            return gt_mask.bool()

        neg_score = text_score[gt_text <= 0.5]
        neg_score_sorted, _ = torch.sort(neg_score, descending=True)
        threshold = neg_score_sorted[neg_num - 1]
        sampled_mask = (((text_score >= threshold) + (gt_text > 0.5)) > 0) * (
            gt_mask > 0.5)
        return sampled_mask


@MODELS.register_module()
class PANEmbLossV1(nn.Module):
    """The class for implementing EmbLossV1. This was partially adapted from
    https://github.com/whai362/pan_pp.pytorch.

    Args:
        feature_dim (int): The dimension of the feature. Defaults to 4.
        delta_aggregation (float): The delta for aggregation. Defaults to 0.5.
        delta_discrimination (float): The delta for discrimination.
            Defaults to 1.5.
    """

    def __init__(self,
                 feature_dim: int = 4,
                 delta_aggregation: float = 0.5,
                 delta_discrimination: float = 1.5) -> None:
        super().__init__()
        self.feature_dim = feature_dim
        self.delta_aggregation = delta_aggregation
        self.delta_discrimination = delta_discrimination
        self.weights = (1.0, 1.0)

    def _forward_single(self, emb: torch.Tensor, instance: torch.Tensor,
                        kernel: torch.Tensor,
                        training_mask: torch.Tensor) -> torch.Tensor:
        """Compute the loss for a single image.

        Args:
            emb (torch.Tensor): The embedding feature.
            instance (torch.Tensor): The instance feature.
            kernel (torch.Tensor): The kernel feature.
            training_mask (torch.Tensor): The effective mask.
        """
        training_mask = (training_mask > 0.5).float()
        kernel = (kernel > 0.5).float()
        instance = instance * training_mask
        instance_kernel = (instance * kernel).view(-1)
        instance = instance.view(-1)
        emb = emb.view(self.feature_dim, -1)

        unique_labels, unique_ids = torch.unique(
            instance_kernel, sorted=True, return_inverse=True)
        num_instance = unique_labels.size(0)
        if num_instance <= 1:
            return 0

        emb_mean = emb.new_zeros((self.feature_dim, num_instance),
                                 dtype=torch.float32)
        for i, lb in enumerate(unique_labels):
            if lb == 0:
                continue
            ind_k = instance_kernel == lb
            emb_mean[:, i] = torch.mean(emb[:, ind_k], dim=1)

        l_agg = emb.new_zeros(num_instance, dtype=torch.float32)
        for i, lb in enumerate(unique_labels):
            if lb == 0:
                continue
            ind = instance == lb
            emb_ = emb[:, ind]
            dist = (emb_ - emb_mean[:, i:i + 1]).norm(p=2, dim=0)
            dist = F.relu(dist - self.delta_aggregation)**2
            l_agg[i] = torch.mean(torch.log(dist + 1.0))
        l_agg = torch.mean(l_agg[1:])

        if num_instance > 2:
            emb_interleave = emb_mean.permute(1, 0).repeat(num_instance, 1)
            emb_band = emb_mean.permute(1, 0).repeat(1, num_instance).view(
                -1, self.feature_dim)

            mask = (1 - torch.eye(num_instance, dtype=torch.int8)).view(
                -1, 1).repeat(1, self.feature_dim)
            mask = mask.view(num_instance, num_instance, -1)
            mask[0, :, :] = 0
            mask[:, 0, :] = 0
            mask = mask.view(num_instance * num_instance, -1)

            dist = emb_interleave - emb_band
            dist = dist[mask > 0].view(-1, self.feature_dim).norm(p=2, dim=1)
            dist = F.relu(2 * self.delta_discrimination - dist)**2
            l_dis = torch.mean(torch.log(dist + 1.0))
        else:
            l_dis = 0

        l_agg = self.weights[0] * l_agg
        l_dis = self.weights[1] * l_dis
        l_reg = torch.mean(torch.log(torch.norm(emb_mean, 2, 0) + 1.0)) * 0.001
        loss = l_agg + l_dis + l_reg
        return loss

    def forward(self, emb: torch.Tensor, instance: torch.Tensor,
                kernel: torch.Tensor,
                training_mask: torch.Tensor) -> torch.Tensor:
        """Compute the loss for a batch image.

        Args:
            emb (torch.Tensor): The embedding feature.
            instance (torch.Tensor): The instance feature.
            kernel (torch.Tensor): The kernel feature.
            training_mask (torch.Tensor): The effective mask.
        """
        loss_batch = emb.new_zeros((emb.size(0)), dtype=torch.float32)

        for i in range(loss_batch.size(0)):
            loss_batch[i] = self._forward_single(emb[i], instance[i],
                                                 kernel[i], training_mask[i])

        return loss_batch