DongfuJiang commited on
Commit
bf79ee8
·
1 Parent(s): fcac78f
Files changed (2) hide show
  1. app.py +410 -233
  2. descriptions.py +49 -0
app.py CHANGED
@@ -1,10 +1,14 @@
1
  import gradio as gr
2
  import sys
3
  import os
4
- import zipfile
 
 
5
  from datasets import load_dataset
 
6
  from typing import List
7
 
 
8
  MAX_BASE_LLM_NUM = 20
9
  MIN_BASE_LLM_NUM = 3
10
  SOURCE_MAX_LENGTH = 256
@@ -13,10 +17,6 @@ CANDIDATE_MAX_LENGTH = 256
13
  DEFAULT_CANDIDATE_MAX_LENGTH = 128
14
  FUSER_MAX_NEW_TOKENS = 512
15
  DEFAULT_FUSER_MAX_NEW_TOKENS = 256
16
- DESCRIPTIONS = """# LLM-BLENDER
17
-
18
- LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs). LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs.
19
- """
20
  EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
21
  SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
22
  EXAMPLES = []
@@ -28,47 +28,75 @@ for example in SHUFFLED_EXAMPLES_DATASET.take(100):
28
  ])
29
  CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
30
 
31
- # Download ranker checkpoint
32
- if not os.path.exists("pairranker-deberta-v3-large.zip"):
33
- os.system("gdown https://drive.google.com/uc?id=1EpvFu_qYY0MaIu0BAAhK-sYKHVWtccWg")
34
- if not os.path.exists("pairranker-deberta-v3-large"):
35
- with zipfile.ZipFile("pairranker-deberta-v3-large.zip", 'r') as zip_ref:
36
- zip_ref.extractall(".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Load Blender
39
- import llm_blender
40
- from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks
41
- ranker_config = llm_blender.RankerConfig()
42
- ranker_config.ranker_type = "pairranker"
43
- ranker_config.model_type = "deberta"
44
- ranker_config.model_name = "microsoft/deberta-v3-large" # ranker backbone
45
- ranker_config.load_checkpoint = "./pairranker-deberta-v3-large" # ranker checkpoint <your checkpoint path>
46
- ranker_config.source_maxlength = DEFAULT_SOURCE_MAX_LENGTH
47
- ranker_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH
48
- ranker_config.n_tasks = 1 # number of singal that has been used to train the ranker. This checkpoint is trained using BARTScore only, thus being 1.
49
- fuser_config = llm_blender.GenFuserConfig()
50
- fuser_config.model_name = "llm-blender/gen_fuser_3b" # our pre-trained fuser
51
- fuser_config.max_length = 1024
52
- fuser_config.candidate_maxlength = DEFAULT_CANDIDATE_MAX_LENGTH
53
- blender_config = llm_blender.BlenderConfig()
54
- blender_config.load_in_8bit = True
55
- blender_config.device = "cuda" # blender ranker and fuser device
56
- blender = llm_blender.Blender(blender_config, ranker_config, fuser_config)
 
 
 
 
 
 
 
 
 
 
57
 
58
  def update_base_llms_num(k, llm_outputs):
59
  k = int(k)
60
- return [gr.Dropdown.update(choices=[f"LLM-{i+1}" for i in range(k)],
61
  value=f"LLM-1" if k >= 1 else "", visible=True),
62
  {f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
63
 
64
 
65
  def display_llm_output(llm_outputs, selected_base_llm_name):
66
- return gr.Textbox.update(value=llm_outputs.get(selected_base_llm_name, ""),
67
  label=selected_base_llm_name + " (Click Save to save current content)",
68
  placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)
69
 
70
  def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
71
- llm_outputs.update({selected_base_llm_name: selected_base_llm_output})
72
  return llm_outputs
73
 
74
  def get_preprocess_examples(inst, input):
@@ -131,217 +159,366 @@ def display_fuser_output(fuser_output):
131
 
132
 
133
  with gr.Blocks(theme='ParityError/Anime') as demo:
134
- gr.Markdown(DESCRIPTIONS)
135
- gr.Markdown("## Input and Base LLMs")
136
- with gr.Row():
137
- with gr.Column():
138
- inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
139
- input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
140
- with gr.Column():
141
- saved_llm_outputs = gr.State(value={})
142
- with gr.Group():
143
- selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
144
- choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
145
- selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
146
- placeholder="Enter LLM-1 output here", show_label=True)
147
- with gr.Row():
148
- base_llm_outputs_save_button = gr.Button('Save', variant='primary')
149
-
150
- base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
151
-
152
- base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
153
- base_llms_num = gr.Slider(
154
- label='Number of base llms',
155
- minimum=MIN_BASE_LLM_NUM,
156
- maximum=MAX_BASE_LLM_NUM,
157
- step=1,
158
- value=MIN_BASE_LLM_NUM,
159
- )
160
 
161
- blender_state = gr.State(value={})
162
- saved_rank_outputs = gr.State(value=[])
163
- saved_fuse_outputs = gr.State(value=[])
164
- gr.Markdown("## Blender Outputs")
165
- with gr.Group():
166
- rank_outputs = gr.Textbox(lines=1, label="Ranks of each LLM's output", placeholder="Ranking outputs", show_label=True)
167
- fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
168
- with gr.Row():
169
- rank_button = gr.Button('Rank LLM Outputs', variant='primary')
170
- fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary')
171
- clear_button = gr.Button('Clear Blender Outputs', variant='primary')
172
- blender_config = gr.State(value={
173
- "source_max_length": DEFAULT_SOURCE_MAX_LENGTH,
174
- "candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH,
175
- "top_k_for_fuser": 3,
176
- "max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS,
177
- "temperature": 0.7,
178
- "top_p": 1.0,
179
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- with gr.Accordion(label='Advanced options', open=False):
182
- top_k_for_fuser = gr.Slider(
183
- label='Top-k ranked candidates to fuse',
184
- minimum=1,
185
- maximum=3,
186
- step=1,
187
- value=3,
 
188
  )
189
- source_max_length = gr.Slider(
190
- label='Max length of Instruction + Input',
191
- minimum=1,
192
- maximum=SOURCE_MAX_LENGTH,
193
- step=1,
194
- value=DEFAULT_SOURCE_MAX_LENGTH,
195
  )
196
- candidate_max_length = gr.Slider(
197
- label='Max length of LLM-Output Candidate',
198
- minimum=1,
199
- maximum=CANDIDATE_MAX_LENGTH,
200
- step=1,
201
- value=DEFAULT_CANDIDATE_MAX_LENGTH,
202
  )
203
- max_new_tokens = gr.Slider(
204
- label='Max new tokens fuser can generate',
205
- minimum=1,
206
- maximum=FUSER_MAX_NEW_TOKENS,
207
- step=1,
208
- value=DEFAULT_FUSER_MAX_NEW_TOKENS,
209
  )
210
- # temperature = gr.Slider(
211
- # label='Temperature of fuser generation',
212
- # minimum=0.1,
213
- # maximum=2.0,
214
- # step=0.1,
215
- # value=0.7,
216
- # )
217
- # top_p = gr.Slider(
218
- # label='Top-p of fuser generation',
219
- # minimum=0.05,
220
- # maximum=1.0,
221
- # step=0.05,
222
- # value=1.0,
223
- # )
224
- beam_size = gr.Slider(
225
- label='Beam size of fuser generation',
226
- minimum=1,
227
- maximum=10,
228
- step=1,
229
- value=4,
230
  )
231
-
232
- examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
233
- batch_examples = gr.Examples(
234
- examples=EXAMPLES,
235
- fn=get_preprocess_examples,
236
- cache_examples=True,
237
- examples_per_page=5,
238
- inputs=[inst_textbox, input_textbox],
239
- outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
240
- )
241
 
242
- base_llms_num.change(
243
- fn=update_base_llms_num,
244
- inputs=[base_llms_num, saved_llm_outputs],
245
- outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
246
- )
247
-
248
- examples_dummy_textbox.change(
249
- fn=update_base_llm_dropdown_along_examples,
250
- inputs=[examples_dummy_textbox],
251
- outputs=[saved_llm_outputs, rank_outputs, fuser_outputs],
252
- ).then(
253
- fn=display_llm_output,
254
- inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
255
- outputs=selected_base_llm_output,
256
- )
257
-
258
- selected_base_llm_name_dropdown.change(
259
- fn=display_llm_output,
260
- inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
261
- outputs=selected_base_llm_output,
262
- )
263
-
264
- base_llm_outputs_save_button.click(
265
- fn=save_llm_output,
266
- inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
267
- outputs=saved_llm_outputs,
268
- )
269
- base_llm_outputs_clear_all_button.click(
270
- fn=lambda: [{}, ""],
271
- inputs=[],
272
- outputs=[saved_llm_outputs, selected_base_llm_output],
273
- )
274
- base_llm_outputs_clear_single_button.click(
275
- fn=lambda: "",
276
- inputs=[],
277
- outputs=selected_base_llm_output,
278
- )
279
 
280
-
281
- rank_button.click(
282
- fn=check_save_ranker_inputs,
283
- inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
284
- outputs=blender_state,
285
- ).success(
286
- fn=llms_rank,
287
- inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
288
- outputs=[saved_rank_outputs, rank_outputs],
289
- )
290
-
291
- fuse_button.click(
292
- fn=check_fuser_inputs,
293
- inputs=[blender_state, blender_config, saved_rank_outputs],
294
- outputs=fuser_outputs,
295
- ).success(
296
- fn=llms_fuse,
297
- inputs=[blender_state, blender_config, saved_rank_outputs],
298
- outputs=[saved_fuse_outputs, fuser_outputs],
299
- )
300
-
301
- clear_button.click(
302
- fn=lambda: ["", "", {}, []],
303
- inputs=[],
304
- outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
305
- )
306
-
307
- # update blender config
308
- source_max_length.change(
309
- fn=lambda x, y: y.update({"source_max_length": x}) or y,
310
- inputs=[source_max_length, blender_config],
311
- outputs=blender_config,
312
- )
313
- candidate_max_length.change(
314
- fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
315
- inputs=[candidate_max_length, blender_config],
316
- outputs=blender_config,
317
- )
318
- top_k_for_fuser.change(
319
- fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
320
- inputs=[top_k_for_fuser, blender_config],
321
- outputs=blender_config,
322
- )
323
- max_new_tokens.change(
324
- fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
325
- inputs=[max_new_tokens, blender_config],
326
- outputs=blender_config,
327
- )
328
- # temperature.change(
329
- # fn=lambda x, y: y.update({"temperature": x}) or y,
330
- # inputs=[temperature, blender_config],
331
- # outputs=blender_config,
332
- # )
333
- # top_p.change(
334
- # fn=lambda x, y: y.update({"top_p": x}) or y,
335
- # inputs=[top_p, blender_config],
336
- # outputs=blender_config,
337
- # )
338
- beam_size.change(
339
- fn=lambda x, y: y.update({"num_beams": x}) or y,
340
- inputs=[beam_size, blender_config],
341
- outputs=blender_config,
342
- )
343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  demo.queue(max_size=20).launch()
 
1
  import gradio as gr
2
  import sys
3
  import os
4
+ import random
5
+ import llm_blender
6
+ import descriptions
7
  from datasets import load_dataset
8
+ from llm_blender.blender.blender_utils import get_topk_candidates_from_ranks
9
  from typing import List
10
 
11
+
12
  MAX_BASE_LLM_NUM = 20
13
  MIN_BASE_LLM_NUM = 3
14
  SOURCE_MAX_LENGTH = 256
 
17
  DEFAULT_CANDIDATE_MAX_LENGTH = 128
18
  FUSER_MAX_NEW_TOKENS = 512
19
  DEFAULT_FUSER_MAX_NEW_TOKENS = 256
 
 
 
 
20
  EXAMPLES_DATASET = load_dataset("llm-blender/mix-instruct", split='validation', streaming=True)
21
  SHUFFLED_EXAMPLES_DATASET = EXAMPLES_DATASET.shuffle(seed=42, buffer_size=1000)
22
  EXAMPLES = []
 
28
  ])
29
  CANDIDATE_EXAMPLES[example['instruction']+example['input']] = example['candidates']
30
 
31
+ HHH_EXAMPLES = []
32
+ subsets = ['harmless', 'helpful', 'honest', 'other']
33
+ random.seed(42)
34
+ for subset in subsets:
35
+ dataset = load_dataset("HuggingFaceH4/hhh_alignment", subset)
36
+ for example in dataset['test']:
37
+ if random.random() < 0.5:
38
+ HHH_EXAMPLES.append([
39
+ subset,
40
+ example['input'],
41
+ example['targets']['choices'][0],
42
+ example['targets']['choices'][1],
43
+ "Response 1" if example['targets']['labels'][0] == 1 else "Response 2",
44
+ ])
45
+ else:
46
+ HHH_EXAMPLES.append([
47
+ subset,
48
+ example['input'],
49
+ example['targets']['choices'][1],
50
+ example['targets']['choices'][0],
51
+ "Response 2" if example['targets']['labels'][0] == 1 else "Response 1",
52
+ ])
53
+ def get_hhh_examples(subset, instruction, response1, response2, dummy_text):
54
+ return instruction, response1, response2
55
 
56
+ MT_BENCH_HUMAN_JUDGE_EXAMPLES = []
57
+ dataset = load_dataset("lmsys/mt_bench_human_judgments")
58
+ for example in dataset['human']:
59
+ if example['turn'] != 1:
60
+ continue
61
+ MT_BENCH_HUMAN_JUDGE_EXAMPLES.append([
62
+ example['model_a'],
63
+ example['model_b'],
64
+ str(example['conversation_a']),
65
+ str(example['conversation_b']),
66
+ "Model A" if example['winner'] == 'model_a' else "Model B",
67
+ ])
68
+ def get_mt_bench_human_judge_examples(model_a, model_b, conversation_a, conversation_b, dummy_text):
69
+ chat_history_a = []
70
+ chat_history_b = []
71
+ conversation_a = eval(conversation_a)
72
+ conversation_b = eval(conversation_b)
73
+ for i in range(0, len(conversation_a), 2):
74
+ chat_history_a.append((conversation_a[i]['content'], conversation_a[i+1]['content']))
75
+ assert conversation_a[i]['role'] == 'user' and conversation_a[i+1]['role'] == 'assistant'
76
+ for i in range(0, len(conversation_b), 2):
77
+ chat_history_b.append((conversation_b[i]['content'], conversation_b[i+1]['content']))
78
+ assert conversation_b[i]['role'] == 'user' and conversation_b[i+1]['role'] == 'assistant'
79
+ return chat_history_a, chat_history_b
80
+
81
+
82
+ blender = llm_blender.Blender()
83
+ blender.loadranker("llm-blender/PairRM")
84
+ blender.loadfuser("llm-blender/gen_fuser_3b")
85
 
86
  def update_base_llms_num(k, llm_outputs):
87
  k = int(k)
88
+ return [gr.Dropdown(choices=[f"LLM-{i+1}" for i in range(k)],
89
  value=f"LLM-1" if k >= 1 else "", visible=True),
90
  {f"LLM-{i+1}": llm_outputs.get(f"LLM-{i+1}", "") for i in range(k)}]
91
 
92
 
93
  def display_llm_output(llm_outputs, selected_base_llm_name):
94
+ return gr.Textbox(value=llm_outputs.get(selected_base_llm_name, ""),
95
  label=selected_base_llm_name + " (Click Save to save current content)",
96
  placeholder=f"Enter {selected_base_llm_name} output here", show_label=True)
97
 
98
  def save_llm_output(selected_base_llm_name, selected_base_llm_output, llm_outputs):
99
+ llm_outputs({selected_base_llm_name: selected_base_llm_output})
100
  return llm_outputs
101
 
102
  def get_preprocess_examples(inst, input):
 
159
 
160
 
161
  with gr.Blocks(theme='ParityError/Anime') as demo:
162
+
163
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ with gr.Tab("LLM-Blender"):
166
+ # llm-blender interface
167
+ with gr.Row():
168
+ gr.Markdown(descriptions.LLM_BLENDER_OVERALL_DESC)
169
+ gr.Image("https://github.com/yuchenlin/LLM-Blender/blob/main/docs/llm_blender.png?raw=true", height=300)
170
+ gr.Markdown("## Input and Base LLMs")
171
+ with gr.Row():
172
+ with gr.Column():
173
+ inst_textbox = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
174
+ input_textbox = gr.Textbox(lines=4, label="Input Context", placeholder="Enter input context here", show_label=True)
175
+ with gr.Column():
176
+ saved_llm_outputs = gr.State(value={})
177
+ with gr.Group():
178
+ selected_base_llm_name_dropdown = gr.Dropdown(label="Base LLM",
179
+ choices=[f"LLM-{i+1}" for i in range(MIN_BASE_LLM_NUM)], value="LLM-1", show_label=True)
180
+ selected_base_llm_output = gr.Textbox(lines=4, label="LLM-1 (Click Save to save current content)",
181
+ placeholder="Enter LLM-1 output here", show_label=True)
182
+ with gr.Row():
183
+ base_llm_outputs_save_button = gr.Button('Save', variant='primary')
184
+
185
+ base_llm_outputs_clear_single_button = gr.Button('Clear Single', variant='primary')
186
+
187
+ base_llm_outputs_clear_all_button = gr.Button('Clear All', variant='primary')
188
+ base_llms_num = gr.Slider(
189
+ label='Number of base llms',
190
+ minimum=MIN_BASE_LLM_NUM,
191
+ maximum=MAX_BASE_LLM_NUM,
192
+ step=1,
193
+ value=MIN_BASE_LLM_NUM,
194
+ )
195
+
196
+ blender_state = gr.State(value={})
197
+ saved_rank_outputs = gr.State(value=[])
198
+ saved_fuse_outputs = gr.State(value=[])
199
+ gr.Markdown("## Blender Outputs")
200
+ with gr.Group():
201
+ rank_outputs = gr.Textbox(lines=1, label="Ranking outputs", placeholder="Ranking outputs", show_label=True)
202
+ fuser_outputs = gr.Textbox(lines=4, label="Fusing outputs", placeholder="Fusing outputs", show_label=True)
203
+ with gr.Row():
204
+ rank_button = gr.Button('Rank LLM Outputs', variant='primary')
205
+ fuse_button = gr.Button('Fuse Top-K ranked outputs', variant='primary')
206
+ clear_button = gr.Button('Clear Blender Outputs', variant='primary')
207
+ blender_config = gr.State(value={
208
+ "source_max_length": DEFAULT_SOURCE_MAX_LENGTH,
209
+ "candidate_max_length": DEFAULT_CANDIDATE_MAX_LENGTH,
210
+ "top_k_for_fuser": 3,
211
+ "max_new_tokens": DEFAULT_FUSER_MAX_NEW_TOKENS,
212
+ "temperature": 0.7,
213
+ "top_p": 1.0,
214
+ })
215
+
216
+ with gr.Accordion(label='Advanced options', open=False):
217
+ source_max_length = gr.Slider(
218
+ label='Max length of Instruction + Input',
219
+ minimum=1,
220
+ maximum=SOURCE_MAX_LENGTH,
221
+ step=1,
222
+ value=DEFAULT_SOURCE_MAX_LENGTH,
223
+ )
224
+ candidate_max_length = gr.Slider(
225
+ label='Max length of LLM-Output Candidate',
226
+ minimum=1,
227
+ maximum=CANDIDATE_MAX_LENGTH,
228
+ step=1,
229
+ value=DEFAULT_CANDIDATE_MAX_LENGTH,
230
+ )
231
+ top_k_for_fuser = gr.Slider(
232
+ label='Top-k ranked candidates to fuse',
233
+ minimum=1,
234
+ maximum=3,
235
+ step=1,
236
+ value=3,
237
+ )
238
+ max_new_tokens = gr.Slider(
239
+ label='Max new tokens fuser can generate',
240
+ minimum=1,
241
+ maximum=FUSER_MAX_NEW_TOKENS,
242
+ step=1,
243
+ value=DEFAULT_FUSER_MAX_NEW_TOKENS,
244
+ )
245
+ temperature = gr.Slider(
246
+ label='Temperature of fuser generation',
247
+ minimum=0.1,
248
+ maximum=2.0,
249
+ step=0.1,
250
+ value=0.7,
251
+ )
252
+ top_p = gr.Slider(
253
+ label='Top-p of fuser generation',
254
+ minimum=0.05,
255
+ maximum=1.0,
256
+ step=0.05,
257
+ value=1.0,
258
+ )
259
+
260
+ examples_dummy_textbox = gr.Textbox(lines=1, label="", placeholder="", show_label=False, visible=False)
261
+ batch_examples = gr.Examples(
262
+ examples=EXAMPLES,
263
+ fn=get_preprocess_examples,
264
+ cache_examples=True,
265
+ examples_per_page=5,
266
+ inputs=[inst_textbox, input_textbox],
267
+ outputs=[inst_textbox, input_textbox, base_llms_num, examples_dummy_textbox],
268
+ )
269
+
270
+ base_llms_num.change(
271
+ fn=update_base_llms_num,
272
+ inputs=[base_llms_num, saved_llm_outputs],
273
+ outputs=[selected_base_llm_name_dropdown, saved_llm_outputs],
274
+ )
275
 
276
+ examples_dummy_textbox.change(
277
+ fn=update_base_llm_dropdown_along_examples,
278
+ inputs=[examples_dummy_textbox],
279
+ outputs=[saved_llm_outputs, rank_outputs, fuser_outputs],
280
+ ).then(
281
+ fn=display_llm_output,
282
+ inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
283
+ outputs=selected_base_llm_output,
284
  )
285
+
286
+ selected_base_llm_name_dropdown.change(
287
+ fn=display_llm_output,
288
+ inputs=[saved_llm_outputs, selected_base_llm_name_dropdown],
289
+ outputs=selected_base_llm_output,
 
290
  )
291
+
292
+ base_llm_outputs_save_button.click(
293
+ fn=save_llm_output,
294
+ inputs=[selected_base_llm_name_dropdown, selected_base_llm_output, saved_llm_outputs],
295
+ outputs=saved_llm_outputs,
 
296
  )
297
+ base_llm_outputs_clear_all_button.click(
298
+ fn=lambda: [{}, ""],
299
+ inputs=[],
300
+ outputs=[saved_llm_outputs, selected_base_llm_output],
 
 
301
  )
302
+ base_llm_outputs_clear_single_button.click(
303
+ fn=lambda: "",
304
+ inputs=[],
305
+ outputs=selected_base_llm_output,
306
+ )
307
+
308
+
309
+ rank_button.click(
310
+ fn=check_save_ranker_inputs,
311
+ inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
312
+ outputs=blender_state,
313
+ ).success(
314
+ fn=llms_rank,
315
+ inputs=[inst_textbox, input_textbox, saved_llm_outputs, blender_config],
316
+ outputs=[saved_rank_outputs, rank_outputs],
 
 
 
 
 
317
  )
 
 
 
 
 
 
 
 
 
 
318
 
319
+ fuse_button.click(
320
+ fn=check_fuser_inputs,
321
+ inputs=[blender_state, blender_config, saved_rank_outputs],
322
+ outputs=[],
323
+ ).success(
324
+ fn=llms_fuse,
325
+ inputs=[blender_state, blender_config, saved_rank_outputs],
326
+ outputs=[saved_fuse_outputs, fuser_outputs],
327
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
+ clear_button.click(
330
+ fn=lambda: ["", "", {}, []],
331
+ inputs=[],
332
+ outputs=[rank_outputs, fuser_outputs, blender_state, saved_rank_outputs],
333
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
+ # update blender config
336
+ source_max_length.change(
337
+ fn=lambda x, y: y.update({"source_max_length": x}) or y,
338
+ inputs=[source_max_length, blender_config],
339
+ outputs=blender_config,
340
+ )
341
+ candidate_max_length.change(
342
+ fn=lambda x, y: y.update({"candidate_max_length": x}) or y,
343
+ inputs=[candidate_max_length, blender_config],
344
+ outputs=blender_config,
345
+ )
346
+ top_k_for_fuser.change(
347
+ fn=lambda x, y: y.update({"top_k_for_fuser": x}) or y,
348
+ inputs=[top_k_for_fuser, blender_config],
349
+ outputs=blender_config,
350
+ )
351
+ max_new_tokens.change(
352
+ fn=lambda x, y: y.update({"max_new_tokens": x}) or y,
353
+ inputs=[max_new_tokens, blender_config],
354
+ outputs=blender_config,
355
+ )
356
+ temperature.change(
357
+ fn=lambda x, y: y.update({"temperature": x}) or y,
358
+ inputs=[temperature, blender_config],
359
+ outputs=blender_config,
360
+ )
361
+ top_p.change(
362
+ fn=lambda x, y: y.update({"top_p": x}) or y,
363
+ inputs=[top_p, blender_config],
364
+ outputs=blender_config,
365
+ )
366
 
367
 
368
+ with gr.Tab("PairRM"):
369
+ # PairRM interface
370
+ with gr.Row():
371
+ gr.Markdown(descriptions.PairRM_OVERALL_DESC)
372
+ gr.Image("https://yuchenlin.xyz/LLM-Blender/pairranker.png")
373
+
374
+ with gr.Tab("Compare two responses"):
375
+ instruction = gr.Textbox(lines=1, label="Instruction", placeholder="Enter instruction here", show_label=True)
376
+ with gr.Row():
377
+ response1 = gr.Textbox(lines=4, label="Response 1", placeholder="Enter response 1 here", show_label=True)
378
+ response2 = gr.Textbox(lines=4, label="Response 2", placeholder="Enter response 2 here", show_label=True)
379
+ with gr.Row():
380
+ compare_button = gr.Button('Compare', variant='primary')
381
+ clear_button = gr.Button('Clear', variant='primary')
382
+ with gr.Row():
383
+ compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
384
+ compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
385
+
386
+ def compare_fn(inst, response1, response2):
387
+ if not inst:
388
+ raise gr.Error("Please enter instruction")
389
+ if not response1 or not response2:
390
+ raise gr.Error("Please enter response 1 and response 2")
391
+ comparison_results = blender.compare([inst], [response1], [response2], return_logits=True)
392
+ logit = comparison_results[0]
393
+ if logit > 0:
394
+ result = "Response 1 is better than Response 2"
395
+ prob = f"Confidence: {round(logit, 2)}"
396
+ elif logit < 0:
397
+ result = "Response 2 is better than Response 1"
398
+ prob = f"Cofidence: {round(abs(logit), 2)}"
399
+ else:
400
+ result = "Response 1 and Response 2 are equally good"
401
+ prob = f"No confidence for tie"
402
+
403
+ return [result, prob]
404
+ compare_button.click(
405
+ fn=compare_fn,
406
+ inputs=[instruction, response1, response2],
407
+ outputs=[compare_result, compare_result_prob],
408
+ )
409
+ clear_button.click(
410
+ fn=lambda: ["", ""],
411
+ inputs=[],
412
+ outputs=[compare_result, compare_result_prob],
413
+ )
414
+
415
+ hhh_dummy_textbox1 = gr.Textbox(lines=1, label="subset", placeholder="", show_label=False, visible=False)
416
+ hhh_dummy_textbox2 = gr.Textbox(lines=1, label="Better Response", placeholder="", show_label=False, visible=False)
417
+ gr.Markdown("## Examples from [HuggingFaceH4/hhh_alignment](https://huggingface.co/datasets/HuggingFaceH4/hhh_alignment)")
418
+ gr.Examples(
419
+ HHH_EXAMPLES,
420
+ fn=get_hhh_examples,
421
+ cache_examples=True,
422
+ examples_per_page=5,
423
+ inputs=[hhh_dummy_textbox1, instruction, response1, response2, hhh_dummy_textbox2],
424
+ outputs=[instruction, response1, response2],
425
+ )
426
+
427
+
428
+ with gr.Tab("Compare assistant's response in two multi-turn conversations"):
429
+
430
+ gr.Markdown("NOTE: Comparison of two conversations is based on that the user query in each turn is the same of two conversations.")
431
+ def append_message(message, chat_history):
432
+ if not message:
433
+ return "", chat_history
434
+ if len(chat_history) == 0:
435
+ chat_history.append((message, "(Please enter your bot response)"))
436
+ else:
437
+ if chat_history[-1][1] == "(Please enter your bot response)":
438
+ chat_history[-1] = (chat_history[-1][0], message)
439
+ else:
440
+ chat_history.append((message, "(Please enter your bot response)"))
441
+ return "", chat_history
442
+ with gr.Row():
443
+ with gr.Column():
444
+ gr.Markdown("### Conversation A")
445
+ chatbot1 = gr.Chatbot()
446
+ msg1 = gr.Textbox(lines=1, label="Enter Chat history for Conversation A", placeholder="Enter your message here", show_label=True)
447
+ clear1 = gr.ClearButton([msg1, chatbot1])
448
+ msg1.submit(append_message, [msg1, chatbot1], [msg1, chatbot1])
449
+ with gr.Column():
450
+ gr.Markdown("### Conversation B")
451
+ chatbot2 = gr.Chatbot()
452
+ msg2 = gr.Textbox(lines=1, label="Enter Chat history for Conversation B", placeholder="Enter your message here", show_label=True)
453
+ clear2 = gr.ClearButton([msg2, chatbot2])
454
+ msg2.submit(append_message, [msg2, chatbot2], [msg2, chatbot2])
455
+ with gr.Row():
456
+ compare_button = gr.Button('Compare', variant='primary')
457
+ with gr.Row():
458
+ compare_result = gr.Textbox(lines=1, label="Compare Result", placeholder="", show_label=True)
459
+ compare_result_prob = gr.Textbox(lines=1, label="PairRM Confidence", placeholder="", show_label=True)
460
+
461
+ def compare_conv_fn(chat_history1, chat_history2):
462
+ if len(chat_history1) == 0 or len(chat_history2) == 0:
463
+ raise gr.Error("Please enter chat history for both conversations")
464
+ assert chat_history1[-1][1] != "(Please enter your bot response)" \
465
+ and chat_history2[-1][1] != "(Please enter your bot response)", \
466
+ "Please complete chat history for both conversations"
467
+ chat1_messages = []
468
+ for item in chat_history1:
469
+ chat1_messages.append({
470
+ "role": "USER",
471
+ "content": item[0],
472
+ })
473
+ chat1_messages.append({
474
+ "role": "ASSISTANT",
475
+ "content": item[1],
476
+ })
477
+ chat2_messages = []
478
+ for item in chat_history2:
479
+ chat2_messages.append({
480
+ "role": "USER",
481
+ "content": item[0],
482
+ })
483
+ chat2_messages.append({
484
+ "role": "ASSISTANT",
485
+ "content": item[1],
486
+ })
487
+
488
+ comparison_results = blender.compare_conversations([chat1_messages], [chat2_messages], return_logits=True)
489
+ logit = comparison_results[0]
490
+ if logit > 0:
491
+ result = "Assistant's response in Conversation A is better than Conversation B"
492
+ prob = f"Confidence: {round(logit, 2)}"
493
+ elif logit < 0:
494
+ result = "Assistant's response in Conversation B is better than Conversation A"
495
+ prob = f"Cofidence: {round(abs(logit), 2)}"
496
+ else:
497
+ result = "Assistant's response in Conversation A and Conversation B are equally good"
498
+ prob = f"No confidence for tie"
499
+
500
+ return [result, prob]
501
 
502
+ compare_button.click(
503
+ fn=compare_conv_fn,
504
+ inputs=[chatbot1, chatbot2],
505
+ outputs=[compare_result, compare_result_prob],
506
+ )
507
+
508
+ model_a_dummy_textbox = gr.Textbox(lines=1, label="Model A", placeholder="", show_label=False, visible=False)
509
+ model_b_dummy_textbox = gr.Textbox(lines=1, label="Model B", placeholder="", show_label=False, visible=False)
510
+ winner_dummy_textbox = gr.Textbox(lines=1, label="Better Model in conversation", placeholder="", show_label=False, visible=False)
511
+ chatbot1_dummy_textbox = gr.Textbox(lines=1, label="Conversation A", placeholder="", show_label=False, visible=False)
512
+ chatbot2_dummy_textbox = gr.Textbox(lines=1, label="Conversation B", placeholder="", show_label=False, visible=False)
513
+ gr.Markdown("## Examples from [lmsys/mt_bench_human_judgments](https://huggingface.co/datasets/lmsys/mt_bench_human_judgments)")
514
+ gr.Examples(
515
+ MT_BENCH_HUMAN_JUDGE_EXAMPLES,
516
+ fn=get_mt_bench_human_judge_examples,
517
+ cache_examples=True,
518
+ examples_per_page=5,
519
+ inputs=[model_a_dummy_textbox, model_b_dummy_textbox, chatbot1_dummy_textbox, chatbot2_dummy_textbox, winner_dummy_textbox],
520
+ outputs=[chatbot1, chatbot2],
521
+ )
522
+
523
+ gr.Markdown(descriptions.CITATION)
524
  demo.queue(max_size=20).launch()
descriptions.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LLM_BLENDER_OVERALL_DESC = """
2
+ LLM-Blender is an innovative ensembling framework to attain consistently superior performance by leveraging the diverse strengths of multiple open-source large language models (LLMs).
3
+ LLM-Blender cut the weaknesses through ranking and integrate the strengths through fusing generation to enhance the capability of LLMs.
4
+
5
+ Our framework consists of two complementary modules: **PairRanker** and **GenFuser**,
6
+ addressing the observation that optimal LLMs for different examples can significantly vary.
7
+ **PairRanker** employs a specialized pairwise comparison method to distinguish subtle differences between candidate outputs.
8
+ **GenFuser** aims to merge the top-ranked candidates from the aggregation of PairRanker's pairwise
9
+ comparisons into an improved output by capitalizing on their strengths and mitigating their weaknesses.
10
+
11
+ | [Paper](https://arxiv.org/abs/2306.02561)
12
+ | [Code](https://github.com/yuchenlin/LLM-Blender)
13
+ | [Dataset](https://huggingface.co/datasets/llm-blender/mix-instruct)
14
+ | [Models](https://huggingface.co/llm-blender)
15
+ | [Tweet](https://twitter.com/billyuchenlin/status/1668666357058277377)
16
+
17
+ Try LLM-Blender now! 👇
18
+ """
19
+
20
+ PairRM_OVERALL_DESC = """## 🤗 [PairRM](https://huggingface.co/llm-blender/PairRM)
21
+
22
+ PairRM is a reward model based on PairRanker architecture that has been trained on various high-quality and
23
+ large-scale dataset with human preference annotations and exhibits great correlation with human preferences.
24
+
25
+ **While PairRM is a extremely small model (0.4B), our tests on various human alignment benchmarks show approaching performance of GPT-4.** (See [PairRM](https://huggingface.co/llm-blender/PairRM) for more detail results)
26
+
27
+ PairRM could be easily applied to 3 scenarios:
28
+ 1. [Directly compare two responses or two conversations](https://huggingface.co/llm-blender/PairRM#use-case-1-compare-responses-quality-evaluator)
29
+ 2. [Do best-of-n sampling to enhancing the ability of LLM](https://huggingface.co/llm-blender/PairRM#use-case-2-best-of-n-sampling-decoding-enhancing)
30
+ 3. [RLHF alignment](https://huggingface.co/llm-blender/PairRM#use-case-3-rlhf)
31
+
32
+ This demo allows user to interact with PairRM in the first scenario. Feel free to compare the quality of two responses or two conversations by inputting them in the text box below.
33
+
34
+ Try PairRM now! 👇
35
+ """
36
+
37
+ CITATION = """
38
+
39
+ ## Citation
40
+ ```bibtex
41
+ @inproceedings{llm-blender-2023,
42
+ title = "LLM-Blender: Ensembling Large Language Models with Pairwise Comparison and Generative Fusion",
43
+ author = "Jiang, Dongfu and Ren, Xiang and Lin, Bill Yuchen",
44
+ booktitle = "Proceedings of the 61th Annual Meeting of the Association for Computational Linguistics (ACL 2023)",
45
+ year = "2023"
46
+ }
47
+
48
+ ```
49
+ """