File size: 7,306 Bytes
f74bb58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import random
from statistics import mean
from typing import List, Tuple

import torch as th
import pytorch_lightning as pl
from jaxtyping import Float, Int
import numpy as np
from torch_geometric.nn.conv import GATv2Conv

from models.SAP.dpsr import DPSR
from models.SAP.model import PSR2Mesh

# Constants

th.manual_seed(0)
np.random.seed(0)

BATCH_SIZE = 1 # BS

IN_DIM = 1 
OUT_DIM = 1
LATENT_DIM = 32

DROPOUT_PROB = 0.1
GRID_SIZE = 128

def generate_grid_edge_list(gs: int = 128):
    grid_edge_list = []

    for k in range(gs):
        for j in range(gs):
            for i in range(gs):
                current_idx = i + gs*j + k*gs*gs
                if (i - 1) >= 0:
                    grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])
                if (i + 1) < gs:
                    grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])
                if (j - 1) >= 0:
                    grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])
                if (j + 1) < gs:
                    grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])
                if (k - 1) >= 0:
                    grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])
                if (k + 1) < gs:
                    grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])
    return grid_edge_list

GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE)
GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int)
GRID_EDGE_LIST = GRID_EDGE_LIST.T
# GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda"))
GRID_EDGE_LIST.requires_grad = False # Do not forget to delete it if train


class FormOptimizer(th.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        layers = []
        
        self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
        self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
        
        self.actv = th.nn.Sigmoid()
        self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)

    def forward(self, 
                field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]:
        """
         Args:
        field (Tensor [GS, GS, GS]): vertices and normals tensor.
        """
        vertex_features = field.clone()
        vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)
        
        vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST) 
        vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST) 
        field_delta = self.head(self.actv(vertex_features))
        
        field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)
        field_delta += field # field_delta carries the gradient
        field_delta = th.clamp(field_delta, min=-0.5, max=0.5)
        
        return field_delta

class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.form_optimizer = FormOptimizer()
        
        self.dpsr =  DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=0.0)
        self.field2mesh = PSR2Mesh().apply

        self.metric = th.nn.MSELoss()
            
        self.val_losses = []
        self.train_losses = []

    def log_h5(self, points, normals):
        dset = self.log_points_file.create_dataset(
                                    name=str(self.h5_frame),
                                    shape=points.shape,
                                    dtype=np.float16, 
                                    compression="gzip")
        dset[:] = points
        dset = self.log_normals_file.create_dataset(
                                     name=str(self.h5_frame),
                                     shape=normals.shape,
                                     dtype=np.float16, 
                                     compression="gzip")
        dset[:] = normals
        self.h5_frame += 1
    
    def forward(self, 
                v: Float[th.Tensor, "BS N 3"],
                n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices
                                                        Int[th.Tensor, "2 E"], # f - faces
                                                        Float[th.Tensor, "BS N 3"], # n - vertices normals
                                                        Float[th.Tensor, "BS GR GR GR"]]: # field: 
        field = self.dpsr(v, n)
        field = self.form_optimizer(field)
        v, f, n = self.field2mesh(field)
        return v, f, n, field

    def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]:
        vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
            
        mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5)
        vertices = vertices[:, mask]
        vertices_normals = vertices_normals[:, mask]
        
        vr, fr, nr, field_r = model(vertices, vertices_normals)
        
        loss = self.metric(field_r, field_gt)
        train_per_step_loss = loss.item()
        self.train_losses.append(train_per_step_loss)
        
        return loss
    
    def on_train_epoch_end(self):
        mean_train_per_epoch_loss = mean(self.train_losses)
        self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True)
        self.train_losses = []
    
    def validation_step(self, batch, batch_idx):
        vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
            
        vr, fr, nr, field_r = model(vertices, vertices_normals)
        
        loss = self.metric(field_r, field_gt)
        val_per_step_loss = loss.item()
        self.val_losses.append(val_per_step_loss)
        return loss
    
    def on_validation_epoch_end(self):
        mean_val_per_epoch_loss = mean(self.val_losses)
        self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True)
        self.val_losses = []

    def configure_optimizers(self):
        optimizer = th.optim.Adam(self.parameters(), lr=LR)
        scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
        
        return {
                "optimizer": optimizer,
                "lr_scheduler": {
                                 "scheduler": scheduler, 
                                 "monitor": "mean_val_per_epoch_loss",
                                 "interval": "epoch",
                                 "frequency": 1,
                                 # If set to `True`, will enforce that the value specified 'monitor'
                                 # is available when the scheduler is updated, thus stopping
                                 # training if not found. If set to `False`, it will only produce a warning
                                 "strict": True,
                                 # If using the `LearningRateMonitor` callback to monitor the
                                 # learning rate progress, this keyword can be used to specify
                                 # a custom logged name
                                 "name": None,
                                }
                }