File size: 8,937 Bytes
7fc0cce
 
69f3ec1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fc0cce
24ffc35
1b901ae
a775416
1b901ae
c9f7a5d
1b901ae
a775416
1b901ae
a775416
1b901ae
 
 
 
 
 
 
a775416
1b901ae
a775416
1b901ae
 
 
 
 
 
 
 
 
 
a775416
1b901ae
a775416
1b901ae
 
 
a775416
1b901ae
a775416
 
 
 
0319c87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fc8ad5
0319c87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a775416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fca838
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
---
license: apache-2.0
datasets:
- openai/summarize_from_feedback
- openai/webgpt_comparisons
- berkeley-nest/Nectar
- Dahoas/instruct-synthetic-prompt-responses
- Anthropic/hh-rlhf
- lmsys/chatbot_arena_conversations
- openbmb/UltraFeedback
- argilla/ultrafeedback-binarized-preferences-cleaned
metrics:
- accuracy
tags:
- reward_model
- reward-model
- RLHF
- evaluation
- llm
- instruction
- reranking
language:
- en
---
# Better Implementation of [*PairRM*](https://huggingface.co/llm-blender/PairRM)

## Introduction

This version of PairRM have some fixes on training process, which improve model's performance by **15%**.

### Minor Fixes

- Longer Context Length (2048 -> 3370)

Thanks to deberta's tokenzer, original PairRM model had enough Context Length.

But, the longer the better :>

---

### Major Fixes

- Change Prompt Format

Why use something like
```
<Response i + 1> {response}
```

So, I changed to a format based on Vicuna 1.1.

---

- Change Truncate side

The original process was using right side truncate even on Input. This can cause serious problem when Input exceeds model's context length.

---

- Dataset Filter

There was decent amount of empty assistant response on original dataset. So, I dropped them.

---

## Example Code

**The code below is modified from** (**PairRM-hf Repo**)[https://huggingface.co/llm-blender/PairRM-hf]

```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
from transformers import AutoTokenizer
from typing import List
pairrm = DebertaV2PairRM.from_pretrained("maywell/Better-PairRM", device_map="cuda:0").eval()
tokenizer = AutoTokenizer.from_pretrained("maywell/Better-PairRM")
source_prefix = "<|source|>"
cand1_prefix = "<|candidate1|>"
cand2_prefix = "<|candidate2|>"
inputs = ["hello!", "I love you!"]
candidates_A = ["hi!", "I hate you!"]
candidates_B = ["f**k off!", "I love you, too!"]
def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=2030, candidate_max_length=670):
    ids = []
    assert len(sources) == len(candidate1s) == len(candidate2s)
    max_length = source_max_length + 2 * candidate_max_length
    for i in range(len(sources)):
        source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)
        candidate_max_length = (max_length - len(source_ids)) // 2
        candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)
        candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)
        ids.append(source_ids + candidate1_ids + candidate2_ids)
    encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length)
    return encodings

encodings = tokenize_pair(inputs, candidates_A, candidates_B)
encodings = {k:v.to(pairrm.device) for k,v in encodings.items()}
outputs = pairrm(**encodings)
logits = outputs.logits.tolist()
comparison_results = outputs.logits > 0
print(logits)
print(comparison_results)
```

You can also easily compare two conversations like the followings:
```python
import jinja2
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")

def truncate_texts(text, max_length, truncate_side):
    tokenizer.truncation_side = truncate_side
    tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length)
    truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
    return truncated_text

MY_JINJA_TEMPLATE = """{% for message in messages -%}
{% if message['role'] == 'user' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% elif message['role'] == 'assistant' -%}
ASSISTANT: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% elif message['role'] == 'user_context' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% elif message['role'] == 'system' -%}
SYSTEM MESSAGE: {{ message['content']|trim -}}
{% if not loop.last -%}


{% endif %}
{% endif %}
{% endfor -%}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
ASSISTANT: {% endif -%}"""

my_jinja2_env = jinja2.Environment()
my_jinja2_template = my_jinja2_env.from_string(MY_JINJA_TEMPLATE)

def tokenize_conv_pair(convAs: List[str], convBs: List[str]):

    # check conversations correctness
    assert len(convAs) == len(convBs), "Number of conversations must be the same"
    for c_a, c_b in zip(convAs, convBs):
        assert len(c_a) == len(c_b), "Number of turns in each conversation must be the same"
        assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), "USER turns must be the same"
    
    inputs = [
        truncate_texts(my_jinja2_template.render(messages=x[:-1], add_generation_prompt=True), 2030, "left") for x in convAs
    ]
    cand1_texts = [
        truncate_texts(x[-1]['content'], 670, "right") for x in convAs
    ]
    cand2_texts = [
        truncate_texts(x[-1]['content'], 670, "right") for x in convBs
    ]
    encodings = tokenize_pair(inputs, cand1_texts, cand2_texts)
    return encodings
```

## Statistics

### Context length
|  PairRanker type  | Source max length | Candidate max length | Total max length |
|:-----------------:|:-----------------:|----------------------|------------------|
| [pair-ranker](https://huggingface.co/llm-blender/pair-ranker)             | 128               | 128                  | 384              |
| [PairRM](https://huggingface.co/llm-blender/pair-reward-model/) | 1224              | 412                  | 2048             |
| [Better-PairRM](https://huggingface.co/maywell/Better-PairRM/) (This model) | 2030              | 670                  | 3370             |

### Performance

#### Reward-Bench by AllenAI

| Metric                     | llm-blender/PairRM-hf  | maywell/Better-PairRM  |
|----------------------------|------------------------|------------------------|
| model                      | llm-blender/PairRM-hf  | maywell/Better-PairRM  |
| model_type                 | Custom Classifier      | Custom Classifier      |
| alpacaeval-length          | 0.758                  | **0.863**                |
| alpacaeval-hard            | 0.979                  | **1.000**                  |
| alpacaeval-easy            | 0.970                  | **0.990**                  |
| donotanswer                | 0.360                  | **0.522**                  |
| hep-cpp                    | 0.628                  | **0.646**                  |
| hep-go                     | 0.689                  | **0.713**                  |
| hep-java                   | 0.628                  | **0.713**                  |
| hep-js                     | 0.604                  | **0.707**                  |
| hep-python                 | 0.646                  | **0.713**                  |
| hep-rust                   | 0.652                  | **0.726**                  |
| llmbar-adver-GPTInst       | **0.304**                  | 0.141                  |
| llmbar-adver-GPTOut        | **0.596**                  | 0.447                  |
| llmbar-adver-manual        | **0.500**                  | 0.261                  |
| llmbar-adver-neighbor      | **0.433**                  | 0.276                  |
| llmbar-natural             | **0.800**                  | 0.720                  |
| math-prm                   | **0.333**                  | 0.295                  |
| mt-bench-hard              | 0.649                  | **0.703**                  |
| mt-bench-med               | 0.900                  | **1.000**                  |
| mt-bench-easy              | **0.964**                  | 0.929                  |
| refusals-dangerous         | 0.080                  | **0.730**                  |
| refusals-offensive         | 0.010                  | **0.940**                  |
| xstest-should-refuse       | 0.370                  | **0.968**                  |
| xstest-should-respond      | **0.952**                  | 0.876                  |
| average                    | 0.600                  | **0.690**                  |

> *Note - llmbar test score is bit weird across all models on [Reward-Bench](https://huggingface.co/spaces/allenai/reward-bench)*

## Thanks to

- [Sionic AI](https://sionic.ai/) for providing the A100 cluster.

## Contact

- [Discord Server Link](https://discord.gg/MrBt3PXdXc)

## Original Paper 
```
@inproceedings{llm-blender-2023,
    title = "LLM-Blender: Ensembling Large Language Models with Pairwise Comparison and Generative Fusion",
    author = "Jiang, Dongfu and Ren, Xiang and Lin, Bill Yuchen",
    booktitle = "Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics (ACL 2023)",
    year = "2023"
}
```