File size: 4,165 Bytes
cc9780d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch.nn as nn
import sys,os
sys.path.append("..")
import torch
from datasets import build_dataset
from configs.config_utils import CONFIG
from torch.utils.data import DataLoader
from models.modules import PointEmbed
from models.modules import ConvPointnet_Encoder,ConvPointnet_Decoder
import numpy as np

class TriplaneVAE(nn.Module):
    def __init__(self,opt):
        super().__init__()
        self.point_embedder=PointEmbed(hidden_dim=opt['point_emb_dim'])

        encoder_args=opt['encoder']
        decoder_args=opt['decoder']
        self.encoder=ConvPointnet_Encoder(c_dim=encoder_args['plane_latent_dim'],dim=opt['point_emb_dim'],latent_dim=encoder_args['latent_dim'],
                    plane_resolution=encoder_args['plane_reso'],unet_kwargs=encoder_args['unet'],unet=True,padding=opt['padding'])
        self.decoder=ConvPointnet_Decoder(latent_dim=decoder_args['latent_dim'],query_emb_dim=decoder_args['query_emb_dim'],
                                          hidden_dim=decoder_args['hidden_dim'],unet_kwargs=decoder_args['unet'],n_blocks=decoder_args['n_blocks'],
                                          plane_resolution=decoder_args['plane_reso'],padding=opt['padding'])

    def forward(self,p,query):
        '''
        :param p: surface points cloud of shape B,N,3
        :param query: sample points of shape B,N,3
        :return:
        '''
        point_emb=self.point_embedder(p)
        query_emb=self.point_embedder(query)
        kl,plane_feat,means,logvars=self.encoder(p,point_emb)
        if self.training:
            if np.random.random()<0.5:
                '''randomly sacle the triplane, and conduct triplane diffusion on 64x64x64 plane, promote robustness'''
                plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode="bilinear")
                plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode="bilinear")
        # if self.training:
        #     if np.random.random()<0.5:
        #         means = torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
        #         vars=torch.exp(logvars)
        #         vars = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")
        #         new_logvars=torch.log(vars)
        #         posterior = DiagonalGaussianDistribution(means, new_logvars)
        #         plane_feat=posterior.sample()
        #         plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode='bilinear')

        # mean_scale=torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
        # vars = torch.exp(logvars)
        # vars_scale = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")/4
        # logvars_scale=torch.log(vars_scale)
        # scale_noise=torch.randn(mean_scale.shape).to(mean_scale.device)
        # plane_feat_scale2=mean_scale+torch.exp(0.5*logvars_scale)*scale_noise
        # plane_feat=torch.nn.functional.interpolate(plane_feat_scale2,scale_factor=2,mode='bilinear')
        o=self.decoder(plane_feat,query,query_emb)

        return {'logits':o,'kl':kl}


    def decode(self,plane_feature,query):
        query_embedding=self.point_embedder(query)
        o=self.decoder(plane_feature,query,query_embedding)

        return o

    def encode(self,p):
        point_emb = self.point_embedder(p)
        kl, plane_feat,mean,logvar = self.encoder(p, point_emb)
        '''p is point cloud of B,N,3'''
        return plane_feat,kl,mean,logvar

if __name__=="__main__":
    configs=CONFIG("../configs/train_triplane_vae_64.yaml")
    config=configs.config
    dataset_config=config['datasets']
    model_config=config["model"]
    dataset=build_dataset("train",dataset_config)
    dataset.__getitem__(0)
    dataloader=DataLoader(
        dataset=dataset,
        batch_size=10,
        shuffle=True,
        num_workers=2,
    )
    net=TriplaneVAE(model_config).float().cuda()
    for idx,data_batch in enumerate(dataloader):
        if idx==1:
            break
        surface=data_batch['surface'].float().cuda()
        query=data_batch['points'].float().cuda()
        net(surface,query)