File size: 8,937 Bytes
7fc0cce 69f3ec1 7fc0cce 24ffc35 1b901ae a775416 1b901ae c9f7a5d 1b901ae a775416 1b901ae a775416 1b901ae a775416 1b901ae a775416 1b901ae a775416 1b901ae a775416 1b901ae a775416 1b901ae a775416 0319c87 2fc8ad5 0319c87 a775416 5fca838 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
---
license: apache-2.0
datasets:
- openai/summarize_from_feedback
- openai/webgpt_comparisons
- berkeley-nest/Nectar
- Dahoas/instruct-synthetic-prompt-responses
- Anthropic/hh-rlhf
- lmsys/chatbot_arena_conversations
- openbmb/UltraFeedback
- argilla/ultrafeedback-binarized-preferences-cleaned
metrics:
- accuracy
tags:
- reward_model
- reward-model
- RLHF
- evaluation
- llm
- instruction
- reranking
language:
- en
---
# Better Implementation of [*PairRM*](https://huggingface.co/llm-blender/PairRM)
## Introduction
This version of PairRM have some fixes on training process, which improve model's performance by **15%**.
### Minor Fixes
- Longer Context Length (2048 -> 3370)
Thanks to deberta's tokenzer, original PairRM model had enough Context Length.
But, the longer the better :>
---
### Major Fixes
- Change Prompt Format
Why use something like
```
<Response i + 1> {response}
```
So, I changed to a format based on Vicuna 1.1.
---
- Change Truncate side
The original process was using right side truncate even on Input. This can cause serious problem when Input exceeds model's context length.
---
- Dataset Filter
There was decent amount of empty assistant response on original dataset. So, I dropped them.
---
## Example Code
**The code below is modified from** (**PairRM-hf Repo**)[https://huggingface.co/llm-blender/PairRM-hf]
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from llm_blender.pair_ranker.pairrm import DebertaV2PairRM
from transformers import AutoTokenizer
from typing import List
pairrm = DebertaV2PairRM.from_pretrained("maywell/Better-PairRM", device_map="cuda:0").eval()
tokenizer = AutoTokenizer.from_pretrained("maywell/Better-PairRM")
source_prefix = "<|source|>"
cand1_prefix = "<|candidate1|>"
cand2_prefix = "<|candidate2|>"
inputs = ["hello!", "I love you!"]
candidates_A = ["hi!", "I hate you!"]
candidates_B = ["f**k off!", "I love you, too!"]
def tokenize_pair(sources:List[str], candidate1s:List[str], candidate2s:List[str], source_max_length=2030, candidate_max_length=670):
ids = []
assert len(sources) == len(candidate1s) == len(candidate2s)
max_length = source_max_length + 2 * candidate_max_length
for i in range(len(sources)):
source_ids = tokenizer.encode(source_prefix + sources[i], max_length=source_max_length, truncation=True)
candidate_max_length = (max_length - len(source_ids)) // 2
candidate1_ids = tokenizer.encode(cand1_prefix + candidate1s[i], max_length=candidate_max_length, truncation=True)
candidate2_ids = tokenizer.encode(cand2_prefix + candidate2s[i], max_length=candidate_max_length, truncation=True)
ids.append(source_ids + candidate1_ids + candidate2_ids)
encodings = tokenizer.pad({"input_ids": ids}, return_tensors="pt", padding="max_length", max_length=max_length)
return encodings
encodings = tokenize_pair(inputs, candidates_A, candidates_B)
encodings = {k:v.to(pairrm.device) for k,v in encodings.items()}
outputs = pairrm(**encodings)
logits = outputs.logits.tolist()
comparison_results = outputs.logits > 0
print(logits)
print(comparison_results)
```
You can also easily compare two conversations like the followings:
```python
import jinja2
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-large")
def truncate_texts(text, max_length, truncate_side):
tokenizer.truncation_side = truncate_side
tokens = tokenizer.encode(text, add_special_tokens=False, max_length=max_length)
truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
return truncated_text
MY_JINJA_TEMPLATE = """{% for message in messages -%}
{% if message['role'] == 'user' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% elif message['role'] == 'assistant' -%}
ASSISTANT: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% elif message['role'] == 'user_context' -%}
USER: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% elif message['role'] == 'system' -%}
SYSTEM MESSAGE: {{ message['content']|trim -}}
{% if not loop.last -%}
{% endif %}
{% endif %}
{% endfor -%}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' -%}
ASSISTANT: {% endif -%}"""
my_jinja2_env = jinja2.Environment()
my_jinja2_template = my_jinja2_env.from_string(MY_JINJA_TEMPLATE)
def tokenize_conv_pair(convAs: List[str], convBs: List[str]):
# check conversations correctness
assert len(convAs) == len(convBs), "Number of conversations must be the same"
for c_a, c_b in zip(convAs, convBs):
assert len(c_a) == len(c_b), "Number of turns in each conversation must be the same"
assert all([c_a[i]['content'] == c_b[i]['content'] for i in range(0, len(c_a), 2)]), "USER turns must be the same"
inputs = [
truncate_texts(my_jinja2_template.render(messages=x[:-1], add_generation_prompt=True), 2030, "left") for x in convAs
]
cand1_texts = [
truncate_texts(x[-1]['content'], 670, "right") for x in convAs
]
cand2_texts = [
truncate_texts(x[-1]['content'], 670, "right") for x in convBs
]
encodings = tokenize_pair(inputs, cand1_texts, cand2_texts)
return encodings
```
## Statistics
### Context length
| PairRanker type | Source max length | Candidate max length | Total max length |
|:-----------------:|:-----------------:|----------------------|------------------|
| [pair-ranker](https://huggingface.co/llm-blender/pair-ranker) | 128 | 128 | 384 |
| [PairRM](https://huggingface.co/llm-blender/pair-reward-model/) | 1224 | 412 | 2048 |
| [Better-PairRM](https://huggingface.co/maywell/Better-PairRM/) (This model) | 2030 | 670 | 3370 |
### Performance
#### Reward-Bench by AllenAI
| Metric | llm-blender/PairRM-hf | maywell/Better-PairRM |
|----------------------------|------------------------|------------------------|
| model | llm-blender/PairRM-hf | maywell/Better-PairRM |
| model_type | Custom Classifier | Custom Classifier |
| alpacaeval-length | 0.758 | **0.863** |
| alpacaeval-hard | 0.979 | **1.000** |
| alpacaeval-easy | 0.970 | **0.990** |
| donotanswer | 0.360 | **0.522** |
| hep-cpp | 0.628 | **0.646** |
| hep-go | 0.689 | **0.713** |
| hep-java | 0.628 | **0.713** |
| hep-js | 0.604 | **0.707** |
| hep-python | 0.646 | **0.713** |
| hep-rust | 0.652 | **0.726** |
| llmbar-adver-GPTInst | **0.304** | 0.141 |
| llmbar-adver-GPTOut | **0.596** | 0.447 |
| llmbar-adver-manual | **0.500** | 0.261 |
| llmbar-adver-neighbor | **0.433** | 0.276 |
| llmbar-natural | **0.800** | 0.720 |
| math-prm | **0.333** | 0.295 |
| mt-bench-hard | 0.649 | **0.703** |
| mt-bench-med | 0.900 | **1.000** |
| mt-bench-easy | **0.964** | 0.929 |
| refusals-dangerous | 0.080 | **0.730** |
| refusals-offensive | 0.010 | **0.940** |
| xstest-should-refuse | 0.370 | **0.968** |
| xstest-should-respond | **0.952** | 0.876 |
| average | 0.600 | **0.690** |
> *Note - llmbar test score is bit weird across all models on [Reward-Bench](https://huggingface.co/spaces/allenai/reward-bench)*
## Thanks to
- [Sionic AI](https://sionic.ai/) for providing the A100 cluster.
## Contact
- [Discord Server Link](https://discord.gg/MrBt3PXdXc)
## Original Paper
```
@inproceedings{llm-blender-2023,
title = "LLM-Blender: Ensembling Large Language Models with Pairwise Comparison and Generative Fusion",
author = "Jiang, Dongfu and Ren, Xiang and Lin, Bill Yuchen",
booktitle = "Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics (ACL 2023)",
year = "2023"
}
``` |