File size: 20,384 Bytes
7088d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fit a volume via raymarching\n",
    "\n",
    "This tutorial shows how to fit a volume given a set of views of a scene using differentiable volumetric rendering.\n",
    "\n",
    "More specifically, this tutorial will explain how to:\n",
    "1. Create a differentiable volumetric renderer.\n",
    "2. Create a Volumetric model (including how to use the `Volumes` class).\n",
    "3. Fit the volume based on the images using the differentiable volumetric renderer. \n",
    "4. Visualize the predicted volume."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Install and Import modules\n",
    "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import subprocess\n",
    "need_pytorch3d=False\n",
    "try:\n",
    "    import pytorch3d\n",
    "except ModuleNotFoundError:\n",
    "    need_pytorch3d=True\n",
    "if need_pytorch3d:\n",
    "    pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
    "    version_str=\"\".join([\n",
    "        f\"py3{sys.version_info.minor}_cu\",\n",
    "        torch.version.cuda.replace(\".\",\"\"),\n",
    "        f\"_pyt{pyt_version_str}\"\n",
    "    ])\n",
    "    !pip install fvcore iopath\n",
    "    if sys.platform.startswith(\"linux\"):\n",
    "        print(\"Trying to install wheel for PyTorch3D\")\n",
    "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
    "        pip_list = !pip freeze\n",
    "        need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for  i in pip_list)\n",
    "    if need_pytorch3d:\n",
    "        print(f\"failed to find/install wheel for {version_str}\")\n",
    "if need_pytorch3d:\n",
    "    print(\"Installing PyTorch3D from source\")\n",
    "    !pip install ninja\n",
    "    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import time\n",
    "import json\n",
    "import glob\n",
    "import torch\n",
    "import math\n",
    "from tqdm.notebook import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from IPython import display\n",
    "\n",
    "# Data structures and functions for rendering\n",
    "from pytorch3d.structures import Volumes\n",
    "from pytorch3d.renderer import (\n",
    "    FoVPerspectiveCameras, \n",
    "    VolumeRenderer,\n",
    "    NDCMultinomialRaysampler,\n",
    "    EmissionAbsorptionRaymarcher\n",
    ")\n",
    "from pytorch3d.transforms import so3_exp_map\n",
    "\n",
    "# obtain the utilized device\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    torch.cuda.set_device(device)\n",
    "else:\n",
    "    device = torch.device(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/plot_image_grid.py\n",
    "!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/generate_cow_renders.py\n",
    "from plot_image_grid import image_grid\n",
    "from generate_cow_renders import generate_cow_renders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OR if running locally uncomment and run the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from utils.generate_cow_renders import generate_cow_renders\n",
    "# from utils import image_grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Generate images of the scene and masks\n",
    "\n",
    "The following cell generates our training data.\n",
    "It renders the cow mesh from the `fit_textured_mesh.ipynb` tutorial from several viewpoints and returns:\n",
    "1. A batch of image and silhouette tensors that are produced by the cow mesh renderer.\n",
    "2. A set of cameras corresponding to each render.\n",
    "\n",
    "Note: For the purpose of this tutorial, which aims at explaining the details of volumetric rendering, we do not explain how the mesh rendering, implemented in the `generate_cow_renders` function, works. Please refer to `fit_textured_mesh.ipynb` for a detailed explanation of mesh rendering."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=40)\n",
    "print(f'Generated {len(target_images)} images/silhouettes/cameras.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Initialize the volumetric renderer\n",
    "\n",
    "The following initializes a volumetric renderer that emits a ray from each pixel of a target image and samples a set of uniformly-spaced points along the ray. At each ray-point, the corresponding density and color value is obtained by querying the corresponding location in the volumetric model of the scene (the model is described & instantiated in a later cell).\n",
    "\n",
    "The renderer is composed of a *raymarcher* and a *raysampler*.\n",
    "- The *raysampler* is responsible for emitting rays from image pixels and sampling the points along them. Here, we use the `NDCMultinomialRaysampler` which follows the standard PyTorch3D coordinate grid convention (+X from right to left; +Y from bottom to top; +Z away from the user).\n",
    "- The *raymarcher* takes the densities and colors sampled along each ray and renders each ray into a color and an opacity value of the ray's source pixel. Here we use the `EmissionAbsorptionRaymarcher` which implements the standard Emission-Absorption raymarching algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# render_size describes the size of both sides of the \n",
    "# rendered images in pixels. We set this to the same size\n",
    "# as the target images. I.e. we render at the same\n",
    "# size as the ground truth images.\n",
    "render_size = target_images.shape[1]\n",
    "\n",
    "# Our rendered scene is centered around (0,0,0) \n",
    "# and is enclosed inside a bounding box\n",
    "# whose side is roughly equal to 3.0 (world units).\n",
    "volume_extent_world = 3.0\n",
    "\n",
    "# 1) Instantiate the raysampler.\n",
    "# Here, NDCMultinomialRaysampler generates a rectangular image\n",
    "# grid of rays whose coordinates follow the PyTorch3D\n",
    "# coordinate conventions.\n",
    "# Since we use a volume of size 128^3, we sample n_pts_per_ray=150,\n",
    "# which roughly corresponds to a one ray-point per voxel.\n",
    "# We further set the min_depth=0.1 since there is no surface within\n",
    "# 0.1 units of any camera plane.\n",
    "raysampler = NDCMultinomialRaysampler(\n",
    "    image_width=render_size,\n",
    "    image_height=render_size,\n",
    "    n_pts_per_ray=150,\n",
    "    min_depth=0.1,\n",
    "    max_depth=volume_extent_world,\n",
    ")\n",
    "\n",
    "\n",
    "# 2) Instantiate the raymarcher.\n",
    "# Here, we use the standard EmissionAbsorptionRaymarcher \n",
    "# which marches along each ray in order to render\n",
    "# each ray into a single 3D color vector \n",
    "# and an opacity scalar.\n",
    "raymarcher = EmissionAbsorptionRaymarcher()\n",
    "\n",
    "# Finally, instantiate the volumetric render\n",
    "# with the raysampler and raymarcher objects.\n",
    "renderer = VolumeRenderer(\n",
    "    raysampler=raysampler, raymarcher=raymarcher,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Initialize the volumetric model\n",
    "\n",
    "Next we instantiate a volumetric model of the scene. This quantizes the 3D space to cubical voxels, where each voxel is described with a 3D vector representing the voxel's RGB color and a density scalar which describes the opacity of the voxel (ranging between [0-1], the higher the more opaque).\n",
    "\n",
    "In order to ensure the range of densities and colors is between [0-1], we represent both volume colors and densities in the logarithmic space. During the forward function of the model, the log-space values are passed through the sigmoid function to bring the log-space values to the correct range.\n",
    "\n",
    "Additionally, `VolumeModel` contains the renderer object. This object stays unaltered throughout the optimization.\n",
    "\n",
    "In this cell we also define the `huber` loss function which computes the discrepancy between the rendered colors and masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VolumeModel(torch.nn.Module):\n",
    "    def __init__(self, renderer, volume_size=[64] * 3, voxel_size=0.1):\n",
    "        super().__init__()\n",
    "        # After evaluating torch.sigmoid(self.log_colors), we get \n",
    "        # densities close to zero.\n",
    "        self.log_densities = torch.nn.Parameter(-4.0 * torch.ones(1, *volume_size))\n",
    "        # After evaluating torch.sigmoid(self.log_colors), we get \n",
    "        # a neutral gray color everywhere.\n",
    "        self.log_colors = torch.nn.Parameter(torch.zeros(3, *volume_size))\n",
    "        self._voxel_size = voxel_size\n",
    "        # Store the renderer module as well.\n",
    "        self._renderer = renderer\n",
    "        \n",
    "    def forward(self, cameras):\n",
    "        batch_size = cameras.R.shape[0]\n",
    "\n",
    "        # Convert the log-space values to the densities/colors\n",
    "        densities = torch.sigmoid(self.log_densities)\n",
    "        colors = torch.sigmoid(self.log_colors)\n",
    "        \n",
    "        # Instantiate the Volumes object, making sure\n",
    "        # the densities and colors are correctly\n",
    "        # expanded batch_size-times.\n",
    "        volumes = Volumes(\n",
    "            densities = densities[None].expand(\n",
    "                batch_size, *self.log_densities.shape),\n",
    "            features = colors[None].expand(\n",
    "                batch_size, *self.log_colors.shape),\n",
    "            voxel_size=self._voxel_size,\n",
    "        )\n",
    "        \n",
    "        # Given cameras and volumes, run the renderer\n",
    "        # and return only the first output value \n",
    "        # (the 2nd output is a representation of the sampled\n",
    "        # rays which can be omitted for our purpose).\n",
    "        return self._renderer(cameras=cameras, volumes=volumes)[0]\n",
    "    \n",
    "# A helper function for evaluating the smooth L1 (huber) loss\n",
    "# between the rendered silhouettes and colors.\n",
    "def huber(x, y, scaling=0.1):\n",
    "    diff_sq = (x - y) ** 2\n",
    "    loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Fit the volume\n",
    "\n",
    "Here we carry out the volume fitting with differentiable rendering.\n",
    "\n",
    "In order to fit the volume, we render it from the viewpoints of the `target_cameras`\n",
    "and compare the resulting renders with the observed `target_images` and `target_silhouettes`.\n",
    "\n",
    "The comparison is done by evaluating the mean huber (smooth-l1) error between corresponding\n",
    "pairs of `target_images`/`rendered_images` and `target_silhouettes`/`rendered_silhouettes`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First move all relevant variables to the correct device.\n",
    "target_cameras = target_cameras.to(device)\n",
    "target_images = target_images.to(device)\n",
    "target_silhouettes = target_silhouettes.to(device)\n",
    "\n",
    "# Instantiate the volumetric model.\n",
    "# We use a cubical volume with the size of \n",
    "# one side = 128. The size of each voxel of the volume \n",
    "# is set to volume_extent_world / volume_size s.t. the\n",
    "# volume represents the space enclosed in a 3D bounding box\n",
    "# centered at (0, 0, 0) with the size of each side equal to 3.\n",
    "volume_size = 128\n",
    "volume_model = VolumeModel(\n",
    "    renderer,\n",
    "    volume_size=[volume_size] * 3, \n",
    "    voxel_size = volume_extent_world / volume_size,\n",
    ").to(device)\n",
    "\n",
    "# Instantiate the Adam optimizer. We set its master learning rate to 0.1.\n",
    "lr = 0.1\n",
    "optimizer = torch.optim.Adam(volume_model.parameters(), lr=lr)\n",
    "\n",
    "# We do 300 Adam iterations and sample 10 random images in each minibatch.\n",
    "batch_size = 10\n",
    "n_iter = 300\n",
    "for iteration in range(n_iter):\n",
    "\n",
    "    # In case we reached the last 75% of iterations,\n",
    "    # decrease the learning rate of the optimizer 10-fold.\n",
    "    if iteration == round(n_iter * 0.75):\n",
    "        print('Decreasing LR 10-fold ...')\n",
    "        optimizer = torch.optim.Adam(\n",
    "            volume_model.parameters(), lr=lr * 0.1\n",
    "        )\n",
    "    \n",
    "    # Zero the optimizer gradient.\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Sample random batch indices.\n",
    "    batch_idx = torch.randperm(len(target_cameras))[:batch_size]\n",
    "    \n",
    "    # Sample the minibatch of cameras.\n",
    "    batch_cameras = FoVPerspectiveCameras(\n",
    "        R = target_cameras.R[batch_idx], \n",
    "        T = target_cameras.T[batch_idx], \n",
    "        znear = target_cameras.znear[batch_idx],\n",
    "        zfar = target_cameras.zfar[batch_idx],\n",
    "        aspect_ratio = target_cameras.aspect_ratio[batch_idx],\n",
    "        fov = target_cameras.fov[batch_idx],\n",
    "        device = device,\n",
    "    )\n",
    "    \n",
    "    # Evaluate the volumetric model.\n",
    "    rendered_images, rendered_silhouettes = volume_model(\n",
    "        batch_cameras\n",
    "    ).split([3, 1], dim=-1)\n",
    "    \n",
    "    # Compute the silhouette error as the mean huber\n",
    "    # loss between the predicted masks and the\n",
    "    # target silhouettes.\n",
    "    sil_err = huber(\n",
    "        rendered_silhouettes[..., 0], target_silhouettes[batch_idx],\n",
    "    ).abs().mean()\n",
    "\n",
    "    # Compute the color error as the mean huber\n",
    "    # loss between the rendered colors and the\n",
    "    # target ground truth images.\n",
    "    color_err = huber(\n",
    "        rendered_images, target_images[batch_idx],\n",
    "    ).abs().mean()\n",
    "    \n",
    "    # The optimization loss is a simple\n",
    "    # sum of the color and silhouette errors.\n",
    "    loss = color_err + sil_err \n",
    "    \n",
    "    # Print the current values of the losses.\n",
    "    if iteration % 10 == 0:\n",
    "        print(\n",
    "            f'Iteration {iteration:05d}:'\n",
    "            + f' color_err = {float(color_err):1.2e}'\n",
    "            + f' mask_err = {float(sil_err):1.2e}'\n",
    "        )\n",
    "    \n",
    "    # Take the optimization step.\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    # Visualize the renders every 40 iterations.\n",
    "    if iteration % 40 == 0:\n",
    "        # Visualize only a single randomly selected element of the batch.\n",
    "        im_show_idx = int(torch.randint(low=0, high=batch_size, size=(1,)))\n",
    "        fig, ax = plt.subplots(2, 2, figsize=(10, 10))\n",
    "        ax = ax.ravel()\n",
    "        clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()\n",
    "        ax[0].imshow(clamp_and_detach(rendered_images[im_show_idx]))\n",
    "        ax[1].imshow(clamp_and_detach(target_images[batch_idx[im_show_idx], ..., :3]))\n",
    "        ax[2].imshow(clamp_and_detach(rendered_silhouettes[im_show_idx, ..., 0]))\n",
    "        ax[3].imshow(clamp_and_detach(target_silhouettes[batch_idx[im_show_idx]]))\n",
    "        for ax_, title_ in zip(\n",
    "            ax, \n",
    "            (\"rendered image\", \"target image\", \"rendered silhouette\", \"target silhouette\")\n",
    "        ):\n",
    "            ax_.grid(\"off\")\n",
    "            ax_.axis(\"off\")\n",
    "            ax_.set_title(title_)\n",
    "        fig.canvas.draw(); fig.show()\n",
    "        display.clear_output(wait=True)\n",
    "        display.display(fig)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Visualizing the optimized volume\n",
    "\n",
    "Finally, we visualize the optimized volume by rendering from multiple viewpoints that rotate around the volume's y-axis."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_rotating_volume(volume_model, n_frames = 50):\n",
    "    logRs = torch.zeros(n_frames, 3, device=device)\n",
    "    logRs[:, 1] = torch.linspace(0.0, 2.0 * 3.14, n_frames, device=device)\n",
    "    Rs = so3_exp_map(logRs)\n",
    "    Ts = torch.zeros(n_frames, 3, device=device)\n",
    "    Ts[:, 2] = 2.7\n",
    "    frames = []\n",
    "    print('Generating rotating volume ...')\n",
    "    for R, T in zip(tqdm(Rs), Ts):\n",
    "        camera = FoVPerspectiveCameras(\n",
    "            R=R[None], \n",
    "            T=T[None], \n",
    "            znear = target_cameras.znear[0],\n",
    "            zfar = target_cameras.zfar[0],\n",
    "            aspect_ratio = target_cameras.aspect_ratio[0],\n",
    "            fov = target_cameras.fov[0],\n",
    "            device=device,\n",
    "        )\n",
    "        frames.append(volume_model(camera)[..., :3].clamp(0.0, 1.0))\n",
    "    return torch.cat(frames)\n",
    "    \n",
    "with torch.no_grad():\n",
    "    rotating_volume_frames = generate_rotating_volume(volume_model, n_frames=7*4)\n",
    "\n",
    "image_grid(rotating_volume_frames.clamp(0., 1.).cpu().numpy(), rows=4, cols=7, rgb=True, fill=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Conclusion\n",
    "\n",
    "In this tutorial, we have shown how to optimize a 3D volumetric representation of a scene such that the renders of the volume from known viewpoints match the observed images for each viewpoint. The rendering was carried out using the PyTorch3D's volumetric renderer composed of an `NDCMultinomialRaysampler` and an `EmissionAbsorptionRaymarcher`."
   ]
  }
 ],
 "metadata": {
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "kernelspec": {
   "display_name": "pytorch3d_etc (local)",
   "language": "python",
   "name": "pytorch3d_etc_local"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}