from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset from fake_face_detection.metrics.compute_metrics import compute_metrics from torch.utils.tensorboard import SummaryWriter from PIL.JpegImagePlugin import JpegImageFile from torch.utils.data import DataLoader from torch.nn import functional as F from torchvision import transforms import matplotlib.pyplot as plt from glob import glob from PIL import Image from typing import * import pandas as pd from math import * import numpy as np import torch import os def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple): # recuperate the image as a numpy array if isinstance(image, str): with Image.open(image) as img: img = np.array(transforms.Resize(size)(img)) else: img = np.array(transforms.Resize(size)(image)) # recuperate the attention provided by the last patch (notice that we eliminate 1 because of the +1 added by the convolutation layer) attention = attention[:, -1, 1:] # calculate the mean attention attention = attention.mean(axis = 0) # let us reshape transform the image to a numpy array # calculate the scale factor scale_factor = size[0] * size[1] / (patch_size[0] * patch_size[1]) # rescale the attention with the nearest scaler attention = F.interpolate(attention.reshape(1, 1, -1), scale_factor=scale_factor, mode='nearest') # let us reshape the attention to the right size attention = attention.reshape(size[0], size[1], 1) # recuperate the result attention_image = img / 255 * attention.numpy() return attention_image def make_predictions(test_dataset: FakeFaceDetectionDataset, model, log_dir: str = "fake_face_logs", tag: str = "Attentions", batch_size: int = 3, size: tuple = (224, 224), patch_size: tuple = (14, 14), figsize: tuple = (24, 24)): with torch.no_grad(): _ = model.eval() # initialize the logger writer = SummaryWriter(os.path.join(log_dir, "attentions")) # let us recuperate the images and labels images = test_dataset.images labels = test_dataset.labels # let us initialize the predictions predictions = {'attentions': [], 'predictions': [], 'true_labels': labels, 'predicted_labels': []} # let us initialize the dataloader test_dataloader = DataLoader(test_dataset, batch_size=batch_size) # get the loss loss = 0 for data in test_dataloader: # recuperate the pixel values pixel_values = data['pixel_values'][0].cuda() # recuperate the labels labels_ = data['labels'].cuda() # # recuperate the outputs outputs = model(pixel_values, labels = labels_, output_attentions = True) # recuperate the predictions predictions['predictions'].append(torch.softmax(outputs.logits.detach().cpu(), axis = -1).numpy()) # recuperate the attentions of the last encoder layer predictions['attentions'].append(outputs.attentions[-1].detach().cpu()) # add the loss loss += outputs.loss.detach().cpu().item() predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0) predictions['attentions'] = torch.concatenate(predictions['attentions'], axis = 0) predictions['predicted_labels'] = np.argmax(predictions['predictions'], axis = -1).tolist() # let us calculate the metrics metrics = compute_metrics((predictions['predictions'], np.array(predictions['true_labels']))) metrics['loss'] = loss / len(test_dataloader) # for each image we will visualize his attention nrows = ceil(sqrt(len(images))) fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize = figsize) axes = axes.flat for i in range(len(images)): attention_image = get_attention(images[i], predictions['attentions'][i], size, patch_size) axes[i].imshow(attention_image) axes[i].set_title(f'Image {i + 1}') axes[i].axis('off') fig.tight_layout() [fig.delaxes(axes[i]) for i in range(len(images), nrows * nrows)] writer.add_figure(tag, fig) # let us remove the predictions and the attentions del predictions['predictions'] del predictions['attentions'] # let us recuperate the metrics and the predictions return pd.DataFrame(predictions), metrics