Spaces:
Build error
Build error
import torch | |
import copy | |
# we want to take 0.2 of the pixel and 0.7 of the mean of the pixels around it 100 times | |
# we will take a size between the current pixel and the pixels around it | |
def smooth_attention(attention: torch.Tensor, iters: int = 1000, threshold: float = 0.1, scale: float = 0.2, size: int = 3): | |
# squeeze the attention | |
attention = copy.deepcopy(attention.squeeze()) | |
# make 100 iterations | |
for _ in range(iters): | |
# initialize the difference | |
difference = torch.full(attention.shape, torch.inf) | |
# iterate over the pixels of the attention | |
for i in range(attention.shape[0]): | |
for j in range(attention.shape[1]): | |
# recuperate the pixel | |
pixel = attention[i, j] | |
# recuperate the mean of the pixels around it | |
mean = attention[max(0, i - size): min(attention.shape[0], i + size), max(0, j - size): min(attention.shape[1], j + size)].mean() | |
# update the attention | |
attention[i, j] = (1 - scale) * pixel + scale * mean | |
# recuperate the difference | |
difference[i, j] = abs(pixel - mean) | |
# compare each difference with the threshold | |
if (difference < threshold).all(): break | |
# unsqueeze the attention | |
attention = attention.unsqueeze(-1) | |
# return the attention | |
return attention | |