Spaces:
Running
Running
#!/usr/bin/env python | |
# encoding: utf-8 | |
''' | |
@author: Xu Yan | |
@file: basic_blocks.py | |
@time: 2021/4/14 22:53 | |
''' | |
import torch.nn as nn | |
try: | |
import torchsparse.nn as spnn | |
except: | |
print('To install torchsparse 1.4.0, please refer to https://github.com/mit-han-lab/torchsparse/tree/74099d10a51c71c14318bce63d6421f698b24f24') | |
class BasicConvolutionBlock(nn.Module): | |
def __init__(self, inc, outc, ks=3, stride=1, dilation=1): | |
super().__init__() | |
self.net = nn.Sequential( | |
spnn.Conv3d( | |
inc, | |
outc, | |
kernel_size=ks, | |
dilation=dilation, | |
stride=stride), spnn.BatchNorm(outc), | |
spnn.ReLU(True)) | |
def forward(self, x): | |
out = self.net(x) | |
return out | |
class BasicDeconvolutionBlock(nn.Module): | |
def __init__(self, inc, outc, ks=3, stride=1): | |
super().__init__() | |
self.net = nn.Sequential( | |
spnn.Conv3d( | |
inc, | |
outc, | |
kernel_size=ks, | |
stride=stride, | |
transposed=True), | |
spnn.BatchNorm(outc), | |
spnn.ReLU(True)) | |
def forward(self, x): | |
return self.net(x) | |
class ResidualBlock(nn.Module): | |
def __init__(self, inc, outc, ks=3, stride=1, dilation=1): | |
super().__init__() | |
self.net = nn.Sequential( | |
spnn.Conv3d( | |
inc, | |
outc, | |
kernel_size=ks, | |
dilation=dilation, | |
stride=stride), spnn.BatchNorm(outc), | |
spnn.ReLU(True), | |
spnn.Conv3d( | |
outc, | |
outc, | |
kernel_size=ks, | |
dilation=dilation, | |
stride=1), | |
spnn.BatchNorm(outc)) | |
self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ | |
nn.Sequential( | |
spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), | |
spnn.BatchNorm(outc) | |
) | |
self.ReLU = spnn.ReLU(True) | |
def forward(self, x): | |
out = self.ReLU(self.net(x) + self.downsample(x)) | |
return out | |