ppak10's picture
Updates model.py.
fd94ff6
raw
history blame contribute delete
849 Bytes
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel
class LlamaClassificationModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.base_model = AutoModel.from_pretrained(config.model_path, config=config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.config = config
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
summed_representation = outputs.last_hidden_state.sum(dim=1)
logits = self.classifier(summed_representation)
loss = None
if labels is not None:
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels.float())
return {"loss": loss, "logits": logits}