Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from maskrcnn_benchmark.structures.bounding_box import BoxList | |
import cv2 | |
# TODO check if want to return a single BoxList or a composite | |
# object | |
class KEPostProcessor(nn.Module): | |
""" | |
From the results of the CNN, post process the kes | |
by taking the ke corresponding to the class with max | |
probability (which are of fixed size and directly output | |
by the CNN) and return the kes in the ke field of the BoxList. | |
If a keer object is passed, it will additionally | |
project the kes in the image according to the locations in boxes, | |
""" | |
def __init__(self, keer=None): | |
super(KEPostProcessor, self).__init__() | |
self.keer = keer | |
def forward(self, x, boxes): | |
""" | |
Arguments: | |
x (Tensor): the ke logits | |
boxes (list[BoxList]): bounding boxes that are used as | |
reference, one for ech image | |
Returns: | |
results (list[BoxList]): one BoxList for each image, containing | |
the extra field ke | |
""" | |
# ke_prob = x.sigmoid() | |
# select kes coresponding to the predicted classes | |
num_proposals = x.shape[0] | |
labels = [bbox.get_field("labels") for bbox in boxes] | |
labels = torch.cat(labels) | |
index = torch.arange(num_proposals, device=labels.device) | |
####### outputs | |
ke_prob = x[index] | |
# print("labels", labels) | |
# print("x",x.size()) | |
# print("ke_",ke_prob.size()) | |
# assert(0) | |
boxes_per_image = [len(box) for box in boxes] | |
ke_prob = ke_prob.split(boxes_per_image, dim=0) | |
if self.keer: | |
ke_prob = self.keer(ke_prob, boxes) | |
results = [] | |
for prob, box in zip(ke_prob, boxes): | |
bbox = BoxList(box.bbox, box.size, mode="xyxy") | |
for field in box.fields(): | |
bbox.add_field(field, box.get_field(field)) | |
bbox.add_field("ke", prob) | |
results.append(bbox) | |
return results | |
class KEPostProcessorCOCOFormat(KEPostProcessor): | |
""" | |
From the results of the CNN, post process the results | |
so that the kes are pasted in the image, and | |
additionally convert the results to COCO format. | |
""" | |
def forward(self, x, boxes): | |
# import pycocotools.mask as mask_util | |
import numpy as np | |
results = super(KEPostProcessorCOCOFormat, self).forward(x, boxes) | |
for result in results: | |
kes = result.get_field("ke").cpu() | |
rles = [ | |
ke_util.encode(np.array(ke[0, :, :, np.newaxis], order="F"))[0] | |
for ke in kes | |
] | |
for rle in rles: | |
rle["counts"] = rle["counts"].decode("utf-8") | |
result.add_field("ke", rles) | |
return results | |
# the next two functions should be merged inside keer | |
# but are kept here for the moment while we need them | |
# temporarily gor paste_ke_in_image | |
def expand_boxes(boxes, scale): | |
w_half = (boxes[:, 2] - boxes[:, 0]) * .5 | |
h_half = (boxes[:, 3] - boxes[:, 1]) * .5 | |
x_c = (boxes[:, 2] + boxes[:, 0]) * .5 | |
y_c = (boxes[:, 3] + boxes[:, 1]) * .5 | |
w_half *= scale | |
h_half *= scale | |
boxes_exp = torch.zeros_like(boxes) | |
boxes_exp[:, 0] = x_c - w_half | |
boxes_exp[:, 2] = x_c + w_half | |
boxes_exp[:, 1] = y_c - h_half | |
boxes_exp[:, 3] = y_c + h_half | |
return boxes_exp | |
def expand_kes(ke, padding): | |
N = ke.shape[0] | |
M = ke.shape[-1] | |
# print("NM ", N ,M) | |
pad2 = 2 * padding | |
scale = float(M + pad2) / M | |
padded_ke = ke.new_zeros((N, 1, M + pad2, M + pad2)) | |
padded_ke[:, :, padding:-padding, padding:-padding] = ke | |
# print("padded_ke ", padded_ke.size()) | |
return padded_ke, scale | |
def paste_ke_in_image(ke, box, im_h, im_w, thresh=0.5, padding=1): | |
# print("ke ", ke.size(), ke[None].size()) | |
padded_ke, scale = expand_kes(ke[None], padding=padding) | |
ke = padded_ke[0, 0] | |
box = expand_boxes(box[None], scale)[0] | |
box = box.to(dtype=torch.int32) | |
TO_REMOVE = 1 | |
w = int(box[2] - box[0] + TO_REMOVE) | |
h = int(box[3] - box[1] + TO_REMOVE) | |
w = max(w, 1) | |
h = max(h, 1) | |
# Set shape to [batchxCxHxW] | |
ke = ke.expand((1, 1, -1, -1)) | |
# print("ke 2", ke.size()) | |
# Resize ke | |
ke = ke.to(torch.float32) | |
ke = F.interpolate(ke, size=(h, w), mode='bilinear', align_corners=False) | |
ke = ke[0][0] | |
# print("ke3 ", ke.size()) | |
if thresh >= 0: | |
ke = ke > thresh | |
else: | |
# for visualization and debugging, we also | |
# allow it to return an unmodified ke | |
ke = (ke * 255).to(torch.uint8) | |
im_ke = torch.zeros((im_h, im_w), dtype=torch.uint8) | |
x_0 = max(box[0], 0) | |
x_1 = min(box[2] + 1, im_w) | |
y_0 = max(box[1], 0) | |
y_1 = min(box[3] + 1, im_h) | |
im_ke[y_0:y_1, x_0:x_1] = ke[ | |
(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) | |
] | |
# print("im_ke ", im_ke.size()) | |
return im_ke | |
def scores_to_probs(scores): | |
"""Transforms CxHxW of scores to probabilities spatially.""" | |
channels = scores.shape[0] | |
for c in range(channels): | |
temp = scores[c, :, :] | |
max_score = temp.max() | |
temp = np.exp(temp - max_score) / np.sum(np.exp(temp - max_score)) | |
scores[c, :, :] = temp | |
return scores | |
def heatmaps_to_kes(maps, rois): | |
# This function converts a discrete image coordinate in a HEATMAP_SIZE x | |
# HEATMAP_SIZE image to a continuous ke coordinate. We maintain | |
# consistency with ke_to_heatmap_labels by using the conversion from | |
# Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a | |
# continuous coordinate. | |
rois =rois.numpy() | |
maps = maps.numpy() | |
offset_x = rois[:, 0] | |
offset_y = rois[:, 1] | |
widths = rois[:, 2] - rois[:, 0] | |
heights = rois[:, 3] - rois[:, 1] | |
widths = np.maximum(widths, 1) | |
heights = np.maximum(heights, 1) | |
widths_ceil = np.ceil(widths) | |
heights_ceil = np.ceil(heights) | |
# NCHW to NHWC for use with OpenCV | |
maps = np.transpose(maps, [0, 2, 3, 1]) | |
# min_size = cfg.KRCNN.INFERENCE_MIN_SIZE | |
num_kes = 10 | |
xy_preds = np.zeros( | |
(len(rois), 4, num_kes), dtype=np.float32) | |
for i in range(len(rois)): | |
# if min_size > 0: | |
# roi_map_width = int(np.maximum(widths_ceil[i], min_size)) | |
# roi_map_height = int(np.maximum(heights_ceil[i], min_size)) | |
# else: | |
# roi_map_width = widths_ceil[i] | |
# roi_map_height = heights_ceil[i] | |
roi_map_width = int(widths_ceil[i]) | |
roi_map_height = int(heights_ceil[i]) | |
width_correction = widths[i] / roi_map_width | |
height_correction = heights[i] / roi_map_height | |
roi_map = cv2.resize( | |
maps[i], (roi_map_width, roi_map_height), | |
interpolation=cv2.INTER_CUBIC) | |
# Bring back to CHW | |
roi_map = np.transpose(roi_map, [2, 0, 1]) | |
roi_map_probs = scores_to_probs(roi_map.copy()) | |
w = roi_map.shape[2] | |
for k in range(num_kes): | |
pos = roi_map[k, :, :].argmax() | |
x_int = pos % w | |
y_int = (pos - x_int) // w | |
assert (roi_map_probs[k, y_int, x_int] == | |
roi_map_probs[k, :, :].max()) | |
x = (x_int + 0.5) * width_correction | |
y = (y_int + 0.5) * height_correction | |
xy_preds[i, 0, k] = x + offset_x[i] | |
xy_preds[i, 1, k] = y + offset_y[i] | |
xy_preds[i, 2, k] = roi_map[k, y_int, x_int] | |
xy_preds[i, 3, k] = roi_map_probs[k, y_int, x_int] | |
return xy_preds | |
class KEer(object): | |
""" | |
Projects a set of kes in an image on the locations | |
specified by the bounding boxes | |
""" | |
def __init__(self, threshold=0.5, padding=1): | |
self.threshold = threshold | |
self.padding = padding | |
def forward_single_image(self, kes, boxes): | |
boxes = boxes.convert("xyxy") | |
im_w, im_h = boxes.size | |
# print("KEer kes.size()", kes.size(), kes[0].size(), kes[0][0].size()) | |
# assert(0) | |
# res = [ | |
# paste_ke_in_image(ke[0], box, im_h, im_w, self.threshold, self.padding) | |
# for ke, box in zip(kes, boxes.bbox) | |
# ] | |
res = heatmaps_to_kes(kes, boxes.bbox) | |
if len(res) > 0: | |
# res = torch.stack(res, dim=0)[:, None] | |
res = torch.from_numpy(res) | |
else: | |
res = kes.new_empty((0, 1, kes.shape[-2], kes.shape[-1])) | |
print("res inference.py", res.size()) | |
return res | |
def __call__(self, kes, boxes): | |
if isinstance(boxes, BoxList): | |
boxes = [boxes] | |
# Make some sanity check | |
assert len(boxes) == len(kes), "kes and boxes should have the same length." | |
# TODO: Is this JIT compatible? | |
# If not we should make it compatible. | |
results = [] | |
for ke, box in zip(kes, boxes): | |
assert ke.shape[0] == len(box), "Number of objects should be the same." | |
# print("ke inference.py", ke.size()) | |
result = self.forward_single_image(ke, box) | |
results.append(result) | |
return results | |
def make_roi_ke_post_processor(cfg): | |
if cfg.MODEL.ROI_KE_HEAD.POSTPROCESS_KES: | |
ke_threshold = cfg.MODEL.ROI_KE_HEAD.POSTPROCESS_KES_THRESHOLD | |
keer = KEer(threshold=ke_threshold, padding=1) | |
else: | |
keer = None | |
ke_post_processor = KEPostProcessor(keer) | |
return ke_post_processor | |