Spaces:
Runtime error
Runtime error
#!/usr/bin/python | |
# | |
# Copyright 2018 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from . import box_utils | |
""" | |
Functions for computing image layouts from object vectors, bounding boxes, | |
and segmentation masks. These are used to compute course scene layouts which | |
are then fed as input to the cascaded refinement network. | |
""" | |
def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'): | |
""" | |
Inputs: | |
- vecs: Tensor of shape (O, D) giving vectors | |
- boxes: Tensor of shape (O, 4) giving bounding boxes in the format | |
[x0, y0, x1, y1] in the [0, 1] coordinate space | |
- obj_to_img: LongTensor of shape (O,) mapping each element of vecs to | |
an image, where each element is in the range [0, N). If obj_to_img[i] = j | |
then vecs[i] belongs to image j. | |
- H, W: Size of the output | |
Returns: | |
- out: Tensor of shape (N, D, H, W) | |
""" | |
O, D = vecs.size() | |
if W is None: | |
W = H | |
grid = _boxes_to_grid(boxes, H, W) | |
# If we don't add extra spatial dimensions here then out-of-bounds | |
# elements won't be automatically set to 0 | |
img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8) | |
sampled = F.grid_sample(img_in, grid) # (O, D, H, W) | |
# Explicitly masking makes everything quite a bit slower. | |
# If we rely on implicit masking the interpolated boxes end up | |
# blurred around the edges, but it should be fine. | |
# mask = ((X < 0) + (X > 1) + (Y < 0) + (Y > 1)).clamp(max=1) | |
# sampled[mask[:, None]] = 0 | |
out = _pool_samples(sampled, obj_to_img, pooling=pooling) | |
return out | |
def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum'): | |
""" | |
Inputs: | |
- vecs: Tensor of shape (O, D) giving vectors | |
- boxes: Tensor of shape (O, 4) giving bounding boxes in the format | |
[x0, y0, x1, y1] in the [0, 1] coordinate space | |
- masks: Tensor of shape (O, M, M) giving binary masks for each object | |
- obj_to_img: LongTensor of shape (O,) mapping objects to images | |
- H, W: Size of the output image. | |
Returns: | |
- out: Tensor of shape (N, D, H, W) | |
""" | |
O, D = vecs.size() | |
M = masks.size(1) | |
assert masks.size() == (O, M, M) | |
if W is None: | |
W = H | |
grid = _boxes_to_grid(boxes, H, W) | |
img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M) | |
sampled = F.grid_sample(img_in, grid) | |
out = _pool_samples(sampled, obj_to_img, pooling=pooling) | |
return out | |
def _boxes_to_grid(boxes, H, W): | |
""" | |
Input: | |
- boxes: FloatTensor of shape (O, 4) giving boxes in the [x0, y0, x1, y1] | |
format in the [0, 1] coordinate space | |
- H, W: Scalars giving size of output | |
Returns: | |
- grid: FloatTensor of shape (O, H, W, 2) suitable for passing to grid_sample | |
""" | |
O = boxes.size(0) | |
boxes = box_utils.centers_to_extents(boxes) | |
boxes = boxes.view(O, 4, 1, 1) | |
# w,h = boxes[:, 2], boxes[:, 3] | |
# # All these are (O, 1, 1) | |
# x0, y0 = boxes[:, 0]-w/2, boxes[:, 1]-h/2 | |
# x1, y1 = boxes[:, 0]+w/2, boxes[:, 1]+h/2 | |
x0, y0 = boxes[:, 0], boxes[:, 1] | |
ww, hh = boxes[:, 2] - x0, boxes[:, 3] - y0 | |
# ww = x1 - x0 | |
# hh = y1 - y0 | |
X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes) | |
Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes) | |
X = (X - x0) / ww # (O, 1, W) | |
Y = (Y - y0) / hh # (O, H, 1) | |
# Stack does not broadcast its arguments so we need to expand explicitly | |
X = X.expand(O, H, W) | |
Y = Y.expand(O, H, W) | |
grid = torch.stack([X, Y], dim=3) # (O, H, W, 2) | |
# Right now grid is in [0, 1] space; transform to [-1, 1] | |
grid = grid.mul(2).sub(1) | |
return grid | |
def _pool_samples(samples, obj_to_img, pooling='sum'): | |
""" | |
Input: | |
- samples: FloatTensor of shape (O, D, H, W) | |
- obj_to_img: LongTensor of shape (O,) with each element in the range | |
[0, N) mapping elements of samples to output images | |
Output: | |
- pooled: FloatTensor of shape (N, D, H, W) | |
""" | |
dtype, device = samples.dtype, samples.device | |
O, D, H, W = samples.size() | |
N = obj_to_img.data.max().item() + 1 | |
# Use scatter_add to sum the sampled outputs for each image | |
out = torch.zeros(N, D, H, W, dtype=dtype, device=device) | |
idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W) | |
#out = out.scatter_add(0, idx, samples) | |
if pooling == 'avg': | |
# Divide each output mask by the number of objects; use scatter_add again | |
# to count the number of objects per image. | |
out = out.scatter_add(0, idx, samples) | |
ones = torch.ones(O, dtype=dtype, device=device) | |
obj_counts = torch.zeros(N, dtype=dtype, device=device) | |
obj_counts = obj_counts.scatter_add(0, obj_to_img, ones) | |
obj_counts = obj_counts.clamp(min=1) | |
out = out / obj_counts.view(N, 1, 1, 1) | |
elif pooling == 'max': | |
all_out = [] | |
obj_to_img_list = [i.item() for i in list(obj_to_img)] | |
for i in range(N): | |
start = obj_to_img_list.index(i) | |
end = len(obj_to_img_list) - obj_to_img_list[::-1].index(i) | |
all_out.append(torch.max(samples[start:end, :, :, :], dim=0)[0]) | |
out = torch.stack(all_out) | |
elif pooling == 'sum': | |
out = out.scatter_add(0, idx, samples) | |
#raise ValueError('Invalid pooling "%s"' % pooling) | |
return out | |
def masks_to_seg(boxes, masks, objs, obj_to_img, H, W=None, num_classes=15): | |
""" | |
Inputs: | |
- vecs: Tensor of shape (O, D) giving vectors | |
- boxes: Tensor of shape (O, 4) giving bounding boxes in the format | |
[x0, y0, x1, y1] in the [0, 1] coordinate space | |
- obj_to_img: LongTensor of shape (O,) mapping each element of vecs to | |
an image, where each element is in the range [0, N). If obj_to_img[i] = j | |
then vecs[i] belongs to image j. | |
- H, W: Size of the output | |
Returns: | |
- out: Tensor of shape (N, D, H, W) | |
""" | |
dtype, device = boxes.dtype, boxes.device | |
O, D = boxes.size() | |
M = masks.size(1) | |
assert masks.size() == (O, M, M) | |
if W is None: | |
W = H | |
N = obj_to_img.data.max().item() + 1 | |
grid = _boxes_to_grid(boxes, H, W) | |
mask_sampled = F.grid_sample(masks.float().view(O, 1, M, M), grid) | |
seg = torch.zeros((N,num_classes,H,W)).to(device) | |
# obj_to_img_list = [i.item() for i in list(obj_to_img)] | |
for i in range(N): | |
obj_to_i = (obj_to_img==i).nonzero().view(-1) | |
# start = obj_to_img_list.index(i) | |
# end = len(obj_to_img_list) - obj_to_img_list[::-1].index(i) | |
# for j in range(start,end): | |
for j in obj_to_i: | |
obj = objs[j] | |
seg[i,obj]=seg[i,obj]+mask_sampled[j] | |
return seg | |
def boxes_to_seg(boxes, objs, obj_to_img, H, W=None,num_classes=15): | |
""" | |
Inputs: | |
- vecs: Tensor of shape (O, D) giving vectors | |
- boxes: Tensor of shape (O, 4) giving bounding boxes in the format | |
[x0, y0, x1, y1] in the [0, 1] coordinate space | |
- obj_to_img: LongTensor of shape (O,) mapping each element of vecs to | |
an image, where each element is in the range [0, N). If obj_to_img[i] = j | |
then vecs[i] belongs to image j. | |
- H, W: Size of the output | |
Returns: | |
- out: Tensor of shape (N, D, H, W) | |
""" | |
dtype, device = boxes.dtype, boxes.device | |
O, D = boxes.size() | |
if W is None: | |
W = H | |
N = obj_to_img.data.max().item() + 1 | |
grid = _boxes_to_grid(boxes, H, W) | |
mask_sampled = F.grid_sample(torch.ones(O,1,8,8).to(boxes), grid) | |
seg = torch.zeros((N,num_classes,H,W)).to(device) | |
obj_to_img_list = [i.item() for i in list(obj_to_img)] | |
for i in range(N): | |
start = obj_to_img_list.index(i) | |
end = len(obj_to_img_list) - obj_to_img_list[::-1].index(i) | |
for j in range(start,end): | |
#obj_to_i = (obj_to_img==i).nonzero().view(-1) | |
#for j in obj_to_i: | |
obj = objs[j] | |
seg[i,obj]=seg[i,obj]+mask_sampled[j] | |
return seg | |
if __name__ == '__main__': | |
vecs = torch.FloatTensor([ | |
[1, 0, 0], [0, 1, 0], [0, 0, 1], | |
[1, 0, 0], [0, 1, 0], [0, 0, 1], | |
]) | |
boxes = torch.FloatTensor([ | |
[0.25, 0.125, 0.5, 0.875], | |
[0, 0, 1, 0.25], | |
[0.6125, 0, 0.875, 1], | |
[0, 0.8, 1, 1.0], | |
[0.25, 0.125, 0.5, 0.875], | |
[0.6125, 0, 0.875, 1], | |
]) | |
obj_to_img = torch.LongTensor([0, 0, 0, 1, 1, 1]) | |
# vecs = torch.FloatTensor([[[1]]]) | |
# boxes = torch.FloatTensor([[[0.25, 0.25, 0.75, 0.75]]]) | |
vecs, boxes = vecs.cuda(), boxes.cuda() | |
obj_to_img = obj_to_img.cuda() | |
out = boxes_to_layout(vecs, boxes, obj_to_img, 256, pooling='sum') | |
from torchvision.utils import save_image | |
save_image(out.data, 'out.png') | |
masks = torch.FloatTensor([ | |
[ | |
[0, 0, 1, 0, 0], | |
[0, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1], | |
[0, 1, 1, 1, 0], | |
[0, 0, 1, 0, 0], | |
], | |
[ | |
[0, 0, 1, 0, 0], | |
[0, 1, 0, 1, 0], | |
[1, 0, 0, 0, 1], | |
[0, 1, 0, 1, 0], | |
[0, 0, 1, 0, 0], | |
], | |
[ | |
[0, 0, 1, 0, 0], | |
[0, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1], | |
[0, 1, 1, 1, 0], | |
[0, 0, 1, 0, 0], | |
], | |
[ | |
[0, 0, 1, 0, 0], | |
[0, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1], | |
[0, 1, 1, 1, 0], | |
[0, 0, 1, 0, 0], | |
], | |
[ | |
[0, 0, 1, 0, 0], | |
[0, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1], | |
[0, 1, 1, 1, 0], | |
[0, 0, 1, 0, 0], | |
], | |
[ | |
[0, 0, 1, 0, 0], | |
[0, 1, 1, 1, 0], | |
[1, 1, 1, 1, 1], | |
[0, 1, 1, 1, 0], | |
[0, 0, 1, 0, 0], | |
] | |
]) | |
masks = masks.cuda() | |
out = masks_to_layout(vecs, boxes, masks, obj_to_img, 256) | |
save_image(out.data, 'out_masks.png') | |