Kazel commited on
Commit
065b6ad
·
1 Parent(s): 4fc6794
Files changed (11) hide show
  1. app.py +311 -0
  2. colpali_manager.py +141 -0
  3. middleware.py +62 -0
  4. milvus_manager.py +195 -0
  5. packages.txt +1 -0
  6. pdf_manager.py +46 -0
  7. rag.py +147 -0
  8. requirements.txt +14 -0
  9. test.py +30 -0
  10. uploaded_files.txt +3 -0
  11. utils.py +5 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tempfile
3
+ import os
4
+ import fitz # PyMuPDF
5
+ import uuid
6
+ import shutil
7
+ from pymilvus import MilvusClient
8
+
9
+ from middleware import Middleware
10
+ from rag import Rag
11
+ from pathlib import Path
12
+ import subprocess
13
+ import getpass
14
+
15
+ rag = Rag()
16
+
17
+
18
+ def generate_uuid(state):
19
+ # Check if UUID already exists in session state
20
+ if state["user_uuid"] is None:
21
+ # Generate a new UUID if not already set
22
+ state["user_uuid"] = str(uuid.uuid4())
23
+
24
+ return state["user_uuid"]
25
+
26
+
27
+ class PDFSearchApp:
28
+ def __init__(self):
29
+ self.indexed_docs = {}
30
+ self.current_pdf = None
31
+
32
+ def upload_and_convert(self, state, files, max_pages):
33
+ #change id
34
+ #id = generate_uuid(state)
35
+
36
+
37
+ pages = 0
38
+
39
+ if files is None:
40
+ return "No file uploaded"
41
+ try: #if onlyy one file
42
+ for file in files[:]: # Iterate over a shallow copy of the list, TEST THIS
43
+
44
+ # Extract the last part of the path (file name)
45
+ filename = os.path.basename(file.name)
46
+
47
+ # Split the base name into name and extension
48
+ name, ext = os.path.splitext(filename)
49
+ self.current_pdf = file.name
50
+ pdf_path=file.name
51
+ #if ppt will get replaced with path of ppt!
52
+
53
+ #if extension is .ppt or .pptx, convert
54
+ if ext == ".ppt" or ext == ".pptx": #need to test with a ppt key...
55
+ '''
56
+ import comtypes.client
57
+ powerpoint = comtypes.client.CreateObject("PowerPoint.Application")
58
+ powerpoint.Visible = 1
59
+ presentation = powerpoint.Presentations.Open(file)
60
+ output_file = os.path.splitext(file)[0] + '.pdf'
61
+ output_directory = os.path.dirname(file)
62
+ presentation.SaveAs(os.path.join(output_directory, output_file), 32) # 32 is the formatType for PDF
63
+ presentation.Close()
64
+ powerpoint.Quit()
65
+ file = os.path.join(output_directory, output_file) #swap file to be used to the outputted pdf file instead
66
+ # Extract the last part of the path (file name)
67
+ name = os.path.basename(file)
68
+ # Split the base name into name and extension
69
+ name, ext = os.path.splitext(name)
70
+ print(name)
71
+ self.current_pdf = os.path.join(output_directory, output_file)
72
+ pdf_path = os.path.join(output_directory, output_file)'
73
+ '''
74
+ print("pptx not supported on spaces")
75
+
76
+
77
+ # Replace spaces and hyphens with underscores in the name
78
+ modified_filename = name.replace(" ", "_").replace("-", "_")
79
+
80
+ id = modified_filename #if string cmi then serialize the name, test for later
81
+
82
+ print(f"Uploading file: {id}, id: abc")
83
+ middleware = Middleware(modified_filename, create_collection=True)
84
+
85
+
86
+ pages = middleware.index(pdf_path, id=id, max_pages=max_pages)
87
+
88
+
89
+ self.indexed_docs[id] = True
90
+
91
+ #clear files for next consec upload after loop is complete
92
+ files = []
93
+ return f"Uploaded and extracted {len(pages)} pages"
94
+ except Exception as e:
95
+ return f"Error processing PDF: {str(e)}"
96
+
97
+
98
+ def display_file_list(text):
99
+ try:
100
+ # Retrieve all entries in the specified directory
101
+ directory_path = "pages"
102
+ current_working_directory = os.getcwd()
103
+ directory_path = os.path.join(current_working_directory, directory_path)
104
+ entries = os.listdir(directory_path)
105
+ # Filter out entries that are directories
106
+ directories = [entry for entry in entries if os.path.isdir(os.path.join(directory_path, entry))]
107
+ return directories
108
+ except FileNotFoundError:
109
+ return f"The directory {directory_path} does not exist."
110
+ except PermissionError:
111
+ return f"Permission denied to access {directory_path}."
112
+ except Exception as e:
113
+ return str(e)
114
+
115
+
116
+ def search_documents(self, state, query, num_results=1):
117
+ print(f"Searching for query: {query}")
118
+ #id = generate_uuid(state)
119
+ id = "test" # not used anyway
120
+
121
+ """
122
+ if not self.indexed_docs[id]:
123
+ print("Please index documents first")
124
+ return "Please index documents first", "--"
125
+ """ #edited out to allow direct query on db to test persistency
126
+ if not query:
127
+ print("Please enter a search query")
128
+ return "Please enter a search query", "--"
129
+ try:
130
+
131
+ middleware = Middleware(id, create_collection=False)
132
+
133
+ search_results = middleware.search([query])[0]
134
+ #direct retrieve file path rather than rely on page nums!
135
+ #try to retrieve multiple files rather than a single page (TBD)
136
+
137
+ page_num = search_results[0][1] + 1 # final return value is a list of tuples, each tuple being: (score, doc_id, collection_name), so use [0][2] to get collection name of first ranked item
138
+ coll_num = search_results[0][2]
139
+
140
+ print(f"Retrieved page number: {page_num}")
141
+
142
+ img_path = f"pages/{coll_num}/page_{page_num}.png"
143
+ path = f"pages/{coll_num}/page_{page_num}"
144
+
145
+ print(f"Retrieved image path: {img_path}")
146
+
147
+ rag_response = rag.get_answer_from_gemini(query, [img_path])
148
+
149
+ return path,img_path, rag_response
150
+
151
+ except Exception as e:
152
+ return f"Error during search: {str(e)}", "--"
153
+
154
+ def delete(state,choice):
155
+ #delete file in pages, then use middleware to delete collection
156
+ # 1. Create a milvus client
157
+
158
+ client = MilvusClient(uri="http://localhost:19530")
159
+ #client = MilvusClient(
160
+ # uri="http://localhost:19530",
161
+ # token="root:Milvus"
162
+ # )
163
+ path = f"pages/{choice}"
164
+ if os.path.exists(path):
165
+ shutil.rmtree(path)
166
+ #call milvus manager to delete collection
167
+ client.drop_collection(collection_name=choice)
168
+ return f"Deleted {choice}"
169
+ else:
170
+ return "Directory not found"
171
+
172
+ def list_downloaded_hf_models(state):
173
+ # Determine the cache directory
174
+ hf_cache_dir = Path(os.getenv('HF_HOME', Path.home() / '.cache/huggingface/hub'))
175
+
176
+ # Initialize a list to store model names
177
+ model_names = []
178
+
179
+ # Traverse the cache directory
180
+ for repo_dir in hf_cache_dir.glob('models--*'):
181
+ # Extract the model name from the directory structure
182
+ model_name = repo_dir.name.split('--', 1)[-1].replace('-', '/')
183
+ model_names.append(model_name)
184
+
185
+ return model_names
186
+
187
+
188
+ def list_downloaded_ollama_models(state,):
189
+ # Retrieve the current user's name
190
+ username = getpass.getuser()
191
+
192
+ # Construct the target directory path
193
+ base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library"
194
+
195
+ try:
196
+ # List all entries in the directory
197
+ with os.scandir(base_path) as entries:
198
+ # Filter and print only directories
199
+ directories = [entry.name for entry in entries if entry.is_dir()]
200
+
201
+ return directories
202
+ except FileNotFoundError:
203
+ print(f"The directory {base_path} does not exist.")
204
+ except PermissionError:
205
+ print(f"Permission denied to access {base_path}.")
206
+ except Exception as e:
207
+ print(f"An error occurred: {e}")
208
+
209
+ def model_settings(state,hfchoice, ollamachoice,tokensize):
210
+ os.environ['colpali'] = hfchoice
211
+ os.environ['ollama'] = ollamachoice
212
+ os.environ['tokens'] = tokensize
213
+ return "abc"
214
+
215
+
216
+
217
+ def create_ui():
218
+ app = PDFSearchApp()
219
+
220
+ with gr.Blocks(css="footer{display:none !important}") as demo:
221
+ state = gr.State(value={"user_uuid": None})
222
+
223
+
224
+ gr.Markdown("# Collar Multimodal RAG Demo")
225
+ gr.Markdown("Made by Collar")
226
+
227
+ with gr.Tab("Upload PDF"):
228
+ with gr.Column():
229
+ max_pages_input = gr.Slider(
230
+ minimum=1,
231
+ maximum=10000,
232
+ value=20,
233
+ step=10,
234
+ label="Max pages to extract and index per document"
235
+ )
236
+ file_input = gr.Files(label="Upload PDFs")
237
+ file_list = gr.Textbox(label="Uploaded Files", interactive=False, value=app.display_file_list())
238
+ status = gr.Textbox(label="Indexing Status", interactive=False)
239
+
240
+
241
+ with gr.Tab("Query"):
242
+ with gr.Column():
243
+ query_input = gr.Textbox(label="Enter query")
244
+ #num_results = gr.Slider(
245
+ # minimum=1,
246
+ # maximum=10,
247
+ # value=5,
248
+ # step=1,
249
+ # label="Number of results"
250
+ #)
251
+ search_btn = gr.Button("Query")
252
+ llm_answer = gr.Textbox(label="RAG Response", interactive=False)
253
+ path = gr.Textbox(label="Link To Document Page", interactive=False)
254
+ images = gr.Image(label="Top page matching query")
255
+ with gr.Tab("Data Settings"): #deletion of collections, changing of model parameters etc
256
+ with gr.Column():
257
+ # Button to delete (TBD)
258
+ choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
259
+ delete_button = gr.Button("Delete Document From DB")
260
+ status1 = gr.Textbox(label="Deletion Status", interactive=False)
261
+
262
+ with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
263
+ with gr.Column():
264
+ # Button to delete (TBD)
265
+ hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),label="Visual Document Retrieval (VDR) Model")
266
+ ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
267
+ tokensize = gr.Slider(
268
+ minimum=256,
269
+ maximum=4096,
270
+ value=20,
271
+ step=10,
272
+ label="Max tokens per response (Reply Length)"
273
+ )
274
+ model_button = gr.Button("Update Settings")
275
+ status2 = gr.Textbox(label="Update Status", interactive=False)
276
+
277
+
278
+
279
+
280
+ # Event handlers
281
+ file_input.change(
282
+ fn=app.upload_and_convert,
283
+ inputs=[state, file_input, max_pages_input],
284
+ outputs=[status]
285
+ )
286
+
287
+ search_btn.click(
288
+ #try to query without uploading first
289
+ fn= app.search_documents,
290
+ inputs=[state, query_input],
291
+ outputs=[path,images, llm_answer]
292
+ )
293
+
294
+ delete_button.click(
295
+ fn=app.delete,
296
+ inputs=[choice],
297
+ outputs=[status1]
298
+ )
299
+
300
+ model_button.click(
301
+ fn=app.model_settings,
302
+ inputs=[hfchoice, ollamachoice,tokensize],
303
+ outputs=[status2]
304
+ )
305
+
306
+ return demo
307
+
308
+ if __name__ == "__main__":
309
+ demo = create_ui()
310
+ demo.launch()
311
+
colpali_manager.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colpali_engine.models import ColPali
2
+ from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
3
+ from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
4
+ from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
5
+ from torch.utils.data import DataLoader
6
+ import torch
7
+ from typing import List, cast
8
+
9
+ #from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
10
+ from colpali_engine.models import ColIdefics3, ColIdefics3Processor
11
+
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ import os
15
+
16
+ import spaces
17
+
18
+
19
+ #this part is for local runs
20
+ torch.cpu.empty_cache()
21
+
22
+ model_name = "vidore/colSmol-256M"
23
+ device = get_torch_device("cpu") #try using cpu instead of cpu?
24
+
25
+ #switch to locally downloading models & loading locally rather than from hf
26
+ #
27
+
28
+ current_working_directory = os.getcwd()
29
+ save_directory = model_name # Directory to save the specific model name
30
+ save_directory = os.path.join(current_working_directory, save_directory)
31
+
32
+ processor_directory = 'local_processor' # Directory to save the processor
33
+ processor_directory = os.path.join(current_working_directory, processor_directory)
34
+
35
+
36
+ model = ColIdefics3.from_pretrained(
37
+ model_name,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map=device,
40
+ #attn_implementation="flash_attention_2",
41
+ ).eval()
42
+ processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
43
+
44
+ """
45
+ if not os.path.exists(save_directory): #download if directory not created/model not loaded
46
+ # Directory does not exist; create it
47
+ os.makedirs(save_directory)
48
+ print(f"Directory '{save_directory}' created.")
49
+ model = ColIdefics3.from_pretrained(
50
+ model_name,
51
+ torch_dtype=torch.bfloat16,
52
+ device_map=device,
53
+ attn_implementation="flash_attention_2",
54
+ ).eval()
55
+ model.save_pretrained(save_directory)
56
+ os.makedirs(processor_directory)
57
+ processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
58
+
59
+ processor.save_pretrained(processor_directory)
60
+
61
+ else:
62
+ model = ColIdefics3.from_pretrained(save_directory)
63
+ processor = ColIdefics3.from_pretrained(processor_directory, use_fast=True)
64
+ """
65
+
66
+
67
+ class ColpaliManager:
68
+
69
+
70
+ def __init__(self, device = "cpu", model_name = "vidore/colSmol-256M"): #need to hot potato/use diff gpus between colpali & ollama
71
+
72
+ print(f"Initializing ColpaliManager with device {device} and model {model_name}")
73
+
74
+ # self.device = get_torch_device(device)
75
+
76
+ # self.model = ColPali.from_pretrained(
77
+ # model_name,
78
+ # torch_dtype=torch.bfloat16,
79
+ # device_map=self.device,
80
+ # ).eval()
81
+
82
+ # self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
83
+
84
+ @spaces.GPU
85
+ def get_images(self, paths: list[str]) -> List[Image.Image]:
86
+ model.to("cpu")
87
+ return [Image.open(path) for path in paths]
88
+
89
+ @spaces.GPU
90
+ def process_images(self, image_paths:list[str], batch_size=5):
91
+ model.to("cpu")
92
+ print(f"Processing {len(image_paths)} image_paths")
93
+
94
+ images = self.get_images(image_paths)
95
+
96
+ dataloader = DataLoader(
97
+ dataset=ListDataset[str](images),
98
+ batch_size=batch_size,
99
+ shuffle=False,
100
+ collate_fn=lambda x: processor.process_images(x),
101
+ )
102
+
103
+ ds: List[torch.Tensor] = []
104
+ for batch_doc in tqdm(dataloader):
105
+ with torch.no_grad():
106
+ batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
107
+ embeddings_doc = model(**batch_doc)
108
+ ds.extend(list(torch.unbind(embeddings_doc.to(device))))
109
+
110
+ ds_np = [d.float().cpu().numpy() for d in ds]
111
+
112
+ return ds_np
113
+
114
+
115
+ @spaces.GPU
116
+ def process_text(self, texts: list[str]):
117
+ model.to("cpu") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
118
+ print(f"Processing {len(texts)} texts")
119
+
120
+ dataloader = DataLoader(
121
+ dataset=ListDataset[str](texts),
122
+ batch_size=5,
123
+ shuffle=False,
124
+ collate_fn=lambda x: processor.process_queries(x),
125
+ )
126
+
127
+ qs: List[torch.Tensor] = []
128
+ for batch_query in dataloader:
129
+ with torch.no_grad():
130
+ batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
131
+ embeddings_query = model(**batch_query)
132
+
133
+ qs.extend(list(torch.unbind(embeddings_query.to(device))))
134
+
135
+ qs_np = [q.float().cpu().numpy() for q in qs]
136
+ model.to("cpu") # Moves all model parameters and buffers to the CPU, freeing up gpu for ollama call after this process text call! (THIS WORKS!)
137
+
138
+ return qs_np
139
+
140
+
141
+
middleware.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colpali_manager import ColpaliManager
2
+ from milvus_manager import MilvusManager
3
+ from pdf_manager import PdfManager
4
+ import hashlib
5
+
6
+
7
+
8
+ pdf_manager = PdfManager()
9
+ colpali_manager = ColpaliManager()
10
+
11
+
12
+
13
+ class Middleware:
14
+ def __init__(self, id:str, create_collection=True):
15
+ #hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
16
+ hashed_id = 0 #switched to persistent db, shld use diff id for diff accs
17
+ milvus_db_name = f"milvus_{hashed_id}.db"
18
+ self.milvus_manager = MilvusManager(milvus_db_name, id, create_collection) #create collections based on id rather than colpali
19
+
20
+ def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
21
+
22
+ if type(pdf_path) == None: #for direct query without any upload to db
23
+ print("no docs")
24
+ return
25
+
26
+ print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
27
+
28
+ image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
29
+
30
+ print(f"Saved {len(image_paths)} images")
31
+
32
+ colbert_vecs = colpali_manager.process_images(image_paths)
33
+
34
+ images_data = [{
35
+ "colbert_vecs": colbert_vecs[i],
36
+ "filepath": image_paths[i]
37
+ } for i in range(len(image_paths))]
38
+
39
+ print(f"Inserting {len(images_data)} images data to Milvus")
40
+
41
+ self.milvus_manager.insert_images_data(images_data)
42
+
43
+ print("Indexing completed")
44
+
45
+ return image_paths
46
+
47
+
48
+
49
+ def search(self, search_queries: list[str]):
50
+ print(f"Searching for {len(search_queries)} queries")
51
+
52
+ final_res = []
53
+
54
+ for query in search_queries:
55
+ print(f"Searching for query: {query}")
56
+ query_vec = colpali_manager.process_text([query])[0]
57
+ search_res = self.milvus_manager.search(query_vec, topk=1)
58
+ print(f"Search result: {search_res} for query: {query}")
59
+ final_res.append(search_res)
60
+
61
+ return final_res
62
+
milvus_manager.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import MilvusClient, DataType
2
+ import numpy as np
3
+ import concurrent.futures
4
+ from pymilvus import Collection
5
+
6
+ class MilvusManager:
7
+ def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
8
+ self.client = MilvusClient(uri=milvus_uri)
9
+ # self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
10
+ self.collection_name = collection_name
11
+ self.dim = dim
12
+
13
+ if self.client.has_collection(collection_name=self.collection_name):
14
+ self.client.load_collection(collection_name=self.collection_name)
15
+ print("Loaded existing collection.")
16
+ elif create_collection:
17
+ self.create_collection()
18
+ self.create_index()
19
+
20
+ def create_collection(self):
21
+ if self.client.has_collection(collection_name=self.collection_name):
22
+ print("Collection already exists.")
23
+ return
24
+
25
+ schema = self.client.create_schema(
26
+ auto_id=True,
27
+ enable_dynamic_fields=True,
28
+ )
29
+ schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
30
+ schema.add_field(
31
+ field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
32
+ )
33
+ schema.add_field(field_name="seq_id", datatype=DataType.INT16)
34
+ schema.add_field(field_name="doc_id", datatype=DataType.INT64)
35
+ schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
36
+
37
+ self.client.create_collection(
38
+ collection_name=self.collection_name, schema=schema
39
+ )
40
+
41
+ def create_index(self):
42
+ index_params = self.client.prepare_index_params()
43
+ index_params.add_index(
44
+ field_name="vector",
45
+ index_name="vector_index",
46
+ index_type="HNSW", #use HNSW option if got more mem, if not use IVF for faster processing
47
+ metric_type="IP",
48
+ params={
49
+ "M": 16, #M:16 for HNSW, capital M
50
+ "efConstruction": 500, #for HNSW
51
+ },
52
+ )
53
+
54
+ self.client.create_index(
55
+ collection_name=self.collection_name, index_params=index_params, sync=True
56
+ )
57
+
58
+ def search(self, data, topk):
59
+ # Retrieve all collection names from the Milvus client.
60
+ collections = self.client.list_collections()
61
+
62
+ # Set search parameters (here, using Inner Product metric).
63
+ search_params = {"metric_type": "IP", "params": {}}
64
+
65
+ # Set to store unique (doc_id, collection_name) pairs across all collections.
66
+ doc_collection_pairs = set()
67
+
68
+ # Query each collection individually
69
+ for collection in collections:
70
+ self.client.load_collection(collection_name=collection)
71
+ print("collection loaded:"+ collection)
72
+ results = self.client.search(
73
+ collection,
74
+ data,
75
+ limit=50, # Adjust limit per collection as needed.
76
+ output_fields=["vector", "seq_id", "doc_id"],
77
+ search_params=search_params,
78
+ )
79
+ # Accumulate document IDs along with their originating collection.
80
+ for r_id in range(len(results)):
81
+ for r in range(len(results[r_id])):
82
+ doc_id = results[r_id][r]["entity"]["doc_id"]
83
+ doc_collection_pairs.add((doc_id, collection))
84
+
85
+ scores = []
86
+
87
+ def rerank_single_doc(doc_id, data, client, collection_name):
88
+ # Query for detailed document vectors in the given collection.
89
+ doc_colbert_vecs = client.query(
90
+ collection_name=collection_name,
91
+ filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
92
+ output_fields=["seq_id", "vector", "doc"],
93
+ limit=16380,
94
+ )
95
+ # Stack the vectors for dot product computation.
96
+ doc_vecs = np.vstack(
97
+ [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
98
+ )
99
+ # Compute a similarity score via dot product.
100
+ score = np.dot(data, doc_vecs.T).max(1).sum()
101
+ return (score, doc_id, collection_name)
102
+
103
+ # Use a thread pool to rerank each document concurrently.
104
+ with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
105
+ futures = {
106
+ executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection)
107
+ for doc_id, collection in doc_collection_pairs
108
+ }
109
+ for future in concurrent.futures.as_completed(futures):
110
+ score, doc_id, collection = future.result()
111
+ scores.append((score, doc_id, collection))
112
+ #doc_id is page number!
113
+
114
+ # Sort the reranked results by score in descending order.
115
+ scores.sort(key=lambda x: x[0], reverse=True)
116
+ # Unload the collection after search to free memory.
117
+ self.client.release_collection(collection_name=collection)
118
+
119
+ return scores[:topk] if len(scores) >= topk else scores
120
+ """
121
+ search_params = {"metric_type": "IP", "params": {}}
122
+ results = self.client.search(
123
+ self.collection_name,
124
+ data,
125
+ limit=50,
126
+ output_fields=["vector", "seq_id", "doc_id"],
127
+ search_params=search_params,
128
+ )
129
+ doc_ids = {result["entity"]["doc_id"] for result in results[0]}
130
+
131
+ scores = []
132
+
133
+ def rerank_single_doc(doc_id, data, client, collection_name):
134
+ doc_colbert_vecs = client.query(
135
+ collection_name=collection_name,
136
+ filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
137
+ output_fields=["seq_id", "vector", "doc"],
138
+ limit=1000,
139
+ )
140
+ doc_vecs = np.vstack(
141
+ [doc["vector"] for doc in doc_colbert_vecs]
142
+ )
143
+ score = np.dot(data, doc_vecs.T).max(1).sum()
144
+ return score, doc_id
145
+
146
+ with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
147
+ futures = {
148
+ executor.submit(
149
+ rerank_single_doc, doc_id, data, self.client, self.collection_name
150
+ ): doc_id
151
+ for doc_id in doc_ids
152
+ }
153
+ for future in concurrent.futures.as_completed(futures):
154
+ score, doc_id = future.result()
155
+ scores.append((score, doc_id))
156
+
157
+ scores.sort(key=lambda x: x[0], reverse=True)
158
+ return scores[:topk]
159
+ """
160
+
161
+ def insert(self, data):
162
+ colbert_vecs = data["colbert_vecs"]
163
+ seq_length = len(colbert_vecs)
164
+ doc_ids = [data["doc_id"]] * seq_length
165
+ seq_ids = list(range(seq_length))
166
+ docs = [""] * seq_length
167
+ docs[0] = data["filepath"]
168
+
169
+ self.client.insert(
170
+ self.collection_name,
171
+ [
172
+ {
173
+ "vector": colbert_vecs[i],
174
+ "seq_id": seq_ids[i],
175
+ "doc_id": doc_ids[i],
176
+ "doc": docs[i],
177
+ }
178
+ for i in range(seq_length)
179
+ ],
180
+ )
181
+
182
+ def get_images_as_doc(self, images_with_vectors):
183
+ return [
184
+ {
185
+ "colbert_vecs": image["colbert_vecs"],
186
+ "doc_id": idx,
187
+ "filepath": image["filepath"],
188
+ }
189
+ for idx, image in enumerate(images_with_vectors)
190
+ ]
191
+
192
+ def insert_images_data(self, image_data):
193
+ data = self.get_images_as_doc(image_data)
194
+ for item in data:
195
+ self.insert(item)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
pdf_manager.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pdf2image import convert_from_path
2
+ import os
3
+ import shutil
4
+
5
+ class PdfManager:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def clear_and_recreate_dir(self, output_folder):
10
+
11
+ print(f"Clearing output folder {output_folder}")
12
+
13
+ if os.path.exists(output_folder):
14
+ shutil.rmtree(output_folder)
15
+ #print("Clearing is unused for now for persistency")
16
+ else:
17
+ os.makedirs(output_folder)
18
+
19
+ #print("Clearing is unused for now for persistency")
20
+
21
+ def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
22
+ output_folder = f"pages/{id}" #remove last backslash to avoid error,test this
23
+ images = convert_from_path(pdf_path)
24
+
25
+ print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
26
+
27
+ self.clear_and_recreate_dir(output_folder)
28
+
29
+ num_page_processed = 0
30
+
31
+ for i, image in enumerate(images):
32
+ if max_pages and num_page_processed >= max_pages:
33
+ break
34
+
35
+ if pages and i not in pages:
36
+ continue
37
+
38
+ full_save_path = f"{output_folder}/page_{i + 1}.png"
39
+
40
+ #print(f"Saving image to {full_save_path}")
41
+
42
+ image.save(full_save_path, "PNG")
43
+
44
+ num_page_processed += 1
45
+
46
+ return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
rag.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+
4
+ from typing import List
5
+ from utils import encode_image
6
+ from PIL import Image
7
+ from ollama import chat
8
+ import torch
9
+ import subprocess
10
+ import psutil
11
+ import torch
12
+ from transformers import AutoModel, AutoTokenizer
13
+ import google.generativeai as genai
14
+
15
+
16
+
17
+ class Rag:
18
+
19
+ def get_answer_from_gemini(self, query, imagePaths):
20
+
21
+ print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
22
+
23
+ try:
24
+ genai.configure(api_key="AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs")
25
+ model = genai.GenerativeModel('gemini-1.5-flash')
26
+
27
+ images = [Image.open(path) for path in imagePaths]
28
+
29
+ chat = model.start_chat()
30
+
31
+ response = chat.send_message([*images, query])
32
+
33
+ answer = response.text
34
+
35
+ print(answer)
36
+
37
+ return answer
38
+
39
+ except Exception as e:
40
+ print(f"An error occurred while querying Gemini: {e}")
41
+ return f"Error: {str(e)}"
42
+
43
+ #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
44
+
45
+ def get_answer_from_openai(self, query, imagesPaths):
46
+ """ #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
47
+ print(f"Querying for query={query}, imagesPaths={imagesPaths}")
48
+
49
+ model = AutoModel.from_pretrained(
50
+ 'openbmb/MiniCPM-o-2_6-int4',
51
+ trust_remote_code=True,
52
+ attn_implementation='flash_attention_2', # sdpa or flash_attention_2
53
+ torch_dtype=torch.bfloat16,
54
+ init_vision=True,
55
+ )
56
+
57
+
58
+ model = model.eval().cuda()
59
+ tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6-int4', trust_remote_code=True)
60
+ image = Image.open(imagesPaths[0]).convert('RGB')
61
+
62
+ msgs = [{'role': 'user', 'content': [image, query]}]
63
+ answer = model.chat(
64
+ image=None,
65
+ msgs=msgs,
66
+ tokenizer=tokenizer
67
+ )
68
+ print(answer)
69
+ return answer
70
+ """
71
+
72
+ #ollama method below
73
+
74
+ torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
75
+
76
+
77
+ os.environ['OLLAMA_FLASH_ATTENTION'] = '1'
78
+
79
+
80
+ # Close model thread (colpali)
81
+ print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
82
+
83
+ try:
84
+
85
+ response = chat(
86
+ model='minicpm-v:8b-2.6-q8_0',
87
+ messages=[
88
+ {
89
+ 'role': 'user',
90
+ 'content': query,
91
+ 'images': imagesPaths,
92
+ }
93
+ ],
94
+ )
95
+
96
+ answer = response.message.content
97
+
98
+ print(answer)
99
+
100
+ return answer
101
+
102
+ except Exception as e:
103
+ print(f"An error occurred while querying OpenAI: {e}")
104
+ return None
105
+
106
+
107
+
108
+ def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
109
+ image_payload = []
110
+
111
+ for imagePath in imagesPaths:
112
+ base64_image = encode_image(imagePath)
113
+ image_payload.append({
114
+ "type": "image_url",
115
+ "image_url": {
116
+ "url": f"data:image/jpeg;base64,{base64_image}"
117
+ }
118
+ })
119
+
120
+ payload = {
121
+ "model": "Llama3.2-vision", #change model here as needed
122
+ "messages": [
123
+ {
124
+ "role": "user",
125
+ "content": [
126
+ {
127
+ "type": "text",
128
+ "text": query
129
+ },
130
+ *image_payload
131
+ ]
132
+ }
133
+ ],
134
+ "max_tokens": 1024 #reduce token size to reduce processing time
135
+ }
136
+
137
+ return payload
138
+
139
+
140
+
141
+ # if __name__ == "__main__":
142
+ # rag = Rag()
143
+
144
+ # query = "Based on attached images, how many new cases were reported during second wave peak"
145
+ # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
146
+
147
+ # rag.get_answer_from_gemini(query, imagesPaths)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ PyMuPDF
3
+ pdf2image
4
+ pymilvus
5
+ tqdm
6
+ pillow
7
+ spaces
8
+ google-generativeai
9
+ git+https://github.com/illuin-tech/colpali
10
+ timm==1.0.13
11
+ transformers
12
+ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp311-cp311-win_amd64.wh
13
+ comtypes
14
+ python-dotenv
test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import MilvusClient
2
+ from pymilvus import (
3
+ connections,
4
+ utility,
5
+ FieldSchema, CollectionSchema, DataType,
6
+ Collection,
7
+ )
8
+
9
+ # 1. Create a milvus client
10
+ client = MilvusClient(
11
+ uri="http://localhost:19530",
12
+ token="root:Milvus"
13
+ )
14
+
15
+ # 2. Create a collection
16
+ client.drop_collection(collection_name="fy2025_budget_statement")
17
+
18
+ # 3. List collections
19
+ print(client.list_collections() )
20
+
21
+ # ['test_collection']
22
+
23
+ """
24
+ res = client.get(
25
+ collection_name="colpali",
26
+ ids=[0, 1, 2],
27
+ )
28
+
29
+ print(res)
30
+ """
uploaded_files.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ EMERGING_MISSILE_THREATS_16382509
2
+ handwriting
3
+ multimediareport
utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import base64
2
+
3
+ def encode_image(image_path):
4
+ with open(image_path, "rb") as image_file:
5
+ return base64.b64encode(image_file.read()).decode('utf-8')