Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) Megvii Inc. All rights reserved. | |
import megengine.functional as F | |
import megengine.module as M | |
from .network_blocks import BaseConv, DWConv | |
def meshgrid(x, y): | |
"""meshgrid wrapper for megengine""" | |
assert len(x.shape) == 1 | |
assert len(y.shape) == 1 | |
mesh_shape = (y.shape[0], x.shape[0]) | |
mesh_x = F.broadcast_to(x, mesh_shape) | |
mesh_y = F.broadcast_to(y.reshape(-1, 1), mesh_shape) | |
return mesh_x, mesh_y | |
class YOLOXHead(M.Module): | |
def __init__( | |
self, num_classes, width=1.0, strides=[8, 16, 32], | |
in_channels=[256, 512, 1024], act="silu", depthwise=False | |
): | |
""" | |
Args: | |
act (str): activation type of conv. Defalut value: "silu". | |
depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. | |
""" | |
super().__init__() | |
self.n_anchors = 1 | |
self.num_classes = num_classes | |
self.decode_in_inference = True # save for matching | |
self.cls_convs = [] | |
self.reg_convs = [] | |
self.cls_preds = [] | |
self.reg_preds = [] | |
self.obj_preds = [] | |
self.stems = [] | |
Conv = DWConv if depthwise else BaseConv | |
for i in range(len(in_channels)): | |
self.stems.append( | |
BaseConv( | |
in_channels=int(in_channels[i] * width), | |
out_channels=int(256 * width), | |
ksize=1, | |
stride=1, | |
act=act, | |
) | |
) | |
self.cls_convs.append( | |
M.Sequential( | |
*[ | |
Conv( | |
in_channels=int(256 * width), | |
out_channels=int(256 * width), | |
ksize=3, | |
stride=1, | |
act=act, | |
), | |
Conv( | |
in_channels=int(256 * width), | |
out_channels=int(256 * width), | |
ksize=3, | |
stride=1, | |
act=act, | |
), | |
] | |
) | |
) | |
self.reg_convs.append( | |
M.Sequential( | |
*[ | |
Conv( | |
in_channels=int(256 * width), | |
out_channels=int(256 * width), | |
ksize=3, | |
stride=1, | |
act=act, | |
), | |
Conv( | |
in_channels=int(256 * width), | |
out_channels=int(256 * width), | |
ksize=3, | |
stride=1, | |
act=act, | |
), | |
] | |
) | |
) | |
self.cls_preds.append( | |
M.Conv2d( | |
in_channels=int(256 * width), | |
out_channels=self.n_anchors * self.num_classes, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
) | |
self.reg_preds.append( | |
M.Conv2d( | |
in_channels=int(256 * width), | |
out_channels=4, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
) | |
self.obj_preds.append( | |
M.Conv2d( | |
in_channels=int(256 * width), | |
out_channels=self.n_anchors * 1, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
) | |
self.use_l1 = False | |
self.strides = strides | |
self.grids = [F.zeros(1)] * len(in_channels) | |
def forward(self, xin, labels=None, imgs=None): | |
outputs = [] | |
assert not self.training | |
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( | |
zip(self.cls_convs, self.reg_convs, self.strides, xin) | |
): | |
x = self.stems[k](x) | |
cls_x = x | |
reg_x = x | |
cls_feat = cls_conv(cls_x) | |
cls_output = self.cls_preds[k](cls_feat) | |
reg_feat = reg_conv(reg_x) | |
reg_output = self.reg_preds[k](reg_feat) | |
obj_output = self.obj_preds[k](reg_feat) | |
output = F.concat([reg_output, F.sigmoid(obj_output), F.sigmoid(cls_output)], 1) | |
outputs.append(output) | |
self.hw = [x.shape[-2:] for x in outputs] | |
# [batch, n_anchors_all, 85] | |
outputs = F.concat([F.flatten(x, start_axis=2) for x in outputs], axis=2) | |
outputs = F.transpose(outputs, (0, 2, 1)) | |
if self.decode_in_inference: | |
return self.decode_outputs(outputs) | |
else: | |
return outputs | |
def get_output_and_grid(self, output, k, stride, dtype): | |
grid = self.grids[k] | |
batch_size = output.shape[0] | |
n_ch = 5 + self.num_classes | |
hsize, wsize = output.shape[-2:] | |
if grid.shape[2:4] != output.shape[2:4]: | |
yv, xv = meshgrid([F.arange(hsize), F.arange(wsize)]) | |
grid = F.stack((xv, yv), 2).reshape(1, 1, hsize, wsize, 2).type(dtype) | |
self.grids[k] = grid | |
output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize) | |
output = ( | |
output.permute(0, 1, 3, 4, 2) | |
.reshape(batch_size, self.n_anchors * hsize * wsize, -1) | |
) | |
grid = grid.view(1, -1, 2) | |
output[..., :2] = (output[..., :2] + grid) * stride | |
output[..., 2:4] = F.exp(output[..., 2:4]) * stride | |
return output, grid | |
def decode_outputs(self, outputs): | |
grids = [] | |
strides = [] | |
for (hsize, wsize), stride in zip(self.hw, self.strides): | |
xv, yv = meshgrid(F.arange(hsize), F.arange(wsize)) | |
grid = F.stack((xv, yv), 2).reshape(1, -1, 2) | |
grids.append(grid) | |
shape = grid.shape[:2] | |
strides.append(F.full((*shape, 1), stride)) | |
grids = F.concat(grids, axis=1) | |
strides = F.concat(strides, axis=1) | |
outputs[..., :2] = (outputs[..., :2] + grids) * strides | |
outputs[..., 2:4] = F.exp(outputs[..., 2:4]) * strides | |
return outputs | |