Spaces:
Running
Running
import torch | |
import torchtext | |
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER | |
from torch import nn | |
def xlmr_base_encoder_model(num_classes:int=2, # default output classes = 2 (Bad, Good)): | |
# 1, 2, 3 Create EffNetB2 pretrained weights, transforms and model | |
transforms = torchtext.models.XLMR_BASE_ENCODER.transform() | |
classifier_head = torchtext.RobertaClassificationHead(num_classes = 2, input_dim = 768) | |
model = XLMR_BASE_ENCODER.get_model(head = classifier_head) | |
# 4. Freeze all layers in the base model | |
for param in model.parameters(): | |
param.requires_grad = False | |
return model, transforms | |