Kazel
commited on
Commit
·
065b6ad
1
Parent(s):
4fc6794
all
Browse files- app.py +311 -0
- colpali_manager.py +141 -0
- middleware.py +62 -0
- milvus_manager.py +195 -0
- packages.txt +1 -0
- pdf_manager.py +46 -0
- rag.py +147 -0
- requirements.txt +14 -0
- test.py +30 -0
- uploaded_files.txt +3 -0
- utils.py +5 -0
app.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import tempfile
|
3 |
+
import os
|
4 |
+
import fitz # PyMuPDF
|
5 |
+
import uuid
|
6 |
+
import shutil
|
7 |
+
from pymilvus import MilvusClient
|
8 |
+
|
9 |
+
from middleware import Middleware
|
10 |
+
from rag import Rag
|
11 |
+
from pathlib import Path
|
12 |
+
import subprocess
|
13 |
+
import getpass
|
14 |
+
|
15 |
+
rag = Rag()
|
16 |
+
|
17 |
+
|
18 |
+
def generate_uuid(state):
|
19 |
+
# Check if UUID already exists in session state
|
20 |
+
if state["user_uuid"] is None:
|
21 |
+
# Generate a new UUID if not already set
|
22 |
+
state["user_uuid"] = str(uuid.uuid4())
|
23 |
+
|
24 |
+
return state["user_uuid"]
|
25 |
+
|
26 |
+
|
27 |
+
class PDFSearchApp:
|
28 |
+
def __init__(self):
|
29 |
+
self.indexed_docs = {}
|
30 |
+
self.current_pdf = None
|
31 |
+
|
32 |
+
def upload_and_convert(self, state, files, max_pages):
|
33 |
+
#change id
|
34 |
+
#id = generate_uuid(state)
|
35 |
+
|
36 |
+
|
37 |
+
pages = 0
|
38 |
+
|
39 |
+
if files is None:
|
40 |
+
return "No file uploaded"
|
41 |
+
try: #if onlyy one file
|
42 |
+
for file in files[:]: # Iterate over a shallow copy of the list, TEST THIS
|
43 |
+
|
44 |
+
# Extract the last part of the path (file name)
|
45 |
+
filename = os.path.basename(file.name)
|
46 |
+
|
47 |
+
# Split the base name into name and extension
|
48 |
+
name, ext = os.path.splitext(filename)
|
49 |
+
self.current_pdf = file.name
|
50 |
+
pdf_path=file.name
|
51 |
+
#if ppt will get replaced with path of ppt!
|
52 |
+
|
53 |
+
#if extension is .ppt or .pptx, convert
|
54 |
+
if ext == ".ppt" or ext == ".pptx": #need to test with a ppt key...
|
55 |
+
'''
|
56 |
+
import comtypes.client
|
57 |
+
powerpoint = comtypes.client.CreateObject("PowerPoint.Application")
|
58 |
+
powerpoint.Visible = 1
|
59 |
+
presentation = powerpoint.Presentations.Open(file)
|
60 |
+
output_file = os.path.splitext(file)[0] + '.pdf'
|
61 |
+
output_directory = os.path.dirname(file)
|
62 |
+
presentation.SaveAs(os.path.join(output_directory, output_file), 32) # 32 is the formatType for PDF
|
63 |
+
presentation.Close()
|
64 |
+
powerpoint.Quit()
|
65 |
+
file = os.path.join(output_directory, output_file) #swap file to be used to the outputted pdf file instead
|
66 |
+
# Extract the last part of the path (file name)
|
67 |
+
name = os.path.basename(file)
|
68 |
+
# Split the base name into name and extension
|
69 |
+
name, ext = os.path.splitext(name)
|
70 |
+
print(name)
|
71 |
+
self.current_pdf = os.path.join(output_directory, output_file)
|
72 |
+
pdf_path = os.path.join(output_directory, output_file)'
|
73 |
+
'''
|
74 |
+
print("pptx not supported on spaces")
|
75 |
+
|
76 |
+
|
77 |
+
# Replace spaces and hyphens with underscores in the name
|
78 |
+
modified_filename = name.replace(" ", "_").replace("-", "_")
|
79 |
+
|
80 |
+
id = modified_filename #if string cmi then serialize the name, test for later
|
81 |
+
|
82 |
+
print(f"Uploading file: {id}, id: abc")
|
83 |
+
middleware = Middleware(modified_filename, create_collection=True)
|
84 |
+
|
85 |
+
|
86 |
+
pages = middleware.index(pdf_path, id=id, max_pages=max_pages)
|
87 |
+
|
88 |
+
|
89 |
+
self.indexed_docs[id] = True
|
90 |
+
|
91 |
+
#clear files for next consec upload after loop is complete
|
92 |
+
files = []
|
93 |
+
return f"Uploaded and extracted {len(pages)} pages"
|
94 |
+
except Exception as e:
|
95 |
+
return f"Error processing PDF: {str(e)}"
|
96 |
+
|
97 |
+
|
98 |
+
def display_file_list(text):
|
99 |
+
try:
|
100 |
+
# Retrieve all entries in the specified directory
|
101 |
+
directory_path = "pages"
|
102 |
+
current_working_directory = os.getcwd()
|
103 |
+
directory_path = os.path.join(current_working_directory, directory_path)
|
104 |
+
entries = os.listdir(directory_path)
|
105 |
+
# Filter out entries that are directories
|
106 |
+
directories = [entry for entry in entries if os.path.isdir(os.path.join(directory_path, entry))]
|
107 |
+
return directories
|
108 |
+
except FileNotFoundError:
|
109 |
+
return f"The directory {directory_path} does not exist."
|
110 |
+
except PermissionError:
|
111 |
+
return f"Permission denied to access {directory_path}."
|
112 |
+
except Exception as e:
|
113 |
+
return str(e)
|
114 |
+
|
115 |
+
|
116 |
+
def search_documents(self, state, query, num_results=1):
|
117 |
+
print(f"Searching for query: {query}")
|
118 |
+
#id = generate_uuid(state)
|
119 |
+
id = "test" # not used anyway
|
120 |
+
|
121 |
+
"""
|
122 |
+
if not self.indexed_docs[id]:
|
123 |
+
print("Please index documents first")
|
124 |
+
return "Please index documents first", "--"
|
125 |
+
""" #edited out to allow direct query on db to test persistency
|
126 |
+
if not query:
|
127 |
+
print("Please enter a search query")
|
128 |
+
return "Please enter a search query", "--"
|
129 |
+
try:
|
130 |
+
|
131 |
+
middleware = Middleware(id, create_collection=False)
|
132 |
+
|
133 |
+
search_results = middleware.search([query])[0]
|
134 |
+
#direct retrieve file path rather than rely on page nums!
|
135 |
+
#try to retrieve multiple files rather than a single page (TBD)
|
136 |
+
|
137 |
+
page_num = search_results[0][1] + 1 # final return value is a list of tuples, each tuple being: (score, doc_id, collection_name), so use [0][2] to get collection name of first ranked item
|
138 |
+
coll_num = search_results[0][2]
|
139 |
+
|
140 |
+
print(f"Retrieved page number: {page_num}")
|
141 |
+
|
142 |
+
img_path = f"pages/{coll_num}/page_{page_num}.png"
|
143 |
+
path = f"pages/{coll_num}/page_{page_num}"
|
144 |
+
|
145 |
+
print(f"Retrieved image path: {img_path}")
|
146 |
+
|
147 |
+
rag_response = rag.get_answer_from_gemini(query, [img_path])
|
148 |
+
|
149 |
+
return path,img_path, rag_response
|
150 |
+
|
151 |
+
except Exception as e:
|
152 |
+
return f"Error during search: {str(e)}", "--"
|
153 |
+
|
154 |
+
def delete(state,choice):
|
155 |
+
#delete file in pages, then use middleware to delete collection
|
156 |
+
# 1. Create a milvus client
|
157 |
+
|
158 |
+
client = MilvusClient(uri="http://localhost:19530")
|
159 |
+
#client = MilvusClient(
|
160 |
+
# uri="http://localhost:19530",
|
161 |
+
# token="root:Milvus"
|
162 |
+
# )
|
163 |
+
path = f"pages/{choice}"
|
164 |
+
if os.path.exists(path):
|
165 |
+
shutil.rmtree(path)
|
166 |
+
#call milvus manager to delete collection
|
167 |
+
client.drop_collection(collection_name=choice)
|
168 |
+
return f"Deleted {choice}"
|
169 |
+
else:
|
170 |
+
return "Directory not found"
|
171 |
+
|
172 |
+
def list_downloaded_hf_models(state):
|
173 |
+
# Determine the cache directory
|
174 |
+
hf_cache_dir = Path(os.getenv('HF_HOME', Path.home() / '.cache/huggingface/hub'))
|
175 |
+
|
176 |
+
# Initialize a list to store model names
|
177 |
+
model_names = []
|
178 |
+
|
179 |
+
# Traverse the cache directory
|
180 |
+
for repo_dir in hf_cache_dir.glob('models--*'):
|
181 |
+
# Extract the model name from the directory structure
|
182 |
+
model_name = repo_dir.name.split('--', 1)[-1].replace('-', '/')
|
183 |
+
model_names.append(model_name)
|
184 |
+
|
185 |
+
return model_names
|
186 |
+
|
187 |
+
|
188 |
+
def list_downloaded_ollama_models(state,):
|
189 |
+
# Retrieve the current user's name
|
190 |
+
username = getpass.getuser()
|
191 |
+
|
192 |
+
# Construct the target directory path
|
193 |
+
base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library"
|
194 |
+
|
195 |
+
try:
|
196 |
+
# List all entries in the directory
|
197 |
+
with os.scandir(base_path) as entries:
|
198 |
+
# Filter and print only directories
|
199 |
+
directories = [entry.name for entry in entries if entry.is_dir()]
|
200 |
+
|
201 |
+
return directories
|
202 |
+
except FileNotFoundError:
|
203 |
+
print(f"The directory {base_path} does not exist.")
|
204 |
+
except PermissionError:
|
205 |
+
print(f"Permission denied to access {base_path}.")
|
206 |
+
except Exception as e:
|
207 |
+
print(f"An error occurred: {e}")
|
208 |
+
|
209 |
+
def model_settings(state,hfchoice, ollamachoice,tokensize):
|
210 |
+
os.environ['colpali'] = hfchoice
|
211 |
+
os.environ['ollama'] = ollamachoice
|
212 |
+
os.environ['tokens'] = tokensize
|
213 |
+
return "abc"
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
def create_ui():
|
218 |
+
app = PDFSearchApp()
|
219 |
+
|
220 |
+
with gr.Blocks(css="footer{display:none !important}") as demo:
|
221 |
+
state = gr.State(value={"user_uuid": None})
|
222 |
+
|
223 |
+
|
224 |
+
gr.Markdown("# Collar Multimodal RAG Demo")
|
225 |
+
gr.Markdown("Made by Collar")
|
226 |
+
|
227 |
+
with gr.Tab("Upload PDF"):
|
228 |
+
with gr.Column():
|
229 |
+
max_pages_input = gr.Slider(
|
230 |
+
minimum=1,
|
231 |
+
maximum=10000,
|
232 |
+
value=20,
|
233 |
+
step=10,
|
234 |
+
label="Max pages to extract and index per document"
|
235 |
+
)
|
236 |
+
file_input = gr.Files(label="Upload PDFs")
|
237 |
+
file_list = gr.Textbox(label="Uploaded Files", interactive=False, value=app.display_file_list())
|
238 |
+
status = gr.Textbox(label="Indexing Status", interactive=False)
|
239 |
+
|
240 |
+
|
241 |
+
with gr.Tab("Query"):
|
242 |
+
with gr.Column():
|
243 |
+
query_input = gr.Textbox(label="Enter query")
|
244 |
+
#num_results = gr.Slider(
|
245 |
+
# minimum=1,
|
246 |
+
# maximum=10,
|
247 |
+
# value=5,
|
248 |
+
# step=1,
|
249 |
+
# label="Number of results"
|
250 |
+
#)
|
251 |
+
search_btn = gr.Button("Query")
|
252 |
+
llm_answer = gr.Textbox(label="RAG Response", interactive=False)
|
253 |
+
path = gr.Textbox(label="Link To Document Page", interactive=False)
|
254 |
+
images = gr.Image(label="Top page matching query")
|
255 |
+
with gr.Tab("Data Settings"): #deletion of collections, changing of model parameters etc
|
256 |
+
with gr.Column():
|
257 |
+
# Button to delete (TBD)
|
258 |
+
choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
|
259 |
+
delete_button = gr.Button("Delete Document From DB")
|
260 |
+
status1 = gr.Textbox(label="Deletion Status", interactive=False)
|
261 |
+
|
262 |
+
with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
|
263 |
+
with gr.Column():
|
264 |
+
# Button to delete (TBD)
|
265 |
+
hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),label="Visual Document Retrieval (VDR) Model")
|
266 |
+
ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
|
267 |
+
tokensize = gr.Slider(
|
268 |
+
minimum=256,
|
269 |
+
maximum=4096,
|
270 |
+
value=20,
|
271 |
+
step=10,
|
272 |
+
label="Max tokens per response (Reply Length)"
|
273 |
+
)
|
274 |
+
model_button = gr.Button("Update Settings")
|
275 |
+
status2 = gr.Textbox(label="Update Status", interactive=False)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
# Event handlers
|
281 |
+
file_input.change(
|
282 |
+
fn=app.upload_and_convert,
|
283 |
+
inputs=[state, file_input, max_pages_input],
|
284 |
+
outputs=[status]
|
285 |
+
)
|
286 |
+
|
287 |
+
search_btn.click(
|
288 |
+
#try to query without uploading first
|
289 |
+
fn= app.search_documents,
|
290 |
+
inputs=[state, query_input],
|
291 |
+
outputs=[path,images, llm_answer]
|
292 |
+
)
|
293 |
+
|
294 |
+
delete_button.click(
|
295 |
+
fn=app.delete,
|
296 |
+
inputs=[choice],
|
297 |
+
outputs=[status1]
|
298 |
+
)
|
299 |
+
|
300 |
+
model_button.click(
|
301 |
+
fn=app.model_settings,
|
302 |
+
inputs=[hfchoice, ollamachoice,tokensize],
|
303 |
+
outputs=[status2]
|
304 |
+
)
|
305 |
+
|
306 |
+
return demo
|
307 |
+
|
308 |
+
if __name__ == "__main__":
|
309 |
+
demo = create_ui()
|
310 |
+
demo.launch()
|
311 |
+
|
colpali_manager.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from colpali_engine.models import ColPali
|
2 |
+
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
|
3 |
+
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
|
4 |
+
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import torch
|
7 |
+
from typing import List, cast
|
8 |
+
|
9 |
+
#from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
|
10 |
+
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
|
11 |
+
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
import os
|
15 |
+
|
16 |
+
import spaces
|
17 |
+
|
18 |
+
|
19 |
+
#this part is for local runs
|
20 |
+
torch.cpu.empty_cache()
|
21 |
+
|
22 |
+
model_name = "vidore/colSmol-256M"
|
23 |
+
device = get_torch_device("cpu") #try using cpu instead of cpu?
|
24 |
+
|
25 |
+
#switch to locally downloading models & loading locally rather than from hf
|
26 |
+
#
|
27 |
+
|
28 |
+
current_working_directory = os.getcwd()
|
29 |
+
save_directory = model_name # Directory to save the specific model name
|
30 |
+
save_directory = os.path.join(current_working_directory, save_directory)
|
31 |
+
|
32 |
+
processor_directory = 'local_processor' # Directory to save the processor
|
33 |
+
processor_directory = os.path.join(current_working_directory, processor_directory)
|
34 |
+
|
35 |
+
|
36 |
+
model = ColIdefics3.from_pretrained(
|
37 |
+
model_name,
|
38 |
+
torch_dtype=torch.bfloat16,
|
39 |
+
device_map=device,
|
40 |
+
#attn_implementation="flash_attention_2",
|
41 |
+
).eval()
|
42 |
+
processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
|
43 |
+
|
44 |
+
"""
|
45 |
+
if not os.path.exists(save_directory): #download if directory not created/model not loaded
|
46 |
+
# Directory does not exist; create it
|
47 |
+
os.makedirs(save_directory)
|
48 |
+
print(f"Directory '{save_directory}' created.")
|
49 |
+
model = ColIdefics3.from_pretrained(
|
50 |
+
model_name,
|
51 |
+
torch_dtype=torch.bfloat16,
|
52 |
+
device_map=device,
|
53 |
+
attn_implementation="flash_attention_2",
|
54 |
+
).eval()
|
55 |
+
model.save_pretrained(save_directory)
|
56 |
+
os.makedirs(processor_directory)
|
57 |
+
processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
|
58 |
+
|
59 |
+
processor.save_pretrained(processor_directory)
|
60 |
+
|
61 |
+
else:
|
62 |
+
model = ColIdefics3.from_pretrained(save_directory)
|
63 |
+
processor = ColIdefics3.from_pretrained(processor_directory, use_fast=True)
|
64 |
+
"""
|
65 |
+
|
66 |
+
|
67 |
+
class ColpaliManager:
|
68 |
+
|
69 |
+
|
70 |
+
def __init__(self, device = "cpu", model_name = "vidore/colSmol-256M"): #need to hot potato/use diff gpus between colpali & ollama
|
71 |
+
|
72 |
+
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
|
73 |
+
|
74 |
+
# self.device = get_torch_device(device)
|
75 |
+
|
76 |
+
# self.model = ColPali.from_pretrained(
|
77 |
+
# model_name,
|
78 |
+
# torch_dtype=torch.bfloat16,
|
79 |
+
# device_map=self.device,
|
80 |
+
# ).eval()
|
81 |
+
|
82 |
+
# self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
83 |
+
|
84 |
+
@spaces.GPU
|
85 |
+
def get_images(self, paths: list[str]) -> List[Image.Image]:
|
86 |
+
model.to("cpu")
|
87 |
+
return [Image.open(path) for path in paths]
|
88 |
+
|
89 |
+
@spaces.GPU
|
90 |
+
def process_images(self, image_paths:list[str], batch_size=5):
|
91 |
+
model.to("cpu")
|
92 |
+
print(f"Processing {len(image_paths)} image_paths")
|
93 |
+
|
94 |
+
images = self.get_images(image_paths)
|
95 |
+
|
96 |
+
dataloader = DataLoader(
|
97 |
+
dataset=ListDataset[str](images),
|
98 |
+
batch_size=batch_size,
|
99 |
+
shuffle=False,
|
100 |
+
collate_fn=lambda x: processor.process_images(x),
|
101 |
+
)
|
102 |
+
|
103 |
+
ds: List[torch.Tensor] = []
|
104 |
+
for batch_doc in tqdm(dataloader):
|
105 |
+
with torch.no_grad():
|
106 |
+
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
107 |
+
embeddings_doc = model(**batch_doc)
|
108 |
+
ds.extend(list(torch.unbind(embeddings_doc.to(device))))
|
109 |
+
|
110 |
+
ds_np = [d.float().cpu().numpy() for d in ds]
|
111 |
+
|
112 |
+
return ds_np
|
113 |
+
|
114 |
+
|
115 |
+
@spaces.GPU
|
116 |
+
def process_text(self, texts: list[str]):
|
117 |
+
model.to("cpu") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
|
118 |
+
print(f"Processing {len(texts)} texts")
|
119 |
+
|
120 |
+
dataloader = DataLoader(
|
121 |
+
dataset=ListDataset[str](texts),
|
122 |
+
batch_size=5,
|
123 |
+
shuffle=False,
|
124 |
+
collate_fn=lambda x: processor.process_queries(x),
|
125 |
+
)
|
126 |
+
|
127 |
+
qs: List[torch.Tensor] = []
|
128 |
+
for batch_query in dataloader:
|
129 |
+
with torch.no_grad():
|
130 |
+
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
131 |
+
embeddings_query = model(**batch_query)
|
132 |
+
|
133 |
+
qs.extend(list(torch.unbind(embeddings_query.to(device))))
|
134 |
+
|
135 |
+
qs_np = [q.float().cpu().numpy() for q in qs]
|
136 |
+
model.to("cpu") # Moves all model parameters and buffers to the CPU, freeing up gpu for ollama call after this process text call! (THIS WORKS!)
|
137 |
+
|
138 |
+
return qs_np
|
139 |
+
|
140 |
+
|
141 |
+
|
middleware.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from colpali_manager import ColpaliManager
|
2 |
+
from milvus_manager import MilvusManager
|
3 |
+
from pdf_manager import PdfManager
|
4 |
+
import hashlib
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
pdf_manager = PdfManager()
|
9 |
+
colpali_manager = ColpaliManager()
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class Middleware:
|
14 |
+
def __init__(self, id:str, create_collection=True):
|
15 |
+
#hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
|
16 |
+
hashed_id = 0 #switched to persistent db, shld use diff id for diff accs
|
17 |
+
milvus_db_name = f"milvus_{hashed_id}.db"
|
18 |
+
self.milvus_manager = MilvusManager(milvus_db_name, id, create_collection) #create collections based on id rather than colpali
|
19 |
+
|
20 |
+
def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
|
21 |
+
|
22 |
+
if type(pdf_path) == None: #for direct query without any upload to db
|
23 |
+
print("no docs")
|
24 |
+
return
|
25 |
+
|
26 |
+
print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
|
27 |
+
|
28 |
+
image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
|
29 |
+
|
30 |
+
print(f"Saved {len(image_paths)} images")
|
31 |
+
|
32 |
+
colbert_vecs = colpali_manager.process_images(image_paths)
|
33 |
+
|
34 |
+
images_data = [{
|
35 |
+
"colbert_vecs": colbert_vecs[i],
|
36 |
+
"filepath": image_paths[i]
|
37 |
+
} for i in range(len(image_paths))]
|
38 |
+
|
39 |
+
print(f"Inserting {len(images_data)} images data to Milvus")
|
40 |
+
|
41 |
+
self.milvus_manager.insert_images_data(images_data)
|
42 |
+
|
43 |
+
print("Indexing completed")
|
44 |
+
|
45 |
+
return image_paths
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
def search(self, search_queries: list[str]):
|
50 |
+
print(f"Searching for {len(search_queries)} queries")
|
51 |
+
|
52 |
+
final_res = []
|
53 |
+
|
54 |
+
for query in search_queries:
|
55 |
+
print(f"Searching for query: {query}")
|
56 |
+
query_vec = colpali_manager.process_text([query])[0]
|
57 |
+
search_res = self.milvus_manager.search(query_vec, topk=1)
|
58 |
+
print(f"Search result: {search_res} for query: {query}")
|
59 |
+
final_res.append(search_res)
|
60 |
+
|
61 |
+
return final_res
|
62 |
+
|
milvus_manager.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymilvus import MilvusClient, DataType
|
2 |
+
import numpy as np
|
3 |
+
import concurrent.futures
|
4 |
+
from pymilvus import Collection
|
5 |
+
|
6 |
+
class MilvusManager:
|
7 |
+
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
|
8 |
+
self.client = MilvusClient(uri=milvus_uri)
|
9 |
+
# self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
|
10 |
+
self.collection_name = collection_name
|
11 |
+
self.dim = dim
|
12 |
+
|
13 |
+
if self.client.has_collection(collection_name=self.collection_name):
|
14 |
+
self.client.load_collection(collection_name=self.collection_name)
|
15 |
+
print("Loaded existing collection.")
|
16 |
+
elif create_collection:
|
17 |
+
self.create_collection()
|
18 |
+
self.create_index()
|
19 |
+
|
20 |
+
def create_collection(self):
|
21 |
+
if self.client.has_collection(collection_name=self.collection_name):
|
22 |
+
print("Collection already exists.")
|
23 |
+
return
|
24 |
+
|
25 |
+
schema = self.client.create_schema(
|
26 |
+
auto_id=True,
|
27 |
+
enable_dynamic_fields=True,
|
28 |
+
)
|
29 |
+
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
|
30 |
+
schema.add_field(
|
31 |
+
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
|
32 |
+
)
|
33 |
+
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
|
34 |
+
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
|
35 |
+
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
|
36 |
+
|
37 |
+
self.client.create_collection(
|
38 |
+
collection_name=self.collection_name, schema=schema
|
39 |
+
)
|
40 |
+
|
41 |
+
def create_index(self):
|
42 |
+
index_params = self.client.prepare_index_params()
|
43 |
+
index_params.add_index(
|
44 |
+
field_name="vector",
|
45 |
+
index_name="vector_index",
|
46 |
+
index_type="HNSW", #use HNSW option if got more mem, if not use IVF for faster processing
|
47 |
+
metric_type="IP",
|
48 |
+
params={
|
49 |
+
"M": 16, #M:16 for HNSW, capital M
|
50 |
+
"efConstruction": 500, #for HNSW
|
51 |
+
},
|
52 |
+
)
|
53 |
+
|
54 |
+
self.client.create_index(
|
55 |
+
collection_name=self.collection_name, index_params=index_params, sync=True
|
56 |
+
)
|
57 |
+
|
58 |
+
def search(self, data, topk):
|
59 |
+
# Retrieve all collection names from the Milvus client.
|
60 |
+
collections = self.client.list_collections()
|
61 |
+
|
62 |
+
# Set search parameters (here, using Inner Product metric).
|
63 |
+
search_params = {"metric_type": "IP", "params": {}}
|
64 |
+
|
65 |
+
# Set to store unique (doc_id, collection_name) pairs across all collections.
|
66 |
+
doc_collection_pairs = set()
|
67 |
+
|
68 |
+
# Query each collection individually
|
69 |
+
for collection in collections:
|
70 |
+
self.client.load_collection(collection_name=collection)
|
71 |
+
print("collection loaded:"+ collection)
|
72 |
+
results = self.client.search(
|
73 |
+
collection,
|
74 |
+
data,
|
75 |
+
limit=50, # Adjust limit per collection as needed.
|
76 |
+
output_fields=["vector", "seq_id", "doc_id"],
|
77 |
+
search_params=search_params,
|
78 |
+
)
|
79 |
+
# Accumulate document IDs along with their originating collection.
|
80 |
+
for r_id in range(len(results)):
|
81 |
+
for r in range(len(results[r_id])):
|
82 |
+
doc_id = results[r_id][r]["entity"]["doc_id"]
|
83 |
+
doc_collection_pairs.add((doc_id, collection))
|
84 |
+
|
85 |
+
scores = []
|
86 |
+
|
87 |
+
def rerank_single_doc(doc_id, data, client, collection_name):
|
88 |
+
# Query for detailed document vectors in the given collection.
|
89 |
+
doc_colbert_vecs = client.query(
|
90 |
+
collection_name=collection_name,
|
91 |
+
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
|
92 |
+
output_fields=["seq_id", "vector", "doc"],
|
93 |
+
limit=16380,
|
94 |
+
)
|
95 |
+
# Stack the vectors for dot product computation.
|
96 |
+
doc_vecs = np.vstack(
|
97 |
+
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
|
98 |
+
)
|
99 |
+
# Compute a similarity score via dot product.
|
100 |
+
score = np.dot(data, doc_vecs.T).max(1).sum()
|
101 |
+
return (score, doc_id, collection_name)
|
102 |
+
|
103 |
+
# Use a thread pool to rerank each document concurrently.
|
104 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
|
105 |
+
futures = {
|
106 |
+
executor.submit(rerank_single_doc, doc_id, data, self.client, collection): (doc_id, collection)
|
107 |
+
for doc_id, collection in doc_collection_pairs
|
108 |
+
}
|
109 |
+
for future in concurrent.futures.as_completed(futures):
|
110 |
+
score, doc_id, collection = future.result()
|
111 |
+
scores.append((score, doc_id, collection))
|
112 |
+
#doc_id is page number!
|
113 |
+
|
114 |
+
# Sort the reranked results by score in descending order.
|
115 |
+
scores.sort(key=lambda x: x[0], reverse=True)
|
116 |
+
# Unload the collection after search to free memory.
|
117 |
+
self.client.release_collection(collection_name=collection)
|
118 |
+
|
119 |
+
return scores[:topk] if len(scores) >= topk else scores
|
120 |
+
"""
|
121 |
+
search_params = {"metric_type": "IP", "params": {}}
|
122 |
+
results = self.client.search(
|
123 |
+
self.collection_name,
|
124 |
+
data,
|
125 |
+
limit=50,
|
126 |
+
output_fields=["vector", "seq_id", "doc_id"],
|
127 |
+
search_params=search_params,
|
128 |
+
)
|
129 |
+
doc_ids = {result["entity"]["doc_id"] for result in results[0]}
|
130 |
+
|
131 |
+
scores = []
|
132 |
+
|
133 |
+
def rerank_single_doc(doc_id, data, client, collection_name):
|
134 |
+
doc_colbert_vecs = client.query(
|
135 |
+
collection_name=collection_name,
|
136 |
+
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
|
137 |
+
output_fields=["seq_id", "vector", "doc"],
|
138 |
+
limit=1000,
|
139 |
+
)
|
140 |
+
doc_vecs = np.vstack(
|
141 |
+
[doc["vector"] for doc in doc_colbert_vecs]
|
142 |
+
)
|
143 |
+
score = np.dot(data, doc_vecs.T).max(1).sum()
|
144 |
+
return score, doc_id
|
145 |
+
|
146 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
|
147 |
+
futures = {
|
148 |
+
executor.submit(
|
149 |
+
rerank_single_doc, doc_id, data, self.client, self.collection_name
|
150 |
+
): doc_id
|
151 |
+
for doc_id in doc_ids
|
152 |
+
}
|
153 |
+
for future in concurrent.futures.as_completed(futures):
|
154 |
+
score, doc_id = future.result()
|
155 |
+
scores.append((score, doc_id))
|
156 |
+
|
157 |
+
scores.sort(key=lambda x: x[0], reverse=True)
|
158 |
+
return scores[:topk]
|
159 |
+
"""
|
160 |
+
|
161 |
+
def insert(self, data):
|
162 |
+
colbert_vecs = data["colbert_vecs"]
|
163 |
+
seq_length = len(colbert_vecs)
|
164 |
+
doc_ids = [data["doc_id"]] * seq_length
|
165 |
+
seq_ids = list(range(seq_length))
|
166 |
+
docs = [""] * seq_length
|
167 |
+
docs[0] = data["filepath"]
|
168 |
+
|
169 |
+
self.client.insert(
|
170 |
+
self.collection_name,
|
171 |
+
[
|
172 |
+
{
|
173 |
+
"vector": colbert_vecs[i],
|
174 |
+
"seq_id": seq_ids[i],
|
175 |
+
"doc_id": doc_ids[i],
|
176 |
+
"doc": docs[i],
|
177 |
+
}
|
178 |
+
for i in range(seq_length)
|
179 |
+
],
|
180 |
+
)
|
181 |
+
|
182 |
+
def get_images_as_doc(self, images_with_vectors):
|
183 |
+
return [
|
184 |
+
{
|
185 |
+
"colbert_vecs": image["colbert_vecs"],
|
186 |
+
"doc_id": idx,
|
187 |
+
"filepath": image["filepath"],
|
188 |
+
}
|
189 |
+
for idx, image in enumerate(images_with_vectors)
|
190 |
+
]
|
191 |
+
|
192 |
+
def insert_images_data(self, image_data):
|
193 |
+
data = self.get_images_as_doc(image_data)
|
194 |
+
for item in data:
|
195 |
+
self.insert(item)
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
poppler-utils
|
pdf_manager.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pdf2image import convert_from_path
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
class PdfManager:
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def clear_and_recreate_dir(self, output_folder):
|
10 |
+
|
11 |
+
print(f"Clearing output folder {output_folder}")
|
12 |
+
|
13 |
+
if os.path.exists(output_folder):
|
14 |
+
shutil.rmtree(output_folder)
|
15 |
+
#print("Clearing is unused for now for persistency")
|
16 |
+
else:
|
17 |
+
os.makedirs(output_folder)
|
18 |
+
|
19 |
+
#print("Clearing is unused for now for persistency")
|
20 |
+
|
21 |
+
def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
|
22 |
+
output_folder = f"pages/{id}" #remove last backslash to avoid error,test this
|
23 |
+
images = convert_from_path(pdf_path)
|
24 |
+
|
25 |
+
print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
|
26 |
+
|
27 |
+
self.clear_and_recreate_dir(output_folder)
|
28 |
+
|
29 |
+
num_page_processed = 0
|
30 |
+
|
31 |
+
for i, image in enumerate(images):
|
32 |
+
if max_pages and num_page_processed >= max_pages:
|
33 |
+
break
|
34 |
+
|
35 |
+
if pages and i not in pages:
|
36 |
+
continue
|
37 |
+
|
38 |
+
full_save_path = f"{output_folder}/page_{i + 1}.png"
|
39 |
+
|
40 |
+
#print(f"Saving image to {full_save_path}")
|
41 |
+
|
42 |
+
image.save(full_save_path, "PNG")
|
43 |
+
|
44 |
+
num_page_processed += 1
|
45 |
+
|
46 |
+
return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
|
rag.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import os
|
3 |
+
|
4 |
+
from typing import List
|
5 |
+
from utils import encode_image
|
6 |
+
from PIL import Image
|
7 |
+
from ollama import chat
|
8 |
+
import torch
|
9 |
+
import subprocess
|
10 |
+
import psutil
|
11 |
+
import torch
|
12 |
+
from transformers import AutoModel, AutoTokenizer
|
13 |
+
import google.generativeai as genai
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class Rag:
|
18 |
+
|
19 |
+
def get_answer_from_gemini(self, query, imagePaths):
|
20 |
+
|
21 |
+
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
|
22 |
+
|
23 |
+
try:
|
24 |
+
genai.configure(api_key="AIzaSyCwRr9054tCuh2S8yGpwKFvOAxYMT4WNIs")
|
25 |
+
model = genai.GenerativeModel('gemini-1.5-flash')
|
26 |
+
|
27 |
+
images = [Image.open(path) for path in imagePaths]
|
28 |
+
|
29 |
+
chat = model.start_chat()
|
30 |
+
|
31 |
+
response = chat.send_message([*images, query])
|
32 |
+
|
33 |
+
answer = response.text
|
34 |
+
|
35 |
+
print(answer)
|
36 |
+
|
37 |
+
return answer
|
38 |
+
|
39 |
+
except Exception as e:
|
40 |
+
print(f"An error occurred while querying Gemini: {e}")
|
41 |
+
return f"Error: {str(e)}"
|
42 |
+
|
43 |
+
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
|
44 |
+
|
45 |
+
def get_answer_from_openai(self, query, imagesPaths):
|
46 |
+
""" #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
|
47 |
+
print(f"Querying for query={query}, imagesPaths={imagesPaths}")
|
48 |
+
|
49 |
+
model = AutoModel.from_pretrained(
|
50 |
+
'openbmb/MiniCPM-o-2_6-int4',
|
51 |
+
trust_remote_code=True,
|
52 |
+
attn_implementation='flash_attention_2', # sdpa or flash_attention_2
|
53 |
+
torch_dtype=torch.bfloat16,
|
54 |
+
init_vision=True,
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
model = model.eval().cuda()
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2_6-int4', trust_remote_code=True)
|
60 |
+
image = Image.open(imagesPaths[0]).convert('RGB')
|
61 |
+
|
62 |
+
msgs = [{'role': 'user', 'content': [image, query]}]
|
63 |
+
answer = model.chat(
|
64 |
+
image=None,
|
65 |
+
msgs=msgs,
|
66 |
+
tokenizer=tokenizer
|
67 |
+
)
|
68 |
+
print(answer)
|
69 |
+
return answer
|
70 |
+
"""
|
71 |
+
|
72 |
+
#ollama method below
|
73 |
+
|
74 |
+
torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
|
75 |
+
|
76 |
+
|
77 |
+
os.environ['OLLAMA_FLASH_ATTENTION'] = '1'
|
78 |
+
|
79 |
+
|
80 |
+
# Close model thread (colpali)
|
81 |
+
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
|
82 |
+
|
83 |
+
try:
|
84 |
+
|
85 |
+
response = chat(
|
86 |
+
model='minicpm-v:8b-2.6-q8_0',
|
87 |
+
messages=[
|
88 |
+
{
|
89 |
+
'role': 'user',
|
90 |
+
'content': query,
|
91 |
+
'images': imagesPaths,
|
92 |
+
}
|
93 |
+
],
|
94 |
+
)
|
95 |
+
|
96 |
+
answer = response.message.content
|
97 |
+
|
98 |
+
print(answer)
|
99 |
+
|
100 |
+
return answer
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
print(f"An error occurred while querying OpenAI: {e}")
|
104 |
+
return None
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
|
109 |
+
image_payload = []
|
110 |
+
|
111 |
+
for imagePath in imagesPaths:
|
112 |
+
base64_image = encode_image(imagePath)
|
113 |
+
image_payload.append({
|
114 |
+
"type": "image_url",
|
115 |
+
"image_url": {
|
116 |
+
"url": f"data:image/jpeg;base64,{base64_image}"
|
117 |
+
}
|
118 |
+
})
|
119 |
+
|
120 |
+
payload = {
|
121 |
+
"model": "Llama3.2-vision", #change model here as needed
|
122 |
+
"messages": [
|
123 |
+
{
|
124 |
+
"role": "user",
|
125 |
+
"content": [
|
126 |
+
{
|
127 |
+
"type": "text",
|
128 |
+
"text": query
|
129 |
+
},
|
130 |
+
*image_payload
|
131 |
+
]
|
132 |
+
}
|
133 |
+
],
|
134 |
+
"max_tokens": 1024 #reduce token size to reduce processing time
|
135 |
+
}
|
136 |
+
|
137 |
+
return payload
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
# if __name__ == "__main__":
|
142 |
+
# rag = Rag()
|
143 |
+
|
144 |
+
# query = "Based on attached images, how many new cases were reported during second wave peak"
|
145 |
+
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
|
146 |
+
|
147 |
+
# rag.get_answer_from_gemini(query, imagesPaths)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
PyMuPDF
|
3 |
+
pdf2image
|
4 |
+
pymilvus
|
5 |
+
tqdm
|
6 |
+
pillow
|
7 |
+
spaces
|
8 |
+
google-generativeai
|
9 |
+
git+https://github.com/illuin-tech/colpali
|
10 |
+
timm==1.0.13
|
11 |
+
transformers
|
12 |
+
https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp311-cp311-win_amd64.wh
|
13 |
+
comtypes
|
14 |
+
python-dotenv
|
test.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymilvus import MilvusClient
|
2 |
+
from pymilvus import (
|
3 |
+
connections,
|
4 |
+
utility,
|
5 |
+
FieldSchema, CollectionSchema, DataType,
|
6 |
+
Collection,
|
7 |
+
)
|
8 |
+
|
9 |
+
# 1. Create a milvus client
|
10 |
+
client = MilvusClient(
|
11 |
+
uri="http://localhost:19530",
|
12 |
+
token="root:Milvus"
|
13 |
+
)
|
14 |
+
|
15 |
+
# 2. Create a collection
|
16 |
+
client.drop_collection(collection_name="fy2025_budget_statement")
|
17 |
+
|
18 |
+
# 3. List collections
|
19 |
+
print(client.list_collections() )
|
20 |
+
|
21 |
+
# ['test_collection']
|
22 |
+
|
23 |
+
"""
|
24 |
+
res = client.get(
|
25 |
+
collection_name="colpali",
|
26 |
+
ids=[0, 1, 2],
|
27 |
+
)
|
28 |
+
|
29 |
+
print(res)
|
30 |
+
"""
|
uploaded_files.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
EMERGING_MISSILE_THREATS_16382509
|
2 |
+
handwriting
|
3 |
+
multimediareport
|
utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
+
def encode_image(image_path):
|
4 |
+
with open(image_path, "rb") as image_file:
|
5 |
+
return base64.b64encode(image_file.read()).decode('utf-8')
|