|
'''
|
|
# author: Zhiyuan Yan
|
|
# email: [email protected]
|
|
# date: 2023-0706
|
|
# description: Class for the XceptionDetector
|
|
|
|
Functions in the Class are summarized as:
|
|
1. __init__: Initialization
|
|
2. build_backbone: Backbone-building
|
|
3. build_loss: Loss-function-building
|
|
4. features: Feature-extraction
|
|
5. classifier: Classification
|
|
6. get_losses: Loss-computation
|
|
7. get_train_metrics: Training-metrics-computation
|
|
8. get_test_metrics: Testing-metrics-computation
|
|
9. forward: Forward-propagation
|
|
|
|
Reference:
|
|
@inproceedings{rossler2019faceforensics++,
|
|
title={Faceforensics++: Learning to detect manipulated facial images},
|
|
author={Rossler, Andreas and Cozzolino, Davide and Verdoliva, Luisa and Riess, Christian and Thies, Justus and Nie{\ss}ner, Matthias},
|
|
booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
|
|
pages={1--11},
|
|
year={2019}
|
|
}
|
|
'''
|
|
|
|
import os
|
|
import datetime
|
|
import logging
|
|
import numpy as np
|
|
from sklearn import metrics
|
|
from typing import Union
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.nn import DataParallel
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from metrics.base_metrics_class import calculate_metrics_for_train
|
|
|
|
from .base_detector import AbstractDetector
|
|
from detectors import DETECTOR
|
|
from networks import BACKBONE
|
|
from loss import LOSSFUNC
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@DETECTOR.register_module(module_name='xception')
|
|
class XceptionDetector(AbstractDetector):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.backbone = self.build_backbone(config)
|
|
self.loss_func = self.build_loss(config)
|
|
self.prob, self.label = [], []
|
|
self.video_names = []
|
|
self.correct, self.total = 0, 0
|
|
|
|
def build_backbone(self, config):
|
|
|
|
backbone_class = BACKBONE[config['backbone_name']]
|
|
model_config = config['backbone_config']
|
|
backbone = backbone_class(model_config)
|
|
|
|
state_dict = torch.load(config['pretrained'])
|
|
for name, weights in state_dict.items():
|
|
if 'pointwise' in name:
|
|
state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
|
|
state_dict = {k:v for k, v in state_dict.items() if 'fc' not in k}
|
|
backbone.load_state_dict(state_dict, False)
|
|
logger.info('Load pretrained model successfully!')
|
|
return backbone
|
|
|
|
def build_loss(self, config):
|
|
|
|
loss_class = LOSSFUNC[config['loss_func']]
|
|
loss_func = loss_class()
|
|
return loss_func
|
|
|
|
def features(self, data_dict: dict) -> torch.tensor:
|
|
return self.backbone.features(data_dict['image'])
|
|
|
|
def classifier(self, features: torch.tensor) -> torch.tensor:
|
|
return self.backbone.classifier(features)
|
|
|
|
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
|
label = data_dict['label']
|
|
pred = pred_dict['cls']
|
|
loss = self.loss_func(pred, label)
|
|
overall_loss = loss
|
|
loss_dict = {'overall': overall_loss, 'cls': loss,}
|
|
return loss_dict
|
|
|
|
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
|
|
label = data_dict['label']
|
|
pred = pred_dict['cls']
|
|
|
|
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
|
|
metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
|
|
|
|
self.video_names = []
|
|
return metric_batch_dict
|
|
|
|
def forward(self, data_dict: dict, inference=False) -> dict:
|
|
|
|
features = self.features(data_dict)
|
|
|
|
pred = self.classifier(features)
|
|
|
|
prob = torch.softmax(pred, dim=1)[:, 1]
|
|
|
|
pred_dict = {'cls': pred, 'prob': prob, 'feat': features}
|
|
return pred_dict
|
|
|