Spaces:
Runtime error
Runtime error
DongfuJiang
commited on
Commit
·
bf79ee8
1
Parent(s):
fcac78f
update
Browse files- app.py +410 -233
- descriptions.py +49 -0
app.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
import gradio as gr
|
2 |
import sys
|
3 |
import os
|
4 |
-
import
|
|
|
|
|
5 |
from datasets import load_dataset
|
|
|
6 |
from typing import List
|
7 |
|
|
|
8 |
MAX_BASE_LLM_NUM = 20
|
9 |
MIN_BASE_LLM_NUM = 3
|
10 |
SOURCE_MAX_LENGTH = 256
|
@@ -13,10 +17,6 @@ CANDIDATE_MAX_LENGTH = 256
|
|
13 |
DEFAULT_CANDIDATE_MAX_LENGTH = 128
|
14 |
FUSER_MAX_NEW_TOKENS = 512
|
15 |
DEFAULT_FUSER_MAX_NEW_TOKENS = 256
|
16 |
-
DESCRIPTIONS = """# LLM-BLENDER
|
17 |
-
|
18 |
-
LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs). LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs.
|
19 |
-
"""
|
20 |
EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
|
21 |
SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
|
22 |
EXAMPLES = []
|
@@ -28,47 +28,75 @@ for example in SHUFFLED_EXAMPLES_DATASET.take(100):
|
|
28 |
])
|
29 |
CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
def update_base_llms_num(k, llm_outputs):
|
59 |
k = int(k)
|
60 |
-
return [gr.Dropdown
|
61 |
value=f"LLM-1" if k >= 1 else "", visible=True),
|
62 |
{f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
|
63 |
|
64 |
|
65 |
def display_llm_output(llm_outputs, selected_base_llm_name):
|
66 |
-
return gr.Textbox
|
67 |
label=selected_base_llm_name + " (Click Save to save current content)",
|
68 |
placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)
|
69 |
|
70 |
def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
|
71 |
-
llm_outputs
|
72 |
return llm_outputs
|
73 |
|
74 |
def get_preprocess_examples(inst, input):
|
@@ -131,217 +159,366 @@ def display_fuser_output(fuser_output):
|
|
131 |
|
132 |
|
133 |
with gr.Blocks(theme='ParityError/Anime') as demo:
|
134 |
-
|
135 |
-
|
136 |
-
with gr.Row():
|
137 |
-
with gr.Column():
|
138 |
-
inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
|
139 |
-
input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
|
140 |
-
with gr.Column():
|
141 |
-
saved_llm_outputs = gr.State(value={})
|
142 |
-
with gr.Group():
|
143 |
-
selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
|
144 |
-
choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
|
145 |
-
selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
|
146 |
-
placeholder="Enter LLM-1 output here", show_label=True)
|
147 |
-
with gr.Row():
|
148 |
-
base_llm_outputs_save_button = gr.Button('Save', variant='primary')
|
149 |
-
|
150 |
-
base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
|
151 |
-
|
152 |
-
base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
|
153 |
-
base_llms_num = gr.Slider(
|
154 |
-
label='Number of base llms',
|
155 |
-
minimum=MIN_BASE_LLM_NUM,
|
156 |
-
maximum=MAX_BASE_LLM_NUM,
|
157 |
-
step=1,
|
158 |
-
value=MIN_BASE_LLM_NUM,
|
159 |
-
)
|
160 |
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
188 |
)
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
value=DEFAULT_SOURCE_MAX_LENGTH,
|
195 |
)
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
value=DEFAULT_CANDIDATE_MAX_LENGTH,
|
202 |
)
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
step=1,
|
208 |
-
value=DEFAULT_FUSER_MAX_NEW_TOKENS,
|
209 |
)
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
label='Beam size of fuser generation',
|
226 |
-
minimum=1,
|
227 |
-
maximum=10,
|
228 |
-
step=1,
|
229 |
-
value=4,
|
230 |
)
|
231 |
-
|
232 |
-
examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
|
233 |
-
batch_examples = gr.Examples(
|
234 |
-
examples=EXAMPLES,
|
235 |
-
fn=get_preprocess_examples,
|
236 |
-
cache_examples=True,
|
237 |
-
examples_per_page=5,
|
238 |
-
inputs=[inst_textbox, input_textbox],
|
239 |
-
outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
|
240 |
-
)
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
outputs=[saved_llm_outputs, rank_outputs, fuser_outputs],
|
252 |
-
).then(
|
253 |
-
fn=display_llm_output,
|
254 |
-
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
|
255 |
-
outputs=selected_base_llm_output,
|
256 |
-
)
|
257 |
-
|
258 |
-
selected_base_llm_name_dropdown.change(
|
259 |
-
fn=display_llm_output,
|
260 |
-
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
|
261 |
-
outputs=selected_base_llm_output,
|
262 |
-
)
|
263 |
-
|
264 |
-
base_llm_outputs_save_button.click(
|
265 |
-
fn=save_llm_output,
|
266 |
-
inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
|
267 |
-
outputs=saved_llm_outputs,
|
268 |
-
)
|
269 |
-
base_llm_outputs_clear_all_button.click(
|
270 |
-
fn=lambda: [{}, ""],
|
271 |
-
inputs=[],
|
272 |
-
outputs=[saved_llm_outputs, selected_base_llm_output],
|
273 |
-
)
|
274 |
-
base_llm_outputs_clear_single_button.click(
|
275 |
-
fn=lambda: "",
|
276 |
-
inputs=[],
|
277 |
-
outputs=selected_base_llm_output,
|
278 |
-
)
|
279 |
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
).success(
|
286 |
-
fn=llms_rank,
|
287 |
-
inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
|
288 |
-
outputs=[saved_rank_outputs, rank_outputs],
|
289 |
-
)
|
290 |
-
|
291 |
-
fuse_button.click(
|
292 |
-
fn=check_fuser_inputs,
|
293 |
-
inputs=[blender_state, blender_config, saved_rank_outputs],
|
294 |
-
outputs=fuser_outputs,
|
295 |
-
).success(
|
296 |
-
fn=llms_fuse,
|
297 |
-
inputs=[blender_state, blender_config, saved_rank_outputs],
|
298 |
-
outputs=[saved_fuse_outputs, fuser_outputs],
|
299 |
-
)
|
300 |
-
|
301 |
-
clear_button.click(
|
302 |
-
fn=lambda: ["", "", {}, []],
|
303 |
-
inputs=[],
|
304 |
-
outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
|
305 |
-
)
|
306 |
-
|
307 |
-
# update blender config
|
308 |
-
source_max_length.change(
|
309 |
-
fn=lambda x, y: y.update({"source_max_length": x}) or y,
|
310 |
-
inputs=[source_max_length, blender_config],
|
311 |
-
outputs=blender_config,
|
312 |
-
)
|
313 |
-
candidate_max_length.change(
|
314 |
-
fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
|
315 |
-
inputs=[candidate_max_length, blender_config],
|
316 |
-
outputs=blender_config,
|
317 |
-
)
|
318 |
-
top_k_for_fuser.change(
|
319 |
-
fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
|
320 |
-
inputs=[top_k_for_fuser, blender_config],
|
321 |
-
outputs=blender_config,
|
322 |
-
)
|
323 |
-
max_new_tokens.change(
|
324 |
-
fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
|
325 |
-
inputs=[max_new_tokens, blender_config],
|
326 |
-
outputs=blender_config,
|
327 |
-
)
|
328 |
-
# temperature.change(
|
329 |
-
# fn=lambda x, y: y.update({"temperature": x}) or y,
|
330 |
-
# inputs=[temperature, blender_config],
|
331 |
-
# outputs=blender_config,
|
332 |
-
# )
|
333 |
-
# top_p.change(
|
334 |
-
# fn=lambda x, y: y.update({"top_p": x}) or y,
|
335 |
-
# inputs=[top_p, blender_config],
|
336 |
-
# outputs=blender_config,
|
337 |
-
# )
|
338 |
-
beam_size.change(
|
339 |
-
fn=lambda x, y: y.update({"num_beams": x}) or y,
|
340 |
-
inputs=[beam_size, blender_config],
|
341 |
-
outputs=blender_config,
|
342 |
-
)
|
343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
demo.queue(max_size=20).launch()
|
|
|
1 |
import gradio as gr
|
2 |
import sys
|
3 |
import os
|
4 |
+
import random
|
5 |
+
import llm_blender
|
6 |
+
import descriptions
|
7 |
from datasets import load_dataset
|
8 |
+
from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks
|
9 |
from typing import List
|
10 |
|
11 |
+
|
12 |
MAX_BASE_LLM_NUM = 20
|
13 |
MIN_BASE_LLM_NUM = 3
|
14 |
SOURCE_MAX_LENGTH = 256
|
|
|
17 |
DEFAULT_CANDIDATE_MAX_LENGTH = 128
|
18 |
FUSER_MAX_NEW_TOKENS = 512
|
19 |
DEFAULT_FUSER_MAX_NEW_TOKENS = 256
|
|
|
|
|
|
|
|
|
20 |
EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
|
21 |
SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
|
22 |
EXAMPLES = []
|
|
|
28 |
])
|
29 |
CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
|
30 |
|
31 |
+
HHH_EXAMPLES = []
|
32 |
+
subsets = ['harmless', 'helpful', 'honest', 'other']
|
33 |
+
random.seed(42)
|
34 |
+
for subset in subsets:
|
35 |
+
dataset = load_dataset("HuggingFaceH4/hhh_alignment", subset)
|
36 |
+
for example in dataset['test']:
|
37 |
+
if random.random() < 0.5:
|
38 |
+
HHH_EXAMPLES.append([
|
39 |
+
subset,
|
40 |
+
example['input'],
|
41 |
+
example['targets']['choices'][0],
|
42 |
+
example['targets']['choices'][1],
|
43 |
+
"Response 1" if example['targets']['labels'][0] == 1 else "Response 2",
|
44 |
+
])
|
45 |
+
else:
|
46 |
+
HHH_EXAMPLES.append([
|
47 |
+
subset,
|
48 |
+
example['input'],
|
49 |
+
example['targets']['choices'][1],
|
50 |
+
example['targets']['choices'][0],
|
51 |
+
"Response 2" if example['targets']['labels'][0] == 1 else "Response 1",
|
52 |
+
])
|
53 |
+
def get_hhh_examples(subset, instruction, response1, response2, dummy_text):
|
54 |
+
return instruction, response1, response2
|
55 |
|
56 |
+
MT_BENCH_HUMAN_JUDGE_EXAMPLES = []
|
57 |
+
dataset = load_dataset("lmsys/mt_bench_human_judgments")
|
58 |
+
for example in dataset['human']:
|
59 |
+
if example['turn'] != 1:
|
60 |
+
continue
|
61 |
+
MT_BENCH_HUMAN_JUDGE_EXAMPLES.append([
|
62 |
+
example['model_a'],
|
63 |
+
example['model_b'],
|
64 |
+
str(example['conversation_a']),
|
65 |
+
str(example['conversation_b']),
|
66 |
+
"Model A" if example['winner'] == 'model_a' else "Model B",
|
67 |
+
])
|
68 |
+
def get_mt_bench_human_judge_examples(model_a, model_b, conversation_a, conversation_b, dummy_text):
|
69 |
+
chat_history_a = []
|
70 |
+
chat_history_b = []
|
71 |
+
conversation_a = eval(conversation_a)
|
72 |
+
conversation_b = eval(conversation_b)
|
73 |
+
for i in range(0, len(conversation_a), 2):
|
74 |
+
chat_history_a.append((conversation_a[i]['content'], conversation_a[i+1]['content']))
|
75 |
+
assert conversation_a[i]['role'] == 'user' and conversation_a[i+1]['role'] == 'assistant'
|
76 |
+
for i in range(0, len(conversation_b), 2):
|
77 |
+
chat_history_b.append((conversation_b[i]['content'], conversation_b[i+1]['content']))
|
78 |
+
assert conversation_b[i]['role'] == 'user' and conversation_b[i+1]['role'] == 'assistant'
|
79 |
+
return chat_history_a, chat_history_b
|
80 |
+
|
81 |
+
|
82 |
+
blender = llm_blender.Blender()
|
83 |
+
blender.loadranker("llm-blender/PairRM")
|
84 |
+
blender.loadfuser("llm-blender/gen_fuser_3b")
|
85 |
|
86 |
def update_base_llms_num(k, llm_outputs):
|
87 |
k = int(k)
|
88 |
+
return [gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)],
|
89 |
value=f"LLM-1" if k >= 1 else "", visible=True),
|
90 |
{f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
|
91 |
|
92 |
|
93 |
def display_llm_output(llm_outputs, selected_base_llm_name):
|
94 |
+
return gr.Textbox(value=llm_outputs.get(selected_base_llm_name, ""),
|
95 |
label=selected_base_llm_name + " (Click Save to save current content)",
|
96 |
placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)
|
97 |
|
98 |
def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
|
99 |
+
llm_outputs({selected_base_llm_name: selected_base_llm_output})
|
100 |
return llm_outputs
|
101 |
|
102 |
def get_preprocess_examples(inst, input):
|
|
|
159 |
|
160 |
|
161 |
with gr.Blocks(theme='ParityError/Anime') as demo:
|
162 |
+
|
163 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
+
with gr.Tab("LLM-Blender"):
|
166 |
+
# llm-blender interface
|
167 |
+
with gr.Row():
|
168 |
+
gr.Markdown(descriptions.LLM_BLENDER_OVERALL_DESC)
|
169 |
+
gr.Image("https://github.com/yuchenlin/LLM-Blender/blob/main/docs/llm_blender.png?raw=true", height=300)
|
170 |
+
gr.Markdown("## Input and Base LLMs")
|
171 |
+
with gr.Row():
|
172 |
+
with gr.Column():
|
173 |
+
inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
|
174 |
+
input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
|
175 |
+
with gr.Column():
|
176 |
+
saved_llm_outputs = gr.State(value={})
|
177 |
+
with gr.Group():
|
178 |
+
selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
|
179 |
+
choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
|
180 |
+
selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
|
181 |
+
placeholder="Enter LLM-1 output here", show_label=True)
|
182 |
+
with gr.Row():
|
183 |
+
base_llm_outputs_save_button = gr.Button('Save', variant='primary')
|
184 |
+
|
185 |
+
base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
|
186 |
+
|
187 |
+
base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
|
188 |
+
base_llms_num = gr.Slider(
|
189 |
+
label='Number of base llms',
|
190 |
+
minimum=MIN_BASE_LLM_NUM,
|
191 |
+
maximum=MAX_BASE_LLM_NUM,
|
192 |
+
step=1,
|
193 |
+
value=MIN_BASE_LLM_NUM,
|
194 |
+
)
|
195 |
+
|
196 |
+
blender_state = gr.State(value={})
|
197 |
+
saved_rank_outputs = gr.State(value=[])
|
198 |
+
saved_fuse_outputs = gr.State(value=[])
|
199 |
+
gr.Markdown("## Blender Outputs")
|
200 |
+
with gr.Group():
|
201 |
+
rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
|
202 |
+
fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
|
203 |
+
with gr.Row():
|
204 |
+
rank_button = gr.Button('Rank LLM Outputs', variant='primary')
|
205 |
+
fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary')
|
206 |
+
clear_button = gr.Button('Clear Blender Outputs', variant='primary')
|
207 |
+
blender_config = gr.State(value={
|
208 |
+
"source_max_length": DEFAULT_SOURCE_MAX_LENGTH,
|
209 |
+
"candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH,
|
210 |
+
"top_k_for_fuser": 3,
|
211 |
+
"max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS,
|
212 |
+
"temperature": 0.7,
|
213 |
+
"top_p": 1.0,
|
214 |
+
})
|
215 |
+
|
216 |
+
with gr.Accordion(label='Advanced options', open=False):
|
217 |
+
source_max_length = gr.Slider(
|
218 |
+
label='Max length of Instruction + Input',
|
219 |
+
minimum=1,
|
220 |
+
maximum=SOURCE_MAX_LENGTH,
|
221 |
+
step=1,
|
222 |
+
value=DEFAULT_SOURCE_MAX_LENGTH,
|
223 |
+
)
|
224 |
+
candidate_max_length = gr.Slider(
|
225 |
+
label='Max length of LLM-Output Candidate',
|
226 |
+
minimum=1,
|
227 |
+
maximum=CANDIDATE_MAX_LENGTH,
|
228 |
+
step=1,
|
229 |
+
value=DEFAULT_CANDIDATE_MAX_LENGTH,
|
230 |
+
)
|
231 |
+
top_k_for_fuser = gr.Slider(
|
232 |
+
label='Top-k ranked candidates to fuse',
|
233 |
+
minimum=1,
|
234 |
+
maximum=3,
|
235 |
+
step=1,
|
236 |
+
value=3,
|
237 |
+
)
|
238 |
+
max_new_tokens = gr.Slider(
|
239 |
+
label='Max new tokens fuser can generate',
|
240 |
+
minimum=1,
|
241 |
+
maximum=FUSER_MAX_NEW_TOKENS,
|
242 |
+
step=1,
|
243 |
+
value=DEFAULT_FUSER_MAX_NEW_TOKENS,
|
244 |
+
)
|
245 |
+
temperature = gr.Slider(
|
246 |
+
label='Temperature of fuser generation',
|
247 |
+
minimum=0.1,
|
248 |
+
maximum=2.0,
|
249 |
+
step=0.1,
|
250 |
+
value=0.7,
|
251 |
+
)
|
252 |
+
top_p = gr.Slider(
|
253 |
+
label='Top-p of fuser generation',
|
254 |
+
minimum=0.05,
|
255 |
+
maximum=1.0,
|
256 |
+
step=0.05,
|
257 |
+
value=1.0,
|
258 |
+
)
|
259 |
+
|
260 |
+
examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
|
261 |
+
batch_examples = gr.Examples(
|
262 |
+
examples=EXAMPLES,
|
263 |
+
fn=get_preprocess_examples,
|
264 |
+
cache_examples=True,
|
265 |
+
examples_per_page=5,
|
266 |
+
inputs=[inst_textbox, input_textbox],
|
267 |
+
outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
|
268 |
+
)
|
269 |
+
|
270 |
+
base_llms_num.change(
|
271 |
+
fn=update_base_llms_num,
|
272 |
+
inputs=[base_llms_num, saved_llm_outputs],
|
273 |
+
outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
|
274 |
+
)
|
275 |
|
276 |
+
examples_dummy_textbox.change(
|
277 |
+
fn=update_base_llm_dropdown_along_examples,
|
278 |
+
inputs=[examples_dummy_textbox],
|
279 |
+
outputs=[saved_llm_outputs, rank_outputs, fuser_outputs],
|
280 |
+
).then(
|
281 |
+
fn=display_llm_output,
|
282 |
+
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
|
283 |
+
outputs=selected_base_llm_output,
|
284 |
)
|
285 |
+
|
286 |
+
selected_base_llm_name_dropdown.change(
|
287 |
+
fn=display_llm_output,
|
288 |
+
inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
|
289 |
+
outputs=selected_base_llm_output,
|
|
|
290 |
)
|
291 |
+
|
292 |
+
base_llm_outputs_save_button.click(
|
293 |
+
fn=save_llm_output,
|
294 |
+
inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
|
295 |
+
outputs=saved_llm_outputs,
|
|
|
296 |
)
|
297 |
+
base_llm_outputs_clear_all_button.click(
|
298 |
+
fn=lambda: [{}, ""],
|
299 |
+
inputs=[],
|
300 |
+
outputs=[saved_llm_outputs, selected_base_llm_output],
|
|
|
|
|
301 |
)
|
302 |
+
base_llm_outputs_clear_single_button.click(
|
303 |
+
fn=lambda: "",
|
304 |
+
inputs=[],
|
305 |
+
outputs=selected_base_llm_output,
|
306 |
+
)
|
307 |
+
|
308 |
+
|
309 |
+
rank_button.click(
|
310 |
+
fn=check_save_ranker_inputs,
|
311 |
+
inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
|
312 |
+
outputs=blender_state,
|
313 |
+
).success(
|
314 |
+
fn=llms_rank,
|
315 |
+
inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
|
316 |
+
outputs=[saved_rank_outputs, rank_outputs],
|
|
|
|
|
|
|
|
|
|
|
317 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
|
319 |
+
fuse_button.click(
|
320 |
+
fn=check_fuser_inputs,
|
321 |
+
inputs=[blender_state, blender_config, saved_rank_outputs],
|
322 |
+
outputs=[],
|
323 |
+
).success(
|
324 |
+
fn=llms_fuse,
|
325 |
+
inputs=[blender_state, blender_config, saved_rank_outputs],
|
326 |
+
outputs=[saved_fuse_outputs, fuser_outputs],
|
327 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
+
clear_button.click(
|
330 |
+
fn=lambda: ["", "", {}, []],
|
331 |
+
inputs=[],
|
332 |
+
outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
|
333 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
+
# update blender config
|
336 |
+
source_max_length.change(
|
337 |
+
fn=lambda x, y: y.update({"source_max_length": x}) or y,
|
338 |
+
inputs=[source_max_length, blender_config],
|
339 |
+
outputs=blender_config,
|
340 |
+
)
|
341 |
+
candidate_max_length.change(
|
342 |
+
fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
|
343 |
+
inputs=[candidate_max_length, blender_config],
|
344 |
+
outputs=blender_config,
|
345 |
+
)
|
346 |
+
top_k_for_fuser.change(
|
347 |
+
fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
|
348 |
+
inputs=[top_k_for_fuser, blender_config],
|
349 |
+
outputs=blender_config,
|
350 |
+
)
|
351 |
+
max_new_tokens.change(
|
352 |
+
fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
|
353 |
+
inputs=[max_new_tokens, blender_config],
|
354 |
+
outputs=blender_config,
|
355 |
+
)
|
356 |
+
temperature.change(
|
357 |
+
fn=lambda x, y: y.update({"temperature": x}) or y,
|
358 |
+
inputs=[temperature, blender_config],
|
359 |
+
outputs=blender_config,
|
360 |
+
)
|
361 |
+
top_p.change(
|
362 |
+
fn=lambda x, y: y.update({"top_p": x}) or y,
|
363 |
+
inputs=[top_p, blender_config],
|
364 |
+
outputs=blender_config,
|
365 |
+
)
|
366 |
|
367 |
|
368 |
+
with gr.Tab("PairRM"):
|
369 |
+
# PairRM interface
|
370 |
+
with gr.Row():
|
371 |
+
gr.Markdown(descriptions.PairRM_OVERALL_DESC)
|
372 |
+
gr.Image("https://yuchenlin.xyz/LLM-Blender/pairranker.png")
|
373 |
+
|
374 |
+
with gr.Tab("Compare two responses"):
|
375 |
+
instruction = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
|
376 |
+
with gr.Row():
|
377 |
+
response1 = gr.Textbox(lines=4, label="Response 1", placeholder="Enter response 1 here", show_label=True)
|
378 |
+
response2 = gr.Textbox(lines=4, label="Response 2", placeholder="Enter response 2 here", show_label=True)
|
379 |
+
with gr.Row():
|
380 |
+
compare_button = gr.Button('Compare', variant='primary')
|
381 |
+
clear_button = gr.Button('Clear', variant='primary')
|
382 |
+
with gr.Row():
|
383 |
+
compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
|
384 |
+
compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
|
385 |
+
|
386 |
+
def compare_fn(inst, response1, response2):
|
387 |
+
if not inst:
|
388 |
+
raise gr.Error("Please enter instruction")
|
389 |
+
if not response1 or not response2:
|
390 |
+
raise gr.Error("Please enter response 1 and response 2")
|
391 |
+
comparison_results = blender.compare([inst], [response1], [response2], return_logits=True)
|
392 |
+
logit = comparison_results[0]
|
393 |
+
if logit > 0:
|
394 |
+
result = "Response 1 is better than Response 2"
|
395 |
+
prob = f"Confidence: {round(logit, 2)}"
|
396 |
+
elif logit < 0:
|
397 |
+
result = "Response 2 is better than Response 1"
|
398 |
+
prob = f"Cofidence: {round(abs(logit), 2)}"
|
399 |
+
else:
|
400 |
+
result = "Response 1 and Response 2 are equally good"
|
401 |
+
prob = f"No confidence for tie"
|
402 |
+
|
403 |
+
return [result, prob]
|
404 |
+
compare_button.click(
|
405 |
+
fn=compare_fn,
|
406 |
+
inputs=[instruction, response1, response2],
|
407 |
+
outputs=[compare_result, compare_result_prob],
|
408 |
+
)
|
409 |
+
clear_button.click(
|
410 |
+
fn=lambda: ["", ""],
|
411 |
+
inputs=[],
|
412 |
+
outputs=[compare_result, compare_result_prob],
|
413 |
+
)
|
414 |
+
|
415 |
+
hhh_dummy_textbox1 = gr.Textbox(lines=1, label="subset", placeholder="", show_label=False, visible=False)
|
416 |
+
hhh_dummy_textbox2 = gr.Textbox(lines=1, label="Better Response", placeholder="", show_label=False, visible=False)
|
417 |
+
gr.Markdown("## Examples from [HuggingFaceH4/hhh_alignment](https://huggingface.co/datasets/HuggingFaceH4/hhh_alignment)")
|
418 |
+
gr.Examples(
|
419 |
+
HHH_EXAMPLES,
|
420 |
+
fn=get_hhh_examples,
|
421 |
+
cache_examples=True,
|
422 |
+
examples_per_page=5,
|
423 |
+
inputs=[hhh_dummy_textbox1, instruction, response1, response2, hhh_dummy_textbox2],
|
424 |
+
outputs=[instruction, response1, response2],
|
425 |
+
)
|
426 |
+
|
427 |
+
|
428 |
+
with gr.Tab("Compare assistant's response in two multi-turn conversations"):
|
429 |
+
|
430 |
+
gr.Markdown("NOTE: Comparison of two conversations is based on that the user query in each turn is the same of two conversations.")
|
431 |
+
def append_message(message, chat_history):
|
432 |
+
if not message:
|
433 |
+
return "", chat_history
|
434 |
+
if len(chat_history) == 0:
|
435 |
+
chat_history.append((message, "(Please enter your bot response)"))
|
436 |
+
else:
|
437 |
+
if chat_history[-1][1] == "(Please enter your bot response)":
|
438 |
+
chat_history[-1] = (chat_history[-1][0], message)
|
439 |
+
else:
|
440 |
+
chat_history.append((message, "(Please enter your bot response)"))
|
441 |
+
return "", chat_history
|
442 |
+
with gr.Row():
|
443 |
+
with gr.Column():
|
444 |
+
gr.Markdown("### Conversation A")
|
445 |
+
chatbot1 = gr.Chatbot()
|
446 |
+
msg1 = gr.Textbox(lines=1, label="Enter Chat history for Conversation A", placeholder="Enter your message here", show_label=True)
|
447 |
+
clear1 = gr.ClearButton([msg1, chatbot1])
|
448 |
+
msg1.submit(append_message, [msg1, chatbot1], [msg1, chatbot1])
|
449 |
+
with gr.Column():
|
450 |
+
gr.Markdown("### Conversation B")
|
451 |
+
chatbot2 = gr.Chatbot()
|
452 |
+
msg2 = gr.Textbox(lines=1, label="Enter Chat history for Conversation B", placeholder="Enter your message here", show_label=True)
|
453 |
+
clear2 = gr.ClearButton([msg2, chatbot2])
|
454 |
+
msg2.submit(append_message, [msg2, chatbot2], [msg2, chatbot2])
|
455 |
+
with gr.Row():
|
456 |
+
compare_button = gr.Button('Compare', variant='primary')
|
457 |
+
with gr.Row():
|
458 |
+
compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
|
459 |
+
compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
|
460 |
+
|
461 |
+
def compare_conv_fn(chat_history1, chat_history2):
|
462 |
+
if len(chat_history1) == 0 or len(chat_history2) == 0:
|
463 |
+
raise gr.Error("Please enter chat history for both conversations")
|
464 |
+
assert chat_history1[-1][1] != "(Please enter your bot response)" \
|
465 |
+
and chat_history2[-1][1] != "(Please enter your bot response)", \
|
466 |
+
"Please complete chat history for both conversations"
|
467 |
+
chat1_messages = []
|
468 |
+
for item in chat_history1:
|
469 |
+
chat1_messages.append({
|
470 |
+
"role": "USER",
|
471 |
+
"content": item[0],
|
472 |
+
})
|
473 |
+
chat1_messages.append({
|
474 |
+
"role": "ASSISTANT",
|
475 |
+
"content": item[1],
|
476 |
+
})
|
477 |
+
chat2_messages = []
|
478 |
+
for item in chat_history2:
|
479 |
+
chat2_messages.append({
|
480 |
+
"role": "USER",
|
481 |
+
"content": item[0],
|
482 |
+
})
|
483 |
+
chat2_messages.append({
|
484 |
+
"role": "ASSISTANT",
|
485 |
+
"content": item[1],
|
486 |
+
})
|
487 |
+
|
488 |
+
comparison_results = blender.compare_conversations([chat1_messages], [chat2_messages], return_logits=True)
|
489 |
+
logit = comparison_results[0]
|
490 |
+
if logit > 0:
|
491 |
+
result = "Assistant's response in Conversation A is better than Conversation B"
|
492 |
+
prob = f"Confidence: {round(logit, 2)}"
|
493 |
+
elif logit < 0:
|
494 |
+
result = "Assistant's response in Conversation B is better than Conversation A"
|
495 |
+
prob = f"Cofidence: {round(abs(logit), 2)}"
|
496 |
+
else:
|
497 |
+
result = "Assistant's response in Conversation A and Conversation B are equally good"
|
498 |
+
prob = f"No confidence for tie"
|
499 |
+
|
500 |
+
return [result, prob]
|
501 |
|
502 |
+
compare_button.click(
|
503 |
+
fn=compare_conv_fn,
|
504 |
+
inputs=[chatbot1, chatbot2],
|
505 |
+
outputs=[compare_result, compare_result_prob],
|
506 |
+
)
|
507 |
+
|
508 |
+
model_a_dummy_textbox = gr.Textbox(lines=1, label="Model A", placeholder="", show_label=False, visible=False)
|
509 |
+
model_b_dummy_textbox = gr.Textbox(lines=1, label="Model B", placeholder="", show_label=False, visible=False)
|
510 |
+
winner_dummy_textbox = gr.Textbox(lines=1, label="Better Model in conversation", placeholder="", show_label=False, visible=False)
|
511 |
+
chatbot1_dummy_textbox = gr.Textbox(lines=1, label="Conversation A", placeholder="", show_label=False, visible=False)
|
512 |
+
chatbot2_dummy_textbox = gr.Textbox(lines=1, label="Conversation B", placeholder="", show_label=False, visible=False)
|
513 |
+
gr.Markdown("## Examples from [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)")
|
514 |
+
gr.Examples(
|
515 |
+
MT_BENCH_HUMAN_JUDGE_EXAMPLES,
|
516 |
+
fn=get_mt_bench_human_judge_examples,
|
517 |
+
cache_examples=True,
|
518 |
+
examples_per_page=5,
|
519 |
+
inputs=[model_a_dummy_textbox, model_b_dummy_textbox, chatbot1_dummy_textbox, chatbot2_dummy_textbox, winner_dummy_textbox],
|
520 |
+
outputs=[chatbot1, chatbot2],
|
521 |
+
)
|
522 |
+
|
523 |
+
gr.Markdown(descriptions.CITATION)
|
524 |
demo.queue(max_size=20).launch()
|
descriptions.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
LLM_BLENDER_OVERALL_DESC = """
|
2 |
+
LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs).
|
3 |
+
LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs.
|
4 |
+
|
5 |
+
Our framework consists of two complementary modules: **PairRanker** and **GenFuser**,
|
6 |
+
addressing the observation that optimal LLMs for different examples can significantly vary.
|
7 |
+
**PairRanker** employs a specialized pairwise comparison method to distinguish subtle differences between candidate outputs.
|
8 |
+
**GenFuser** aims to merge the top-ranked candidates from the aggregation of PairRanker's pairwise
|
9 |
+
comparisons into an improved output by capitalizing on their strengths and mitigating their weaknesses.
|
10 |
+
|
11 |
+
| [Paper](https://arxiv.org/abs/2306.02561)
|
12 |
+
| [Code](https://github.com/yuchenlin/LLM-Blender)
|
13 |
+
| [Dataset](https://huggingface.co/datasets/llm-blender/mix-instruct)
|
14 |
+
| [Models](https://huggingface.co/llm-blender)
|
15 |
+
| [Tweet](https://twitter.com/billyuchenlin/status/1668666357058277377)
|
16 |
+
|
17 |
+
Try LLM-Blender now! 👇
|
18 |
+
"""
|
19 |
+
|
20 |
+
PairRM_OVERALL_DESC = """## 🤗 [PairRM](https://huggingface.co/llm-blender/PairRM)
|
21 |
+
|
22 |
+
PairRM is a reward model based on PairRanker architecture that has been trained on various high-quality and
|
23 |
+
large-scale dataset with human preference annotations and exhibits great correlation with human preferences.
|
24 |
+
|
25 |
+
**While PairRM is a extremely small model (0.4B), our tests on various human alignment benchmarks show approaching performance of GPT-4.** (See [PairRM](https://huggingface.co/llm-blender/PairRM) for more detail results)
|
26 |
+
|
27 |
+
PairRM could be easily applied to 3 scenarios:
|
28 |
+
1. [Directly compare two responses or two conversations](https://huggingface.co/llm-blender/PairRM#use-case-1-compare-responses-quality-evaluator)
|
29 |
+
2. [Do best-of-n sampling to enhancing the ability of LLM](https://huggingface.co/llm-blender/PairRM#use-case-2-best-of-n-sampling-decoding-enhancing)
|
30 |
+
3. [RLHF alignment](https://huggingface.co/llm-blender/PairRM#use-case-3-rlhf)
|
31 |
+
|
32 |
+
This demo allows user to interact with PairRM in the first scenario. Feel free to compare the quality of two responses or two conversations by inputting them in the text box below.
|
33 |
+
|
34 |
+
Try PairRM now! 👇
|
35 |
+
"""
|
36 |
+
|
37 |
+
CITATION = """
|
38 |
+
|
39 |
+
## Citation
|
40 |
+
```bibtex
|
41 |
+
@inproceedings{llm-blender-2023,
|
42 |
+
title = "LLM-Blender: Ensembling Large Language Models with Pairwise Comparison and Generative Fusion",
|
43 |
+
author = "Jiang, Dongfu and Ren, Xiang and Lin, Bill Yuchen",
|
44 |
+
booktitle = "Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics (ACL 2023)",
|
45 |
+
year = "2023"
|
46 |
+
}
|
47 |
+
|
48 |
+
```
|
49 |
+
"""
|