Spaces:
Runtime error
Runtime error
File size: 1,193 Bytes
e418f09 6c038be e418f09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
def model_input_wrapper(batch_size, sequence_length, tokenizer):
dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, sequence_length), dtype=torch.long)
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long)
return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask}
def create_roberta_model(output_shape:int=10, device=device):
"""Creates a HuggingFace roberta-base model.
Args:
device: A torch.device
print_summary: A boolean to print the model summary
Returns:
A tuple of the model and tokenizer
"""
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
model = AutoModelForSequenceClassification.from_pretrained('roberta-base')
# Partial Freeze to speed up training
for param in model.parameters():
param.requires_grad = False
for param in model.classifier.parameters():
param.requires_grad = True
model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape)
return model.to(device), tokenizer
|