[docs] update readme: add more base model and rm evaluation
Browse files
README.md
CHANGED
@@ -24,374 +24,228 @@ We are excited to announce the release of the Skywork o1 Open model series, deve
|
|
24 |
|
25 |
Different from mere reproductions of the OpenAI o1 model, the Skywork o1 Open model series not only exhibits innate thinking, planning, and reflecting capabilities in its outputs, but also shows significant improvements in reasoning skills on standard benchmarks. This series represents a strategic advancement in AI capabilities, moving a previously weaker base model towards the state-of-the-art (SOTA) in reasoning tasks.
|
26 |
|
27 |
-
|
28 |
-
The Skywork o1 Open series' remarkable cognitive abilities are developed through a three-stage training scheme:
|
29 |
-
- Reflective Reasoning Training: Utilizing a proprietary multi-agent system to generate high-quality, diverse data for long-thinking tasks, followed by continuous pre-training and supervised fine-tuning.
|
30 |
|
31 |
-
- Reinforcement Learning for Reasoning Capabilities: Introduction of the Skywork o1 Process Reward Model (PRM), tailored to enhance step-by-step reasoning. Our experiments confirm that the Skywork-PRM effectively captures the influence of intermediate reasoning steps on final outcomes, combined with proprietary reasoning reinforcement algorithms.
|
32 |
|
33 |
-
- Reasoning Planning: Deploying Tiangong's proprietary Q* online reasoning algorithm alongside model-based thinking, searching for optimal reasoning paths. This marks the first implementation and public release of a Q* algorithm, significantly boosting the model's online reasoning capabilities.
|
34 |
|
35 |
-
#
|
36 |
-
The Skywork
|
37 |
-
- Enhanced model thinking and planning capabilities.
|
38 |
-
- Advanced self-reflection and self-verification abilities.
|
39 |
|
40 |
-
Compared to previous large models, the Skywork o1 Open series adeptly handles a variety of reasoning challenges, including common-sense, logical, mathematical, ethical decision-making, and logical trap problems.
|
41 |
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
## Skywork o1 Open 8B
|
46 |
-
The [Skywork o1 Open 8B](https://huggingface.co/Skywork/Skywork-o1-Open-Llama3.1-8B) model shows notable improvements across various mathematical and coding benchmarks, pushing the performance of Llama-3.1-8B to the forefront of its category, outperforming prior SOTA models (with a similar size) Qwen-2.5-7B instruct.
|
47 |
-
<img src="misc/main_result_math.png" width="1000"/>
|
48 |
-
<img src="misc/main_result_code.png" width="1000"/>
|
49 |
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
56 |
|
57 |
-
system_prompt = """You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. After completing your thoughts, you then provide a detailed explanation of the solution process in your response."""
|
58 |
|
59 |
-
|
60 |
-
problem = "Jane has 12 apples. She gives 4 apples to her friend Mark, then buys 1 more apple, and finally splits all her apples equally among herself and her 2 siblings. How many apples does each person get?"
|
61 |
|
62 |
-
|
63 |
|
64 |
-
conversation = [
|
65 |
-
{
|
66 |
-
"role": "system",
|
67 |
-
"content": system_prompt
|
68 |
-
},
|
69 |
-
{
|
70 |
-
"role": "user",
|
71 |
-
"content": user_message
|
72 |
-
}
|
73 |
-
]
|
74 |
|
75 |
-
|
76 |
-
model = AutoModelForCausalLM.from_pretrained(
|
77 |
-
model_name,
|
78 |
-
torch_dtype="auto",
|
79 |
-
device_map="auto"
|
80 |
-
)
|
81 |
|
82 |
-
|
|
|
|
|
83 |
|
84 |
-
input_ids = tokenizer.apply_chat_template(
|
85 |
-
conversation,
|
86 |
-
tokenize=True,
|
87 |
-
add_generation_prompt=True,
|
88 |
-
return_tensors="pt").to(model.device)
|
89 |
|
90 |
-
|
91 |
-
input_ids=input_ids,
|
92 |
-
max_new_tokens=2048,
|
93 |
-
do_sample=False,
|
94 |
-
pad_token_id=128009,
|
95 |
-
temperature=0)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
101 |
|
102 |
-
print(completion)
|
103 |
-
```
|
104 |
-
## Reasoning Examples
|
105 |
-
Here are a few examples demonstrating how Skywork-o1-Open-Llama3.1-8B excels in planning, reasoning, and reflection.
|
106 |
|
107 |
-
|
108 |
-
```
|
109 |
-
Input: 一项工程,山东队独做需15天完成,江苏队独做需10天完成。两队合作,几天可以完成?
|
110 |
-
```
|
111 |
-
```
|
112 |
-
Output: 好的,让我们来解决这个有趣的工程问题。我会用轻松的语气,一步步地分析这个问题。
|
113 |
|
114 |
-
|
115 |
-
1. 山东队独做需要15天完成工程
|
116 |
-
2. 江苏队独做需要10天完成工程
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
1. 山东队的效率:
|
121 |
-
15天完成1个工程,所以每天完成的工程量是 1/15
|
122 |
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
稍等,让我们再检查一下这个计算...没错,这个计算是正确的。
|
127 |
|
128 |
-
现在,两队合作时,每天完成的工程量是两队效率的和:
|
129 |
-
1/15 + 1/10 = 2/30 + 3/30 = 5/30 = 1/6
|
130 |
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
-
接下来,我们需要计算完成1个工程需要多少天。我们可以用以下等式:
|
134 |
-
(1/6) × x = 1
|
135 |
|
136 |
-
|
137 |
-
|
138 |
|
139 |
-
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
|
|
144 |
|
145 |
-
因此,最终答案是:两队合作完成工程需要6天。
|
146 |
-
```
|
147 |
-
### Mathematical Problem in English
|
148 |
-
```
|
149 |
-
Input: Determine the value of $y$ in its simplest form, where $y = \frac{p + q\sqrt{r}}{s}$, given that $\frac{7y}{5} - 2 = \frac{4}{y}$, and $p, q, r, s$ are integers. What is $\frac{pqr}{s}$?
|
150 |
-
```
|
151 |
-
```
|
152 |
-
Output: To determine the value of \( y \) in the form \( y = \frac{p + q\sqrt{r}}{s} \), we start with the given equation:
|
153 |
-
|
154 |
-
\[
|
155 |
-
\frac{7y}{5} - 2 = \frac{4}{y}
|
156 |
-
\]
|
157 |
-
|
158 |
-
First, we eliminate the denominators by multiplying both sides of the equation by \( 5y \):
|
159 |
-
|
160 |
-
\[
|
161 |
-
7y^2 - 10y = 20
|
162 |
-
\]
|
163 |
-
|
164 |
-
Next, we rearrange the equation to set it to zero:
|
165 |
-
|
166 |
-
\[
|
167 |
-
7y^2 - 10y - 20 = 0
|
168 |
-
\]
|
169 |
-
|
170 |
-
We solve this quadratic equation using the quadratic formula \( y = \frac{-b \pm \sqrt{b^2 - 4ac}}{2a} \), where \( a = 7 \), \( b = -10 \), and \( c = -20 \). Plugging in these values, we get:
|
171 |
|
172 |
-
|
173 |
-
y = \frac{-(-10) \pm \sqrt{(-10)^2 - 4 \cdot 7 \cdot (-20)}}{2 \cdot 7} = \frac{10 \pm \sqrt{100 + 560}}{14} = \frac{10 \pm \sqrt{660}}{14}
|
174 |
-
\]
|
175 |
|
176 |
-
|
|
|
|
|
|
|
177 |
|
178 |
-
|
179 |
-
\sqrt{660} = \sqrt{4 \cdot 165} = 2\sqrt{165}
|
180 |
-
\]
|
181 |
|
182 |
-
|
|
|
|
|
|
|
183 |
|
184 |
-
\[
|
185 |
-
y = \frac{10 \pm 2\sqrt{165}}{14} = \frac{5 \pm \sqrt{165}}{7}
|
186 |
-
\]
|
187 |
|
188 |
-
Since we are asked for the value of \( y \) in the form \( \frac{p + q\sqrt{r}}{s} \), we can choose either the positive or negative root. Let's take the positive root for simplicity:
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
\]
|
193 |
|
194 |
-
Here, \( p = 5 \), \( q = 1 \), \( r = 165 \), and \( s = 7 \). We need to find the value of \( \frac{pqr}{s} \):
|
195 |
|
196 |
-
\[
|
197 |
-
\frac{pqr}{s} = \frac{5 \cdot 1 \cdot 165}{7} = \frac{825}{7}
|
198 |
-
\]
|
199 |
|
200 |
-
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
205 |
```
|
206 |
-
|
207 |
-
```
|
208 |
-
Input: 火药∶鞭炮∶二踢脚 可以对应下列选项中的:A.乌铁∶刀具∶三棱刀\nB.红砖∶建筑∶四合院\nC.清水∶白酒∶五粮液\nD.杉木∶乐器∶六弦琴
|
209 |
-
```
|
210 |
-
|
211 |
-
```
|
212 |
-
Output: 好的,让我们一起来解决这个逻辑关系题。我们需要找出与"火药∶鞭炮∶二踢脚"最相似的关系。
|
213 |
|
214 |
-
|
215 |
-
|
|
|
|
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
-
现在,让我们逐个检查选项:
|
222 |
|
223 |
-
|
224 |
-
|
225 |
-
这个关系看起来很相似!
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
|
|
231 |
|
232 |
-
|
233 |
-
清水是制作白酒的原料,白酒是由清水制成的。
|
234 |
-
但是五粮液是白酒的一种,而不是白酒的原料。
|
235 |
-
这个关系不太匹配。
|
236 |
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
241 |
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
|
244 |
-
|
245 |
|
246 |
-
|
247 |
-
|
248 |
-
|
|
|
249 |
|
250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
```
|
256 |
-
Input: How many `r`s in strawberry?
|
257 |
-
```
|
258 |
-
```
|
259 |
-
Ouput: To determine how many `r`s are in the word "strawberry," we will follow a step-by-step approach:
|
260 |
-
|
261 |
-
1. **Identify the word**: The word we need to analyze is "strawberry."
|
262 |
-
2. **List each letter**: Write down each letter in the word to ensure we don't miss any.
|
263 |
-
- s
|
264 |
-
- t
|
265 |
-
- r
|
266 |
-
- a
|
267 |
-
- w
|
268 |
-
- b
|
269 |
-
- e
|
270 |
-
- r
|
271 |
-
- r
|
272 |
-
- y
|
273 |
-
3. **Count the occurrences of `r`**: Now, we count how many times the letter `r` appears in the list.
|
274 |
-
- The first `r` is in the third position.
|
275 |
-
- The second `r` is in the seventh position.
|
276 |
-
- The third `r` is in the ninth position.
|
277 |
-
4. **Verify the count**: Double-check that we have counted all the `r`s correctly and haven't missed any or counted any other letters as `r`.
|
278 |
-
|
279 |
-
After following these steps, we find that there are three `r`s in the word "strawberry."
|
280 |
-
|
281 |
-
Therefore, the number of `r`s in strawberry is \(\boxed{3}\).
|
282 |
-
```
|
283 |
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
Example 2:
|
299 |
-
Input: `nums = [1,2,3,4,4,4,4,5,6,7]`, `k = 5`
|
300 |
-
Output: `false`
|
301 |
-
|
302 |
-
Hints:
|
303 |
-
- `2 <= nums.length <= 1000`
|
304 |
-
- `1 <= 2 * k <= nums.length`
|
305 |
-
- `-1000 <= nums[i] <= 1000`
|
306 |
-
|
307 |
-
Please output the answer in the following format:
|
308 |
-
```
|
309 |
|
|
|
|
|
|
|
310 |
```
|
311 |
-
Output: ### Planning Phase
|
312 |
-
|
313 |
-
#### Key Steps:
|
314 |
-
1. **Iterate through the array**: Start from the first element and check every possible starting point for a subarray of length `k`.
|
315 |
-
2. **Check for strictly increasing subarray**: For each starting point, check if the subarray of length `k` is strictly increasing.
|
316 |
-
3. **Check the next subarray**: If the current subarray is strictly increasing, check if the next subarray of length `k` starting from the next element is also strictly increasing.
|
317 |
-
4. **Return the result**: If both subarrays are strictly increasing and adjacent, return `true`. If no such pair is found after checking all possible starting points, return `false`.
|
318 |
-
|
319 |
-
#### Important Considerations:
|
320 |
-
- **Edge Cases**:
|
321 |
-
- Arrays with fewer than `2*k` elements.
|
322 |
-
- Arrays where no subarray of length `k` is strictly increasing.
|
323 |
-
- Arrays where multiple subarrays of length `k` are strictly increasing, but none are adjacent.
|
324 |
-
- **Pitfalls**:
|
325 |
-
- Ensure that the subarrays are strictly increasing, not just non-decreasing.
|
326 |
-
- Handle cases where the array length is exactly `2*k`.
|
327 |
-
|
328 |
-
### Code Solution
|
329 |
-
|
330 |
-
|
331 |
-
class Solution(object):
|
332 |
-
def hasincreasingSubarrays(self, nums, k):
|
333 |
-
# Iterate through the array, stopping at the point where the next subarray would go out of bounds
|
334 |
-
for i in range(len(nums) - k + 1):
|
335 |
-
# Check if the current subarray is strictly increasing
|
336 |
-
is_increasing = True
|
337 |
-
for j in range(1, k):
|
338 |
-
if nums[i + j] <= nums[i + j - 1]:
|
339 |
-
is_increasing = False
|
340 |
-
break
|
341 |
-
|
342 |
-
# If the current subarray is strictly increasing, check the next subarray
|
343 |
-
if is_increasing:
|
344 |
-
# Check if the next subarray is strictly increasing
|
345 |
-
for j in range(1, k):
|
346 |
-
if nums[i + k + j] <= nums[i + k + j - 1]:
|
347 |
-
break
|
348 |
-
else:
|
349 |
-
# If the next subarray is also strictly increasing, return True
|
350 |
-
return True
|
351 |
-
|
352 |
-
# If no such pair of subarrays is found, return False
|
353 |
-
return False
|
354 |
-
|
355 |
-
### Solution Explanation
|
356 |
-
|
357 |
-
#### Detailed Explanation:
|
358 |
-
1. **Iterate through the array**:
|
359 |
-
- We use a loop to iterate through the array from the first element to the point where the next subarray of length `k` would go out of bounds (`len(nums) - k + 1`).
|
360 |
-
|
361 |
-
2. **Check for strictly increasing subarray**:
|
362 |
-
- For each starting index `i`, we check if the subarray `nums[i:i+k]` is strictly increasing.
|
363 |
-
- We use a nested loop to compare each element with the previous one. If any element is not greater than the previous one, we set `is_increasing` to `False` and break out of the loop.
|
364 |
-
|
365 |
-
3. **Check the next subarray**:
|
366 |
-
- If the current subarray is strictly increasing, we then check the next subarray starting from `i + k`.
|
367 |
-
- We use another nested loop to check if `nums[i+k:i+2*k]` is strictly increasing.
|
368 |
-
- If the next subarray is also strictly increasing, we return `True`.
|
369 |
-
|
370 |
-
4. **Return the result**:
|
371 |
-
- If we complete the loop without finding any such pair of subarrays, we return `False`.
|
372 |
-
|
373 |
-
#### Clarity and Breakdown:
|
374 |
-
- **Outer Loop**: The outer loop iterates through each possible starting point for a subarray of length `k`.
|
375 |
-
- **Inner Loop for Current Subarray**: The inner loop checks if the current subarray is strictly increasing.
|
376 |
-
- **Inner Loop for Next Subarray**: If the current subarray is strictly increasing, the inner loop checks if the next subarray is also strictly increasing.
|
377 |
-
- **Return True/False**: If both subarrays are strictly increasing and adjacent, we return `True`. If no such pair is found, we return `False`.
|
378 |
-
|
379 |
-
This approach ensures that we efficiently check all possible pairs of adjacent strictly increasing subarrays of length `k` in the given array.
|
380 |
-
```
|
381 |
-
|
382 |
-
## Skywork o1 Open PRM
|
383 |
-
We introduce two advanced Process-Reward-Models (PRM) for reasoning tasks:
|
384 |
-
- Skywork o1 Open-PRM-Qwen2.5-1.5B: Achieves the performance of 8B models, competing with advanced models like RLHFlow's Llama3.1-8B-PRM-Deepseek-Data and OpenR's Math-psa-7B.
|
385 |
-
- Skywork o1 Open-PRM-Qwen2.5-7B: Matches or surpasses larger scale models like Qwen2.5-Math-RM-72B on most benchmarks, setting a new standard for AI reasoning.
|
386 |
-
|
387 |
-
<img src="misc/prm_result_math.png" width="1000"/>
|
388 |
-
<img src="misc/prm_result_code.png" width="1000"/>
|
389 |
-
|
390 |
-
The inference code is publicly available at https://github.com/SkyworkAI/skywork-o1-prm-inference.
|
391 |
|
392 |
-
#
|
|
|
|
|
|
|
|
|
393 |
|
394 |
-
If you have any questions, please feel free to reach us at {zifeng.cao, liang.zhao, liang.zeng, tianwen.wei}@kunlun-inc.com.
|
395 |
|
396 |
# LICENSE
|
397 |
The community usage of Skywork models require Skywork Community License. The Skywork models support commercial use. If you plan to use the Skywork models or its derivatives for commercial purposes, you must abide by terms and conditions within Skywork Community License.
|
@@ -408,7 +262,7 @@ If you find our work helpful, please feel free to cite us using the following Bi
|
|
408 |
title={Skywork-o1 Open Series},
|
409 |
author={Skywork-o1 Team},
|
410 |
year={2024},
|
411 |
-
month={
|
412 |
howpublished={\url{https://huggingface.co/Skywork}},
|
413 |
url={https://huggingface.co/Skywork},
|
414 |
}
|
|
|
24 |
|
25 |
Different from mere reproductions of the OpenAI o1 model, the Skywork o1 Open model series not only exhibits innate thinking, planning, and reflecting capabilities in its outputs, but also shows significant improvements in reasoning skills on standard benchmarks. This series represents a strategic advancement in AI capabilities, moving a previously weaker base model towards the state-of-the-art (SOTA) in reasoning tasks.
|
26 |
|
27 |
+
If you are interested in the Skywork o1 Open model series, please check out the [o1-llama-3.1-8b](https://huggingface.co/Skywork/o1-llama-3.1-8b) model.
|
|
|
|
|
28 |
|
|
|
29 |
|
|
|
30 |
|
31 |
+
# Model Information
|
32 |
+
The Skywork-o1-Open-PRM series are trained on [**Qwen2.5-Math-1.5B-Instruct**](https://huggingface.co/Qwen/Qwen2.5-Math-1.5B-Instruct) and [**Qwen2.5-Math-7B-Instruct**](https://huggingface.co/Qwen/Qwen2.5-Math-7B-Instruct).
|
|
|
|
|
33 |
|
|
|
34 |
|
35 |
+
# PRM Evaluation
|
36 |
|
37 |
+
## Evaluation Settings
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
### Mathematical Evaluation
|
40 |
+
We utilized the evaluation scripts from [Qwen2.5-Math](https://github.com/QwenLM/Qwen2.5-Math) and followed their configuration to ensure consistency. The selected datasets include **GSM8K**, **MATH**, **GaoKao**, **CN-Middle School 24**, **OlympiadBench**, **AMC-23**, and **AIME-24**. Among these, **GaoKao** and **CN-Middle School 24** are Chinese datasets, while the remaining datasets are in English. Notably, **OlympiadBench**, **AIME-24**, and **AMC-23** are competition-level datasets.
|
41 |
|
42 |
+
### Code Evaluation
|
43 |
+
For code evaluation, we adopted the evaluation scripts from [Qwen2.5-Coder](https://github.com/QwenLM/Qwen2.5-Coder), maintaining the same configuration. The selected datasets include **HumanEval**, **MBPP**, and **LiveCodeBench**, with **LiveCodeBench** specifically using the version **2024.01-2024-11**.
|
|
|
44 |
|
|
|
45 |
|
46 |
+
## Evaluation Base Models
|
|
|
47 |
|
48 |
+
We evaluated the performance of RMs on three base models: **Qwen2.5-7B-Instruct**, **Llama3.1-8B-Instruct**, and **Skywork-o1-Open-8B**. Data sampling was conducted to verify the performance of the RMs across different models. The sampling temperature was set to **0.7** for mathematical problems and **1.0** for code-related tasks.
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
## Compared RMs
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
+
- [Qwen2.5-Math-RM-72B](https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B): An open-source ORM provided by the Qwen team.
|
54 |
+
- [OpenR-MATH-psa-PRM-7B](https://huggingface.co/openreasoner/Math-psa): An open-source PRM from the OpenR project.
|
55 |
+
- [RLHFlow-Deepseek-Data-PRM-8B](https://huggingface.co/RLHFlow/Llama3.1-8B-PRM-Deepseek-Data): An open-source PRM from the RLHFlow project.
|
56 |
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
+
## Evaluation Metrics
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
- **Greedy Sampling Pass@1**: Uses greedy sampling for generating answers.
|
61 |
+
- **Majority Voting@64**: Randomly samples 64 answers.
|
62 |
+
- **Best-of-N@64**: Ranks 64 answers based on output values provided by the Reward Model (RM). The weighting methods differ for ORM and PRM:
|
63 |
+
- For **ORM**, only the reward from the final step is used.
|
64 |
+
- For **PRM**, the average reward across all steps is used for weighting.
|
65 |
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
## Evaluation Results
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
### Mathematical Evaluation
|
|
|
|
|
70 |
|
71 |
+
#### Skywork-o1-Open-8B
|
72 |
+
| Reward Model | Method | GSM8K | MATH | GaoKao | CN-Middle School 24 | OlympiadBench | AIME-24 | AMC-23 | Avg |
|
73 |
+
|--------------------------------|-------------------------|--------|--------|--------|---------------------|---------------|---------|--------|------|
|
74 |
+
| N/A | Greedy Sampling Pass@1 | 91.6 | 78.1 | 63.6 | 67.3 | 43.1 | 13.3 | 55.0 | 58.9 |
|
75 |
+
| N/A | Majority Voting@64 | 93.9 | 84.3 | 69.5 | 73.3 | 50.4 | 16.7 | 52.5 | 62.9 |
|
76 |
+
| OpenR-MATH-psa-PRM-**7B** | Best-of-N@64 | 95.1 | 82.7 | 67.1 | 70.3 | 47.6 | 20.0 | 57.5 | 62.9 |
|
77 |
+
| RLHFlow-Deepseek-Data-PRM-**8B** | Best-of-N@64 | 94.4 | 80.1 | 59.1 | 74.3 | 43.0 | 20.0 | 50.0 | 60.1 |
|
78 |
+
| Qwen2.5-Math-RM-**72B** | Best-of-N@64 | 96.1 | 86.9 | **76.0** | 76.2 | **53.3** | **26.7** | **65.0** | **68.6** |
|
79 |
+
| Skywork-o1-Open-PRM-**1.5B** | Best-of-N@64 | 94.5 | 85.0 | 65.6 | 73.3 | 49.9 | 16.7 | 62.5 | 63.9 |
|
80 |
+
| Skywork-o1-Open-PRM-**7B** | Best-of-N@64 | **96.7** | **87.0** | 70.3 | **76.2** | 52.3 | 23.3 | **65.0** | 67.3 |
|
81 |
|
|
|
|
|
82 |
|
83 |
+
#### Qwen2.5-7B-Instruct
|
84 |
+
| Reward Model | Method | GSM8K | MATH | GaoKao | CN-Middle School 24 | OlympiadBench | AIME-24 | AMC-23 | Avg |
|
85 |
+
|--------------------------------|-------------------------|--------|--------|--------|---------------------|---------------|---------|--------|------|
|
86 |
+
| N/A | Greedy Sampling Pass@1 | 91.9 | 75.2 | 55.6 | 75.2 | 39.1 | 13.3 | 45.0 | 56.5 |
|
87 |
+
| N/A | Majority Voting@64 | 93.5 | 78.4 | 55.3 | 78.2 | 40.1 | 13.3 | 62.5 | 60.2 |
|
88 |
+
| OpenR-MATH-psa-PRM-**7B** | Best-of-N@64 | 93.9 | 77.9 | 52.4 | 73.3 | 40.7 | 10.0 | 55.0 | 57.6 |
|
89 |
+
| RLHFlow-Deepseek-Data-PRM-**8B** | Best-of-N@64 | 94.1 | 78.1 | 53.2 | 75.2 | 39.1 | 16.7 | 55.0 | 58.8 |
|
90 |
+
| Qwen2.5-Math-RM-**72B** | Best-of-N@64 | 94.8 | **82.4** | **65.2** | **80.2** | **45.0** | **13.3** | 62.5 | 63.4 |
|
91 |
+
| Skywork-o1-Open-PRM-**1.5B** | Best-of-N@64 | 93.3 | 79.8 | 56.1 | 74.3 | 43.9 | 10.0 | 67.5 | 60.7 |
|
92 |
+
| Skywork-o1-Open-PRM-**7B** | Best-of-N@64 | **94.9** | 81.9 | 56.3 | 75.2 | 44.9 | **13.3** | **65.0** | 61.6 |
|
93 |
|
|
|
94 |
|
|
|
|
|
95 |
|
96 |
+
#### Llama3.1-8B-Instruct
|
97 |
+
| Reward Model | Method | GSM8K | MATH | GaoKao | CN-Middle School 24 | OlympiadBench | AIME-24 | AMC-23 | Avg |
|
98 |
+
|--------------------------------|-------------------------|--------|--------|--------|---------------------|---------------|---------|--------|------|
|
99 |
+
| N/A | Greedy Sampling Pass@1 | 85.3 | 49.7 | 25.3 | 47.5 | 16.6 | 6.7 | 27.5 | 36.9 |
|
100 |
+
| N/A | Majority Voting@64 | 90.9 | 62.9 | 28.0 | 56.4 | 26.4 | 13.3 | 37.5 | 45.1 |
|
101 |
+
| OpenR-MATH-psa-PRM-**7B** | Best-of-N@64 | 91.8 | 59.4 | 24.7 | 47.5 | 23.0 | 13.3 | 35.0 | 42.1 |
|
102 |
+
| RLHFlow-Deepseek-Data-PRM-**8B** | Best-of-N@64 | 89.8 | 56.1 | 24.0 | 40.6 | 20.4 | 0.0 | 35.0 | 38.0 |
|
103 |
+
| Qwen2.5-Math-RM-**72B** | Best-of-N@64 | **94.9** | **72.5** | **44.9** | **65.3** | **34.4** | **23.3** | **60.0** | 56.5 |
|
104 |
+
| Skywork-o1-Open-PRM-**1.5B** | Best-of-N@64 | 91.7 | 65.6 | 26.8 | 49.5 | 27.0 | 16.7 | **60.0** | 48.2 |
|
105 |
+
| Skywork-o1-Open-PRM-**7B** | Best-of-N@64 | 94.0 | 69.8 | 32.0 | 56.4 | 29.9 | 16.7 | 52.5 | 50.2 |
|
106 |
|
|
|
|
|
107 |
|
108 |
+
### Code Evaluation
|
109 |
+
Since the compared PRMs have not been trained on code-related tasks, this section focuses solely on the performance of Skywork-o1-Open-PRM.
|
110 |
|
111 |
+
#### Skywork-o1-Open-8B
|
112 |
|
113 |
+
| Reward Model | Method | MBPP | MBPP+ | HumanEval | HumanEval+ | LiveCodeBench-2024.01-2024-11 |
|
114 |
+
|--------------------------|-------------------------|-------|-------|-----------|------------|-------------------------------|
|
115 |
+
| N/A | Greedy Sampling Pass@1 | 77.9 | 65.2 | **82.3** | **78.7** | 26.0 |
|
116 |
+
| Skywork-o1-Open-PRM-7B | Best-of-N@64 | **81.2** | **68.5** | 81.1 | 74.4 | **31.3** |
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
+
#### Qwen2.5-7B-Instruct
|
|
|
|
|
120 |
|
121 |
+
| Reward Model | Method | MBPP | MBPP+ | HumanEval | HumanEval+ | LiveCodeBench-2024.01-2024-11 |
|
122 |
+
|--------------------------|-------------------------|-------|-------|-----------|------------|-------------------------------|
|
123 |
+
| N/A | Greedy Sampling Pass@1 | 81.7 | 69.3 | **84.8** | **78.0** | 25.3 |
|
124 |
+
| Skywork-o1-Open-PRM-7B | Best-of-N@64 | **84.9** | **72.5** | 83.5 | **78.0** | |
|
125 |
|
126 |
+
#### Llama3.1-8B-Instruct
|
|
|
|
|
127 |
|
128 |
+
| Reward Model | Method | MBPP | MBPP+ | HumanEval | HumanEval+ | LiveCodeBench-2024.01-2024-11 |
|
129 |
+
|--------------------------|-------------------------|-------|-------|-----------|------------|-------------------------------|
|
130 |
+
| N/A | Greedy Sampling Pass@1 | 75.7 | 64.0 | 68.3 | 63.4 | 16.1 |
|
131 |
+
| Skywork-o1-Open-PRM-7B | Best-of-N@64 | **77.8** | **66.1** | **73.8** | **67.7** | **23.3** |
|
132 |
|
|
|
|
|
|
|
133 |
|
|
|
134 |
|
135 |
+
# Quick Start
|
136 |
+
For detailed instructions on using PRM, please refer to the [Skywork-o1-PRM-Inference Repository](https://github.com/SkyworkAI/skywork-o1-prm-inference).
|
|
|
137 |
|
|
|
138 |
|
|
|
|
|
|
|
139 |
|
140 |
+
## huggingface inference
|
141 |
|
142 |
+
1. clone skywork prm inference repo
|
143 |
+
```shell
|
144 |
+
git clone https://github.com/SkyworkAI/skywork-o1-prm-inference.git
|
145 |
+
cd skywork-o1-prm-inference
|
146 |
```
|
147 |
+
1. run_prm_inference
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
```python
|
150 |
+
from transformers import AutoTokenizer
|
151 |
+
from model_utils.prm_model import PRM_MODEL
|
152 |
+
from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
|
153 |
|
154 |
+
prm_model_path = "/path/to/prm_model"
|
155 |
+
tokenizer = AutoTokenizer.from_pretrained(prm_model_path, trust_remote_code=True)
|
156 |
+
datas = [
|
157 |
+
{
|
158 |
+
"problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
|
159 |
+
"response" : "To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. **Calculate the total number of eggs laid by the ducks per day.**\n Janet's ducks lay 16 eggs per day.\n2. **Determine the number of eggs Janet uses each day.**\n - She eats 3 eggs for breakfast every morning.\n - She bakes muffins for her friends every day with 4 eggs.\n So, the total number of eggs used per day is:\n \\[\n 3 + 4 = 7 \\text{ eggs}\n \\]\n3. **Calculate the number of eggs Janet sells at the farmers' market each day.**\n Subtract the number of eggs used from the total number of eggs laid:\n \\[\n 16 - 7 = 9 \\text{ eggs}\n \\]\n4. **Determine how much money Janet makes from selling the eggs.**\n She sells each egg for $2, so the total amount of money she makes is:\n \\[\n 9 \\times 2 = 18 \\text{ dollars}\n \\]\nTherefore, the amount of money Janet makes every day at the farmers' market is \\(\\boxed{18}\\)."
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
|
163 |
+
"response" : "To find out how much Janet makes every day at the farmers' market, we need to follow these steps:\n1. Calculate the total number of eggs laid by the ducks: \n - Janet's ducks lay 16 eggs per day.\n2. Calculate the total number of eggs consumed by Janet and used for baking:\n - Janet eats 3 eggs for breakfast.\n - She bakes 4 eggs for muffins.\n - Total eggs used: 3 (breakfast) + 4 (baking) = 7 eggs.\n3. Calculate the remaining number of eggs for sale:\n - Total eggs laid: 16\n - Eggs used: 7\n - Remaining eggs: 16 - 7 = 9 eggs\n4. Calculate the total amount of money made at the farmers' market:\n - Price per egg: $2\n - Number of eggs sold: 9\n - Total money made: 9 * $2 = $18\nTherefore, Janet makes $\\boxed{18}$ dollars every day at the farmers' market."
|
164 |
+
}
|
165 |
+
]
|
166 |
|
|
|
167 |
|
168 |
+
processed_data = [prepare_input(d["problem"], d["response"], tokenizer=tokenizer, step_token="\n") for d in datas]
|
169 |
+
input_ids, steps, reward_flags = zip(*processed_data)
|
|
|
170 |
|
171 |
+
model = PRM_MODEL.from_pretrained(prm_model_path, device_map="auto").eval()
|
172 |
+
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(input_ids, reward_flags, tokenizer.pad_token_id)
|
173 |
+
_, _, rewards = model(input_ids=input_ids, attention_mask=attention_mask, return_probs=True)
|
174 |
+
step_rewards = derive_step_rewards(rewards, reward_flags)
|
175 |
+
print("step_rewards:",step_rewards[0])
|
176 |
+
print("step_rewards:",step_rewards[1])
|
177 |
+
```
|
178 |
|
179 |
+
## vllm server for inference
|
|
|
|
|
|
|
180 |
|
181 |
+
1. install vllm and install vllm prm plugin
|
182 |
+
```shell
|
183 |
+
pip install vllm==v0.6.4.post1
|
184 |
+
git clone https://github.com/SkyworkAI/skywork-o1-prm-inference.git
|
185 |
+
cd skywork-o1-prm-inference
|
186 |
+
pip install -e .
|
187 |
+
```
|
188 |
|
189 |
+
2. start vllm server
|
190 |
+
```shell
|
191 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 vllm serve /path/to/prm_model \
|
192 |
+
--host 0.0.0.0 \
|
193 |
+
--port 8081 \
|
194 |
+
--tensor-parallel-size 4 \
|
195 |
+
--gpu-memory-utilization 0.9 \
|
196 |
+
--enable-prefix-caching \
|
197 |
+
--dtype auto
|
198 |
+
```
|
199 |
|
200 |
+
3. request server for inference
|
201 |
|
202 |
+
```python
|
203 |
+
from openai import OpenAI
|
204 |
+
from transformers import AutoTokenizer
|
205 |
+
from model_utils.io_utils import prepare_input, derive_step_rewards_vllm
|
206 |
|
207 |
+
prm_model_path = "/path/to/prm_model"
|
208 |
+
tokenizer = AutoTokenizer.from_pretrained(prm_model_path, trust_remote_code=True)
|
209 |
+
datas = [
|
210 |
+
{
|
211 |
+
"problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
|
212 |
+
"response" : "To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. **Calculate the total number of eggs laid by the ducks per day.**\n Janet's ducks lay 16 eggs per day.\n2. **Determine the number of eggs Janet uses each day.**\n - She eats 3 eggs for breakfast every morning.\n - She bakes muffins for her friends every day with 4 eggs.\n So, the total number of eggs used per day is:\n \\[\n 3 + 4 = 7 \\text{ eggs}\n \\]\n3. **Calculate the number of eggs Janet sells at the farmers' market each day.**\n Subtract the number of eggs used from the total number of eggs laid:\n \\[\n 16 - 7 = 9 \\text{ eggs}\n \\]\n4. **Determine how much money Janet makes from selling the eggs.**\n She sells each egg for $2, so the total amount of money she makes is:\n \\[\n 9 \\times 2 = 18 \\text{ dollars}\n \\]\nTherefore, the amount of money Janet makes every day at the farmers' market is \\(\\boxed{18}\\)."
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"problem" : "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
|
216 |
+
"response" : "To find out how much Janet makes every day at the farmers' market, we need to follow these steps:\n1. Calculate the total number of eggs laid by the ducks: \n - Janet's ducks lay 16 eggs per day.\n2. Calculate the total number of eggs consumed by Janet and used for baking:\n - Janet eats 3 eggs for breakfast.\n - She bakes 4 eggs for muffins.\n - Total eggs used: 3 (breakfast) + 4 (baking) = 7 eggs.\n3. Calculate the remaining number of eggs for sale:\n - Total eggs laid: 16\n - Eggs used: 7\n - Remaining eggs: 16 - 7 = 9 eggs\n4. Calculate the total amount of money made at the farmers' market:\n - Price per egg: $2\n - Number of eggs sold: 9\n - Total money made: 9 * $2 = $18\nTherefore, Janet makes $\\boxed{18}$ dollars every day at the farmers' market."
|
217 |
+
}
|
218 |
+
]
|
219 |
|
220 |
+
# data preprocessing
|
221 |
+
processed_data = [prepare_input(d["problem"], d["response"], tokenizer=tokenizer, step_token="\n") for d in datas]
|
222 |
+
input_ids, steps, reward_flags = zip(*processed_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
+
openai_api_key = "EMPTY"
|
225 |
+
openai_api_base = "http://localhost:8081/v1"
|
226 |
+
client = OpenAI(
|
227 |
+
# defaults to os.environ.get("OPENAI_API_KEY")
|
228 |
+
api_key=openai_api_key,
|
229 |
+
base_url=openai_api_base,
|
230 |
+
)
|
231 |
+
models = client.models.list()
|
232 |
+
model = models.data[0].id
|
233 |
+
rewards = client.embeddings.create(
|
234 |
+
input=input_ids,
|
235 |
+
model=model,
|
236 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
|
238 |
+
step_rewards = derive_step_rewards_vllm(rewards, reward_flags)
|
239 |
+
print("step_rewards:",step_rewards[0])
|
240 |
+
print("step_rewards:",step_rewards[1])
|
241 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
+
# TODO
|
244 |
+
- Add more results for step-wise RM evaluation methods.
|
245 |
+
- Adjust the RM architecture to enhance compatibility with vLLM/sglang inference.
|
246 |
+
- Expand RM use cases by incorporating more types of reasoning tasks.
|
247 |
+
- Mitigate performance conflicts across different reasoning tasks.
|
248 |
|
|
|
249 |
|
250 |
# LICENSE
|
251 |
The community usage of Skywork models require Skywork Community License. The Skywork models support commercial use. If you plan to use the Skywork models or its derivatives for commercial purposes, you must abide by terms and conditions within Skywork Community License.
|
|
|
262 |
title={Skywork-o1 Open Series},
|
263 |
author={Skywork-o1 Team},
|
264 |
year={2024},
|
265 |
+
month={November},
|
266 |
howpublished={\url{https://huggingface.co/Skywork}},
|
267 |
url={https://huggingface.co/Skywork},
|
268 |
}
|