Text Generation
Transformers
Safetensors
qwen2
reranker
conversational
text-generation-inference
ptrdvn commited on
Commit
d765fb0
·
verified ·
1 Parent(s): 3ae0c6c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +367 -52
README.md CHANGED
@@ -1,73 +1,388 @@
1
  ---
2
  library_name: transformers
3
- license: other
4
- base_model: Qwen/Qwen2.5-0.5B-Instruct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  tags:
6
- - llama-factory
7
- - full
8
- - generated_from_trainer
9
- model-index:
10
- - name: reranker_continuous_filt_max7_rev_train
11
- results: []
 
12
  ---
13
 
14
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
- should probably proofread and complete it, then remove this comment. -->
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # reranker_continuous_filt_max7_rev_train
18
 
19
- This model is a fine-tuned version of [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) on the reranker_continuous_filt_max7_rev_train dataset.
20
- It achieves the following results on the evaluation set:
21
- - Loss: 0.3944
22
 
23
- ## Model description
 
 
 
 
 
 
 
 
24
 
25
- More information needed
26
 
27
- ## Intended uses & limitations
28
 
29
- More information needed
30
 
31
- ## Training and evaluation data
32
 
33
- More information needed
34
 
35
- ## Training procedure
36
 
37
- ### Training hyperparameters
38
 
39
- The following hyperparameters were used during training:
40
- - learning_rate: 1e-05
41
- - train_batch_size: 1
42
- - eval_batch_size: 1
43
- - seed: 42
44
- - distributed_type: multi-GPU
45
- - num_devices: 8
46
- - total_train_batch_size: 8
47
- - total_eval_batch_size: 8
48
- - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
49
- - lr_scheduler_type: cosine
50
- - lr_scheduler_warmup_ratio: 0.01
51
- - num_epochs: 1.0
52
 
53
- ### Training results
54
 
55
- | Training Loss | Epoch | Step | Validation Loss |
56
- |:-------------:|:------:|:----:|:---------------:|
57
- | 0.4901 | 0.1001 | 908 | 0.4728 |
58
- | 0.4144 | 0.2001 | 1816 | 0.4561 |
59
- | 0.3597 | 0.3002 | 2724 | 0.4485 |
60
- | 0.4041 | 0.4002 | 3632 | 0.4265 |
61
- | 0.3714 | 0.5003 | 4540 | 0.4173 |
62
- | 0.4307 | 0.6003 | 5448 | 0.4078 |
63
- | 0.34 | 0.7004 | 6356 | 0.4024 |
64
- | 0.4315 | 0.8004 | 7264 | 0.3974 |
65
- | 0.4267 | 0.9005 | 8172 | 0.3952 |
66
 
 
67
 
68
- ### Framework versions
 
 
69
 
70
- - Transformers 4.46.1
71
- - Pytorch 2.5.1+cu124
72
- - Datasets 3.1.0
73
- - Tokenizers 0.20.3
 
1
  ---
2
  library_name: transformers
3
+ license: apache-2.0
4
+ language:
5
+ - en
6
+ - zh
7
+ - es
8
+ - de
9
+ - ar
10
+ - ru
11
+ - ja
12
+ - ko
13
+ - hi
14
+ - sk
15
+ - vi
16
+ - tr
17
+ - fi
18
+ - id
19
+ - fa
20
+ - 'no'
21
+ - th
22
+ - sv
23
+ - pt
24
+ - da
25
+ - bn
26
+ - te
27
+ - ro
28
+ - it
29
+ - fr
30
+ - nl
31
+ - sw
32
+ - pl
33
+ - hu
34
+ - cs
35
+ - el
36
+ - uk
37
+ - mr
38
+ - ta
39
+ - tl
40
+ - bg
41
+ - lt
42
+ - ur
43
+ - he
44
+ - gu
45
+ - kn
46
+ - am
47
+ - kk
48
+ - hr
49
+ - uz
50
+ - jv
51
+ - ca
52
+ - az
53
+ - ms
54
+ - sr
55
+ - sl
56
+ - yo
57
+ - lv
58
+ - is
59
+ - ha
60
+ - ka
61
+ - et
62
+ - bs
63
+ - hy
64
+ - ml
65
+ - pa
66
+ - mt
67
+ - km
68
+ - sq
69
+ - or
70
+ - as
71
+ - my
72
+ - mn
73
+ - af
74
+ - be
75
+ - ga
76
+ - mk
77
+ - cy
78
+ - gl
79
+ - ceb
80
+ - la
81
+ - yi
82
+ - lb
83
+ - tg
84
+ - gd
85
+ - ne
86
+ - ps
87
+ - eu
88
+ - ky
89
+ - ku
90
+ - si
91
+ - ht
92
+ - eo
93
+ - lo
94
+ - fy
95
+ - sd
96
+ - mg
97
+ - so
98
+ - ckb
99
+ - su
100
+ - nn
101
+ datasets:
102
+ - lightblue/reranker_continuous_filt_max7_train
103
+ base_model:
104
+ - Qwen/Qwen2.5-0.5B-Instruct
105
+ pipeline_tag: text-generation
106
  tags:
107
+ - reranker
108
+ widget:
109
+ - text: "<<<Query>>>\nHow many languages has LB-Reranker been trained on?\n\n\n<<<Context>>>\nLB-Reranker has been trained on more than 95 languages."
110
+ example_title: Positive example (7/7)
111
+ - text: "<<<Query>>>\nHow many languages has LB-Reranker been trained on?\n\n\n<<<Context>>>\nAA-Reranker is applicable to a broad range of use cases."
112
+ example_title: Negative example (2/7)
113
+
114
  ---
115
 
116
+ # LB Reranker v1.0
117
+
118
+ <div style="width: 100%; height: 160px;
119
+ display: flex; align-items: center;
120
+ justify-content: center;
121
+ border: 8px solid black;
122
+ font-size: 120px; font-weight: bold;
123
+ text-align: center;
124
+ color: #438db8;
125
+ font-family: 'Helvetica Neue', sans-serif;">
126
+ LBR
127
+ </div>
128
+
129
+
130
+ This is a reversed version of the original LB Reranker - (lightblue/lb-reranker-0.5B-v1.0)[https://huggingface.co/lightblue/lb-reranker-0.5B-v1.0].
131
+ With this version, you input the text, then the query into the reranker, allowing for caching of the text instead of the query.
132
+
133
+ The LB Reranker has been trained to determine the relatedness of a given query to a piece of text, therefore allowing it to be used as a ranker or reranker in various retrieval-based tasks.
134
+
135
+ This model is fine-tuned from a [Qwen/Qwen2.5-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) model checkpoint and was trained for roughly 5.5 hours using the 8 x L20 instance ([ecs.gn8is-8x.32xlarge](https://www.alibabacloud.com/help/en/ecs/user-guide/gpu-accelerated-compute-optimized-and-vgpu-accelerated-instance-families-1)) on [Alibaba Cloud](https://www.alibabacloud.com/).
136
+
137
+ The training data for this model can be found at [lightblue/reranker_continuous_filt_max7_train](https://huggingface.co/datasets/lightblue/reranker_continuous_filt_max7_train) and the code for generating this data as well as running the training of the model can be found on [our Github repo](https://github.com/lightblue-tech/lb-reranker).
138
+
139
+ Trained on data in over 95 languages, this model is applicable to a broad range of use cases.
140
+
141
+ This model has three main benefits over comparable rerankers.
142
+ 1. It has shown slightly higher performance on evaluation benchmarks.
143
+ 2. It has been trained on more languages than any previous model.
144
+ 3. It is a simple Causal LM model trained to output a string between "1" and "7".
145
+
146
+ This last point means that this model can be used natively with many widely available inference packages, including vLLM and LMDeploy.
147
+ This in turns allows our reranker to benefit from improvements to inference as and when these packages release them.
148
+
149
+ Update: We have also found that this model works pretty well as a code snippet reranker too (P@1 of 96%)! See our [Colab](https://colab.research.google.com/drive/1ABL1xaarekLIlVJKbniYhXgYu6ZNwfBm?usp=sharing) for more details.
150
+
151
+ # How to use
152
+
153
+ The model was trained to expect an input such as:
154
+
155
+ ```
156
+ <<<Context>>>
157
+ {your_query_here}
158
+
159
+ <<<Query>>>
160
+ {your_context_here}
161
+ ```
162
+
163
+ And to output a string of a number between 1-7.
164
+
165
+ In order to make a continuous score that can be used for reranking query-context pairs (i.e. a method with few ties), we calculate the expectation value of the scores.
166
+
167
+ We include scripts to do this in vLLM, LMDeploy, and OpenAI (hosted for free on Huggingface):
168
+
169
+
170
+ <ul>
171
+ <li><b>vLLM</b>
172
+
173
+ Install [vLLM](https://github.com/vllm-project/vllm/) using `pip install vllm`.
174
+
175
+ <details open>
176
+ <summary>Show vLLM code</summary>
177
+
178
+ ```python
179
+ from vllm import LLM, SamplingParams
180
+ import numpy as np
181
+
182
+ def make_reranker_input(t, q):
183
+ return f"<<<Context>>>\n{q}\n\n<<<Query>>>\n{t}"
184
+
185
+ def make_reranker_inference_conversation(context, question):
186
+ system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
187
+
188
+ return [
189
+ {"role": "system", "content": system_message},
190
+ {"role": "user", "content": make_reranker_input(context, question)},
191
+ ]
192
+
193
+ def get_prob(logprob_dict, tok_id):
194
+ return np.exp(logprob_dict[tok_id].logprob) if tok_id in logprob_dict.keys() else 0
195
+
196
+ llm = LLM("lightblue/lb-reranker-0.5B-v1.0-rev")
197
+ sampling_params = SamplingParams(temperature=0.0, logprobs=14, max_tokens=1)
198
+ tok = llm.llm_engine.tokenizer.tokenizer
199
+ idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)]
200
+
201
+ query_texts = [
202
+ ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
203
+ ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
204
+ ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
205
+ ]
206
+
207
+ chats = [make_reranker_inference_conversation(c, q) for q, c in query_texts]
208
+ responses = llm.chat(chats, sampling_params)
209
+ probs = np.array([[get_prob(r.outputs[0].logprobs[0], y) for y in idx_tokens] for r in responses])
210
+
211
+ N = probs.shape[1]
212
+ M = probs.shape[0]
213
+ idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N)
214
+
215
+ expected_vals = (probs * idxs).sum(axis=1)
216
+ print(expected_vals)
217
+ # [6.66570732 1.86686378 1.01102923]
218
+ ```
219
+
220
+ </details></li>
221
+ <li><b>LMDeploy</b>
222
+
223
+ Install [LMDeploy](https://github.com/InternLM/lmdeploy) using `pip install lmdeploy`.
224
+
225
+ <details>
226
+ <summary>Show LMDeploy code</summary>
227
+
228
+ ```python
229
+ # Un-comment this if running in a Jupyter notebook, Colab etc.
230
+ # import nest_asyncio
231
+ # nest_asyncio.apply()
232
+
233
+ from lmdeploy import GenerationConfig, ChatTemplateConfig, pipeline
234
+ import numpy as np
235
+
236
+ def make_reranker_input(t, q):
237
+ return f"<<<Context>>>\n{q}\n\n<<<Query>>>\n{t}"
238
+
239
+ def make_reranker_inference_conversation(context, question):
240
+ system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
241
+
242
+ return [
243
+ {"role": "system", "content": system_message},
244
+ {"role": "user", "content": make_reranker_input(context, question)},
245
+ ]
246
+
247
+ def get_prob(logprob_dict, tok_id):
248
+ return np.exp(logprob_dict[tok_id]) if tok_id in logprob_dict.keys() else 0
249
+
250
+ pipe = pipeline(
251
+ "lightblue/lb-reranker-0.5B-v1.0-rev",
252
+ chat_template_config=ChatTemplateConfig(
253
+ model_name='qwen2d5',
254
+ capability='chat'
255
+ )
256
+ )
257
+ tok = pipe.tokenizer.model
258
+ idx_tokens = [tok.encode(str(i))[0] for i in range(1, 8)]
259
+
260
+ query_texts = [
261
+ ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
262
+ ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
263
+ ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
264
+ ]
265
+
266
+ chats = [make_reranker_inference_conversation(c, q) for q, c in query_texts]
267
+ responses = pipe(
268
+ chats,
269
+ gen_config=GenerationConfig(temperature=1.0, logprobs=14, max_new_tokens=1, do_sample=True)
270
+ )
271
+ probs = np.array([[get_prob(r.logprobs[0], y) for y in idx_tokens] for r in responses])
272
+
273
+ N = probs.shape[1]
274
+ M = probs.shape[0]
275
+ idxs = np.tile(np.arange(1, N + 1), M).reshape(M, N)
276
+
277
+ expected_vals = (probs * idxs).sum(axis=1)
278
+ print(expected_vals)
279
+ # [6.66415229 1.84342025 1.01133205]
280
+ ```
281
+
282
+ </details></li>
283
+ <li><b>OpenAI (Hosted on Huggingface)</b>
284
+
285
+ Install [openai](https://github.com/openai/openai-python) using `pip install openai`.
286
+
287
+ <details>
288
+ <summary>Show OpenAI + Huggingface Inference code</summary>
289
+
290
+ ```python
291
+ from openai import OpenAI
292
+ import numpy as np
293
+ from multiprocessing import Pool
294
+ from tqdm.auto import tqdm
295
+
296
+ client = OpenAI(
297
+ base_url="https://api-inference.huggingface.co/v1/",
298
+ api_key="hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" # Change this to an access token from https://huggingface.co/settings/tokens
299
+ )
300
+
301
+ def make_reranker_input(t, q):
302
+ return f"<<<Context>>>\n{q}\n\n<<<Query>>>\n{t}"
303
+
304
+ def make_reranker_inference_conversation(context, question):
305
+ system_message = "Given a query and a piece of text, output a score of 1-7 based on how related the query is to the text. 1 means least related and 7 is most related."
306
+
307
+ return [
308
+ {"role": "system", "content": system_message},
309
+ {"role": "user", "content": make_reranker_input(context, question)},
310
+ ]
311
+
312
+ def get_reranker_score(context_question_tuple):
313
+ question, context = context_question_tuple
314
+
315
+ messages = make_reranker_inference_conversation(context, question)
316
+
317
+ completion = client.chat.completions.create(
318
+ model="lightblue/lb-reranker-0.5B-v1.0-rev",
319
+ messages=messages,
320
+ max_tokens=1,
321
+ temperature=0.0,
322
+ logprobs=True,
323
+ top_logprobs=5, # Max allowed by the openai API as top_n_tokens must be >= 0 and <= 5. If this gets changed, fix to > 7.
324
+ )
325
+
326
+ logprobs = completion.choices[0].logprobs.content[0].top_logprobs
327
+
328
+ calculated_score = sum([int(x.token) * np.exp(x.logprob) for x in logprobs])
329
+
330
+ return calculated_score
331
+
332
+ query_texts = [
333
+ ("What is the scientific name of apples?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
334
+ ("What is the Chinese word for 'apple'?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
335
+ ("What is the square root of 999?", "An apple is a round, edible fruit produced by an apple tree (Malus spp., among them the domestic or orchard apple; Malus domestica)."),
336
+ ]
337
+
338
+ with Pool(processes=16) as p: # Allows for parallel processing
339
+ expected_vals = list(tqdm(p.imap(get_reranker_score, query_texts), total=len(query_texts)))
340
+
341
+ print(expected_vals)
342
+ # [6.64866580, 1.85144404, 1.010719508]
343
+ ```
344
+
345
+ </details></li>
346
+ </ul>
347
 
348
+ # Evaluation
349
 
350
+ We perform an evaluation on 9 datasets from the [BEIR benchmark](https://github.com/beir-cellar/beir) that none of the evaluated models have been trained upon (to our knowledge).
 
 
351
 
352
+ * Arguana
353
+ * Dbpedia-entity
354
+ * Fiqa
355
+ * NFcorpus
356
+ * Scidocs
357
+ * Scifact
358
+ * Trec-covid-v2
359
+ * Vihealthqa
360
+ * Webis-touche2020
361
 
362
+ We evaluate on a subset of all queries (the first 250) to save evaluation time.
363
 
364
+ We find that our model performs similarly or better than many of the state-of-the-art reranker models in our evaluation, without compromising on inference speed.
365
 
366
+ We make our evaluation code and results available [on our Github](https://github.com/lightblue-tech/lb-reranker/blob/main/run_bier.ipynb).
367
 
368
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/xkNzCABFUmU7UmDXUduiz.png)
369
 
370
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/P-XCA3TGHqDSX8k6c4hCE.png)
371
 
372
+ As we can see, this reranker attains greater IR evaluation metrics compared to the two benchmarks we include for all positions apart from @1.
373
 
374
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/64b63f8ad57e02621dc93c8b/puhhWseBOcIyOEdW4L-B0.png)
375
 
376
+ We also show that our model is, on average, faster than the BGE reranker v2.
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
+ # License
379
 
380
+ We share this model under an Apache 2.0 license.
 
 
 
 
 
 
 
 
 
 
381
 
382
+ # Developed by
383
 
384
+ <a href="https://www.lightblue-tech.com">
385
+ <img src="https://www.lightblue-tech.com/wp-content/uploads/2023/08/color_%E6%A8%AA%E5%9E%8B-1536x469.png" alt="Lightblue technology logo" width="400"/>
386
+ </a>
387
 
388
+ This model was trained by Peter Devine ([ptrdvn](https://huggingface.co/ptrdvn)) for Lightblue