File size: 9,481 Bytes
94b55f0 a246b15 94b55f0 602d806 89cecf3 3649694 602d806 5dfd724 8666c3a 4f3a756 9b1e831 9c66171 b9c715a 9b1e831 e7f8afe 9b1e831 5d48f31 9b1e831 e7f8afe 9b1e831 9c66171 6efb913 4f3a756 8666c3a 6efb913 73f30e5 a129662 4f3a756 a129662 4f3a756 a129662 4f3a756 a129662 73f30e5 a129662 4f3a756 a129662 73f30e5 6efb913 602d806 4f3a756 4be1e51 9c66171 602d806 9b1e831 602d806 9b1e831 0d01d71 9b1e831 0d01d71 4f3a756 0d01d71 6efb913 73f30e5 6efb913 602d806 4f3a756 ec28a2a 4f3a756 ec28a2a 4f3a756 8666c3a 4f3a756 8666c3a 4f3a756 602d806 4f3a756 0d01d71 4f3a756 602d806 f1d7f41 4f3a756 0d01d71 ec28a2a d546c80 9b1e831 602d806 068f2e8 b78ca65 602d806 9b1e831 602d806 a2d6d06 602d806 9c66171 602d806 4f3a756 602d806 0d01d71 602d806 dad1e49 d546c80 0d01d71 5923654 0d01d71 f700076 9357d80 0d01d71 602d806 f700076 0d01d71 73f30e5 0d01d71 4f3a756 602d806 0d01d71 10278bd 602d806 fa73ad0 0d01d71 6efb913 0d01d71 4f3a756 602d806 4f3a756 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
import os
import base64
from io import BytesIO
import gradio as gr
import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from functools import partial
from pqdm.processes import pqdm
from colpali_engine.models import ColQwen2, ColQwen2Processor
model = ColQwen2.from_pretrained(
"vidore/colqwen2-v1.0",
torch_dtype=torch.bfloat16,
device_map="cuda:0", # or "mps" if on Apple Silicon
attn_implementation="flash_attention_2"
).eval()
processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0")
def encode_image_to_base64(image):
"""Encodes a PIL image to a base64 string."""
buffered = BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
DEFAULT_SYSTEM_PROMPT = """
You are a smart assistant designed to answer questions about a PDF document.
You are given relevant information in the form of PDF pages preceded by their metadata (PDF title, page number, surrounding context).
Use them to construct a short response to the question, and cite your sources (page number, pdf title).
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
Give detailed and extensive answers, only containing info in the pages you are given.
You can answer using information contained in plots and figures if necessary.
Answer in the same language as the query.
"""
def query_gpt4o_mini(query, images, api_key=None, system_prompt=DEFAULT_SYSTEM_PROMPT):
"""Calls OpenAI's GPT-4o-mini with the query and image data."""
if api_key and api_key.startswith("sk"):
try:
from openai import OpenAI
client = OpenAI(api_key=api_key.strip())
prompt = f"""
{system_prompt}
Query: {query}
PDF pages:
"""
messages = [{"type": "text", "text": prompt}]
for im, capt in images:
if capt is not None:
messages.append({
"type": "text",
"text": capt
})
messages.append({
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{encode_image_to_base64(im)}"
},
})
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": messages
}
],
max_tokens=500,
)
return response.choices[0].message.content
except Exception as e:
return "OpenAI API connection failure. Verify the provided key is correct (sk-***)."
return "Enter your OpenAI API key to get a custom response"
def search(query: str, ds, images, metadatas, k, api_key):
k = min(k, len(ds))
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
qs = []
with torch.no_grad():
batch_query = processor.process_queries([query]).to(model.device)
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
scores = processor.score(qs, ds, device=device)
top_k_indices = scores[0].topk(k).indices.tolist()
results = []
for idx in top_k_indices:
img = images[idx]
meta = metadatas[idx]
results.append((img, f"Document: {meta['title']}, Page: {meta['page']}, Context: {meta['context']}"))
# Generate response from GPT-4o-mini
ai_response = query_gpt4o_mini(query, results, api_key)
return results, ai_response
def index(files, ds, api_key):
print("Converting files")
images, metadatas = convert_files(files, api_key)
print(f"Files converted with {len(images)} images.")
ds = index_gpu(images, ds)
print(f"Indexed {len(ds)} images.")
return f"Uploaded and converted {len(images)} pages", ds, images, metadatas
DEFAULT_CONTEXT_PROMPT = """
You are a smart assistant designed to extract context of PDF pages.
Give detailed and extensive answers, only containing info in the pages you are given.
You can answer using information contained in plots and figures if necessary.
Answer in the same language as the query.
"""
def extract_context(images, api_key, window=10):
"""Extracts context from images."""
prompt = "Give the general context about these pages."
window_contexts = []
args = [
{
'query': prompt,
'images': zip(images[max(i-window+1, 0):i+1], [None]*len(images[max(i-window+1, 0):i+1])),
'api_key': api_key,
'system_prompt': DEFAULT_CONTEXT_PROMPT
} for i in range(0, len(images), window)
]
window_contexts = pqdm(args, query_gpt4o_mini, n_jobs=8, argument_type='kwargs')
# for i in tqdm(range(0, len(images), window), desc="Extracting context", total=len(images)//window):
# window_images = images[max(i-window+1, 0):i+1]
# window_images = [(image, None) for image in window_images]
# window_contexts.append(query_gpt4o_mini(prompt, window_images, api_key, system_prompt=DEFAULT_CONTEXT_PROMPT))
contexts = []
for i in range(len(images)):
context = window_contexts[i//window]
contexts.append(context)
print(f"Example context: {contexts[0]}")
assert len(contexts) == len(images)
return contexts
def extract_metadata(file, images, api_key, window=10):
"""Extracts metadata from pdfs. Extract page number, file name, and authors."""
title = file.split("/")[-1]
contexts = extract_context(images, api_key, window=window)
return [{"title": title, "page": i+1, "context": contexts[i]} for i in range(len(images))]
def convert_files(files, api_key):
images = []
metadatas = []
for f in files:
file_images = convert_from_path(f, thread_count=4)
file_metadatas = extract_metadata(f, file_images, api_key)
images.extend(file_images)
metadatas.extend(file_metadatas)
if len(images) >= 500:
raise gr.Error("The number of images in the dataset should be less than 500.")
return images, metadatas
def index_gpu(images, ds):
"""Example script to run inference with ColPali (ColQwen2)"""
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device != model.device:
model.to(device)
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
# num_workers=4,
shuffle=False,
collate_fn=lambda x: processor.process_images(x).to(model.device),
)
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
return ds
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚")
gr.Markdown("""Demo to test ColQwen2 (ColPali) on PDF documents.
ColPali is model implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
This demo allows you to upload PDF files and search for the most relevant pages based on your query.
Refresh the page if you change documents !
⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.
Other models will be released with better robustness towards different languages and document formats !
""")
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## 1️⃣ Upload PDFs")
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")
convert_button = gr.Button("🔄 Index documents")
message = gr.Textbox("Files not yet uploaded", label="Status")
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
embeds = gr.State(value=[])
imgs = gr.State(value=[])
metadatas = gr.State(value=[])
with gr.Column(scale=3):
gr.Markdown("## 2️⃣ Search")
query = gr.Textbox(placeholder="Enter your query here", label="Query")
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=5)
# Define the actions
search_button = gr.Button("🔍 Search", variant="primary")
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
convert_button.click(index, inputs=[file, embeds, api_key], outputs=[message, embeds, imgs, metadatas])
search_button.click(search, inputs=[query, embeds, imgs, metadatas, k, api_key], outputs=[output_gallery, output_text])
if __name__ == "__main__":
demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0") |