panda1835 commited on
Commit
a47307e
·
1 Parent(s): a230b7d

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +57 -0
models.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import torchvision.transforms as T
5
+ from PIL import Image
6
+
7
+ class DinoVisionTransformerClassifier(nn.Module):
8
+ def __init__(self, num_classes):
9
+ super(DinoVisionTransformerClassifier, self).__init__()
10
+
11
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_lc")
14
+
15
+ self.model.linear_head = nn.Sequential(
16
+ nn.Linear(3840, 512, bias=True),
17
+ nn.ReLU(),
18
+ nn.Linear(512, 256, bias=True),
19
+ nn.ReLU(),
20
+ nn.Linear(256, num_classes, bias=True)
21
+ )
22
+
23
+ self.model.to(self.device)
24
+
25
+ self.transform_image = T.Compose([
26
+ T.Resize((224, 224)),
27
+ T.ToTensor(),
28
+ T.Normalize(mean=[0.485, 0.456, 0.406],
29
+ std=[0.229, 0.224, 0.225])
30
+ ])
31
+
32
+ self.model_name = "dinov2"
33
+
34
+ def load_image_from_filepath(self, img: str) -> torch.Tensor:
35
+ """
36
+ Load an image as filepath and return a tensor that can be used as an input to model.
37
+ """
38
+ img = Image.open(img).convert('RGB')
39
+
40
+ transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)
41
+
42
+ return transformed_img
43
+
44
+ def load_image_from_pillowimage(self, img: Image.Image) -> torch.Tensor:
45
+ """
46
+ Load an image as Pillow Image and return a tensor that can be used as an input to model.
47
+ """
48
+ transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)
49
+
50
+ return transformed_img
51
+
52
+ def forward(self, x):
53
+ if isinstance(x, str):
54
+ x = self.load_image_from_filepath(x)
55
+ if isinstance(x, Image.Image):
56
+ x = self.load_image_from_pillowimage(x)
57
+ return self.model(x)