Safetensors
gemma
yifAI commited on
Commit
3ba1e99
·
verified ·
1 Parent(s): 5a11484

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +215 -1
README.md CHANGED
@@ -34,7 +34,221 @@ The GPM is evaluated using the [RewardBench](https://github.com/allenai/reward-b
34
  To use this model, please refer to the [General Preference Model Code Repository](https://github.com/general-preference/general-preference-model). The repository includes detailed instructions for finetuning, evaluation, and integration of the GPM with downstream tasks. Below is an example code snippet:
35
 
36
  ```python
37
- TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ```
39
 
40
  ## Citation
 
34
  To use this model, please refer to the [General Preference Model Code Repository](https://github.com/general-preference/general-preference-model). The repository includes detailed instructions for finetuning, evaluation, and integration of the GPM with downstream tasks. Below is an example code snippet:
35
 
36
  ```python
37
+ from typing import Optional, List, Dict
38
+ import torch
39
+ import torch.nn as nn
40
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
41
+ import torch.nn.functional as F
42
+ from transformers import AutoTokenizer
43
+
44
+ def get_tokenizer(pretrain, model, padding_side="left", use_fast=True):
45
+ tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
46
+ tokenizer.padding_side = padding_side
47
+ if tokenizer.pad_token is None:
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ tokenizer.pad_token_id = tokenizer.eos_token_id
50
+ model.config.pad_token_id = tokenizer.pad_token_id
51
+ return tokenizer
52
+
53
+ def get_reward_model(base_causal_model, base_llm_model, is_general_preference: bool=False, add_prompt_head: bool=False, value_head_dim: int=2):
54
+ class CustomRewardModel(base_causal_model):
55
+
56
+ def __init__(self, config: AutoConfig):
57
+ super().__init__(config)
58
+ setattr(self, self.base_model_prefix, base_llm_model(config))
59
+ if not is_general_preference:
60
+ self.value_head = nn.Linear(config.hidden_size, 1, bias=False)
61
+ else:
62
+ self.value_head = nn.Linear(config.hidden_size, value_head_dim, bias=False)
63
+ if add_prompt_head:
64
+ self.prompt_head = nn.Linear(config.hidden_size, value_head_dim // 2, bias=False)
65
+
66
+ self.is_general_preference = is_general_preference
67
+
68
+ self.post_init()
69
+
70
+ def custom_forward(
71
+ self,
72
+ input_ids: torch.LongTensor = None,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ return_output=False,
75
+ ) -> torch.Tensor:
76
+ position_ids = attention_mask.long().cumsum(-1) - 1
77
+ position_ids.masked_fill_(attention_mask == 0, 1)
78
+ outputs = getattr(self, self.base_model_prefix)(
79
+ input_ids, attention_mask=attention_mask, position_ids=position_ids
80
+ )
81
+ last_hidden_states = outputs["last_hidden_state"]
82
+
83
+ if not self.is_general_preference:
84
+ values = self.value_head(last_hidden_states).squeeze(-1)
85
+ # left padding in training mode
86
+ if self.training:
87
+ reward = values[:, -1]
88
+ else:
89
+ eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
90
+ reward = values.gather(dim=1, index=eos_indices).squeeze(1)
91
+ if return_output:
92
+ return reward, outputs
93
+ else:
94
+ return reward, None
95
+ else:
96
+ values = self.value_head(last_hidden_states)
97
+ # left padding in training mode
98
+ if self.training:
99
+ reward = values[:, -1, :]
100
+ reward = F.normalize(reward, p=2, dim=-1) # Shape will be [batch_size, value_head_dim]
101
+ else:
102
+ eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1)
103
+ eos_indices = eos_indices.unsqueeze(1) # Change shape to [batch_size, 1]
104
+ reward_list = []
105
+ for dim in range(value_head_dim):
106
+ reward_list.append(values[:,:,dim].gather(dim=1, index=eos_indices))
107
+ reward = torch.cat(reward_list, dim=1)
108
+ reward = F.normalize(reward, p=2, dim=-1) # Shape will be [batch_size, value_head_dim]
109
+ if return_output:
110
+ return reward, outputs
111
+ else:
112
+ return reward, None
113
+
114
+ def create_skew_symmetric_block_matrix(self, dim, device, dtype, prompt_hidden_states):
115
+ """
116
+ Create a batch of skew-symmetric block matrices where each matrix is data-dependent on
117
+ the corresponding prompt_hidden_states. Only the relevant block diagonal parts are generated.
118
+
119
+ Args:
120
+ - dim: Dimension of the square matrix (must be even).
121
+ - prompt_hidden_states: Tensor of shape [batch_size, hidden_dim].
122
+
123
+ Returns:
124
+ - batch_R_matrices: Tensor of shape [batch_size, dim, dim], with skew-symmetric block entries.
125
+ """
126
+ if hasattr(self, 'prompt_head'):
127
+ batch_size = prompt_hidden_states.shape[0]
128
+
129
+ # Ensure that dim is even, as we're creating blocks of size 2x2
130
+ assert dim % 2 == 0, "dim must be even for skew-symmetric block generation"
131
+
132
+ # Pass through the linear layer to get the block diagonal entries (half of the matrix's off-diagonal blocks)
133
+ block_values = self.prompt_head(prompt_hidden_states).view(batch_size, dim // 2)
134
+ block_values = torch.softmax(block_values, dim=-1)
135
+
136
+ # Create a batch of zero matrices [batch_size, dim, dim]
137
+ batch_R_matrices = torch.zeros((batch_size, dim, dim), device=device, dtype=dtype)
138
+
139
+ # Fill only the block diagonal entries with the learned values
140
+ for i in range(0, dim, 2):
141
+ batch_R_matrices[:, i, i + 1] = -block_values[:, i // 2]
142
+ batch_R_matrices[:, i + 1, i] = block_values[:, i // 2] # Skew-symmetric condition
143
+ else:
144
+ raise AttributeError("prompt_head is not defined. Ensure 'add_prompt_head' is set to True during initialization.")
145
+
146
+ return batch_R_matrices
147
+
148
+ return CustomRewardModel
149
+
150
+ def generate_high_dim_result_with_prompt(model, value_head_dim, chosen_reward, rejected_reward, prompt_hidden_states):
151
+ R_matrix = model.create_skew_symmetric_block_matrix(value_head_dim, chosen_reward.device, chosen_reward.dtype, prompt_hidden_states)
152
+ if chosen_reward.device == rejected_reward.device == R_matrix.device:
153
+ transformed_chosen = torch.bmm(chosen_reward.view(chosen_reward.shape[0], 1, value_head_dim), R_matrix.transpose(1, 2))
154
+ result = torch.bmm(transformed_chosen, rejected_reward.view(rejected_reward.shape[0], value_head_dim, 1))
155
+ result = result.view(chosen_reward.shape[0])
156
+ return result
157
+
158
+ class GPMPipeline:
159
+ def __init__(self, model_name_or_path, device=torch.device("cuda:0"), is_general_preference: bool=True, add_prompt_head: bool=True, value_head_dim: int=2, bf16: bool=True, truncation: bool=True, max_length: int=4096, padding: bool=True, tau: float=0.1):
160
+ self.device = device
161
+ self.is_general_preference = is_general_preference
162
+ self.add_prompt_head = add_prompt_head
163
+ self.value_head_dim = value_head_dim
164
+ self.truncation = truncation
165
+ self.max_length = max_length
166
+ self.padding = padding
167
+ self.tau = 0.1
168
+
169
+ config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
170
+ config._attn_implementation = "flash_attention_2"
171
+ base_class = AutoModel._model_mapping[type(config)]
172
+ base_causal_class = AutoModelForCausalLM._model_mapping.get(type(config), None)
173
+ cls_class = get_reward_model(base_causal_class, base_class, is_general_preference, add_prompt_head, value_head_dim)
174
+
175
+ # configure model
176
+ self.model = cls_class.from_pretrained(
177
+ model_name_or_path,
178
+ config=config,
179
+ trust_remote_code=True,
180
+ torch_dtype=torch.bfloat16 if bf16 else "auto",
181
+ )
182
+ # configure tokenizer
183
+ self.tokenizer = get_tokenizer(model_name_or_path, self.model, "left", use_fast=True)
184
+ self.tokenizer.truncation_side = "right"
185
+
186
+ # prepare model
187
+ self.model.to(device)
188
+ self.model.eval()
189
+
190
+ def __call__(self, samples: List[List[Dict[str, str]]], return_prompt=False):
191
+ input_texts = [self.tokenizer.apply_chat_template(sample, tokenize=False) for sample in samples]
192
+
193
+ inputs = self.tokenizer(
194
+ input_texts,
195
+ truncation=True,
196
+ max_length=self.max_length,
197
+ padding=True,
198
+ return_tensors="pt",
199
+ ).to(self.device)
200
+
201
+ inputs["input_ids"][:, -1] = self.tokenizer.eos_token_id
202
+ inputs["attention_mask"][:, -1] = 1
203
+
204
+ with torch.no_grad():
205
+ rewards, outputs = self.model.custom_forward(**inputs, return_output=return_prompt)
206
+
207
+ if return_prompt:
208
+ # Compute prompt hidden states
209
+ prompt_texts = [self.tokenizer.apply_chat_template([sample[0]], tokenize=False) for sample in samples]
210
+ prompt_lengths = [len(self.tokenizer(prompt_text, padding=False, return_tensors="pt")["input_ids"][0]) for prompt_text in prompt_texts]
211
+ prompt_lengths = torch.tensor(prompt_lengths, device=self.device)
212
+ prompt_end_indices = prompt_lengths - 1
213
+
214
+ last_hidden_states = outputs.last_hidden_state
215
+ prompt_hidden_states = last_hidden_states[torch.arange(len(samples)), prompt_end_indices, :]
216
+
217
+ return rewards, prompt_hidden_states
218
+
219
+ return rewards
220
+
221
+
222
+ prompt_text = "Describe the importance of reading books in today's digital age."
223
+ response1 = "Books remain crucial in the digital era, offering in-depth knowledge and fostering critical thinking. They provide a unique, immersive experience that digital media can't replicate, contributing significantly to personal and intellectual growth."
224
+ response2 = "Books are still useful for learning new things. They help you relax and can be a good break from screens."
225
+
226
+ context1 = [
227
+ {"role": "user", "content": prompt_text},
228
+ {"role": "assistant", "content": response1}
229
+ ]
230
+
231
+ context2 = [
232
+ {"role": "user", "content": prompt_text},
233
+ {"role": "assistant", "content": response2}
234
+ ]
235
+
236
+ rm = GPMPipeline("general-preference/GPM-Llama-3.1-8B", value_head_dim=4)
237
+
238
+ reward1, prompt_hidden_state = rm([context1], return_prompt=True)
239
+ reward2 = rm([context2])
240
+
241
+ result = generate_high_dim_result_with_prompt(rm.model, rm.value_head_dim, reward1, reward2, prompt_hidden_state)
242
+
243
+ result_batch = result.float().cpu().detach().numpy().tolist()
244
+
245
+ results = []
246
+ [
247
+ results.append(1) if result > 0 else results.append(0)
248
+ for result in result_batch
249
+ ]
250
+
251
+ print(result_batch)
252
  ```
253
 
254
  ## Citation