from PIL import Image import matplotlib.pyplot as plt import numpy as np import torch from torchvision import transforms from torchvision.utils import make_grid from skimage.transform import resize from .u2net import U2NET def plot_attn_dino(attn, threshold_map, inputs, inds, output_path): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(2, attn.shape[0] + 2, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, 2) plt.imshow(attn.sum(0).numpy(), interpolation='nearest') plt.title("atn map sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3) plt.imshow(threshold_map[-1].numpy(), interpolation='nearest') plt.title("prob sum") plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4) plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest') plt.title("thresh sum") plt.axis("off") for i in range(attn.shape[0]): plt.subplot(2, attn.shape[0] + 2, i + 3) plt.imshow(attn[i].numpy()) plt.axis("off") plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4) plt.imshow(threshold_map[i].numpy()) plt.axis("off") plt.tight_layout() plt.savefig(output_path) plt.close() def plot_attn_clip(attn, threshold_map, inputs, inds, output_path): # currently supports one image (and not a batch) plt.figure(figsize=(10, 5)) plt.subplot(1, 3, 1) main_im = make_grid(inputs, normalize=True, pad_value=2) main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) plt.imshow(main_im, interpolation='nearest') plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.title("input im") plt.axis("off") plt.subplot(1, 3, 2) plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) plt.title("attn map") plt.axis("off") plt.subplot(1, 3, 3) threshold_map_ = (threshold_map - threshold_map.min()) / \ (threshold_map.max() - threshold_map.min()) plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1) plt.title("prob softmax") plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') plt.axis("off") plt.tight_layout() plt.savefig(output_path) plt.close() def plot_attn(attn, threshold_map, inputs, inds, output_path, saliency_model): if saliency_model == "dino": plot_attn_dino(attn, threshold_map, inputs, inds, output_path) elif saliency_model == "clip": plot_attn_clip(attn, threshold_map, inputs, inds, output_path) def fix_image_scale(im): im_np = np.array(im) / 255 height, width = im_np.shape[0], im_np.shape[1] max_len = max(height, width) + 20 new_background = np.ones((max_len, max_len, 3)) y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 new_background[y: y + height, x: x + width] = im_np new_background = (new_background / new_background.max() * 255).astype(np.uint8) new_im = Image.fromarray(new_background) return new_im def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"): # input preprocess w, h = pil_im.size[0], pil_im.size[1] im_size = min(w, h) data_transforms = transforms.Compose([ transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)), ]) input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device) # load U^2 Net model net = U2NET(in_ch=3, out_ch=1) net.load_state_dict(torch.load(u2net_path)) net.to(device) net.eval() # get mask with torch.no_grad(): d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach()) pred = d1[:, 0, :, :] pred = (pred - pred.min()) / (pred.max() - pred.min()) predict = pred predict[predict < 0.5] = 0 predict[predict >= 0.5] = 1 mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0) mask = mask.cpu().numpy() mask = resize(mask, (h, w), anti_aliasing=False) mask[mask < 0.5] = 0 mask[mask >= 0.5] = 1 # predict_np = predict.clone().cpu().data.numpy() im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB') save_path_ = output_dir / "mask.png" im.save(save_path_) im_np = np.array(pil_im) im_np = im_np / im_np.max() im_np = mask * im_np im_np[mask == 0] = 1 im_final = (im_np / im_np.max() * 255).astype(np.uint8) im_final = Image.fromarray(im_final) # free u2net del net torch.cuda.empty_cache() return im_final, predict