Ray2333 commited on
Commit
bf90245
1 Parent(s): bb02fc9

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +208 -0
model.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from Huggingface trl package AutoModelForCausalLMWithValueHead class
2
+ # Enabling better customization for generalizable reward modeling
3
+ import torch
4
+ import torch.nn as nn
5
+ import os
6
+ from transformers import AutoModelForCausalLM
7
+ from trl import PreTrainedModelWrapper
8
+ from peft import PeftModel, PeftConfig
9
+ from safetensors import safe_open
10
+
11
+
12
+ class ValueHead(nn.Module):
13
+ def __init__(self, config, **kwargs):
14
+ super().__init__()
15
+ if not hasattr(config, "summary_dropout_prob"):
16
+ summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1)
17
+ else:
18
+ summary_dropout_prob = config.summary_dropout_prob
19
+
20
+ self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity()
21
+
22
+ # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m
23
+ if hasattr(config, "hidden_size"):
24
+ hidden_size = config.hidden_size
25
+ if hasattr(config, "word_embed_proj_dim"):
26
+ hidden_size = config.word_embed_proj_dim
27
+ elif hasattr(config, "is_encoder_decoder"):
28
+ if config.is_encoder_decoder and hasattr(config, "decoder"):
29
+ if hasattr(config.decoder, "hidden_size"):
30
+ hidden_size = config.decoder.hidden_size
31
+
32
+ # get vhead config
33
+ if hasattr(config, "vhead_layer_type"): # config from json first
34
+ self.layer_type = config.vhead_layer_type
35
+ else:
36
+ self.layer_type = kwargs.pop("vhead_layer_type", 'mlp')
37
+ if hasattr(config, 'vhead_num_neurons'):
38
+ num_neurons = config.vhead_num_neurons
39
+ else:
40
+ num_neurons = kwargs.pop("vhead_num_neurons", 1024)
41
+ if hasattr(config, 'vhead_num_layers'):
42
+ num_layers = config.vhead_num_layers
43
+ else:
44
+ num_layers = kwargs.pop("vhead_num_layers", 1)
45
+
46
+ if self.layer_type == 'linear':
47
+ self.summary = nn.Linear(hidden_size, 1)
48
+ else:
49
+ module_lis = []
50
+ input_neurons = hidden_size
51
+ for i in range(num_layers):
52
+ module_lis.extend([nn.Linear(input_neurons, num_neurons), nn.ReLU()])
53
+ input_neurons = num_neurons
54
+
55
+ module_lis.append(nn.Linear(num_neurons, 1))
56
+ self.summary = nn.Sequential(*module_lis)
57
+ self.flatten = nn.Flatten()
58
+
59
+ def forward(self, hidden_states):
60
+ output = self.dropout(hidden_states)
61
+ if (self.layer_type == 'linear' and output.dtype != self.summary.weight.dtype):
62
+ output = output.to(self.summary.weight.dtype)
63
+ elif (self.layer_type != 'linear' and output.dtype != self.summary[0].weight.dtype):
64
+ output = output.to(self.summary[0].weight.dtype)
65
+
66
+ output = self.summary(output)
67
+ return output
68
+
69
+
70
+ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
71
+ transformers_parent_class = AutoModelForCausalLM
72
+ lm_head_namings = ["lm_head", "embed_out"]
73
+ supported_args = (
74
+ "summary_dropout_prob",
75
+ "v_head_initializer_range",
76
+ "v_head_init_strategy",
77
+ "layer_type",
78
+ 'num_neurons',
79
+ 'num_layers',
80
+ )
81
+
82
+ def __init__(self, pretrained_model, **kwargs):
83
+ r"""
84
+ Initializes the model.
85
+ """
86
+ super().__init__(pretrained_model, **kwargs)
87
+ v_head_kwargs, _, _ = self._split_kwargs(kwargs)
88
+
89
+ if not any(hasattr(self.pretrained_model, attribute) for attribute in self.lm_head_namings):
90
+ raise ValueError("The model does not have a language model head, please use a model that has one.")
91
+
92
+ self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs)
93
+ self._init_weights(**v_head_kwargs)
94
+
95
+ def _init_weights(self, **kwargs):
96
+ r"""
97
+ Initializes the weights of the value head.
98
+ """
99
+ initializer_range = kwargs.pop("v_head_initializer_range", 0.2)
100
+ # random init by default
101
+ init_strategy = kwargs.pop("v_head_init_strategy", None)
102
+ if init_strategy is None:
103
+ # do nothing
104
+ pass
105
+ elif init_strategy == "normal":
106
+ self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range)
107
+ self.v_head.summary.bias.data.zero_()
108
+
109
+ def forward(
110
+ self,
111
+ input_ids=None,
112
+ past_key_values=None,
113
+ attention_mask=None,
114
+ **kwargs,
115
+ ):
116
+ kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples
117
+ kwargs["past_key_values"] = past_key_values
118
+
119
+ if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING":
120
+ kwargs.pop("past_key_values")
121
+
122
+ base_model_output = self.pretrained_model(
123
+ input_ids=input_ids,
124
+ attention_mask=attention_mask,
125
+ **kwargs,
126
+ )
127
+
128
+ last_hidden_state = base_model_output.hidden_states[-1]
129
+ lm_logits = base_model_output.logits
130
+ loss = base_model_output.loss
131
+
132
+ if (hasattr(self.v_head.summary, 'weight') and last_hidden_state.device != self.v_head.summary.weight.device):
133
+ last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device)
134
+ elif not hasattr(self.v_head.summary, 'weight') and (last_hidden_state.device != self.v_head.summary[0].weight.device):
135
+ last_hidden_state = last_hidden_state.to(self.v_head.summary[0].weight.device)
136
+
137
+ # use the last token value as reward
138
+ last_index = attention_mask.sum(dim=-1) - 1
139
+ value = self.v_head(last_hidden_state).squeeze(-1)[torch.arange(len(last_hidden_state)), last_index]
140
+
141
+ # force upcast in fp32 if logits are in half-precision
142
+ if lm_logits.dtype != torch.float32:
143
+ lm_logits = lm_logits.float()
144
+
145
+ return (lm_logits, loss, value)
146
+
147
+ def generate(self, *args, **kwargs):
148
+ return self.pretrained_model.generate(*args, **kwargs)
149
+
150
+ def state_dict(self, *args, **kwargs):
151
+ pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs)
152
+
153
+ v_head_state_dict = self.v_head.state_dict(*args, **kwargs)
154
+ for k, v in v_head_state_dict.items():
155
+ pretrained_model_state_dict[f"v_head.{k}"] = v
156
+ return pretrained_model_state_dict
157
+
158
+ def push_to_hub(self, *args, **kwargs):
159
+ setattr(self.pretrained_model, "v_head", self.v_head)
160
+ return self.pretrained_model.push_to_hub(*args, **kwargs)
161
+
162
+
163
+
164
+ def post_init(self, state_dict):
165
+ for k in list(state_dict.keys()):
166
+ if "v_head." in k:
167
+ state_dict[k.replace("v_head.", "")] = state_dict.pop(k)
168
+ self.v_head.load_state_dict(state_dict, strict=False)
169
+ del state_dict
170
+
171
+ if hasattr(self.pretrained_model, "hf_device_map"):
172
+ if (
173
+ "cpu" in self.pretrained_model.hf_device_map.values()
174
+ or "disk" in self.pretrained_model.hf_device_map.values()
175
+ ):
176
+ raise ValueError(
177
+ "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models."
178
+ )
179
+
180
+ first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
181
+
182
+ self.v_head = self.v_head.to(first_device)
183
+
184
+ def set_device_hook(module, input, outputs):
185
+ new_output = ()
186
+ for output in outputs:
187
+ if isinstance(output, torch.Tensor):
188
+ new_output += (output.to(first_device),)
189
+ else:
190
+ new_output += (output,)
191
+ return new_output
192
+
193
+ self.register_forward_hook(set_device_hook)
194
+
195
+ self.is_sequential_parallel = True
196
+
197
+ @classmethod
198
+ def register_for_auto_class(cls, auto_class="AutoModel"):
199
+ if not isinstance(auto_class, str):
200
+ auto_class = auto_class.__name__
201
+
202
+ import transformers.models.auto as auto_module
203
+
204
+ if not hasattr(auto_module, auto_class):
205
+ raise ValueError(f"{auto_class} is not a valid auto class.")
206
+
207
+ cls._auto_class = auto_class
208
+