Plachta commited on
Commit
b298540
1 Parent(s): 8e9f709

Update models/grasp_mods.py

Browse files
Files changed (1) hide show
  1. models/grasp_mods.py +64 -13
models/grasp_mods.py CHANGED
@@ -68,25 +68,65 @@ def modify_grasp_loss_forward(self):
68
 
69
  return losses
70
 
71
- def modified_loss_boxes(outputs, targets, indices, num_boxes):
72
 
73
  if "pred_boxes" not in outputs:
74
  raise KeyError("No predicted boxes found in outputs")
75
  idx = self._get_source_permutation_idx(indices)
76
  source_boxes = outputs["pred_boxes"][idx]
77
  target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
78
-
79
- loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
 
 
 
 
80
 
81
  losses = {}
82
  losses["loss_bbox"] = loss_bbox.sum() / num_boxes
83
-
84
- loss_giou = 1 - torch.diag(
85
- generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
86
- )
 
 
 
 
 
87
  losses["loss_giou"] = loss_giou.sum() / num_boxes
88
  return losses
89
- return modified_loss_labels, modified_loss_boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def modify_forward(self):
92
  """
@@ -127,7 +167,7 @@ def modify_forward(self):
127
  ):
128
  input_images = torch.stack([x["image"] for x in batched_input], dim=0)
129
  image_embeddings = self.image_encoder(input_images)
130
-
131
  outputs = []
132
  srcs = []
133
  for image_record, curr_embedding in zip(batched_input, image_embeddings):
@@ -162,13 +202,17 @@ def modify_forward(self):
162
  grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
163
  # repeat to batchsize
164
  grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
 
 
 
 
165
  grasp_decoder_outputs = self.grasp_decoder(
166
  inputs_embeds=torch.zeros_like(grasp_query_pe),
167
  attention_mask=None,
168
  position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
169
  query_position_embeddings=grasp_query_pe,
170
  encoder_hidden_states=grasp_encoder_hidden_states,
171
- encoder_attention_mask=None,
172
  output_attentions=False,
173
  output_hidden_states=False,
174
  return_dict=True,
@@ -198,14 +242,14 @@ def modify_forward(self):
198
  eos_coef=config.eos_coefficient,
199
  losses=losses,
200
  )
201
- criterion.loss_labels, criterion.loss_boxes = modify_grasp_loss_forward(criterion)
202
  criterion.to(self.device)
203
  # Third: compute the losses, based on outputs and labels
204
  outputs_loss = {}
205
  outputs_loss["logits"] = grasp_logits
206
  outputs_loss["pred_boxes"] = pred_grasps
207
 
208
- grasp_loss_dict = criterion(outputs_loss, grasp_labels)
209
  # Fourth: compute total loss, as a weighted sum of the various losses
210
  weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
211
  weight_dict["loss_giou"] = config.giou_loss_coefficient
@@ -282,6 +326,8 @@ def add_inference_method(self):
282
 
283
  n_queries = iou_predictions.size(0)
284
 
 
 
285
  # forward grasp decoder here
286
  # 1. Get encoder hidden states
287
  grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
@@ -289,13 +335,18 @@ def add_inference_method(self):
289
  grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
290
  # repeat to batchsize
291
  grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
 
 
 
 
 
292
  grasp_decoder_outputs = self.grasp_decoder(
293
  inputs_embeds=torch.zeros_like(grasp_query_pe),
294
  attention_mask=None,
295
  position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
296
  query_position_embeddings=grasp_query_pe,
297
  encoder_hidden_states=grasp_encoder_hidden_states,
298
- encoder_attention_mask=None,
299
  output_attentions=False,
300
  output_hidden_states=False,
301
  return_dict=True,
 
68
 
69
  return losses
70
 
71
+ def modified_loss_boxes(outputs, targets, indices, num_boxes, ignore_wh=False):
72
 
73
  if "pred_boxes" not in outputs:
74
  raise KeyError("No predicted boxes found in outputs")
75
  idx = self._get_source_permutation_idx(indices)
76
  source_boxes = outputs["pred_boxes"][idx]
77
  target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
78
+ if not ignore_wh:
79
+ loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
80
+ else:
81
+ source_xytheta = source_boxes[:, [0, 1, 4]]
82
+ target_xytheta = target_boxes[:, [0, 1, 4]]
83
+ loss_bbox = nn.functional.l1_loss(source_xytheta, target_xytheta, reduction="none") * 5 / 3
84
 
85
  losses = {}
86
  losses["loss_bbox"] = loss_bbox.sum() / num_boxes
87
+ if not ignore_wh:
88
+ loss_giou = 1 - torch.diag(
89
+ generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
90
+ )
91
+ else:
92
+ source_boxes[:, -2:] = target_boxes[:, -2:].clone()
93
+ source_corners = center_to_corners_format(source_boxes[:, :4])
94
+ target_corners = center_to_corners_format(target_boxes[:, :4])
95
+ loss_giou = 1 - torch.diag(generalized_box_iou(source_corners, target_corners))
96
  losses["loss_giou"] = loss_giou.sum() / num_boxes
97
  return losses
98
+ def modified_forward(outputs, targets, ignore_wh=False):
99
+ """
100
+ This performs the loss computation.
101
+
102
+ Args:
103
+ outputs (`dict`, *optional*):
104
+ Dictionary of tensors, see the output specification of the model for the format.
105
+ targets (`List[dict]`, *optional*):
106
+ List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
107
+ losses applied, see each loss' doc.
108
+ """
109
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
110
+
111
+ # Retrieve the matching between the outputs of the last layer and the targets
112
+ indices = self.matcher(outputs_without_aux, targets)
113
+
114
+ # Compute the average number of target boxes across all nodes, for normalization purposes
115
+ num_boxes = sum(len(t["class_labels"]) for t in targets)
116
+ num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
117
+ # (Niels): comment out function below, distributed training to be added
118
+ # if is_dist_avail_and_initialized():
119
+ # torch.distributed.all_reduce(num_boxes)
120
+ # (Niels) in original implementation, num_boxes is divided by get_world_size()
121
+ num_boxes = torch.clamp(num_boxes, min=1).item()
122
+
123
+ # Compute all the requested losses
124
+ losses = {}
125
+ losses.update(self.loss_labels(outputs, targets, indices, num_boxes))
126
+ losses.update(self.loss_boxes(outputs, targets, indices, num_boxes, ignore_wh))
127
+
128
+ return losses
129
+ return modified_loss_labels, modified_loss_boxes, modified_forward
130
 
131
  def modify_forward(self):
132
  """
 
167
  ):
168
  input_images = torch.stack([x["image"] for x in batched_input], dim=0)
169
  image_embeddings = self.image_encoder(input_images)
170
+ batch_size = len(batched_input)
171
  outputs = []
172
  srcs = []
173
  for image_record, curr_embedding in zip(batched_input, image_embeddings):
 
202
  grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
203
  # repeat to batchsize
204
  grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
205
+ pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
206
+ downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), mode='nearest').squeeze(1).bool()
207
+ downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64*64).contiguous()
208
+ grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64*64, 256).contiguous()
209
  grasp_decoder_outputs = self.grasp_decoder(
210
  inputs_embeds=torch.zeros_like(grasp_query_pe),
211
  attention_mask=None,
212
  position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
213
  query_position_embeddings=grasp_query_pe,
214
  encoder_hidden_states=grasp_encoder_hidden_states,
215
+ encoder_attention_mask=downsampled_pixel_masks,
216
  output_attentions=False,
217
  output_hidden_states=False,
218
  return_dict=True,
 
242
  eos_coef=config.eos_coefficient,
243
  losses=losses,
244
  )
245
+ criterion.loss_labels, criterion.loss_boxes, criterion.forward = modify_grasp_loss_forward(criterion)
246
  criterion.to(self.device)
247
  # Third: compute the losses, based on outputs and labels
248
  outputs_loss = {}
249
  outputs_loss["logits"] = grasp_logits
250
  outputs_loss["pred_boxes"] = pred_grasps
251
 
252
+ grasp_loss_dict = criterion(outputs_loss, grasp_labels, ignore_wh=batched_input[0].get("ignore_wh", False))
253
  # Fourth: compute total loss, as a weighted sum of the various losses
254
  weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
255
  weight_dict["loss_giou"] = config.giou_loss_coefficient
 
326
 
327
  n_queries = iou_predictions.size(0)
328
 
329
+ batch_size = n_queries
330
+
331
  # forward grasp decoder here
332
  # 1. Get encoder hidden states
333
  grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
 
335
  grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
336
  # repeat to batchsize
337
  grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
338
+ pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
339
+ downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64),
340
+ mode='nearest').squeeze(1).bool()
341
+ downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous()
342
+ grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64 * 64, 256).contiguous()
343
  grasp_decoder_outputs = self.grasp_decoder(
344
  inputs_embeds=torch.zeros_like(grasp_query_pe),
345
  attention_mask=None,
346
  position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
347
  query_position_embeddings=grasp_query_pe,
348
  encoder_hidden_states=grasp_encoder_hidden_states,
349
+ encoder_attention_mask=downsampled_pixel_masks,
350
  output_attentions=False,
351
  output_hidden_states=False,
352
  return_dict=True,