import cv2 import numpy as np # class to rgb colour pallet color_dict = { 0: (0, 0, 0), # BG 1: (239, 164, 0), # EX 2: (0, 186, 127), # HE 3: (0, 185, 255), # SE 4: (34, 80, 242), # MA 5: (73, 73, 73), # OD 6: (255, 255, 255), # VB } def rgb_to_onehot(rgb_arr, color_dict): """ Converts a rgb label map to onehot label map defined by color_dict Parameters: rgb_arr (array): rgb label mask with shape (H x W x 3) color_dict (dict): dictionary mapping of class to colour Returns: arr (array): onehot label map of shape (H x W x n_classes) """ num_classes = len(color_dict) shape = rgb_arr.shape[:2]+(num_classes,) arr = np.zeros(shape, dtype=np.int8) for i, cls in enumerate(color_dict): arr[:, :, i] = np.all(rgb_arr.reshape((-1, 3)) == color_dict[i], axis=1).reshape(shape[:2]) return arr def onehot_to_rgb(onehot_arr, color_dict): """ Converts an onehot label map to rgb label map defined by color_dict Parameters: onehot_arr (array): onehot label mask with shape (H x W x n_classes) color_dict (dict): dictionary mapping of class to colour Returns: arr (array): rgb label map of shape (H x W x 3) """ shape = onehot_arr.shape[:2] mask = np.argmax(onehot_arr, axis=-1) arr = np.zeros(shape+(3,), dtype=np.uint8) for i, cls in enumerate(color_dict): arr = arr + np.tile(color_dict[cls], shape + (1,)) * (mask[..., None] == cls) return arr def fix_pred_label(labels): """ Post-processing fixes for the prediction of VB and BG label class, the Vitrous Body should be consistently spherical on a black background Parameters: labels (tensor): A 4-D array of predicted label with shape (batch x H x W x 7) Returns: fixed_labels (array): shape (batch x H x W x 7) """ shape = labels.shape[1:-1] VB = np.uint8(cv2.circle(np.zeros(shape), (shape[0]//2, shape[1]//2), min(shape) // 2, 1, -1))[..., None] BG = np.uint8(VB == 0) VB = VB - np.sum(labels[..., 1:-1], axis=-1)[..., None] BG = np.broadcast_to(BG, VB.shape) fixed_labels = np.concatenate([BG, labels[..., 1:-1], VB], axis=-1) return fixed_labels