import os import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F import wandb from PIL import Image from unik3d.utils.distributed import get_rank from unik3d.utils.misc import ssi_helper def colorize( value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r" ): # if already RGB, do nothing if value.ndim > 2: if value.shape[-1] > 1: return value value = value[..., 0] invalid_mask = value < 0.0001 # normalize vmin = value.min() if vmin is None else vmin vmax = value.max() if vmax is None else vmax value = (value - vmin) / (vmax - vmin) # vmin..vmax # set color cmapper = plt.get_cmap(cmap) value = cmapper(value, bytes=True) # (nxmx4) value[invalid_mask] = 0 img = value[..., :3] return img def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray: if not len(imgs): return None assert len(imgs) == rows * cols h, w = imgs[0].shape[:2] grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste( Image.fromarray(img.astype(np.uint8)).resize( (w, h), resample=Image.BILINEAR ), box=(i % cols * w, i // cols * h), ) return np.array(grid) def get_pointcloud_from_rgbd( image: np.array, depth: np.array, mask: np.ndarray, intrinsic_matrix: np.array, extrinsic_matrix: np.array = None, ): depth = np.array(depth).squeeze() mask = np.array(mask).squeeze() # Mask the depth array masked_depth = np.ma.masked_where(mask == False, depth) # masked_depth = np.ma.masked_greater(masked_depth, 8000) # Create idx array idxs = np.indices(masked_depth.shape) u_idxs = idxs[1] v_idxs = idxs[0] # Get only non-masked depth and idxs z = masked_depth[~masked_depth.mask] compressed_u_idxs = u_idxs[~masked_depth.mask] compressed_v_idxs = v_idxs[~masked_depth.mask] image = np.stack( [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1 ) # Calculate local position of each point # Apply vectorized math to depth using compressed arrays cx = intrinsic_matrix[0, 2] fx = intrinsic_matrix[0, 0] x = (compressed_u_idxs - cx) * z / fx cy = intrinsic_matrix[1, 2] fy = intrinsic_matrix[1, 1] # Flip y as we want +y pointing up not down y = -((compressed_v_idxs - cy) * z / fy) # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords # if extrinsic_matrix is not None: # # Calculate camera pose from extrinsic matrix # camera_matrix = np.linalg.inv(extrinsic_matrix) # # Create homogenous array of vectors by adding 4th entry of 1 # # At the same time flip z as for eye space the camera is looking down the -z axis # w = np.ones(z.shape) # x_y_z_eye_hom = np.vstack((x, y, -z, w)) # # Transform the points from eye space to world space # x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3] # return x_y_z_world.T # else: x_y_z_local = np.stack((x, y, z), axis=-1) return np.concatenate([x_y_z_local, image], axis=-1) def save_file_ply(xyz, rgb, pc_file): if rgb.max() < 1.001: rgb = rgb * 255.0 rgb = rgb.astype(np.uint8) # print(rgb) with open(pc_file, "w") as f: # headers f.writelines( [ "ply\n" "format ascii 1.0\n", "element vertex {}\n".format(xyz.shape[0]), "property float x\n", "property float y\n", "property float z\n", "property uchar red\n", "property uchar green\n", "property uchar blue\n", "end_header\n", ] ) for i in range(xyz.shape[0]): str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format( xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2] ) f.write(str_v) # really awful fct... FIXME def train_artifacts(rgbs, gts, preds, infos={}): # interpolate to same shape, will be distorted! FIXME TODO shape = rgbs[0].shape[-2:] gts = F.interpolate(gts, shape, mode="nearest-exact") rgbs = [ (127.5 * (rgb + 1)) .clip(0, 255) .to(torch.uint8) .cpu() .detach() .permute(1, 2, 0) .numpy() for rgb in rgbs ] new_gts, new_preds = [], [] num_additional, additionals = 0, [] if len(gts) > 0: for i, gt in enumerate(gts): # scale, shift = ssi_helper(gts[i][gts[i]>0].cpu().detach(), preds[i][gts[i]>0].cpu().detach()) scale, shift = 1, 0 up = torch.quantile( torch.log(1 + gts[i][gts[i] > 0]).float().cpu().detach(), 0.98 ).item() down = torch.quantile( torch.log(1 + gts[i][gts[i] > 0]).float().cpu().detach(), 0.02 ).item() gt = gts[i].cpu().detach().squeeze().numpy() pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy() new_gts.append( colorize(np.log(1.0 + gt), vmin=down, vmax=up) ) # , vmin=vmin, vmax=vmax)) new_preds.append( colorize(np.log(1.0 + pred), vmin=down, vmax=up) ) # , vmin=vmin, vmax=vmax)) gts, preds = new_gts, new_preds else: preds = [ colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0) for i, pred in enumerate(preds) ] for name, info in infos.items(): num_additional += 1 if info.shape[1] == 3: additionals.extend( [ (127.5 * (x + 1)) .clip(0, 255) .to(torch.uint8) .cpu() .detach() .permute(1, 2, 0) .numpy() for x in info ] ) else: # must be depth! additionals.extend( [ colorize(x.cpu().detach().squeeze().numpy()) for i, x in enumerate(info) ] ) num_rows = 2 + int(len(gts) > 0) + num_additional artifacts_grid = image_grid( [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs) ) return artifacts_grid def log_train_artifacts(rgbs, gts, preds, step, infos={}): artifacts_grid = train_artifacts(rgbs, gts, preds, infos) try: wandb.log({f"training": [wandb.Image(artifacts_grid)]}, step=step) except: Image.fromarray(artifacts_grid).save( os.path.join( os.environ.get("TMPDIR", "/tmp"), f"{get_rank()}_art_grid{step}.png", ) ) print("Logging training images failed") def plot_quiver(flow, spacing, margin=0, **kwargs): """Plots less dense quiver field. Args: ax: Matplotlib axis flow: motion vectors spacing: space (px) between each arrow in grid margin: width (px) of enclosing region without arrows kwargs: quiver kwargs (default: angles="xy", scale_units="xy") """ h, w, *_ = flow.shape nx = int((w - 2 * margin) / spacing) ny = int((h - 2 * margin) / spacing) x = np.linspace(margin, w - margin - 1, nx, dtype=np.int64) y = np.linspace(margin, h - margin - 1, ny, dtype=np.int64) flow = flow[np.ix_(y, x)] u = flow[:, :, 0] v = flow[:, :, 1] kwargs = {**dict(angles="xy", scale_units="xy"), **kwargs} fig, ax = plt.subplots(figsize=(10, 10)) ax.quiver(x, y, u, v, **kwargs) # ax.set_ylim(sorted(ax.get_ylim(), reverse=True)) return fig, ax