Update README.md
Browse files
README.md
CHANGED
@@ -310,14 +310,15 @@ Let's move on nonetheless to see how it actually performs with LLM sampling.
|
|
310 |
|
311 |
Without delving into further reinforcement learning, can we directly apply PRM with our LLMs? The answer is YES!
|
312 |
|
313 |
-
Here we experiment with a simplistic variant of MCTS sampling, with the help of the `continue_final_message` feature. The code snippet is provided as follows:
|
314 |
|
315 |
```python
|
316 |
def direct_proba(x):
|
317 |
s = sum(x)
|
318 |
return [e/s for e in x]
|
319 |
|
320 |
-
|
|
|
321 |
import time
|
322 |
start_time = time.time()
|
323 |
outlines = []
|
@@ -338,6 +339,7 @@ async def _guided_generation(sample):
|
|
338 |
)
|
339 |
return response
|
340 |
|
|
|
341 |
for i in range(sample["n_chapter"]):
|
342 |
history = "\n".join(outlines)
|
343 |
if i > 0:
|
@@ -351,8 +353,8 @@ async def _guided_generation(sample):
|
|
351 |
{"role": "assistant", "content": history + assistant_prefix}
|
352 |
]
|
353 |
|
354 |
-
# Perform
|
355 |
-
responses = await asyncio.gather(*[request_single_response(messages) for _ in range(
|
356 |
responses_content = [response["content"] for response in responses]
|
357 |
|
358 |
# sampling based on rewards
|
|
|
310 |
|
311 |
Without delving into further reinforcement learning, can we directly apply PRM with our LLMs? The answer is YES!
|
312 |
|
313 |
+
Here we experiment with a simplistic variant of MCTS sampling, namely the sequential rejection sampling (Only one path is fully explored at the end), with the help of the `continue_final_message` feature and VLLM server. The code snippet is provided as follows:
|
314 |
|
315 |
```python
|
316 |
def direct_proba(x):
|
317 |
s = sum(x)
|
318 |
return [e/s for e in x]
|
319 |
|
320 |
+
|
321 |
+
async def _guided_generation(sample, sampling_size: int):
|
322 |
import time
|
323 |
start_time = time.time()
|
324 |
outlines = []
|
|
|
339 |
)
|
340 |
return response
|
341 |
|
342 |
+
# sequential rejection sampling
|
343 |
for i in range(sample["n_chapter"]):
|
344 |
history = "\n".join(outlines)
|
345 |
if i > 0:
|
|
|
353 |
{"role": "assistant", "content": history + assistant_prefix}
|
354 |
]
|
355 |
|
356 |
+
# Perform parallel requests (didn't apply n/best_of parameter to prevent server OOM)
|
357 |
+
responses = await asyncio.gather(*[request_single_response(messages) for _ in range(sampling_size)])
|
358 |
responses_content = [response["content"] for response in responses]
|
359 |
|
360 |
# sampling based on rewards
|