import torch from torch import nn from transformers import AutoModelForCausalLM, AutoTokenizer class GPTRewardModel(nn.Module): def __init__(self, model_path): super().__init__() model = AutoModelForCausalLM.from_pretrained(model_path) self.config = model.config # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd`` self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd self.transformer = model.transformer self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") self.tokenizer.pad_token = self.tokenizer.eos_token self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] def forward( self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, mc_token_ids=None, labels=None, return_dict=False, output_attentions=False, output_hidden_states=False, ): loss = None transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, ) hidden_states = transformer_outputs[0] rewards = self.v_head(hidden_states).squeeze(-1) chosen_end_scores = [] rejected_end_scores = [] # Split the inputs and rewards into two parts, chosen and rejected assert len(input_ids.shape) == 2 bs = input_ids.shape[0] // 2 chosen = input_ids[:bs] rejected = input_ids[bs:] chosen_rewards = rewards[:bs] rejected_rewards = rewards[bs:] loss = 0 inference = False for i in range(bs): if torch.all(torch.eq(chosen[i], rejected[i])).item(): c_inds = (chosen[i] == self.PAD_ID).nonzero() c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] chosen_end_scores.append(chosen_rewards[i, c_ind - 1]) inference = True continue # Check if there is any padding otherwise take length of sequence c_inds = (chosen[i] == self.PAD_ID).nonzero() c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] r_inds = (rejected[i] == self.PAD_ID).nonzero() r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1] end_ind = max(c_ind, r_ind) # Retrieve first index where trajectories diverge divergence_ind = (chosen[i] != rejected[i]).nonzero()[0] assert divergence_ind > 0 # Index into the correct rewards c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind] r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind] # Append the last rewards to the list of end scores chosen_end_scores.append(c_truncated_reward[-1]) rejected_end_scores.append(r_truncated_reward[-1]) # Compute loss based on truncated rewards (ignore padding) loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean() loss = loss / bs if not inference: chosen_end_scores = torch.stack(chosen_end_scores) rejected_end_scores = torch.stack(rejected_end_scores) if inference: chosen_end_scores = torch.stack(chosen_end_scores) return {"chosen_end_scores": chosen_end_scores} return { "loss": loss, "chosen_end_scores": chosen_end_scores, "rejected_end_scores": rejected_end_scores, }