|
'''
|
|
# author: Zhiyuan Yan
|
|
# email: [email protected]
|
|
# date: 2023-0706
|
|
# description: Class for the ResnetDetector
|
|
|
|
Functions in the Class are summarized as:
|
|
1. __init__: Initialization
|
|
2. build_backbone: Backbone-building
|
|
3. build_loss: Loss-function-building
|
|
4. features: Feature-extraction
|
|
5. classifier: Classification
|
|
6. get_losses: Loss-computation
|
|
7. get_train_metrics: Training-metrics-computation
|
|
8. get_test_metrics: Testing-metrics-computation
|
|
9. forward: Forward-propagation
|
|
|
|
Reference:
|
|
@inproceedings{wang2020cnn,
|
|
title={CNN-generated images are surprisingly easy to spot... for now},
|
|
author={Wang, Sheng-Yu and Wang, Oliver and Zhang, Richard and Owens, Andrew and Efros, Alexei A},
|
|
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
|
pages={8695--8704},
|
|
year={2020}
|
|
}
|
|
|
|
Notes:
|
|
We chose to use ResNet-34 as the backbone instead of ResNet-50 because the number of parameters in ResNet-34 is relatively similar to that of Xception. This similarity allows us to make a more meaningful and fair comparison between different architectures.
|
|
'''
|
|
|
|
import os
|
|
import datetime
|
|
import logging
|
|
import numpy as np
|
|
from sklearn import metrics
|
|
from typing import Union
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.nn import DataParallel
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from metrics.base_metrics_class import calculate_metrics_for_train
|
|
|
|
from .base_detector import AbstractDetector
|
|
from detectors import DETECTOR
|
|
from networks import BACKBONE
|
|
from loss import LOSSFUNC
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@DETECTOR.register_module(module_name='resnet34')
|
|
class ResnetDetector(AbstractDetector):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.config = config
|
|
self.backbone = self.build_backbone(config)
|
|
self.loss_func = self.build_loss(config)
|
|
|
|
def build_backbone(self, config):
|
|
|
|
backbone_class = BACKBONE[config['backbone_name']]
|
|
model_config = config['backbone_config']
|
|
backbone = backbone_class(model_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return backbone
|
|
|
|
def build_loss(self, config):
|
|
|
|
loss_class = LOSSFUNC[config['loss_func']]
|
|
loss_func = loss_class()
|
|
return loss_func
|
|
|
|
def features(self, data_dict: dict) -> torch.tensor:
|
|
return self.backbone.features(data_dict['image'])
|
|
|
|
def classifier(self, features: torch.tensor) -> torch.tensor:
|
|
return self.backbone.classifier(features)
|
|
|
|
def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
|
|
label = data_dict['label']
|
|
pred = pred_dict['cls']
|
|
loss = self.loss_func(pred, label)
|
|
loss_dict = {'overall': loss}
|
|
return loss_dict
|
|
|
|
def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
|
|
label = data_dict['label']
|
|
pred = pred_dict['cls']
|
|
|
|
auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
|
|
metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
|
|
return metric_batch_dict
|
|
|
|
def forward(self, data_dict: dict, inference=False) -> dict:
|
|
|
|
features = self.features(data_dict)
|
|
|
|
pred = self.classifier(features)
|
|
|
|
prob = torch.softmax(pred, dim=1)[:, 1]
|
|
|
|
pred_dict = {'cls': pred, 'prob': prob, 'feat': features}
|
|
return pred_dict
|
|
|
|
|