Spaces:
Runtime error
Runtime error
File size: 13,534 Bytes
753fd9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 |
# part of the code from
# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/p3d_renderer.py
import torch
import torch.nn.functional as F
from scipy.io import loadmat
import numpy as np
# import config
import pytorch3d
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
PerspectiveCameras, look_at_view_transform, look_at_rotation,
RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
PointLights, HardPhongShader, SoftSilhouetteShader, Materials, Textures,
DirectionalLights
)
from pytorch3d.renderer import TexturesVertex, SoftPhongShader
from pytorch3d.io import load_objs_as_meshes
MESH_COLOR_0 = [0, 172, 223]
MESH_COLOR_1 = [172, 223, 0]
'''
Explanation of the shift between projection results from opendr and pytorch3d:
(0, 0, ?) will be projected to 127.5 (pytorch3d) instead of 128 (opendr)
imagine you have an image of size 4:
middle of the first pixel is 0
middle of the last pixel is 3
=> middle of the imgae would be 1.5 and not 2!
so in order to go from pytorch3d predictions to opendr we would calculate: p_odr = p_p3d * (128/127.5)
To reproject points (p3d) by hand according to this pytorch3d renderer we would do the following steps:
1.) build camera matrix
K = np.array([[flength, 0, c_x],
[0, flength, c_y],
[0, 0, 1]], np.float)
2.) we don't need to add extrinsics, as the mesh comes with translation (which is
added within smal_pytorch). all 3d points are already in the camera coordinate system.
-> projection reduces to p2d_proj = K*p3d
3.) convert to pytorch3d conventions (0 in the middle of the first pixel)
p2d_proj_pytorch3d = p2d_proj / image_size * (image_size-1.)
renderer.py - project_points_p3d: shows an example of what is described above, but
same focal length for the whole batch
'''
class SilhRenderer(torch.nn.Module):
def __init__(self, image_size, adapt_R_wldo=False):
super(SilhRenderer, self).__init__()
# see: https://pytorch3d.org/files/fit_textured_mesh.py, line 315
# adapt_R=True is True for all my experiments
# image_size: one number, integer
# -----
# set mesh color
self.register_buffer('mesh_color_0', torch.FloatTensor(MESH_COLOR_0))
self.register_buffer('mesh_color_1', torch.FloatTensor(MESH_COLOR_1))
# prepare extrinsics, which in our case don't change
R = torch.Tensor(np.eye(3)).float()[None, :, :]
T = torch.Tensor(np.zeros((1, 3))).float()
if adapt_R_wldo:
R[0, 0, 0] = -1
else: # used for all my own experiments
R[0, 0, 0] = -1
R[0, 1, 1] = -1
self.register_buffer('R', R)
self.register_buffer('T', T)
# prepare that part of the intrinsics which does not change either
# principal_point_prep = torch.Tensor([self.image_size / 2., self.image_size / 2.]).float()[None, :].float().to(device)
# image_size_prep = torch.Tensor([self.image_size, self.image_size]).float()[None, :].float().to(device)
self.img_size_scalar = image_size
self.register_buffer('image_size', torch.Tensor([image_size, image_size]).float()[None, :].float())
self.register_buffer('principal_point', torch.Tensor([image_size / 2., image_size / 2.]).float()[None, :].float())
# Rasterization settings for differentiable rendering, where the blur_radius
# initialization is based on Liu et al, 'Soft Rasterizer: A Differentiable
# Renderer for Image-based 3D Reasoning', ICCV 2019
self.blend_params = BlendParams(sigma=1e-4, gamma=1e-4)
self.raster_settings_soft = RasterizationSettings(
image_size=image_size, # 128
blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params.sigma,
faces_per_pixel=100) #50,
# Renderer for Image-based 3D Reasoning', body part segmentation
self.blend_params_parts = BlendParams(sigma=2*1e-4, gamma=1e-4)
self.raster_settings_soft_parts = RasterizationSettings(
image_size=image_size, # 128
blur_radius=np.log(1. / 1e-4 - 1.)*self.blend_params_parts.sigma,
faces_per_pixel=60) #50,
# settings for visualization renderer
self.raster_settings_vis = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=1)
def _get_cam(self, focal_lengths):
device = focal_lengths.device
bs = focal_lengths.shape[0]
if pytorch3d.__version__ == '0.2.5':
cameras = PerspectiveCameras(device=device,
focal_length=focal_lengths.repeat((1, 2)),
principal_point=self.principal_point.repeat((bs, 1)),
R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)),
image_size=self.image_size.repeat((bs, 1)))
elif pytorch3d.__version__ == '0.6.1':
cameras = PerspectiveCameras(device=device, in_ndc=False,
focal_length=focal_lengths.repeat((1, 2)),
principal_point=self.principal_point.repeat((bs, 1)),
R=self.R.repeat((bs, 1, 1)), T=self.T.repeat((bs, 1)),
image_size=self.image_size.repeat((bs, 1)))
else:
print('this part depends on the version of pytorch3d, code was developed with 0.2.5')
raise ValueError
return cameras
def _get_visualization_from_mesh(self, mesh, cameras, lights=None):
# color renderer for visualization
with torch.no_grad():
device = mesh.device
# renderer for visualization
if lights is None:
lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
vis_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=self.raster_settings_vis),
shader=HardPhongShader(
device=device,
cameras=cameras,
lights=lights))
# render image:
visualization = vis_renderer(mesh).permute(0, 3, 1, 2)[:, :3, :, :]
return visualization
def calculate_vertex_visibility(self, vertices, faces, focal_lengths, soft=False):
tex = torch.ones_like(vertices) * self.mesh_color_0 # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
cameras = self._get_cam(focal_lengths)
# NEW: use the rasterizer to check vertex visibility
# see: https://github.com/facebookresearch/pytorch3d/issues/126
# Get a rasterizer
if soft:
rasterizer = MeshRasterizer(cameras=cameras,
raster_settings=self.raster_settings_soft)
else:
rasterizer = MeshRasterizer(cameras=cameras,
raster_settings=self.raster_settings_vis)
# Get the output from rasterization
fragments = rasterizer(mesh)
# pix_to_face is of shape (N, H, W, 1)
pix_to_face = fragments.pix_to_face
# (F, 3) where F is the total number of faces across all the meshes in the batch
packed_faces = mesh.faces_packed()
# (V, 3) where V is the total number of verts across all the meshes in the batch
packed_verts = mesh.verts_packed()
vertex_visibility_map = torch.zeros(packed_verts.shape[0]) # (V,)
# Indices of unique visible faces
visible_faces = pix_to_face.unique() # [0] # (num_visible_faces )
# Get Indices of unique visible verts using the vertex indices in the faces
visible_verts_idx = packed_faces[visible_faces] # (num_visible_faces, 3)
unique_visible_verts_idx = torch.unique(visible_verts_idx) # (num_visible_verts, )
# Update visibility indicator to 1 for all visible vertices
vertex_visibility_map[unique_visible_verts_idx] = 1.0
# since all meshes have the same amount of vertices, we can reshape the result
bs = vertices.shape[0]
vertex_visibility_map_resh = vertex_visibility_map.reshape((bs, -1))
return pix_to_face, vertex_visibility_map_resh
def get_torch_meshes(self, vertices, faces, color=0):
# create pytorch mesh
if color == 0:
mesh_color = self.mesh_color_0
else:
mesh_color = self.mesh_color_1
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
return mesh
def get_visualization_nograd(self, vertices, faces, focal_lengths, color=0):
# vertices: torch.Size([bs, 3889, 3])
# faces: torch.Size([bs, 7774, 3]), int
# focal_lengths: torch.Size([bs, 1])
device = vertices.device
# create cameras
cameras = self._get_cam(focal_lengths)
# create pytorch mesh
if color == 0:
mesh_color = self.mesh_color_0 # blue
elif color == 1:
mesh_color = self.mesh_color_1
elif color == 2:
MESH_COLOR_2 = [240, 250, 240] # white
mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device)
elif color == 3:
# MESH_COLOR_3 = [223, 0, 172] # pink
# MESH_COLOR_3 = [245, 245, 220] # beige
MESH_COLOR_3 = [166, 173, 164]
mesh_color = torch.FloatTensor(MESH_COLOR_3).to(device)
else:
MESH_COLOR_2 = [240, 250, 240]
mesh_color = torch.FloatTensor(MESH_COLOR_2).to(device)
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
# render mesh (no gradients)
# lights = PointLights(device=device, location=[[0.0, 0.0, 3.0]])
# lights = PointLights(device=device, location=[[2.0, 2.0, -2.0]])
lights = DirectionalLights(device=device, direction=[[0.0, -5.0, -10.0]])
visualization = self._get_visualization_from_mesh(mesh, cameras, lights=lights)
return visualization
def project_points(self, points, focal_lengths=None, cameras=None):
# points: torch.Size([bs, n_points, 3])
# either focal_lengths or cameras is needed:
# focal_lenghts: torch.Size([bs, 1])
# cameras: pytorch camera, for example PerspectiveCameras()
bs = points.shape[0]
device = points.device
screen_size = self.image_size.repeat((bs, 1))
if cameras is None:
cameras = self._get_cam(focal_lengths)
if pytorch3d.__version__ == '0.2.5':
proj_points_orig = cameras.transform_points_screen(points, screen_size)[:, :, [1, 0]] # used in the original virtuel environment (for cvpr BARC submission)
elif pytorch3d.__version__ == '0.6.1':
proj_points_orig = cameras.transform_points_screen(points)[:, :, [1, 0]]
else:
print('this part depends on the version of pytorch3d, code was developed with 0.2.5')
raise ValueError
# flip, otherwise the 1st and 2nd row are exchanged compared to the ground truth
proj_points = torch.flip(proj_points_orig, [2])
# --- project points 'manually'
# j_proj = project_points_p3d(image_size, focal_length, points, device)
return proj_points
def forward(self, vertices, points, faces, focal_lengths, color=None):
# vertices: torch.Size([bs, 3889, 3])
# points: torch.Size([bs, n_points, 3]) (or None)
# faces: torch.Size([bs, 7774, 3]), int
# focal_lengths: torch.Size([bs, 1])
# color: if None we don't render a visualization, else it should
# either be 0 or 1
# ---> important: results are around 0.5 pixels off compared to chumpy!
# have a look at renderer.py for an explanation
# create cameras
cameras = self._get_cam(focal_lengths)
# create pytorch mesh
if color is None or color == 0:
mesh_color = self.mesh_color_0
else:
mesh_color = self.mesh_color_1
tex = torch.ones_like(vertices) * mesh_color # (1, V, 3)
textures = Textures(verts_rgb=tex)
mesh = Meshes(verts=vertices, faces=faces, textures=textures)
# silhouette renderer
renderer_silh = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=self.raster_settings_soft),
shader=SoftSilhouetteShader(blend_params=self.blend_params))
# project silhouette
silh_images = renderer_silh(mesh)[..., -1].unsqueeze(1)
# project points
if points is None:
proj_points = None
else:
proj_points = self.project_points(points=points, cameras=cameras)
if color is not None:
# color renderer for visualization (no gradients)
visualization = self._get_visualization_from_mesh(mesh, cameras)
return silh_images, proj_points, visualization
else:
return silh_images, proj_points
|