File size: 2,402 Bytes
a059c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
from torchvision.ops.boxes import batched_nms

from util import box_ops


class CondNMSPostProcess(nn.Module):
    def __init__(self, num_queries):
        super(CondNMSPostProcess, self).__init__()
        self.num_queries = num_queries

    @torch.no_grad()
    def forward(self, outputs, target_sizes, pred_names, mask_infos):
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
        bs = len(out_logits)
        results = []

        for b in range(bs):
            b_scores, b_boxes, b_names = [], [], []
            b_start_id, b_end_id = [], []
            name = []
            for name_i in pred_names[b]:
                name.append([name_i] * self.num_queries)
            start_id, end_id = [], []
            for (start, end) in mask_infos[b].keys():
                start_id.append([start] * self.num_queries)
                end_id.append([end] * self.num_queries)
            prob = out_logits[b][0][:, -1:].sigmoid()
            if len(prob) == 0:
                continue
            boxes = box_ops.box_cxcywh_to_xyxy(out_bbox[b][0])
            img_h, img_w = target_sizes[b]
            scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=0)
            boxes = boxes * scale_fct[None, :]
            num_patch = len(prob) // self.num_queries
            prob = prob.view(num_patch, self.num_queries, -1)
            boxes = boxes.view(num_patch, self.num_queries, -1)
            for t in range(num_patch):
                ind = prob[t].squeeze(1).topk(100).indices
                prob_prenms = prob[t][ind]
                box_prenms = boxes[t][ind]
                lbl_prenms = torch.zeros_like(prob_prenms)
                nms_ind = batched_nms(box_prenms, prob_prenms[:, 0], lbl_prenms[:, 0], 0.7)[:20]
                b_scores.append(prob_prenms[nms_ind])
                b_boxes.append(box_prenms[nms_ind])

                b_names += [name[t][int(i)] for i in nms_ind]
                b_start_id += [start_id[t][int(i)] for i in nms_ind]
                b_end_id += [end_id[t][int(i)] for i in nms_ind]
            b_scores = torch.cat(b_scores).cpu().squeeze(1)
            b_boxes = torch.cat(b_boxes).cpu()
            out = {'scores': b_scores, 'boxes': b_boxes, 'names': b_names,
                   'start_id': b_start_id, 'end_id': b_end_id}
            results.append(out)
        return results