mrzjy commited on
Commit
11676a6
Β·
verified Β·
1 Parent(s): 16e2463

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -4
README.md CHANGED
@@ -310,14 +310,15 @@ Let's move on nonetheless to see how it actually performs with LLM sampling.
310
 
311
  Without delving into further reinforcement learning, can we directly apply PRM with our LLMs? The answer is YES!
312
 
313
- Here we experiment with a simplistic variant of MCTS sampling, with the help of the `continue_final_message` feature. The code snippet is provided as follows:
314
 
315
  ```python
316
  def direct_proba(x):
317
  s = sum(x)
318
  return [e/s for e in x]
319
 
320
- async def _guided_generation(sample):
 
321
  import time
322
  start_time = time.time()
323
  outlines = []
@@ -338,6 +339,7 @@ async def _guided_generation(sample):
338
  )
339
  return response
340
 
 
341
  for i in range(sample["n_chapter"]):
342
  history = "\n".join(outlines)
343
  if i > 0:
@@ -351,8 +353,8 @@ async def _guided_generation(sample):
351
  {"role": "assistant", "content": history + assistant_prefix}
352
  ]
353
 
354
- # Perform 4 parallel requests (sampling size = 4)
355
- responses = await asyncio.gather(*[request_single_response(messages) for _ in range(4)])
356
  responses_content = [response["content"] for response in responses]
357
 
358
  # sampling based on rewards
 
310
 
311
  Without delving into further reinforcement learning, can we directly apply PRM with our LLMs? The answer is YES!
312
 
313
+ Here we experiment with a simplistic variant of MCTS sampling, namely the sequential rejection sampling (Only one path is fully explored at the end), with the help of the `continue_final_message` feature and VLLM server. The code snippet is provided as follows:
314
 
315
  ```python
316
  def direct_proba(x):
317
  s = sum(x)
318
  return [e/s for e in x]
319
 
320
+
321
+ async def _guided_generation(sample, sampling_size: int):
322
  import time
323
  start_time = time.time()
324
  outlines = []
 
339
  )
340
  return response
341
 
342
+ # sequential rejection sampling
343
  for i in range(sample["n_chapter"]):
344
  history = "\n".join(outlines)
345
  if i > 0:
 
353
  {"role": "assistant", "content": history + assistant_prefix}
354
  ]
355
 
356
+ # Perform parallel requests (didn't apply n/best_of parameter to prevent server OOM)
357
+ responses = await asyncio.gather(*[request_single_response(messages) for _ in range(sampling_size)])
358
  responses_content = [response["content"] for response in responses]
359
 
360
  # sampling based on rewards