{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "dd07a8e6-5809-4bb7-ba3a-bd6c15b22ff2", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/user/conda/envs/senv/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import random\n", "from statistics import mean\n", "from datetime import datetime\n", "from typing import List, Tuple\n", "import copy\n", "\n", "import torch as th\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint\n", "from jaxtyping import Float, Float16, Int\n", "\n", "import trimesh as tm\n", "import numpy as np\n", "import numba\n", "\n", "from torch_geometric.nn.conv import GATv2Conv\n", "\n", "import h5py\n", "\n", "# Clone SAP from original repo https://github.com/autonomousvision/shape_as_points.git\n", "from SAP.dpsr import DPSR\n", "from SAP.model import PSR2Mesh" ] }, { "cell_type": "markdown", "id": "59c87491-5650-4c59-8d33-5153d29fb1a9", "metadata": { "tags": [] }, "source": [ "# Constants" ] }, { "cell_type": "code", "execution_count": 2, "id": "26d62fb9-dae9-406b-ba30-3fec1a43a29a", "metadata": { "tags": [] }, "outputs": [], "source": [ "th.manual_seed(0)\n", "np.random.seed(0)" ] }, { "cell_type": "code", "execution_count": 3, "id": "9ab9502f-e822-4475-9c90-019ff28f12d0", "metadata": {}, "outputs": [], "source": [ "IS_DEBUG = True" ] }, { "cell_type": "code", "execution_count": 4, "id": "7095231b-e8ed-4c4d-997f-8f58664e9877", "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 1 # BS\n", "LR = 0.001\n", "\n", "IN_DIM = 1 \n", "OUT_DIM = 1\n", "LATENT_DIM = 32\n", "\n", "DROPOUT_PROB = 0.1\n", "\n", "PADDING = 1.2 # Scaling\n", "\n", "GRID_SIZE = 128\n", "SIGMA = 5.0" ] }, { "cell_type": "code", "execution_count": 5, "id": "27b7a406-cbb0-4a36-be1e-a8d8aa82c702", "metadata": { "tags": [] }, "outputs": [], "source": [ "DATASET = \"Synthetic\"\n", "LOG_IDX = 14\n", "LOG_VISUALS = not IS_DEBUG\n", "\n", "CHECKPOINTS_PATH = \"./checkpoints/\"\n", "\n", "FIELDS_H5_PATH = f\"./Standart_fields/{DATASET}_fields_32_512.h5\"\n", "PATH_ORIG_H5 = f\"./Standart_h5/{DATASET}.h5\"\n", "PATH_NOISY_H5 = f\"./Standart_h5/{DATASET}_noisy.h5\"\n", "MIN_V_NUMBER = 1_000\n", "MAX_V_NUMBER = 100_000" ] }, { "cell_type": "markdown", "id": "1690b667-0af4-465a-8e3c-4a29622e9e66", "metadata": { "tags": [] }, "source": [ "# Data Preparation" ] }, { "cell_type": "code", "execution_count": 6, "id": "2e774809-1293-4f80-8350-59ae7fc86cbb", "metadata": {}, "outputs": [], "source": [ "@numba.njit\n", "def generate_grid_edge_list(gs: int = 128):\n", " grid_edge_list = []\n", "\n", " for k in range(gs):\n", " for j in range(gs):\n", " for i in range(gs):\n", " current_idx = i + gs*j + k*gs*gs\n", " if (i - 1) >= 0:\n", " grid_edge_list.append([current_idx, i-1 + gs*j + k*gs*gs])\n", " if (i + 1) < gs:\n", " grid_edge_list.append([current_idx, i+1 + gs*j + k*gs*gs])\n", " if (j - 1) >= 0:\n", " grid_edge_list.append([current_idx, i + gs*(j-1) + k*gs*gs])\n", " if (j + 1) < gs:\n", " grid_edge_list.append([current_idx, i + gs*(j+1) + k*gs*gs])\n", " if (k - 1) >= 0:\n", " grid_edge_list.append([current_idx, i + gs*j + (k-1)*gs*gs])\n", " if (k + 1) < gs:\n", " grid_edge_list.append([current_idx, i + gs*j + (k+1)*gs*gs])\n", " return grid_edge_list\n", "\n", "GRID_EDGE_LIST = None" ] }, { "cell_type": "code", "execution_count": 7, "id": "4486968b-3416-41c5-9ecd-429f7cf193de", "metadata": {}, "outputs": [], "source": [ "class StandartH5DataSet(th.utils.data.Dataset):\n", " \n", " def _load_data(self, key: str):\n", " key_orig = key.replace(\"_n1\", \"\")\n", " key_orig = key_orig.replace(\"_n2\", \"\")\n", " key_orig = key_orig.replace(\"_n3\", \"\")\n", " key_orig = key_orig.replace(\"_noisy\", \"\")\n", "\n", " vertices = th.tensor(self._noisy_meshes_h5[key][\"vertices\"][:], dtype=th.float)\n", " vertices_normals = th.tensor(self._noisy_meshes_h5[key][\"vertices_normals\"][:], dtype=th.float)\n", " vertices_gt = th.tensor(self._orig_meshes_h5[key_orig][\"vertices\"][:], dtype=th.float)\n", " vertices_normals_gt = th.tensor(self._orig_meshes_h5[key_orig][\"vertices_normals\"][:], dtype=th.float)\n", " field_gt = self.dpsr(vertices_gt.unsqueeze(0), vertices_normals_gt.unsqueeze(0)).squeeze(0)\n", "\n", " adj = np.array(self._noisy_meshes_h5[key][\"edge_index\"][:], dtype=np.int64)\n", " adj = th.tensor(adj, dtype=th.int64)\n", " \n", " return vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj\n", " \n", " def __init__(self, \n", " orig_meshes_h5: h5py.Group,\n", " noisy_meshes_h5: h5py.Group,\n", " fields_grid_size: int,\n", " min_verts: int,\n", " max_verts: int) -> None:\n", " super().__init__()\n", " \n", " self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=SIGMA)\n", " \n", " self._orig_meshes_h5 = orig_meshes_h5\n", " self._noisy_meshes_h5 = noisy_meshes_h5\n", " \n", " self._fields_grid_size = str(fields_grid_size)\n", " self._min_verts = min_verts\n", " self._max_verts = max_verts\n", " \n", " self._data = {}\n", " self._keys = []\n", " \n", " # filter keys to load only meshes with requested amount of vertices\n", " for key in self._noisy_meshes_h5.keys():\n", " v_number = self._noisy_meshes_h5[key][\"vertices\"].shape[0]\n", " if (v_number >= self._min_verts) and (v_number <= self._max_verts):\n", " self._keys.append(key)\n", " self._keys = np.array(self._keys, dtype=str)\n", " self._loaded = np.full(shape=self._keys.shape, fill_value=False, dtype=bool)\n", " \n", " def __len__(self) -> int:\n", " return self._keys.shape[0]\n", " \n", " def __getitem__(self, index: int) -> Tuple[Float[th.Tensor, \"N 3\"],\n", " Float[th.Tensor, \"N 3\"],\n", " Float[th.Tensor, \"N 3\"],\n", " Float[th.Tensor, \"N 3\"],\n", " Float[th.Tensor, \"GR GR GR\"],\n", " Float[th.Tensor, \"2 E\"]]:\n", " if self._loaded[index] == False:\n", " data = self._load_data(self._keys[index])\n", " self._data[index] = data\n", " self._loaded[index] = True\n", " return copy.deepcopy(self._data[index])\n", " \n", " @property\n", " def fields_grid_size(self):\n", " return int(self._fields_grid_size)\n", " \n", " def renew_grid_size(self, new_grid_size: int):\n", " self._fields_grid_size = str(new_grid_size)\n", " self._loaded = np.full(shape=self._keys.shape, fill_value=False, dtype=bool)" ] }, { "cell_type": "markdown", "id": "13c69a49-5107-4d3e-9b14-1d456768f128", "metadata": { "tags": [] }, "source": [ "# Model" ] }, { "cell_type": "markdown", "id": "1d9a9aac-d229-489a-844d-a1d1cbd34c56", "metadata": { "tags": [] }, "source": [ "### Form Optimizer " ] }, { "cell_type": "code", "execution_count": 8, "id": "940babdc-3e4f-4310-8bfd-48b23d0758dc", "metadata": { "tags": [] }, "outputs": [], "source": [ "class FormOptimizer(th.nn.Module):\n", " def __init__(self) -> None:\n", " super().__init__()\n", " \n", " layers = []\n", " \n", " self.gconv1 = GATv2Conv(in_channels=IN_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)\n", " self.gconv2 = GATv2Conv(in_channels=LATENT_DIM, out_channels=LATENT_DIM, heads=1, dropout=DROPOUT_PROB)\n", " \n", " self.actv = th.nn.Sigmoid()\n", " self.head = th.nn.Linear(in_features=LATENT_DIM, out_features=OUT_DIM)\n", "\n", " def forward(self, \n", " field: Float[th.Tensor, \"GS GS GS\"]) -> Float[th.Tensor, \"GS GS GS\"]:\n", " \"\"\"\n", " Args:\n", " field (Tensor [GS, GS, GS]): vertices and normals tensor.\n", " \"\"\"\n", " vertex_features = field.clone()\n", " vertex_features = vertex_features.reshape(GRID_SIZE*GRID_SIZE*GRID_SIZE, IN_DIM)\n", " \n", " vertex_features = self.gconv1(x=vertex_features, edge_index=GRID_EDGE_LIST) \n", " vertex_features = self.gconv2(x=vertex_features, edge_index=GRID_EDGE_LIST) \n", " field_delta = self.head(self.actv(vertex_features))\n", " \n", " field_delta = field_delta.reshape(BATCH_SIZE, GRID_SIZE, GRID_SIZE, GRID_SIZE)\n", " field_delta += field \n", " field_delta = th.clamp(field_delta, min=-0.5, max=0.5)\n", " \n", " return field_delta" ] }, { "cell_type": "markdown", "id": "67b40c5b-ff1b-416d-b892-c544386eaa95", "metadata": { "toc-hr-collapsed": true }, "source": [ "### Full" ] }, { "cell_type": "code", "execution_count": 9, "id": "bce3aa63-9bd7-4ac8-939d-395d63dd3cad", "metadata": { "scrolled": true, "tags": [] }, "outputs": [], "source": [ "class Model(pl.LightningModule):\n", " def __init__(self):\n", " super().__init__()\n", " self.form_optimizer = FormOptimizer()\n", " \n", " self.dpsr = DPSR([GRID_SIZE, GRID_SIZE, GRID_SIZE], sig=SIGMA)\n", " self.field2mesh = PSR2Mesh().apply\n", "\n", " self.metric = th.nn.MSELoss()\n", "\n", " #video logging databases\n", " dateTimeObj = datetime.now()\n", " start_time = dateTimeObj.strftime(\"%d-%b-%Y_%H-%M-%S\")\n", " \n", " if LOG_VISUALS:\n", " self.h5_frame = 0\n", " self.log_points_file = h5py.File(f\"./logs/points_{start_time}\", \"w\")\n", " self.log_normals_file = h5py.File(f\"./logs/normals_{start_time}\", \"w\")\n", " \n", " self.val_losses = []\n", " self.train_losses = []\n", "\n", " def log_h5(self, points, normals):\n", " dset = self.log_points_file.create_dataset(\n", " name=str(self.h5_frame),\n", " shape=points.shape,\n", " dtype=np.float16, \n", " compression=\"gzip\")\n", " dset[:] = points\n", " dset = self.log_normals_file.create_dataset(\n", " name=str(self.h5_frame),\n", " shape=normals.shape,\n", " dtype=np.float16, \n", " compression=\"gzip\")\n", " dset[:] = normals\n", " self.h5_frame += 1\n", " \n", " def forward(self, \n", " v: Float[th.Tensor, \"BS N 3\"],\n", " n: Float[th.Tensor, \"BS N 3\"]) -> Tuple[Float[th.Tensor, \"BS N 3\"], # v - vertices\n", " Int[th.Tensor, \"2 E\"], # f - faces\n", " Float[th.Tensor, \"BS N 3\"], # n - vertices normals\n", " Float[th.Tensor, \"BS GR GR GR\"]]: # field: \n", " field = self.dpsr(v, n)\n", " field = self.form_optimizer(field)\n", " v, f, n = self.field2mesh(field)\n", " return v, f, n, field\n", "\n", " def training_step(self, batch, batch_idx) -> Float[th.Tensor, \"1\"]:\n", " vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch\n", " \n", " mask = th.rand((vertices.shape[1], ), device=th.device(\"cuda\")) < (random.random() / 2.0 + 0.5)\n", " vertices = vertices[:, mask]\n", " vertices_normals = vertices_normals[:, mask]\n", " \n", " vr, fr, nr, field_r = model(vertices, vertices_normals)\n", " \n", " loss = self.metric(field_r, field_gt)\n", " if LOG_VISUALS and (LOG_IDX == batch_idx):\n", " self.log_h5(vr.squeeze(0).detach().cpu().numpy(), nr.squeeze(0).detach().cpu().numpy())\n", " train_per_step_loss = loss.item()\n", " self.train_losses.append(train_per_step_loss)\n", " \n", " return loss\n", " \n", " def on_train_epoch_end(self):\n", " mean_train_per_epoch_loss = mean(self.train_losses)\n", " self.log(\"mean_train_per_epoch_loss\", mean_train_per_epoch_loss, on_step=False, on_epoch=True)\n", " self.train_losses = []\n", " \n", " def validation_step(self, batch, batch_idx):\n", " vertices, vertices_normals, vertices_gt, vertices_normals_gt, field_gt, adj = batch\n", " \n", " vr, fr, nr, field_r = model(vertices, vertices_normals)\n", " \n", " loss = self.metric(field_r, field_gt)\n", " val_per_step_loss = loss.item()\n", " self.val_losses.append(val_per_step_loss)\n", " return loss\n", " \n", " def on_validation_epoch_end(self):\n", " mean_val_per_epoch_loss = mean(self.val_losses)\n", " self.log(\"mean_val_per_epoch_loss\", mean_val_per_epoch_loss, on_step=False, on_epoch=True)\n", " self.val_losses = []\n", "\n", " def configure_optimizers(self):\n", " optimizer = th.optim.Adam(self.parameters(), lr=LR)\n", " scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)\n", " \n", " return {\n", " \"optimizer\": optimizer,\n", " \"lr_scheduler\": {\n", " \"scheduler\": scheduler, \n", " \"monitor\": \"mean_val_per_epoch_loss\",\n", " \"interval\": \"epoch\",\n", " \"frequency\": 1,\n", " \"strict\": True,\n", " \"name\": None,\n", " }\n", " }\n" ] }, { "cell_type": "markdown", "id": "1fb2c5a5-43ee-4a4e-be08-0dcfcb6816de", "metadata": { "tags": [] }, "source": [ "# Loop" ] }, { "cell_type": "code", "execution_count": 10, "id": "c94c6a68-3986-48af-9da5-cab8c02a8b7b", "metadata": {}, "outputs": [], "source": [ "checkpoint_callback = ModelCheckpoint(\n", " monitor='mean_val_per_epoch_loss', # monitor the validation loss\n", " mode='min', # mode 'min' to save the lowest monitored value\n", " save_top_k=1, # save only the best checkpoint (top 1)\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "id": "03cdddbc-223e-4d40-9fb0-e663beddefda", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/user/conda/envs/senv/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525551200/work/aten/src/ATen/native/TensorShape.cpp:3190.)\n", " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", "/home/user/conda/envs/senv/lib/python3.9/site-packages/lightning_fabric/connector.py:554: UserWarning: 16 is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!\n", " rank_zero_warn(\n", "Using 16bit Automatic Mixed Precision (AMP)\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "Running in `fast_dev_run` mode: will run the requested loop using 300 batch(es). Logging and checkpointing is suppressed.\n", "You are using a CUDA device ('A100-PCIE-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", "/home/user/conda/envs/senv/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:617: UserWarning: Checkpoint directory /home/jovyan/Mashurov/GINSAP/checkpoints exists and is not empty.\n", " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "-------------------------------------------------\n", "0 | form_optimizer | FormOptimizer | 2.4 K \n", "1 | dpsr | DPSR | 0 \n", "2 | metric | MSELoss | 0 \n", "-------------------------------------------------\n", "2.4 K Trainable params\n", "0 Non-trainable params\n", "2.4 K Total params\n", "0.010 Total estimated model params size (MB)\n", "/home/user/conda/envs/senv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n", "/home/user/conda/envs/senv/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:442: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 64 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", " rank_zero_warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████| 60/60 [00:17<00:00, 3.52it/s]\n", "Validation: 0it [00:00, ?it/s]\u001b[A\n", "Validation: 0%| | 0/84 [00:00