stable-diffusion-webui-depthmap-script / lib /spvcnn_classsification.py
hololens's picture
Upload folder using huggingface_hub
e04dce3 verified
import torch.nn as nn
import torchsparse.nn as spnn
from torchsparse.point_tensor import PointTensor
from lib.spvcnn_utils import *
__all__ = ['SPVCNN_CLASSIFICATION']
class BasicConvolutionBlock(nn.Module):
def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
super().__init__()
self.net = nn.Sequential(
spnn.Conv3d(inc,
outc,
kernel_size=ks,
dilation=dilation,
stride=stride),
spnn.BatchNorm(outc),
spnn.ReLU(True))
def forward(self, x):
out = self.net(x)
return out
class BasicDeconvolutionBlock(nn.Module):
def __init__(self, inc, outc, ks=3, stride=1):
super().__init__()
self.net = nn.Sequential(
spnn.Conv3d(inc,
outc,
kernel_size=ks,
stride=stride,
transpose=True),
spnn.BatchNorm(outc),
spnn.ReLU(True))
def forward(self, x):
return self.net(x)
class ResidualBlock(nn.Module):
def __init__(self, inc, outc, ks=3, stride=1, dilation=1):
super().__init__()
self.net = nn.Sequential(
spnn.Conv3d(inc,
outc,
kernel_size=ks,
dilation=dilation,
stride=stride), spnn.BatchNorm(outc),
spnn.ReLU(True),
spnn.Conv3d(outc,
outc,
kernel_size=ks,
dilation=dilation,
stride=1),
spnn.BatchNorm(outc)
)
self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \
nn.Sequential(
spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride),
spnn.BatchNorm(outc)
)
self.relu = spnn.ReLU(True)
def forward(self, x):
out = self.relu(self.net(x) + self.downsample(x))
return out
class SPVCNN_CLASSIFICATION(nn.Module):
def __init__(self, **kwargs):
super().__init__()
cr = kwargs.get('cr', 1.0)
cs = [32, 32, 64, 128, 256, 256, 128, 96, 96]
cs = [int(cr * x) for x in cs]
if 'pres' in kwargs and 'vres' in kwargs:
self.pres = kwargs['pres']
self.vres = kwargs['vres']
self.stem = nn.Sequential(
spnn.Conv3d(kwargs['input_channel'], cs[0], kernel_size=3, stride=1),
spnn.BatchNorm(cs[0]),
spnn.ReLU(True),
spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1),
spnn.BatchNorm(cs[0]),
spnn.ReLU(True))
self.stage1 = nn.Sequential(
BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1),
ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1),
ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1),
)
self.stage2 = nn.Sequential(
BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1),
ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1),
ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1),
)
self.stage3 = nn.Sequential(
BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1),
ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1),
ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1),
)
self.stage4 = nn.Sequential(
BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1),
ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1),
ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1),
)
self.avg_pool = spnn.GlobalAveragePooling()
self.classifier = nn.Sequential(nn.Linear(cs[4], kwargs['num_classes']))
self.point_transforms = nn.ModuleList([
nn.Sequential(
nn.Linear(cs[0], cs[4]),
nn.BatchNorm1d(cs[4]),
nn.ReLU(True),
),
])
self.weight_initialization()
self.dropout = nn.Dropout(0.3, True)
def weight_initialization(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# x: SparseTensor z: PointTensor
z = PointTensor(x.F, x.C.float())
x0 = initial_voxelize(z, self.pres, self.vres)
x0 = self.stem(x0)
z0 = voxel_to_point(x0, z, nearest=False)
z0.F = z0.F
x1 = point_to_voxel(x0, z0)
x1 = self.stage1(x1)
x2 = self.stage2(x1)
x3 = self.stage3(x2)
x4 = self.stage4(x3)
z1 = voxel_to_point(x4, z0)
z1.F = z1.F + self.point_transforms[0](z0.F)
y1 = point_to_voxel(x4, z1)
pool = self.avg_pool(y1)
out = self.classifier(pool)
return out