File size: 32,586 Bytes
7088d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_Ip8kp4TfBLZ"
   },
   "outputs": [],
   "source": [
    "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kuXHJv44fBLe"
   },
   "source": [
    "# Fit a mesh via rendering\n",
    "\n",
    "This tutorial shows how to:\n",
    "- Load a mesh and textures from an `.obj` file. \n",
    "- Create a synthetic dataset by rendering a textured mesh from multiple viewpoints\n",
    "- Fit a mesh to the observed synthetic images using differential silhouette rendering\n",
    "- Fit a mesh and its textures using differential textured rendering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Bnj3THhzfBLf"
   },
   "source": [
    "## 0. Install and Import modules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "okLalbR_g7NS"
   },
   "source": [
    "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "musUWTglgxSB"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import subprocess\n",
    "need_pytorch3d=False\n",
    "try:\n",
    "    import pytorch3d\n",
    "except ModuleNotFoundError:\n",
    "    need_pytorch3d=True\n",
    "if need_pytorch3d:\n",
    "    pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
    "    version_str=\"\".join([\n",
    "        f\"py3{sys.version_info.minor}_cu\",\n",
    "        torch.version.cuda.replace(\".\",\"\"),\n",
    "        f\"_pyt{pyt_version_str}\"\n",
    "    ])\n",
    "    !pip install fvcore iopath\n",
    "    if sys.platform.startswith(\"linux\"):\n",
    "        print(\"Trying to install wheel for PyTorch3D\")\n",
    "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
    "        pip_list = !pip freeze\n",
    "        need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for  i in pip_list)\n",
    "    if need_pytorch3d:\n",
    "        print(f\"failed to find/install wheel for {version_str}\")\n",
    "if need_pytorch3d:\n",
    "    print(\"Installing PyTorch3D from source\")\n",
    "    !pip install ninja\n",
    "    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nX99zdoffBLg"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from pytorch3d.utils import ico_sphere\n",
    "import numpy as np\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Util function for loading meshes\n",
    "from pytorch3d.io import load_objs_as_meshes, save_obj\n",
    "\n",
    "from pytorch3d.loss import (\n",
    "    chamfer_distance, \n",
    "    mesh_edge_loss, \n",
    "    mesh_laplacian_smoothing, \n",
    "    mesh_normal_consistency,\n",
    ")\n",
    "\n",
    "# Data structures and functions for rendering\n",
    "from pytorch3d.structures import Meshes\n",
    "from pytorch3d.renderer import (\n",
    "    look_at_view_transform,\n",
    "    FoVPerspectiveCameras, \n",
    "    PointLights, \n",
    "    DirectionalLights, \n",
    "    Materials, \n",
    "    RasterizationSettings, \n",
    "    MeshRenderer, \n",
    "    MeshRasterizer,  \n",
    "    SoftPhongShader,\n",
    "    SoftSilhouetteShader,\n",
    "    SoftPhongShader,\n",
    "    TexturesVertex\n",
    ")\n",
    "\n",
    "# add path for demo utils functions \n",
    "import sys\n",
    "import os\n",
    "sys.path.append(os.path.abspath(''))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Lxmehq6Zhrzv"
   },
   "source": [
    "If using **Google Colab**, fetch the utils file for plotting image grids:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HZozr3Pmho-5"
   },
   "outputs": [],
   "source": [
    "!wget https://raw.githubusercontent.com/facebookresearch/pytorch3d/main/docs/tutorials/utils/plot_image_grid.py\n",
    "from plot_image_grid import image_grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "g4B62MzYiJUM"
   },
   "source": [
    "OR if running **locally** uncomment and run the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "paJ4Im8ahl7O"
   },
   "outputs": [],
   "source": [
    "#  from utils.plot_image_grid import image_grid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "collapsed": true,
    "id": "5jGq772XfBLk"
   },
   "source": [
    "### 1. Load a mesh and texture file\n",
    "\n",
    "Load an `.obj` file and its associated `.mtl` file and create a **Textures** and **Meshes** object. \n",
    "\n",
    "**Meshes** is a unique datastructure provided in PyTorch3D for working with batches of meshes of different sizes. \n",
    "\n",
    "**TexturesVertex** is an auxiliary datastructure for storing vertex rgb texture information about meshes. \n",
    "\n",
    "**Meshes** has several class methods which are used throughout the rendering pipeline."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "a8eU4zo5jd_H"
   },
   "source": [
    "If running this notebook using **Google Colab**, run the following cell to fetch the mesh obj and texture files and save it at the path `data/cow_mesh`:\n",
    "If running locally, the data is already available at the correct path. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "tTm0cVuOjb1W"
   },
   "outputs": [],
   "source": [
    "!mkdir -p data/cow_mesh\n",
    "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n",
    "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n",
    "!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gi5Kd0GafBLl"
   },
   "outputs": [],
   "source": [
    "# Setup\n",
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda:0\")\n",
    "    torch.cuda.set_device(device)\n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "# Set paths\n",
    "DATA_DIR = \"./data\"\n",
    "obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
    "\n",
    "# Load obj file\n",
    "mesh = load_objs_as_meshes([obj_filename], device=device)\n",
    "\n",
    "# We scale normalize and center the target mesh to fit in a sphere of radius 1 \n",
    "# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh \n",
    "# to its original center and scale.  Note that normalizing the target mesh, \n",
    "# speeds up the optimization but is not necessary!\n",
    "verts = mesh.verts_packed()\n",
    "N = verts.shape[0]\n",
    "center = verts.mean(0)\n",
    "scale = max((verts - center).abs().max(0)[0])\n",
    "mesh.offset_verts_(-center)\n",
    "mesh.scale_verts_((1.0 / float(scale)));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "17c4xmtyfBMH"
   },
   "source": [
    "## 2. Dataset Creation\n",
    "\n",
    "We sample different camera positions that encode multiple viewpoints of the cow.  We create a renderer with a shader that performs texture map interpolation.  We render a synthetic dataset of images of the textured cow mesh from multiple viewpoints.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CDQKebNNfBMI"
   },
   "outputs": [],
   "source": [
    "# the number of different viewpoints from which we want to render the mesh.\n",
    "num_views = 20\n",
    "\n",
    "# Get a batch of viewing angles. \n",
    "elev = torch.linspace(0, 360, num_views)\n",
    "azim = torch.linspace(-180, 180, num_views)\n",
    "\n",
    "# Place a point light in front of the object. As mentioned above, the front of \n",
    "# the cow is facing the -z direction. \n",
    "lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])\n",
    "\n",
    "# Initialize an OpenGL perspective camera that represents a batch of different \n",
    "# viewing angles. All the cameras helper methods support mixed type inputs and \n",
    "# broadcasting. So we can view the camera from the a distance of dist=2.7, and \n",
    "# then specify elevation and azimuth angles for each viewpoint as tensors. \n",
    "R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)\n",
    "cameras = FoVPerspectiveCameras(device=device, R=R, T=T)\n",
    "\n",
    "# We arbitrarily choose one particular view that will be used to visualize \n",
    "# results\n",
    "camera = FoVPerspectiveCameras(device=device, R=R[None, 1, ...], \n",
    "                                  T=T[None, 1, ...]) \n",
    "\n",
    "# Define the settings for rasterization and shading. Here we set the output \n",
    "# image to be of size 128X128. As we are rendering images for visualization \n",
    "# purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to \n",
    "# rasterize_meshes.py for explanations of these parameters.  We also leave \n",
    "# bin_size and max_faces_per_bin to their default values of None, which sets \n",
    "# their values using heuristics and ensures that the faster coarse-to-fine \n",
    "# rasterization method is used.  Refer to docs/notes/renderer.md for an \n",
    "# explanation of the difference between naive and coarse-to-fine rasterization. \n",
    "raster_settings = RasterizationSettings(\n",
    "    image_size=128, \n",
    "    blur_radius=0.0, \n",
    "    faces_per_pixel=1, \n",
    ")\n",
    "\n",
    "# Create a Phong renderer by composing a rasterizer and a shader. The textured \n",
    "# Phong shader will interpolate the texture uv coordinates for each vertex, \n",
    "# sample from a texture image and apply the Phong lighting model\n",
    "renderer = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=camera, \n",
    "        raster_settings=raster_settings\n",
    "    ),\n",
    "    shader=SoftPhongShader(\n",
    "        device=device, \n",
    "        cameras=camera,\n",
    "        lights=lights\n",
    "    )\n",
    ")\n",
    "\n",
    "# Create a batch of meshes by repeating the cow mesh and associated textures. \n",
    "# Meshes has a useful `extend` method which allows us do this very easily. \n",
    "# This also extends the textures. \n",
    "meshes = mesh.extend(num_views)\n",
    "\n",
    "# Render the cow mesh from each viewing angle\n",
    "target_images = renderer(meshes, cameras=cameras, lights=lights)\n",
    "\n",
    "# Our multi-view cow dataset will be represented by these 2 lists of tensors,\n",
    "# each of length num_views.\n",
    "target_rgb = [target_images[i, ..., :3] for i in range(num_views)]\n",
    "target_cameras = [FoVPerspectiveCameras(device=device, R=R[None, i, ...], \n",
    "                                           T=T[None, i, ...]) for i in range(num_views)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TppB4PVmR1Rc"
   },
   "source": [
    "Visualize the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HHE0CnbVR1Rd"
   },
   "outputs": [],
   "source": [
    "# RGB images\n",
    "image_grid(target_images.cpu().numpy(), rows=4, cols=5, rgb=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gOb4rYx65E8z"
   },
   "source": [
    "Later in this tutorial, we will fit a mesh to the rendered RGB images, as well as to just images of just the cow silhouette.  For the latter case, we will render a dataset of silhouette images.  Most shaders in PyTorch3D will output an alpha channel along with the RGB image as a 4th channel in an RGBA image.  The alpha channel encodes the probability that each pixel belongs to the foreground of the object.  We construct a soft silhouette shader to render this alpha channel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "iP_g-nwX4exM"
   },
   "outputs": [],
   "source": [
    "# Rasterization settings for silhouette rendering  \n",
    "sigma = 1e-4\n",
    "raster_settings_silhouette = RasterizationSettings(\n",
    "    image_size=128, \n",
    "    blur_radius=np.log(1. / 1e-4 - 1.)*sigma, \n",
    "    faces_per_pixel=50, \n",
    ")\n",
    "\n",
    "# Silhouette renderer \n",
    "renderer_silhouette = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=camera, \n",
    "        raster_settings=raster_settings_silhouette\n",
    "    ),\n",
    "    shader=SoftSilhouetteShader()\n",
    ")\n",
    "\n",
    "# Render silhouette images.  The 3rd channel of the rendering output is \n",
    "# the alpha/silhouette channel\n",
    "silhouette_images = renderer_silhouette(meshes, cameras=cameras, lights=lights)\n",
    "target_silhouette = [silhouette_images[i, ..., 3] for i in range(num_views)]\n",
    "\n",
    "# Visualize silhouette images\n",
    "image_grid(silhouette_images.cpu().numpy(), rows=4, cols=5, rgb=False)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "t3qphI1ElUb5"
   },
   "source": [
    "## 3. Mesh prediction via silhouette rendering\n",
    "In the previous section, we created a dataset of images of multiple viewpoints of a cow.  In this section, we predict a mesh by observing those target images without any knowledge of the ground truth cow mesh.  We assume we know the position of the cameras and lighting.\n",
    "\n",
    "We first define some helper functions to visualize the results of our mesh prediction:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "eeWYHROrR1Rh"
   },
   "outputs": [],
   "source": [
    "# Show a visualization comparing the rendered predicted mesh to the ground truth \n",
    "# mesh\n",
    "def visualize_prediction(predicted_mesh, renderer=renderer_silhouette, \n",
    "                         target_image=target_rgb[1], title='', \n",
    "                         silhouette=False):\n",
    "    inds = 3 if silhouette else range(3)\n",
    "    with torch.no_grad():\n",
    "        predicted_images = renderer(predicted_mesh)\n",
    "    plt.figure(figsize=(20, 10))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.imshow(predicted_images[0, ..., inds].cpu().detach().numpy())\n",
    "\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.imshow(target_image.cpu().detach().numpy())\n",
    "    plt.title(title)\n",
    "    plt.axis(\"off\")\n",
    "\n",
    "# Plot losses as a function of optimization iteration\n",
    "def plot_losses(losses):\n",
    "    fig = plt.figure(figsize=(13, 5))\n",
    "    ax = fig.gca()\n",
    "    for k, l in losses.items():\n",
    "        ax.plot(l['values'], label=k + \" loss\")\n",
    "    ax.legend(fontsize=\"16\")\n",
    "    ax.set_xlabel(\"Iteration\", fontsize=\"16\")\n",
    "    ax.set_ylabel(\"Loss\", fontsize=\"16\")\n",
    "    ax.set_title(\"Loss vs iterations\", fontsize=\"16\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PpsvBpuMR1Ri"
   },
   "source": [
    "Starting from a sphere mesh, we will learn offsets of each vertex such that the predicted mesh silhouette is more similar to the target silhouette image at each optimization step.  We begin by loading our initial sphere mesh:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "i989ARH1R1Rj"
   },
   "outputs": [],
   "source": [
    "# We initialize the source shape to be a sphere of radius 1.  \n",
    "src_mesh = ico_sphere(4, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "f5xVtgLNDvC5"
   },
   "source": [
    "We create a new differentiable renderer for rendering the silhouette of our predicted mesh:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sXfjzgG4DsDJ"
   },
   "outputs": [],
   "source": [
    "# Rasterization settings for differentiable rendering, where the blur_radius\n",
    "# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable \n",
    "# Renderer for Image-based 3D Reasoning', ICCV 2019\n",
    "sigma = 1e-4\n",
    "raster_settings_soft = RasterizationSettings(\n",
    "    image_size=128, \n",
    "    blur_radius=np.log(1. / 1e-4 - 1.)*sigma, \n",
    "    faces_per_pixel=50, \n",
    ")\n",
    "\n",
    "# Silhouette renderer \n",
    "renderer_silhouette = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=camera, \n",
    "        raster_settings=raster_settings_soft\n",
    "    ),\n",
    "    shader=SoftSilhouetteShader()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SGJKbCB6R1Rk"
   },
   "source": [
    "We initialize settings, losses, and the optimizer that will be used to iteratively fit our mesh to the target silhouettes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0sLrKv_MEULh"
   },
   "outputs": [],
   "source": [
    "# Number of views to optimize over in each SGD iteration\n",
    "num_views_per_iteration = 2\n",
    "# Number of optimization steps\n",
    "Niter = 2000\n",
    "# Plot period for the losses\n",
    "plot_period = 250\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "# Optimize using rendered silhouette image loss, mesh edge loss, mesh normal \n",
    "# consistency, and mesh laplacian smoothing\n",
    "losses = {\"silhouette\": {\"weight\": 1.0, \"values\": []},\n",
    "          \"edge\": {\"weight\": 1.0, \"values\": []},\n",
    "          \"normal\": {\"weight\": 0.01, \"values\": []},\n",
    "          \"laplacian\": {\"weight\": 1.0, \"values\": []},\n",
    "         }\n",
    "\n",
    "# Losses to smooth / regularize the mesh shape\n",
    "def update_mesh_shape_prior_losses(mesh, loss):\n",
    "    # and (b) the edge length of the predicted mesh\n",
    "    loss[\"edge\"] = mesh_edge_loss(mesh)\n",
    "    \n",
    "    # mesh normal consistency\n",
    "    loss[\"normal\"] = mesh_normal_consistency(mesh)\n",
    "    \n",
    "    # mesh laplacian smoothing\n",
    "    loss[\"laplacian\"] = mesh_laplacian_smoothing(mesh, method=\"uniform\")\n",
    "\n",
    "# We will learn to deform the source mesh by offsetting its vertices\n",
    "# The shape of the deform parameters is equal to the total number of vertices in\n",
    "# src_mesh\n",
    "verts_shape = src_mesh.verts_packed().shape\n",
    "deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)\n",
    "\n",
    "# The optimizer\n",
    "optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QLc9zK8lEqFS"
   },
   "source": [
    "We write an optimization loop to iteratively refine our predicted mesh from the sphere mesh into a mesh that matches the silhouettes of the target images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gCfepfOoR1Rl"
   },
   "outputs": [],
   "source": [
    "loop = tqdm(range(Niter))\n",
    "\n",
    "for i in loop:\n",
    "    # Initialize optimizer\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Deform the mesh\n",
    "    new_src_mesh = src_mesh.offset_verts(deform_verts)\n",
    "    \n",
    "    # Losses to smooth /regularize the mesh shape\n",
    "    loss = {k: torch.tensor(0.0, device=device) for k in losses}\n",
    "    update_mesh_shape_prior_losses(new_src_mesh, loss)\n",
    "    \n",
    "    # Compute the average silhouette loss over two random views, as the average \n",
    "    # squared L2 distance between the predicted silhouette and the target \n",
    "    # silhouette from our dataset\n",
    "    for j in np.random.permutation(num_views).tolist()[:num_views_per_iteration]:\n",
    "        images_predicted = renderer_silhouette(new_src_mesh, cameras=target_cameras[j], lights=lights)\n",
    "        predicted_silhouette = images_predicted[..., 3]\n",
    "        loss_silhouette = ((predicted_silhouette - target_silhouette[j]) ** 2).mean()\n",
    "        loss[\"silhouette\"] += loss_silhouette / num_views_per_iteration\n",
    "    \n",
    "    # Weighted sum of the losses\n",
    "    sum_loss = torch.tensor(0.0, device=device)\n",
    "    for k, l in loss.items():\n",
    "        sum_loss += l * losses[k][\"weight\"]\n",
    "        losses[k][\"values\"].append(float(l.detach().cpu()))\n",
    "\n",
    "    \n",
    "    # Print the losses\n",
    "    loop.set_description(\"total_loss = %.6f\" % sum_loss)\n",
    "    \n",
    "    # Plot mesh\n",
    "    if i % plot_period == 0:\n",
    "        visualize_prediction(new_src_mesh, title=\"iter: %d\" % i, silhouette=True,\n",
    "                             target_image=target_silhouette[1])\n",
    "        \n",
    "    # Optimization step\n",
    "    sum_loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CX4huayKR1Rm",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "visualize_prediction(new_src_mesh, silhouette=True, \n",
    "                     target_image=target_silhouette[1])\n",
    "plot_losses(losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XJDsJQmrR1Ro"
   },
   "source": [
    "## 3. Mesh and texture prediction via textured rendering\n",
    "We can predict both the mesh and its texture if we add an additional loss based on the comparing a predicted rendered RGB image to the target image. As before, we start with a sphere mesh.  We learn both translational offsets and RGB texture colors for each vertex in the sphere mesh.  Since our loss is based on rendered RGB pixel values instead of just the silhouette, we use a **SoftPhongShader** instead of a **SoftSilhouetteShader**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "aZObyIt9R1Ro"
   },
   "outputs": [],
   "source": [
    "# Rasterization settings for differentiable rendering, where the blur_radius\n",
    "# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable \n",
    "# Renderer for Image-based 3D Reasoning', ICCV 2019\n",
    "sigma = 1e-4\n",
    "raster_settings_soft = RasterizationSettings(\n",
    "    image_size=128, \n",
    "    blur_radius=np.log(1. / 1e-4 - 1.)*sigma, \n",
    "    faces_per_pixel=50, \n",
    "    perspective_correct=False, \n",
    ")\n",
    "\n",
    "# Differentiable soft renderer using per vertex RGB colors for texture\n",
    "renderer_textured = MeshRenderer(\n",
    "    rasterizer=MeshRasterizer(\n",
    "        cameras=camera, \n",
    "        raster_settings=raster_settings_soft\n",
    "    ),\n",
    "    shader=SoftPhongShader(device=device, \n",
    "        cameras=camera,\n",
    "        lights=lights)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NM7gJux8GMQX"
   },
   "source": [
    "We initialize settings, losses, and the optimizer that will be used to iteratively fit our mesh to the target RGB images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BS6LAQquF3wq"
   },
   "outputs": [],
   "source": [
    "# Number of views to optimize over in each SGD iteration\n",
    "num_views_per_iteration = 2\n",
    "# Number of optimization steps\n",
    "Niter = 2000\n",
    "# Plot period for the losses\n",
    "plot_period = 250\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "# Optimize using rendered RGB image loss, rendered silhouette image loss, mesh \n",
    "# edge loss, mesh normal consistency, and mesh laplacian smoothing\n",
    "losses = {\"rgb\": {\"weight\": 1.0, \"values\": []},\n",
    "          \"silhouette\": {\"weight\": 1.0, \"values\": []},\n",
    "          \"edge\": {\"weight\": 1.0, \"values\": []},\n",
    "          \"normal\": {\"weight\": 0.01, \"values\": []},\n",
    "          \"laplacian\": {\"weight\": 1.0, \"values\": []},\n",
    "         }\n",
    "\n",
    "# We will learn to deform the source mesh by offsetting its vertices\n",
    "# The shape of the deform parameters is equal to the total number of vertices in \n",
    "# src_mesh\n",
    "verts_shape = src_mesh.verts_packed().shape\n",
    "deform_verts = torch.full(verts_shape, 0.0, device=device, requires_grad=True)\n",
    "\n",
    "# We will also learn per vertex colors for our sphere mesh that define texture \n",
    "# of the mesh\n",
    "sphere_verts_rgb = torch.full([1, verts_shape[0], 3], 0.5, device=device, requires_grad=True)\n",
    "\n",
    "# The optimizer\n",
    "optimizer = torch.optim.SGD([deform_verts, sphere_verts_rgb], lr=1.0, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tzIAycuUR1Rq"
   },
   "source": [
    "We write an optimization loop to iteratively refine our predicted mesh and its vertex colors from the sphere mesh into a mesh that matches the target images:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EKEH2p8-R1Rr"
   },
   "outputs": [],
   "source": [
    "loop = tqdm(range(Niter))\n",
    "\n",
    "for i in loop:\n",
    "    # Initialize optimizer\n",
    "    optimizer.zero_grad()\n",
    "    \n",
    "    # Deform the mesh\n",
    "    new_src_mesh = src_mesh.offset_verts(deform_verts)\n",
    "    \n",
    "    # Add per vertex colors to texture the mesh\n",
    "    new_src_mesh.textures = TexturesVertex(verts_features=sphere_verts_rgb) \n",
    "    \n",
    "    # Losses to smooth /regularize the mesh shape\n",
    "    loss = {k: torch.tensor(0.0, device=device) for k in losses}\n",
    "    update_mesh_shape_prior_losses(new_src_mesh, loss)\n",
    "    \n",
    "    # Randomly select two views to optimize over in this iteration.  Compared\n",
    "    # to using just one view, this helps resolve ambiguities between updating\n",
    "    # mesh shape vs. updating mesh texture\n",
    "    for j in np.random.permutation(num_views).tolist()[:num_views_per_iteration]:\n",
    "        images_predicted = renderer_textured(new_src_mesh, cameras=target_cameras[j], lights=lights)\n",
    "\n",
    "        # Squared L2 distance between the predicted silhouette and the target \n",
    "        # silhouette from our dataset\n",
    "        predicted_silhouette = images_predicted[..., 3]\n",
    "        loss_silhouette = ((predicted_silhouette - target_silhouette[j]) ** 2).mean()\n",
    "        loss[\"silhouette\"] += loss_silhouette / num_views_per_iteration\n",
    "        \n",
    "        # Squared L2 distance between the predicted RGB image and the target \n",
    "        # image from our dataset\n",
    "        predicted_rgb = images_predicted[..., :3]\n",
    "        loss_rgb = ((predicted_rgb - target_rgb[j]) ** 2).mean()\n",
    "        loss[\"rgb\"] += loss_rgb / num_views_per_iteration\n",
    "    \n",
    "    # Weighted sum of the losses\n",
    "    sum_loss = torch.tensor(0.0, device=device)\n",
    "    for k, l in loss.items():\n",
    "        sum_loss += l * losses[k][\"weight\"]\n",
    "        losses[k][\"values\"].append(float(l.detach().cpu()))\n",
    "    \n",
    "    # Print the losses\n",
    "    loop.set_description(\"total_loss = %.6f\" % sum_loss)\n",
    "    \n",
    "    # Plot mesh\n",
    "    if i % plot_period == 0:\n",
    "        visualize_prediction(new_src_mesh, renderer=renderer_textured, title=\"iter: %d\" % i, silhouette=False)\n",
    "        \n",
    "    # Optimization step\n",
    "    sum_loss.backward()\n",
    "    optimizer.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "2qTcHO4rR1Rs",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "visualize_prediction(new_src_mesh, renderer=renderer_textured, silhouette=False)\n",
    "plot_losses(losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "akBOm_xcNUms"
   },
   "source": [
    "Save the final predicted mesh:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "dXoIsGyhxRyK"
   },
   "source": [
    "## 4. Save the final predicted mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OQGhV-psKna8"
   },
   "outputs": [],
   "source": [
    "# Fetch the verts and faces of the final predicted mesh\n",
    "final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)\n",
    "\n",
    "# Scale normalize back to the original target size\n",
    "final_verts = final_verts * scale + center\n",
    "\n",
    "# Store the predicted mesh using save_obj\n",
    "final_obj = os.path.join('./', 'final_model.obj')\n",
    "save_obj(final_obj, final_verts, final_faces)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MtKYp0B6R1Ru"
   },
   "source": [
    "## 5. Conclusion\n",
    "In this tutorial, we learned how to load a textured mesh from an obj file, create a synthetic dataset by rendering the mesh from multiple viewpoints.  We showed how to set up an optimization loop to fit a mesh to the observed dataset images based on a rendered silhouette loss.  We then augmented this optimization loop with an additional loss based on rendered RGB images, which allowed us to predict both a mesh and its texture."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "anp_metadata": {
   "path": "fbsource/fbcode/vision/fair/pytorch3d/docs/tutorials/fit_textured_mesh.ipynb"
  },
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "colab": {
   "name": "fit_textured_mesh.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "disseminate_notebook_info": {
   "backup_notebook_id": "781874812352022"
  },
  "kernelspec": {
   "display_name": "intro_to_cv",
   "language": "python",
   "name": "bento_kernel_intro_to_cv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}