|
import torch |
|
|
|
|
|
|
|
def decode_infer(output, stride): |
|
|
|
|
|
|
|
|
|
|
|
|
|
sh = torch.tensor(output.shape) |
|
bz = sh[0] |
|
gridsize = sh[-1] |
|
|
|
output = output.permute(0, 2, 3, 1) |
|
output = output.view(bz, gridsize, gridsize, self.gt_per_grid, 5+self.numclass) |
|
x1y1, x2y2, conf, prob = torch.split( |
|
output, [2, 2, 1, self.numclass], dim=4) |
|
|
|
shiftx = torch.arange(0, gridsize, dtype=torch.float32) |
|
shifty = torch.arange(0, gridsize, dtype=torch.float32) |
|
shifty, shiftx = torch.meshgrid([shiftx, shifty], indexing='ij') |
|
shiftx = shiftx.unsqueeze(-1).repeat(bz, 1, 1, self.gt_per_grid) |
|
shifty = shifty.unsqueeze(-1).repeat(bz, 1, 1, self.gt_per_grid) |
|
|
|
xy_grid = torch.stack([shiftx, shifty], dim=4).cuda() |
|
x1y1 = (xy_grid+0.5-torch.exp(x1y1))*stride |
|
x2y2 = (xy_grid+0.5+torch.exp(x2y2))*stride |
|
|
|
xyxy = torch.cat((x1y1, x2y2), dim=4) |
|
conf = torch.sigmoid(conf) |
|
prob = torch.sigmoid(prob) |
|
output = torch.cat((xyxy, conf, prob), 4) |
|
output = output.view(bz, -1, 5+self.numclass) |
|
return output |
|
|