|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This file provides the definition of the convolutional heads used to predict masks, as well as the losses |
|
""" |
|
import io |
|
from collections import defaultdict |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
|
|
import util.box_ops as box_ops |
|
from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list |
|
|
|
try: |
|
from panopticapi.utils import id2rgb, rgb2id |
|
except ImportError: |
|
pass |
|
|
|
|
|
class DETRsegm(nn.Module): |
|
def __init__(self, detr, freeze_detr=False): |
|
super().__init__() |
|
self.detr = detr |
|
|
|
if freeze_detr: |
|
for p in self.parameters(): |
|
p.requires_grad_(False) |
|
|
|
hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead |
|
self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0) |
|
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) |
|
|
|
def forward(self, samples: NestedTensor): |
|
if not isinstance(samples, NestedTensor): |
|
samples = nested_tensor_from_tensor_list(samples) |
|
features, pos = self.detr.backbone(samples) |
|
|
|
bs = features[-1].tensors.shape[0] |
|
|
|
src, mask = features[-1].decompose() |
|
src_proj = self.detr.input_proj(src) |
|
hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) |
|
|
|
outputs_class = self.detr.class_embed(hs) |
|
outputs_coord = self.detr.bbox_embed(hs).sigmoid() |
|
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} |
|
if self.detr.aux_loss: |
|
out["aux_outputs"] = [ |
|
{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) |
|
] |
|
|
|
|
|
bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) |
|
|
|
seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) |
|
outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) |
|
|
|
out["pred_masks"] = outputs_seg_masks |
|
return out |
|
|
|
|
|
class MaskHeadSmallConv(nn.Module): |
|
""" |
|
Simple convolutional head, using group norm. |
|
Upsampling is done using a FPN approach |
|
""" |
|
|
|
def __init__(self, dim, fpn_dims, context_dim): |
|
super().__init__() |
|
|
|
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] |
|
self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) |
|
self.gn1 = torch.nn.GroupNorm(8, dim) |
|
self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) |
|
self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) |
|
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) |
|
self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) |
|
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) |
|
self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) |
|
self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) |
|
self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) |
|
self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) |
|
|
|
self.dim = dim |
|
|
|
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) |
|
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) |
|
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.kaiming_uniform_(m.weight, a=1) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x, bbox_mask, fpns): |
|
def expand(tensor, length): |
|
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) |
|
|
|
x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) |
|
|
|
x = self.lay1(x) |
|
x = self.gn1(x) |
|
x = F.relu(x) |
|
x = self.lay2(x) |
|
x = self.gn2(x) |
|
x = F.relu(x) |
|
|
|
cur_fpn = self.adapter1(fpns[0]) |
|
if cur_fpn.size(0) != x.size(0): |
|
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) |
|
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") |
|
x = self.lay3(x) |
|
x = self.gn3(x) |
|
x = F.relu(x) |
|
|
|
cur_fpn = self.adapter2(fpns[1]) |
|
if cur_fpn.size(0) != x.size(0): |
|
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) |
|
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") |
|
x = self.lay4(x) |
|
x = self.gn4(x) |
|
x = F.relu(x) |
|
|
|
cur_fpn = self.adapter3(fpns[2]) |
|
if cur_fpn.size(0) != x.size(0): |
|
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0)) |
|
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") |
|
x = self.lay5(x) |
|
x = self.gn5(x) |
|
x = F.relu(x) |
|
|
|
x = self.out_lay(x) |
|
return x |
|
|
|
|
|
class MHAttentionMap(nn.Module): |
|
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" |
|
|
|
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True): |
|
super().__init__() |
|
self.num_heads = num_heads |
|
self.hidden_dim = hidden_dim |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) |
|
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) |
|
|
|
nn.init.zeros_(self.k_linear.bias) |
|
nn.init.zeros_(self.q_linear.bias) |
|
nn.init.xavier_uniform_(self.k_linear.weight) |
|
nn.init.xavier_uniform_(self.q_linear.weight) |
|
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 |
|
|
|
def forward(self, q, k, mask=None): |
|
q = self.q_linear(q) |
|
k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) |
|
qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) |
|
kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) |
|
weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) |
|
|
|
if mask is not None: |
|
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) |
|
weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) |
|
weights = self.dropout(weights) |
|
return weights |
|
|
|
|
|
def dice_loss(inputs, targets, num_boxes): |
|
""" |
|
Compute the DICE loss, similar to generalized IOU for masks |
|
Args: |
|
inputs: A float tensor of arbitrary shape. |
|
The predictions for each example. |
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
classification label for each element in inputs |
|
(0 for the negative class and 1 for the positive class). |
|
""" |
|
inputs = inputs.sigmoid() |
|
inputs = inputs.flatten(1) |
|
numerator = 2 * (inputs * targets).sum(1) |
|
denominator = inputs.sum(-1) + targets.sum(-1) |
|
loss = 1 - (numerator + 1) / (denominator + 1) |
|
return loss.sum() / num_boxes |
|
|
|
|
|
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): |
|
""" |
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. |
|
Args: |
|
inputs: A float tensor of arbitrary shape. |
|
The predictions for each example. |
|
targets: A float tensor with the same shape as inputs. Stores the binary |
|
classification label for each element in inputs |
|
(0 for the negative class and 1 for the positive class). |
|
alpha: (optional) Weighting factor in range (0,1) to balance |
|
positive vs negative examples. Default = -1 (no weighting). |
|
gamma: Exponent of the modulating factor (1 - p_t) to |
|
balance easy vs hard examples. |
|
Returns: |
|
Loss tensor |
|
""" |
|
prob = inputs.sigmoid() |
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") |
|
p_t = prob * targets + (1 - prob) * (1 - targets) |
|
loss = ce_loss * ((1 - p_t) ** gamma) |
|
|
|
if alpha >= 0: |
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
|
loss = alpha_t * loss |
|
|
|
return loss.mean(1).sum() / num_boxes |
|
|
|
|
|
class PostProcessSegm(nn.Module): |
|
def __init__(self, threshold=0.5): |
|
super().__init__() |
|
self.threshold = threshold |
|
|
|
@torch.no_grad() |
|
def forward(self, results, outputs, orig_target_sizes, max_target_sizes): |
|
assert len(orig_target_sizes) == len(max_target_sizes) |
|
max_h, max_w = max_target_sizes.max(0)[0].tolist() |
|
outputs_masks = outputs["pred_masks"].squeeze(2) |
|
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) |
|
outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() |
|
|
|
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): |
|
img_h, img_w = t[0], t[1] |
|
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) |
|
results[i]["masks"] = F.interpolate( |
|
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" |
|
).byte() |
|
|
|
return results |
|
|
|
|
|
class PostProcessPanoptic(nn.Module): |
|
"""This class converts the output of the model to the final panoptic result, in the format expected by the |
|
coco panoptic API """ |
|
|
|
def __init__(self, is_thing_map, threshold=0.85): |
|
""" |
|
Parameters: |
|
is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether |
|
the class is a thing (True) or a stuff (False) class |
|
threshold: confidence threshold: segments with confidence lower than this will be deleted |
|
""" |
|
super().__init__() |
|
self.threshold = threshold |
|
self.is_thing_map = is_thing_map |
|
|
|
def forward(self, outputs, processed_sizes, target_sizes=None): |
|
""" This function computes the panoptic prediction from the model's predictions. |
|
Parameters: |
|
outputs: This is a dict coming directly from the model. See the model doc for the content. |
|
processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the |
|
model, ie the size after data augmentation but before batching. |
|
target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size |
|
of each prediction. If left to None, it will default to the processed_sizes |
|
""" |
|
if target_sizes is None: |
|
target_sizes = processed_sizes |
|
assert len(processed_sizes) == len(target_sizes) |
|
out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] |
|
assert len(out_logits) == len(raw_masks) == len(target_sizes) |
|
preds = [] |
|
|
|
def to_tuple(tup): |
|
if isinstance(tup, tuple): |
|
return tup |
|
return tuple(tup.cpu().tolist()) |
|
|
|
for cur_logits, cur_masks, cur_boxes, size, target_size in zip( |
|
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes |
|
): |
|
|
|
scores, labels = cur_logits.softmax(-1).max(-1) |
|
keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) |
|
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) |
|
cur_scores = cur_scores[keep] |
|
cur_classes = cur_classes[keep] |
|
cur_masks = cur_masks[keep] |
|
cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0) |
|
cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) |
|
|
|
h, w = cur_masks.shape[-2:] |
|
assert len(cur_boxes) == len(cur_classes) |
|
|
|
|
|
|
|
cur_masks = cur_masks.flatten(1) |
|
stuff_equiv_classes = defaultdict(lambda: []) |
|
for k, label in enumerate(cur_classes): |
|
if not self.is_thing_map[label.item()]: |
|
stuff_equiv_classes[label.item()].append(k) |
|
|
|
def get_ids_area(masks, scores, dedup=False): |
|
|
|
|
|
|
|
m_id = masks.transpose(0, 1).softmax(-1) |
|
|
|
if m_id.shape[-1] == 0: |
|
|
|
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) |
|
else: |
|
m_id = m_id.argmax(-1).view(h, w) |
|
|
|
if dedup: |
|
|
|
for equiv in stuff_equiv_classes.values(): |
|
if len(equiv) > 1: |
|
for eq_id in equiv: |
|
m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) |
|
|
|
final_h, final_w = to_tuple(target_size) |
|
|
|
seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) |
|
seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) |
|
|
|
np_seg_img = ( |
|
torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() |
|
) |
|
m_id = torch.from_numpy(rgb2id(np_seg_img)) |
|
|
|
area = [] |
|
for i in range(len(scores)): |
|
area.append(m_id.eq(i).sum().item()) |
|
return area, seg_img |
|
|
|
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) |
|
if cur_classes.numel() > 0: |
|
|
|
while True: |
|
filtered_small = torch.as_tensor( |
|
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device |
|
) |
|
if filtered_small.any().item(): |
|
cur_scores = cur_scores[~filtered_small] |
|
cur_classes = cur_classes[~filtered_small] |
|
cur_masks = cur_masks[~filtered_small] |
|
area, seg_img = get_ids_area(cur_masks, cur_scores) |
|
else: |
|
break |
|
|
|
else: |
|
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) |
|
|
|
segments_info = [] |
|
for i, a in enumerate(area): |
|
cat = cur_classes[i].item() |
|
segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) |
|
del cur_classes |
|
|
|
with io.BytesIO() as out: |
|
seg_img.save(out, format="PNG") |
|
predictions = {"png_string": out.getvalue(), "segments_info": segments_info} |
|
preds.append(predictions) |
|
return preds |
|
|