panda1835 commited on
Commit
0ddbb98
·
verified ·
1 Parent(s): dbc83fe

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +8 -46
models.py CHANGED
@@ -4,56 +4,18 @@ import torchvision
4
  import torchvision.transforms as T
5
  from PIL import Image
6
 
7
- class DinoVisionTransformerClassifier(nn.Module):
8
- def __init__(self, num_classes):
9
- super(DinoVisionTransformerClassifier, self).__init__()
 
10
 
11
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
-
13
- # Workaround to bypass HTTP Error 403 rate limit exceeded
14
- torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
15
- self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_lc")
16
-
17
- self.model.linear_head = nn.Sequential(
18
- nn.Linear(3840, 512, bias=True),
19
  nn.ReLU(),
20
  nn.Linear(512, 256, bias=True),
21
  nn.ReLU(),
22
- nn.Linear(256, num_classes, bias=True)
23
  )
24
 
25
- self.model.to(self.device)
26
-
27
- self.transform_image = T.Compose([
28
- T.Resize((224, 224)),
29
- T.ToTensor(),
30
- T.Normalize(mean=[0.485, 0.456, 0.406],
31
- std=[0.229, 0.224, 0.225])
32
- ])
33
-
34
- self.model_name = "dinov2"
35
-
36
- def load_image_from_filepath(self, img: str) -> torch.Tensor:
37
- """
38
- Load an image as filepath and return a tensor that can be used as an input to model.
39
- """
40
- img = Image.open(img).convert('RGB')
41
-
42
- transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)
43
-
44
- return transformed_img
45
-
46
- def load_image_from_pillowimage(self, img: Image.Image) -> torch.Tensor:
47
- """
48
- Load an image as Pillow Image and return a tensor that can be used as an input to model.
49
- """
50
- transformed_img = self.transform_image(img)[:3].unsqueeze(0).to(self.device)
51
-
52
- return transformed_img
53
-
54
  def forward(self, x):
55
- if isinstance(x, str):
56
- x = self.load_image_from_filepath(x)
57
- if isinstance(x, Image.Image):
58
- x = self.load_image_from_pillowimage(x)
59
- return self.model(x)
 
4
  import torchvision.transforms as T
5
  from PIL import Image
6
 
7
+ class LinearClassifier(torch.nn.Module):
8
+ def __init__(self, input_dim, output_dim):
9
+ super(LinearClassifier, self).__init__()
10
+ num_classes = len(index_to_species.keys())
11
 
12
+ self.linear_head = nn.Sequential(
13
+ nn.Linear(input_dim, 512, bias=True),
 
 
 
 
 
 
14
  nn.ReLU(),
15
  nn.Linear(512, 256, bias=True),
16
  nn.ReLU(),
17
+ nn.Linear(256, output_dim, bias=True)
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def forward(self, x):
21
+ return self.linear_head(x)