LukeOLuck's picture
Remove torchinfo
6c038be
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = "cuda" if torch.cuda.is_available() else "cpu"
def model_input_wrapper(batch_size, sequence_length, tokenizer):
dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, sequence_length), dtype=torch.long)
dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long)
return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask}
def create_roberta_model(output_shape:int=10, device=device):
"""Creates a HuggingFace roberta-base model.
Args:
device: A torch.device
print_summary: A boolean to print the model summary
Returns:
A tuple of the model and tokenizer
"""
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
model = AutoModelForSequenceClassification.from_pretrained('roberta-base')
# Partial Freeze to speed up training
for param in model.parameters():
param.requires_grad = False
for param in model.classifier.parameters():
param.requires_grad = True
model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape)
return model.to(device), tokenizer