#!/usr/bin/env python3 # -*- coding:utf-8 -*- # Copyright (c) Megvii Inc. All rights reserved. import megengine.functional as F import megengine.module as M from .network_blocks import BaseConv, DWConv def meshgrid(x, y): """meshgrid wrapper for megengine""" assert len(x.shape) == 1 assert len(y.shape) == 1 mesh_shape = (y.shape[0], x.shape[0]) mesh_x = F.broadcast_to(x, mesh_shape) mesh_y = F.broadcast_to(y.reshape(-1, 1), mesh_shape) return mesh_x, mesh_y class YOLOXHead(M.Module): def __init__( self, num_classes, width=1.0, strides=[8, 16, 32], in_channels=[256, 512, 1024], act="silu", depthwise=False ): """ Args: act (str): activation type of conv. Defalut value: "silu". depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False. """ super().__init__() self.n_anchors = 1 self.num_classes = num_classes self.decode_in_inference = True # save for matching self.cls_convs = [] self.reg_convs = [] self.cls_preds = [] self.reg_preds = [] self.obj_preds = [] self.stems = [] Conv = DWConv if depthwise else BaseConv for i in range(len(in_channels)): self.stems.append( BaseConv( in_channels=int(in_channels[i] * width), out_channels=int(256 * width), ksize=1, stride=1, act=act, ) ) self.cls_convs.append( M.Sequential( *[ Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), ] ) ) self.reg_convs.append( M.Sequential( *[ Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), Conv( in_channels=int(256 * width), out_channels=int(256 * width), ksize=3, stride=1, act=act, ), ] ) ) self.cls_preds.append( M.Conv2d( in_channels=int(256 * width), out_channels=self.n_anchors * self.num_classes, kernel_size=1, stride=1, padding=0, ) ) self.reg_preds.append( M.Conv2d( in_channels=int(256 * width), out_channels=4, kernel_size=1, stride=1, padding=0, ) ) self.obj_preds.append( M.Conv2d( in_channels=int(256 * width), out_channels=self.n_anchors * 1, kernel_size=1, stride=1, padding=0, ) ) self.use_l1 = False self.strides = strides self.grids = [F.zeros(1)] * len(in_channels) def forward(self, xin, labels=None, imgs=None): outputs = [] assert not self.training for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate( zip(self.cls_convs, self.reg_convs, self.strides, xin) ): x = self.stems[k](x) cls_x = x reg_x = x cls_feat = cls_conv(cls_x) cls_output = self.cls_preds[k](cls_feat) reg_feat = reg_conv(reg_x) reg_output = self.reg_preds[k](reg_feat) obj_output = self.obj_preds[k](reg_feat) output = F.concat([reg_output, F.sigmoid(obj_output), F.sigmoid(cls_output)], 1) outputs.append(output) self.hw = [x.shape[-2:] for x in outputs] # [batch, n_anchors_all, 85] outputs = F.concat([F.flatten(x, start_axis=2) for x in outputs], axis=2) outputs = F.transpose(outputs, (0, 2, 1)) if self.decode_in_inference: return self.decode_outputs(outputs) else: return outputs def get_output_and_grid(self, output, k, stride, dtype): grid = self.grids[k] batch_size = output.shape[0] n_ch = 5 + self.num_classes hsize, wsize = output.shape[-2:] if grid.shape[2:4] != output.shape[2:4]: yv, xv = meshgrid([F.arange(hsize), F.arange(wsize)]) grid = F.stack((xv, yv), 2).reshape(1, 1, hsize, wsize, 2).type(dtype) self.grids[k] = grid output = output.view(batch_size, self.n_anchors, n_ch, hsize, wsize) output = ( output.permute(0, 1, 3, 4, 2) .reshape(batch_size, self.n_anchors * hsize * wsize, -1) ) grid = grid.view(1, -1, 2) output[..., :2] = (output[..., :2] + grid) * stride output[..., 2:4] = F.exp(output[..., 2:4]) * stride return output, grid def decode_outputs(self, outputs): grids = [] strides = [] for (hsize, wsize), stride in zip(self.hw, self.strides): xv, yv = meshgrid(F.arange(hsize), F.arange(wsize)) grid = F.stack((xv, yv), 2).reshape(1, -1, 2) grids.append(grid) shape = grid.shape[:2] strides.append(F.full((*shape, 1), stride)) grids = F.concat(grids, axis=1) strides = F.concat(strides, axis=1) outputs[..., :2] = (outputs[..., :2] + grids) * strides outputs[..., 2:4] = F.exp(outputs[..., 2:4]) * strides return outputs