File size: 5,191 Bytes
ea0fb2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# -*- coding: utf-8 -*-
# @Time    : 2023/5/6 4:29 p.m.
# @Author  : JianingWang
# @File    : reward_model.py

from typing import Optional, Tuple

import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from loss.rl_loss import LogSigLoss, LogExpLoss
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel
from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2Model

"""
RoERTa for Reward Model
"""
class RobertaForReward(RobertaPreTrainedModel):
    """
    Reward model base class.

    Args:
        model (nn.Module): Reward model.
        value_head (nn.Module): Value head to get reward score.
    """

    def __init__(self, config) -> None:
        super().__init__(config)
        self.config = config
        self.roberta = RobertaModel(config)
        self.value_head = nn.Linear(self.config.n_embd, 1)
        self.init_weights()

    def forward(
            self, 
            chosen_sequences: torch.LongTensor, 
            chosen_attention_mask: Optional[torch.Tensor],
            rejected_sequences: Optional[torch.LongTensor] = None,
            rejected_attention_mask: Optional[torch.Tensor] = None,
        ) -> torch.Tensor:
        # obtain reward value of chosen sequence
        chosen_outputs = self.roberta(chosen_sequences, attention_mask=chosen_attention_mask)
        chosen_last_hidden_states = chosen_outputs['last_hidden_state']
        chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1]
        chosen_values = chosen_values.mean(dim=1).squeeze(1)    # ensure shape is (B)

        return_dict = {
            "chosen_values": chosen_values,
        }
        # if has rejected, obtain reward of rejected sequence, and calculate the loss
        if rejected_sequences is not None:
            rejected_outputs = self.roberta(rejected_sequences, attention_mask=rejected_attention_mask)
            rejected_last_hidden_states = rejected_outputs['last_hidden_state']
            rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1]
            rejected_values = rejected_values.mean(dim=1).squeeze(1)    # ensure shape is (B)
            return_dict["rejected_values"] = rejected_values
            
            loss_fn = LogSigLoss()
            loss = loss_fn(chosen_values, rejected_values)
            
            return_dict["loss"] = loss
        
        return return_dict


"""
GPT2 for Reward Model
"""
class GPT2ForReward(GPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
    """
    Reward model base class.

    Args:
        model (nn.Module): Reward model.
        value_head (nn.Module): Value head to get reward score.
    """

    def __init__(self, config) -> None:
        super().__init__(config)
        self.config = config
        self.transformer = GPT2Model(config)
        self.value_head = nn.Linear(self.config.n_embd, 1)

        # Model parallel
        self.model_parallel = False
        self.device_map = None

        self.post_init()

    def forward(
            self, 
            chosen_sequences: torch.LongTensor, 
            chosen_attention_mask: Optional[torch.Tensor],
            rejected_sequences: Optional[torch.LongTensor] = None,
            rejected_attention_mask: Optional[torch.Tensor] = None,
        ) -> torch.Tensor:
        # obtain reward value of chosen sequence
        chosen_outputs = self.transformer(chosen_sequences, attention_mask=chosen_attention_mask)
        chosen_last_hidden_states = chosen_outputs['last_hidden_state']
        chosen_values = self.value_head(chosen_last_hidden_states)[:, :-1]
        chosen_values = chosen_values.mean(dim=1).squeeze(1)    # ensure shape is (B)

        return_dict = {
            "chosen_values": chosen_values,
        }
        # if has rejected, obtain reward of rejected sequence, and calculate the loss
        if rejected_sequences is not None:
            rejected_outputs = self.transformer(rejected_sequences, attention_mask=rejected_attention_mask)
            rejected_last_hidden_states = rejected_outputs['last_hidden_state']
            rejected_values = self.value_head(rejected_last_hidden_states)[:, :-1]
            rejected_values = rejected_values.mean(dim=1).squeeze(1)    # ensure shape is (B)
            return_dict["rejected_values"] = rejected_values
            loss_fn = LogSigLoss()
            loss = loss_fn(chosen_values, rejected_values)
            
            return_dict["loss"] = loss
        
        return return_dict
    
    @staticmethod
    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.
        """
        return tuple(
            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
            for layer_past in past
        )