File size: 1,496 Bytes
e418f09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torchinfo import summary
from transformers import AutoTokenizer, AutoModelForSequenceClassification

device = "cuda" if torch.cuda.is_available() else "cpu"

def model_input_wrapper(batch_size, sequence_length, tokenizer):
  dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, sequence_length), dtype=torch.long)
  dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long)
  return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask}

def create_roberta_model(output_shape:int=10, device=device, print_summary=True):
  """Creates a HuggingFace roberta-base model.

  Args:
    device: A torch.device
    print_summary: A boolean to print the model summary

  Returns:
    A tuple of the model and tokenizer
  """
  tokenizer = AutoTokenizer.from_pretrained('roberta-base')
  model = AutoModelForSequenceClassification.from_pretrained('roberta-base')

  # Partial Freeze to speed up training
  for param in model.parameters():
    param.requires_grad = False

  for param in model.classifier.parameters():
    param.requires_grad = True

  model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape)

  if print_summary:
    sample_inputs = model_input_wrapper(1, 128, tokenizer)
    print(summary(model, input_data=sample_inputs, verbose=0, col_names=["input_size", "output_size", "num_params", "trainable"], col_width=20, row_settings=["var_names"]))

  return model.to(device), tokenizer