Spaces:
Running
Running
from PIL import Image | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from torchvision.utils import make_grid | |
from skimage.transform import resize | |
from .u2net import U2NET | |
def plot_attn_dino(attn, threshold_map, inputs, inds, output_path): | |
# currently supports one image (and not a batch) | |
plt.figure(figsize=(10, 5)) | |
plt.subplot(2, attn.shape[0] + 2, 1) | |
main_im = make_grid(inputs, normalize=True, pad_value=2) | |
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) | |
plt.imshow(main_im, interpolation='nearest') | |
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
plt.title("input im") | |
plt.axis("off") | |
plt.subplot(2, attn.shape[0] + 2, 2) | |
plt.imshow(attn.sum(0).numpy(), interpolation='nearest') | |
plt.title("atn map sum") | |
plt.axis("off") | |
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 3) | |
plt.imshow(threshold_map[-1].numpy(), interpolation='nearest') | |
plt.title("prob sum") | |
plt.axis("off") | |
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 4) | |
plt.imshow(threshold_map[:-1].sum(0).numpy(), interpolation='nearest') | |
plt.title("thresh sum") | |
plt.axis("off") | |
for i in range(attn.shape[0]): | |
plt.subplot(2, attn.shape[0] + 2, i + 3) | |
plt.imshow(attn[i].numpy()) | |
plt.axis("off") | |
plt.subplot(2, attn.shape[0] + 2, attn.shape[0] + 1 + i + 4) | |
plt.imshow(threshold_map[i].numpy()) | |
plt.axis("off") | |
plt.tight_layout() | |
plt.savefig(output_path) | |
plt.close() | |
def plot_attn_clip(attn, threshold_map, inputs, inds, output_path): | |
# currently supports one image (and not a batch) | |
plt.figure(figsize=(10, 5)) | |
plt.subplot(1, 3, 1) | |
main_im = make_grid(inputs, normalize=True, pad_value=2) | |
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) | |
plt.imshow(main_im, interpolation='nearest') | |
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
plt.title("input im") | |
plt.axis("off") | |
plt.subplot(1, 3, 2) | |
plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) | |
plt.title("attn map") | |
plt.axis("off") | |
plt.subplot(1, 3, 3) | |
threshold_map_ = (threshold_map - threshold_map.min()) / \ | |
(threshold_map.max() - threshold_map.min()) | |
plt.imshow(threshold_map_, interpolation='nearest', vmin=0, vmax=1) | |
plt.title("prob softmax") | |
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') | |
plt.axis("off") | |
plt.tight_layout() | |
plt.savefig(output_path) | |
plt.close() | |
def plot_attn(attn, threshold_map, inputs, inds, output_path, saliency_model): | |
if saliency_model == "dino": | |
plot_attn_dino(attn, threshold_map, inputs, inds, output_path) | |
elif saliency_model == "clip": | |
plot_attn_clip(attn, threshold_map, inputs, inds, output_path) | |
def fix_image_scale(im): | |
im_np = np.array(im) / 255 | |
height, width = im_np.shape[0], im_np.shape[1] | |
max_len = max(height, width) + 20 | |
new_background = np.ones((max_len, max_len, 3)) | |
y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 | |
new_background[y: y + height, x: x + width] = im_np | |
new_background = (new_background / new_background.max() * 255).astype(np.uint8) | |
new_im = Image.fromarray(new_background) | |
return new_im | |
def get_mask_u2net(pil_im, output_dir, u2net_path, device="cpu"): | |
# input preprocess | |
w, h = pil_im.size[0], pil_im.size[1] | |
im_size = min(w, h) | |
data_transforms = transforms.Compose([ | |
transforms.Resize(min(320, im_size), interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), | |
std=(0.26862954, 0.26130258, 0.27577711)), | |
]) | |
input_im_trans = data_transforms(pil_im).unsqueeze(0).to(device) | |
# load U^2 Net model | |
net = U2NET(in_ch=3, out_ch=1) | |
net.load_state_dict(torch.load(u2net_path)) | |
net.to(device) | |
net.eval() | |
# get mask | |
with torch.no_grad(): | |
d1, d2, d3, d4, d5, d6, d7 = net(input_im_trans.detach()) | |
pred = d1[:, 0, :, :] | |
pred = (pred - pred.min()) / (pred.max() - pred.min()) | |
predict = pred | |
predict[predict < 0.5] = 0 | |
predict[predict >= 0.5] = 1 | |
mask = torch.cat([predict, predict, predict], dim=0).permute(1, 2, 0) | |
mask = mask.cpu().numpy() | |
mask = resize(mask, (h, w), anti_aliasing=False) | |
mask[mask < 0.5] = 0 | |
mask[mask >= 0.5] = 1 | |
# predict_np = predict.clone().cpu().data.numpy() | |
im = Image.fromarray((mask[:, :, 0] * 255).astype(np.uint8)).convert('RGB') | |
save_path_ = output_dir / "mask.png" | |
im.save(save_path_) | |
im_np = np.array(pil_im) | |
im_np = im_np / im_np.max() | |
im_np = mask * im_np | |
im_np[mask == 0] = 1 | |
im_final = (im_np / im_np.max() * 255).astype(np.uint8) | |
im_final = Image.fromarray(im_final) | |
# free u2net | |
del net | |
torch.cuda.empty_cache() | |
return im_final, predict | |