Spaces:
Sleeping
Sleeping
File size: 5,191 Bytes
ea0fb2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# -*- coding: utf-8 -*-
# @Time : 2023/5/6 4:29 p.m.
# @Author : JianingWang
# @File : reward_model.py
from typing import Optional, Tuple
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from loss.rl_loss import LogSigLoss, LogExpLoss
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model
"""
RoERTa for Reward Model
"""
class RobertaForReward(RobertaPreTrainedModel):
"""
Reward model base class.
Args:
model (nn.Module): Reward model.
value_head (nn.Module): Value head to get reward score.
"""
def __init__(self, config) -> None:
super().__init__(config)
self.config = config
self.roberta = RobertaModel(config)
self.value_head = nn.Linear(self.config.n_embd, 1)
self.init_weights()
def forward(
self,
chosen_sequences: torch.LongTensor,
chosen_attention_mask: Optional[torch.Tensor],
rejected_sequences: Optional[torch.LongTensor] = None,
rejected_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# obtain reward value of chosen sequence
chosen_outputs = self.roberta(chosen_sequences, attention_mask=chosen_attention_mask)
chosen_last_hidden_states = chosen_outputs['last_hidden_state']
chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1]
chosen_values = chosen_values.mean(dim=1).squeeze(1) # ensure shape is (B)
return_dict = {
"chosen_values": chosen_values,
}
# if has rejected, obtain reward of rejected sequence, and calculate the loss
if rejected_sequences is not None:
rejected_outputs = self.roberta(rejected_sequences, attention_mask=rejected_attention_mask)
rejected_last_hidden_states = rejected_outputs['last_hidden_state']
rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1]
rejected_values = rejected_values.mean(dim=1).squeeze(1) # ensure shape is (B)
return_dict["rejected_values"] = rejected_values
loss_fn = LogSigLoss()
loss = loss_fn(chosen_values, rejected_values)
return_dict["loss"] = loss
return return_dict
"""
GPT2 for Reward Model
"""
class GPT2ForReward(GPT2PreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
"""
Reward model base class.
Args:
model (nn.Module): Reward model.
value_head (nn.Module): Value head to get reward score.
"""
def __init__(self, config) -> None:
super().__init__(config)
self.config = config
self.transformer = GPT2Model(config)
self.value_head = nn.Linear(self.config.n_embd, 1)
# Model parallel
self.model_parallel = False
self.device_map = None
self.post_init()
def forward(
self,
chosen_sequences: torch.LongTensor,
chosen_attention_mask: Optional[torch.Tensor],
rejected_sequences: Optional[torch.LongTensor] = None,
rejected_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# obtain reward value of chosen sequence
chosen_outputs = self.transformer(chosen_sequences, attention_mask=chosen_attention_mask)
chosen_last_hidden_states = chosen_outputs['last_hidden_state']
chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1]
chosen_values = chosen_values.mean(dim=1).squeeze(1) # ensure shape is (B)
return_dict = {
"chosen_values": chosen_values,
}
# if has rejected, obtain reward of rejected sequence, and calculate the loss
if rejected_sequences is not None:
rejected_outputs = self.transformer(rejected_sequences, attention_mask=rejected_attention_mask)
rejected_last_hidden_states = rejected_outputs['last_hidden_state']
rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1]
rejected_values = rejected_values.mean(dim=1).squeeze(1) # ensure shape is (B)
return_dict["rejected_values"] = rejected_values
loss_fn = LogSigLoss()
loss = loss_fn(chosen_values, rejected_values)
return_dict["loss"] = loss
return return_dict
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
) |