File size: 3,242 Bytes
8c02843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class DropoutSampler(torch.nn.Module):
    def __init__(self, num_features, num_outputs, dropout_rate = 0.5):
        super(DropoutSampler, self).__init__()
        self.linear = nn.Linear(num_features, num_features)
        self.linear2 = nn.Linear(num_features, num_features)
        self.predict = nn.Linear(num_features, num_outputs)
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.dropout_rate = dropout_rate

    def forward(self, x):
        x = F.relu(self.linear(x))
        if self.dropout_rate > 0:
            x = F.dropout(x, self.dropout_rate)
        x = F.relu(self.linear2(x))
        return self.predict(x)


class EncoderMLP(torch.nn.Module):
    def __init__(self, in_dim, out_dim, pt_dim=3, uses_pt=True):
        super(EncoderMLP, self).__init__()
        self.uses_pt = uses_pt
        self.output = out_dim
        d5 = int(in_dim)
        d6 = int(2 * self.output)
        d7 = self.output
        self.encode_position = nn.Sequential(
                nn.Linear(pt_dim, in_dim),
                nn.LayerNorm(in_dim),
                nn.ReLU(),
                nn.Linear(in_dim, in_dim),
                nn.LayerNorm(in_dim),
                nn.ReLU(),
                )
        d5 = 2 * in_dim if self.uses_pt else in_dim
        self.fc_block = nn.Sequential(
            nn.Linear(int(d5), d6),
            nn.LayerNorm(int(d6)),
            nn.ReLU(),
            nn.Linear(int(d6), d6),
            nn.LayerNorm(int(d6)),
            nn.ReLU(),
            nn.Linear(d6, d7))

    def forward(self, x, pt=None):
        if self.uses_pt:
            if pt is None: raise RuntimeError('did not provide pt')
            y = self.encode_position(pt)
            x = torch.cat([x, y], dim=-1)
        return self.fc_block(x)


class MeanEncoder(torch.nn.Module):
    def __init__(self, input_channels=3, use_xyz=True, output=512, scale=0.04, factor=1):
        super(MeanEncoder, self).__init__()
        self.uses_rgb = False
        self.dim = 3

    def forward(self, xyz, f=None):

        # Fix shape
        if f is not None:
            if len(f.shape) < 3:
                f = f.transpose(0,1).contiguous()
                f = f[None]
            else:
                f = f.transpose(1,2).contiguous()
        if len(xyz.shape) == 3:
            center = torch.mean(xyz, dim=1)
        elif len(xyz.shape) == 2:
            center = torch.mean(xyz, dim=0)
        else:
            raise RuntimeError('not sure what to do with points of shape ' + str(xyz.shape))
        assert(xyz.shape[-1]) == 3
        assert(center.shape[-1]) == 3
        return center, center