Spaces:
Sleeping
Sleeping
import torch | |
import torchtext | |
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER | |
from torch import nn | |
def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good) | |
): | |
transforms = XLMR_BASE_ENCODER.transform() | |
classifier_head = RobertaClassificationHead(num_classes = 2, input_dim = 768) | |
model = XLMR_BASE_ENCODER.get_model(head = classifier_head) | |
for param in model.parameters(): | |
param.requires_grad = False | |
return model, transforms | |