|
from transformers import ViTFeatureExtractor, ViTForImageClassification |
|
import warnings |
|
from torchvision import transforms |
|
from datasets import load_dataset |
|
from pytorch_grad_cam import run_dff_on_image, GradCAM |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from PIL import Image |
|
import numpy as np |
|
import cv2 as cv |
|
import torch |
|
from typing import List, Callable, Optional |
|
import logging |
|
from face_grab import FaceGrabber |
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
class HuggingfaceToTensorModelWrapper(torch.nn.Module): |
|
def __init__(self, model): |
|
super(HuggingfaceToTensorModelWrapper, self).__init__() |
|
self.model = model |
|
|
|
def forward(self, x): |
|
return self.model(x).logits |
|
|
|
|
|
|
|
class GradCam(): |
|
def __init__(self): |
|
pass |
|
|
|
def category_name_to_index(self, model, category_name): |
|
name_to_index = dict((v, k) for k, v in model.config.id2label.items()) |
|
return name_to_index[category_name] |
|
|
|
def run_grad_cam_on_image(self, model: torch.nn.Module, |
|
target_layer: torch.nn.Module, |
|
targets_for_gradcam: List[Callable], |
|
reshape_transform: Optional[Callable], |
|
input_tensor: torch.nn.Module, |
|
input_image: Image, |
|
method: Callable=GradCAM, |
|
threshold: float=0.5): |
|
with method(model=HuggingfaceToTensorModelWrapper(model), |
|
target_layers=[target_layer], |
|
reshape_transform=reshape_transform) as cam: |
|
|
|
|
|
repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1) |
|
|
|
batch_results = cam(input_tensor=repeated_tensor, |
|
targets=targets_for_gradcam) |
|
results = [] |
|
for grayscale_cam in batch_results: |
|
grayscale_cam[grayscale_cam < threshold] = 0 |
|
visualization = show_cam_on_image(np.float32(input_image)/255, |
|
grayscale_cam, |
|
use_rgb=True) |
|
|
|
visualization = cv.resize(visualization, |
|
(visualization.shape[1]//2, visualization.shape[0]//2)) |
|
results.append(visualization) |
|
return np.hstack(results) |
|
|
|
|
|
def get_top_category(self, model, img_tensor, top_k=5): |
|
logits = model(img_tensor.unsqueeze(0)).logits |
|
probabilities = torch.nn.functional.softmax(logits, dim=1) |
|
topIdx = logits.cpu()[0, :].detach().numpy().argsort()[-1] |
|
topClass = model.config.id2label[topIdx] |
|
topScore = probabilities[0][topIdx].item() |
|
return [{"label": topClass, "score": topScore}] |
|
|
|
def reshape_transform_vit_huggingface(self, x): |
|
activations = x[:, 1:, :] |
|
activations = activations.view(activations.shape[0], |
|
14, 14, activations.shape[2]) |
|
activations = activations.transpose(2, 3).transpose(1, 2) |
|
return activations |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
faceGrabber = FaceGrabber() |
|
gradCam = GradCam() |
|
|
|
image = Image.open("Feature-Image-74.jpg").convert("RGB") |
|
face = faceGrabber.grab_faces(np.array(image)) |
|
if face is not None: |
|
image = Image.fromarray(face) |
|
|
|
img_tensor = transforms.ToTensor()(image) |
|
|
|
model = ViTForImageClassification.from_pretrained("ongkn/attraction-classifier") |
|
targets_for_gradcam = [ClassifierOutputTarget(gradCam.category_name_to_index(model, "pos")), |
|
ClassifierOutputTarget(gradCam.category_name_to_index(model, "neg"))] |
|
target_layer_dff = model.vit.layernorm |
|
target_layer_gradcam = model.vit.encoder.layer[-2].output |
|
image_resized = image.resize((224, 224)) |
|
tensor_resized = transforms.ToTensor()(image_resized) |
|
|
|
dff_image = run_dff_on_image(model=model, |
|
target_layer=target_layer_dff, |
|
classifier=model.classifier, |
|
img_pil=image_resized, |
|
img_tensor=tensor_resized, |
|
reshape_transform=gradCam.reshape_transform_vit_huggingface, |
|
n_components=5, |
|
top_k=10, |
|
threshold=0, |
|
output_size=None) |
|
cv.namedWindow("DFF Image", cv.WINDOW_KEEPRATIO) |
|
cv.imshow("DFF Image", cv.cvtColor(dff_image, cv.COLOR_BGR2RGB)) |
|
cv.resizeWindow("DFF Image", 2500, 700) |
|
|
|
|
|
grad_cam_image = gradCam.run_grad_cam_on_image(model=model, |
|
target_layer=target_layer_gradcam, |
|
targets_for_gradcam=targets_for_gradcam, |
|
input_tensor=tensor_resized, |
|
input_image=image_resized, |
|
reshape_transform=gradCam.reshape_transform_vit_huggingface, |
|
threshold=0) |
|
cv.namedWindow("Grad-CAM Image", cv.WINDOW_KEEPRATIO) |
|
cv.imshow("Grad-CAM Image", grad_cam_image) |
|
cv.resizeWindow("Grad-CAM Image", 2000, 1250) |
|
cv.waitKey(0) |
|
cv.destroyAllWindows() |
|
gradCam.print_top_categories(model, tensor_resized) |