In [1]:
import random
from statistics import mean
from datetime import datetime
from typing import List, Tuple
import copy

import torch as th
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from jaxtyping import Float, Float16, Int

import trimesh as tm
import numpy as np
import numba

from torch_geometric.nn.conv import GATv2Conv

import h5py

# Clone SAP from original repo https://github.com/autonomousvision/shape_as_points.git
from SAP.dpsr import DPSR
from SAP.model import PSR2Mesh

  from .autonotebook import tqdm as notebook_tqdm


# Constants

In [2]:
th.manual_seed(0)
np.random.seed(0)

In [3]:
IS_DEBUG = True

In [4]:
BATCH_SIZE = 1 # BS
LR = 0.001

IN_DIM = 1 
OUT_DIM = 1
LATENT_DIM = 32

DROPOUT_PROB = 0.1

PADDING = 1.2 # Scaling

GRID_SIZE = 128
SIGMA = 5.0

In [5]:
DATASET = "Synthetic"
LOG_IDX = 14
LOG_VISUALS = not IS_DEBUG

CHECKPOINTS_PATH = "./checkpoints/"

FIELDS_H5_PATH = f"./Standart_fields/{DATASET}_fields_32_512.h5"
PATH_ORIG_H5 = f"./Standart_h5/{DATASET}.h5"
PATH_NOISY_H5 = f"./Standart_h5/{DATASET}_noisy.h5"
MIN_V_NUMBER = 1_000
MAX_V_NUMBER = 100_000

# Data Preparation

In [6]:
@numba.njit
def generate_grid_edge_list(gs: int = 128):
    grid_edge_list = []

    for k in range(gs):
        for j in range(gs):
            for i in range(gs):
                current_idx = i + gs*j + k*gs*gs
                if (i - 1) >= 0:
                    grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])
                if (i + 1) < gs:
                    grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])
                if (j - 1) >= 0:
                    grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])
                if (j + 1) < gs:
                    grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])
                if (k - 1) >= 0:
                    grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])
                if (k + 1) < gs:
                    grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])
    return grid_edge_list

GRID_EDGE_LIST = None

In [7]:
class StandartH5DataSet(th.utils.data.Dataset):
    
    def _load_data(self, key: str):
        key_orig = key.replace("_n1", "")
        key_orig = key_orig.replace("_n2", "")
        key_orig = key_orig.replace("_n3", "")
        key_orig = key_orig.replace("_noisy", "")

        vertices            = th.tensor(self._noisy_meshes_h5[key]["vertices"][:],                 dtype=th.float)
        vertices_normals    = th.tensor(self._noisy_meshes_h5[key]["vertices_normals"][:],         dtype=th.float)
        vertices_gt         = th.tensor(self._orig_meshes_h5[key_orig]["vertices"][:],             dtype=th.float)
        vertices_normals_gt = th.tensor(self._orig_meshes_h5[key_orig]["vertices_normals"][:],     dtype=th.float)
        field_gt            = self.dpsr(vertices_gt.unsqueeze(0), vertices_normals_gt.unsqueeze(0)).squeeze(0)

        adj = np.array(self._noisy_meshes_h5[key]["edge_index"][:], dtype=np.int64)
        adj = th.tensor(adj, dtype=th.int64)
        
        return vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj
        
    def __init__(self, 
                 orig_meshes_h5: h5py.Group,
                 noisy_meshes_h5: h5py.Group,
                 fields_grid_size: int,
                 min_verts: int,
                 max_verts: int) -> None:
        super().__init__()
        
        self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=SIGMA)
        
        self._orig_meshes_h5 =  orig_meshes_h5
        self._noisy_meshes_h5 = noisy_meshes_h5
        
        self._fields_grid_size = str(fields_grid_size)
        self._min_verts = min_verts
        self._max_verts = max_verts
        
        self._data = {}
        self._keys = []
        
        # filter keys to load only meshes with requested amount of vertices
        for key in self._noisy_meshes_h5.keys():
            v_number = self._noisy_meshes_h5[key]["vertices"].shape[0]
            if (v_number >= self._min_verts) and (v_number <= self._max_verts):
                self._keys.append(key)
        self._keys = np.array(self._keys, dtype=str)
        self._loaded = np.full(shape=self._keys.shape, fill_value=False, dtype=bool)
    
    def __len__(self) -> int:
        return self._keys.shape[0]
    
    def __getitem__(self, index: int) -> Tuple[Float[th.Tensor, "N 3"],
                                               Float[th.Tensor, "N 3"],
                                               Float[th.Tensor, "N 3"],
                                               Float[th.Tensor, "N 3"],
                                               Float[th.Tensor, "GR GR GR"],
                                               Float[th.Tensor, "2 E"]]:
        if self._loaded[index] == False:
            data = self._load_data(self._keys[index])
            self._data[index] = data
            self._loaded[index] = True
        return copy.deepcopy(self._data[index])
    
    @property
    def fields_grid_size(self):
        return int(self._fields_grid_size)
    
    def renew_grid_size(self, new_grid_size: int):
        self._fields_grid_size = str(new_grid_size)
        self._loaded = np.full(shape=self._keys.shape, fill_value=False, dtype=bool)

# Model

### Form Optimizer 

In [8]:
class FormOptimizer(th.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        layers = []
        
        self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
        self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)
        
        self.actv = th.nn.Sigmoid()
        self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)

    def forward(self, 
                field: Float[th.Tensor, "GS GS GS"]) -> Float[th.Tensor, "GS GS GS"]:
        """
         Args:
        field (Tensor [GS, GS, GS]): vertices and normals tensor.
        """
        vertex_features = field.clone()
        vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)
        
        vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST) 
        vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST) 
        field_delta = self.head(self.actv(vertex_features))
        
        field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)
        field_delta += field 
        field_delta = th.clamp(field_delta, min=-0.5, max=0.5)
        
        return field_delta

### Full

In [9]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.form_optimizer = FormOptimizer()
        
        self.dpsr =  DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=SIGMA)
        self.field2mesh = PSR2Mesh().apply

        self.metric = th.nn.MSELoss()

        #video logging databases
        dateTimeObj = datetime.now()
        start_time = dateTimeObj.strftime("%d-%b-%Y_%H-%M-%S")
        
        if LOG_VISUALS:
            self.h5_frame = 0
            self.log_points_file = h5py.File(f"./logs/points_{start_time}", "w")
            self.log_normals_file = h5py.File(f"./logs/normals_{start_time}", "w")
            
        self.val_losses = []
        self.train_losses = []

    def log_h5(self, points, normals):
        dset = self.log_points_file.create_dataset(
                                    name=str(self.h5_frame),
                                    shape=points.shape,
                                    dtype=np.float16, 
                                    compression="gzip")
        dset[:] = points
        dset = self.log_normals_file.create_dataset(
                                     name=str(self.h5_frame),
                                     shape=normals.shape,
                                     dtype=np.float16, 
                                     compression="gzip")
        dset[:] = normals
        self.h5_frame += 1
    
    def forward(self, 
                v: Float[th.Tensor, "BS N 3"],
                n: Float[th.Tensor, "BS N 3"]) -> Tuple[Float[th.Tensor, "BS N 3"], # v - vertices
                                                        Int[th.Tensor, "2 E"], # f - faces
                                                        Float[th.Tensor, "BS N 3"], # n - vertices normals
                                                        Float[th.Tensor, "BS GR GR GR"]]: # field: 
        field = self.dpsr(v, n)
        field = self.form_optimizer(field)
        v, f, n = self.field2mesh(field)
        return v, f, n, field

    def training_step(self, batch, batch_idx) -> Float[th.Tensor, "1"]:
        vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
            
        mask = th.rand((vertices.shape[1], ), device=th.device("cuda")) < (random.random() / 2.0 + 0.5)
        vertices = vertices[:, mask]
        vertices_normals = vertices_normals[:, mask]
        
        vr, fr, nr, field_r = model(vertices, vertices_normals)
        
        loss = self.metric(field_r, field_gt)
        if LOG_VISUALS and (LOG_IDX == batch_idx):
            self.log_h5(vr.squeeze(0).detach().cpu().numpy(), nr.squeeze(0).detach().cpu().numpy())
        train_per_step_loss = loss.item()
        self.train_losses.append(train_per_step_loss)
        
        return loss
    
    def on_train_epoch_end(self):
        mean_train_per_epoch_loss = mean(self.train_losses)
        self.log("mean_train_per_epoch_loss", mean_train_per_epoch_loss, on_step=False, on_epoch=True)
        self.train_losses = []
    
    def validation_step(self, batch, batch_idx):
        vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch
            
        vr, fr, nr, field_r = model(vertices, vertices_normals)
        
        loss = self.metric(field_r, field_gt)
        val_per_step_loss = loss.item()
        self.val_losses.append(val_per_step_loss)
        return loss
    
    def on_validation_epoch_end(self):
        mean_val_per_epoch_loss = mean(self.val_losses)
        self.log("mean_val_per_epoch_loss", mean_val_per_epoch_loss, on_step=False, on_epoch=True)
        self.val_losses = []

    def configure_optimizers(self):
        optimizer = th.optim.Adam(self.parameters(), lr=LR)
        scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
        
        return {
                "optimizer": optimizer,
                "lr_scheduler": {
                                 "scheduler": scheduler, 
                                 "monitor": "mean_val_per_epoch_loss",
                                 "interval": "epoch",
                                 "frequency": 1,
                                 "strict": True,
                                 "name": None,
                                }
                }


# Loop

In [10]:
checkpoint_callback = ModelCheckpoint(
    monitor='mean_val_per_epoch_loss', # monitor the validation loss
    mode='min', # mode 'min' to save the lowest monitored value
    save_top_k=1, # save only the best checkpoint (top 1)
)

In [11]:
if __name__ == "__main__":
    
    GRID_EDGE_LIST = generate_grid_edge_list(GRID_SIZE)
    GRID_EDGE_LIST = th.tensor(GRID_EDGE_LIST, dtype=th.int)
    GRID_EDGE_LIST = GRID_EDGE_LIST.T
    GRID_EDGE_LIST = GRID_EDGE_LIST.to(th.device("cuda"))
    
    noisy_meshes_h5 = h5py.File("./Standart_h5/Synthetic_noisy.h5", "r")
    orig_meshes_h5 =  h5py.File("./Standart_h5/Synthetic.h5", "r")
    
    train_dataset = StandartH5DataSet(orig_meshes_h5=orig_meshes_h5['train'],
                                      noisy_meshes_h5=noisy_meshes_h5['train'],
                                      fields_grid_size=GRID_SIZE,
                                      min_verts=MIN_V_NUMBER,
                                      max_verts=MAX_V_NUMBER)
    test_dataset = StandartH5DataSet(orig_meshes_h5=orig_meshes_h5['test'],
                                     noisy_meshes_h5=noisy_meshes_h5['test'],
                                     fields_grid_size=GRID_SIZE,
                                     min_verts=MIN_V_NUMBER,
                                     max_verts=MAX_V_NUMBER)

    train_dataloader = th.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_dataloader = th.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    trainer = pl.Trainer(accelerator="gpu", 
                         callbacks=[checkpoint_callback],
                         log_every_n_steps=len(train_dataset)+len(test_dataset),
                         fast_dev_run=(300 if IS_DEBUG else False),
                         max_epochs=200,
                         precision=16)
    
    model = Model()
    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
    if LOG_VISUALS:
        model.log_points_file.close()
        model.log_normals_file.close()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 300 batch(es). Logging and checkpointing is suppressed.
You are using a CUDA device ('A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type          | Params
-------------------------------------------------
0 | form_optimizer | For

Epoch 0: 100%|██████████| 60/60 [00:17<00:00,  3.52it/s]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/84 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/84 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|          | 1/84 [00:00<00:09,  8.99it/s][A
Validation DataLoader 0:   2%|▏         | 2/84 [00:00<00:10,  7.88it/s][A
Validation DataLoader 0:   4%|▎         | 3/84 [00:00<00:10,  7.47it/s][A
Validation DataLoader 0:   5%|▍         | 4/84 [00:00<00:11,  7.22it/s][A
Validation DataLoader 0:   6%|▌         | 5/84 [00:00<00:11,  7.13it/s][A
Validation DataLoader 0:   7%|▋         | 6/84 [00:00<00:10,  7.09it/s][A
Validation DataLoader 0:   8%|▊         | 7/84 [00:01<00:11,  6.95it/s][A
Validation DataLoader 0:  10%|▉         | 8/84 [00:01<00:11,  6.86it/s][A
Validation DataLoader 0:  11%|█         | 9/84 [00:01<00:11,  6.81it/s][A
Validation DataLoader 0:  12%|█▏        | 10/84 [00:01<00:11,  6.71it/s][A
Validation DataLoader 0:  13%|█▎     

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 60/60 [00:30<00:00,  1.94it/s]
