File size: 3,287 Bytes
fa36636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
---
license: mit
language:
- en
base_model:
- meta-llama/Llama-3.1-8B-Instruct
pipeline_tag: token-classification
---

<div align="center">
<h1>
  MedSSS-8B-PRM
</h1>
</div>

<div align="center">
<a href="https://github.com/pixas/MedSSS" target="_blank">GitHub</a> | <a href="" target="_blank">Paper</a>
</div>

# <span>Introduction</span>
**MedSSS-PRM** is a the PRM model designed for slow-thinking medical reasoning. It will assign a `[0-1]` float value for every internal reasoning step of **MedSSS-Policy**.

For more information, visit our GitHub repository: 
[https://github.com/pixas/MedSSS](https://github.com/pixas/MedSSS).




# <span>Usage</span>
We build the PRM model as a LoRA adapter, which saves the memory to use it.
As this LoRA adapter is built on `Meta-Llama3.1-8B-Instruct`, you need to first prepare the base model in your platform.

```python

def obtain_prm_value_for_single_pair(tokenizer, value_model, inputs, outputs):
    # `outputs` generated by the MedSSS-Policy
    response = outputs
    completions = [f"Step" + completion if not completion.startswith("Step") else completion for k, completion in enumerate(outputs.split("\n\nStep"))]
    
    messages = [
        {"role": "user", "content": inputs},
        {"role": "assistant", "content": response}
    ]
    input_text = tokenizer.apply_chat_template(messages, tokenize=False)

    response_begin_index = input_text.index(response)

    pre_response_input = input_text[:response_begin_index]
    after_response_input = input_text[response_begin_index + len(response):]
    completion_ids = [
        tokenizer(completion + "\n\n", add_special_tokens=False)['input_ids'] for completion in completions
    ]
    
    response_id = list(chain(*completion_ids))
    pre_response_id = tokenizer(pre_response_input, add_special_tokens=False)['input_ids']
    after_response_id = tokenizer(after_response_input, add_special_tokens=False)['input_ids']

    
    input_ids = pre_response_id + response_id + after_response_id
    
    value = value_model(input_ids=torch.tensor(input_ids).unsqueeze(0).to(value_model.device))  # [1, N]
    
    completion_index = []
    for i, completion in enumerate(completion_ids):
        if i == 0:
            completion_index.append(len(completion) + len(pre_response_id) - 1)
        else:
            completion_index.append(completion_index[-1] + len(completion))
    
    step_value = value[0, completion_index].cpu().numpy().tolist()
    return step_value
from transformers import AutoModelForTokenClassification, AutoTokenizer
from peft import PeftModel
base_model = AutoModelForTokenClassification.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",torch_dtype="auto",device_map="auto")
model = PeftModel.from_pretrained(base_model, "pixas/MedSSS_PRM", torc_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("pixas/MedSSS_PRM")
steps
input_text = "How to stop a cough?"
step_wise_generation = "Step 0: Let's break down this problem step by step.\n\nStep 1: First [omitted]"

value = obtain_prm_value_for_single_pair(tokenizer, model, input_text, step_wise_generation)
print(value)
```

MedSSS-PRM uses "\n\nStep" to separate intermediate steps. So the token classification happens before the next "Step k: " or the end of the sequence.