|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
import timm |
|
import PIL.Image as Image |
|
|
|
class ViTClassifier(nn.Module): |
|
def __init__(self, config, device='cuda', dtype=torch.float32): |
|
super(ViTClassifier, self).__init__() |
|
self.config = config |
|
self.device = device |
|
self.dtype = dtype |
|
|
|
|
|
self.vit = timm.create_model( |
|
config['model']['variant'], |
|
pretrained=False, |
|
num_classes=config['model']['num_classes'], |
|
drop_rate=config['model']['hidden_dropout_prob'], |
|
attn_drop_rate=config['model']['attention_probs_dropout_prob'] |
|
).to(device) |
|
|
|
|
|
self.vit.head = nn.Linear( |
|
in_features=config['model']['head']['in_features'], |
|
out_features=config['model']['head']['out_features'], |
|
bias=config['model']['head']['bias'], |
|
device=device, |
|
dtype=dtype |
|
) |
|
|
|
if config['model']['freeze_backbone']: |
|
for param in self.vit.parameters(): |
|
param.requires_grad = False |
|
|
|
for param in self.vit.head.parameters(): |
|
assert param.requires_grad == True, "Model head should be trainable." |
|
|
|
def preprocess_input(self, x): |
|
norm_mean = self.config['preprocessing']['norm_mean'] |
|
norm_std = self.config['preprocessing']['norm_std'] |
|
resize_size = self.config['preprocessing']['resize_size'] |
|
crop_size = self.config['preprocessing']['crop_size'] |
|
|
|
augment_list = [ |
|
transforms.Resize(resize_size), |
|
transforms.CenterCrop(crop_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=norm_mean, std=norm_std), |
|
transforms.ConvertImageDtype(self.dtype), |
|
] |
|
|
|
preprocess = transforms.Compose(augment_list) |
|
x = preprocess(x) |
|
x = x.unsqueeze(0) |
|
return x |
|
|
|
def forward(self, x): |
|
x = self.preprocess_input(x).to(self.device) |
|
x = self.vit(x) |
|
x = torch.nn.functional.sigmoid(x) |
|
return x |