''' # author: Zhiyuan Yan # email: zhiyuanyan@link.cuhk.edu.cn # date: 2023-0706 # description: Class for the ResnetDetector Functions in the Class are summarized as: 1. __init__: Initialization 2. build_backbone: Backbone-building 3. build_loss: Loss-function-building 4. features: Feature-extraction 5. classifier: Classification 6. get_losses: Loss-computation 7. get_train_metrics: Training-metrics-computation 8. get_test_metrics: Testing-metrics-computation 9. forward: Forward-propagation Reference: @inproceedings{wang2020cnn, title={CNN-generated images are surprisingly easy to spot... for now}, author={Wang, Sheng-Yu and Wang, Oliver and Zhang, Richard and Owens, Andrew and Efros, Alexei A}, booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, pages={8695--8704}, year={2020} } Notes: We chose to use ResNet-34 as the backbone instead of ResNet-50 because the number of parameters in ResNet-34 is relatively similar to that of Xception. This similarity allows us to make a more meaningful and fair comparison between different architectures. ''' import os import datetime import logging import numpy as np from sklearn import metrics from typing import Union from collections import defaultdict import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.nn import DataParallel from torch.utils.tensorboard import SummaryWriter from metrics.base_metrics_class import calculate_metrics_for_train from .base_detector import AbstractDetector from detectors import DETECTOR from networks import BACKBONE from loss import LOSSFUNC logger = logging.getLogger(__name__) @DETECTOR.register_module(module_name='resnet34') class ResnetDetector(AbstractDetector): def __init__(self, config): super().__init__() self.config = config self.backbone = self.build_backbone(config) self.loss_func = self.build_loss(config) def build_backbone(self, config): # prepare the backbone backbone_class = BACKBONE[config['backbone_name']] model_config = config['backbone_config'] backbone = backbone_class(model_config) #FIXME: current load pretrained weights only from the backbone, not here # # if donot load the pretrained weights, fail to get good results # state_dict = torch.load(config['pretrained']) # state_dict = {'resnet.'+k:v for k, v in state_dict.items() if 'fc' not in k} # backbone.load_state_dict(state_dict, False) # logger.info('Load pretrained model successfully!') return backbone def build_loss(self, config): # prepare the loss function loss_class = LOSSFUNC[config['loss_func']] loss_func = loss_class() return loss_func def features(self, data_dict: dict) -> torch.tensor: return self.backbone.features(data_dict['image']) def classifier(self, features: torch.tensor) -> torch.tensor: return self.backbone.classifier(features) def get_losses(self, data_dict: dict, pred_dict: dict) -> dict: label = data_dict['label'] pred = pred_dict['cls'] loss = self.loss_func(pred, label) loss_dict = {'overall': loss} return loss_dict def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict: label = data_dict['label'] pred = pred_dict['cls'] # compute metrics for batch data auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach()) metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap} return metric_batch_dict def forward(self, data_dict: dict, inference=False) -> dict: # get the features by backbone features = self.features(data_dict) # get the prediction by classifier pred = self.classifier(features) # get the probability of the pred prob = torch.softmax(pred, dim=1)[:, 1] # build the prediction dict for each output pred_dict = {'cls': pred, 'prob': prob, 'feat': features} return pred_dict