Spaces:
Build error
Build error
File size: 7,349 Bytes
783053f b39c220 783053f 3bb44c5 783053f b63fd37 783053f d57c931 3bb44c5 d57c931 783053f b63fd37 783053f b63fd37 783053f b63fd37 d57c931 3bb44c5 d57c931 b63fd37 d57c931 3bb44c5 d57c931 b63fd37 783053f b63fd37 783053f b63fd37 783053f b63fd37 783053f b63fd37 783053f b63fd37 783053f 3bb44c5 783053f b63fd37 783053f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset
from fake_face_detection.metrics.compute_metrics import compute_metrics
from fake_face_detection.utils.smoothest_attention import smooth_attention
from torch.utils.tensorboard import SummaryWriter
from PIL.JpegImagePlugin import JpegImageFile
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from typing import *
import pandas as pd
from math import *
import numpy as np
import torch
import os
def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple, scale: int = 50, head: int = 1, smooth_iter: int = 2, smooth_thres: float = 0.01, smooth_scale: float = 0.2, smooth_size = 5):
# recuperate the image as a numpy array
if isinstance(image, str):
with Image.open(image) as img:
img = np.array(transforms.Resize(size)(img))
else:
img = np.array(transforms.Resize(size)(image))
# recuperate the attention provided by the last patch (notice that we eliminate 1 because of the +1 added by the convolutation layer)
attention = attention[:, -1, 1:]
# calculate the mean attention
attention = attention[head - 1]
# let us reshape transform the image to a numpy array
# calculate the scale factor
scale_factor = size[0] * size[1] / (patch_size[0] * patch_size[1])
# rescale the attention with the nearest scaler
attention = F.interpolate(attention.reshape(1, 1, -1), scale_factor=scale_factor,
mode='nearest')
# let us reshape the attention to the right size
attention = attention.reshape(size[0], size[1], 1)
# add the smoothest attention
attention = smooth_attention(attention, smooth_iter, smooth_thres, smooth_scale, smooth_size)
# recuperate the result
attention_image = img / 255 * attention.numpy() * scale
return np.clip(attention_image, 0, 1)
def make_predictions(test_dataset: FakeFaceDetectionDataset,
model,
log_dir: str = "fake_face_logs",
tag: str = "Attentions",
batch_size: int = 3,
size: tuple = (224, 224),
patch_size: tuple = (14, 14),
figsize: tuple = (24, 24),
attention_scale: int = 50,
show: bool = True,
head: int = 1,
smooth_iter: int = 2,
smooth_thres: float = 0.01,
smooth_scale: float = 0.2,
smooth_size = 5):
"""Make predictions with a vision transformer model
Args:
test_dataset (FakeFaceDetectionDataset): The test dataset
model (_type_): The model
log_dir (str, optional): The log directory. Defaults to "fake_face_logs".
tag (str, optional): The tag. Defaults to "Attentions".
batch_size (int, optional): The batch size. Defaults to 3.
size (tuple, optional): The size of the attention image. Defaults to (224, 224).
patch_size (tuple, optional): The path size. Defaults to (14, 14).
figsize (tuple, optional): The figure size. Defaults to (24, 24).
attention_scale (int, optional): The attention scale. Defaults to 50.
show (bool, optional): A boolean value indicating if we want to recuperate the figure. Defaults to True.
head (int, optional): The head number. Defaults to 1.
smooth_iter (int, optional): The number of iterations for the smoothest attention. Defaults to 2.
smooth_thres (float, optional): The threshold for the smoothest attention. Defaults to 0.01.
smooth_scale (float, optional): The scale for the smoothest attention. Defaults to 0.2.
smooth_size ([type], optional): The size for the smoothest attention. Defaults to 5.
Returns:
Union[Tuple[pd.DataFrame, dict], Tuple[pd.DataFame, dict, figure]]: The return prediction and the metrics
"""
with torch.no_grad():
_ = model.eval()
# initialize the logger
writer = SummaryWriter(os.path.join(log_dir, "attentions"))
# let us recuperate the images and labels
images = test_dataset.images
labels = test_dataset.labels
# let us initialize the predictions
predictions = {'attentions': [], 'predictions': [], 'true_labels': labels, 'predicted_labels': []}
# let us initialize the dataloader
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
# get the loss
loss = 0
for data in test_dataloader:
# recuperate the pixel values
pixel_values = data['pixel_values'][0]
# recuperate the labels
labels_ = data['labels']
# # recuperate the outputs
outputs = model(pixel_values, labels = labels_, output_attentions = True)
# recuperate the predictions
predictions['predictions'].append(torch.softmax(outputs.logits.detach(), axis = -1).numpy())
# recuperate the attentions of the last encoder layer
predictions['attentions'].append(outputs.attentions[-1].detach())
# add the loss
loss += outputs.loss.detach().item()
predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0)
predictions['attentions'] = torch.concatenate(predictions['attentions'], axis = 0)
predictions['predicted_labels'] = np.argmax(predictions['predictions'], axis = -1).tolist()
# let us calculate the metrics
metrics = compute_metrics((predictions['predictions'], np.array(predictions['true_labels'])))
metrics['loss'] = loss / len(test_dataloader)
# for each image we will visualize his attention
nrows = ceil(sqrt(len(images)))
fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize = figsize)
axes = axes.flat
for i in range(len(images)):
attention_image = get_attention(images[i], predictions['attentions'][i], size, patch_size, attention_scale, head, smooth_iter, smooth_thres, smooth_scale, smooth_size)
axes[i].imshow(attention_image)
axes[i].set_title(f'Image {i + 1}')
axes[i].axis('off')
fig.tight_layout()
[fig.delaxes(axes[i]) for i in range(len(images), nrows * nrows)]
writer.add_figure(tag, fig)
# let us remove the predictions and the attentions
del predictions['predictions']
del predictions['attentions']
# show the figure if necessary
if show: return pd.DataFrame(predictions), metrics, fig
else:
# let us recuperate the metrics and the predictions
return pd.DataFrame(predictions), metrics
|