|
import torch.nn as nn |
|
from transformers import AutoModel, PreTrainedModel |
|
|
|
class LlamaClassificationModel(PreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.base_model = AutoModel.from_pretrained(config.model_path, config=config) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
self.config = config |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) |
|
summed_representation = outputs.last_hidden_state.sum(dim=1) |
|
logits = self.classifier(summed_representation) |
|
loss = None |
|
if labels is not None: |
|
loss_fn = nn.BCEWithLogitsLoss() |
|
loss = loss_fn(logits, labels.float()) |
|
return {"loss": loss, "logits": logits} |
|
|