chendl's picture
Add application file
0b7b08a
raw
history blame
6.51 kB
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii Inc. All rights reserved.
import megengine.functional as F
import megengine.module as M
from .network_blocks import BaseConv, DWConv
def meshgrid(x, y):
"""meshgrid wrapper for megengine"""
assert len(x.shape) == 1
assert len(y.shape) == 1
mesh_shape = (y.shape[0], x.shape[0])
mesh_x = F.broadcast_to(x, mesh_shape)
mesh_y = F.broadcast_to(y.reshape(-1, 1), mesh_shape)
return mesh_x, mesh_y
class YOLOXHead(M.Module):
def __init__(
self, num_classes, width=1.0, strides=[8, 16, 32],
in_channels=[256, 512, 1024], act="silu", depthwise=False
):
"""
Args:
act (str): activation type of conv. Defalut value: "silu".
depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
"""
super().__init__()
self.n_anchors = 1
self.num_classes = num_classes
self.decode_in_inference = True # save for matching
self.cls_convs = []
self.reg_convs = []
self.cls_preds = []
self.reg_preds = []
self.obj_preds = []
self.stems = []
Conv = DWConv if depthwise else BaseConv
for i in range(len(in_channels)):
self.stems.append(
BaseConv(
in_channels=int(in_channels[i] * width),
out_channels=int(256 * width),
ksize=1,
stride=1,
act=act,
)
)
self.cls_convs.append(
M.Sequential(
*[
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
]
)
)
self.reg_convs.append(
M.Sequential(
*[
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
Conv(
in_channels=int(256 * width),
out_channels=int(256 * width),
ksize=3,
stride=1,
act=act,
),
]
)
)
self.cls_preds.append(
M.Conv2d(
in_channels=int(256 * width),
out_channels=self.n_anchors * self.num_classes,
kernel_size=1,
stride=1,
padding=0,
)
)
self.reg_preds.append(
M.Conv2d(
in_channels=int(256 * width),
out_channels=4,
kernel_size=1,
stride=1,
padding=0,
)
)
self.obj_preds.append(
M.Conv2d(
in_channels=int(256 * width),
out_channels=self.n_anchors * 1,
kernel_size=1,
stride=1,
padding=0,
)
)
self.use_l1 = False
self.strides = strides
self.grids = [F.zeros(1)] * len(in_channels)
def forward(self, xin, labels=None, imgs=None):
outputs = []
assert not self.training
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)
):
x = self.stems[k](x)
cls_x = x
reg_x = x
cls_feat = cls_conv(cls_x)
cls_output = self.cls_preds[k](cls_feat)
reg_feat = reg_conv(reg_x)
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)
output = F.concat([reg_output, F.sigmoid(obj_output), F.sigmoid(cls_output)], 1)
outputs.append(output)
self.hw = [x.shape[-2:] for x in outputs]
# [batch, n_anchors_all, 85]
outputs = F.concat([F.flatten(x, start_axis=2) for x in outputs], axis=2)
outputs = F.transpose(outputs, (0, 2, 1))
if self.decode_in_inference:
return self.decode_outputs(outputs)
else:
return outputs
def get_output_and_grid(self, output, k, stride, dtype):
grid = self.grids[k]
batch_size = output.shape[0]
n_ch = 5 + self.num_classes
hsize, wsize = output.shape[-2:]
if grid.shape[2:4] != output.shape[2:4]:
yv, xv = meshgrid([F.arange(hsize), F.arange(wsize)])
grid = F.stack((xv, yv), 2).reshape(1, 1, hsize, wsize, 2).type(dtype)
self.grids[k] = grid
output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize)
output = (
output.permute(0, 1, 3, 4, 2)
.reshape(batch_size, self.n_anchors * hsize * wsize, -1)
)
grid = grid.view(1, -1, 2)
output[..., :2] = (output[..., :2] + grid) * stride
output[..., 2:4] = F.exp(output[..., 2:4]) * stride
return output, grid
def decode_outputs(self, outputs):
grids = []
strides = []
for (hsize, wsize), stride in zip(self.hw, self.strides):
xv, yv = meshgrid(F.arange(hsize), F.arange(wsize))
grid = F.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
strides.append(F.full((*shape, 1), stride))
grids = F.concat(grids, axis=1)
strides = F.concat(strides, axis=1)
outputs[..., :2] = (outputs[..., :2] + grids) * strides
outputs[..., 2:4] = F.exp(outputs[..., 2:4]) * strides
return outputs