File size: 4,527 Bytes
52a6b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e139b0
52a6b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b55ea91
52a6b6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os

import gradio as gr
import spaces
import torch
from pdf2image import convert_from_path
from torch.utils.data import DataLoader
from tqdm import tqdm

from transformers import ColPaliForRetrieval, ColPaliProcessor


@spaces.GPU
def install_fa2():
    print("Install FA2")
    os.system("pip install flash-attn --no-build-isolation")


# install_fa2()

model_name = "vidore/colpali-v1.3-hf"


model = ColPaliForRetrieval.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",  # or "mps" if on Apple Silicon
    # attn_implementation="flash_attention_2", # should work on A100
).eval()
processor = ColPaliProcessor.from_pretrained(model_name)


@spaces.GPU
def search(query: str, ds, images, k):
    k = min(k, len(ds))
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)

    qs = []
    with torch.no_grad():
        batch_query = processor(text=[query]).to(model.device)
        query_embeddings = model(**batch_query).embeddings
        qs.extend(list(torch.unbind(query_embeddings.to("cpu"))))

    scores = processor.score_retrieval(qs, ds)

    top_k_indices = scores[0].topk(k).indices.tolist()

    results = []
    for idx in top_k_indices:
        results.append((images[idx], f"Page {idx}"))

    return results


def index(files, ds):
    print("Converting files")
    images = convert_files(files)
    print(f"Files converted with {len(images)} images.")
    return index_gpu(images, ds)


def convert_files(files):
    images = []
    for f in files:
        images.extend(convert_from_path(f, thread_count=4))

    if len(images) >= 150:
        raise gr.Error("The number of images in the dataset should be less than 150.")
    return images


@spaces.GPU
def index_gpu(images, ds):
    """Example script to run inference with ColPali"""

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)

    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: processor(images=x).to(model.device),
    )

    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc).embeddings
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Uploaded and converted {len(images)} pages", ds, images


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        "# ColPali: Efficient Document Retrieval with Vision Language Models πŸ“š"
    )
    gr.Markdown("""Demo to test the Transformers πŸ€— implementation of ColPali on PDF documents.<br>
    ColPali is the model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).<br>
    This demo allows you to upload PDF files and search for the most relevant pages based on your query.
    Refresh the page if you change documents!<br>
    ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.<br>
    Other models will be released with better robustness towards different languages and document formats!
    Demo by [manu](https://huggingface.co/spaces/manu/ColPali-demo)
    """)
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload PDFs")
            file = gr.File(file_count="multiple", label="Upload PDFs")

            convert_button = gr.Button("πŸ”„ Index documents")
            message = gr.Textbox("Files not yet uploaded", label="Status")
            embeds = gr.State(value=[])
            imgs = gr.State(value=[])

        with gr.Column(scale=3):
            gr.Markdown("## 2️⃣ Search")
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k = gr.Slider(
                minimum=1, maximum=10, step=1, label="Number of results", value=5
            )

    # Define the actions
    search_button = gr.Button("πŸ” Search", variant="primary")
    output_gallery = gr.Gallery(
        label="Retrieved Documents", height=600, show_label=True
    )

    convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
    search_button.click(
        search, inputs=[query, embeds, imgs, k], outputs=[output_gallery]
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)