|
import os |
|
import base64 |
|
from io import BytesIO |
|
|
|
import gradio as gr |
|
import torch |
|
|
|
from pdf2image import convert_from_path |
|
from PIL import Image |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from functools import partial |
|
|
|
from pqdm.processes import pqdm |
|
|
|
from colpali_engine.models import ColQwen2, ColQwen2Processor |
|
|
|
|
|
|
|
model = ColQwen2.from_pretrained( |
|
"vidore/colqwen2-v1.0", |
|
torch_dtype=torch.bfloat16, |
|
device_map="cuda:0", |
|
attn_implementation="flash_attention_2" |
|
).eval() |
|
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") |
|
|
|
|
|
def encode_image_to_base64(image): |
|
"""Encodes a PIL image to a base64 string.""" |
|
buffered = BytesIO() |
|
image.save(buffered, format="JPEG") |
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """ |
|
You are a smart assistant designed to answer questions about a PDF document. |
|
You are given relevant information in the form of PDF pages preceded by their metadata (PDF title, page number, surrounding context). |
|
Use them to construct a short response to the question, and cite your sources (page number, pdf title). |
|
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. |
|
Give detailed and extensive answers, only containing info in the pages you are given. |
|
You can answer using information contained in plots and figures if necessary. |
|
Answer in the same language as the query. |
|
""" |
|
|
|
def query_gpt4o_mini(query, images, api_key=None, system_prompt=DEFAULT_SYSTEM_PROMPT): |
|
"""Calls OpenAI's GPT-4o-mini with the query and image data.""" |
|
|
|
if api_key and api_key.startswith("sk"): |
|
try: |
|
from openai import OpenAI |
|
|
|
client = OpenAI(api_key=api_key.strip()) |
|
prompt = f""" |
|
{system_prompt} |
|
Query: {query} |
|
PDF pages: |
|
""" |
|
|
|
messages = [{"type": "text", "text": prompt}] |
|
for im, capt in images: |
|
if capt is not None: |
|
messages.append({ |
|
"type": "text", |
|
"text": capt |
|
}) |
|
messages.append({ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/jpeg;base64,{encode_image_to_base64(im)}" |
|
}, |
|
}) |
|
|
|
response = client.chat.completions.create( |
|
model="gpt-4o-mini", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": messages |
|
} |
|
], |
|
max_tokens=500, |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
return "OpenAI API connection failure. Verify the provided key is correct (sk-***)." |
|
|
|
return "Enter your OpenAI API key to get a custom response" |
|
|
|
|
|
def search(query: str, ds, images, metadatas, k, api_key): |
|
k = min(k, len(ds)) |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
if device != model.device: |
|
model.to(device) |
|
|
|
qs = [] |
|
with torch.no_grad(): |
|
batch_query = processor.process_queries([query]).to(model.device) |
|
embeddings_query = model(**batch_query) |
|
qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) |
|
|
|
scores = processor.score(qs, ds, device=device) |
|
|
|
top_k_indices = scores[0].topk(k).indices.tolist() |
|
|
|
results = [] |
|
for idx in top_k_indices: |
|
img = images[idx] |
|
meta = metadatas[idx] |
|
results.append((img, f"Document: {meta['title']}, Page: {meta['page']}, Context: {meta['context']}")) |
|
|
|
|
|
ai_response = query_gpt4o_mini(query, results, api_key) |
|
|
|
return results, ai_response |
|
|
|
|
|
def index(files, ds, api_key): |
|
print("Converting files") |
|
images, metadatas = convert_files(files, api_key) |
|
print(f"Files converted with {len(images)} images.") |
|
ds = index_gpu(images, ds) |
|
print(f"Indexed {len(ds)} images.") |
|
return f"Uploaded and converted {len(images)} pages", ds, images, metadatas |
|
|
|
DEFAULT_CONTEXT_PROMPT = """ |
|
You are a smart assistant designed to extract context of PDF pages. |
|
Give detailed and extensive answers, only containing info in the pages you are given. |
|
You can answer using information contained in plots and figures if necessary. |
|
Answer in the same language as the query. |
|
""" |
|
|
|
def extract_context(images, api_key, window=10): |
|
"""Extracts context from images.""" |
|
prompt = "Give the general context about these pages." |
|
window_contexts = [] |
|
|
|
args = [ |
|
{ |
|
'query': prompt, |
|
'images': zip(images[max(i-window+1, 0):i+1], [None]*len(images[max(i-window+1, 0):i+1])), |
|
'api_key': api_key, |
|
'system_prompt': DEFAULT_CONTEXT_PROMPT |
|
} for i in range(0, len(images), window) |
|
] |
|
window_contexts = pqdm(args, query_gpt4o_mini, n_jobs=8, argument_type='kwargs') |
|
|
|
|
|
|
|
|
|
|
|
|
|
contexts = [] |
|
for i in range(len(images)): |
|
context = window_contexts[i//window] |
|
contexts.append(context) |
|
|
|
print(f"Example context: {contexts[0]}") |
|
|
|
assert len(contexts) == len(images) |
|
return contexts |
|
|
|
def extract_metadata(file, images, api_key, window=10): |
|
"""Extracts metadata from pdfs. Extract page number, file name, and authors.""" |
|
title = file.split("/")[-1] |
|
contexts = extract_context(images, api_key, window=window) |
|
return [{"title": title, "page": i+1, "context": contexts[i]} for i in range(len(images))] |
|
|
|
def convert_files(files, api_key): |
|
images = [] |
|
metadatas = [] |
|
|
|
for f in files: |
|
file_images = convert_from_path(f, thread_count=4) |
|
file_metadatas = extract_metadata(f, file_images, api_key) |
|
images.extend(file_images) |
|
metadatas.extend(file_metadatas) |
|
|
|
if len(images) >= 500: |
|
raise gr.Error("The number of images in the dataset should be less than 500.") |
|
return images, metadatas |
|
|
|
|
|
def index_gpu(images, ds): |
|
"""Example script to run inference with ColPali (ColQwen2)""" |
|
|
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
if device != model.device: |
|
model.to(device) |
|
|
|
|
|
dataloader = DataLoader( |
|
images, |
|
batch_size=4, |
|
|
|
shuffle=False, |
|
collate_fn=lambda x: processor.process_images(x).to(model.device), |
|
) |
|
|
|
for batch_doc in tqdm(dataloader): |
|
with torch.no_grad(): |
|
batch_doc = {k: v.to(device) for k, v in batch_doc.items()} |
|
embeddings_doc = model(**batch_doc) |
|
ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) |
|
return ds |
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚") |
|
gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents. |
|
ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449). |
|
|
|
This demo allows you to upload PDF files and search for the most relevant pages based on your query. |
|
Refresh the page if you change documents ! |
|
|
|
⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages. |
|
Other models will be released with better robustness towards different languages and document formats ! |
|
""") |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
gr.Markdown("## 1️⃣ Upload PDFs") |
|
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs") |
|
|
|
convert_button = gr.Button("🔄 Index documents") |
|
message = gr.Textbox("Files not yet uploaded", label="Status") |
|
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key") |
|
embeds = gr.State(value=[]) |
|
imgs = gr.State(value=[]) |
|
metadatas = gr.State(value=[]) |
|
|
|
with gr.Column(scale=3): |
|
gr.Markdown("## 2️⃣ Search") |
|
query = gr.Textbox(placeholder="Enter your query here", label="Query") |
|
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5) |
|
|
|
|
|
|
|
search_button = gr.Button("🔍 Search", variant="primary") |
|
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True) |
|
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents") |
|
|
|
convert_button.click(index, inputs=[file, embeds, api_key], outputs=[message, embeds, imgs, metadatas]) |
|
search_button.click(search, inputs=[query, embeds, imgs, metadatas, k, api_key], outputs=[output_gallery, output_text]) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0") |