|
--- |
|
license: apache-2.0 |
|
tags: |
|
- vidore |
|
- reranker |
|
- qwen2_vl |
|
--- |
|
# MonoQwen2-VL-2B-LoRA-Reranker |
|
|
|
## Model Overview |
|
The **MonoQwen2-VL-v0.1** is a LoRA of the Qwen2-VL-2B model, optimized for reranking (i.e, asserting pointwise image-query relevance) using the [MonoT5](https://arxiv.org/pdf/2101.05667) objective. |
|
That is, given a couple of image and query fed into the prompt of the VLM, the model is tasked to generate "True" if the image is relevant to the query and "False" otherwise. |
|
During inference, a relevancy score can then be obtained by comparing the logits of the two tokens and this score can effectively be used to rerank the candidates generated by a first-stage retriever (such as DSE or ColPali) or filter them using a threshold. |
|
|
|
The [ColPali train set](https://huggingface.co/datasets/vidore/colpali_train_set) was used to train this model with negatives mined using DSE. |
|
|
|
## How to Use the Model |
|
Below is a quick example to rerank a single image against a user query using this model: |
|
|
|
```python |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration |
|
|
|
# Load processor and model |
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct") |
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
"lightonai/MonoQwen2-VL-v0.1", |
|
device_map="auto", |
|
# attn_implementation="flash_attention_2", |
|
# torch_dtype=torch.bfloat16, |
|
) |
|
|
|
# Define query and load image |
|
query = "What is ColPali?" |
|
image_path = "your/path/to/image.png" |
|
image = Image.open(image_path) |
|
|
|
# Construct the prompt and prepare input |
|
prompt = ( |
|
"Assert the relevance of the previous image document to the following query, " |
|
"answer True or False. The query is: {query}" |
|
).format(query=query) |
|
|
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image", "image": image}, |
|
{"type": "text", "text": prompt}, |
|
], |
|
} |
|
] |
|
|
|
# Apply chat template and tokenize |
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
inputs = processor(text=text, images=image, return_tensors="pt").to("cuda") |
|
|
|
# Run inference to obtain logits |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits_for_last_token = outputs.logits[:, -1, :] |
|
|
|
# Convert tokens and calculate relevance score |
|
true_token_id = processor.tokenizer.convert_tokens_to_ids("True") |
|
false_token_id = processor.tokenizer.convert_tokens_to_ids("False") |
|
relevance_score = torch.softmax(logits_for_last_token[:, [true_token_id, false_token_id]], dim=-1) |
|
|
|
# Extract and display probabilities |
|
true_prob = relevance_score[0, 0].item() |
|
false_prob = relevance_score[0, 1].item() |
|
|
|
print(f"True probability: {true_prob:.4f}, False probability: {false_prob:.4f}") |
|
``` |
|
|
|
This example demonstrates how to use the model to assess the relevance of an image with respect to a query. It outputs the probability that the image is relevant ("True") or not relevant ("False"). |
|
|
|
## Performance Metrics |
|
|
|
The model has been evaluated on [ViDoRe Benchmark](https://huggingface.co/spaces/vidore/vidore-leaderboard), by retrieving 10 elements with [MrLight_dse-qwen2-2b-mrl-v1](https://huggingface.co/MrLight/dse-qwen2-2b-mrl-v1) and reranking them. The table below summarizes its `ndcg@5` scores: |
|
|
|
| Dataset | NDCG@5 Before Reranking | NDCG@5 After Reranking | |
|
|---------------------------------------------------|--------------------------|------------------------| |
|
| vidore/arxivqa_test_subsampled | 85.6 | 89.0 | |
|
| vidore/docvqa_test_subsampled | 57.1 | 59.7 | |
|
| vidore/infovqa_test_subsampled | 88.1 | 93.2 | |
|
| vidore/tabfquad_test_subsampled | 93.1 | 96.0 | |
|
| vidore/shiftproject_test | 82.0 | 93.0 | |
|
| vidore/syntheticDocQA_artificial_intelligence_test| 97.5 | 100.0 | |
|
| vidore/syntheticDocQA_energy_test | 92.9 | 97.7 | |
|
| vidore/syntheticDocQA_government_reports_test | 96.0 | 98.0 | |
|
| vidore/syntheticDocQA_healthcare_industry_test | 96.4 | 99.3 | |
|
| vidore/tatdqa_test | 69.4 | 79.0 | |
|
| **Mean** | 85.8 | **90.5** | |
|
|
|
|
|
## License |
|
|
|
This LoRA model is licensed under the Apache 2.0 license. |