File size: 1,083 Bytes
be55357 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from transformers import PreTrainedModel,LlamaConfig,LlamaModel
import torch.nn as nn
import torch
from typing import Optional
class LlamaRewardModel(PreTrainedModel):
config_class =LlamaConfig
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.value_head = nn.Linear(config.hidden_size, 1)
def forward(self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.model(input_ids,attention_mask=attention_mask, output_hidden_states=True)
last_hidden_states = outputs.hidden_states[-1]
if attention_mask is None:
last_hidden_states = last_hidden_states[:, -1]
else:
last_index = attention_mask.cumsum(dim=1).argmax(dim=1)
last_hidden_states = last_hidden_states.gather(1, last_index.view(-1, 1, 1).expand(-1, 1, last_hidden_states.size(-1))).squeeze(1)
values = self.value_head(last_hidden_states).squeeze(-1)# (bs,)
return values
|