sayakpaul's picture
sayakpaul HF Staff
add files
c4b2b37
import argparse
import os
import imageio
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from data.Imagenet import Imagenet_Segmentation
from numpy import *
from PIL import Image
from sklearn.metrics import precision_recall_curve
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import render
from utils.iou import IoU
from utils.metrices import *
from utils.saver import Saver
from ViT_explanation_generator import LRP, Baselines
from ViT_LRP import vit_base_patch16_224 as vit_LRP
from ViT_new import vit_base_patch16_224
from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP
plt.switch_backend("agg")
# hyperparameters
num_workers = 0
batch_size = 1
cls = [
"airplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"dining table",
"dog",
"horse",
"motobike",
"person",
"potted plant",
"sheep",
"sofa",
"train",
"tv",
]
# Args
parser = argparse.ArgumentParser(description="Training multi-class classifier")
parser.add_argument(
"--arc", type=str, default="vgg", metavar="N", help="Model architecture"
)
parser.add_argument(
"--train_dataset", type=str, default="imagenet", metavar="N", help="Testing Dataset"
)
parser.add_argument(
"--method",
type=str,
default="grad_rollout",
choices=[
"rollout",
"lrp",
"transformer_attribution",
"full_lrp",
"lrp_last_layer",
"attn_last_layer",
"attn_gradcam",
],
help="",
)
parser.add_argument("--thr", type=float, default=0.0, help="threshold")
parser.add_argument("--K", type=int, default=1, help="new - top K results")
parser.add_argument("--save-img", action="store_true", default=False, help="")
parser.add_argument("--no-ia", action="store_true", default=False, help="")
parser.add_argument("--no-fx", action="store_true", default=False, help="")
parser.add_argument("--no-fgx", action="store_true", default=False, help="")
parser.add_argument("--no-m", action="store_true", default=False, help="")
parser.add_argument("--no-reg", action="store_true", default=False, help="")
parser.add_argument("--is-ablation", type=bool, default=False, help="")
parser.add_argument("--imagenet-seg-path", type=str, required=True)
args = parser.parse_args()
args.checkname = args.method + "_" + args.arc
alpha = 2
cuda = torch.cuda.is_available()
device = torch.device("cuda" if cuda else "cpu")
# Define Saver
saver = Saver(args)
saver.results_dir = os.path.join(saver.experiment_dir, "results")
if not os.path.exists(saver.results_dir):
os.makedirs(saver.results_dir)
if not os.path.exists(os.path.join(saver.results_dir, "input")):
os.makedirs(os.path.join(saver.results_dir, "input"))
if not os.path.exists(os.path.join(saver.results_dir, "explain")):
os.makedirs(os.path.join(saver.results_dir, "explain"))
args.exp_img_path = os.path.join(saver.results_dir, "explain/img")
if not os.path.exists(args.exp_img_path):
os.makedirs(args.exp_img_path)
args.exp_np_path = os.path.join(saver.results_dir, "explain/np")
if not os.path.exists(args.exp_np_path):
os.makedirs(args.exp_np_path)
# Data
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
test_img_trans = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
normalize,
]
)
test_lbl_trans = transforms.Compose(
[
transforms.Resize((224, 224), Image.NEAREST),
]
)
ds = Imagenet_Segmentation(
args.imagenet_seg_path, transform=test_img_trans, target_transform=test_lbl_trans
)
dl = DataLoader(
ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False
)
# Model
model = vit_base_patch16_224(pretrained=True).cuda()
baselines = Baselines(model)
# LRP
model_LRP = vit_LRP(pretrained=True).cuda()
model_LRP.eval()
lrp = LRP(model_LRP)
# orig LRP
model_orig_LRP = vit_orig_LRP(pretrained=True).cuda()
model_orig_LRP.eval()
orig_lrp = LRP(model_orig_LRP)
metric = IoU(2, ignore_index=-1)
iterator = tqdm(dl)
model.eval()
def compute_pred(output):
pred = output.data.max(1, keepdim=True)[
1
] # get the index of the max log-probability
# pred[0, 0] = 282
# print('Pred cls : ' + str(pred))
T = pred.squeeze().cpu().numpy()
T = np.expand_dims(T, 0)
T = (T[:, np.newaxis] == np.arange(1000)) * 1.0
T = torch.from_numpy(T).type(torch.FloatTensor)
Tt = T.cuda()
return Tt
def eval_batch(image, labels, evaluator, index):
evaluator.zero_grad()
# Save input image
if args.save_img:
img = image[0].permute(1, 2, 0).data.cpu().numpy()
img = 255 * (img - img.min()) / (img.max() - img.min())
img = img.astype("uint8")
Image.fromarray(img, "RGB").save(
os.path.join(saver.results_dir, "input/{}_input.png".format(index))
)
Image.fromarray(
(labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype(
"uint8"
),
"RGB",
).save(os.path.join(saver.results_dir, "input/{}_mask.png".format(index)))
image.requires_grad = True
image = image.requires_grad_()
predictions = evaluator(image)
# segmentation test for the rollout baseline
if args.method == "rollout":
Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape(
batch_size, 1, 14, 14
)
# segmentation test for the LRP baseline (this is full LRP, not partial)
elif args.method == "full_lrp":
Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape(
batch_size, 1, 224, 224
)
# segmentation test for our method
elif args.method == "transformer_attribution":
Res = lrp.generate_LRP(
image.cuda(), start_layer=1, method="transformer_attribution"
).reshape(batch_size, 1, 14, 14)
# segmentation test for the partial LRP baseline (last attn layer)
elif args.method == "lrp_last_layer":
Res = orig_lrp.generate_LRP(
image.cuda(), method="last_layer", is_ablation=args.is_ablation
).reshape(batch_size, 1, 14, 14)
# segmentation test for the raw attention baseline (last attn layer)
elif args.method == "attn_last_layer":
Res = orig_lrp.generate_LRP(
image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation
).reshape(batch_size, 1, 14, 14)
# segmentation test for the GradCam baseline (last attn layer)
elif args.method == "attn_gradcam":
Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14)
if args.method != "full_lrp":
# interpolate to full image size (224,224)
Res = torch.nn.functional.interpolate(
Res, scale_factor=16, mode="bilinear"
).cuda()
# threshold between FG and BG is the mean
Res = (Res - Res.min()) / (Res.max() - Res.min())
ret = Res.mean()
Res_1 = Res.gt(ret).type(Res.type())
Res_0 = Res.le(ret).type(Res.type())
Res_1_AP = Res
Res_0_AP = 1 - Res
Res_1[Res_1 != Res_1] = 0
Res_0[Res_0 != Res_0] = 0
Res_1_AP[Res_1_AP != Res_1_AP] = 0
Res_0_AP[Res_0_AP != Res_0_AP] = 0
# TEST
pred = Res.clamp(min=args.thr) / Res.max()
pred = pred.view(-1).data.cpu().numpy()
target = labels.view(-1).data.cpu().numpy()
# print("target", target.shape)
output = torch.cat((Res_0, Res_1), 1)
output_AP = torch.cat((Res_0_AP, Res_1_AP), 1)
if args.save_img:
# Save predicted mask
mask = F.interpolate(Res_1, [64, 64], mode="bilinear")
mask = mask[0].squeeze().data.cpu().numpy()
# mask = Res_1[0].squeeze().data.cpu().numpy()
mask = 255 * mask
mask = mask.astype("uint8")
imageio.imsave(
os.path.join(args.exp_img_path, "mask_" + str(index) + ".jpg"), mask
)
relevance = F.interpolate(Res, [64, 64], mode="bilinear")
relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy()
# relevance = Res[0].permute(1, 2, 0).data.cpu().numpy()
hm = np.sum(relevance, axis=-1)
maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap="seismic") * 255).astype(
np.uint8
)
imageio.imsave(
os.path.join(args.exp_img_path, "heatmap_" + str(index) + ".jpg"), maps
)
# Evaluate Segmentation
batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0
batch_ap, batch_f1 = 0, 0
# Segmentation resutls
correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0])
inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2)
batch_correct += correct
batch_label += labeled
batch_inter += inter
batch_union += union
# print("output", output.shape)
# print("ap labels", labels.shape)
# ap = np.nan_to_num(get_ap_scores(output, labels))
ap = np.nan_to_num(get_ap_scores(output_AP, labels))
f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0]))
batch_ap += ap
batch_f1 += f1
return (
batch_correct,
batch_label,
batch_inter,
batch_union,
batch_ap,
batch_f1,
pred,
target,
)
total_inter, total_union, total_correct, total_label = (
np.int64(0),
np.int64(0),
np.int64(0),
np.int64(0),
)
total_ap, total_f1 = [], []
predictions, targets = [], []
for batch_idx, (image, labels) in enumerate(iterator):
if args.method == "blur":
images = (image[0].cuda(), image[1].cuda())
else:
images = image.cuda()
labels = labels.cuda()
# print("image", image.shape)
# print("lables", labels.shape)
correct, labeled, inter, union, ap, f1, pred, target = eval_batch(
images, labels, model, batch_idx
)
predictions.append(pred)
targets.append(target)
total_correct += correct.astype("int64")
total_label += labeled.astype("int64")
total_inter += inter.astype("int64")
total_union += union.astype("int64")
total_ap += [ap]
total_f1 += [f1]
pixAcc = (
np.float64(1.0)
* total_correct
/ (np.spacing(1, dtype=np.float64) + total_label)
)
IoU = (
np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union)
)
mIoU = IoU.mean()
mAp = np.mean(total_ap)
mF1 = np.mean(total_f1)
iterator.set_description(
"pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f" % (pixAcc, mIoU, mAp, mF1)
)
predictions = np.concatenate(predictions)
targets = np.concatenate(targets)
pr, rc, thr = precision_recall_curve(targets, predictions)
np.save(os.path.join(saver.experiment_dir, "precision.npy"), pr)
np.save(os.path.join(saver.experiment_dir, "recall.npy"), rc)
plt.figure()
plt.plot(rc, pr)
plt.savefig(os.path.join(saver.experiment_dir, "PR_curve_{}.png".format(args.method)))
txtfile = os.path.join(saver.experiment_dir, "result_mIoU_%.4f.txt" % mIoU)
# txtfile = 'result_mIoU_%.4f.txt' % mIoU
fh = open(txtfile, "w")
print("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
print("Mean AP over %d classes: %.4f\n" % (2, mAp))
print("Mean F1 over %d classes: %.4f\n" % (2, mF1))
fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU))
fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100))
fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp))
fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1))
fh.close()