File size: 8,513 Bytes
f30ab0b 33eef8d f30ab0b d9c53d1 f30ab0b d9c53d1 f30ab0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
import os
import spaces
import gradio as gr
import torch
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from pdf2image import convert_from_path
from PIL import Image, ImageEnhance
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import faiss # FAISS for fast retrieval
import numpy as np
# Initialize FAISS index for fast similarity search (used only if selected)
embedding_dim = 448
faiss_index = faiss.IndexFlatL2(embedding_dim)
stored_images = [] # To store images associated with embeddings for retrieval if using FAISS
def preprocess_image(image_path, grayscale=False):
"""Apply optional grayscale and other enhancements to images."""
img = Image.open(image_path)
if grayscale:
img = img.convert("L") # Apply grayscale if selected
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(2.0) # Sharpen
return img
@spaces.GPU
def model_inference(images, text, grayscale=False):
"""Qwen2VL-based inference function with optional grayscale processing."""
images = [
{
"type": "image",
"image": preprocess_image(image[0], grayscale=grayscale),
"resized_height": 1344,
"resized_width": 1344,
}
for image in images
]
images.append({"type": "text", "text": text})
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda:0")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
messages = [{"role": "user", "content": images}]
text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, _ = process_vision_info(messages)
inputs = processor(
text=[text_input], images=image_inputs, padding=True, return_tensors="pt"
).to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=512)
output_text = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
del model, processor
torch.cuda.empty_cache()
return output_text[0]
@spaces.GPU
def search(query: str, ds, images, k, retrieval_method="CustomEvaluator"):
"""Search function with option to choose between CustomEvaluator and FAISS for retrieval."""
model_name = "vidore/colpali-v1.2"
token = os.environ.get("HF_TOKEN")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ColPali.from_pretrained(
"vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token=token
).eval().to(device)
processor = AutoProcessor.from_pretrained(model_name, token=token)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
# Process the query to obtain embeddings
batch_query = process_queries(processor, [query], mock_image)
embeddings_query = model(**{k: v.to(device) for k, v in batch_query.items()})
query_embedding = embeddings_query[0].cpu().numpy()
if retrieval_method == "FAISS":
# Use FAISS for efficient retrieval
distances, indices = faiss_index.search(np.array([query_embedding]), k)
results = [stored_images[idx] for idx in indices[0]]
else:
# Use CustomEvaluator for retrieval
qs = [query_embedding]
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
results = [images[idx] for idx in top_k_indices]
del model, processor
torch.cuda.empty_cache()
return results
def index(files, ds):
"""Convert and index PDF files."""
images = convert_files(files)
return index_gpu(images, ds)
def convert_files(files):
"""Convert PDF files to images."""
images = []
for f in files:
images.extend(convert_from_path(f, thread_count=4))
if len(images) >= 150:
raise gr.Error("The number of images in the dataset should be less than 150.")
return images
@spaces.GPU
def index_gpu(images, ds):
"""Index documents using FAISS or store in dataset for CustomEvaluator."""
global stored_images
model_name = "vidore/colpali-v1.2"
token = os.environ.get("HF_TOKEN")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ColPali.from_pretrained(
"vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token=token
).eval().to(device)
processor = AutoProcessor.from_pretrained(model_name, token=token)
mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
dataloader = DataLoader(images, batch_size=4, shuffle=False, collate_fn=lambda x: process_images(processor, x))
all_embeddings = []
for batch in tqdm(dataloader):
with torch.no_grad():
batch = {k: v.to(device) for k, v in batch.items()}
embeddings_doc = model(**batch)
all_embeddings.extend(embeddings_doc.cpu().numpy())
# Store embeddings in FAISS index and dataset for respective retrieval options
embeddings = np.array(all_embeddings)
faiss_index.add(embeddings) # Add to FAISS index
ds.extend(list(torch.unbind(torch.tensor(embeddings)))) # Extend original ds for CustomEvaluator
stored_images.extend(images) # Store images to link with FAISS indices
del model, processor
torch.cuda.empty_cache()
return f"Indexed {len(images)} pages"
def get_example():
return [
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quels sont les 4 axes majeurs des achats?"],
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quelles sont les actions entreprise en Afrique du Sud?"],
[["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "fais moi un tableau markdown de la rΓ©partition homme femme"],
]
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
gr.Markdown("# π ColPali + Qwen2VL 7B: Document Retrieval & Analysis App")
# Section 1: File Upload
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## Step 1: Upload Your Documents π")
file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDF Documents")
grayscale_option = gr.Checkbox(label="Convert images to grayscale π€", value=False)
convert_button = gr.Button("π Index Documents", variant="secondary")
message = gr.Textbox("No files uploaded yet", label="Status", interactive=False)
embeds = gr.State(value=[])
imgs = gr.State(value=[])
img_chunk = gr.State(value=[])
# Section 2: Search Options
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## Step 2: Search the Indexed Documents π")
query = gr.Textbox(placeholder="Enter your query here", label="Query", lines=2)
k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Results", value=1)
retrieval_method = gr.Dropdown(
choices=["CustomEvaluator", "FAISS"],
label="Choose Retrieval Method π",
value="CustomEvaluator"
)
search_button = gr.Button("π Search", variant="primary")
# Displaying Examples
with gr.Row():
gr.Markdown("## π‘ Example Queries")
gr.Examples(examples=get_example(), inputs=[file, query], label="Try These Examples")
# Output Gallery for Search Results
output_gallery = gr.Gallery(label="π Retrieved Documents", height=600)
# Section 3: Answer Retrieval
with gr.Row():
gr.Markdown("## Step 3: Generate Answers with Qwen2-VL π§ ")
answer_button = gr.Button("π¬ Get Answer", variant="primary")
output = gr.Markdown(label="Output")
# Define interactions
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
search_button.click(search, inputs=[query, embeds, imgs, k, retrieval_method], outputs=[output_gallery])
answer_button.click(model_inference, inputs=[output_gallery, query, grayscale_option], outputs=output)
if __name__ == "__main__":
demo.queue(max_size=10).launch(share=True)
|