= commited on
Commit
b39c220
·
1 Parent(s): d57c931

adding correct smooth attention

Browse files
fake_face_detection/metrics/make_predictions.py CHANGED
@@ -1,7 +1,7 @@
1
 
2
  from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset
3
  from fake_face_detection.metrics.compute_metrics import compute_metrics
4
- from fake_face_detection.smoothest_attention import smooth_attention
5
  from torch.utils.tensorboard import SummaryWriter
6
  from PIL.JpegImagePlugin import JpegImageFile
7
  from torch.utils.data import DataLoader
 
1
 
2
  from fake_face_detection.data.fake_face_dataset import FakeFaceDetectionDataset
3
  from fake_face_detection.metrics.compute_metrics import compute_metrics
4
+ from fake_face_detection.utils.smoothest_attention import smooth_attention
5
  from torch.utils.tensorboard import SummaryWriter
6
  from PIL.JpegImagePlugin import JpegImageFile
7
  from torch.utils.data import DataLoader
fake_face_detection/utils/smoothest_attention.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+
4
+ # we want to take 0.2 of the pixel and 0.7 of the mean of the pixels around it 100 times
5
+ # we will take a size between the current pixel and the pixels around it
6
+ def smooth_attention(attention: torch.Tensor, iters: int = 1000, threshold: float = 0.1, scale: float = 0.2, size: int = 3):
7
+
8
+ # squeeze the attention
9
+ attention = copy.deepcopy(attention.squeeze())
10
+
11
+ # make 100 iterations
12
+ for _ in range(iters):
13
+
14
+ # initialize the difference
15
+ difference = torch.full(attention.shape, torch.inf)
16
+
17
+ # iterate over the pixels of the attention
18
+ for i in range(attention.shape[0]):
19
+
20
+ for j in range(attention.shape[1]):
21
+
22
+ # recuperate the pixel
23
+ pixel = attention[i, j]
24
+
25
+ # recuperate the mean of the pixels around it
26
+ mean = attention[max(0, i - size): min(attention.shape[0], i + size), max(0, j - size): min(attention.shape[1], j + size)].mean()
27
+
28
+ # update the attention
29
+ attention[i, j] = (1 - scale) * pixel + scale * mean
30
+
31
+ # recuperate the difference
32
+ difference[i, j] = abs(pixel - mean)
33
+
34
+ # compare each difference with the threshold
35
+ if (difference < threshold).all(): break
36
+
37
+ # unsqueeze the attention
38
+ attention = attention.unsqueeze(-1)
39
+
40
+ # return the attention
41
+ return attention
42
+