Spaces:
Paused
Paused
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SinusoidalPositionEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, time): | |
device = time.device | |
half_dim = self.dim // 2 | |
embeddings = math.log(10000) / (half_dim - 1) | |
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) | |
embeddings = time[:, None] * embeddings[None, :] | |
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) | |
return embeddings | |
class DropoutSampler(torch.nn.Module): | |
def __init__(self, num_features, num_outputs, dropout_rate = 0.5): | |
super(DropoutSampler, self).__init__() | |
self.linear = nn.Linear(num_features, num_features) | |
self.linear2 = nn.Linear(num_features, num_features) | |
self.predict = nn.Linear(num_features, num_outputs) | |
self.num_features = num_features | |
self.num_outputs = num_outputs | |
self.dropout_rate = dropout_rate | |
def forward(self, x): | |
x = F.relu(self.linear(x)) | |
if self.dropout_rate > 0: | |
x = F.dropout(x, self.dropout_rate) | |
x = F.relu(self.linear2(x)) | |
return self.predict(x) | |
class EncoderMLP(torch.nn.Module): | |
def __init__(self, in_dim, out_dim, pt_dim=3, uses_pt=True): | |
super(EncoderMLP, self).__init__() | |
self.uses_pt = uses_pt | |
self.output = out_dim | |
d5 = int(in_dim) | |
d6 = int(2 * self.output) | |
d7 = self.output | |
self.encode_position = nn.Sequential( | |
nn.Linear(pt_dim, in_dim), | |
nn.LayerNorm(in_dim), | |
nn.ReLU(), | |
nn.Linear(in_dim, in_dim), | |
nn.LayerNorm(in_dim), | |
nn.ReLU(), | |
) | |
d5 = 2 * in_dim if self.uses_pt else in_dim | |
self.fc_block = nn.Sequential( | |
nn.Linear(int(d5), d6), | |
nn.LayerNorm(int(d6)), | |
nn.ReLU(), | |
nn.Linear(int(d6), d6), | |
nn.LayerNorm(int(d6)), | |
nn.ReLU(), | |
nn.Linear(d6, d7)) | |
def forward(self, x, pt=None): | |
if self.uses_pt: | |
if pt is None: raise RuntimeError('did not provide pt') | |
y = self.encode_position(pt) | |
x = torch.cat([x, y], dim=-1) | |
return self.fc_block(x) | |
class MeanEncoder(torch.nn.Module): | |
def __init__(self, input_channels=3, use_xyz=True, output=512, scale=0.04, factor=1): | |
super(MeanEncoder, self).__init__() | |
self.uses_rgb = False | |
self.dim = 3 | |
def forward(self, xyz, f=None): | |
# Fix shape | |
if f is not None: | |
if len(f.shape) < 3: | |
f = f.transpose(0,1).contiguous() | |
f = f[None] | |
else: | |
f = f.transpose(1,2).contiguous() | |
if len(xyz.shape) == 3: | |
center = torch.mean(xyz, dim=1) | |
elif len(xyz.shape) == 2: | |
center = torch.mean(xyz, dim=0) | |
else: | |
raise RuntimeError('not sure what to do with points of shape ' + str(xyz.shape)) | |
assert(xyz.shape[-1]) == 3 | |
assert(center.shape[-1]) == 3 | |
return center, center |