Kazel commited on
Commit
d8688d6
·
1 Parent(s): 48a037c
Files changed (9) hide show
  1. .env +0 -8
  2. README.md +10 -11
  3. app.py +49 -84
  4. colpali_manager.py +26 -36
  5. milvus_manager.py +7 -16
  6. rag.py +8 -27
  7. requirements.txt +8 -14
  8. test.py +30 -0
  9. uploaded_files.txt +3 -0
.env DELETED
@@ -1,8 +0,0 @@
1
- colpali='vidore/colSmol-256M'
2
- ollama='minicpm-v'
3
- flashattn='1'
4
- metrictype='IP'
5
- mnum='16'
6
- efnum='500'
7
- topk='50'
8
- temperature='0.8'
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,15 +1,14 @@
1
  ---
2
- title: Test
3
- emoji:
4
- colorFrom: pink
5
- colorTo: blue
6
- sdk: static
 
 
7
  pinned: false
 
 
8
  ---
9
 
10
-
11
- Code for blog [https://saumitra.me/2024/2024-11-15-colpali-milvus-rag/](https://saumitra.me/2024/2024-11-15-colpali-milvus-rag/) on how to do multimodal RAG with [colpali](https://arxiv.org/abs/2407.01449), [milvus](https://milvus.io/) and a visual LLM (gemini/gpt-4o)
12
-
13
- Demo running at [https://huggingface.co/spaces/saumitras/colpali-milvus](https://huggingface.co/spaces/saumitras/colpali-milvus)
14
-
15
- Application will allow users to upload a PDF and then perform search or Q&A queries on both the text and visual elements of the document. We will not extract text from the PDF; instead, we will treat it as an image and use colpali to get embeddings for the PDF pages. These embeddings will be indexed to Milvus, and then we will use a visual LLM (gemini/gpt-4o) to facilitate the Q&A queries.
 
1
  ---
2
+ title: Multimodal
3
+ emoji: 🦀
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.20.1
8
+ app_file: app.py
9
  pinned: false
10
+ license: cc-by-nc-sa-4.0
11
+ short_description: Demo for Collar's offline multimodal rag system
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
app.py CHANGED
@@ -11,18 +11,6 @@ from rag import Rag
11
  from pathlib import Path
12
  import subprocess
13
  import getpass
14
- # importing necessary functions from dotenv library
15
- from dotenv import load_dotenv, dotenv_values
16
- import dotenv
17
- import platform
18
- import time
19
-
20
- # loading variables from .env file
21
- dotenv_file = dotenv.find_dotenv()
22
- dotenv.load_dotenv(dotenv_file)
23
-
24
- #kickstart docker and ollama servers
25
-
26
 
27
  rag = Rag()
28
 
@@ -62,6 +50,30 @@ class PDFSearchApp:
62
  pdf_path=file.name
63
  #if ppt will get replaced with path of ppt!
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Replace spaces and hyphens with underscores in the name
66
  modified_filename = name.replace(" ", "_").replace("-", "_")
67
 
@@ -142,7 +154,12 @@ class PDFSearchApp:
142
  def delete(state,choice):
143
  #delete file in pages, then use middleware to delete collection
144
  # 1. Create a milvus client
 
145
  client = MilvusClient(uri="localhost")
 
 
 
 
146
  path = f"pages/{choice}"
147
  if os.path.exists(path):
148
  shutil.rmtree(path)
@@ -151,18 +168,6 @@ class PDFSearchApp:
151
  return f"Deleted {choice}"
152
  else:
153
  return "Directory not found"
154
- def dbupdate(state,metric_type,m_num,ef_num,topk):
155
- os.environ['metrictype'] = metric_type
156
- # Update the .env file with the new value
157
- dotenv.set_key(dotenv_file, 'metrictype', metric_type)
158
- os.environ['mnum'] = str(m_num)
159
- dotenv.set_key(dotenv_file, 'mnum', str(m_num))
160
- os.environ['efnum'] = str(ef_num)
161
- dotenv.set_key(dotenv_file, 'efnum', str(ef_num))
162
- os.environ['topk'] = str(topk)
163
- dotenv.set_key(dotenv_file, 'topk', str(topk))
164
-
165
- return "DB Settings Updated, Restart App To Load"
166
 
167
  def list_downloaded_hf_models(state):
168
  # Determine the cache directory
@@ -174,19 +179,18 @@ class PDFSearchApp:
174
  # Traverse the cache directory
175
  for repo_dir in hf_cache_dir.glob('models--*'):
176
  # Extract the model name from the directory structure
177
- model_name = repo_dir.name.split('--', 1)[-1].replace('--', '/')
178
  model_names.append(model_name)
179
 
180
  return model_names
181
 
182
 
183
- def list_downloaded_ollama_models(state):
184
  # Retrieve the current user's name
185
  username = getpass.getuser()
186
 
187
  # Construct the target directory path
188
- #base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library" #this is for if ollama pull is called from C://, if ollama pulls are called from the proj dir, use the NEW_PATH in the proj dir!
189
- base_path = f"NEW_PATH\\manifests\\registry.ollama.ai\\library" #relative to proj dir! (IMPT: OLLAMA PULL COMMAND IN PROJ DIR!!!)
190
 
191
  try:
192
  # List all entries in the directory
@@ -202,29 +206,18 @@ class PDFSearchApp:
202
  except Exception as e:
203
  print(f"An error occurred: {e}")
204
 
205
- def model_settings(state,hfchoice, ollamachoice,flash, temp):
206
  os.environ['colpali'] = hfchoice
207
- # Update the .env file with the new value
208
- dotenv.set_key(dotenv_file, 'colpali', hfchoice)
209
  os.environ['ollama'] = ollamachoice
210
- dotenv.set_key(dotenv_file, 'ollama', ollamachoice)
211
- if flash == "Enabled":
212
- os.environ['flashattn'] = "1"
213
- dotenv.set_key(dotenv_file, 'flashattn', "1")
214
- else:
215
- os.environ['flashattn'] = "0"
216
- dotenv.set_key(dotenv_file, 'flashattn', "0")
217
- os.environ['temperature'] = str(temp)
218
- dotenv.set_key(dotenv_file, 'temperature', str(temp))
219
-
220
- return "Models Updated, Restart App To Use New Settings"
221
 
222
 
223
 
224
  def create_ui():
225
  app = PDFSearchApp()
226
 
227
- with gr.Blocks(theme=gr.themes.Ocean(),css ="footer{display:none !important}") as demo:
228
  state = gr.State(value={"user_uuid": None})
229
 
230
 
@@ -263,47 +256,26 @@ def create_ui():
263
  with gr.Column():
264
  # Button to delete (TBD)
265
  choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
266
- status1 = gr.Textbox(label="Deletion Status", interactive=False)
267
  delete_button = gr.Button("Delete Document From DB")
268
-
269
- # Create the dropdown component with default value as the first option
270
- #Milvusindex = gr.Dropdown(["HNSW","FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "RHNSW_FLAT"], value="HNSW", label="Select Vector DB Index Parameter")
271
- metric_type = gr.Dropdown(choices=["IP", "L2", "COSINE"],value="IP",label="Metric Type (Mathematical function to measure similarity)")
272
- m_num = gr.Dropdown(
273
- choices=["8", "16", "32", "64"], value="16",label="M Vectors (Maximum number of neighbors each node can connect to in the graph)")
274
- ef_num = gr.Slider(
275
- minimum=50,
276
- maximum=1000,
277
- value=500,
278
- step=10,
279
- label="EF Construction (Number of candidate neighbors considered for connection during index construction)"
280
- )
281
- topk = gr.Slider(
282
- minimum=1,
283
- maximum=100,
284
- value=50,
285
- step=1,
286
- label="Top-K (Maximum number of entities to return in a single search of a document)"
287
- )
288
- db_button = gr.Button("Update DB Settings")
289
- status3 = gr.Textbox(label="DB Update Status", interactive=False)
290
-
291
 
292
  with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
293
  with gr.Column():
294
  # Button to delete (TBD)
295
- hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),value=os.environ['colpali'], label="Visual Document Retrieval (VDR) Model")
296
- ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),value=os.environ['ollama'],label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
297
- flash = gr.Dropdown(["Enabled","Disabled"], value = "Enabled",label ="Flash Attention 2.0 Acceleration")
298
- temp = gr.Slider(
299
- minimum=0.1,
300
- maximum=1,
301
- value=0.8,
302
- step=0.1,
303
- label="RAG Temperature"
304
  )
305
  model_button = gr.Button("Update Settings")
306
  status2 = gr.Textbox(label="Update Status", interactive=False)
 
 
 
307
 
308
  # Event handlers
309
  file_input.change(
@@ -324,16 +296,10 @@ def create_ui():
324
  inputs=[choice],
325
  outputs=[status1]
326
  )
327
-
328
- db_button.click(
329
- fn=app.dbupdate,
330
- inputs=[metric_type,m_num,ef_num,topk],
331
- outputs=[status3]
332
- )
333
 
334
  model_button.click(
335
  fn=app.model_settings,
336
- inputs=[hfchoice, ollamachoice,flash,temp],
337
  outputs=[status2]
338
  )
339
 
@@ -341,6 +307,5 @@ def create_ui():
341
 
342
  if __name__ == "__main__":
343
  demo = create_ui()
344
- #demo.launch(auth=("admin", "pass1234")) for with login page config
345
  demo.launch()
346
 
 
11
  from pathlib import Path
12
  import subprocess
13
  import getpass
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  rag = Rag()
16
 
 
50
  pdf_path=file.name
51
  #if ppt will get replaced with path of ppt!
52
 
53
+ #if extension is .ppt or .pptx, convert
54
+ if ext == ".ppt" or ext == ".pptx": #need to test with a ppt key...
55
+ '''
56
+ import comtypes.client
57
+ powerpoint = comtypes.client.CreateObject("PowerPoint.Application")
58
+ powerpoint.Visible = 1
59
+ presentation = powerpoint.Presentations.Open(file)
60
+ output_file = os.path.splitext(file)[0] + '.pdf'
61
+ output_directory = os.path.dirname(file)
62
+ presentation.SaveAs(os.path.join(output_directory, output_file), 32) # 32 is the formatType for PDF
63
+ presentation.Close()
64
+ powerpoint.Quit()
65
+ file = os.path.join(output_directory, output_file) #swap file to be used to the outputted pdf file instead
66
+ # Extract the last part of the path (file name)
67
+ name = os.path.basename(file)
68
+ # Split the base name into name and extension
69
+ name, ext = os.path.splitext(name)
70
+ print(name)
71
+ self.current_pdf = os.path.join(output_directory, output_file)
72
+ pdf_path = os.path.join(output_directory, output_file)'
73
+ '''
74
+ print("pptx not supported on spaces")
75
+
76
+
77
  # Replace spaces and hyphens with underscores in the name
78
  modified_filename = name.replace(" ", "_").replace("-", "_")
79
 
 
154
  def delete(state,choice):
155
  #delete file in pages, then use middleware to delete collection
156
  # 1. Create a milvus client
157
+
158
  client = MilvusClient(uri="localhost")
159
+ #client = MilvusClient(
160
+ # uri="http://localhost:19530",
161
+ # token="root:Milvus"
162
+ # )
163
  path = f"pages/{choice}"
164
  if os.path.exists(path):
165
  shutil.rmtree(path)
 
168
  return f"Deleted {choice}"
169
  else:
170
  return "Directory not found"
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def list_downloaded_hf_models(state):
173
  # Determine the cache directory
 
179
  # Traverse the cache directory
180
  for repo_dir in hf_cache_dir.glob('models--*'):
181
  # Extract the model name from the directory structure
182
+ model_name = repo_dir.name.split('--', 1)[-1].replace('-', '/')
183
  model_names.append(model_name)
184
 
185
  return model_names
186
 
187
 
188
+ def list_downloaded_ollama_models(state,):
189
  # Retrieve the current user's name
190
  username = getpass.getuser()
191
 
192
  # Construct the target directory path
193
+ base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library"
 
194
 
195
  try:
196
  # List all entries in the directory
 
206
  except Exception as e:
207
  print(f"An error occurred: {e}")
208
 
209
+ def model_settings(state,hfchoice, ollamachoice,tokensize):
210
  os.environ['colpali'] = hfchoice
 
 
211
  os.environ['ollama'] = ollamachoice
212
+ os.environ['tokens'] = tokensize
213
+ return "abc"
 
 
 
 
 
 
 
 
 
214
 
215
 
216
 
217
  def create_ui():
218
  app = PDFSearchApp()
219
 
220
+ with gr.Blocks(css="footer{display:none !important}") as demo:
221
  state = gr.State(value={"user_uuid": None})
222
 
223
 
 
256
  with gr.Column():
257
  # Button to delete (TBD)
258
  choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
 
259
  delete_button = gr.Button("Delete Document From DB")
260
+ status1 = gr.Textbox(label="Deletion Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
263
  with gr.Column():
264
  # Button to delete (TBD)
265
+ hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),label="Visual Document Retrieval (VDR) Model")
266
+ ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
267
+ tokensize = gr.Slider(
268
+ minimum=256,
269
+ maximum=4096,
270
+ value=20,
271
+ step=10,
272
+ label="Max tokens per response (Reply Length)"
 
273
  )
274
  model_button = gr.Button("Update Settings")
275
  status2 = gr.Textbox(label="Update Status", interactive=False)
276
+
277
+
278
+
279
 
280
  # Event handlers
281
  file_input.change(
 
296
  inputs=[choice],
297
  outputs=[status1]
298
  )
 
 
 
 
 
 
299
 
300
  model_button.click(
301
  fn=app.model_settings,
302
+ inputs=[hfchoice, ollamachoice,tokensize],
303
  outputs=[status2]
304
  )
305
 
 
307
 
308
  if __name__ == "__main__":
309
  demo = create_ui()
 
310
  demo.launch()
311
 
colpali_manager.py CHANGED
@@ -17,16 +17,9 @@ import spaces
17
 
18
 
19
  #this part is for local runs
20
- torch.cuda.empty_cache()
21
 
22
- #get model name from .env variable & set directory & processor dir as the model names!
23
- import dotenv
24
- # Load the .env file
25
- dotenv_file = dotenv.find_dotenv()
26
- dotenv.load_dotenv(dotenv_file)
27
-
28
- model_name = os.environ['colpali'] #"vidore/colSmol-256M"
29
- device = get_torch_device("cuda") #try using cpu instead of cuda?
30
 
31
  #switch to locally downloading models & loading locally rather than from hf
32
  #
@@ -35,48 +28,45 @@ current_working_directory = os.getcwd()
35
  save_directory = model_name # Directory to save the specific model name
36
  save_directory = os.path.join(current_working_directory, save_directory)
37
 
38
- processor_directory = model_name+'_processor' # Directory to save the processor
39
  processor_directory = os.path.join(current_working_directory, processor_directory)
40
 
41
 
 
 
 
 
 
 
 
42
 
 
43
  if not os.path.exists(save_directory): #download if directory not created/model not loaded
44
  # Directory does not exist; create it
45
- if "colSmol" in model_name: #if colsmol
46
- model = ColIdefics3.from_pretrained(
47
- model_name,
48
- torch_dtype=torch.bfloat16,
49
- device_map=device,
50
- attn_implementation="flash_attention_2",
51
- ).eval()
52
- processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
53
- else: #if colpali v1.3 etc
54
- model = ColPali.from_pretrained(
55
- model_name,
56
- torch_dtype=torch.bfloat16,
57
- device_map=device,
58
- attn_implementation="flash_attention_2",
59
- ).eval()
60
- processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
61
  os.makedirs(save_directory)
62
  print(f"Directory '{save_directory}' created.")
 
 
 
 
 
 
63
  model.save_pretrained(save_directory)
64
  os.makedirs(processor_directory)
 
 
65
  processor.save_pretrained(processor_directory)
66
 
67
  else:
68
- if "colSmol" in model_name:
69
- model = ColIdefics3.from_pretrained(save_directory)
70
- processor = ColIdefics3Processor.from_pretrained(processor_directory, use_fast=True)
71
- else:
72
- model = ColPali.from_pretrained(save_directory)
73
- processor = ColPaliProcessor.from_pretrained(processor_directory, use_fast=True)
74
 
75
 
76
  class ColpaliManager:
77
 
78
 
79
- def __init__(self, device = "cuda", model_name = model_name): #need to hot potato/use diff gpus between colpali & ollama
80
 
81
  print(f"Initializing ColpaliManager with device {device} and model {model_name}")
82
 
@@ -92,12 +82,12 @@ class ColpaliManager:
92
 
93
  @spaces.GPU
94
  def get_images(self, paths: list[str]) -> List[Image.Image]:
95
- model.to("cuda")
96
  return [Image.open(path) for path in paths]
97
 
98
  @spaces.GPU
99
  def process_images(self, image_paths:list[str], batch_size=5):
100
- model.to("cuda")
101
  print(f"Processing {len(image_paths)} image_paths")
102
 
103
  images = self.get_images(image_paths)
@@ -123,7 +113,7 @@ class ColpaliManager:
123
 
124
  @spaces.GPU
125
  def process_text(self, texts: list[str]):
126
- model.to("cuda") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
127
  print(f"Processing {len(texts)} texts")
128
 
129
  dataloader = DataLoader(
 
17
 
18
 
19
  #this part is for local runs
 
20
 
21
+ model_name = "vidore/colSmol-256M"
22
+ device = get_torch_device("cpu") #try using cpu instead of cpu?
 
 
 
 
 
 
23
 
24
  #switch to locally downloading models & loading locally rather than from hf
25
  #
 
28
  save_directory = model_name # Directory to save the specific model name
29
  save_directory = os.path.join(current_working_directory, save_directory)
30
 
31
+ processor_directory = 'local_processor' # Directory to save the processor
32
  processor_directory = os.path.join(current_working_directory, processor_directory)
33
 
34
 
35
+ model = ColIdefics3.from_pretrained(
36
+ model_name,
37
+ torch_dtype=torch.bfloat16,
38
+ device_map=device,
39
+ #attn_implementation="flash_attention_2",
40
+ ).eval()
41
+ processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
42
 
43
+ """
44
  if not os.path.exists(save_directory): #download if directory not created/model not loaded
45
  # Directory does not exist; create it
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  os.makedirs(save_directory)
47
  print(f"Directory '{save_directory}' created.")
48
+ model = ColIdefics3.from_pretrained(
49
+ model_name,
50
+ torch_dtype=torch.bfloat16,
51
+ device_map=device,
52
+ attn_implementation="flash_attention_2",
53
+ ).eval()
54
  model.save_pretrained(save_directory)
55
  os.makedirs(processor_directory)
56
+ processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
57
+
58
  processor.save_pretrained(processor_directory)
59
 
60
  else:
61
+ model = ColIdefics3.from_pretrained(save_directory)
62
+ processor = ColIdefics3.from_pretrained(processor_directory, use_fast=True)
63
+ """
 
 
 
64
 
65
 
66
  class ColpaliManager:
67
 
68
 
69
+ def __init__(self, device = "cpu", model_name = "vidore/colSmol-256M"): #need to hot potato/use diff gpus between colpali & ollama
70
 
71
  print(f"Initializing ColpaliManager with device {device} and model {model_name}")
72
 
 
82
 
83
  @spaces.GPU
84
  def get_images(self, paths: list[str]) -> List[Image.Image]:
85
+ model.to("cpu")
86
  return [Image.open(path) for path in paths]
87
 
88
  @spaces.GPU
89
  def process_images(self, image_paths:list[str], batch_size=5):
90
+ model.to("cpu")
91
  print(f"Processing {len(image_paths)} image_paths")
92
 
93
  images = self.get_images(image_paths)
 
113
 
114
  @spaces.GPU
115
  def process_text(self, texts: list[str]):
116
+ model.to("cpu") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
117
  print(f"Processing {len(texts)} texts")
118
 
119
  dataloader = DataLoader(
milvus_manager.py CHANGED
@@ -2,18 +2,11 @@ from pymilvus import MilvusClient, DataType
2
  import numpy as np
3
  import concurrent.futures
4
  from pymilvus import Collection
5
- import os
6
 
7
  class MilvusManager:
8
  def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
9
-
10
- #import environ variables from .env
11
- import dotenv
12
- # Load the .env file
13
- dotenv_file = dotenv.find_dotenv()
14
- dotenv.load_dotenv(dotenv_file)
15
-
16
- self.client = MilvusClient(uri="localhost")
17
  self.collection_name = collection_name
18
  self.dim = dim
19
 
@@ -47,15 +40,13 @@ class MilvusManager:
47
 
48
  def create_index(self):
49
  index_params = self.client.prepare_index_params()
50
-
51
  index_params.add_index(
52
  field_name="vector",
53
  index_name="vector_index",
54
- index_type="HNSW", #use HNSW option if got more mem, if not use IVF for faster processing
55
- metric_type=os.environ["metrictype"], #"IP"
56
  params={
57
- "M": int(os.environ["mnum"]), #M:16 for HNSW, capital M
58
- "efConstruction": int(os.environ["efnum"]), #500 for HNSW
59
  },
60
  )
61
 
@@ -68,7 +59,7 @@ class MilvusManager:
68
  collections = self.client.list_collections()
69
 
70
  # Set search parameters (here, using Inner Product metric).
71
- search_params = {"metric_type": os.environ["metrictype"], "params": {}} #default metric type is "IP"
72
 
73
  # Set to store unique (doc_id, collection_name) pairs across all collections.
74
  doc_collection_pairs = set()
@@ -80,7 +71,7 @@ class MilvusManager:
80
  results = self.client.search(
81
  collection,
82
  data,
83
- limit=int(os.environ["topk"]), # Adjust limit per collection as needed. (default is 50)
84
  output_fields=["vector", "seq_id", "doc_id"],
85
  search_params=search_params,
86
  )
 
2
  import numpy as np
3
  import concurrent.futures
4
  from pymilvus import Collection
 
5
 
6
  class MilvusManager:
7
  def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
8
+ self.client = MilvusClient(uri=milvus_uri)
9
+ # self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
 
 
 
 
 
 
10
  self.collection_name = collection_name
11
  self.dim = dim
12
 
 
40
 
41
  def create_index(self):
42
  index_params = self.client.prepare_index_params()
 
43
  index_params.add_index(
44
  field_name="vector",
45
  index_name="vector_index",
46
+ index_type="IVF_FLAT", #use HNSW option if got more mem, if not use IVF for faster processing
47
+ metric_type="IP",
48
  params={
49
+ "nlist": 1024
 
50
  },
51
  )
52
 
 
59
  collections = self.client.list_collections()
60
 
61
  # Set search parameters (here, using Inner Product metric).
62
+ search_params = {"metric_type": "IP", "params": {}}
63
 
64
  # Set to store unique (doc_id, collection_name) pairs across all collections.
65
  doc_collection_pairs = set()
 
71
  results = self.client.search(
72
  collection,
73
  data,
74
+ limit=50, # Adjust limit per collection as needed.
75
  output_fields=["vector", "seq_id", "doc_id"],
76
  search_params=search_params,
77
  )
rag.py CHANGED
@@ -1,17 +1,10 @@
1
  import requests
2
  import os
 
3
 
4
  from typing import List
5
  from utils import encode_image
6
  from PIL import Image
7
- from ollama import chat
8
- import torch
9
- import subprocess
10
- import psutil
11
- import torch
12
- from transformers import AutoModel, AutoTokenizer
13
- import google.generativeai as genai
14
-
15
 
16
 
17
 
@@ -40,17 +33,10 @@ class Rag:
40
  except Exception as e:
41
  print(f"An error occurred while querying Gemini: {e}")
42
  return f"Error: {str(e)}"
43
-
44
 
45
  #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
46
-
47
- def get_answer_from_openai(self, query, imagesPaths):
48
- #import environ variables from .env
49
- import dotenv
50
 
51
- # Load the .env file
52
- dotenv_file = dotenv.find_dotenv()
53
- dotenv.load_dotenv(dotenv_file)
54
  """ #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
55
  print(f"Querying for query={query}, imagesPaths={imagesPaths}")
56
 
@@ -79,12 +65,8 @@ class Rag:
79
 
80
  #ollama method below
81
 
82
- torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
83
-
84
-
85
- os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
86
- if os.environ['ollama'] == "minicpm-v":
87
- os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
88
 
89
 
90
  # Close model thread (colpali)
@@ -92,14 +74,13 @@ class Rag:
92
 
93
  try:
94
 
95
- response = chat(
96
- model=os.environ['ollama'],
97
- messages=[
98
  {
99
  'role': 'user',
100
  'content': query,
101
  'images': imagesPaths,
102
- "temperature":float(os.environ['temperature']), #test if temp makes a diff
103
  }
104
  ],
105
  )
@@ -155,4 +136,4 @@ class Rag:
155
  # query = "Based on attached images, how many new cases were reported during second wave peak"
156
  # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
157
 
158
- # rag.get_answer_from_gemini(query, imagesPaths)
 
1
  import requests
2
  import os
3
+ import google.generativeai as genai
4
 
5
  from typing import List
6
  from utils import encode_image
7
  from PIL import Image
 
 
 
 
 
 
 
 
8
 
9
 
10
 
 
33
  except Exception as e:
34
  print(f"An error occurred while querying Gemini: {e}")
35
  return f"Error: {str(e)}"
 
36
 
37
  #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
 
 
 
 
38
 
39
+ def get_answer_from_openai(self, query, imagesPaths):
 
 
40
  """ #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
41
  print(f"Querying for query={query}, imagesPaths={imagesPaths}")
42
 
 
65
 
66
  #ollama method below
67
 
68
+
69
+ os.environ['OLLAMA_FLASH_ATTENTION'] = '1'
 
 
 
 
70
 
71
 
72
  # Close model thread (colpali)
 
74
 
75
  try:
76
 
77
+ response = chat(
78
+ model='minicpm-v:8b-2.6-q8_0',
79
+ messages=[
80
  {
81
  'role': 'user',
82
  'content': query,
83
  'images': imagesPaths,
 
84
  }
85
  ],
86
  )
 
136
  # query = "Based on attached images, how many new cases were reported during second wave peak"
137
  # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
138
 
139
+ # rag.get_answer_from_gemini(query, imagesPaths)
requirements.txt CHANGED
@@ -1,15 +1,9 @@
1
- gradio
2
- PyMuPDF
3
- pdf2image
4
- pymilvus
5
- tqdm
6
- pillow
7
- spaces
8
- google-generativeai
9
  git+https://github.com/illuin-tech/colpali
10
- timm==1.0.13
11
- transformers
12
- https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp311-cp311-win_amd64.wh
13
- comtypes
14
- python-dotenv
15
- colpali-engine[interpretability]
 
 
 
 
 
 
 
 
 
 
 
1
  git+https://github.com/illuin-tech/colpali
2
+ gradio==4.25.0
3
+ PyMuPDF==1.24.9
4
+ pdf2image==1.17.0
5
+ pymilvus==2.4.9
6
+ tqdm==4.66.5
7
+ pillow==10.4.0
8
+ spaces==0.30.4
9
+ google-generativeai==0.8.3
test.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import MilvusClient
2
+ from pymilvus import (
3
+ connections,
4
+ utility,
5
+ FieldSchema, CollectionSchema, DataType,
6
+ Collection,
7
+ )
8
+
9
+ # 1. Create a milvus client
10
+ client = MilvusClient(
11
+ uri="http://localhost:19530",
12
+ token="root:Milvus"
13
+ )
14
+
15
+ # 2. Create a collection
16
+ client.drop_collection(collection_name="fy2025_budget_statement")
17
+
18
+ # 3. List collections
19
+ print(client.list_collections() )
20
+
21
+ # ['test_collection']
22
+
23
+ """
24
+ res = client.get(
25
+ collection_name="colpali",
26
+ ids=[0, 1, 2],
27
+ )
28
+
29
+ print(res)
30
+ """
uploaded_files.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ EMERGING_MISSILE_THREATS_16382509
2
+ handwriting
3
+ multimediareport