Spaces:
Running
Running
""" | |
Add additional grasp decoder for Segment Anything model. | |
The structure should follow the grasp decoder structure in GraspDETR. | |
""" | |
import torch | |
import torch.nn as nn | |
from transformers.models.detr.configuration_detr import DetrConfig | |
from transformers.models.detr.modeling_detr import DetrHungarianMatcher, DetrLoss, DetrSegmentationOutput, DetrDecoder, sigmoid_focal_loss, dice_loss | |
from typing import Any, Dict, List, Tuple | |
from transformers.models.detr.modeling_detr import generalized_box_iou | |
from transformers.image_transforms import center_to_corners_format | |
from scipy.optimize import linear_sum_assignment | |
def modify_matcher_forward(self): | |
def matcher_forward(outputs, targets): | |
batch_size, num_queries = outputs["logits"].shape[:2] | |
# We flatten to compute the cost matrices in a batch | |
out_prob = outputs["logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] | |
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] | |
# Also concat the target labels and boxes | |
target_ids = torch.cat([v["class_labels"] for v in targets]) | |
target_bbox = torch.cat([v["boxes"] for v in targets]) | |
# Compute the classification cost. Contrary to the loss, we don't use the NLL, | |
# but approximate it in 1 - proba[target class]. | |
# The 1 is a constant that doesn't change the matching, it can be ommitted. | |
class_cost = -out_prob[:, target_ids] | |
# Compute the L1 cost between boxes | |
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) | |
# Compute the giou cost between boxes | |
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox[:, :4]), center_to_corners_format(target_bbox[:, :4])) | |
# Final cost matrix | |
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost | |
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() | |
sizes = [len(v["boxes"]) for v in targets] | |
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))] | |
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] | |
return matcher_forward | |
def modify_grasp_loss_forward(self): | |
def modified_loss_labels(outputs, targets, indices, num_boxes): | |
""" | |
Classification loss (NLL) targets dicts must contain the key "class_labels" containing a tensor of dim | |
[nb_target_boxes] | |
""" | |
num_classes = 1 # model v9 always use class agnostic grasp | |
if "logits" not in outputs: | |
raise KeyError("No logits were found in the outputs") | |
source_logits = outputs["logits"] | |
idx = self._get_source_permutation_idx(indices) | |
target_classes_o = torch.cat([t["class_labels"][J] for t, (_, J) in zip(targets, indices)]) | |
target_classes = torch.full( | |
source_logits.shape[:2], num_classes, dtype=torch.int64, device=source_logits.device | |
) | |
target_classes[idx] = target_classes_o | |
loss_ce = nn.functional.cross_entropy(source_logits.transpose(1, 2), target_classes) | |
losses = {"loss_ce": loss_ce} | |
return losses | |
def modified_loss_boxes(outputs, targets, indices, num_boxes): | |
if "pred_boxes" not in outputs: | |
raise KeyError("No predicted boxes found in outputs") | |
idx = self._get_source_permutation_idx(indices) | |
source_boxes = outputs["pred_boxes"][idx] | |
target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
loss_bbox = nn.functional.l1_loss(source_boxes, target_boxes, reduction="none") | |
losses = {} | |
losses["loss_bbox"] = loss_bbox.sum() / num_boxes | |
loss_giou = 1 - torch.diag( | |
generalized_box_iou(center_to_corners_format(source_boxes[:, :4]), center_to_corners_format(target_boxes[:, :4])) | |
) | |
losses["loss_giou"] = loss_giou.sum() / num_boxes | |
return losses | |
return modified_loss_labels, modified_loss_boxes | |
def modify_forward(self): | |
""" | |
Modify the following methods to make SAM perform grasp detection after segmentation: | |
1. Add a parallel decoder for grasping detection: 1(+1) classes, 5 values to regress (bbox & rotation) | |
Returns: | |
Modified model | |
""" | |
# 1. We instantiate a new module in self.base_model, as another decoder | |
self.grasp_decoder_config = DetrConfig() | |
self.grasp_decoder = DetrDecoder(self.grasp_decoder_config).to(self.device) | |
self.grasp_query_position_embeddings = nn.Embedding(20, 256).to(self.device) | |
# 2. Base model forward method is not directly used, no modification needs to be done | |
# self.detr.model.forward = modify_base_model_forward(self.detr.model) | |
# 3. Add additional classification head & bbox regression head for grasp_decoder output | |
self.grasp_predictor = torch.nn.Sequential( | |
torch.nn.Linear(256, 256), | |
torch.nn.Linear(256, 256), | |
torch.nn.Linear(256, 5) | |
).to(self.device) | |
self.grasp_label_classifier = torch.nn.Linear(256, 2).to(self.device) | |
# 4. Add positional embedding | |
# name it as grasp_img_pos_embed to avoid name conflict | |
class ImagePosEmbed(nn.Module): | |
def __init__(self, img_size=64, hidden_dim=256): | |
super().__init__() | |
self.pos_embed = nn.Parameter( | |
torch.randn(1, img_size, img_size, hidden_dim) | |
) | |
def forward(self, x): | |
return x + self.pos_embed | |
self.grasp_img_pos_embed = ImagePosEmbed().to(self.device) | |
def modified_forward( | |
batched_input: List[Dict[str, Any]], | |
multimask_output: bool, | |
): | |
input_images = torch.stack([x["image"] for x in batched_input], dim=0) | |
image_embeddings = self.image_encoder(input_images) | |
outputs = [] | |
srcs = [] | |
for image_record, curr_embedding in zip(batched_input, image_embeddings): | |
if "point_coords" in image_record: | |
points = (image_record["point_coords"], image_record["point_labels"]) | |
else: | |
points = None | |
sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
points=points, | |
boxes=image_record.get("boxes", None), | |
masks=image_record.get("mask_inputs", None), | |
) | |
low_res_masks, iou_predictions, src = self.mask_decoder( | |
image_embeddings=curr_embedding.unsqueeze(0), | |
image_pe=self.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_output, | |
) | |
outputs.append( | |
{ | |
"iou_predictions": iou_predictions, | |
"low_res_logits": low_res_masks, | |
} | |
) | |
srcs.append(src[0]) | |
srcs = torch.stack(srcs, dim=0) | |
# forward grasp decoder here | |
# 1. Get encoder hidden states | |
grasp_encoder_hidden_states = self.grasp_img_pos_embed(srcs.permute(0, 2, 3, 1)) | |
# 2. Get query embeddings | |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device)) | |
# repeat to batchsize | |
grasp_query_pe = grasp_query_pe.repeat(len(batched_input), 1, 1) | |
grasp_decoder_outputs = self.grasp_decoder( | |
inputs_embeds=torch.zeros_like(grasp_query_pe), | |
attention_mask=None, | |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states), | |
query_position_embeddings=grasp_query_pe, | |
encoder_hidden_states=grasp_encoder_hidden_states, | |
encoder_attention_mask=None, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=True, | |
) | |
grasp_sequence_output = grasp_decoder_outputs[0] | |
grasp_logits = self.grasp_label_classifier(grasp_sequence_output) | |
pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid() | |
# 3. Calculate loss | |
loss, loss_dict = 0, {} | |
if "grasp_labels" in batched_input[0]: | |
config = self.grasp_decoder_config | |
grasp_labels = [{ | |
"class_labels": torch.zeros([len(x["grasp_labels"])], dtype=torch.long).to(self.device), | |
"boxes": x["grasp_labels"], | |
} for x in batched_input] | |
# First: create the matcher | |
matcher = DetrHungarianMatcher( | |
class_cost=config.class_cost, bbox_cost=config.bbox_cost, giou_cost=config.giou_cost | |
) | |
matcher.forward = modify_matcher_forward(matcher) | |
# Second: create the criterion | |
losses = ["labels", "boxes"] | |
criterion = DetrLoss( | |
matcher=matcher, | |
num_classes=config.num_labels, | |
eos_coef=config.eos_coefficient, | |
losses=losses, | |
) | |
criterion.loss_labels, criterion.loss_boxes = modify_grasp_loss_forward(criterion) | |
criterion.to(self.device) | |
# Third: compute the losses, based on outputs and labels | |
outputs_loss = {} | |
outputs_loss["logits"] = grasp_logits | |
outputs_loss["pred_boxes"] = pred_grasps | |
grasp_loss_dict = criterion(outputs_loss, grasp_labels) | |
# Fourth: compute total loss, as a weighted sum of the various losses | |
weight_dict = {"loss_ce": 1, "loss_bbox": config.bbox_loss_coefficient} | |
weight_dict["loss_giou"] = config.giou_loss_coefficient | |
if config.auxiliary_loss: | |
aux_weight_dict = {} | |
for i in range(config.decoder_layers - 1): | |
aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) | |
weight_dict.update(aux_weight_dict) | |
grasp_loss = sum(grasp_loss_dict[k] * weight_dict[k] for k in grasp_loss_dict.keys() if k in weight_dict) | |
# merge grasp branch loss into variable loss & loss_dict | |
loss += grasp_loss | |
loss_dict.update(grasp_loss_dict) | |
pred_masks = self.postprocess_masks( | |
torch.cat([x['low_res_logits'] for x in outputs], dim=0), | |
input_size=image_record["image"].shape[-2:], | |
original_size=(1024, 1024), | |
) | |
if 'masks' in batched_input[0]: | |
# 4. Calculate segmentation loss | |
sf_loss = sigmoid_focal_loss(pred_masks.flatten(1), | |
torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input)) | |
d_loss = dice_loss(pred_masks.flatten(1), | |
torch.stack([x['masks'] for x in batched_input], dim=0).unsqueeze(1).type(torch.float32).flatten(1), len(batched_input)) | |
loss += sf_loss + d_loss | |
loss_dict["sf_loss"] = sf_loss | |
loss_dict["d_loss"] = d_loss | |
return DetrSegmentationOutput( | |
loss=loss, | |
loss_dict=loss_dict, | |
logits=grasp_logits, | |
pred_boxes=pred_grasps, | |
pred_masks=pred_masks, | |
) | |
return modified_forward | |
def add_inference_method(self): | |
def infer( | |
batched_input: List[Dict[str, Any]], | |
multimask_output: bool, | |
): | |
input_images = torch.stack([x["image"] for x in batched_input], dim=0) | |
image_embeddings = self.image_encoder(input_images) | |
outputs = [] | |
srcs = [] | |
curr_embedding = image_embeddings[0] | |
image_record = batched_input[0] | |
if "point_coords" in image_record: | |
points = (image_record["point_coords"], image_record["point_labels"]) | |
else: | |
points = None | |
sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
points=points, | |
boxes=image_record.get("boxes", None), | |
masks=image_record.get("mask_inputs", None), | |
) | |
low_res_masks, iou_predictions, src = self.mask_decoder( | |
image_embeddings=curr_embedding.unsqueeze(0), | |
image_pe=self.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_output, | |
) | |
outputs.append( | |
{ | |
"iou_predictions": iou_predictions, | |
"low_res_logits": low_res_masks, | |
} | |
) | |
srcs.append(src[0]) | |
n_queries = iou_predictions.size(0) | |
# forward grasp decoder here | |
# 1. Get encoder hidden states | |
grasp_encoder_hidden_states = self.grasp_img_pos_embed(src.permute(0, 2, 3, 1)) | |
# 2. Get query embeddings | |
grasp_query_pe = self.grasp_query_position_embeddings(torch.arange(20).to(self.device)) | |
# repeat to batchsize | |
grasp_query_pe = grasp_query_pe.repeat(n_queries, 1, 1) | |
grasp_decoder_outputs = self.grasp_decoder( | |
inputs_embeds=torch.zeros_like(grasp_query_pe), | |
attention_mask=None, | |
position_embeddings=torch.zeros_like(grasp_encoder_hidden_states), | |
query_position_embeddings=grasp_query_pe, | |
encoder_hidden_states=grasp_encoder_hidden_states, | |
encoder_attention_mask=None, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=True, | |
) | |
grasp_sequence_output = grasp_decoder_outputs[0] | |
grasp_logits = self.grasp_label_classifier(grasp_sequence_output) | |
pred_grasps = self.grasp_predictor(grasp_sequence_output).sigmoid() | |
pred_masks = self.postprocess_masks( | |
torch.cat([x['low_res_logits'] for x in outputs], dim=0), | |
input_size=image_record["image"].shape[-2:], | |
original_size=(1024, 1024), | |
) | |
return DetrSegmentationOutput( | |
loss=0, | |
loss_dict={}, | |
logits=grasp_logits, | |
pred_boxes=pred_grasps, | |
pred_masks=pred_masks, | |
) | |
return infer |