OneLLM / lib /pointnet2 /pointnet2_modules.py
csuhan's picture
Upload folder using huggingface_hub
8b54513
raw
history blame
6.34 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import pointnet2_utils
from . import pytorch_utils as pt_utils
from typing import List
class _PointnetSAModuleBase(nn.Module):
def __init__(self):
super().__init__()
self.npoint = None
self.groupers = None
self.mlps = None
self.pool_method = 'max_pool'
def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
"""
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
:param features: (B, N, C) tensor of the descriptors of the the features
:param new_xyz:
:return:
new_xyz: (B, npoint, 3) tensor of the new features' xyz
new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
"""
new_features_list = []
xyz_flipped = xyz.transpose(1, 2).contiguous()
if new_xyz is None:
new_xyz = pointnet2_utils.gather_operation(
xyz_flipped,
pointnet2_utils.furthest_point_sample(xyz, self.npoint)
).transpose(1, 2).contiguous() if self.npoint is not None else None
for i in range(len(self.groupers)):
new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample)
new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample)
if self.pool_method == 'max_pool':
new_features = F.max_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
elif self.pool_method == 'avg_pool':
new_features = F.avg_pool2d(
new_features, kernel_size=[1, new_features.size(3)]
) # (B, mlp[-1], npoint, 1)
else:
raise NotImplementedError
new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint)
new_features_list.append(new_features)
return new_xyz, torch.cat(new_features_list, dim=1)
class PointnetSAModuleMSG(_PointnetSAModuleBase):
"""Pointnet set abstraction layer with multiscale grouping"""
def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
"""
:param npoint: int
:param radii: list of float, list of radii to group with
:param nsamples: list of int, number of samples in each ball query
:param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
:param instance_norm: whether to use instance_norm
"""
super().__init__()
assert len(radii) == len(nsamples) == len(mlps)
self.npoint = npoint
self.groupers = nn.ModuleList()
self.mlps = nn.ModuleList()
for i in range(len(radii)):
radius = radii[i]
nsample = nsamples[i]
self.groupers.append(
pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
)
mlp_spec = mlps[i]
if use_xyz:
mlp_spec[0] += 3
self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
self.pool_method = pool_method
class PointnetSAModule(PointnetSAModuleMSG):
"""Pointnet set abstraction layer"""
def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
"""
:param mlp: list of int, spec of the pointnet before the global max_pool
:param npoint: int, number of features
:param radius: float, radius of ball
:param nsample: int, number of samples in the ball query
:param bn: whether to use batchnorm
:param use_xyz:
:param pool_method: max_pool / avg_pool
:param instance_norm: whether to use instance_norm
"""
super().__init__(
mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
pool_method=pool_method, instance_norm=instance_norm
)
class PointnetFPModule(nn.Module):
r"""Propigates the features of one set to another"""
def __init__(self, *, mlp: List[int], bn: bool = True):
"""
:param mlp: list of int
:param bn: whether to use batchnorm
"""
super().__init__()
self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
def forward(
self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
) -> torch.Tensor:
"""
:param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
:param known: (B, m, 3) tensor of the xyz positions of the known features
:param unknow_feats: (B, C1, n) tensor of the features to be propigated to
:param known_feats: (B, C2, m) tensor of features to be propigated
:return:
new_features: (B, mlp[-1], n) tensor of the features of the unknown features
"""
if known is not None:
dist, idx = pointnet2_utils.three_nn(unknown, known)
dist_recip = 1.0 / (dist + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
else:
interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
if unknow_feats is not None:
new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n)
else:
new_features = interpolated_feats
new_features = new_features.unsqueeze(-1)
new_features = self.mlp(new_features)
return new_features.squeeze(-1)
if __name__ == "__main__":
pass