Mghao commited on
Commit
9f91572
1 Parent(s): 6ed4edd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +109 -1
README.md CHANGED
@@ -33,6 +33,115 @@ We evaluate our model on [RewardBench](https://huggingface.co/spaces/allenai/rew
33
 
34
  ## Demo Code
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ## Declaration and License Agreement
37
 
38
  ### Declaration
@@ -40,7 +149,6 @@ We evaluate our model on [RewardBench](https://huggingface.co/spaces/allenai/rew
40
  ### License Agreement
41
 
42
  ## Contact
43
-
44
  If you have any questions, please feel free to reach us at <[email protected]>.
45
  ## Citation
46
 
 
33
 
34
  ## Demo Code
35
 
36
+ We provide example usage of the Skywork reward model series below.
37
+ Below is an example of obtaining the reward scores of two conversations.
38
+
39
+ ```python
40
+ from typing import List, Optional, Union
41
+
42
+ import torch
43
+ import torch.nn as nn
44
+ from transformers import LlamaPreTrainedModel, LlamaModel, PreTrainedTokenizerFast
45
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
46
+
47
+ class INFORMForSequenceClassification(LlamaPreTrainedModel):
48
+ def __init__(self, config):
49
+ super().__init__(config)
50
+ self.num_labels = config.num_labels
51
+ self.model = LlamaModel(config)
52
+ self.score = nn.Sequential(
53
+ nn.Linear(config.hidden_size, config.hidden_size),
54
+ nn.ReLU(),
55
+ nn.Linear(config.hidden_size, self.num_labels)
56
+ )
57
+ # Initialize weights and apply final processing
58
+ self.post_init()
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ position_ids: Optional[torch.LongTensor] = None,
65
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
66
+ inputs_embeds: Optional[torch.FloatTensor] = None,
67
+ labels: Optional[torch.LongTensor] = None,
68
+ use_cache: Optional[bool] = None,
69
+ output_attentions: Optional[bool] = None,
70
+ output_hidden_states: Optional[bool] = None,
71
+ return_dict: Optional[bool] = None,
72
+ ):
73
+
74
+ transformer_outputs = self.model(
75
+ input_ids,
76
+ attention_mask=attention_mask,
77
+ position_ids=position_ids,
78
+ past_key_values=past_key_values,
79
+ inputs_embeds=inputs_embeds,
80
+ )
81
+ hidden_states = transformer_outputs[0]
82
+ logits = self.score(hidden_states)
83
+
84
+ if input_ids is not None:
85
+ batch_size = input_ids.shape[0]
86
+ else:
87
+ batch_size = inputs_embeds.shape[0]
88
+
89
+ if self.config.pad_token_id is None and batch_size != 1:
90
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
91
+ if self.config.pad_token_id is None:
92
+ sequence_lengths = -1
93
+ else:
94
+ if input_ids is not None:
95
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
96
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
97
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
98
+ sequence_lengths = sequence_lengths.to(logits.device)
99
+ else:
100
+ sequence_lengths = -1
101
+
102
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
103
+
104
+ loss = None
105
+ return SequenceClassifierOutputWithPast(
106
+ loss=loss,
107
+ logits=pooled_logits,
108
+ past_key_values=transformer_outputs.past_key_values,
109
+ hidden_states=transformer_outputs.hidden_states,
110
+ attentions=transformer_outputs.attentions,
111
+ )
112
+
113
+ # Load model
114
+ model_name = "infly/INF-ORM-Llama3.1-70B"
115
+ model_name = "/inspire/ssd/ws-ba572160-47f8-4ca1-984e-d6bcdeb95dbb/a100-maybe/public/yangminghao/upload_2_model/INF-ORM-Llama3.1-70B"
116
+ orm = INFORMForSequenceClassification.from_pretrained(
117
+ model_name,
118
+ torch_dtype=torch.bfloat16,
119
+ device_map="auto",
120
+ attn_implementation="flash_attention_2",
121
+ num_labels=1,
122
+ )
123
+
124
+ # Load tokenizer
125
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name)
126
+ conv1 = [ { "content": "Tom has 20 oranges. He gave 5 oranges to his friend Lisa, then he bought 3 more oranges. Finally, he divided all the oranges equally among himself and his 4 friends. How many oranges does each person get?", "role": "user" }, { "content": "Tom has 20 oranges. He gave 5 oranges to Lisa. 20 - 5 = 15. Tom now has 15 oranges. Tom bought 3 more oranges. 15 + 3 = 18. Tom now has 18 oranges. Tom divides the 18 oranges equally among himself and his 4 friends (a total of 5 people). 18 ÷ 5 = 3.6 oranges. Each person gets 3.6 oranges.", "role": "assistant" } ]
127
+ conv2= [ { "content": "Tom has 20 oranges. He gave 5 oranges to his friend Lisa, then he bought 3 more oranges. Finally, he divided all the oranges equally among himself and his 4 friends. How many oranges does each person get?", "role": "user" }, { "content": "Tom has 20 oranges. He gave 5 oranges to his friend Lisa. 20 - 5 = 15. Tom now has 15 oranges. Tom bought 3 more oranges. 15 + 3 = 18. Tom now has 18 oranges. Tom divides the 18 oranges equally among his 4 friends (a total of 4 people). 18 ÷ 4 = 4.5 oranges. Each person gets 4.5 oranges.", "role": "assistant" } ]
128
+ conv1_tokenized = tokenizer.apply_chat_template(conv1, tokenize=True, return_tensors="pt").to("cuda")
129
+ conv2_tokenized = tokenizer.apply_chat_template(conv2, tokenize=True, return_tensors="pt").to("cuda")
130
+
131
+ # Inference
132
+ with torch.no_grad():
133
+ score1 = orm(conv1_tokenized).logits[0][0].item()
134
+ score2 = orm(conv2_tokenized).logits[0][0].item()
135
+ print(f"Score for response 1: {score1}")
136
+ print(f"Score for response 2: {score2}")
137
+
138
+ # Output:
139
+
140
+ # Score for response 1: 4.96875
141
+ # Score for response 2: 2.890625
142
+
143
+ ```
144
+
145
  ## Declaration and License Agreement
146
 
147
  ### Declaration
 
149
  ### License Agreement
150
 
151
  ## Contact
 
152
  If you have any questions, please feel free to reach us at <[email protected]>.
153
  ## Citation
154