ahamedddd commited on
Commit
f98422c
·
1 Parent(s): fd092c5

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -12
model.py CHANGED
@@ -2,15 +2,11 @@ import torch
2
  import torchtext
3
  from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
4
  from torch import nn
5
- def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good)):
6
- # 1, 2, 3 Create EffNetB2 pretrained weights, transforms and model
7
- transforms = torchtext.models.XLMR_BASE_ENCODER.transform()
8
- classifier_head = RobertaClassificationHead(num_classes = 2, input_dim = 768)
9
- model = XLMR_BASE_ENCODER.get_model(head = classifier_head)
10
-
11
- # 4. Freeze all layers in the base model
12
- for param in model.parameters():
13
- param.requires_grad = False
14
-
15
-
16
- return model, transforms
 
2
  import torchtext
3
  from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
4
  from torch import nn
5
+ def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good)
6
+ ):
7
+ transforms = XLMR_BASE_ENCODER.transform()
8
+ classifier_head = RobertaClassificationHead(num_classes = 2, input_dim = 768)
9
+ model = XLMR_BASE_ENCODER.get_model(head = classifier_head)
10
+ for param in model.parameters():
11
+ param.requires_grad = False
12
+ return model, transforms