|
from transformers import PreTrainedModel,LlamaConfig,LlamaModel |
|
import torch.nn as nn |
|
import torch |
|
from typing import Optional |
|
|
|
class LlamaRewardModel(PreTrainedModel): |
|
config_class =LlamaConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = LlamaModel(config) |
|
self.value_head = nn.Linear(config.hidden_size, 1) |
|
|
|
def forward(self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
outputs = self.model(input_ids,attention_mask=attention_mask, output_hidden_states=True) |
|
last_hidden_states = outputs.hidden_states[-1] |
|
if attention_mask is None: |
|
last_hidden_states = last_hidden_states[:, -1] |
|
else: |
|
last_index = attention_mask.cumsum(dim=1).argmax(dim=1) |
|
last_hidden_states = last_hidden_states.gather(1, last_index.view(-1, 1, 1).expand(-1, 1, last_hidden_states.size(-1))).squeeze(1) |
|
values = self.value_head(last_hidden_states).squeeze(-1) |
|
|
|
return values |
|
|