ItsNotRohit commited on
Commit
b549ce0
·
1 Parent(s): b290c65

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +9 -12
model.py CHANGED
@@ -4,20 +4,17 @@ import torchvision
4
  from torch import nn
5
 
6
 
7
- def create_effnetv2(num_classes:int=101,
8
  seed:int=42):
9
 
10
- weights = torchvision.models.EfficientNet_V2_M_Weights.DEFAULT
11
- transforms = weights.transforms()
12
- model = torchvision.models.efficientnet_v2_m(weights=weights)
13
 
14
- for param in model.parameters():
15
- param.requires_grad = False
16
 
17
- torch.manual_seed(seed)
18
- model.classifier = nn.Sequential(
19
- nn.Dropout(p=0.3, inplace=True),
20
- nn.Linear(in_features=1280, out_features=num_classes),
21
- )
22
 
23
- return model, transforms
 
4
  from torch import nn
5
 
6
 
7
+ def create_ViT(num_classes:int=126,
8
  seed:int=42):
9
 
10
+ weights = torchvision.models.ViT_B_16_Weights.DEFAULT
11
+ transforms = weights.transforms()
12
+ model = torchvision.models.vit_b_16(weights=weights)
13
 
14
+ for param in model.parameters():
15
+ param.requires_grad = False
16
 
17
+ torch.manual_seed(seed)
18
+ model.heads = nn.Sequential(nn.Linear(in_features=768, out_features=num_classes))
 
 
 
19
 
20
+ return model, transforms