DrishtiSharma commited on
Commit
f30ab0b
Β·
verified Β·
1 Parent(s): f70bb9d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import gradio as gr
4
+ import torch
5
+ from colpali_engine.models.paligemma_colbert_architecture import ColPali
6
+ from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
7
+ from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
8
+ from pdf2image import convert_from_path
9
+ from PIL import Image, ImageEnhance
10
+ from torch.utils.data import DataLoader
11
+ from tqdm import tqdm
12
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
13
+ import faiss # FAISS for fast retrieval
14
+ import numpy as np
15
+
16
+ # Initialize FAISS index for fast similarity search (used only if selected)
17
+ embedding_dim = 448
18
+ faiss_index = faiss.IndexFlatL2(embedding_dim)
19
+ stored_images = [] # To store images associated with embeddings for retrieval if using FAISS
20
+
21
+
22
+ def preprocess_image(image_path, grayscale=False):
23
+ """Apply optional grayscale and other enhancements to images."""
24
+ img = Image.open(image_path)
25
+ if grayscale:
26
+ img = img.convert("L") # Apply grayscale if selected
27
+ enhancer = ImageEnhance.Sharpness(img)
28
+ img = enhancer.enhance(2.0) # Sharpen
29
+ return img
30
+
31
+
32
+ @spaces.GPU
33
+ def model_inference(images, text, grayscale=False):
34
+ """Qwen2VL-based inference function with optional grayscale processing."""
35
+ images = [
36
+ {
37
+ "type": "image",
38
+ "image": preprocess_image(image[0], grayscale=grayscale),
39
+ "resized_height": 1344,
40
+ "resized_width": 1344,
41
+ }
42
+ for image in images
43
+ ]
44
+ images.append({"type": "text", "text": text})
45
+
46
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
47
+ "Qwen/Qwen2-VL-7B-Instruct",
48
+ trust_remote_code=True,
49
+ torch_dtype=torch.bfloat16
50
+ ).to("cuda:0")
51
+
52
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
53
+ messages = [{"role": "user", "content": images}]
54
+
55
+ text_input = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
56
+ image_inputs, _ = process_vision_info(messages)
57
+
58
+ inputs = processor(
59
+ text=[text_input], images=image_inputs, padding=True, return_tensors="pt"
60
+ ).to("cuda")
61
+
62
+ generated_ids = model.generate(**inputs, max_new_tokens=512)
63
+ output_text = processor.batch_decode(generated_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
64
+
65
+ del model, processor
66
+ torch.cuda.empty_cache()
67
+ return output_text[0]
68
+
69
+
70
+ @spaces.GPU
71
+ def search(query: str, ds, images, k, retrieval_method="CustomEvaluator"):
72
+ """Search function with option to choose between CustomEvaluator and FAISS for retrieval."""
73
+ model_name = "vidore/colpali-v1.2"
74
+ token = os.environ.get("HF_TOKEN")
75
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
76
+
77
+ model = ColPali.from_pretrained(
78
+ "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token=token
79
+ ).eval().to(device)
80
+ processor = AutoProcessor.from_pretrained(model_name, token=token)
81
+ mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
82
+
83
+ # Process the query to obtain embeddings
84
+ batch_query = process_queries(processor, [query], mock_image)
85
+ embeddings_query = model(**{k: v.to(device) for k, v in batch_query.items()})
86
+ query_embedding = embeddings_query[0].cpu().numpy()
87
+
88
+ if retrieval_method == "FAISS":
89
+ # Use FAISS for efficient retrieval
90
+ distances, indices = faiss_index.search(np.array([query_embedding]), k)
91
+ results = [stored_images[idx] for idx in indices[0]]
92
+ else:
93
+ # Use CustomEvaluator for retrieval
94
+ qs = [query_embedding]
95
+ retriever_evaluator = CustomEvaluator(is_multi_vector=True)
96
+ scores = retriever_evaluator.evaluate(qs, ds)
97
+
98
+ top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]
99
+ results = [images[idx] for idx in top_k_indices]
100
+
101
+ del model, processor
102
+ torch.cuda.empty_cache()
103
+ return results
104
+
105
+
106
+ def index(files, ds):
107
+ """Convert and index PDF files."""
108
+ images = convert_files(files)
109
+ return index_gpu(images, ds)
110
+
111
+
112
+ def convert_files(files):
113
+ """Convert PDF files to images."""
114
+ images = []
115
+ for f in files:
116
+ images.extend(convert_from_path(f, thread_count=4))
117
+
118
+ if len(images) >= 150:
119
+ raise gr.Error("The number of images in the dataset should be less than 150.")
120
+ return images
121
+
122
+
123
+ @spaces.GPU
124
+ def index_gpu(images, ds):
125
+ """Index documents using FAISS or store in dataset for CustomEvaluator."""
126
+ global stored_images
127
+ model_name = "vidore/colpali-v1.2"
128
+ token = os.environ.get("HF_TOKEN")
129
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
130
+
131
+ model = ColPali.from_pretrained(
132
+ "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token=token
133
+ ).eval().to(device)
134
+ processor = AutoProcessor.from_pretrained(model_name, token=token)
135
+ mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
136
+
137
+ dataloader = DataLoader(images, batch_size=4, shuffle=False, collate_fn=lambda x: process_images(processor, x))
138
+ all_embeddings = []
139
+
140
+ for batch in tqdm(dataloader):
141
+ with torch.no_grad():
142
+ batch = {k: v.to(device) for k, v in batch.items()}
143
+ embeddings_doc = model(**batch)
144
+ all_embeddings.extend(embeddings_doc.cpu().numpy())
145
+
146
+ # Store embeddings in FAISS index and dataset for respective retrieval options
147
+ embeddings = np.array(all_embeddings)
148
+ faiss_index.add(embeddings) # Add to FAISS index
149
+ ds.extend(list(torch.unbind(torch.tensor(embeddings)))) # Extend original ds for CustomEvaluator
150
+ stored_images.extend(images) # Store images to link with FAISS indices
151
+
152
+ del model, processor
153
+ torch.cuda.empty_cache()
154
+ return f"Indexed {len(images)} pages"
155
+
156
+
157
+ def get_example():
158
+ return [
159
+ [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quels sont les 4 axes majeurs des achats?"],
160
+ [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quelles sont les actions entreprise en Afrique du Sud?"],
161
+ [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "fais moi un tableau markdown de la rΓ©partition homme femme"],
162
+ ]
163
+
164
+
165
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
166
+ gr.Markdown("# πŸ“ ColPali + Qwen2VL 7B: Enhanced Document Retrieval & Analysis App")
167
+
168
+ # Section 1: File Upload
169
+ with gr.Row():
170
+ with gr.Column(scale=2):
171
+ gr.Markdown("## Step 1: Upload Your Documents πŸ“„")
172
+ file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDF Documents")
173
+ grayscale_option = gr.Checkbox(label="Convert images to grayscale πŸ–€", value=False)
174
+ convert_button = gr.Button("πŸ”„ Index Documents", variant="secondary")
175
+ message = gr.Textbox("No files uploaded yet", label="Status", interactive=False)
176
+ embeds = gr.State(value=[])
177
+ imgs = gr.State(value=[])
178
+ img_chunk = gr.State(value=[])
179
+
180
+ # Section 2: Search Options
181
+ with gr.Row():
182
+ with gr.Column(scale=3):
183
+ gr.Markdown("## Step 2: Search the Indexed Documents πŸ”")
184
+ query = gr.Textbox(placeholder="Enter your query here", label="Query", lines=2)
185
+ k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of Results", value=1)
186
+ retrieval_method = gr.Dropdown(
187
+ choices=["CustomEvaluator", "FAISS"],
188
+ label="Choose Retrieval Method πŸ”€",
189
+ value="CustomEvaluator"
190
+ )
191
+ search_button = gr.Button("πŸ” Search", variant="primary")
192
+
193
+ # Displaying Examples
194
+ with gr.Row():
195
+ gr.Markdown("## πŸ’‘ Example Queries")
196
+ gr.Examples(examples=get_example(), inputs=[file, query], label="Try These Examples", show_label=True)
197
+
198
+ # Output Gallery for Search Results
199
+ output_gallery = gr.Gallery(label="πŸ“‚ Retrieved Documents", height=600, show_label=True)
200
+
201
+ # Section 3: Answer Retrieval
202
+ with gr.Row():
203
+ gr.Markdown("## Step 3: Generate Answers with Qwen2-VL 🧠")
204
+ answer_button = gr.Button("πŸ’¬ Get Answer", variant="primary")
205
+ output = gr.Markdown(label="Output")
206
+
207
+ # Define interactions
208
+ convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
209
+ search_button.click(search, inputs=[query, embeds, imgs, k, retrieval_method], outputs=[output_gallery])
210
+ answer_button.click(model_inference, inputs=[output_gallery, query, grayscale_option], outputs=output)
211
+
212
+ if __name__ == "__main__":
213
+ demo.queue(max_size=10).launch(share=True)