File size: 1,193 Bytes
e418f09
 
 
 
 
 
 
 
 
 
6c038be
e418f09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = "cuda" if torch.cuda.is_available() else "cpu"

def model_input_wrapper(batch_size, sequence_length, tokenizer):
  dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, sequence_length), dtype=torch.long)
  dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long)
  return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask}

def create_roberta_model(output_shape:int=10, device=device):
  """Creates a HuggingFace roberta-base model.

  Args:
    device: A torch.device
    print_summary: A boolean to print the model summary

  Returns:
    A tuple of the model and tokenizer
  """
  tokenizer = AutoTokenizer.from_pretrained('roberta-base')
  model = AutoModelForSequenceClassification.from_pretrained('roberta-base')

  # Partial Freeze to speed up training
  for param in model.parameters():
    param.requires_grad = False

  for param in model.classifier.parameters():
    param.requires_grad = True

  model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape)

  return model.to(device), tokenizer