Spaces:
Build error
Build error
from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset | |
from fake_face_detection.metrics.compute_metrics import compute_metrics | |
from fake_face_detection.utils.smoothest_attention import smooth_attention | |
from torch.utils.tensorboard import SummaryWriter | |
from PIL.JpegImagePlugin import JpegImageFile | |
from torch.utils.data import DataLoader | |
from torch.nn import functional as F | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
from glob import glob | |
from PIL import Image | |
from typing import * | |
import pandas as pd | |
from math import * | |
import numpy as np | |
import torch | |
import os | |
def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple, scale: int = 50, head: int = 1, smooth_iter: int = 2, smooth_thres: float = 0.01, smooth_scale: float = 0.2, smooth_size = 5): | |
# recuperate the image as a numpy array | |
if isinstance(image, str): | |
with Image.open(image) as img: | |
img = np.array(transforms.Resize(size)(img)) | |
else: | |
img = np.array(transforms.Resize(size)(image)) | |
# recuperate the attention provided by the last patch (notice that we eliminate 1 because of the +1 added by the convolutation layer) | |
attention = attention[:, -1, 1:] | |
# calculate the mean attention | |
attention = attention[head - 1] | |
# let us reshape transform the image to a numpy array | |
# calculate the scale factor | |
scale_factor = size[0] * size[1] / (patch_size[0] * patch_size[1]) | |
# rescale the attention with the nearest scaler | |
attention = F.interpolate(attention.reshape(1, 1, -1), scale_factor=scale_factor, | |
mode='nearest') | |
# let us reshape the attention to the right size | |
attention = attention.reshape(size[0], size[1], 1) | |
# add the smoothest attention | |
attention = smooth_attention(attention, smooth_iter, smooth_thres, smooth_scale, smooth_size) | |
# recuperate the result | |
attention_image = img / 255 * attention.numpy() * scale | |
return np.clip(attention_image, 0, 1) | |
def make_predictions(test_dataset: FakeFaceDetectionDataset, | |
model, | |
log_dir: str = "fake_face_logs", | |
tag: str = "Attentions", | |
batch_size: int = 3, | |
size: tuple = (224, 224), | |
patch_size: tuple = (14, 14), | |
figsize: tuple = (24, 24), | |
attention_scale: int = 50, | |
show: bool = True, | |
head: int = 1, | |
smooth_iter: int = 2, | |
smooth_thres: float = 0.01, | |
smooth_scale: float = 0.2, | |
smooth_size = 5): | |
"""Make predictions with a vision transformer model | |
Args: | |
test_dataset (FakeFaceDetectionDataset): The test dataset | |
model (_type_): The model | |
log_dir (str, optional): The log directory. Defaults to "fake_face_logs". | |
tag (str, optional): The tag. Defaults to "Attentions". | |
batch_size (int, optional): The batch size. Defaults to 3. | |
size (tuple, optional): The size of the attention image. Defaults to (224, 224). | |
patch_size (tuple, optional): The path size. Defaults to (14, 14). | |
figsize (tuple, optional): The figure size. Defaults to (24, 24). | |
attention_scale (int, optional): The attention scale. Defaults to 50. | |
show (bool, optional): A boolean value indicating if we want to recuperate the figure. Defaults to True. | |
head (int, optional): The head number. Defaults to 1. | |
smooth_iter (int, optional): The number of iterations for the smoothest attention. Defaults to 2. | |
smooth_thres (float, optional): The threshold for the smoothest attention. Defaults to 0.01. | |
smooth_scale (float, optional): The scale for the smoothest attention. Defaults to 0.2. | |
smooth_size ([type], optional): The size for the smoothest attention. Defaults to 5. | |
Returns: | |
Union[Tuple[pd.DataFrame, dict], Tuple[pd.DataFame, dict, figure]]: The return prediction and the metrics | |
""" | |
with torch.no_grad(): | |
_ = model.eval() | |
# initialize the logger | |
writer = SummaryWriter(os.path.join(log_dir, "attentions")) | |
# let us recuperate the images and labels | |
images = test_dataset.images | |
labels = test_dataset.labels | |
# let us initialize the predictions | |
predictions = {'attentions': [], 'predictions': [], 'true_labels': labels, 'predicted_labels': []} | |
# let us initialize the dataloader | |
test_dataloader = DataLoader(test_dataset, batch_size=batch_size) | |
# get the loss | |
loss = 0 | |
for data in test_dataloader: | |
# recuperate the pixel values | |
pixel_values = data['pixel_values'][0] | |
# recuperate the labels | |
labels_ = data['labels'] | |
# # recuperate the outputs | |
outputs = model(pixel_values, labels = labels_, output_attentions = True) | |
# recuperate the predictions | |
predictions['predictions'].append(torch.softmax(outputs.logits.detach(), axis = -1).numpy()) | |
# recuperate the attentions of the last encoder layer | |
predictions['attentions'].append(outputs.attentions[-1].detach()) | |
# add the loss | |
loss += outputs.loss.detach().item() | |
predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0) | |
predictions['attentions'] = torch.concatenate(predictions['attentions'], axis = 0) | |
predictions['predicted_labels'] = np.argmax(predictions['predictions'], axis = -1).tolist() | |
# let us calculate the metrics | |
metrics = compute_metrics((predictions['predictions'], np.array(predictions['true_labels']))) | |
metrics['loss'] = loss / len(test_dataloader) | |
# for each image we will visualize his attention | |
nrows = ceil(sqrt(len(images))) | |
fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize = figsize) | |
axes = axes.flat | |
for i in range(len(images)): | |
attention_image = get_attention(images[i], predictions['attentions'][i], size, patch_size, attention_scale, head, smooth_iter, smooth_thres, smooth_scale, smooth_size) | |
axes[i].imshow(attention_image) | |
axes[i].set_title(f'Image {i + 1}') | |
axes[i].axis('off') | |
fig.tight_layout() | |
[fig.delaxes(axes[i]) for i in range(len(images), nrows * nrows)] | |
writer.add_figure(tag, fig) | |
# let us remove the predictions and the attentions | |
del predictions['predictions'] | |
del predictions['attentions'] | |
# show the figure if necessary | |
if show: return pd.DataFrame(predictions), metrics, fig | |
else: | |
# let us recuperate the metrics and the predictions | |
return pd.DataFrame(predictions), metrics | |