Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
@@ -2,15 +2,11 @@ import torch
|
|
2 |
import torchtext
|
3 |
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
|
4 |
from torch import nn
|
5 |
-
def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good)
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
param.requires_grad = False
|
14 |
-
|
15 |
-
|
16 |
-
return model, transforms
|
|
|
2 |
import torchtext
|
3 |
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
|
4 |
from torch import nn
|
5 |
+
def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good)
|
6 |
+
):
|
7 |
+
transforms = XLMR_BASE_ENCODER.transform()
|
8 |
+
classifier_head = RobertaClassificationHead(num_classes = 2, input_dim = 768)
|
9 |
+
model = XLMR_BASE_ENCODER.get_model(head = classifier_head)
|
10 |
+
for param in model.parameters():
|
11 |
+
param.requires_grad = False
|
12 |
+
return model, transforms
|
|
|
|
|
|
|
|