File size: 14,241 Bytes
fcdfd72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
Add additional grasp decoder for Segment Anything model.
The structure should follow the grasp decoder structure in GraspDETR.
"""
import torch
import torch.nn as nn
from transformers.models.detr.configuration_detr import DetrConfig
from transformers.models.detr.modeling_detr import DetrHungarianMatcher, DetrLoss, DetrSegmentationOutput, DetrDecoder, sigmoid_focal_loss, dice_loss
from typing import Any, Dict, List, Tuple
from transformers.models.detr.modeling_detr import generalized_box_iou
from transformers.image_transforms import center_to_corners_format
from scipy.optimize import linear_sum_assignment

def modify_matcher_forward(self):
    @torch.no_grad()
    def matcher_forward(outputs, targets):

        batch_size, num_queries = outputs["logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]

        # Also concat the target labels and boxes
        target_ids = torch.cat([v["class_labels"] for v in targets])
        target_bbox = torch.cat([v["boxes"] for v in targets])

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        class_cost = -out_prob[:, target_ids]

        # Compute the L1 cost between boxes
        bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)

        # Compute the giou cost between boxes
        giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox[:, :4]), center_to_corners_format(target_bbox[:, :4]))

        # Final cost matrix
        cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
        cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()

        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
    return matcher_forward

def modify_grasp_loss_forward(self):
    def modified_loss_labels(outputs, targets, indices, num_boxes):
        """
        Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim
        [nb_target_boxes]
        """
        num_classes = 1  # model v9 always use class agnostic grasp
        if "logits" not in outputs:
            raise KeyError("No logits were found in the outputs")
        source_logits = outputs["logits"]

        idx = self._get_source_permutation_idx(indices)
        target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(
            source_logits.shape[:2], num_classes, dtype=torch.int64, device=source_logits.device
        )
        target_classes[idx] = target_classes_o

        loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes)
        losses = {"loss_ce": loss_ce}

        return losses

    def modified_loss_boxes(outputs, targets, indices, num_boxes):

        if "pred_boxes" not in outputs:
            raise KeyError("No predicted boxes found in outputs")
        idx = self._get_source_permutation_idx(indices)
        source_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")

        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / num_boxes

        loss_giou = 1 - torch.diag(
            generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
        )
        losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses
    return modified_loss_labels, modified_loss_boxes

def modify_forward(self):
    """
    Modify the following methods to make SAM perform grasp detection after segmentation:
        1. Add a parallel decoder for grasping detection: 1(+1) classes, 5 values to regress (bbox & rotation)
    Returns:
        Modified model
    """
    # 1. We instantiate a new module in self.base_model, as another decoder
    self.grasp_decoder_config = DetrConfig()
    self.grasp_decoder = DetrDecoder(self.grasp_decoder_config).to(self.device)
    self.grasp_query_position_embeddings = nn.Embedding(20, 256).to(self.device)
    # 2. Base model forward method is not directly used, no modification needs to be done
    # self.detr.model.forward = modify_base_model_forward(self.detr.model)
    # 3. Add additional classification head & bbox regression head for grasp_decoder output
    self.grasp_predictor = torch.nn.Sequential(
        torch.nn.Linear(256, 256),
        torch.nn.Linear(256, 256),
        torch.nn.Linear(256, 5)
    ).to(self.device)
    self.grasp_label_classifier = torch.nn.Linear(256, 2).to(self.device)
    # 4. Add positional embedding
    # name it as grasp_img_pos_embed to avoid name conflict
    class ImagePosEmbed(nn.Module):
        def __init__(self, img_size=64, hidden_dim=256):
            super().__init__()
            self.pos_embed = nn.Parameter(
                torch.randn(1, img_size, img_size, hidden_dim)
            )
        def forward(self, x):
            return x + self.pos_embed

    self.grasp_img_pos_embed = ImagePosEmbed().to(self.device)

    def modified_forward(
            batched_input: List[Dict[str, Any]],
            multimask_output: bool,
    ):
        input_images = torch.stack([x["image"] for x in batched_input], dim=0)
        image_embeddings = self.image_encoder(input_images)

        outputs = []
        srcs = []
        for image_record, curr_embedding in zip(batched_input, image_embeddings):
            if "point_coords" in image_record:
                points = (image_record["point_coords"], image_record["point_labels"])
            else:
                points = None
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=points,
                boxes=image_record.get("boxes", None),
                masks=image_record.get("mask_inputs", None),
            )
            low_res_masks, iou_predictions, src = self.mask_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
            outputs.append(
                {
                    "iou_predictions": iou_predictions,
                    "low_res_logits": low_res_masks,
                }
            )
            srcs.append(src[0])
        srcs = torch.stack(srcs, dim=0)
        # forward grasp decoder here
        # 1. Get encoder hidden states
        grasp_encoder_hidden_states = self.grasp_img_pos_embed(srcs.permute(0, 2, 3, 1))
        # 2. Get query embeddings
        grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
        # repeat to batchsize
        grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
        grasp_decoder_outputs = self.grasp_decoder(
            inputs_embeds=torch.zeros_like(grasp_query_pe),
            attention_mask=None,
            position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
            query_position_embeddings=grasp_query_pe,
            encoder_hidden_states=grasp_encoder_hidden_states,
            encoder_attention_mask=None,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )
        grasp_sequence_output = grasp_decoder_outputs[0]
        grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
        pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()

        # 3. Calculate loss
        loss, loss_dict = 0, {}
        if "grasp_labels" in batched_input[0]:
            config = self.grasp_decoder_config
            grasp_labels = [{
                "class_labels": torch.zeros([len(x["grasp_labels"])], dtype=torch.long).to(self.device),
                "boxes": x["grasp_labels"],
            } for x in batched_input]
            # First: create the matcher
            matcher = DetrHungarianMatcher(
                class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost
            )
            matcher.forward = modify_matcher_forward(matcher)
            # Second: create the criterion
            losses = ["labels", "boxes"]
            criterion = DetrLoss(
                matcher=matcher,
                num_classes=config.num_labels,
                eos_coef=config.eos_coefficient,
                losses=losses,
            )
            criterion.loss_labels, criterion.loss_boxes = modify_grasp_loss_forward(criterion)
            criterion.to(self.device)
            # Third: compute the losses, based on outputs and labels
            outputs_loss = {}
            outputs_loss["logits"] = grasp_logits
            outputs_loss["pred_boxes"] = pred_grasps

            grasp_loss_dict = criterion(outputs_loss, grasp_labels)
            # Fourth: compute total loss, as a weighted sum of the various losses
            weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
            weight_dict["loss_giou"] = config.giou_loss_coefficient
            if config.auxiliary_loss:
                aux_weight_dict = {}
                for i in range(config.decoder_layers - 1):
                    aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
                weight_dict.update(aux_weight_dict)
            grasp_loss = sum(grasp_loss_dict[k] * weight_dict[k] for k in grasp_loss_dict.keys() if k in weight_dict)

            # merge grasp branch loss into variable loss & loss_dict
            loss += grasp_loss
            loss_dict.update(grasp_loss_dict)
        pred_masks = self.postprocess_masks(
            torch.cat([x['low_res_logits'] for x in outputs], dim=0),
            input_size=image_record["image"].shape[-2:],
            original_size=(1024, 1024),
        )
        if 'masks' in batched_input[0]:
            # 4. Calculate segmentation loss
            sf_loss = sigmoid_focal_loss(pred_masks.flatten(1),
                torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
            d_loss = dice_loss(pred_masks.flatten(1),
                torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input))
            loss += sf_loss + d_loss
            loss_dict["sf_loss"] = sf_loss
            loss_dict["d_loss"] = d_loss
        return DetrSegmentationOutput(
            loss=loss,
            loss_dict=loss_dict,
            logits=grasp_logits,
            pred_boxes=pred_grasps,
            pred_masks=pred_masks,
        )

    return modified_forward

def add_inference_method(self):
    def infer(
            batched_input: List[Dict[str, Any]],
            multimask_output: bool,
    ):
        input_images = torch.stack([x["image"] for x in batched_input], dim=0)
        image_embeddings = self.image_encoder(input_images)

        outputs = []
        srcs = []
        curr_embedding = image_embeddings[0]
        image_record = batched_input[0]

        if "point_coords" in image_record:
            points = (image_record["point_coords"], image_record["point_labels"])
        else:
            points = None
        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=points,
            boxes=image_record.get("boxes", None),
            masks=image_record.get("mask_inputs", None),
        )
        low_res_masks, iou_predictions, src = self.mask_decoder(
            image_embeddings=curr_embedding.unsqueeze(0),
            image_pe=self.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=multimask_output,
        )
        outputs.append(
            {
                "iou_predictions": iou_predictions,
                "low_res_logits": low_res_masks,
            }
        )
        srcs.append(src[0])

        n_queries = iou_predictions.size(0)

        # forward grasp decoder here
        # 1. Get encoder hidden states
        grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
        # 2. Get query embeddings
        grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
        # repeat to batchsize
        grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
        grasp_decoder_outputs = self.grasp_decoder(
            inputs_embeds=torch.zeros_like(grasp_query_pe),
            attention_mask=None,
            position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
            query_position_embeddings=grasp_query_pe,
            encoder_hidden_states=grasp_encoder_hidden_states,
            encoder_attention_mask=None,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=True,
        )
        grasp_sequence_output = grasp_decoder_outputs[0]
        grasp_logits = self.grasp_label_classifier(grasp_sequence_output)
        pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid()
        pred_masks = self.postprocess_masks(
            torch.cat([x['low_res_logits'] for x in outputs], dim=0),
            input_size=image_record["image"].shape[-2:],
            original_size=(1024, 1024),
        )
        return DetrSegmentationOutput(
            loss=0,
            loss_dict={},
            logits=grasp_logits,
            pred_boxes=pred_grasps,
            pred_masks=pred_masks,
        )
    return infer