import torch.nn as nn import torchsparse.nn as spnn from torchsparse.point_tensor import PointTensor from lib.spvcnn_utils import * __all__ = ['SPVCNN_CLASSIFICATION'] class BasicConvolutionBlock(nn.Module): def __init__(self, inc, outc, ks=3, stride=1, dilation=1): super().__init__() self.net = nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), spnn.BatchNorm(outc), spnn.ReLU(True)) def forward(self, x): out = self.net(x) return out class BasicDeconvolutionBlock(nn.Module): def __init__(self, inc, outc, ks=3, stride=1): super().__init__() self.net = nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=ks, stride=stride, transpose=True), spnn.BatchNorm(outc), spnn.ReLU(True)) def forward(self, x): return self.net(x) class ResidualBlock(nn.Module): def __init__(self, inc, outc, ks=3, stride=1, dilation=1): super().__init__() self.net = nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=ks, dilation=dilation, stride=stride), spnn.BatchNorm(outc), spnn.ReLU(True), spnn.Conv3d(outc, outc, kernel_size=ks, dilation=dilation, stride=1), spnn.BatchNorm(outc) ) self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ nn.Sequential( spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), spnn.BatchNorm(outc) ) self.relu = spnn.ReLU(True) def forward(self, x): out = self.relu(self.net(x) + self.downsample(x)) return out class SPVCNN_CLASSIFICATION(nn.Module): def __init__(self, **kwargs): super().__init__() cr = kwargs.get('cr', 1.0) cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] cs = [int(cr * x) for x in cs] if 'pres' in kwargs and 'vres' in kwargs: self.pres = kwargs['pres'] self.vres = kwargs['vres'] self.stem = nn.Sequential( spnn.Conv3d(kwargs['input_channel'], cs[0], kernel_size=3, stride=1), spnn.BatchNorm(cs[0]), spnn.ReLU(True), spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), spnn.BatchNorm(cs[0]), spnn.ReLU(True)) self.stage1 = nn.Sequential( BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), ) self.stage2 = nn.Sequential( BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), ) self.stage3 = nn.Sequential( BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), ) self.stage4 = nn.Sequential( BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), ) self.avg_pool = spnn.GlobalAveragePooling() self.classifier = nn.Sequential(nn.Linear(cs[4], kwargs['num_classes'])) self.point_transforms = nn.ModuleList([ nn.Sequential( nn.Linear(cs[0], cs[4]), nn.BatchNorm1d(cs[4]), nn.ReLU(True), ), ]) self.weight_initialization() self.dropout = nn.Dropout(0.3, True) def weight_initialization(self): for m in self.modules(): if isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): # x: SparseTensor z: PointTensor z = PointTensor(x.F, x.C.float()) x0 = initial_voxelize(z, self.pres, self.vres) x0 = self.stem(x0) z0 = voxel_to_point(x0, z, nearest=False) z0.F = z0.F x1 = point_to_voxel(x0, z0) x1 = self.stage1(x1) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) z1 = voxel_to_point(x4, z0) z1.F = z1.F + self.point_transforms[0](z0.F) y1 = point_to_voxel(x4, z1) pool = self.avg_pool(y1) out = self.classifier(pool) return out