Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import imageio | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from pytorch3d.renderer import ( | |
PerspectiveCameras, | |
TexturesVertex, | |
PointLights, | |
Materials, | |
RasterizationSettings, | |
MeshRenderer, | |
MeshRasterizer, | |
SoftPhongShader, | |
) | |
from pytorch3d.renderer.mesh.shader import ShaderBase | |
from pytorch3d.structures import Meshes | |
class NormalShader(ShaderBase): | |
def __init__(self, device = "cpu", **kwargs): | |
super().__init__(device=device, **kwargs) | |
def forward(self, fragments, meshes, **kwargs): | |
blend_params = kwargs.get("blend_params", self.blend_params) | |
texels = fragments.bary_coords.clone() | |
texels = texels.permute(0, 3, 1, 2, 4) | |
texels = texels * 2 - 1 # 将 bary_coords 映射到 [-1, 1] | |
# 获取法线 | |
verts_normals = meshes.verts_normals_packed() | |
faces_normals = verts_normals[meshes.faces_packed()] | |
bary_coords = fragments.bary_coords | |
pixel_normals = (bary_coords[..., None] * faces_normals[fragments.pix_to_face]).sum(dim=-2) | |
pixel_normals = pixel_normals / pixel_normals.norm(dim=-1, keepdim=True) | |
# 将法线映射到颜色空间 | |
# colors = (pixel_normals + 1) / 2 # 将法线映射到 [0, 1] | |
colors = torch.clamp(pixel_normals, -1, 1) | |
print(colors.shape) | |
mask = (fragments.pix_to_face > 0).float() | |
colors = torch.cat([colors, mask.unsqueeze(-1)], dim=-1) | |
# colors[fragments.pix_to_face < 0] = 0 | |
# 混合颜色 | |
# images = self.blend(texels, colors, fragments, blend_params) | |
return colors | |
def overlay_image_onto_background(image, mask, bbox, background): | |
if isinstance(image, torch.Tensor): | |
image = image.detach().cpu().numpy() | |
if isinstance(mask, torch.Tensor): | |
mask = mask.detach().cpu().numpy() | |
out_image = background.copy() | |
bbox = bbox[0].int().cpu().numpy().copy() | |
roi_image = out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] | |
if len(roi_image) < 1 or len(roi_image[1]) < 1: | |
return out_image | |
try: | |
roi_image[mask] = image[mask] | |
except Exception as e: | |
raise e | |
out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] = roi_image | |
return out_image | |
def update_intrinsics_from_bbox(K_org, bbox): | |
''' | |
update intrinsics for cropped images | |
''' | |
device, dtype = K_org.device, K_org.dtype | |
K = torch.zeros((K_org.shape[0], 4, 4) | |
).to(device=device, dtype=dtype) | |
K[:, :3, :3] = K_org.clone() | |
K[:, 2, 2] = 0 | |
K[:, 2, -1] = 1 | |
K[:, -1, 2] = 1 | |
image_sizes = [] | |
for idx, bbox in enumerate(bbox): | |
left, upper, right, lower = bbox | |
cx, cy = K[idx, 0, 2], K[idx, 1, 2] | |
new_cx = cx - left | |
new_cy = cy - upper | |
new_height = max(lower - upper, 1) | |
new_width = max(right - left, 1) | |
new_cx = new_width - new_cx | |
new_cy = new_height - new_cy | |
K[idx, 0, 2] = new_cx | |
K[idx, 1, 2] = new_cy | |
image_sizes.append((int(new_height), int(new_width))) | |
return K, image_sizes | |
def perspective_projection(x3d, K, R=None, T=None): | |
if R != None: | |
x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2) | |
if T != None: | |
x3d = x3d + T.transpose(1, 2) | |
x2d = torch.div(x3d, x3d[..., 2:]) | |
x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2] | |
return x2d | |
def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2): | |
left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w) | |
right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w) | |
top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h) | |
bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h) | |
cx = (left + right) / 2 | |
cy = (top + bottom) / 2 | |
width = (right - left) | |
height = (bottom - top) | |
new_left = torch.clamp(cx - width/2 * scaleFactor, min=0, max=img_w-1) | |
new_right = torch.clamp(cx + width/2 * scaleFactor, min=1, max=img_w) | |
new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h-1) | |
new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h) | |
bbox = torch.stack((new_left.detach(), new_top.detach(), | |
new_right.detach(), new_bottom.detach())).int().float().T | |
return bbox | |
class Renderer(): | |
def __init__(self, width, height, K, device, faces=None): | |
self.width = width | |
self.height = height | |
self.K = K | |
self.device = device | |
if faces is not None: | |
self.faces = torch.from_numpy( | |
(faces).astype('int') | |
).unsqueeze(0).to(self.device) | |
self.initialize_camera_params() | |
self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]]) | |
self.create_renderer() | |
def create_camera(self, R=None, T=None): | |
if R is not None: | |
self.R = R.clone().view(1, 3, 3).to(self.device) | |
if T is not None: | |
self.T = T.clone().view(1, 3).to(self.device) | |
return PerspectiveCameras( | |
device=self.device, | |
R=self.R.mT, | |
T=self.T, | |
K=self.K_full, | |
image_size=self.image_sizes, | |
in_ndc=False) | |
def create_renderer(self): | |
self.renderer = MeshRenderer( | |
rasterizer=MeshRasterizer( | |
raster_settings=RasterizationSettings( | |
image_size=self.image_sizes[0], | |
blur_radius=1e-5,), | |
), | |
shader=SoftPhongShader( | |
device=self.device, | |
lights=self.lights, | |
) | |
) | |
def create_normal_renderer(self): | |
normal_renderer = MeshRenderer( | |
rasterizer=MeshRasterizer( | |
cameras=self.cameras, | |
raster_settings=RasterizationSettings( | |
image_size=self.image_sizes[0], | |
), | |
), | |
shader=NormalShader(device=self.device), | |
) | |
return normal_renderer | |
def initialize_camera_params(self): | |
"""Hard coding for camera parameters | |
TODO: Do some soft coding""" | |
# Extrinsics | |
self.R = torch.diag( | |
torch.tensor([1, 1, 1]) | |
).float().to(self.device).unsqueeze(0) | |
self.T = torch.tensor( | |
[0, 0, 0] | |
).unsqueeze(0).float().to(self.device) | |
# Intrinsics | |
self.K = self.K.unsqueeze(0).float().to(self.device) | |
self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float() | |
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes) | |
self.cameras = self.create_camera() | |
def render_normal(self, vertices): | |
vertices = vertices.unsqueeze(0) | |
mesh = Meshes(verts=vertices, faces=self.faces) | |
normal_renderer = self.create_normal_renderer() | |
results = normal_renderer(mesh) | |
results = torch.flip(results, [1, 2]) | |
return results | |
def render_mesh(self, vertices, background, colors=[0.8, 0.8, 0.8]): | |
self.update_bbox(vertices[::50], scale=1.2) | |
vertices = vertices.unsqueeze(0) | |
if colors[0] > 1: colors = [c / 255. for c in colors] | |
verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype) | |
verts_features = verts_features.repeat(1, vertices.shape[1], 1) | |
textures = TexturesVertex(verts_features=verts_features) | |
mesh = Meshes(verts=vertices, | |
faces=self.faces, | |
textures=textures,) | |
materials = Materials( | |
device=self.device, | |
specular_color=(colors, ), | |
shininess=0 | |
) | |
results = torch.flip( | |
self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights), | |
[1, 2] | |
) | |
image = results[0, ..., :3] * 255 | |
mask = results[0, ..., -1] > 1e-3 | |
image = overlay_image_onto_background(image, mask, self.bboxes, background.copy()) | |
self.reset_bbox() | |
return image | |
def update_bbox(self, x3d, scale=2.0, mask=None): | |
""" Update bbox of cameras from the given 3d points | |
x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3) | |
""" | |
if x3d.size(-1) != 3: | |
x2d = x3d.unsqueeze(0) | |
else: | |
x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1)) | |
if mask is not None: | |
x2d = x2d[:, ~mask] | |
bbox = compute_bbox_from_points(x2d, self.width, self.height, scale) | |
self.bboxes = bbox | |
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) | |
self.cameras = self.create_camera() | |
self.create_renderer() | |
def reset_bbox(self,): | |
bbox = torch.zeros((1, 4)).float().to(self.device) | |
bbox[0, 2] = self.width | |
bbox[0, 3] = self.height | |
self.bboxes = bbox | |
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox) | |
self.cameras = self.create_camera() | |
self.create_renderer() | |
class RendererUtil(): | |
def __init__(self, K, w, h, device, faces, keep_origin=True): | |
self.keep_origin = keep_origin | |
self.default_R = torch.eye(3) | |
self.default_T = torch.zeros(3) | |
self.device = device | |
self.renderer = Renderer(w, h, K, device, faces) | |
def set_extrinsic(self, R, T): | |
self.default_R = R | |
self.default_T = T | |
def render_normal(self, verts_list): | |
if not len(verts_list) == 1: | |
return None | |
self.renderer.create_camera(self.default_R, self.default_T) | |
normal_map = self.renderer.render_normal(verts_list[0]) | |
return normal_map[0, :, :, 0] | |
def render_frame(self, humans, pred_rend_array, verts_list=None, color_list=None): | |
if not isinstance(pred_rend_array, np.ndarray): | |
pred_rend_array = np.asarray(pred_rend_array) | |
self.renderer.create_camera(self.default_R, self.default_T) | |
_img = pred_rend_array | |
if humans is not None: | |
for human in humans: | |
_img = self.renderer.render_mesh(human['v3d'].to(self.device), _img) | |
else: | |
for i, verts in enumerate(verts_list): | |
if color_list is None: | |
_img = self.renderer.render_mesh(verts.to(self.device), _img) | |
else: | |
_img = self.renderer.render_mesh(verts.to(self.device), _img, color_list[i]) | |
if self.keep_origin: | |
_img = np.concatenate([np.asarray(pred_rend_array), _img],1).astype(np.uint8) | |
return _img | |
def render_video(self, results, pil_bis_frames, fps, out_path): | |
writer = imageio.get_writer( | |
out_path, | |
fps=fps, mode='I', format='FFMPEG', macro_block_size=1 | |
) | |
for i, humans in enumerate(tqdm(results)): | |
pred_rend_array = pil_bis_frames[i] | |
_img = self.render_frame( humans, pred_rend_array) | |
try: | |
writer.append_data(_img) | |
except: | |
print('Error in writing video') | |
print(type(_img)) | |
writer.close() | |
def render_frame(renderer, humans, pred_rend_array, default_R, default_T, device, keep_origin=True): | |
if not isinstance(pred_rend_array, np.ndarray): | |
pred_rend_array = np.asarray(pred_rend_array) | |
renderer.create_camera(default_R, default_T) | |
_img = pred_rend_array | |
if humans is None: | |
humans = [] | |
if isinstance(humans, dict): | |
humans = [humans] | |
for human in humans: | |
if isinstance(human, dict): | |
v3d = human['v3d'].to(device) | |
else: | |
v3d = human | |
_img = renderer.render_mesh(v3d, _img) | |
if keep_origin: | |
_img = np.concatenate([np.asarray(pred_rend_array), _img],1).astype(np.uint8) | |
return _img | |
def render_video(results, faces, K, pil_bis_frames, fps, out_path, device, keep_origin=True): | |
# results [F, N, ...] | |
if isinstance(pil_bis_frames[0], np.ndarray): | |
height, width, _ = pil_bis_frames[0].shape | |
else: | |
shape = pil_bis_frames[0].size | |
width, height = shape[1], shape[0] | |
renderer = Renderer(width, height, K[0], device, faces) | |
# build default camera | |
default_R, default_T = torch.eye(3), torch.zeros(3) | |
writer = imageio.get_writer( | |
out_path, | |
fps=fps, mode='I', format='FFMPEG', macro_block_size=1 | |
) | |
for i, humans in enumerate(tqdm(results)): | |
pred_rend_array = pil_bis_frames[i] | |
_img = render_frame(renderer, humans, pred_rend_array, default_R, default_T, device, keep_origin) | |
try: | |
writer.append_data(_img) | |
except: | |
print('Error in writing video') | |
print(type(_img)) | |
writer.close() | |