import torch import tops import cv2 import torchvision.transforms.functional as F from typing import Optional, List, Union, Tuple from .cse import from_E_to_vertex import numpy as np from tops import download_file from .torch_utils import ( denormalize_img, binary_dilation, binary_erosion, remove_pad, crop_box) from torchvision.utils import _generate_color_palette from PIL import Image, ImageColor, ImageDraw def get_coco_keypoints(): # From: https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/keypoints.py keypoints = [ 'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle' ] keypoint_flip_map = { 'left_eye': 'right_eye', 'left_ear': 'right_ear', 'left_shoulder': 'right_shoulder', 'left_elbow': 'right_elbow', 'left_wrist': 'right_wrist', 'left_hip': 'right_hip', 'left_knee': 'right_knee', 'left_ankle': 'right_ankle' } connectivity = { "nose": "left_eye", "left_eye": "right_eye", "right_eye": "nose", "left_ear": "left_eye", "right_ear": "right_eye", "left_shoulder": "nose", "right_shoulder": "nose", "left_elbow": "left_shoulder", "right_elbow": "right_shoulder", "left_wrist": "left_elbow", "right_wrist": "right_elbow", "left_hip": "left_shoulder", "right_hip": "right_shoulder", "left_knee": "left_hip", "right_knee": "right_hip", "left_ankle": "left_knee", "right_ankle": "right_knee" } connectivity_indices = [ (sidx, keypoints.index(connectivity[kp])) for sidx, kp in enumerate(keypoints) ] return keypoints, keypoint_flip_map, connectivity_indices def get_coco_colors(): return [ *["red"]*5, "blue", "green", "blue", "green", "blue", "green", "purple", "orange", "purple", "orange", "purple", "orange", ] @torch.no_grad() def draw_keypoints( image: torch.Tensor, keypoints: torch.Tensor, connectivity: Optional[List[Tuple[int, int]]] = None, visible: Optional[List[List[bool]]] = None, colors: Optional[Union[str, Tuple[int, int, int]]] = None, radius: int = None, width: int = None, ) -> torch.Tensor: """ Function taken from torchvision source code. Added in torchvision 0.12 Draws Keypoints on given RGB image. The values of the input image should be uint8 between 0 and 255. Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8. keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, in the format [x, y]. connectivity (List[Tuple[int, int]]]): A List of tuple where, each tuple contains pair of keypoints to be connected. colors (str, Tuple): The color can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. radius (int): Integer denoting radius of keypoint. width (int): Integer denoting width of line connecting keypoints. Returns: img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. """ if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype != torch.uint8: raise ValueError(f"The image dtype must be uint8, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") if keypoints.ndim != 3: raise ValueError("keypoints must be of shape (num_instances, K, 2)") if width is None: width = int(max(max(image.shape[-2:]) * 0.01, 1)) if radius is None: radius = int(max(max(image.shape[-2:]) * 0.01, 1)) ndarr = image.permute(1, 2, 0).cpu().numpy() img_to_draw = Image.fromarray(ndarr) draw = ImageDraw.Draw(img_to_draw) if isinstance(keypoints, torch.Tensor): img_kpts = keypoints.to(torch.int64).tolist() else: assert isinstance(keypoints, np.ndarray) img_kpts = keypoints.astype(int).tolist() colors = get_coco_colors() for inst_id, kpt_inst in enumerate(img_kpts): for kpt_id, kpt in enumerate(kpt_inst): if visible is not None and int(visible[inst_id][kpt_id]) == 0: continue x1 = kpt[0] - radius x2 = kpt[0] + radius y1 = kpt[1] - radius y2 = kpt[1] + radius draw.ellipse([x1, y1, x2, y2], fill=colors[kpt_id], outline=None, width=0) if connectivity is not None: for connection in connectivity: if connection[1] >= len(kpt_inst) or connection[0] >= len(kpt_inst): continue if visible is not None and int(visible[inst_id][connection[1]]) == 0 or int(visible[inst_id][connection[0]]) == 0: continue start_pt_x = kpt_inst[connection[0]][0] start_pt_y = kpt_inst[connection[0]][1] end_pt_x = kpt_inst[connection[1]][0] end_pt_y = kpt_inst[connection[1]][1] draw.line( ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), width=width, fill=colors[connection[1]] ) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) def visualize_keypoints(img, keypoints): img = img.clone() keypoints = keypoints.clone() keypoints[:, :, 0] *= img.shape[-1] keypoints[:, :, 1] *= img.shape[-2] _, _, connectivity = get_coco_keypoints() connectivity = np.array(connectivity) visible = None if keypoints.shape[-1] == 3: visible = keypoints[:, :, 2] > 0 for idx in range(img.shape[0]): img[idx] = draw_keypoints( img[idx], keypoints[idx:idx+1].long(), colors="red", connectivity=connectivity, visible=visible[idx:idx+1]) return img def visualize_batch( img: torch.Tensor, mask: torch.Tensor, vertices: torch.Tensor = None, E_mask: torch.Tensor = None, embed_map: torch.Tensor = None, semantic_mask: torch.Tensor = None, embedding: torch.Tensor = None, keypoints: torch.Tensor = None, maskrcnn_mask: torch.Tensor = None, **kwargs) -> torch.ByteTensor: img = denormalize_img(img).mul(255).round().clamp(0, 255).byte() img = draw_mask(img, mask) if maskrcnn_mask is not None and maskrcnn_mask.shape == mask.shape: img = draw_mask(img, maskrcnn_mask) if vertices is not None or embedding is not None: assert E_mask is not None assert embed_map is not None img, E_mask, embedding, embed_map, vertices = tops.to_cpu([ img, E_mask, embedding, embed_map, vertices ]) img = draw_cse(img, E_mask, embedding, embed_map, vertices) elif semantic_mask is not None: img = draw_segmentation_masks(img, semantic_mask) if keypoints is not None: img = visualize_keypoints(img, keypoints) return img @torch.no_grad() def draw_cse( img: torch.Tensor, E_seg: torch.Tensor, embedding: torch.Tensor = None, embed_map: torch.Tensor = None, vertices: torch.Tensor = None, t=0.7 ): """ E_seg: 1 for areas with embedding """ assert img.dtype == torch.uint8 img = img.view(-1, *img.shape[-3:]) E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) if vertices is None: assert embedding is not None assert embed_map is not None embedding = embedding.view(-1, *embedding.shape[-3:]) vertices = torch.stack( [from_E_to_vertex(e[None], e_seg[None].logical_not().float(), embed_map) for e, e_seg in zip(embedding, E_seg)]) i = np.arange(0, 256, dtype=np.uint8).reshape(1, -1) colormap_JET = torch.from_numpy(cv2.applyColorMap(i, cv2.COLORMAP_JET)[0]) color_embed_map, _ = np.load(download_file( "https://dl.fbaipublicfiles.com/densepose/data/cse/mds_d=256.npy"), allow_pickle=True) color_embed_map = torch.from_numpy(color_embed_map).float()[:, 0] color_embed_map -= color_embed_map.min() color_embed_map /= color_embed_map.max() vertx2idx = (color_embed_map*255).long() vertx2colormap = colormap_JET[vertx2idx] vertices = vertices.view(-1, *vertices.shape[-2:]) E_seg = E_seg.view(-1, 1, *E_seg.shape[-2:]) # This operation might be good to do on cpu... E_color = vertx2colormap[vertices.long()] E_color = E_color.to(E_seg.device) E_color = E_color.permute(0, 3, 1, 2) E_color = E_color*E_seg.byte() m = E_seg.bool().repeat(1, 3, 1, 1) img[m] = (img[m] * (1-t) + t * E_color[m]).byte() return img def draw_cse_all( embedding: List[torch.Tensor], E_mask: List[torch.Tensor], im: torch.Tensor, boxes_XYXY: list, embed_map: torch.Tensor, t=0.7): """ E_seg: 1 for areas with embedding """ assert len(im.shape) == 3, im.shape assert im.dtype == torch.uint8 N = len(E_mask) im = im.clone() for i in range(N): assert len(E_mask[i].shape) == 2 assert len(embedding[i].shape) == 3 assert embed_map.shape[1] == embedding[i].shape[0] assert len(boxes_XYXY[i]) == 4 E = embedding[i] x0, y0, x1, y1 = boxes_XYXY[i] E = F.resize(E, (y1-y0, x1-x0), antialias=True) s = E_mask[i].float() s = (F.resize(s.squeeze()[None], (y1-y0, x1-x0), antialias=True) > 0).float() box = boxes_XYXY[i] im_ = crop_box(im, box) s = remove_pad(s, box, im.shape[1:]) E = remove_pad(E, box, im.shape[1:]) E_color = draw_cse(img=im_, E_seg=s[None], embedding=E[None], embed_map=embed_map)[0] E_color = E_color.to(im.device) s = s.bool().repeat(3, 1, 1) crop_box(im, box)[s] = (im_[s] * (1-t) + t * E_color[s]).byte() return im @torch.no_grad() def draw_segmentation_masks( image: torch.Tensor, masks: torch.Tensor, alpha: float = 0.8, colors: Optional[List[Union[str, Tuple[int, int, int]]]] = None, ) -> torch.Tensor: """ Draws segmentation masks on given RGB image. The values of the input image should be uint8 between 0 and 255. Args: image (Tensor): Tensor of shape (3, H, W) and dtype uint8. masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. alpha (float): Float number between 0 and 1 denoting the transparency of the masks. 0 means full transparency, 1 means no transparency. colors (list or None): List containing the colors of the masks. The colors can be represented as PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. When ``masks`` has a single entry of shape (H, W), you can pass a single color instead of a list with one element. By default, random colors are generated for each mask. Returns: img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. """ if not isinstance(image, torch.Tensor): raise TypeError(f"The image must be a tensor, got {type(image)}") elif image.dtype != torch.uint8: raise ValueError(f"The image dtype must be uint8, got {image.dtype}") elif image.dim() != 3: raise ValueError("Pass individual images, not batches") elif image.size()[0] != 3: raise ValueError("Pass an RGB image. Other Image formats are not supported") if masks.ndim == 2: masks = masks[None, :, :] if masks.ndim != 3: raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") if masks.dtype != torch.bool: raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") if masks.shape[-2:] != image.shape[-2:]: raise ValueError("The image and the masks must have the same height and width") num_masks = masks.size()[0] if num_masks == 0: return image if colors is None: colors = _generate_color_palette(num_masks) if not isinstance(colors[0], (Tuple, List)): colors = [colors for i in range(num_masks)] if colors is not None and num_masks > len(colors): raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") if not isinstance(colors, list): colors = [colors] if not isinstance(colors[0], (tuple, str)): raise ValueError("colors must be a tuple or a string, or a list thereof") if isinstance(colors[0], tuple) and len(colors[0]) != 3: raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") out_dtype = torch.uint8 colors_ = [] for color in colors: if isinstance(color, str): color = ImageColor.getrgb(color) color = torch.tensor(color, dtype=out_dtype, device=masks.device) colors_.append(color) img_to_draw = image.detach().clone() # TODO: There might be a way to vectorize this for mask, color in zip(masks, colors_): img_to_draw[:, mask] = color[:, None] out = image * (1 - alpha) + img_to_draw * alpha return out.to(out_dtype) def draw_mask(im: torch.Tensor, mask: torch.Tensor, t=0.2, color=(255, 255, 255), visualize_instances=True): """ Visualize mask where mask = 0. Supports multiple instances. mask shape: [N, C, H, W], where C is different instances in same image. """ orig_imshape = im.shape if mask.numel() == 0: return im assert len(mask.shape) in (3, 4), mask.shape mask = mask.view(-1, *mask.shape[-3:]) im = im.view(-1, *im.shape[-3:]) assert im.dtype == torch.uint8, im.dtype assert 0 <= t <= 1 if not visualize_instances: mask = mask.any(dim=1, keepdim=True) mask = mask.bool() kernel = torch.ones((3, 3), dtype=mask.dtype, device=mask.device) outer_border = binary_dilation(mask, kernel).logical_xor(mask) outer_border = outer_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 inner_border = binary_erosion(mask, kernel).logical_xor(mask) inner_border = inner_border.any(dim=1, keepdim=True).repeat(1, 3, 1, 1) > 0 mask = (mask == 0).any(dim=1, keepdim=True).repeat(1, 3, 1, 1) color = torch.tensor(color).to(im.device).byte().view(1, 3, 1, 1) # .repeat(1, *im.shape[1:]) color = color.repeat(im.shape[0], 1, *im.shape[-2:]) im[mask] = (im[mask] * (1-t) + t * color[mask]).byte() im[outer_border] = 255 im[inner_border] = 0 return im.view(*orig_imshape) def draw_cropped_masks(im: torch.Tensor, mask: torch.Tensor, boxes: torch.Tensor, **kwargs): for i, box in enumerate(boxes): x0, y0, x1, y1 = boxes[i] orig_shape = (y1-y0, x1-x0) m = F.resize(mask[i], orig_shape, F.InterpolationMode.NEAREST).squeeze()[None] m = remove_pad(m, boxes[i], im.shape[-2:]) crop_box(im, boxes[i]).set_(draw_mask(crop_box(im, boxes[i]), m)) return im def draw_cropped_keypoints(im: torch.Tensor, all_keypoints: torch.Tensor, boxes: torch.Tensor, **kwargs): n_boxes = boxes.shape[0] tops.assert_shape(all_keypoints, (n_boxes, 17, 3)) im = im.clone() for i, box in enumerate(boxes): x0, y0, x1, y1 = boxes[i] orig_shape = (y1-y0, x1-x0) keypoints = all_keypoints[i].clone() keypoints[:, 0] *= orig_shape[1] keypoints[:, 1] *= orig_shape[0] keypoints = keypoints.long() _, _, connectivity = get_coco_keypoints() connectivity = np.array(connectivity) visible = (keypoints[:, 2] > .5) # Remove padding from keypoints before visualization keypoints[:, 0] += min(x0, 0) keypoints[:, 1] += min(y0, 0) im_with_kp = draw_keypoints( crop_box(im, box), keypoints[None], colors="red", connectivity=connectivity, visible=visible[None]) crop_box(im, box).copy_(im_with_kp) return im