Kazel
commited on
Commit
·
d8688d6
1
Parent(s):
48a037c
change
Browse files- .env +0 -8
- README.md +10 -11
- app.py +49 -84
- colpali_manager.py +26 -36
- milvus_manager.py +7 -16
- rag.py +8 -27
- requirements.txt +8 -14
- test.py +30 -0
- uploaded_files.txt +3 -0
.env
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
colpali='vidore/colSmol-256M'
|
2 |
-
ollama='minicpm-v'
|
3 |
-
flashattn='1'
|
4 |
-
metrictype='IP'
|
5 |
-
mnum='16'
|
6 |
-
efnum='500'
|
7 |
-
topk='50'
|
8 |
-
temperature='0.8'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
|
|
|
|
7 |
pinned: false
|
|
|
|
|
8 |
---
|
9 |
|
10 |
-
|
11 |
-
Code for blog [https://saumitra.me/2024/2024-11-15-colpali-milvus-rag/](https://saumitra.me/2024/2024-11-15-colpali-milvus-rag/) on how to do multimodal RAG with [colpali](https://arxiv.org/abs/2407.01449), [milvus](https://milvus.io/) and a visual LLM (gemini/gpt-4o)
|
12 |
-
|
13 |
-
Demo running at [https://huggingface.co/spaces/saumitras/colpali-milvus](https://huggingface.co/spaces/saumitras/colpali-milvus)
|
14 |
-
|
15 |
-
Application will allow users to upload a PDF and then perform search or Q&A queries on both the text and visual elements of the document. We will not extract text from the PDF; instead, we will treat it as an image and use colpali to get embeddings for the PDF pages. These embeddings will be indexed to Milvus, and then we will use a visual LLM (gemini/gpt-4o) to facilitate the Q&A queries.
|
|
|
1 |
---
|
2 |
+
title: Multimodal
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.20.1
|
8 |
+
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: cc-by-nc-sa-4.0
|
11 |
+
short_description: Demo for Collar's offline multimodal rag system
|
12 |
---
|
13 |
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -11,18 +11,6 @@ from rag import Rag
|
|
11 |
from pathlib import Path
|
12 |
import subprocess
|
13 |
import getpass
|
14 |
-
# importing necessary functions from dotenv library
|
15 |
-
from dotenv import load_dotenv, dotenv_values
|
16 |
-
import dotenv
|
17 |
-
import platform
|
18 |
-
import time
|
19 |
-
|
20 |
-
# loading variables from .env file
|
21 |
-
dotenv_file = dotenv.find_dotenv()
|
22 |
-
dotenv.load_dotenv(dotenv_file)
|
23 |
-
|
24 |
-
#kickstart docker and ollama servers
|
25 |
-
|
26 |
|
27 |
rag = Rag()
|
28 |
|
@@ -62,6 +50,30 @@ class PDFSearchApp:
|
|
62 |
pdf_path=file.name
|
63 |
#if ppt will get replaced with path of ppt!
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
# Replace spaces and hyphens with underscores in the name
|
66 |
modified_filename = name.replace(" ", "_").replace("-", "_")
|
67 |
|
@@ -142,7 +154,12 @@ class PDFSearchApp:
|
|
142 |
def delete(state,choice):
|
143 |
#delete file in pages, then use middleware to delete collection
|
144 |
# 1. Create a milvus client
|
|
|
145 |
client = MilvusClient(uri="localhost")
|
|
|
|
|
|
|
|
|
146 |
path = f"pages/{choice}"
|
147 |
if os.path.exists(path):
|
148 |
shutil.rmtree(path)
|
@@ -151,18 +168,6 @@ class PDFSearchApp:
|
|
151 |
return f"Deleted {choice}"
|
152 |
else:
|
153 |
return "Directory not found"
|
154 |
-
def dbupdate(state,metric_type,m_num,ef_num,topk):
|
155 |
-
os.environ['metrictype'] = metric_type
|
156 |
-
# Update the .env file with the new value
|
157 |
-
dotenv.set_key(dotenv_file, 'metrictype', metric_type)
|
158 |
-
os.environ['mnum'] = str(m_num)
|
159 |
-
dotenv.set_key(dotenv_file, 'mnum', str(m_num))
|
160 |
-
os.environ['efnum'] = str(ef_num)
|
161 |
-
dotenv.set_key(dotenv_file, 'efnum', str(ef_num))
|
162 |
-
os.environ['topk'] = str(topk)
|
163 |
-
dotenv.set_key(dotenv_file, 'topk', str(topk))
|
164 |
-
|
165 |
-
return "DB Settings Updated, Restart App To Load"
|
166 |
|
167 |
def list_downloaded_hf_models(state):
|
168 |
# Determine the cache directory
|
@@ -174,19 +179,18 @@ class PDFSearchApp:
|
|
174 |
# Traverse the cache directory
|
175 |
for repo_dir in hf_cache_dir.glob('models--*'):
|
176 |
# Extract the model name from the directory structure
|
177 |
-
model_name = repo_dir.name.split('--', 1)[-1].replace('
|
178 |
model_names.append(model_name)
|
179 |
|
180 |
return model_names
|
181 |
|
182 |
|
183 |
-
def list_downloaded_ollama_models(state):
|
184 |
# Retrieve the current user's name
|
185 |
username = getpass.getuser()
|
186 |
|
187 |
# Construct the target directory path
|
188 |
-
|
189 |
-
base_path = f"NEW_PATH\\manifests\\registry.ollama.ai\\library" #relative to proj dir! (IMPT: OLLAMA PULL COMMAND IN PROJ DIR!!!)
|
190 |
|
191 |
try:
|
192 |
# List all entries in the directory
|
@@ -202,29 +206,18 @@ class PDFSearchApp:
|
|
202 |
except Exception as e:
|
203 |
print(f"An error occurred: {e}")
|
204 |
|
205 |
-
def model_settings(state,hfchoice, ollamachoice,
|
206 |
os.environ['colpali'] = hfchoice
|
207 |
-
# Update the .env file with the new value
|
208 |
-
dotenv.set_key(dotenv_file, 'colpali', hfchoice)
|
209 |
os.environ['ollama'] = ollamachoice
|
210 |
-
|
211 |
-
|
212 |
-
os.environ['flashattn'] = "1"
|
213 |
-
dotenv.set_key(dotenv_file, 'flashattn', "1")
|
214 |
-
else:
|
215 |
-
os.environ['flashattn'] = "0"
|
216 |
-
dotenv.set_key(dotenv_file, 'flashattn', "0")
|
217 |
-
os.environ['temperature'] = str(temp)
|
218 |
-
dotenv.set_key(dotenv_file, 'temperature', str(temp))
|
219 |
-
|
220 |
-
return "Models Updated, Restart App To Use New Settings"
|
221 |
|
222 |
|
223 |
|
224 |
def create_ui():
|
225 |
app = PDFSearchApp()
|
226 |
|
227 |
-
with gr.Blocks(
|
228 |
state = gr.State(value={"user_uuid": None})
|
229 |
|
230 |
|
@@ -263,47 +256,26 @@ def create_ui():
|
|
263 |
with gr.Column():
|
264 |
# Button to delete (TBD)
|
265 |
choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
|
266 |
-
status1 = gr.Textbox(label="Deletion Status", interactive=False)
|
267 |
delete_button = gr.Button("Delete Document From DB")
|
268 |
-
|
269 |
-
# Create the dropdown component with default value as the first option
|
270 |
-
#Milvusindex = gr.Dropdown(["HNSW","FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "RHNSW_FLAT"], value="HNSW", label="Select Vector DB Index Parameter")
|
271 |
-
metric_type = gr.Dropdown(choices=["IP", "L2", "COSINE"],value="IP",label="Metric Type (Mathematical function to measure similarity)")
|
272 |
-
m_num = gr.Dropdown(
|
273 |
-
choices=["8", "16", "32", "64"], value="16",label="M Vectors (Maximum number of neighbors each node can connect to in the graph)")
|
274 |
-
ef_num = gr.Slider(
|
275 |
-
minimum=50,
|
276 |
-
maximum=1000,
|
277 |
-
value=500,
|
278 |
-
step=10,
|
279 |
-
label="EF Construction (Number of candidate neighbors considered for connection during index construction)"
|
280 |
-
)
|
281 |
-
topk = gr.Slider(
|
282 |
-
minimum=1,
|
283 |
-
maximum=100,
|
284 |
-
value=50,
|
285 |
-
step=1,
|
286 |
-
label="Top-K (Maximum number of entities to return in a single search of a document)"
|
287 |
-
)
|
288 |
-
db_button = gr.Button("Update DB Settings")
|
289 |
-
status3 = gr.Textbox(label="DB Update Status", interactive=False)
|
290 |
-
|
291 |
|
292 |
with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
|
293 |
with gr.Column():
|
294 |
# Button to delete (TBD)
|
295 |
-
hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),
|
296 |
-
ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
label="RAG Temperature"
|
304 |
)
|
305 |
model_button = gr.Button("Update Settings")
|
306 |
status2 = gr.Textbox(label="Update Status", interactive=False)
|
|
|
|
|
|
|
307 |
|
308 |
# Event handlers
|
309 |
file_input.change(
|
@@ -324,16 +296,10 @@ def create_ui():
|
|
324 |
inputs=[choice],
|
325 |
outputs=[status1]
|
326 |
)
|
327 |
-
|
328 |
-
db_button.click(
|
329 |
-
fn=app.dbupdate,
|
330 |
-
inputs=[metric_type,m_num,ef_num,topk],
|
331 |
-
outputs=[status3]
|
332 |
-
)
|
333 |
|
334 |
model_button.click(
|
335 |
fn=app.model_settings,
|
336 |
-
inputs=[hfchoice, ollamachoice,
|
337 |
outputs=[status2]
|
338 |
)
|
339 |
|
@@ -341,6 +307,5 @@ def create_ui():
|
|
341 |
|
342 |
if __name__ == "__main__":
|
343 |
demo = create_ui()
|
344 |
-
#demo.launch(auth=("admin", "pass1234")) for with login page config
|
345 |
demo.launch()
|
346 |
|
|
|
11 |
from pathlib import Path
|
12 |
import subprocess
|
13 |
import getpass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
rag = Rag()
|
16 |
|
|
|
50 |
pdf_path=file.name
|
51 |
#if ppt will get replaced with path of ppt!
|
52 |
|
53 |
+
#if extension is .ppt or .pptx, convert
|
54 |
+
if ext == ".ppt" or ext == ".pptx": #need to test with a ppt key...
|
55 |
+
'''
|
56 |
+
import comtypes.client
|
57 |
+
powerpoint = comtypes.client.CreateObject("PowerPoint.Application")
|
58 |
+
powerpoint.Visible = 1
|
59 |
+
presentation = powerpoint.Presentations.Open(file)
|
60 |
+
output_file = os.path.splitext(file)[0] + '.pdf'
|
61 |
+
output_directory = os.path.dirname(file)
|
62 |
+
presentation.SaveAs(os.path.join(output_directory, output_file), 32) # 32 is the formatType for PDF
|
63 |
+
presentation.Close()
|
64 |
+
powerpoint.Quit()
|
65 |
+
file = os.path.join(output_directory, output_file) #swap file to be used to the outputted pdf file instead
|
66 |
+
# Extract the last part of the path (file name)
|
67 |
+
name = os.path.basename(file)
|
68 |
+
# Split the base name into name and extension
|
69 |
+
name, ext = os.path.splitext(name)
|
70 |
+
print(name)
|
71 |
+
self.current_pdf = os.path.join(output_directory, output_file)
|
72 |
+
pdf_path = os.path.join(output_directory, output_file)'
|
73 |
+
'''
|
74 |
+
print("pptx not supported on spaces")
|
75 |
+
|
76 |
+
|
77 |
# Replace spaces and hyphens with underscores in the name
|
78 |
modified_filename = name.replace(" ", "_").replace("-", "_")
|
79 |
|
|
|
154 |
def delete(state,choice):
|
155 |
#delete file in pages, then use middleware to delete collection
|
156 |
# 1. Create a milvus client
|
157 |
+
|
158 |
client = MilvusClient(uri="localhost")
|
159 |
+
#client = MilvusClient(
|
160 |
+
# uri="http://localhost:19530",
|
161 |
+
# token="root:Milvus"
|
162 |
+
# )
|
163 |
path = f"pages/{choice}"
|
164 |
if os.path.exists(path):
|
165 |
shutil.rmtree(path)
|
|
|
168 |
return f"Deleted {choice}"
|
169 |
else:
|
170 |
return "Directory not found"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
def list_downloaded_hf_models(state):
|
173 |
# Determine the cache directory
|
|
|
179 |
# Traverse the cache directory
|
180 |
for repo_dir in hf_cache_dir.glob('models--*'):
|
181 |
# Extract the model name from the directory structure
|
182 |
+
model_name = repo_dir.name.split('--', 1)[-1].replace('-', '/')
|
183 |
model_names.append(model_name)
|
184 |
|
185 |
return model_names
|
186 |
|
187 |
|
188 |
+
def list_downloaded_ollama_models(state,):
|
189 |
# Retrieve the current user's name
|
190 |
username = getpass.getuser()
|
191 |
|
192 |
# Construct the target directory path
|
193 |
+
base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library"
|
|
|
194 |
|
195 |
try:
|
196 |
# List all entries in the directory
|
|
|
206 |
except Exception as e:
|
207 |
print(f"An error occurred: {e}")
|
208 |
|
209 |
+
def model_settings(state,hfchoice, ollamachoice,tokensize):
|
210 |
os.environ['colpali'] = hfchoice
|
|
|
|
|
211 |
os.environ['ollama'] = ollamachoice
|
212 |
+
os.environ['tokens'] = tokensize
|
213 |
+
return "abc"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
|
216 |
|
217 |
def create_ui():
|
218 |
app = PDFSearchApp()
|
219 |
|
220 |
+
with gr.Blocks(css="footer{display:none !important}") as demo:
|
221 |
state = gr.State(value={"user_uuid": None})
|
222 |
|
223 |
|
|
|
256 |
with gr.Column():
|
257 |
# Button to delete (TBD)
|
258 |
choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
|
|
|
259 |
delete_button = gr.Button("Delete Document From DB")
|
260 |
+
status1 = gr.Textbox(label="Deletion Status", interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
|
263 |
with gr.Column():
|
264 |
# Button to delete (TBD)
|
265 |
+
hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),label="Visual Document Retrieval (VDR) Model")
|
266 |
+
ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
|
267 |
+
tokensize = gr.Slider(
|
268 |
+
minimum=256,
|
269 |
+
maximum=4096,
|
270 |
+
value=20,
|
271 |
+
step=10,
|
272 |
+
label="Max tokens per response (Reply Length)"
|
|
|
273 |
)
|
274 |
model_button = gr.Button("Update Settings")
|
275 |
status2 = gr.Textbox(label="Update Status", interactive=False)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
|
280 |
# Event handlers
|
281 |
file_input.change(
|
|
|
296 |
inputs=[choice],
|
297 |
outputs=[status1]
|
298 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
299 |
|
300 |
model_button.click(
|
301 |
fn=app.model_settings,
|
302 |
+
inputs=[hfchoice, ollamachoice,tokensize],
|
303 |
outputs=[status2]
|
304 |
)
|
305 |
|
|
|
307 |
|
308 |
if __name__ == "__main__":
|
309 |
demo = create_ui()
|
|
|
310 |
demo.launch()
|
311 |
|
colpali_manager.py
CHANGED
@@ -17,16 +17,9 @@ import spaces
|
|
17 |
|
18 |
|
19 |
#this part is for local runs
|
20 |
-
torch.cuda.empty_cache()
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
# Load the .env file
|
25 |
-
dotenv_file = dotenv.find_dotenv()
|
26 |
-
dotenv.load_dotenv(dotenv_file)
|
27 |
-
|
28 |
-
model_name = os.environ['colpali'] #"vidore/colSmol-256M"
|
29 |
-
device = get_torch_device("cuda") #try using cpu instead of cuda?
|
30 |
|
31 |
#switch to locally downloading models & loading locally rather than from hf
|
32 |
#
|
@@ -35,48 +28,45 @@ current_working_directory = os.getcwd()
|
|
35 |
save_directory = model_name # Directory to save the specific model name
|
36 |
save_directory = os.path.join(current_working_directory, save_directory)
|
37 |
|
38 |
-
processor_directory =
|
39 |
processor_directory = os.path.join(current_working_directory, processor_directory)
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
|
|
43 |
if not os.path.exists(save_directory): #download if directory not created/model not loaded
|
44 |
# Directory does not exist; create it
|
45 |
-
if "colSmol" in model_name: #if colsmol
|
46 |
-
model = ColIdefics3.from_pretrained(
|
47 |
-
model_name,
|
48 |
-
torch_dtype=torch.bfloat16,
|
49 |
-
device_map=device,
|
50 |
-
attn_implementation="flash_attention_2",
|
51 |
-
).eval()
|
52 |
-
processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
|
53 |
-
else: #if colpali v1.3 etc
|
54 |
-
model = ColPali.from_pretrained(
|
55 |
-
model_name,
|
56 |
-
torch_dtype=torch.bfloat16,
|
57 |
-
device_map=device,
|
58 |
-
attn_implementation="flash_attention_2",
|
59 |
-
).eval()
|
60 |
-
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
61 |
os.makedirs(save_directory)
|
62 |
print(f"Directory '{save_directory}' created.")
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
model.save_pretrained(save_directory)
|
64 |
os.makedirs(processor_directory)
|
|
|
|
|
65 |
processor.save_pretrained(processor_directory)
|
66 |
|
67 |
else:
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
else:
|
72 |
-
model = ColPali.from_pretrained(save_directory)
|
73 |
-
processor = ColPaliProcessor.from_pretrained(processor_directory, use_fast=True)
|
74 |
|
75 |
|
76 |
class ColpaliManager:
|
77 |
|
78 |
|
79 |
-
def __init__(self, device = "
|
80 |
|
81 |
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
|
82 |
|
@@ -92,12 +82,12 @@ class ColpaliManager:
|
|
92 |
|
93 |
@spaces.GPU
|
94 |
def get_images(self, paths: list[str]) -> List[Image.Image]:
|
95 |
-
model.to("
|
96 |
return [Image.open(path) for path in paths]
|
97 |
|
98 |
@spaces.GPU
|
99 |
def process_images(self, image_paths:list[str], batch_size=5):
|
100 |
-
model.to("
|
101 |
print(f"Processing {len(image_paths)} image_paths")
|
102 |
|
103 |
images = self.get_images(image_paths)
|
@@ -123,7 +113,7 @@ class ColpaliManager:
|
|
123 |
|
124 |
@spaces.GPU
|
125 |
def process_text(self, texts: list[str]):
|
126 |
-
model.to("
|
127 |
print(f"Processing {len(texts)} texts")
|
128 |
|
129 |
dataloader = DataLoader(
|
|
|
17 |
|
18 |
|
19 |
#this part is for local runs
|
|
|
20 |
|
21 |
+
model_name = "vidore/colSmol-256M"
|
22 |
+
device = get_torch_device("cpu") #try using cpu instead of cpu?
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
#switch to locally downloading models & loading locally rather than from hf
|
25 |
#
|
|
|
28 |
save_directory = model_name # Directory to save the specific model name
|
29 |
save_directory = os.path.join(current_working_directory, save_directory)
|
30 |
|
31 |
+
processor_directory = 'local_processor' # Directory to save the processor
|
32 |
processor_directory = os.path.join(current_working_directory, processor_directory)
|
33 |
|
34 |
|
35 |
+
model = ColIdefics3.from_pretrained(
|
36 |
+
model_name,
|
37 |
+
torch_dtype=torch.bfloat16,
|
38 |
+
device_map=device,
|
39 |
+
#attn_implementation="flash_attention_2",
|
40 |
+
).eval()
|
41 |
+
processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
|
42 |
|
43 |
+
"""
|
44 |
if not os.path.exists(save_directory): #download if directory not created/model not loaded
|
45 |
# Directory does not exist; create it
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
os.makedirs(save_directory)
|
47 |
print(f"Directory '{save_directory}' created.")
|
48 |
+
model = ColIdefics3.from_pretrained(
|
49 |
+
model_name,
|
50 |
+
torch_dtype=torch.bfloat16,
|
51 |
+
device_map=device,
|
52 |
+
attn_implementation="flash_attention_2",
|
53 |
+
).eval()
|
54 |
model.save_pretrained(save_directory)
|
55 |
os.makedirs(processor_directory)
|
56 |
+
processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
|
57 |
+
|
58 |
processor.save_pretrained(processor_directory)
|
59 |
|
60 |
else:
|
61 |
+
model = ColIdefics3.from_pretrained(save_directory)
|
62 |
+
processor = ColIdefics3.from_pretrained(processor_directory, use_fast=True)
|
63 |
+
"""
|
|
|
|
|
|
|
64 |
|
65 |
|
66 |
class ColpaliManager:
|
67 |
|
68 |
|
69 |
+
def __init__(self, device = "cpu", model_name = "vidore/colSmol-256M"): #need to hot potato/use diff gpus between colpali & ollama
|
70 |
|
71 |
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
|
72 |
|
|
|
82 |
|
83 |
@spaces.GPU
|
84 |
def get_images(self, paths: list[str]) -> List[Image.Image]:
|
85 |
+
model.to("cpu")
|
86 |
return [Image.open(path) for path in paths]
|
87 |
|
88 |
@spaces.GPU
|
89 |
def process_images(self, image_paths:list[str], batch_size=5):
|
90 |
+
model.to("cpu")
|
91 |
print(f"Processing {len(image_paths)} image_paths")
|
92 |
|
93 |
images = self.get_images(image_paths)
|
|
|
113 |
|
114 |
@spaces.GPU
|
115 |
def process_text(self, texts: list[str]):
|
116 |
+
model.to("cpu") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
|
117 |
print(f"Processing {len(texts)} texts")
|
118 |
|
119 |
dataloader = DataLoader(
|
milvus_manager.py
CHANGED
@@ -2,18 +2,11 @@ from pymilvus import MilvusClient, DataType
|
|
2 |
import numpy as np
|
3 |
import concurrent.futures
|
4 |
from pymilvus import Collection
|
5 |
-
import os
|
6 |
|
7 |
class MilvusManager:
|
8 |
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
|
9 |
-
|
10 |
-
#
|
11 |
-
import dotenv
|
12 |
-
# Load the .env file
|
13 |
-
dotenv_file = dotenv.find_dotenv()
|
14 |
-
dotenv.load_dotenv(dotenv_file)
|
15 |
-
|
16 |
-
self.client = MilvusClient(uri="localhost")
|
17 |
self.collection_name = collection_name
|
18 |
self.dim = dim
|
19 |
|
@@ -47,15 +40,13 @@ class MilvusManager:
|
|
47 |
|
48 |
def create_index(self):
|
49 |
index_params = self.client.prepare_index_params()
|
50 |
-
|
51 |
index_params.add_index(
|
52 |
field_name="vector",
|
53 |
index_name="vector_index",
|
54 |
-
index_type="
|
55 |
-
metric_type=
|
56 |
params={
|
57 |
-
"
|
58 |
-
"efConstruction": int(os.environ["efnum"]), #500 for HNSW
|
59 |
},
|
60 |
)
|
61 |
|
@@ -68,7 +59,7 @@ class MilvusManager:
|
|
68 |
collections = self.client.list_collections()
|
69 |
|
70 |
# Set search parameters (here, using Inner Product metric).
|
71 |
-
search_params = {"metric_type":
|
72 |
|
73 |
# Set to store unique (doc_id, collection_name) pairs across all collections.
|
74 |
doc_collection_pairs = set()
|
@@ -80,7 +71,7 @@ class MilvusManager:
|
|
80 |
results = self.client.search(
|
81 |
collection,
|
82 |
data,
|
83 |
-
limit=
|
84 |
output_fields=["vector", "seq_id", "doc_id"],
|
85 |
search_params=search_params,
|
86 |
)
|
|
|
2 |
import numpy as np
|
3 |
import concurrent.futures
|
4 |
from pymilvus import Collection
|
|
|
5 |
|
6 |
class MilvusManager:
|
7 |
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
|
8 |
+
self.client = MilvusClient(uri=milvus_uri)
|
9 |
+
# self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
self.collection_name = collection_name
|
11 |
self.dim = dim
|
12 |
|
|
|
40 |
|
41 |
def create_index(self):
|
42 |
index_params = self.client.prepare_index_params()
|
|
|
43 |
index_params.add_index(
|
44 |
field_name="vector",
|
45 |
index_name="vector_index",
|
46 |
+
index_type="IVF_FLAT", #use HNSW option if got more mem, if not use IVF for faster processing
|
47 |
+
metric_type="IP",
|
48 |
params={
|
49 |
+
"nlist": 1024
|
|
|
50 |
},
|
51 |
)
|
52 |
|
|
|
59 |
collections = self.client.list_collections()
|
60 |
|
61 |
# Set search parameters (here, using Inner Product metric).
|
62 |
+
search_params = {"metric_type": "IP", "params": {}}
|
63 |
|
64 |
# Set to store unique (doc_id, collection_name) pairs across all collections.
|
65 |
doc_collection_pairs = set()
|
|
|
71 |
results = self.client.search(
|
72 |
collection,
|
73 |
data,
|
74 |
+
limit=50, # Adjust limit per collection as needed.
|
75 |
output_fields=["vector", "seq_id", "doc_id"],
|
76 |
search_params=search_params,
|
77 |
)
|
rag.py
CHANGED
@@ -1,17 +1,10 @@
|
|
1 |
import requests
|
2 |
import os
|
|
|
3 |
|
4 |
from typing import List
|
5 |
from utils import encode_image
|
6 |
from PIL import Image
|
7 |
-
from ollama import chat
|
8 |
-
import torch
|
9 |
-
import subprocess
|
10 |
-
import psutil
|
11 |
-
import torch
|
12 |
-
from transformers import AutoModel, AutoTokenizer
|
13 |
-
import google.generativeai as genai
|
14 |
-
|
15 |
|
16 |
|
17 |
|
@@ -40,17 +33,10 @@ class Rag:
|
|
40 |
except Exception as e:
|
41 |
print(f"An error occurred while querying Gemini: {e}")
|
42 |
return f"Error: {str(e)}"
|
43 |
-
|
44 |
|
45 |
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
|
46 |
-
|
47 |
-
def get_answer_from_openai(self, query, imagesPaths):
|
48 |
-
#import environ variables from .env
|
49 |
-
import dotenv
|
50 |
|
51 |
-
|
52 |
-
dotenv_file = dotenv.find_dotenv()
|
53 |
-
dotenv.load_dotenv(dotenv_file)
|
54 |
""" #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
|
55 |
print(f"Querying for query={query}, imagesPaths={imagesPaths}")
|
56 |
|
@@ -79,12 +65,8 @@ class Rag:
|
|
79 |
|
80 |
#ollama method below
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
|
86 |
-
if os.environ['ollama'] == "minicpm-v":
|
87 |
-
os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
|
88 |
|
89 |
|
90 |
# Close model thread (colpali)
|
@@ -92,14 +74,13 @@ class Rag:
|
|
92 |
|
93 |
try:
|
94 |
|
95 |
-
response = chat(
|
96 |
-
|
97 |
-
|
98 |
{
|
99 |
'role': 'user',
|
100 |
'content': query,
|
101 |
'images': imagesPaths,
|
102 |
-
"temperature":float(os.environ['temperature']), #test if temp makes a diff
|
103 |
}
|
104 |
],
|
105 |
)
|
@@ -155,4 +136,4 @@ class Rag:
|
|
155 |
# query = "Based on attached images, how many new cases were reported during second wave peak"
|
156 |
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
|
157 |
|
158 |
-
# rag.get_answer_from_gemini(query, imagesPaths)
|
|
|
1 |
import requests
|
2 |
import os
|
3 |
+
import google.generativeai as genai
|
4 |
|
5 |
from typing import List
|
6 |
from utils import encode_image
|
7 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
|
|
|
33 |
except Exception as e:
|
34 |
print(f"An error occurred while querying Gemini: {e}")
|
35 |
return f"Error: {str(e)}"
|
|
|
36 |
|
37 |
#os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
def get_answer_from_openai(self, query, imagesPaths):
|
|
|
|
|
40 |
""" #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
|
41 |
print(f"Querying for query={query}, imagesPaths={imagesPaths}")
|
42 |
|
|
|
65 |
|
66 |
#ollama method below
|
67 |
|
68 |
+
|
69 |
+
os.environ['OLLAMA_FLASH_ATTENTION'] = '1'
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
# Close model thread (colpali)
|
|
|
74 |
|
75 |
try:
|
76 |
|
77 |
+
response = chat(
|
78 |
+
model='minicpm-v:8b-2.6-q8_0',
|
79 |
+
messages=[
|
80 |
{
|
81 |
'role': 'user',
|
82 |
'content': query,
|
83 |
'images': imagesPaths,
|
|
|
84 |
}
|
85 |
],
|
86 |
)
|
|
|
136 |
# query = "Based on attached images, how many new cases were reported during second wave peak"
|
137 |
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
|
138 |
|
139 |
+
# rag.get_answer_from_gemini(query, imagesPaths)
|
requirements.txt
CHANGED
@@ -1,15 +1,9 @@
|
|
1 |
-
gradio
|
2 |
-
PyMuPDF
|
3 |
-
pdf2image
|
4 |
-
pymilvus
|
5 |
-
tqdm
|
6 |
-
pillow
|
7 |
-
spaces
|
8 |
-
google-generativeai
|
9 |
git+https://github.com/illuin-tech/colpali
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
git+https://github.com/illuin-tech/colpali
|
2 |
+
gradio==4.25.0
|
3 |
+
PyMuPDF==1.24.9
|
4 |
+
pdf2image==1.17.0
|
5 |
+
pymilvus==2.4.9
|
6 |
+
tqdm==4.66.5
|
7 |
+
pillow==10.4.0
|
8 |
+
spaces==0.30.4
|
9 |
+
google-generativeai==0.8.3
|
test.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymilvus import MilvusClient
|
2 |
+
from pymilvus import (
|
3 |
+
connections,
|
4 |
+
utility,
|
5 |
+
FieldSchema, CollectionSchema, DataType,
|
6 |
+
Collection,
|
7 |
+
)
|
8 |
+
|
9 |
+
# 1. Create a milvus client
|
10 |
+
client = MilvusClient(
|
11 |
+
uri="http://localhost:19530",
|
12 |
+
token="root:Milvus"
|
13 |
+
)
|
14 |
+
|
15 |
+
# 2. Create a collection
|
16 |
+
client.drop_collection(collection_name="fy2025_budget_statement")
|
17 |
+
|
18 |
+
# 3. List collections
|
19 |
+
print(client.list_collections() )
|
20 |
+
|
21 |
+
# ['test_collection']
|
22 |
+
|
23 |
+
"""
|
24 |
+
res = client.get(
|
25 |
+
collection_name="colpali",
|
26 |
+
ids=[0, 1, 2],
|
27 |
+
)
|
28 |
+
|
29 |
+
print(res)
|
30 |
+
"""
|
uploaded_files.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
EMERGING_MISSILE_THREATS_16382509
|
2 |
+
handwriting
|
3 |
+
multimediareport
|