=
add attention
3bb44c5
from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset
from fake_face_detection.metrics.compute_metrics import compute_metrics
from fake_face_detection.utils.smoothest_attention import smooth_attention
from torch.utils.tensorboard import SummaryWriter
from PIL.JpegImagePlugin import JpegImageFile
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from typing import *
import pandas as pd
from math import *
import numpy as np
import torch
import os
def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple, scale: int = 50, head: int = 1, smooth_iter: int = 2, smooth_thres: float = 0.01, smooth_scale: float = 0.2, smooth_size = 5):
# recuperate the image as a numpy array
if isinstance(image, str):
with Image.open(image) as img:
img = np.array(transforms.Resize(size)(img))
else:
img = np.array(transforms.Resize(size)(image))
# recuperate the attention provided by the last patch (notice that we eliminate 1 because of the +1 added by the convolutation layer)
attention = attention[:, -1, 1:]
# calculate the mean attention
attention = attention[head - 1]
# let us reshape transform the image to a numpy array
# calculate the scale factor
scale_factor = size[0] * size[1] / (patch_size[0] * patch_size[1])
# rescale the attention with the nearest scaler
attention = F.interpolate(attention.reshape(1, 1, -1), scale_factor=scale_factor,
mode='nearest')
# let us reshape the attention to the right size
attention = attention.reshape(size[0], size[1], 1)
# add the smoothest attention
attention = smooth_attention(attention, smooth_iter, smooth_thres, smooth_scale, smooth_size)
# recuperate the result
attention_image = img / 255 * attention.numpy() * scale
return np.clip(attention_image, 0, 1)
def make_predictions(test_dataset: FakeFaceDetectionDataset,
model,
log_dir: str = "fake_face_logs",
tag: str = "Attentions",
batch_size: int = 3,
size: tuple = (224, 224),
patch_size: tuple = (14, 14),
figsize: tuple = (24, 24),
attention_scale: int = 50,
show: bool = True,
head: int = 1,
smooth_iter: int = 2,
smooth_thres: float = 0.01,
smooth_scale: float = 0.2,
smooth_size = 5):
"""Make predictions with a vision transformer model
Args:
test_dataset (FakeFaceDetectionDataset): The test dataset
model (_type_): The model
log_dir (str, optional): The log directory. Defaults to "fake_face_logs".
tag (str, optional): The tag. Defaults to "Attentions".
batch_size (int, optional): The batch size. Defaults to 3.
size (tuple, optional): The size of the attention image. Defaults to (224, 224).
patch_size (tuple, optional): The path size. Defaults to (14, 14).
figsize (tuple, optional): The figure size. Defaults to (24, 24).
attention_scale (int, optional): The attention scale. Defaults to 50.
show (bool, optional): A boolean value indicating if we want to recuperate the figure. Defaults to True.
head (int, optional): The head number. Defaults to 1.
smooth_iter (int, optional): The number of iterations for the smoothest attention. Defaults to 2.
smooth_thres (float, optional): The threshold for the smoothest attention. Defaults to 0.01.
smooth_scale (float, optional): The scale for the smoothest attention. Defaults to 0.2.
smooth_size ([type], optional): The size for the smoothest attention. Defaults to 5.
Returns:
Union[Tuple[pd.DataFrame, dict], Tuple[pd.DataFame, dict, figure]]: The return prediction and the metrics
"""
with torch.no_grad():
_ = model.eval()
# initialize the logger
writer = SummaryWriter(os.path.join(log_dir, "attentions"))
# let us recuperate the images and labels
images = test_dataset.images
labels = test_dataset.labels
# let us initialize the predictions
predictions = {'attentions': [], 'predictions': [], 'true_labels': labels, 'predicted_labels': []}
# let us initialize the dataloader
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
# get the loss
loss = 0
for data in test_dataloader:
# recuperate the pixel values
pixel_values = data['pixel_values'][0]
# recuperate the labels
labels_ = data['labels']
# # recuperate the outputs
outputs = model(pixel_values, labels = labels_, output_attentions = True)
# recuperate the predictions
predictions['predictions'].append(torch.softmax(outputs.logits.detach(), axis = -1).numpy())
# recuperate the attentions of the last encoder layer
predictions['attentions'].append(outputs.attentions[-1].detach())
# add the loss
loss += outputs.loss.detach().item()
predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0)
predictions['attentions'] = torch.concatenate(predictions['attentions'], axis = 0)
predictions['predicted_labels'] = np.argmax(predictions['predictions'], axis = -1).tolist()
# let us calculate the metrics
metrics = compute_metrics((predictions['predictions'], np.array(predictions['true_labels'])))
metrics['loss'] = loss / len(test_dataloader)
# for each image we will visualize his attention
nrows = ceil(sqrt(len(images)))
fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize = figsize)
axes = axes.flat
for i in range(len(images)):
attention_image = get_attention(images[i], predictions['attentions'][i], size, patch_size, attention_scale, head, smooth_iter, smooth_thres, smooth_scale, smooth_size)
axes[i].imshow(attention_image)
axes[i].set_title(f'Image {i + 1}')
axes[i].axis('off')
fig.tight_layout()
[fig.delaxes(axes[i]) for i in range(len(images), nrows * nrows)]
writer.add_figure(tag, fig)
# let us remove the predictions and the attentions
del predictions['predictions']
del predictions['attentions']
# show the figure if necessary
if show: return pd.DataFrame(predictions), metrics, fig
else:
# let us recuperate the metrics and the predictions
return pd.DataFrame(predictions), metrics