OpenSight-CommunityForensics-Deepfake-ViT / modeling_vit_classifier.py
LPX55's picture
Update modeling_vit_classifier.py
4c79ef6 verified
from transformers import PreTrainedModel
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm
import PIL.Image as Image
class ViTClassifier(nn.Module):
def __init__(self, config, device='cuda', dtype=torch.float32):
super(ViTClassifier, self).__init__()
self.config = config
self.device = device
self.dtype = dtype
# Create the ViT model without unsupported arguments
self.vit = timm.create_model(
config['model']['variant'],
pretrained=False,
num_classes=config['model']['num_classes'],
drop_rate=config['model']['hidden_dropout_prob'],
attn_drop_rate=config['model']['attention_probs_dropout_prob']
).to(device)
# Replace the head with a custom head
self.vit.head = nn.Linear(
in_features=config['model']['head']['in_features'],
out_features=config['model']['head']['out_features'],
bias=config['model']['head']['bias'],
device=device,
dtype=dtype
)
if config['model']['freeze_backbone']:
for param in self.vit.parameters():
param.requires_grad = False
for param in self.vit.head.parameters():
assert param.requires_grad == True, "Model head should be trainable."
def preprocess_input(self, x):
norm_mean = self.config['preprocessing']['norm_mean']
norm_std = self.config['preprocessing']['norm_std']
resize_size = self.config['preprocessing']['resize_size']
crop_size = self.config['preprocessing']['crop_size']
augment_list = [
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=norm_mean, std=norm_std),
transforms.ConvertImageDtype(self.dtype),
]
preprocess = transforms.Compose(augment_list)
x = preprocess(x)
x = x.unsqueeze(0)
return x
def forward(self, x):
x = self.preprocess_input(x).to(self.device)
x = self.vit(x)
x = torch.nn.functional.sigmoid(x)
return x