File size: 849 Bytes
3e60839
fd94ff6
3e60839
fd94ff6
 
 
 
 
 
3e60839
 
fd94ff6
 
 
3e60839
 
 
 
fd94ff6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel

class LlamaClassificationModel(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.base_model = AutoModel.from_pretrained(config.model_path, config=config)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.config = config

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        summed_representation = outputs.last_hidden_state.sum(dim=1)
        logits = self.classifier(summed_representation)
        loss = None
        if labels is not None:
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits, labels.float())
        return {"loss": loss, "logits": logits}