File size: 7,349 Bytes
783053f
 
 
b39c220
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bb44c5
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b63fd37
783053f
 
 
 
 
 
 
 
 
 
 
 
 
d57c931
3bb44c5
d57c931
783053f
b63fd37
783053f
b63fd37
783053f
 
 
 
 
 
 
 
 
b63fd37
 
 
d57c931
 
3bb44c5
d57c931
 
b63fd37
 
 
 
 
 
 
 
 
 
 
 
 
 
d57c931
3bb44c5
d57c931
 
b63fd37
 
 
 
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b63fd37
783053f
 
b63fd37
783053f
 
 
 
 
b63fd37
783053f
 
b63fd37
783053f
 
b63fd37
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bb44c5
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b63fd37
 
 
 
 
783053f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset
from fake_face_detection.metrics.compute_metrics import compute_metrics
from fake_face_detection.utils.smoothest_attention import smooth_attention
from torch.utils.tensorboard import SummaryWriter
from PIL.JpegImagePlugin import JpegImageFile
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from typing import *
import pandas as pd
from math import *
import numpy as np
import torch
import os

def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple, scale: int = 50, head: int = 1, smooth_iter: int = 2, smooth_thres: float = 0.01, smooth_scale: float = 0.2, smooth_size = 5):
    
    # recuperate the image as a numpy array
    if isinstance(image, str):
        
        with Image.open(image) as img:
        
            img = np.array(transforms.Resize(size)(img))
    
    else:
        
        img = np.array(transforms.Resize(size)(image))
        
    # recuperate the attention provided by the last patch (notice that we eliminate 1 because of the +1 added by the convolutation layer)
    attention = attention[:, -1, 1:]

    # calculate the mean attention
    attention = attention[head - 1]

    # let us reshape transform the image to a numpy array

    # calculate the scale factor
    scale_factor = size[0] * size[1] / (patch_size[0] * patch_size[1])

    # rescale the attention with the nearest scaler
    attention = F.interpolate(attention.reshape(1, 1, -1), scale_factor=scale_factor,
                            mode='nearest')

    # let us reshape the attention to the right size
    attention = attention.reshape(size[0], size[1], 1)
    
    # add the smoothest attention
    attention = smooth_attention(attention, smooth_iter, smooth_thres, smooth_scale, smooth_size)
    
    # recuperate the result
    attention_image = img / 255 * attention.numpy() * scale
    
    return np.clip(attention_image, 0, 1)


def make_predictions(test_dataset: FakeFaceDetectionDataset,
                     model,
                     log_dir: str = "fake_face_logs",
                     tag: str = "Attentions",
                     batch_size: int = 3,
                     size: tuple = (224, 224), 
                     patch_size: tuple = (14, 14),
                     figsize: tuple = (24, 24),
                     attention_scale: int = 50,
                     show: bool = True, 
                     head: int = 1,
                     smooth_iter: int = 2,
                     smooth_thres: float = 0.01,
                     smooth_scale: float = 0.2,
                     smooth_size = 5):
    """Make predictions with a vision transformer model

    Args:
        test_dataset (FakeFaceDetectionDataset): The test dataset
        model (_type_): The model
        log_dir (str, optional): The log directory. Defaults to "fake_face_logs".
        tag (str, optional): The tag. Defaults to "Attentions".
        batch_size (int, optional): The batch size. Defaults to 3.
        size (tuple, optional): The size of the attention image. Defaults to (224, 224).
        patch_size (tuple, optional): The path size. Defaults to (14, 14).
        figsize (tuple, optional): The figure size. Defaults to (24, 24).
        attention_scale (int, optional): The attention scale. Defaults to 50.
        show (bool, optional): A boolean value indicating if we want to recuperate the figure. Defaults to True.
        head (int, optional): The head number. Defaults to 1.
        smooth_iter (int, optional): The number of iterations for the smoothest attention. Defaults to 2.
        smooth_thres (float, optional): The threshold for the smoothest attention. Defaults to 0.01.
        smooth_scale (float, optional): The scale for the smoothest attention. Defaults to 0.2.
        smooth_size ([type], optional): The size for the smoothest attention. Defaults to 5.

    Returns:
        Union[Tuple[pd.DataFrame, dict], Tuple[pd.DataFame, dict, figure]]: The return prediction and the metrics
    """
    
    with torch.no_grad():
        
        _ = model.eval()
        
        # initialize the logger
        writer = SummaryWriter(os.path.join(log_dir, "attentions"))
        
        # let us recuperate the images and labels
        images = test_dataset.images
        
        labels = test_dataset.labels
        
        # let us initialize the predictions
        predictions = {'attentions': [], 'predictions': [], 'true_labels': labels, 'predicted_labels': []}

        # let us initialize the dataloader
        test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
        
        # get the loss
        loss = 0
        
        for data in test_dataloader:
            
            # recuperate the pixel values
            pixel_values = data['pixel_values'][0]
            
            # recuperate the labels
            labels_ = data['labels']
            
            # # recuperate the outputs
            outputs = model(pixel_values, labels = labels_, output_attentions = True)
            
            # recuperate the predictions
            predictions['predictions'].append(torch.softmax(outputs.logits.detach(), axis = -1).numpy())
            
            # recuperate the attentions of the last encoder layer
            predictions['attentions'].append(outputs.attentions[-1].detach())
            
            # add the loss
            loss += outputs.loss.detach().item()
        
        predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0)
        
        predictions['attentions'] = torch.concatenate(predictions['attentions'], axis = 0)
        
        predictions['predicted_labels'] = np.argmax(predictions['predictions'], axis = -1).tolist()
        
        # let us calculate the metrics
        metrics = compute_metrics((predictions['predictions'], np.array(predictions['true_labels'])))
        metrics['loss'] = loss / len(test_dataloader)
        
        # for each image we will visualize his attention
        nrows = ceil(sqrt(len(images)))
        
        fig, axes = plt.subplots(nrows=nrows, ncols=nrows, figsize = figsize)
        
        axes = axes.flat
        
        for i in range(len(images)):
            
            attention_image = get_attention(images[i], predictions['attentions'][i], size, patch_size, attention_scale, head, smooth_iter, smooth_thres, smooth_scale, smooth_size)
        
            axes[i].imshow(attention_image)
            
            axes[i].set_title(f'Image {i + 1}')
            
            axes[i].axis('off')
            
        fig.tight_layout()
        
        [fig.delaxes(axes[i]) for i in range(len(images), nrows * nrows)]
        
        writer.add_figure(tag, fig)    
        
        # let us remove the predictions and the attentions
        del predictions['predictions']
        del predictions['attentions']
        
        # show the figure if necessary
        if show: return pd.DataFrame(predictions), metrics, fig
        else:
            # let us recuperate the metrics and the predictions
            return pd.DataFrame(predictions), metrics