Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def model_input_wrapper(batch_size, sequence_length, tokenizer): | |
dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, sequence_length), dtype=torch.long) | |
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long) | |
return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask} | |
def create_roberta_model(output_shape:int=10, device=device): | |
"""Creates a HuggingFace roberta-base model. | |
Args: | |
device: A torch.device | |
print_summary: A boolean to print the model summary | |
Returns: | |
A tuple of the model and tokenizer | |
""" | |
tokenizer = AutoTokenizer.from_pretrained('roberta-base') | |
model = AutoModelForSequenceClassification.from_pretrained('roberta-base') | |
# Partial Freeze to speed up training | |
for param in model.parameters(): | |
param.requires_grad = False | |
for param in model.classifier.parameters(): | |
param.requires_grad = True | |
model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape) | |
return model.to(device), tokenizer | |