leafzs commited on
Commit
ef7fa22
1 Parent(s): 025915e

[docs] update readme: add more base model and rm evaluation

Browse files
Files changed (1) hide show
  1. README.md +163 -309
README.md CHANGED
@@ -24,374 +24,228 @@ We are excited to announce the release of the Skywork o1 Open model series, deve
24
 
25
  Different from mere reproductions of the OpenAI o1 model, the Skywork o1 Open model series not only exhibits innate thinking, planning, and reflecting capabilities in its outputs, but also shows significant improvements in reasoning skills on standard benchmarks. This series represents a strategic advancement in AI capabilities, moving a previously weaker base model towards the state-of-the-art (SOTA) in reasoning tasks.
26
 
27
- # Methods
28
- The Skywork o1 Open series' remarkable cognitive abilities are developed through a three-stage training scheme:
29
- - Reflective Reasoning Training: Utilizing a proprietary multi-agent system to generate high-quality, diverse data for long-thinking tasks, followed by continuous pre-training and supervised fine-tuning.
30
 
31
- - Reinforcement Learning for Reasoning Capabilities: Introduction of the Skywork o1 Process Reward Model (PRM), tailored to enhance step-by-step reasoning. Our experiments confirm that the Skywork-PRM effectively captures the influence of intermediate reasoning steps on final outcomes, combined with proprietary reasoning reinforcement algorithms.
32
 
33
- - Reasoning Planning: Deploying Tiangong's proprietary Q* online reasoning algorithm alongside model-based thinking, searching for optimal reasoning paths. This marks the first implementation and public release of a Q* algorithm, significantly boosting the model's online reasoning capabilities.
34
 
35
- # Highlights
36
- The Skywork o1 Open series stands out with the following capabilities:
37
- - Enhanced model thinking and planning capabilities.
38
- - Advanced self-reflection and self-verification abilities.
39
 
40
- Compared to previous large models, the Skywork o1 Open series adeptly handles a variety of reasoning challenges, including common-sense, logical, mathematical, ethical decision-making, and logical trap problems.
41
 
42
- <img src="misc/demo_case.png" width="1000"/>
43
 
44
- # Models
45
- ## Skywork o1 Open 8B
46
- The [Skywork o1 Open 8B](https://huggingface.co/Skywork/Skywork-o1-Open-Llama3.1-8B) model shows notable improvements across various mathematical and coding benchmarks, pushing the performance of Llama-3.1-8B to the forefront of its category, outperforming prior SOTA models (with a similar size) Qwen-2.5-7B instruct.
47
- <img src="misc/main_result_math.png" width="1000"/>
48
- <img src="misc/main_result_code.png" width="1000"/>
49
 
50
- ## Quickstart
51
- To run inference with Skywork-o1-Open-Llama3.1-8B, simply provide a few lines of code as shown below.
52
 
53
- ```python
54
- import torch
55
- from transformers import AutoModelForCausalLM, AutoTokenizer
56
 
57
- system_prompt = """You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. After completing your thoughts, you then provide a detailed explanation of the solution process in your response."""
58
 
59
- # An Example Case
60
- problem = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?"
61
 
62
- user_message = problem
63
 
64
- conversation = [
65
- {
66
- "role": "system",
67
- "content": system_prompt
68
- },
69
- {
70
- "role": "user",
71
- "content": user_message
72
- }
73
- ]
74
 
75
- model_name = "Skywork-o1-Open-Llama3.1-8B"
76
- model = AutoModelForCausalLM.from_pretrained(
77
- model_name,
78
- torch_dtype="auto",
79
- device_map="auto"
80
- )
81
 
82
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
83
 
84
- input_ids = tokenizer.apply_chat_template(
85
- conversation,
86
- tokenize=True,
87
- add_generation_prompt=True,
88
- return_tensors="pt").to(model.device)
89
 
90
- generation = model.generate(
91
- input_ids=input_ids,
92
- max_new_tokens=2048,
93
- do_sample=False,
94
- pad_token_id=128009,
95
- temperature=0)
96
 
97
- completion = tokenizer.decode(
98
- generation[0][len(input_ids[0]):],
99
- skip_special_tokens=True,
100
- clean_up_tokenization_spaces=True)
 
101
 
102
- print(completion)
103
- ```
104
- ## Reasoning Examples
105
- Here are a few examples demonstrating how Skywork-o1-Open-Llama3.1-8B excels in planning, reasoning, and reflection.
106
 
107
- ### Mathematical Problem in Chinese
108
- ```
109
- Input: 一项工程,山东队独做需15天完成,江苏队独做需10天完成。两队合作,几天可以完成?
110
- ```
111
- ```
112
- Output: 好的,让我们来解决这个有趣的工程问题。我会用轻松的语气,一步步地分析这个问题。
113
 
114
- 首先,让我们理解一下题目给出的信息:
115
- 1. 山东队独做需要15天完成工程
116
- 2. 江苏队独做需要10天完成工程
117
 
118
- 现在,我们需要找出两队合作时的完成时间。让我们从每队的工作效率开始分析:
 
 
 
 
 
 
 
 
 
119
 
120
- 1. 山东队的效率:
121
- 15天完成1个工程,所以每天完成的工程量是 1/15
122
 
123
- 2. 江苏队的效率:
124
- 10天完成1个工程,所以每天完成的工程量是 1/10
 
 
 
 
 
 
 
 
125
 
126
- 稍等,让我们再检查一下这个计算...没错,这个计算是正确的。
127
 
128
- 现在,两队合作时,每天完成的工程量是两队效率的和:
129
- 1/15 + 1/10 = 2/30 + 3/30 = 5/30 = 1/6
130
 
131
- 这意味着每天两队合作完成的工程量是1/6。
 
 
 
 
 
 
 
 
 
132
 
133
- 接下来,我们需要计算完成1个工程需要多少天。我们可以用以下等式:
134
- (1/6) × x = 1
135
 
136
- 解这个等式:
137
- x = 1 ÷ (1/6) = 6
138
 
139
- 所以,两队合作完成工程需要6天。
140
 
141
- 让我们再验证一下:
142
- - 6天 × (1/6) = 1
143
- - 这确实等于1,说明我们的计算是正确的。
 
144
 
145
- 因此,最终答案是:两队合作完成工程需要6天。
146
- ```
147
- ### Mathematical Problem in English
148
- ```
149
- Input: Determine the value of $y$ in its simplest form, where $y = \frac{p + q\sqrt{r}}{s}$, given that $\frac{7y}{5} - 2 = \frac{4}{y}$, and $p, q, r, s$ are integers. What is $\frac{pqr}{s}$?
150
- ```
151
- ```
152
- Output: To determine the value of \( y \) in the form \( y = \frac{p + q\sqrt{r}}{s} \), we start with the given equation:
153
-
154
- \[
155
- \frac{7y}{5} - 2 = \frac{4}{y}
156
- \]
157
-
158
- First, we eliminate the denominators by multiplying both sides of the equation by \( 5y \):
159
-
160
- \[
161
- 7y^2 - 10y = 20
162
- \]
163
-
164
- Next, we rearrange the equation to set it to zero:
165
-
166
- \[
167
- 7y^2 - 10y - 20 = 0
168
- \]
169
-
170
- We solve this quadratic equation using the quadratic formula \( y = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \), where \( a = 7 \), \( b = -10 \), and \( c = -20 \). Plugging in these values, we get:
171
 
172
- \[
173
- y = \frac{-(-10) \pm \sqrt{(-10)^2 - 4 \cdot 7 \cdot (-20)}}{2 \cdot 7} = \frac{10 \pm \sqrt{100 + 560}}{14} = \frac{10 \pm \sqrt{660}}{14}
174
- \]
175
 
176
- We simplify \( \sqrt{660} \) as follows:
 
 
 
177
 
178
- \[
179
- \sqrt{660} = \sqrt{4 \cdot 165} = 2\sqrt{165}
180
- \]
181
 
182
- Substituting this back into the quadratic formula, we have:
 
 
 
183
 
184
- \[
185
- y = \frac{10 \pm 2\sqrt{165}}{14} = \frac{5 \pm \sqrt{165}}{7}
186
- \]
187
 
188
- Since we are asked for the value of \( y \) in the form \( \frac{p + q\sqrt{r}}{s} \), we can choose either the positive or negative root. Let's take the positive root for simplicity:
189
 
190
- \[
191
- y = \frac{5 + \sqrt{165}}{7}
192
- \]
193
 
194
- Here, \( p = 5 \), \( q = 1 \), \( r = 165 \), and \( s = 7 \). We need to find the value of \( \frac{pqr}{s} \):
195
 
196
- \[
197
- \frac{pqr}{s} = \frac{5 \cdot 1 \cdot 165}{7} = \frac{825}{7}
198
- \]
199
 
200
- Thus, the final answer is:
201
 
202
- \[
203
- \boxed{\frac{825}{7}}
204
- \]
 
205
  ```
206
- ### Logical Problem in Chinese
207
- ```
208
- Input: 火药∶鞭炮∶二踢脚 可以对应下列选项中的:A.乌铁∶刀具∶三棱刀\nB.红砖∶建筑∶四合院\nC.清水∶白酒∶五粮液\nD.杉木∶乐器∶六弦琴
209
- ```
210
-
211
- ```
212
- Output: 好的,让我们一起来解决这个逻辑关系题。我们需要找出与"火药∶鞭炮∶二踢脚"最相似的关系。
213
 
214
- 首先,让我们分析给定的关系:
215
- 火药 鞭炮 → 二踢脚
 
 
216
 
217
- 1. 火药是制作鞭炮的原料。
218
- 2. 鞭炮是由火药制成的。
219
- 3. 二踢脚是鞭炮的一种。
 
 
 
 
 
 
 
 
 
220
 
221
- 现在,让我们逐个检查选项:
222
 
223
- A. 乌铁∶刀具∶三棱刀
224
- 乌铁是制作刀具的原料,刀具是由乌铁制成的,三棱刀是刀具的一种。
225
- 这个关系看起来很相似!
226
 
227
- B. 红砖∶建筑∶四合院
228
- 红砖是建筑的一种材料,但不是所有建筑都用红砖。
229
- 四合院是建筑的一种,但不是所有建筑都叫四合院。
230
- 这个关系不太匹配。
 
 
 
231
 
232
- C. 清水∶白酒∶五粮液
233
- 清水是制作白酒的原料,白酒是由清水制成的。
234
- 但是五粮液是白酒的一种,而不是白酒的原料。
235
- 这个关系不太匹配。
236
 
237
- D. 杉木∶乐器∶六弦琴
238
- 杉木是制作乐器的原料,乐器是由杉木制成的。
239
- 但是六弦琴是乐器的一种,而不是乐器的原料。
240
- 这个关系也不太匹配。
 
 
 
241
 
242
- 稍等,让我再仔细思考一下...
 
 
 
 
 
 
 
 
 
243
 
244
- 在重新审视这些选项后,我更加确信A选项是最佳答案。让我解释一下为什么:
245
 
246
- 1. 乌铁作为原料,刀具作为产品,三棱刀作为具体产品的关系,与火药、鞭炮、二踢脚的关系最为相似。
247
- 2. 每个选项都遵循了"原料 产品 → 具体产品"的模式。
248
- 3. 其他选项要么缺少了原料和产品的关系(如B和C),要么将产品和具体产品的关系颠倒了(如D)。
 
249
 
250
- 因此,最终答案是:A. 乌铁∶刀具∶三棱刀。
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- 这个选项最准确地反映了原始关系中的逻辑结构。
253
- ```
254
- ### Logical Problem in English
255
- ```
256
- Input: How many `r`s in strawberry?
257
- ```
258
- ```
259
- Ouput: To determine how many `r`s are in the word "strawberry," we will follow a step-by-step approach:
260
-
261
- 1. **Identify the word**: The word we need to analyze is "strawberry."
262
- 2. **List each letter**: Write down each letter in the word to ensure we don't miss any.
263
- - s
264
- - t
265
- - r
266
- - a
267
- - w
268
- - b
269
- - e
270
- - r
271
- - r
272
- - y
273
- 3. **Count the occurrences of `r`**: Now, we count how many times the letter `r` appears in the list.
274
- - The first `r` is in the third position.
275
- - The second `r` is in the seventh position.
276
- - The third `r` is in the ninth position.
277
- 4. **Verify the count**: Double-check that we have counted all the `r`s correctly and haven't missed any or counted any other letters as `r`.
278
-
279
- After following these steps, we find that there are three `r`s in the word "strawberry."
280
-
281
- Therefore, the number of `r`s in strawberry is \(\boxed{3}\).
282
- ```
283
 
284
- ### Code Problem
285
- ```
286
- Input: Given an array of integers `nums` with a length of `n` and an integer `k`, please determine if there exist two adjacent subarrays of length `k` that are strictly increasing. Specifically, you need to check if there are two subarrays starting from indices `a` and `b` (where `a < b`) that meet the following conditions:
287
- - Both subarrays `nums[a..a + k - 1]` and `nums[b..b + k - 1]` are strictly increasing.
288
- - These two subarrays must be adjacent, meaning `b = a + k`.
289
- If such two subarrays can be found, return `true`; otherwise, return `false`.
290
- A subarray is a contiguous sequence of elements within an array.
291
-
292
- Example 1:
293
- Input: `nums = [2,5,7,8,9,2,3,4,3,1]`, `k = 3`
294
- Output: `true`
295
- Explanation:
296
- The subarray starting from index 2 is `[7,8,9]`, which is strictly increasing. The subarray starting from index 5 is `[2,3,4]`, which is also strictly increasing. The two subarrays are adjacent, so the result is `true`.
297
-
298
- Example 2:
299
- Input: `nums = [1,2,3,4,4,4,4,5,6,7]`, `k = 5`
300
- Output: `false`
301
-
302
- Hints:
303
- - `2 <= nums.length <= 1000`
304
- - `1 <= 2 * k <= nums.length`
305
- - `-1000 <= nums[i] <= 1000`
306
-
307
- Please output the answer in the following format:
308
- ```
309
 
 
 
 
310
  ```
311
- Output: ### Planning Phase
312
-
313
- #### Key Steps:
314
- 1. **Iterate through the array**: Start from the first element and check every possible starting point for a subarray of length `k`.
315
- 2. **Check for strictly increasing subarray**: For each starting point, check if the subarray of length `k` is strictly increasing.
316
- 3. **Check the next subarray**: If the current subarray is strictly increasing, check if the next subarray of length `k` starting from the next element is also strictly increasing.
317
- 4. **Return the result**: If both subarrays are strictly increasing and adjacent, return `true`. If no such pair is found after checking all possible starting points, return `false`.
318
-
319
- #### Important Considerations:
320
- - **Edge Cases**:
321
- - Arrays with fewer than `2*k` elements.
322
- - Arrays where no subarray of length `k` is strictly increasing.
323
- - Arrays where multiple subarrays of length `k` are strictly increasing, but none are adjacent.
324
- - **Pitfalls**:
325
- - Ensure that the subarrays are strictly increasing, not just non-decreasing.
326
- - Handle cases where the array length is exactly `2*k`.
327
-
328
- ### Code Solution
329
-
330
-
331
- class Solution(object):
332
- def hasincreasingSubarrays(self, nums, k):
333
- # Iterate through the array, stopping at the point where the next subarray would go out of bounds
334
- for i in range(len(nums) - k + 1):
335
- # Check if the current subarray is strictly increasing
336
- is_increasing = True
337
- for j in range(1, k):
338
- if nums[i + j] <= nums[i + j - 1]:
339
- is_increasing = False
340
- break
341
-
342
- # If the current subarray is strictly increasing, check the next subarray
343
- if is_increasing:
344
- # Check if the next subarray is strictly increasing
345
- for j in range(1, k):
346
- if nums[i + k + j] <= nums[i + k + j - 1]:
347
- break
348
- else:
349
- # If the next subarray is also strictly increasing, return True
350
- return True
351
-
352
- # If no such pair of subarrays is found, return False
353
- return False
354
-
355
- ### Solution Explanation
356
-
357
- #### Detailed Explanation:
358
- 1. **Iterate through the array**:
359
- - We use a loop to iterate through the array from the first element to the point where the next subarray of length `k` would go out of bounds (`len(nums) - k + 1`).
360
-
361
- 2. **Check for strictly increasing subarray**:
362
- - For each starting index `i`, we check if the subarray `nums[i:i+k]` is strictly increasing.
363
- - We use a nested loop to compare each element with the previous one. If any element is not greater than the previous one, we set `is_increasing` to `False` and break out of the loop.
364
-
365
- 3. **Check the next subarray**:
366
- - If the current subarray is strictly increasing, we then check the next subarray starting from `i + k`.
367
- - We use another nested loop to check if `nums[i+k:i+2*k]` is strictly increasing.
368
- - If the next subarray is also strictly increasing, we return `True`.
369
-
370
- 4. **Return the result**:
371
- - If we complete the loop without finding any such pair of subarrays, we return `False`.
372
-
373
- #### Clarity and Breakdown:
374
- - **Outer Loop**: The outer loop iterates through each possible starting point for a subarray of length `k`.
375
- - **Inner Loop for Current Subarray**: The inner loop checks if the current subarray is strictly increasing.
376
- - **Inner Loop for Next Subarray**: If the current subarray is strictly increasing, the inner loop checks if the next subarray is also strictly increasing.
377
- - **Return True/False**: If both subarrays are strictly increasing and adjacent, we return `True`. If no such pair is found, we return `False`.
378
-
379
- This approach ensures that we efficiently check all possible pairs of adjacent strictly increasing subarrays of length `k` in the given array.
380
- ```
381
-
382
- ## Skywork o1 Open PRM
383
- We introduce two advanced Process-Reward-Models (PRM) for reasoning tasks:
384
- - Skywork o1 Open-PRM-Qwen2.5-1.5B: Achieves the performance of 8B models, competing with advanced models like RLHFlow's Llama3.1-8B-PRM-Deepseek-Data and OpenR's Math-psa-7B.
385
- - Skywork o1 Open-PRM-Qwen2.5-7B: Matches or surpasses larger scale models like Qwen2.5-Math-RM-72B on most benchmarks, setting a new standard for AI reasoning.
386
-
387
- <img src="misc/prm_result_math.png" width="1000"/>
388
- <img src="misc/prm_result_code.png" width="1000"/>
389
-
390
- The inference code is publicly available at https://github.com/SkyworkAI/skywork-o1-prm-inference.
391
 
392
- # Contact
 
 
 
 
393
 
394
- If you have any questions, please feel free to reach us at {zifeng.cao, liang.zhao, liang.zeng, tianwen.wei}@kunlun-inc.com.
395
 
396
  # LICENSE
397
  The community usage of Skywork models require Skywork Community License. The Skywork models support commercial use. If you plan to use the Skywork models or its derivatives for commercial purposes, you must abide by terms and conditions within Skywork Community License.
@@ -408,7 +262,7 @@ If you find our work helpful, please feel free to cite us using the following Bi
408
  title={Skywork-o1 Open Series},
409
  author={Skywork-o1 Team},
410
  year={2024},
411
- month={September},
412
  howpublished={\url{https://huggingface.co/Skywork}},
413
  url={https://huggingface.co/Skywork},
414
  }
 
24
 
25
  Different from mere reproductions of the OpenAI o1 model, the Skywork o1 Open model series not only exhibits innate thinking, planning, and reflecting capabilities in its outputs, but also shows significant improvements in reasoning skills on standard benchmarks. This series represents a strategic advancement in AI capabilities, moving a previously weaker base model towards the state-of-the-art (SOTA) in reasoning tasks.
26
 
27
+ If you are interested in the Skywork o1 Open model series, please check out the [o1-llama-3.1-8b](https://huggingface.co/Skywork/o1-llama-3.1-8b) model.
 
 
28
 
 
29
 
 
30
 
31
+ # Model Information
32
+ The Skywork-o1-Open-PRM series are trained on [**Qwen2.5-Math-1.5B-Instruct**](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B-Instruct) and [**Qwen2.5-Math-7B-Instruct**](https://huggingface.co/Qwen/Qwen2.5-Math-7B-Instruct).
 
 
33
 
 
34
 
35
+ # PRM Evaluation
36
 
37
+ ## Evaluation Settings
 
 
 
 
38
 
39
+ ### Mathematical Evaluation
40
+ We utilized the evaluation scripts from [Qwen2.5-Math](https://github.com/QwenLM/Qwen2.5-Math) and followed their configuration to ensure consistency. The selected datasets include **GSM8K**, **MATH**, **GaoKao**, **CN-Middle School 24**, **OlympiadBench**, **AMC-23**, and **AIME-24**. Among these, **GaoKao** and **CN-Middle School 24** are Chinese datasets, while the remaining datasets are in English. Notably, **OlympiadBench**, **AIME-24**, and **AMC-23** are competition-level datasets.
41
 
42
+ ### Code Evaluation
43
+ For code evaluation, we adopted the evaluation scripts from [Qwen2.5-Coder](https://github.com/QwenLM/Qwen2.5-Coder), maintaining the same configuration. The selected datasets include **HumanEval**, **MBPP**, and **LiveCodeBench**, with **LiveCodeBench** specifically using the version **2024.01-2024-11**.
 
44
 
 
45
 
46
+ ## Evaluation Base Models
 
47
 
48
+ We evaluated the performance of RMs on three base models: **Qwen2.5-7B-Instruct**, **Llama3.1-8B-Instruct**, and **Skywork-o1-Open-8B**. Data sampling was conducted to verify the performance of the RMs across different models. The sampling temperature was set to **0.7** for mathematical problems and **1.0** for code-related tasks.
49
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ ## Compared RMs
 
 
 
 
 
52
 
53
+ - [Qwen2.5-Math-RM-72B](https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B): An open-source ORM provided by the Qwen team.
54
+ - [OpenR-MATH-psa-PRM-7B](https://huggingface.co/openreasoner/Math-psa): An open-source PRM from the OpenR project.
55
+ - [RLHFlow-Deepseek-Data-PRM-8B](https://huggingface.co/RLHFlow/Llama3.1-8B-PRM-Deepseek-Data): An open-source PRM from the RLHFlow project.
56
 
 
 
 
 
 
57
 
58
+ ## Evaluation Metrics
 
 
 
 
 
59
 
60
+ - **Greedy Sampling Pass@1**: Uses greedy sampling for generating answers.
61
+ - **Majority Voting@64**: Randomly samples 64 answers.
62
+ - **Best-of-N@64**: Ranks 64 answers based on output values provided by the Reward Model (RM). The weighting methods differ for ORM and PRM:
63
+ - For **ORM**, only the reward from the final step is used.
64
+ - For **PRM**, the average reward across all steps is used for weighting.
65
 
 
 
 
 
66
 
67
+ ## Evaluation Results
 
 
 
 
 
68
 
69
+ ### Mathematical Evaluation
 
 
70
 
71
+ #### Skywork-o1-Open-8B
72
+ | Reward Model | Method | GSM8K | MATH | GaoKao | CN-Middle School 24 | OlympiadBench | AIME-24 | AMC-23 | Avg |
73
+ |--------------------------------|-------------------------|--------|--------|--------|---------------------|---------------|---------|--------|------|
74
+ | N/A | Greedy Sampling Pass@1 | 91.6 | 78.1 | 63.6 | 67.3 | 43.1 | 13.3 | 55.0 | 58.9 |
75
+ | N/A | Majority Voting@64 | 93.9 | 84.3 | 69.5 | 73.3 | 50.4 | 16.7 | 52.5 | 62.9 |
76
+ | OpenR-MATH-psa-PRM-**7B** | Best-of-N@64 | 95.1 | 82.7 | 67.1 | 70.3 | 47.6 | 20.0 | 57.5 | 62.9 |
77
+ | RLHFlow-Deepseek-Data-PRM-**8B** | Best-of-N@64 | 94.4 | 80.1 | 59.1 | 74.3 | 43.0 | 20.0 | 50.0 | 60.1 |
78
+ | Qwen2.5-Math-RM-**72B** | Best-of-N@64 | 96.1 | 86.9 | **76.0** | 76.2 | **53.3** | **26.7** | **65.0** | **68.6** |
79
+ | Skywork-o1-Open-PRM-**1.5B** | Best-of-N@64 | 94.5 | 85.0 | 65.6 | 73.3 | 49.9 | 16.7 | 62.5 | 63.9 |
80
+ | Skywork-o1-Open-PRM-**7B** | Best-of-N@64 | **96.7** | **87.0** | 70.3 | **76.2** | 52.3 | 23.3 | **65.0** | 67.3 |
81
 
 
 
82
 
83
+ #### Qwen2.5-7B-Instruct
84
+ | Reward Model | Method | GSM8K | MATH | GaoKao | CN-Middle School 24 | OlympiadBench | AIME-24 | AMC-23 | Avg |
85
+ |--------------------------------|-------------------------|--------|--------|--------|---------------------|---------------|---------|--------|------|
86
+ | N/A | Greedy Sampling Pass@1 | 91.9 | 75.2 | 55.6 | 75.2 | 39.1 | 13.3 | 45.0 | 56.5 |
87
+ | N/A | Majority Voting@64 | 93.5 | 78.4 | 55.3 | 78.2 | 40.1 | 13.3 | 62.5 | 60.2 |
88
+ | OpenR-MATH-psa-PRM-**7B** | Best-of-N@64 | 93.9 | 77.9 | 52.4 | 73.3 | 40.7 | 10.0 | 55.0 | 57.6 |
89
+ | RLHFlow-Deepseek-Data-PRM-**8B** | Best-of-N@64 | 94.1 | 78.1 | 53.2 | 75.2 | 39.1 | 16.7 | 55.0 | 58.8 |
90
+ | Qwen2.5-Math-RM-**72B** | Best-of-N@64 | 94.8 | **82.4** | **65.2** | **80.2** | **45.0** | **13.3** | 62.5 | 63.4 |
91
+ | Skywork-o1-Open-PRM-**1.5B** | Best-of-N@64 | 93.3 | 79.8 | 56.1 | 74.3 | 43.9 | 10.0 | 67.5 | 60.7 |
92
+ | Skywork-o1-Open-PRM-**7B** | Best-of-N@64 | **94.9** | 81.9 | 56.3 | 75.2 | 44.9 | **13.3** | **65.0** | 61.6 |
93
 
 
94
 
 
 
95
 
96
+ #### Llama3.1-8B-Instruct
97
+ | Reward Model | Method | GSM8K | MATH | GaoKao | CN-Middle School 24 | OlympiadBench | AIME-24 | AMC-23 | Avg |
98
+ |--------------------------------|-------------------------|--------|--------|--------|---------------------|---------------|---------|--------|------|
99
+ | N/A | Greedy Sampling Pass@1 | 85.3 | 49.7 | 25.3 | 47.5 | 16.6 | 6.7 | 27.5 | 36.9 |
100
+ | N/A | Majority Voting@64 | 90.9 | 62.9 | 28.0 | 56.4 | 26.4 | 13.3 | 37.5 | 45.1 |
101
+ | OpenR-MATH-psa-PRM-**7B** | Best-of-N@64 | 91.8 | 59.4 | 24.7 | 47.5 | 23.0 | 13.3 | 35.0 | 42.1 |
102
+ | RLHFlow-Deepseek-Data-PRM-**8B** | Best-of-N@64 | 89.8 | 56.1 | 24.0 | 40.6 | 20.4 | 0.0 | 35.0 | 38.0 |
103
+ | Qwen2.5-Math-RM-**72B** | Best-of-N@64 | **94.9** | **72.5** | **44.9** | **65.3** | **34.4** | **23.3** | **60.0** | 56.5 |
104
+ | Skywork-o1-Open-PRM-**1.5B** | Best-of-N@64 | 91.7 | 65.6 | 26.8 | 49.5 | 27.0 | 16.7 | **60.0** | 48.2 |
105
+ | Skywork-o1-Open-PRM-**7B** | Best-of-N@64 | 94.0 | 69.8 | 32.0 | 56.4 | 29.9 | 16.7 | 52.5 | 50.2 |
106
 
 
 
107
 
108
+ ### Code Evaluation
109
+ Since the compared PRMs have not been trained on code-related tasks, this section focuses solely on the performance of Skywork-o1-Open-PRM.
110
 
111
+ #### Skywork-o1-Open-8B
112
 
113
+ | Reward Model | Method | MBPP | MBPP+ | HumanEval | HumanEval+ | LiveCodeBench-2024.01-2024-11 |
114
+ |--------------------------|-------------------------|-------|-------|-----------|------------|-------------------------------|
115
+ | N/A | Greedy Sampling Pass@1 | 77.9 | 65.2 | **82.3** | **78.7** | 26.0 |
116
+ | Skywork-o1-Open-PRM-7B | Best-of-N@64 | **81.2** | **68.5** | 81.1 | 74.4 | **31.3** |
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ #### Qwen2.5-7B-Instruct
 
 
120
 
121
+ | Reward Model | Method | MBPP | MBPP+ | HumanEval | HumanEval+ | LiveCodeBench-2024.01-2024-11 |
122
+ |--------------------------|-------------------------|-------|-------|-----------|------------|-------------------------------|
123
+ | N/A | Greedy Sampling Pass@1 | 81.7 | 69.3 | **84.8** | **78.0** | 25.3 |
124
+ | Skywork-o1-Open-PRM-7B | Best-of-N@64 | **84.9** | **72.5** | 83.5 | **78.0** | |
125
 
126
+ #### Llama3.1-8B-Instruct
 
 
127
 
128
+ | Reward Model | Method | MBPP | MBPP+ | HumanEval | HumanEval+ | LiveCodeBench-2024.01-2024-11 |
129
+ |--------------------------|-------------------------|-------|-------|-----------|------------|-------------------------------|
130
+ | N/A | Greedy Sampling Pass@1 | 75.7 | 64.0 | 68.3 | 63.4 | 16.1 |
131
+ | Skywork-o1-Open-PRM-7B | Best-of-N@64 | **77.8** | **66.1** | **73.8** | **67.7** | **23.3** |
132
 
 
 
 
133
 
 
134
 
135
+ # Quick Start
136
+ For detailed instructions on using PRM, please refer to the [Skywork-o1-PRM-Inference Repository](https://github.com/SkyworkAI/skywork-o1-prm-inference).
 
137
 
 
138
 
 
 
 
139
 
140
+ ## huggingface inference
141
 
142
+ 1. clone skywork prm inference repo
143
+ ```shell
144
+ git clone https://github.com/SkyworkAI/skywork-o1-prm-inference.git
145
+ cd skywork-o1-prm-inference
146
  ```
147
+ 1. run_prm_inference
 
 
 
 
 
 
148
 
149
+ ```python
150
+ from transformers import AutoTokenizer
151
+ from model_utils.prm_model import PRM_MODEL
152
+ from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
153
 
154
+ prm_model_path = "/path/to/prm_model"
155
+ tokenizer = AutoTokenizer.from_pretrained(prm_model_path, trust_remote_code=True)
156
+ datas = [
157
+ {
158
+ "problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
159
+ "response" : "To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. **Calculate the total number of eggs laid by the ducks per day.**\n Janet's ducks lay 16 eggs per day.\n2. **Determine the number of eggs Janet uses each day.**\n - She eats 3 eggs for breakfast every morning.\n - She bakes muffins for her friends every day with 4 eggs.\n So, the total number of eggs used per day is:\n \\[\n 3 + 4 = 7 \\text{ eggs}\n \\]\n3. **Calculate the number of eggs Janet sells at the farmers' market each day.**\n Subtract the number of eggs used from the total number of eggs laid:\n \\[\n 16 - 7 = 9 \\text{ eggs}\n \\]\n4. **Determine how much money Janet makes from selling the eggs.**\n She sells each egg for $2, so the total amount of money she makes is:\n \\[\n 9 \\times 2 = 18 \\text{ dollars}\n \\]\nTherefore, the amount of money Janet makes every day at the farmers' market is \\(\\boxed{18}\\)."
160
+ },
161
+ {
162
+ "problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
163
+ "response" : "To find out how much Janet makes every day at the farmers' market, we need to follow these steps:\n1. Calculate the total number of eggs laid by the ducks: \n - Janet's ducks lay 16 eggs per day.\n2. Calculate the total number of eggs consumed by Janet and used for baking:\n - Janet eats 3 eggs for breakfast.\n - She bakes 4 eggs for muffins.\n - Total eggs used: 3 (breakfast) + 4 (baking) = 7 eggs.\n3. Calculate the remaining number of eggs for sale:\n - Total eggs laid: 16\n - Eggs used: 7\n - Remaining eggs: 16 - 7 = 9 eggs\n4. Calculate the total amount of money made at the farmers' market:\n - Price per egg: $2\n - Number of eggs sold: 9\n - Total money made: 9 * $2 = $18\nTherefore, Janet makes $\\boxed{18}$ dollars every day at the farmers' market."
164
+ }
165
+ ]
166
 
 
167
 
168
+ processed_data = [prepare_input(d["problem"], d["response"], tokenizer=tokenizer, step_token="\n") for d in datas]
169
+ input_ids, steps, reward_flags = zip(*processed_data)
 
170
 
171
+ model = PRM_MODEL.from_pretrained(prm_model_path, device_map="auto").eval()
172
+ input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id)
173
+ _, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True)
174
+ step_rewards = derive_step_rewards(rewards, reward_flags)
175
+ print("step_rewards:",step_rewards[0])
176
+ print("step_rewards:",step_rewards[1])
177
+ ```
178
 
179
+ ## vllm server for inference
 
 
 
180
 
181
+ 1. install vllm and install vllm prm plugin
182
+ ```shell
183
+ pip install vllm==v0.6.4.post1
184
+ git clone https://github.com/SkyworkAI/skywork-o1-prm-inference.git
185
+ cd skywork-o1-prm-inference
186
+ pip install -e .
187
+ ```
188
 
189
+ 2. start vllm server
190
+ ```shell
191
+ CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve /path/to/prm_model \
192
+ --host 0.0.0.0 \
193
+ --port 8081 \
194
+ --tensor-parallel-size 4 \
195
+ --gpu-memory-utilization 0.9 \
196
+ --enable-prefix-caching \
197
+ --dtype auto
198
+ ```
199
 
200
+ 3. request server for inference
201
 
202
+ ```python
203
+ from openai import OpenAI
204
+ from transformers import AutoTokenizer
205
+ from model_utils.io_utils import prepare_input, derive_step_rewards_vllm
206
 
207
+ prm_model_path = "/path/to/prm_model"
208
+ tokenizer = AutoTokenizer.from_pretrained(prm_model_path, trust_remote_code=True)
209
+ datas = [
210
+ {
211
+ "problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
212
+ "response" : "To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. **Calculate the total number of eggs laid by the ducks per day.**\n Janet's ducks lay 16 eggs per day.\n2. **Determine the number of eggs Janet uses each day.**\n - She eats 3 eggs for breakfast every morning.\n - She bakes muffins for her friends every day with 4 eggs.\n So, the total number of eggs used per day is:\n \\[\n 3 + 4 = 7 \\text{ eggs}\n \\]\n3. **Calculate the number of eggs Janet sells at the farmers' market each day.**\n Subtract the number of eggs used from the total number of eggs laid:\n \\[\n 16 - 7 = 9 \\text{ eggs}\n \\]\n4. **Determine how much money Janet makes from selling the eggs.**\n She sells each egg for $2, so the total amount of money she makes is:\n \\[\n 9 \\times 2 = 18 \\text{ dollars}\n \\]\nTherefore, the amount of money Janet makes every day at the farmers' market is \\(\\boxed{18}\\)."
213
+ },
214
+ {
215
+ "problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
216
+ "response" : "To find out how much Janet makes every day at the farmers' market, we need to follow these steps:\n1. Calculate the total number of eggs laid by the ducks: \n - Janet's ducks lay 16 eggs per day.\n2. Calculate the total number of eggs consumed by Janet and used for baking:\n - Janet eats 3 eggs for breakfast.\n - She bakes 4 eggs for muffins.\n - Total eggs used: 3 (breakfast) + 4 (baking) = 7 eggs.\n3. Calculate the remaining number of eggs for sale:\n - Total eggs laid: 16\n - Eggs used: 7\n - Remaining eggs: 16 - 7 = 9 eggs\n4. Calculate the total amount of money made at the farmers' market:\n - Price per egg: $2\n - Number of eggs sold: 9\n - Total money made: 9 * $2 = $18\nTherefore, Janet makes $\\boxed{18}$ dollars every day at the farmers' market."
217
+ }
218
+ ]
219
 
220
+ # data preprocessing
221
+ processed_data = [prepare_input(d["problem"], d["response"], tokenizer=tokenizer, step_token="\n") for d in datas]
222
+ input_ids, steps, reward_flags = zip(*processed_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ openai_api_key = "EMPTY"
225
+ openai_api_base = "http://localhost:8081/v1"
226
+ client = OpenAI(
227
+ # defaults to os.environ.get("OPENAI_API_KEY")
228
+ api_key=openai_api_key,
229
+ base_url=openai_api_base,
230
+ )
231
+ models = client.models.list()
232
+ model = models.data[0].id
233
+ rewards = client.embeddings.create(
234
+ input=input_ids,
235
+ model=model,
236
+ )
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
+ step_rewards = derive_step_rewards_vllm(rewards, reward_flags)
239
+ print("step_rewards:",step_rewards[0])
240
+ print("step_rewards:",step_rewards[1])
241
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # TODO
244
+ - Add more results for step-wise RM evaluation methods.
245
+ - Adjust the RM architecture to enhance compatibility with vLLM/sglang inference.
246
+ - Expand RM use cases by incorporating more types of reasoning tasks.
247
+ - Mitigate performance conflicts across different reasoning tasks.
248
 
 
249
 
250
  # LICENSE
251
  The community usage of Skywork models require Skywork Community License. The Skywork models support commercial use. If you plan to use the Skywork models or its derivatives for commercial purposes, you must abide by terms and conditions within Skywork Community License.
 
262
  title={Skywork-o1 Open Series},
263
  author={Skywork-o1 Team},
264
  year={2024},
265
+ month={November},
266
  howpublished={\url{https://huggingface.co/Skywork}},
267
  url={https://huggingface.co/Skywork},
268
  }