teachyourselfcoding's picture
Upload 245 files
fa6856c
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
class GPTRewardModel(nn.Module):
def __init__(self, model_path):
super().__init__()
model = AutoModelForCausalLM.from_pretrained(model_path)
self.config = model.config
# `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd``
self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd
self.transformer = model.transformer
self.v_head = nn.Linear(self.config.n_embd, 1, bias=False)
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
mc_token_ids=None,
labels=None,
return_dict=False,
output_attentions=False,
output_hidden_states=False,
):
loss = None
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
hidden_states = transformer_outputs[0]
rewards = self.v_head(hidden_states).squeeze(-1)
chosen_end_scores = []
rejected_end_scores = []
# Split the inputs and rewards into two parts, chosen and rejected
assert len(input_ids.shape) == 2
bs = input_ids.shape[0] // 2
chosen = input_ids[:bs]
rejected = input_ids[bs:]
chosen_rewards = rewards[:bs]
rejected_rewards = rewards[bs:]
loss = 0
inference = False
for i in range(bs):
if torch.all(torch.eq(chosen[i], rejected[i])).item():
c_inds = (chosen[i] == self.PAD_ID).nonzero()
c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
chosen_end_scores.append(chosen_rewards[i, c_ind - 1])
inference = True
continue
# Check if there is any padding otherwise take length of sequence
c_inds = (chosen[i] == self.PAD_ID).nonzero()
c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1]
r_inds = (rejected[i] == self.PAD_ID).nonzero()
r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1]
end_ind = max(c_ind, r_ind)
# Retrieve first index where trajectories diverge
divergence_ind = (chosen[i] != rejected[i]).nonzero()[0]
assert divergence_ind > 0
# Index into the correct rewards
c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind]
r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind]
# Append the last rewards to the list of end scores
chosen_end_scores.append(c_truncated_reward[-1])
rejected_end_scores.append(r_truncated_reward[-1])
# Compute loss based on truncated rewards (ignore padding)
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
loss = loss / bs
if not inference:
chosen_end_scores = torch.stack(chosen_end_scores)
rejected_end_scores = torch.stack(rejected_end_scores)
if inference:
chosen_end_scores = torch.stack(chosen_end_scores)
return {"chosen_end_scores": chosen_end_scores}
return {
"loss": loss,
"chosen_end_scores": chosen_end_scores,
"rejected_end_scores": rejected_end_scores,
}