import torch import torchtext from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER from torch import nn def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good) ): transforms = XLMR_BASE_ENCODER.transform() classifier_head = RobertaClassificationHead(num_classes = 2, input_dim = 768) model = XLMR_BASE_ENCODER.get_model(head = classifier_head) for param in model.parameters(): param.requires_grad = False return model, transforms