DeepFake-Videos-Detection / training /detectors /multi_attention_detector.py
anyantudre's picture
moved from training repo to inference
caa56d6
"""
# author: Kangran ZHAO
# email: [email protected]
# date: 2024-0401
# description: Class for the Multi-attention Detector
Functions in the Class are summarized as:
1. __init__: Initialization
2. build_backbone: Backbone-building
3. build_loss: Loss-function-building
4. features: Feature-extraction
5. classifier: Classification
6. get_losses: Loss-computation
7. get_train_metrics: Training-metrics-computation
8. get_test_metrics: Testing-metrics-computation
9. forward: Forward-propagation
Reference:
@INPROCEEDINGS{9577592,
author={Zhao, Hanqing and Wei, Tianyi and Zhou, Wenbo and Zhang, Weiming and Chen, Dongdong and Yu, Nenghai},
booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
title={Multi-attentional Deepfake Detection},
year={2021},
volume={},
number={},
pages={2185-2194},
keywords={Measurement;Semantics;Feature extraction;Forgery;Pattern recognition;Feeds;Task analysis},
doi={10.1109/CVPR46437.2021.00222}
}
Codes are modified based on GitHub repo https://github.com/yoctta/multiple-attention
"""
import random
import kornia
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from detectors import DETECTOR
from loss import LOSSFUNC
from metrics.base_metrics_class import calculate_metrics_for_train
from networks import BACKBONE
from sklearn import metrics
from .base_detector import AbstractDetector
@DETECTOR.register_module(module_name='multi_attention')
class MultiAttentionDetector(AbstractDetector):
def __init__(self, config):
super().__init__()
self.config = config
self.block_layer = {"b1": 1, "b2": 5, "b3": 9, "b4": 15, "b5": 21, "b6": 29, "b7": 31}
self.mid_dim = config["mid_dim"]
self.backbone = self.build_backbone(config)
self.loss_func = self.build_loss(config)
self.batch_cnt = 0
with torch.no_grad():
layer_outputs = self.features({"image": torch.zeros(1, 3, config["resolution"], config["resolution"])})
self.feature_layer = config["feature_layer"]
self.attention_layer = config["attention_layer"]
self.num_classes = config["backbone_config"]["num_classes"]
self.num_shallow_features = layer_outputs[self.feature_layer].shape[1]
self.num_attention_features = layer_outputs[self.attention_layer].shape[1]
self.num_final_features = layer_outputs["final"].shape[1]
self.num_attentions = config["num_attentions"]
self.AGDA = AGDA(kernel_size=config["AGDA"]["kernel_size"],
dilation=config["AGDA"]["dilation"],
sigma=config["AGDA"]["sigma"],
threshold=config["AGDA"]["threshold"],
zoom=config["AGDA"]["zoom"],
scale_factor=config["AGDA"]["scale_factor"],
noise_rate=config["AGDA"]["noise_rate"])
self.attention_generation = AttentionMap(self.num_attention_features, self.num_attentions)
self.attention_pooling = AttentionPooling()
self.texture_enhance = TextureEnhanceV1(self.num_shallow_features, self.num_attentions) # Todo
self.num_enhanced_features = self.texture_enhance.output_features
self.num_features_d = self.texture_enhance.output_features_d
self.projection_local = nn.Sequential(nn.Linear(self.num_attentions * self.num_enhanced_features, self.mid_dim),
nn.Hardswish(),
nn.Linear(self.mid_dim, self.mid_dim),
nn.Hardswish())
self.projection_final = nn.Sequential(nn.Linear(self.num_final_features, self.mid_dim),
nn.Hardswish())
self.ensemble_classifier_fc = nn.Sequential(nn.Linear(self.mid_dim * 2, self.mid_dim),
nn.Hardswish(),
nn.Linear(self.mid_dim, self.num_classes))
self.dropout = nn.Dropout(config["dropout_rate"], inplace=True)
self.dropout_final = nn.Dropout(config["dropout_rate_final"], inplace=True)
def build_backbone(self, config):
backbone_class = BACKBONE[config['backbone_name']]
model_config = config['backbone_config']
model_config['pretrained'] = self.config.get('pretrained', None)
backbone = backbone_class(model_config)
return backbone
def build_loss(self, config):
cls_loss_class = LOSSFUNC[config["loss_func"]["cls_loss"]]
ril_loss_class = LOSSFUNC[config["loss_func"]["ril_loss"]]
cls_loss_func = cls_loss_class()
ril_loss_func = ril_loss_class(M=config["num_attentions"],
N=config["loss_func"]["ril_params"]["N"],
alpha=config["loss_func"]["ril_params"]["alpha"],
alpha_decay=config["loss_func"]["ril_params"]["alpha_decay"],
decay_batch=config["batch_per_epoch"],
inter_margin=config["loss_func"]["ril_params"]["inter_margin"],
intra_margin=config["loss_func"]["ril_params"]["intra_margin"])
return {"cls": cls_loss_func, "ril": ril_loss_func, "weights": config["loss_func"]["weights"]}
def features(self, data_dict: dict) -> torch.tensor:
x = data_dict["image"]
layer_output = {}
for name, module in self.backbone.efficientnet.named_children():
if name == "_avg_pooling":
layer_output["final"] = x
break
elif name != "_blocks":
x = module(x)
else:
for i in range(len(module)):
x = module[i](x)
if i == self.block_layer["b1"]:
layer_output["b1"] = x
elif i == self.block_layer["b2"]:
layer_output["b2"] = x
elif i == self.block_layer["b3"]:
layer_output["b3"] = x
elif i == self.block_layer["b4"]:
layer_output["b4"] = x
elif i == self.block_layer["b5"]:
layer_output["b5"] = x
elif i == self.block_layer["b6"]:
layer_output["b6"] = x
elif i == self.block_layer["b7"]:
layer_output["b7"] = x
x = F.adaptive_avg_pool2d(x, (1, 1))
x = x.view(x.size(0), -1)
layer_output["logit"] = self.backbone.last_layer(x)
return layer_output
def classifier(self, features: torch.tensor) -> torch.tensor:
pass # do not overwrite this, since classifier structure has been written in self.forward()
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
if self.batch_cnt <= self.config["backbone_nEpochs"] * self.config["batch_per_epoch"]:
label = data_dict["label"]
pred = pred_dict["cls"]
ce_loss = self.loss_func["cls"](pred, label)
return {"overall": ce_loss, "ce_loss": ce_loss}
else:
label = data_dict["label"]
pred = pred_dict["cls"]
feature_maps_d = pred_dict["feature_maps_d"]
attention_maps = pred_dict["attentions"]
ce_loss = self.loss_func["cls"](pred, label)
ril_loss = self.loss_func["ril"](feature_maps_d, attention_maps, label)
weights = self.loss_func["weights"]
over_all_loss = weights[0] * ce_loss + weights[1] * ril_loss
return {"overall": over_all_loss, "ce_loss": ce_loss, "ril_loss": ril_loss}
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label']
pred = pred_dict['cls']
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
return metric_batch_dict
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
label = data_dict['label']
pred = pred_dict['cls']
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
return metric_batch_dict
def forward(self, data_dict: dict, inference=False) -> dict:
self.batch_cnt += 1
if self.batch_cnt <= self.config["backbone_nEpochs"] * self.config["batch_per_epoch"]:
layer_output = self.features(data_dict)
pred = layer_output["logit"]
prob = torch.softmax(pred, dim=1)[:, 1]
pred_dict = {"cls": pred,
"prob": prob,
"feat": layer_output["final"]}
else:
if not inference: # use AGDA when training
with torch.no_grad():
layer_output = self.features(data_dict)
raw_attentions = layer_output[self.attention_layer]
attention_maps = self.attention_generation(raw_attentions)
data_dict["image"], _ = self.AGDA.agda(data_dict["image"], attention_maps)
# Get Attention Maps
layer_output = self.features(data_dict)
raw_attentions = layer_output[self.attention_layer]
attention_maps = self.attention_generation(raw_attentions)
# Get Textural Feature Matrix P
shallow_features = layer_output[self.feature_layer]
enhanced_features, feature_maps_d = self.texture_enhance(shallow_features, attention_maps)
textural_feature_matrix_p = self.attention_pooling(enhanced_features, attention_maps)
B, M, N = textural_feature_matrix_p.size()
feature_matrix = self.dropout(textural_feature_matrix_p).view(B, -1)
feature_matrix = self.projection_local(feature_matrix)
# Get Global Feature G
final = layer_output["final"]
attention_maps2 = attention_maps.sum(dim=1, keepdim=True) # [B, 1, H_A, W_A]
final = self.attention_pooling(final, attention_maps2, norm=1).squeeze(1) # [B, C_F]
final = self.projection_final(final)
final = F.hardswish(final)
# Get the Prediction by Ensemble Classifier
feature_matrix = torch.cat((feature_matrix, final), dim=1) # [B, 2 * mid_dim]
pred = self.ensemble_classifier_fc(feature_matrix) # [B, 2]
# Get probability
prob = torch.softmax(pred, dim=1)[:, 1]
pred_dict = {"cls": pred,
"prob": prob,
"feat": layer_output["final"],
"attentions": attention_maps,
"feature_maps_d": feature_maps_d}
return pred_dict
class AttentionMap(nn.Module):
def __init__(self, in_channels, num_attention):
super(AttentionMap, self).__init__()
self.register_buffer('mask', torch.zeros([1, 1, 24, 24]))
self.mask[0, 0, 2:-2, 2:-2] = 1
self.num_attentions = num_attention
self.conv_extract = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv2 = nn.Conv2d(in_channels, num_attention, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(num_attention)
def forward(self, x):
"""
Convert deep feature to attention map
Args:
x: extracted features
Returns:
attention_maps: conventionally 4 attention maps
"""
if self.num_attentions == 0:
return torch.ones([x.shape[0], 1, 1, 1], device=x.device)
x = self.conv_extract(x)
x = self.bn1(x)
x = F.relu(x, inplace=True)
x = self.conv2(x)
x = self.bn2(x)
x = F.elu(x) + 1
mask = F.interpolate(self.mask, (x.shape[2], x.shape[3]), mode='nearest')
return x * mask
class AttentionPooling(nn.Module):
def __init__(self):
super().__init__()
def forward(self, features, attentions, norm=2):
"""
Bilinear Attention Pooing, when used for
Args:
features: [Tensor in [B, C_F, H_F, W_F]] extracted feature maps, either shallow ones or deep ones ???
attentions: [Tensor in [B, M, H, W]] attention maps, conventionally 4 attention maps (M = 4)
norm: [int, default=2] 1 for deep features, 2 for shallow features
Returns:
feature_matrix: [Tensor in [B, M, C_F] or [B, M, 1]] P (shallow feature) or G (deep feature) ???
"""
feature_size = features.size()[-2:]
attention_size = attentions.size()[-2:]
if feature_size != attention_size:
attentions = F.interpolate(attentions, size=feature_size, mode='bilinear', align_corners=True)
if len(features.shape) == 4:
# In TextureEnhanceV1, in accordance with paper
feature_matrix = torch.einsum('imjk,injk->imn', attentions, features) # [B, M, C_F]
else:
# In TextureEnhanceV2
feature_matrix = torch.einsum('imjk,imnjk->imn', attentions, features)
if norm == 1: # Used for deep feature BAP
w = torch.sum(attentions + 1e-8, dim=(2, 3)).unsqueeze(-1)
feature_matrix /= w
elif norm == 2: # Used for shallow feature BAP
feature_matrix = F.normalize(feature_matrix, p=2, dim=-1)
return feature_matrix
class TextureEnhanceV1(nn.Module):
def __init__(self, num_features, num_attentions):
super().__init__()
# self.output_features=num_features
self.output_features = num_features * 4
self.output_features_d = num_features
self.conv0 = nn.Conv2d(num_features, num_features, 1)
self.conv1 = nn.Conv2d(num_features, num_features, 3, padding=1)
self.bn1 = nn.BatchNorm2d(num_features)
self.conv2 = nn.Conv2d(num_features * 2, num_features, 3, padding=1)
self.bn2 = nn.BatchNorm2d(2 * num_features)
self.conv3 = nn.Conv2d(num_features * 3, num_features, 3, padding=1)
self.bn3 = nn.BatchNorm2d(3 * num_features)
self.conv_last = nn.Conv2d(num_features * 4, num_features * 4, 1)
self.bn4 = nn.BatchNorm2d(4 * num_features)
self.bn_last = nn.BatchNorm2d(num_features * 4)
def forward(self, feature_maps, attention_maps=(1, 1)):
"""
Texture Enhancement Block V1, in accordance with description in paper
1. Local average pooling.
2. Residual local features.
3. Dense Net
Args:
feature_maps: [Tensor in [B, C', H', W']] extracted shallow features
attention_maps: [Tensor in [B, M, H_A, W_A]] calculated attention maps, or
[Tuple with two float elements] local average grid scale,
used for conduct local average pooling, local patch size is decided by attention map size.
Returns:
feature_maps: [Tensor in [B, C_1, H_1, W_1]] enhanced feature maps
feature_maps_d: [Tensor in [B, C', H_A, W_A]] textural information
"""
B, N, H, W = feature_maps.shape
if type(attention_maps) == tuple:
attention_size = (int(H * attention_maps[0]), int(W * attention_maps[1]))
else:
attention_size = (attention_maps.shape[2], attention_maps.shape[3])
feature_maps_d = F.adaptive_avg_pool2d(feature_maps, attention_size)
feature_maps = feature_maps - F.interpolate(feature_maps_d, (feature_maps.shape[2], feature_maps.shape[3]),
mode='nearest')
feature_maps0 = self.conv0(feature_maps)
feature_maps1 = self.conv1(F.relu(self.bn1(feature_maps0), inplace=True))
feature_maps1_ = torch.cat([feature_maps0, feature_maps1], dim=1)
feature_maps2 = self.conv2(F.relu(self.bn2(feature_maps1_), inplace=True))
feature_maps2_ = torch.cat([feature_maps1_, feature_maps2], dim=1)
feature_maps3 = self.conv3(F.relu(self.bn3(feature_maps2_), inplace=True))
feature_maps3_ = torch.cat([feature_maps2_, feature_maps3], dim=1)
feature_maps = self.bn_last(self.conv_last(F.relu(self.bn4(feature_maps3_), inplace=True)))
return feature_maps, feature_maps_d
class TextureEnhanceV2(nn.Module):
def __init__(self, num_features, num_attentions):
super().__init__()
self.output_features = num_features
self.output_features_d = num_features
self.conv_extract = nn.Conv2d(num_features, num_features, 3, padding=1)
self.conv0 = nn.Conv2d(num_features * num_attentions, num_features * num_attentions, 5, padding=2,
groups=num_attentions)
self.conv1 = nn.Conv2d(num_features * num_attentions, num_features * num_attentions, 3, padding=1,
groups=num_attentions)
self.bn1 = nn.BatchNorm2d(num_features * num_attentions)
self.conv2 = nn.Conv2d(num_features * 2 * num_attentions, num_features * num_attentions, 3, padding=1,
groups=num_attentions)
self.bn2 = nn.BatchNorm2d(2 * num_features * num_attentions)
self.conv3 = nn.Conv2d(num_features * 3 * num_attentions, num_features * num_attentions, 3, padding=1,
groups=num_attentions)
self.bn3 = nn.BatchNorm2d(3 * num_features * num_attentions)
self.conv_last = nn.Conv2d(num_features * 4 * num_attentions, num_features * num_attentions, 1,
groups=num_attentions)
self.bn4 = nn.BatchNorm2d(4 * num_features * num_attentions)
self.bn_last = nn.BatchNorm2d(num_features * num_attentions)
self.M = num_attentions
def cat(self, a, b):
B, C, H, W = a.shape
c = torch.cat([a.reshape(B, self.M, -1, H, W), b.reshape(B, self.M, -1, H, W)], dim=2).reshape(B, -1, H, W)
return c
def forward(self, feature_maps, attention_maps=(1, 1)):
"""
Args:
feature_maps: [Tensor in [B, N, H, W]] extracted feature maps from shallow layer
attention_maps: [Tensor in [B, M, H_A, W_A] or float of (H_ratio, W_ratio)] either extracted attention maps
or average pooling down-sampling ratio
Returns:
feature_maps, feature_maps_d: [Tensor in [B, M, N, H, W], Tensor in [B, N, H, W]] feature maps after dense
network and non-textural feature map D
"""
B, N, H, W = feature_maps.shape
if type(attention_maps) == tuple:
attention_size = (int(H * attention_maps[0]), int(W * attention_maps[1]))
else:
attention_size = (attention_maps.shape[2], attention_maps.shape[3])
feature_maps = self.conv_extract(feature_maps)
feature_maps_d = F.adaptive_avg_pool2d(feature_maps, attention_size)
if feature_maps.size(2) > feature_maps_d.size(2):
feature_maps = feature_maps - F.interpolate(feature_maps_d, (feature_maps.shape[2], feature_maps.shape[3]),
mode='nearest')
attention_maps = (
torch.tanh(F.interpolate(attention_maps.detach(), (H, W), mode='bilinear', align_corners=True))).unsqueeze(
2) if type(attention_maps) != tuple else 1
feature_maps = feature_maps.unsqueeze(1)
feature_maps = (feature_maps * attention_maps).reshape(B, -1, H, W)
feature_maps0 = self.conv0(feature_maps)
feature_maps1 = self.conv1(F.relu(self.bn1(feature_maps0), inplace=True))
feature_maps1_ = self.cat(feature_maps0, feature_maps1)
feature_maps2 = self.conv2(F.relu(self.bn2(feature_maps1_), inplace=True))
feature_maps2_ = self.cat(feature_maps1_, feature_maps2)
feature_maps3 = self.conv3(F.relu(self.bn3(feature_maps2_), inplace=True))
feature_maps3_ = self.cat(feature_maps2_, feature_maps3)
feature_maps = F.relu(self.bn_last(self.conv_last(F.relu(self.bn4(feature_maps3_), inplace=True))),
inplace=True)
feature_maps = feature_maps.reshape(B, -1, N, H, W)
return feature_maps, feature_maps_d
class AGDA(nn.Module):
def __init__(self, kernel_size, dilation, sigma, threshold, zoom, scale_factor, noise_rate):
super().__init__()
self.kernel_size = kernel_size
self.dilation = dilation
self.sigma = sigma
self.noise_rate = noise_rate
self.scale_factor = scale_factor
self.threshold = threshold
self.zoom = zoom
self.filter = kornia.filters.GaussianBlur2d((self.kernel_size, self.kernel_size), (self.sigma, self.sigma))
def mod_func(self, x):
threshold = random.uniform(*self.threshold) if type(self.threshold) == list else self.threshold
zoom = random.uniform(*self.zoom) if type(self.zoom) == list else self.zoom
bottom = torch.sigmoid((torch.tensor(0.) - threshold) * zoom)
return (torch.sigmoid((x - threshold) * zoom) - bottom) / (1 - bottom)
def soft_drop2(self, x, attention_map):
with torch.no_grad():
attention_map = self.mod_func(attention_map)
B, C, H, W = x.size()
xs = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)
xs = self.filter(xs)
xs += torch.randn_like(xs) * self.noise_rate
xs = F.interpolate(xs, (H, W), mode='bilinear', align_corners=True)
x = x * (1 - attention_map) + xs * attention_map
return x
def agda(self, X, attention_map):
with torch.no_grad():
attention_weight = torch.sum(attention_map, dim=(2, 3))
attention_map = F.interpolate(attention_map, (X.size(2), X.size(3)), mode="bilinear", align_corners=True)
attention_weight = torch.sqrt(attention_weight + 1)
index = torch.distributions.categorical.Categorical(attention_weight).sample()
index1 = index.view(-1, 1, 1, 1).repeat(1, 1, X.size(2), X.size(3))
attention_map = torch.gather(attention_map, 1, index1)
atten_max = torch.max(attention_map.view(attention_map.shape[0], 1, -1), 2)[0] + 1e-8
attention_map = attention_map / atten_max.view(attention_map.shape[0], 1, 1, 1)
return self.soft_drop2(X, attention_map), index