hjc-owo
init repo
966ae59
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import transforms
from torchvision.utils import make_grid
from skimage.transform import resize
from .u2net import U2NET
def plot_attn_dino(attn, threshold_map, inputs, inds, output_path):
# currently supports one image (and not a batch)
plt.figure(figsize=(10, 5))
plt.subplot(2, attn.shape[0] + 2, 1)
main_im = make_grid(inputs, normalize=True, pad_value=2)
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
plt.imshow(main_im, interpolation='nearest')
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.title("input im")
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, 2)
plt.imshow(attn.sum(0).numpy(), interpolation='nearest')
plt.title("atn map sum")
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3)
plt.imshow(threshold_map[-1].numpy(), interpolation='nearest')
plt.title("prob sum")
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4)
plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest')
plt.title("thresh sum")
plt.axis("off")
for i in range(attn.shape[0]):
plt.subplot(2, attn.shape[0] + 2, i + 3)
plt.imshow(attn[i].numpy())
plt.axis("off")
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4)
plt.imshow(threshold_map[i].numpy())
plt.axis("off")
plt.tight_layout()
plt.savefig(output_path)
plt.close()
def plot_attn_clip(attn, threshold_map, inputs, inds, output_path):
# currently supports one image (and not a batch)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
main_im = make_grid(inputs, normalize=True, pad_value=2)
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0))
plt.imshow(main_im, interpolation='nearest')
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.title("input im")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1)
plt.title("attn map")
plt.axis("off")
plt.subplot(1, 3, 3)
threshold_map_ = (threshold_map - threshold_map.min()) / \
(threshold_map.max() - threshold_map.min())
plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1)
plt.title("prob softmax")
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o')
plt.axis("off")
plt.tight_layout()
plt.savefig(output_path)
plt.close()
def plot_attn(attn, threshold_map, inputs, inds, output_path, saliency_model):
if saliency_model == "dino":
plot_attn_dino(attn, threshold_map, inputs, inds, output_path)
elif saliency_model == "clip":
plot_attn_clip(attn, threshold_map, inputs, inds, output_path)
def fix_image_scale(im):
im_np = np.array(im) / 255
height, width = im_np.shape[0], im_np.shape[1]
max_len = max(height, width) + 20
new_background = np.ones((max_len, max_len, 3))
y, x = max_len // 2 - height // 2, max_len // 2 - width // 2
new_background[y: y + height, x: x + width] = im_np
new_background = (new_background / new_background.max() * 255).astype(np.uint8)
new_im = Image.fromarray(new_background)
return new_im
def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"):
# input preprocess
w, h = pil_im.size[0], pil_im.size[1]
im_size = min(w, h)
data_transforms = transforms.Compose([
transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711)),
])
input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device)
# load U^2 Net model
net = U2NET(in_ch=3, out_ch=1)
net.load_state_dict(torch.load(u2net_path))
net.to(device)
net.eval()
# get mask
with torch.no_grad():
d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach())
pred = d1[:, 0, :, :]
pred = (pred - pred.min()) / (pred.max() - pred.min())
predict = pred
predict[predict < 0.5] = 0
predict[predict >= 0.5] = 1
mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0)
mask = mask.cpu().numpy()
mask = resize(mask, (h, w), anti_aliasing=False)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# predict_np = predict.clone().cpu().data.numpy()
im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB')
save_path_ = output_dir / "mask.png"
im.save(save_path_)
im_np = np.array(pil_im)
im_np = im_np / im_np.max()
im_np = mask * im_np
im_np[mask == 0] = 1
im_final = (im_np / im_np.max() * 255).astype(np.uint8)
im_final = Image.fromarray(im_final)
# free u2net
del net
torch.cuda.empty_cache()
return im_final, predict