File size: 4,704 Bytes
cc9780d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_max
from .unet import UNet
from .resnet_block import ResnetBlockFC
import numpy as np

class ConvPointnet_Decoder(nn.Module):
    ''' PointNet-based encoder network with ResNet blocks for each point.
        Number of input points are fixed.

    Args:
        c_dim (int): dimension of latent code c
        dim (int): input points dimension
        hidden_dim (int): hidden dimension of the network
        scatter_type (str): feature aggregation when doing local pooling
        unet (bool): weather to use U-Net
        unet_kwargs (str): U-Net parameters
        plane_resolution (int): defined resolution for plane feature
        plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
        padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
        n_blocks (int): number of blocks ResNetBlockFC layers
    '''

    def __init__(self, latent_dim=32,query_emb_dim=51,hidden_dim=128, unet_kwargs=None,
                 plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
        super().__init__()

        self.latent_dim=32
        self.actvn = nn.ReLU()

        self.unet = UNet(unet_kwargs['output_dim'], in_channels=latent_dim, **unet_kwargs)

        self.fc_c=nn.ModuleList
        self.reso_plane = plane_resolution
        self.plane_type = plane_type
        self.padding = padding
        self.n_blocks=n_blocks

        self.fc_c = nn.ModuleList([
            nn.Linear(latent_dim*3, hidden_dim) for i in range(n_blocks)
        ])
        self.fc_p=nn.Linear(query_emb_dim,hidden_dim)
        self.fc_out=nn.Linear(hidden_dim,1)

        self.blocks = nn.ModuleList([
            ResnetBlockFC(hidden_dim) for i in range(n_blocks)
        ])

    def forward(self, plane_features,query,query_emb):  # , query2):
        plane_feature=self.unet(plane_features)
        H,W=plane_feature.shape[2:4]
        xz_feat,xy_feat,yz_feat=torch.split(plane_feature,dim=2,split_size_or_sections=H//3)
        xz_sample_feat=self.sample_plane_feature(query,xz_feat,'xz')
        xy_sample_feat=self.sample_plane_feature(query,xy_feat,'xy')
        yz_sample_feat=self.sample_plane_feature(query,yz_feat,'yz')

        sample_feat=torch.cat([xz_sample_feat,xy_sample_feat,yz_sample_feat],dim=1)
        sample_feat=sample_feat.transpose(1,2)

        net=self.fc_p(query_emb)
        for i in range(self.n_blocks):
            net=net+self.fc_c[i](sample_feat)
            net=self.blocks[i](net)
        out=self.fc_out(self.actvn(net)).squeeze(-1)
        return out


    def normalize_coordinate(self, p, padding=0.1, plane='xz'):
        ''' Normalize coordinate to [0, 1] for unit cube experiments

        Args:
            p (tensor): point
            padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
            plane (str): plane feature type, ['xz', 'xy', 'yz']
        '''
        if plane == 'xz':
            xy = p[:, :, [0, 2]]
        elif plane == 'xy':
            xy = p[:, :, [0, 1]]
        else:
            xy = p[:, :, [1, 2]]
        #print("origin",torch.amin(xy), torch.amax(xy))
        xy=xy/2 #xy is originally -1 ~ 1
        xy_new = xy / (1 + padding + 10e-6)  # (-0.5, 0.5)
        xy_new = xy_new + 0.5  # range (0, 1)
        #print("scale",torch.amin(xy_new),torch.amax(xy_new))

        # f there are outliers out of the range
        if xy_new.max() >= 1:
            xy_new[xy_new >= 1] = 1 - 10e-6
        if xy_new.min() < 0:
            xy_new[xy_new < 0] = 0.0
        return xy_new

    def coordinate2index(self, x, reso):
        ''' Normalize coordinate to [0, 1] for unit cube experiments.
            Corresponds to our 3D model

        Args:
            x (tensor): coordinate
            reso (int): defined resolution
            coord_type (str): coordinate type
        '''
        x = (x * reso).long()
        index = x[:, :, 0] + reso * x[:, :, 1]
        index = index[:, None, :]
        return index

    # uses values from plane_feature and pixel locations from vgrid to interpolate feature
    def sample_plane_feature(self, query, plane_feature, plane):
        xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
        xy = xy[:, :, None].float()
        vgrid = 2.0 * xy - 1.0  # normalize to (-1, 1)
        sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
                                     mode='bilinear').squeeze(-1)
        return sampled_feat