ItsNotRohit commited on
Commit
c903d33
·
1 Parent(s): b183f68

Update model.py

Browse files

Previous
import torch
import torchvision

from torch import nn


def create_ViT(num_classes:int=126,
seed:int=42):

weights = torchvision.models.ViT_B_16_Weights.DEFAULT
transforms = weights.transforms()
model = torchvision.models.vit_b_16(weights=weights)

for param in model.parameters():
param.requires_grad = False

torch.manual_seed(seed)
model.heads = nn.Sequential(nn.Linear(in_features=768, out_features=num_classes))

return model, transforms

Files changed (1) hide show
  1. model.py +5 -17
model.py CHANGED
@@ -1,20 +1,8 @@
1
- import torch
2
- import torchvision
3
 
4
- from torch import nn
5
 
 
 
6
 
7
- def create_ViT(num_classes:int=126,
8
- seed:int=42):
9
-
10
- weights = torchvision.models.ViT_B_16_Weights.DEFAULT
11
- transforms = weights.transforms()
12
- model = torchvision.models.vit_b_16(weights=weights)
13
-
14
- for param in model.parameters():
15
- param.requires_grad = False
16
-
17
- torch.manual_seed(seed)
18
- model.heads = nn.Sequential(nn.Linear(in_features=768, out_features=num_classes))
19
-
20
- return model, transforms
 
1
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
 
2
 
3
+ def create_ViT():
4
 
5
+ extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
6
+ model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")
7
 
8
+ return model