mrzjy commited on
Commit
16e2463
·
verified ·
1 Parent(s): 397573a

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +92 -0
README.md CHANGED
@@ -310,6 +310,98 @@ Let's move on nonetheless to see how it actually performs with LLM sampling.
310
 
311
  Without delving into further reinforcement learning, can we directly apply PRM with our LLMs? The answer is YES!
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  - Test-Time Scaling
314
 
315
  Since this experiment does not aim to achieve O1-like reasoning behavior, the test-time compute here can be defined simply as a function of `rejection_sampling_size`. Increasing the sampling size during inference leads to higher computational cost, but as expected, it also improves performance according to our PRM.
 
310
 
311
  Without delving into further reinforcement learning, can we directly apply PRM with our LLMs? The answer is YES!
312
 
313
+ Here we experiment with a simplistic variant of MCTS sampling, with the help of the `continue_final_message` feature. The code snippet is provided as follows:
314
+
315
+ ```python
316
+ def direct_proba(x):
317
+ s = sum(x)
318
+ return [e/s for e in x]
319
+
320
+ async def _guided_generation(sample):
321
+ import time
322
+ start_time = time.time()
323
+ outlines = []
324
+ mcts = [{"children": [], "scores": [0], "chosen": 0}]
325
+ prompt, steps = sample["messages"][0]["content"], []
326
+
327
+ async def request_single_response(messages):
328
+ response = await a_request_vllm_chat(
329
+ messages, model_name, temperature=0.7,
330
+ stop="\n",
331
+ logit_bias={9: -1e4, 353: -1e4, 334: -1e4, 3070: -1e4}, # prevent some unexpected tokens
332
+ max_tokens=200,
333
+ extra_body={
334
+ "continue_final_message": True,
335
+ "add_generation_prompt": False,
336
+ "min_tokens": 5
337
+ },
338
+ )
339
+ return response
340
+
341
+ for i in range(sample["n_chapter"]):
342
+ history = "\n".join(outlines)
343
+ if i > 0:
344
+ history += "\n"
345
+ if sample["lang"] == "zh":
346
+ assistant_prefix = f"第{i+1}章:"
347
+ else:
348
+ assistant_prefix = f"Chapter {i+1}:"
349
+ messages = [
350
+ {"role": "system", "content": sample["messages"][0]["content"]},
351
+ {"role": "assistant", "content": history + assistant_prefix}
352
+ ]
353
+
354
+ # Perform 4 parallel requests (sampling size = 4)
355
+ responses = await asyncio.gather(*[request_single_response(messages) for _ in range(4)])
356
+ responses_content = [response["content"] for response in responses]
357
+
358
+ # sampling based on rewards
359
+ batch_steps = [outlines + [assistant_prefix + res] for res in responses_content]
360
+ batch_prompt = [prompt] * len(batch_steps)
361
+ raw_scores = evaluate_reward(batch_prompt, batch_steps)
362
+ scores = direct_proba(raw_scores)
363
+ chosen = random.choices(
364
+ population=[assistant_prefix + res for res in responses_content],
365
+ weights=scores,
366
+ k=1
367
+ )[0]
368
+ mcts.append(
369
+ {
370
+ "children": [assistant_prefix + res for res in responses_content],
371
+ "scores": scores,
372
+ "raw_scores": raw_scores,
373
+ "chosen": chosen
374
+ }
375
+ )
376
+ current_outline = chosen
377
+ outlines.append(current_outline)
378
+
379
+ return outlines, mcts, time.time() - start_time
380
+
381
+
382
+ def evaluate_reward(batch_prompt, batch_steps, separator="\n"):
383
+ """pipe: assume you have already loaded the PRM pipeline with model checkpoint"""
384
+ # Add a separator between the prompt and each steps
385
+ assert len(batch_prompt) == len(batch_steps)
386
+ batch_text = [separator.join((prompt, *steps)) + separator for prompt, steps in zip(batch_prompt, batch_steps)]
387
+ preds = [res[-1] for res in pipe(batch_text)]
388
+ scores = []
389
+ for pred in preds:
390
+ score, pred_entity = pred["score"], pred["entity"]
391
+ # this is tricky (returned score if the proba of the currect class)
392
+ if pred_entity == "LABEL_0":
393
+ score = 1 - score
394
+ scores.append(score)
395
+ return scores
396
+ ```
397
+
398
+ - Case
399
+
400
+ |Prompt|Outline Generation with Sequential Rejection Sampling|
401
+ |--|--|
402
+ |![sequential_rejection_sampling_zh_prompt.png](image%2Fsequential_rejection_sampling_zh_prompt.png)|![sequential_rejection_sampling_zh.png](image%2Fsequential_rejection_sampling_zh.png)|
403
+
404
+
405
  - Test-Time Scaling
406
 
407
  Since this experiment does not aim to achieve O1-like reasoning behavior, the test-time compute here can be defined simply as a function of `rejection_sampling_size`. Increasing the sampling size during inference leads to higher computational cost, but as expected, it also improves performance according to our PRM.