Spaces:
Build error
Build error
File size: 5,117 Bytes
783053f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset
from fake_face_detection.metrics.compute_metrics import compute_metrics
from torch.utils.tensorboard import SummaryWriter
from PIL.JpegImagePlugin import JpegImageFile
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from typing import *
import pandas as pd
from math import *
import numpy as np
import torch
import os
def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple):
# recuperate the image as a numpy array
if isinstance(image, str):
with Image.open(image) as img:
img = np.array(transforms.Resize(size)(img))
else:
img = np.array(transforms.Resize(size)(image))
# recuperate the attention provided by the last patch (notice that we eliminate 1 because of the +1 added by the convolutation layer)
attention = attention[:, -1, 1:]
# calculate the mean attention
attention = attention.mean(axis = 0)
# let us reshape transform the image to a numpy array
# calculate the scale factor
scale_factor = size[0] * size[1] / (patch_size[0] * patch_size[1])
# rescale the attention with the nearest scaler
attention = F.interpolate(attention.reshape(1, 1, -1), scale_factor=scale_factor,
mode='nearest')
# let us reshape the attention to the right size
attention = attention.reshape(size[0], size[1], 1)
# recuperate the result
attention_image = img / 255 * attention.numpy()
return attention_image
def make_predictions(test_dataset: FakeFaceDetectionDataset,
model,
log_dir: str = "fake_face_logs",
tag: str = "Attentions",
batch_size: int = 3,
size: tuple = (224, 224),
patch_size: tuple = (14, 14),
figsize: tuple = (24, 24)):
with torch.no_grad():
_ = model.eval()
# initialize the logger
writer = SummaryWriter(os.path.join(log_dir, "attentions"))
# let us recuperate the images and labels
images = test_dataset.images
labels = test_dataset.labels
# let us initialize the predictions
predictions = {'attentions': [], 'predictions': [], 'true_labels': labels, 'predicted_labels': []}
# let us initialize the dataloader
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
# get the loss
loss = 0
for data in test_dataloader:
# recuperate the pixel values
pixel_values = data['pixel_values'][0].cuda()
# recuperate the labels
labels_ = data['labels'].cuda()
# # recuperate the outputs
outputs = model(pixel_values, labels = labels_, output_attentions = True)
# recuperate the predictions
predictions['predictions'].append(torch.softmax(outputs.logits.detach().cpu(), axis = -1).numpy())
# recuperate the attentions of the last encoder layer
predictions['attentions'].append(outputs.attentions[-1].detach().cpu())
# add the loss
loss += outputs.loss.detach().cpu().item()
predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0)
predictions['attentions'] = torch.concatenate(predictions['attentions'], axis = 0)
predictions['predicted_labels'] = np.argmax(predictions['predictions'], axis = -1).tolist()
# let us calculate the metrics
metrics = compute_metrics((predictions['predictions'], np.array(predictions['true_labels'])))
metrics['loss'] = loss / len(test_dataloader)
# for each image we will visualize his attention
nrows = ceil(sqrt(len(images)))
fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize = figsize)
axes = axes.flat
for i in range(len(images)):
attention_image = get_attention(images[i], predictions['attentions'][i], size, patch_size)
axes[i].imshow(attention_image)
axes[i].set_title(f'Image {i + 1}')
axes[i].axis('off')
fig.tight_layout()
[fig.delaxes(axes[i]) for i in range(len(images), nrows * nrows)]
writer.add_figure(tag, fig)
# let us remove the predictions and the attentions
del predictions['predictions']
del predictions['attentions']
# let us recuperate the metrics and the predictions
return pd.DataFrame(predictions), metrics
|