Spaces:
Running
Running
Update models/grasp_mods.py
Browse files- models/grasp_mods.py +64 -13
models/grasp_mods.py
CHANGED
@@ -68,25 +68,65 @@ def modify_grasp_loss_forward(self):
|
|
68 |
|
69 |
return losses
|
70 |
|
71 |
-
def modified_loss_boxes(outputs, targets, indices, num_boxes):
|
72 |
|
73 |
if "pred_boxes" not in outputs:
|
74 |
raise KeyError("No predicted boxes found in outputs")
|
75 |
idx = self._get_source_permutation_idx(indices)
|
76 |
source_boxes = outputs["pred_boxes"][idx]
|
77 |
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
80 |
|
81 |
losses = {}
|
82 |
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
87 |
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
88 |
return losses
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
def modify_forward(self):
|
92 |
"""
|
@@ -127,7 +167,7 @@ def modify_forward(self):
|
|
127 |
):
|
128 |
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
|
129 |
image_embeddings = self.image_encoder(input_images)
|
130 |
-
|
131 |
outputs = []
|
132 |
srcs = []
|
133 |
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
@@ -162,13 +202,17 @@ def modify_forward(self):
|
|
162 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
163 |
# repeat to batchsize
|
164 |
grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
|
|
|
|
|
|
|
|
|
165 |
grasp_decoder_outputs = self.grasp_decoder(
|
166 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
167 |
attention_mask=None,
|
168 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
169 |
query_position_embeddings=grasp_query_pe,
|
170 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
171 |
-
encoder_attention_mask=
|
172 |
output_attentions=False,
|
173 |
output_hidden_states=False,
|
174 |
return_dict=True,
|
@@ -198,14 +242,14 @@ def modify_forward(self):
|
|
198 |
eos_coef=config.eos_coefficient,
|
199 |
losses=losses,
|
200 |
)
|
201 |
-
criterion.loss_labels, criterion.loss_boxes = modify_grasp_loss_forward(criterion)
|
202 |
criterion.to(self.device)
|
203 |
# Third: compute the losses, based on outputs and labels
|
204 |
outputs_loss = {}
|
205 |
outputs_loss["logits"] = grasp_logits
|
206 |
outputs_loss["pred_boxes"] = pred_grasps
|
207 |
|
208 |
-
grasp_loss_dict = criterion(outputs_loss, grasp_labels)
|
209 |
# Fourth: compute total loss, as a weighted sum of the various losses
|
210 |
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
211 |
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
@@ -282,6 +326,8 @@ def add_inference_method(self):
|
|
282 |
|
283 |
n_queries = iou_predictions.size(0)
|
284 |
|
|
|
|
|
285 |
# forward grasp decoder here
|
286 |
# 1. Get encoder hidden states
|
287 |
grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
|
@@ -289,13 +335,18 @@ def add_inference_method(self):
|
|
289 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
290 |
# repeat to batchsize
|
291 |
grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
292 |
grasp_decoder_outputs = self.grasp_decoder(
|
293 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
294 |
attention_mask=None,
|
295 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
296 |
query_position_embeddings=grasp_query_pe,
|
297 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
298 |
-
encoder_attention_mask=
|
299 |
output_attentions=False,
|
300 |
output_hidden_states=False,
|
301 |
return_dict=True,
|
|
|
68 |
|
69 |
return losses
|
70 |
|
71 |
+
def modified_loss_boxes(outputs, targets, indices, num_boxes, ignore_wh=False):
|
72 |
|
73 |
if "pred_boxes" not in outputs:
|
74 |
raise KeyError("No predicted boxes found in outputs")
|
75 |
idx = self._get_source_permutation_idx(indices)
|
76 |
source_boxes = outputs["pred_boxes"][idx]
|
77 |
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)
|
78 |
+
if not ignore_wh:
|
79 |
+
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none")
|
80 |
+
else:
|
81 |
+
source_xytheta = source_boxes[:, [0, 1, 4]]
|
82 |
+
target_xytheta = target_boxes[:, [0, 1, 4]]
|
83 |
+
loss_bbox = nn.functional.l1_loss(source_xytheta, target_xytheta, reduction="none") * 5 / 3
|
84 |
|
85 |
losses = {}
|
86 |
losses["loss_bbox"] = loss_bbox.sum() / num_boxes
|
87 |
+
if not ignore_wh:
|
88 |
+
loss_giou = 1 - torch.diag(
|
89 |
+
generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4]))
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
source_boxes[:, -2:] = target_boxes[:, -2:].clone()
|
93 |
+
source_corners = center_to_corners_format(source_boxes[:, :4])
|
94 |
+
target_corners = center_to_corners_format(target_boxes[:, :4])
|
95 |
+
loss_giou = 1 - torch.diag(generalized_box_iou(source_corners, target_corners))
|
96 |
losses["loss_giou"] = loss_giou.sum() / num_boxes
|
97 |
return losses
|
98 |
+
def modified_forward(outputs, targets, ignore_wh=False):
|
99 |
+
"""
|
100 |
+
This performs the loss computation.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
outputs (`dict`, *optional*):
|
104 |
+
Dictionary of tensors, see the output specification of the model for the format.
|
105 |
+
targets (`List[dict]`, *optional*):
|
106 |
+
List of dicts, such that `len(targets) == batch_size`. The expected keys in each dict depends on the
|
107 |
+
losses applied, see each loss' doc.
|
108 |
+
"""
|
109 |
+
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs"}
|
110 |
+
|
111 |
+
# Retrieve the matching between the outputs of the last layer and the targets
|
112 |
+
indices = self.matcher(outputs_without_aux, targets)
|
113 |
+
|
114 |
+
# Compute the average number of target boxes across all nodes, for normalization purposes
|
115 |
+
num_boxes = sum(len(t["class_labels"]) for t in targets)
|
116 |
+
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
|
117 |
+
# (Niels): comment out function below, distributed training to be added
|
118 |
+
# if is_dist_avail_and_initialized():
|
119 |
+
# torch.distributed.all_reduce(num_boxes)
|
120 |
+
# (Niels) in original implementation, num_boxes is divided by get_world_size()
|
121 |
+
num_boxes = torch.clamp(num_boxes, min=1).item()
|
122 |
+
|
123 |
+
# Compute all the requested losses
|
124 |
+
losses = {}
|
125 |
+
losses.update(self.loss_labels(outputs, targets, indices, num_boxes))
|
126 |
+
losses.update(self.loss_boxes(outputs, targets, indices, num_boxes, ignore_wh))
|
127 |
+
|
128 |
+
return losses
|
129 |
+
return modified_loss_labels, modified_loss_boxes, modified_forward
|
130 |
|
131 |
def modify_forward(self):
|
132 |
"""
|
|
|
167 |
):
|
168 |
input_images = torch.stack([x["image"] for x in batched_input], dim=0)
|
169 |
image_embeddings = self.image_encoder(input_images)
|
170 |
+
batch_size = len(batched_input)
|
171 |
outputs = []
|
172 |
srcs = []
|
173 |
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
|
|
202 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
203 |
# repeat to batchsize
|
204 |
grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1)
|
205 |
+
pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
|
206 |
+
downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64), mode='nearest').squeeze(1).bool()
|
207 |
+
downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64*64).contiguous()
|
208 |
+
grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64*64, 256).contiguous()
|
209 |
grasp_decoder_outputs = self.grasp_decoder(
|
210 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
211 |
attention_mask=None,
|
212 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
213 |
query_position_embeddings=grasp_query_pe,
|
214 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
215 |
+
encoder_attention_mask=downsampled_pixel_masks,
|
216 |
output_attentions=False,
|
217 |
output_hidden_states=False,
|
218 |
return_dict=True,
|
|
|
242 |
eos_coef=config.eos_coefficient,
|
243 |
losses=losses,
|
244 |
)
|
245 |
+
criterion.loss_labels, criterion.loss_boxes, criterion.forward = modify_grasp_loss_forward(criterion)
|
246 |
criterion.to(self.device)
|
247 |
# Third: compute the losses, based on outputs and labels
|
248 |
outputs_loss = {}
|
249 |
outputs_loss["logits"] = grasp_logits
|
250 |
outputs_loss["pred_boxes"] = pred_grasps
|
251 |
|
252 |
+
grasp_loss_dict = criterion(outputs_loss, grasp_labels, ignore_wh=batched_input[0].get("ignore_wh", False))
|
253 |
# Fourth: compute total loss, as a weighted sum of the various losses
|
254 |
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient}
|
255 |
weight_dict["loss_giou"] = config.giou_loss_coefficient
|
|
|
326 |
|
327 |
n_queries = iou_predictions.size(0)
|
328 |
|
329 |
+
batch_size = n_queries
|
330 |
+
|
331 |
# forward grasp decoder here
|
332 |
# 1. Get encoder hidden states
|
333 |
grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1))
|
|
|
335 |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device))
|
336 |
# repeat to batchsize
|
337 |
grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1)
|
338 |
+
pixel_masks = torch.cat([batched_input[i]['pixel_mask'] for i in range(len(batched_input))], dim=0)
|
339 |
+
downsampled_pixel_masks = nn.functional.interpolate(pixel_masks.unsqueeze(1).float(), size=(64, 64),
|
340 |
+
mode='nearest').squeeze(1).bool()
|
341 |
+
downsampled_pixel_masks = downsampled_pixel_masks.view(batch_size, 64 * 64).contiguous()
|
342 |
+
grasp_encoder_hidden_states = grasp_encoder_hidden_states.view(batch_size, 64 * 64, 256).contiguous()
|
343 |
grasp_decoder_outputs = self.grasp_decoder(
|
344 |
inputs_embeds=torch.zeros_like(grasp_query_pe),
|
345 |
attention_mask=None,
|
346 |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states),
|
347 |
query_position_embeddings=grasp_query_pe,
|
348 |
encoder_hidden_states=grasp_encoder_hidden_states,
|
349 |
+
encoder_attention_mask=downsampled_pixel_masks,
|
350 |
output_attentions=False,
|
351 |
output_hidden_states=False,
|
352 |
return_dict=True,
|