File size: 700 Bytes
d222e8f
 
 
 
f98422c
 
 
 
 
 
 
f16c38b
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
import torchtext
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
from torch import nn
def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good)
                           ):
                               transforms = XLMR_BASE_ENCODER.transform()
                               classifier_head = RobertaClassificationHead(num_classes = 2, input_dim = 768)
                               model = XLMR_BASE_ENCODER.get_model(head = classifier_head)
                               for param in model.parameters():
                                   param.requires_grad = False
                               return model, transforms