import torch import torch.nn as nn import torchvision.transforms as transforms import timm import PIL.Image as Image class ViTClassifier(nn.Module): def __init__(self, config, device='cuda', dtype=torch.float32): super(ViTClassifier, self).__init__() self.config = config self.device = device self.dtype = dtype # Create the ViT model without unsupported arguments self.vit = timm.create_model( config['model']['variant'], pretrained=False, num_classes=config['model']['num_classes'], drop_rate=config['model']['hidden_dropout_prob'], attn_drop_rate=config['model']['attention_probs_dropout_prob'] ).to(device) # Replace the head with a custom head self.vit.head = nn.Linear( in_features=config['model']['head']['in_features'], out_features=config['model']['head']['out_features'], bias=config['model']['head']['bias'], device=device, dtype=dtype ) if config['model']['freeze_backbone']: for param in self.vit.parameters(): param.requires_grad = False for param in self.vit.head.parameters(): assert param.requires_grad == True, "Model head should be trainable." def preprocess_input(self, x): norm_mean = self.config['preprocessing']['norm_mean'] norm_std = self.config['preprocessing']['norm_std'] resize_size = self.config['preprocessing']['resize_size'] crop_size = self.config['preprocessing']['crop_size'] augment_list = [ transforms.Resize(resize_size), transforms.CenterCrop(crop_size), transforms.ToTensor(), transforms.Normalize(mean=norm_mean, std=norm_std), transforms.ConvertImageDtype(self.dtype), ] preprocess = transforms.Compose(augment_list) x = preprocess(x) x = x.unsqueeze(0) return x def forward(self, x): x = self.preprocess_input(x).to(self.device) x = self.vit(x) x = torch.nn.functional.sigmoid(x) return x