LPX55 commited on
Commit
22048f8
·
verified ·
1 Parent(s): c0f6bb1

Create models.py (#3)

Browse files

- Create models.py (0ff329ac3909664a717a4b4aeb09baac94f4143f)

Files changed (1) hide show
  1. models.py +62 -0
models.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.transforms as transforms
4
+ import timm
5
+ import PIL.Image as Image
6
+
7
+ class ViTClassifier(nn.Module):
8
+ def __init__(self, config, device='cuda', dtype=torch.float32):
9
+ super(ViTClassifier, self).__init__()
10
+ self.config = config
11
+ self.device = device
12
+ self.dtype = dtype
13
+
14
+ # Create the ViT model without unsupported arguments
15
+ self.vit = timm.create_model(
16
+ config['model']['variant'],
17
+ pretrained=False,
18
+ num_classes=config['model']['num_classes'],
19
+ drop_rate=config['model']['hidden_dropout_prob'],
20
+ attn_drop_rate=config['model']['attention_probs_dropout_prob']
21
+ ).to(device)
22
+
23
+ # Replace the head with a custom head
24
+ self.vit.head = nn.Linear(
25
+ in_features=config['model']['head']['in_features'],
26
+ out_features=config['model']['head']['out_features'],
27
+ bias=config['model']['head']['bias'],
28
+ device=device,
29
+ dtype=dtype
30
+ )
31
+
32
+ if config['model']['freeze_backbone']:
33
+ for param in self.vit.parameters():
34
+ param.requires_grad = False
35
+
36
+ for param in self.vit.head.parameters():
37
+ assert param.requires_grad == True, "Model head should be trainable."
38
+
39
+ def preprocess_input(self, x):
40
+ norm_mean = self.config['preprocessing']['norm_mean']
41
+ norm_std = self.config['preprocessing']['norm_std']
42
+ resize_size = self.config['preprocessing']['resize_size']
43
+ crop_size = self.config['preprocessing']['crop_size']
44
+
45
+ augment_list = [
46
+ transforms.Resize(resize_size),
47
+ transforms.CenterCrop(crop_size),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=norm_mean, std=norm_std),
50
+ transforms.ConvertImageDtype(self.dtype),
51
+ ]
52
+
53
+ preprocess = transforms.Compose(augment_list)
54
+ x = preprocess(x)
55
+ x = x.unsqueeze(0)
56
+ return x
57
+
58
+ def forward(self, x):
59
+ x = self.preprocess_input(x).to(self.device)
60
+ x = self.vit(x)
61
+ x = torch.nn.functional.sigmoid(x)
62
+ return x