import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification device = "cuda" if torch.cuda.is_available() else "cpu" def model_input_wrapper(batch_size, sequence_length, tokenizer): dummy_input_ids = torch.randint(0, tokenizer.vocab_size, (batch_size, sequence_length), dtype=torch.long) dummy_attention_mask = torch.ones(batch_size, sequence_length, dtype=torch.long) return {'input_ids': dummy_input_ids, 'attention_mask': dummy_attention_mask} def create_roberta_model(output_shape:int=10, device=device): """Creates a HuggingFace roberta-base model. Args: device: A torch.device print_summary: A boolean to print the model summary Returns: A tuple of the model and tokenizer """ tokenizer = AutoTokenizer.from_pretrained('roberta-base') model = AutoModelForSequenceClassification.from_pretrained('roberta-base') # Partial Freeze to speed up training for param in model.parameters(): param.requires_grad = False for param in model.classifier.parameters(): param.requires_grad = True model.classifier.out_proj = torch.nn.Linear(in_features=768, out_features=output_shape) return model.to(device), tokenizer