Kazel commited on
Commit
d9fa664
·
1 Parent(s): d8688d6
Files changed (6) hide show
  1. .env +8 -0
  2. app.py +88 -49
  3. colpali_manager.py +36 -26
  4. milvus_manager.py +16 -7
  5. rag.py +25 -8
  6. requirements.txt +14 -8
.env ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ colpali='vidore/colpali-v1.3'
2
+ ollama='minicpm-v'
3
+ flashattn='1'
4
+ metrictype='IP'
5
+ mnum='16'
6
+ efnum='500'
7
+ topk='50'
8
+ temperature='0.8'
app.py CHANGED
@@ -11,10 +11,23 @@ from rag import Rag
11
  from pathlib import Path
12
  import subprocess
13
  import getpass
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  rag = Rag()
16
 
17
 
 
18
  def generate_uuid(state):
19
  # Check if UUID already exists in session state
20
  if state["user_uuid"] is None:
@@ -50,29 +63,6 @@ class PDFSearchApp:
50
  pdf_path=file.name
51
  #if ppt will get replaced with path of ppt!
52
 
53
- #if extension is .ppt or .pptx, convert
54
- if ext == ".ppt" or ext == ".pptx": #need to test with a ppt key...
55
- '''
56
- import comtypes.client
57
- powerpoint = comtypes.client.CreateObject("PowerPoint.Application")
58
- powerpoint.Visible = 1
59
- presentation = powerpoint.Presentations.Open(file)
60
- output_file = os.path.splitext(file)[0] + '.pdf'
61
- output_directory = os.path.dirname(file)
62
- presentation.SaveAs(os.path.join(output_directory, output_file), 32) # 32 is the formatType for PDF
63
- presentation.Close()
64
- powerpoint.Quit()
65
- file = os.path.join(output_directory, output_file) #swap file to be used to the outputted pdf file instead
66
- # Extract the last part of the path (file name)
67
- name = os.path.basename(file)
68
- # Split the base name into name and extension
69
- name, ext = os.path.splitext(name)
70
- print(name)
71
- self.current_pdf = os.path.join(output_directory, output_file)
72
- pdf_path = os.path.join(output_directory, output_file)'
73
- '''
74
- print("pptx not supported on spaces")
75
-
76
 
77
  # Replace spaces and hyphens with underscores in the name
78
  modified_filename = name.replace(" ", "_").replace("-", "_")
@@ -154,12 +144,9 @@ class PDFSearchApp:
154
  def delete(state,choice):
155
  #delete file in pages, then use middleware to delete collection
156
  # 1. Create a milvus client
157
-
158
- client = MilvusClient(uri="localhost")
159
- #client = MilvusClient(
160
- # uri="http://localhost:19530",
161
- # token="root:Milvus"
162
- # )
163
  path = f"pages/{choice}"
164
  if os.path.exists(path):
165
  shutil.rmtree(path)
@@ -168,6 +155,18 @@ class PDFSearchApp:
168
  return f"Deleted {choice}"
169
  else:
170
  return "Directory not found"
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  def list_downloaded_hf_models(state):
173
  # Determine the cache directory
@@ -179,18 +178,19 @@ class PDFSearchApp:
179
  # Traverse the cache directory
180
  for repo_dir in hf_cache_dir.glob('models--*'):
181
  # Extract the model name from the directory structure
182
- model_name = repo_dir.name.split('--', 1)[-1].replace('-', '/')
183
  model_names.append(model_name)
184
 
185
  return model_names
186
 
187
 
188
- def list_downloaded_ollama_models(state,):
189
  # Retrieve the current user's name
190
  username = getpass.getuser()
191
 
192
  # Construct the target directory path
193
- base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library"
 
194
 
195
  try:
196
  # List all entries in the directory
@@ -206,18 +206,29 @@ class PDFSearchApp:
206
  except Exception as e:
207
  print(f"An error occurred: {e}")
208
 
209
- def model_settings(state,hfchoice, ollamachoice,tokensize):
210
  os.environ['colpali'] = hfchoice
 
 
211
  os.environ['ollama'] = ollamachoice
212
- os.environ['tokens'] = tokensize
213
- return "abc"
 
 
 
 
 
 
 
 
 
214
 
215
 
216
 
217
  def create_ui():
218
  app = PDFSearchApp()
219
 
220
- with gr.Blocks(css="footer{display:none !important}") as demo:
221
  state = gr.State(value={"user_uuid": None})
222
 
223
 
@@ -256,26 +267,47 @@ def create_ui():
256
  with gr.Column():
257
  # Button to delete (TBD)
258
  choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
259
- delete_button = gr.Button("Delete Document From DB")
260
  status1 = gr.Textbox(label="Deletion Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
263
  with gr.Column():
264
  # Button to delete (TBD)
265
- hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),label="Visual Document Retrieval (VDR) Model")
266
- ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
267
- tokensize = gr.Slider(
268
- minimum=256,
269
- maximum=4096,
270
- value=20,
271
- step=10,
272
- label="Max tokens per response (Reply Length)"
 
273
  )
274
  model_button = gr.Button("Update Settings")
275
  status2 = gr.Textbox(label="Update Status", interactive=False)
276
-
277
-
278
-
279
 
280
  # Event handlers
281
  file_input.change(
@@ -296,10 +328,16 @@ def create_ui():
296
  inputs=[choice],
297
  outputs=[status1]
298
  )
 
 
 
 
 
 
299
 
300
  model_button.click(
301
  fn=app.model_settings,
302
- inputs=[hfchoice, ollamachoice,tokensize],
303
  outputs=[status2]
304
  )
305
 
@@ -307,5 +345,6 @@ def create_ui():
307
 
308
  if __name__ == "__main__":
309
  demo = create_ui()
 
310
  demo.launch()
311
 
 
11
  from pathlib import Path
12
  import subprocess
13
  import getpass
14
+ # importing necessary functions from dotenv library
15
+ from dotenv import load_dotenv, dotenv_values
16
+ import dotenv
17
+ import platform
18
+ import time
19
+
20
+ # loading variables from .env file
21
+ dotenv_file = dotenv.find_dotenv()
22
+ dotenv.load_dotenv(dotenv_file)
23
+
24
+ #kickstart docker and ollama servers
25
+
26
 
27
  rag = Rag()
28
 
29
 
30
+
31
  def generate_uuid(state):
32
  # Check if UUID already exists in session state
33
  if state["user_uuid"] is None:
 
63
  pdf_path=file.name
64
  #if ppt will get replaced with path of ppt!
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Replace spaces and hyphens with underscores in the name
68
  modified_filename = name.replace(" ", "_").replace("-", "_")
 
144
  def delete(state,choice):
145
  #delete file in pages, then use middleware to delete collection
146
  # 1. Create a milvus client
147
+ client = MilvusClient(
148
+ uri="localhost"
149
+ )
 
 
 
150
  path = f"pages/{choice}"
151
  if os.path.exists(path):
152
  shutil.rmtree(path)
 
155
  return f"Deleted {choice}"
156
  else:
157
  return "Directory not found"
158
+ def dbupdate(state,metric_type,m_num,ef_num,topk):
159
+ os.environ['metrictype'] = metric_type
160
+ # Update the .env file with the new value
161
+ dotenv.set_key(dotenv_file, 'metrictype', metric_type)
162
+ os.environ['mnum'] = str(m_num)
163
+ dotenv.set_key(dotenv_file, 'mnum', str(m_num))
164
+ os.environ['efnum'] = str(ef_num)
165
+ dotenv.set_key(dotenv_file, 'efnum', str(ef_num))
166
+ os.environ['topk'] = str(topk)
167
+ dotenv.set_key(dotenv_file, 'topk', str(topk))
168
+
169
+ return "DB Settings Updated, Restart App To Load"
170
 
171
  def list_downloaded_hf_models(state):
172
  # Determine the cache directory
 
178
  # Traverse the cache directory
179
  for repo_dir in hf_cache_dir.glob('models--*'):
180
  # Extract the model name from the directory structure
181
+ model_name = repo_dir.name.split('--', 1)[-1].replace('--', '/')
182
  model_names.append(model_name)
183
 
184
  return model_names
185
 
186
 
187
+ def list_downloaded_ollama_models(state):
188
  # Retrieve the current user's name
189
  username = getpass.getuser()
190
 
191
  # Construct the target directory path
192
+ #base_path = f"C:\\Users\\{username}\\NEW_PATH\\manifests\\registry.ollama.ai\\library" #this is for if ollama pull is called from C://, if ollama pulls are called from the proj dir, use the NEW_PATH in the proj dir!
193
+ base_path = f"NEW_PATH\\manifests\\registry.ollama.ai\\library" #relative to proj dir! (IMPT: OLLAMA PULL COMMAND IN PROJ DIR!!!)
194
 
195
  try:
196
  # List all entries in the directory
 
206
  except Exception as e:
207
  print(f"An error occurred: {e}")
208
 
209
+ def model_settings(state,hfchoice, ollamachoice,flash, temp):
210
  os.environ['colpali'] = hfchoice
211
+ # Update the .env file with the new value
212
+ dotenv.set_key(dotenv_file, 'colpali', hfchoice)
213
  os.environ['ollama'] = ollamachoice
214
+ dotenv.set_key(dotenv_file, 'ollama', ollamachoice)
215
+ if flash == "Enabled":
216
+ os.environ['flashattn'] = "1"
217
+ dotenv.set_key(dotenv_file, 'flashattn', "1")
218
+ else:
219
+ os.environ['flashattn'] = "0"
220
+ dotenv.set_key(dotenv_file, 'flashattn', "0")
221
+ os.environ['temperature'] = str(temp)
222
+ dotenv.set_key(dotenv_file, 'temperature', str(temp))
223
+
224
+ return "Models Updated, Restart App To Use New Settings"
225
 
226
 
227
 
228
  def create_ui():
229
  app = PDFSearchApp()
230
 
231
+ with gr.Blocks(theme=gr.themes.Ocean(),css ="footer{display:none !important}") as demo:
232
  state = gr.State(value={"user_uuid": None})
233
 
234
 
 
267
  with gr.Column():
268
  # Button to delete (TBD)
269
  choice = gr.Dropdown(list(app.display_file_list()),label="Choice")
 
270
  status1 = gr.Textbox(label="Deletion Status", interactive=False)
271
+ delete_button = gr.Button("Delete Document From DB")
272
+
273
+ # Create the dropdown component with default value as the first option
274
+ #Milvusindex = gr.Dropdown(["HNSW","FLAT", "IVF_FLAT", "IVF_SQ8", "IVF_PQ", "RHNSW_FLAT"], value="HNSW", label="Select Vector DB Index Parameter")
275
+ metric_type = gr.Dropdown(choices=["IP", "L2", "COSINE"],value="IP",label="Metric Type (Mathematical function to measure similarity)")
276
+ m_num = gr.Dropdown(
277
+ choices=["8", "16", "32", "64"], value="16",label="M Vectors (Maximum number of neighbors each node can connect to in the graph)")
278
+ ef_num = gr.Slider(
279
+ minimum=50,
280
+ maximum=1000,
281
+ value=500,
282
+ step=10,
283
+ label="EF Construction (Number of candidate neighbors considered for connection during index construction)"
284
+ )
285
+ topk = gr.Slider(
286
+ minimum=1,
287
+ maximum=100,
288
+ value=50,
289
+ step=1,
290
+ label="Top-K (Maximum number of entities to return in a single search of a document)"
291
+ )
292
+ db_button = gr.Button("Update DB Settings")
293
+ status3 = gr.Textbox(label="DB Update Status", interactive=False)
294
+
295
 
296
  with gr.Tab("AI Model Settings"): #deletion of collections, changing of model parameters etc
297
  with gr.Column():
298
  # Button to delete (TBD)
299
+ hfchoice = gr.Dropdown(app.list_downloaded_hf_models(),value=os.environ['colpali'], label="Visual Document Retrieval (VDR) Model")
300
+ ollamachoice = gr.Dropdown(app.list_downloaded_ollama_models(),value=os.environ['ollama'],label="Secondary Visual Retrieval-Augmented Generation (RAG) Model")
301
+ flash = gr.Dropdown(["Enabled","Disabled"], value = "Enabled",label ="Flash Attention 2.0 Acceleration")
302
+ temp = gr.Slider(
303
+ minimum=0.1,
304
+ maximum=1,
305
+ value=0.8,
306
+ step=0.1,
307
+ label="RAG Temperature"
308
  )
309
  model_button = gr.Button("Update Settings")
310
  status2 = gr.Textbox(label="Update Status", interactive=False)
 
 
 
311
 
312
  # Event handlers
313
  file_input.change(
 
328
  inputs=[choice],
329
  outputs=[status1]
330
  )
331
+
332
+ db_button.click(
333
+ fn=app.dbupdate,
334
+ inputs=[metric_type,m_num,ef_num,topk],
335
+ outputs=[status3]
336
+ )
337
 
338
  model_button.click(
339
  fn=app.model_settings,
340
+ inputs=[hfchoice, ollamachoice,flash,temp],
341
  outputs=[status2]
342
  )
343
 
 
345
 
346
  if __name__ == "__main__":
347
  demo = create_ui()
348
+ #demo.launch(auth=("admin", "pass1234")) for with login page config
349
  demo.launch()
350
 
colpali_manager.py CHANGED
@@ -17,9 +17,16 @@ import spaces
17
 
18
 
19
  #this part is for local runs
 
20
 
21
- model_name = "vidore/colSmol-256M"
22
- device = get_torch_device("cpu") #try using cpu instead of cpu?
 
 
 
 
 
 
23
 
24
  #switch to locally downloading models & loading locally rather than from hf
25
  #
@@ -28,45 +35,48 @@ current_working_directory = os.getcwd()
28
  save_directory = model_name # Directory to save the specific model name
29
  save_directory = os.path.join(current_working_directory, save_directory)
30
 
31
- processor_directory = 'local_processor' # Directory to save the processor
32
  processor_directory = os.path.join(current_working_directory, processor_directory)
33
 
34
 
35
- model = ColIdefics3.from_pretrained(
36
- model_name,
37
- torch_dtype=torch.bfloat16,
38
- device_map=device,
39
- #attn_implementation="flash_attention_2",
40
- ).eval()
41
- processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
42
 
43
- """
44
  if not os.path.exists(save_directory): #download if directory not created/model not loaded
45
  # Directory does not exist; create it
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  os.makedirs(save_directory)
47
  print(f"Directory '{save_directory}' created.")
48
- model = ColIdefics3.from_pretrained(
49
- model_name,
50
- torch_dtype=torch.bfloat16,
51
- device_map=device,
52
- attn_implementation="flash_attention_2",
53
- ).eval()
54
  model.save_pretrained(save_directory)
55
  os.makedirs(processor_directory)
56
- processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
57
-
58
  processor.save_pretrained(processor_directory)
59
 
60
  else:
61
- model = ColIdefics3.from_pretrained(save_directory)
62
- processor = ColIdefics3.from_pretrained(processor_directory, use_fast=True)
63
- """
 
 
 
64
 
65
 
66
  class ColpaliManager:
67
 
68
 
69
- def __init__(self, device = "cpu", model_name = "vidore/colSmol-256M"): #need to hot potato/use diff gpus between colpali & ollama
70
 
71
  print(f"Initializing ColpaliManager with device {device} and model {model_name}")
72
 
@@ -82,12 +92,12 @@ class ColpaliManager:
82
 
83
  @spaces.GPU
84
  def get_images(self, paths: list[str]) -> List[Image.Image]:
85
- model.to("cpu")
86
  return [Image.open(path) for path in paths]
87
 
88
  @spaces.GPU
89
  def process_images(self, image_paths:list[str], batch_size=5):
90
- model.to("cpu")
91
  print(f"Processing {len(image_paths)} image_paths")
92
 
93
  images = self.get_images(image_paths)
@@ -113,7 +123,7 @@ class ColpaliManager:
113
 
114
  @spaces.GPU
115
  def process_text(self, texts: list[str]):
116
- model.to("cpu") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
117
  print(f"Processing {len(texts)} texts")
118
 
119
  dataloader = DataLoader(
 
17
 
18
 
19
  #this part is for local runs
20
+ torch.cuda.empty_cache()
21
 
22
+ #get model name from .env variable & set directory & processor dir as the model names!
23
+ import dotenv
24
+ # Load the .env file
25
+ dotenv_file = dotenv.find_dotenv()
26
+ dotenv.load_dotenv(dotenv_file)
27
+
28
+ model_name = os.environ['colpali'] #"vidore/colSmol-256M"
29
+ device = get_torch_device("cuda") #try using cpu instead of cuda?
30
 
31
  #switch to locally downloading models & loading locally rather than from hf
32
  #
 
35
  save_directory = model_name # Directory to save the specific model name
36
  save_directory = os.path.join(current_working_directory, save_directory)
37
 
38
+ processor_directory = model_name+'_processor' # Directory to save the processor
39
  processor_directory = os.path.join(current_working_directory, processor_directory)
40
 
41
 
 
 
 
 
 
 
 
42
 
 
43
  if not os.path.exists(save_directory): #download if directory not created/model not loaded
44
  # Directory does not exist; create it
45
+ if "colSmol" in model_name: #if colsmol
46
+ model = ColIdefics3.from_pretrained(
47
+ model_name,
48
+ torch_dtype=torch.bfloat16,
49
+ device_map=device,
50
+ attn_implementation="flash_attention_2",
51
+ ).eval()
52
+ processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
53
+ else: #if colpali v1.3 etc
54
+ model = ColPali.from_pretrained(
55
+ model_name,
56
+ torch_dtype=torch.bfloat16,
57
+ device_map=device,
58
+ attn_implementation="flash_attention_2",
59
+ ).eval()
60
+ processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
61
  os.makedirs(save_directory)
62
  print(f"Directory '{save_directory}' created.")
 
 
 
 
 
 
63
  model.save_pretrained(save_directory)
64
  os.makedirs(processor_directory)
 
 
65
  processor.save_pretrained(processor_directory)
66
 
67
  else:
68
+ if "colSmol" in model_name:
69
+ model = ColIdefics3.from_pretrained(save_directory)
70
+ processor = ColIdefics3Processor.from_pretrained(processor_directory, use_fast=True)
71
+ else:
72
+ model = ColPali.from_pretrained(save_directory)
73
+ processor = ColPaliProcessor.from_pretrained(processor_directory, use_fast=True)
74
 
75
 
76
  class ColpaliManager:
77
 
78
 
79
+ def __init__(self, device = "cuda", model_name = model_name): #need to hot potato/use diff gpus between colpali & ollama
80
 
81
  print(f"Initializing ColpaliManager with device {device} and model {model_name}")
82
 
 
92
 
93
  @spaces.GPU
94
  def get_images(self, paths: list[str]) -> List[Image.Image]:
95
+ model.to("cuda")
96
  return [Image.open(path) for path in paths]
97
 
98
  @spaces.GPU
99
  def process_images(self, image_paths:list[str], batch_size=5):
100
+ model.to("cuda")
101
  print(f"Processing {len(image_paths)} image_paths")
102
 
103
  images = self.get_images(image_paths)
 
123
 
124
  @spaces.GPU
125
  def process_text(self, texts: list[str]):
126
+ model.to("cuda") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
127
  print(f"Processing {len(texts)} texts")
128
 
129
  dataloader = DataLoader(
milvus_manager.py CHANGED
@@ -2,11 +2,18 @@ from pymilvus import MilvusClient, DataType
2
  import numpy as np
3
  import concurrent.futures
4
  from pymilvus import Collection
 
5
 
6
  class MilvusManager:
7
  def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
8
- self.client = MilvusClient(uri=milvus_uri)
9
- # self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
 
 
 
 
 
 
10
  self.collection_name = collection_name
11
  self.dim = dim
12
 
@@ -40,13 +47,15 @@ class MilvusManager:
40
 
41
  def create_index(self):
42
  index_params = self.client.prepare_index_params()
 
43
  index_params.add_index(
44
  field_name="vector",
45
  index_name="vector_index",
46
- index_type="IVF_FLAT", #use HNSW option if got more mem, if not use IVF for faster processing
47
- metric_type="IP",
48
  params={
49
- "nlist": 1024
 
50
  },
51
  )
52
 
@@ -59,7 +68,7 @@ class MilvusManager:
59
  collections = self.client.list_collections()
60
 
61
  # Set search parameters (here, using Inner Product metric).
62
- search_params = {"metric_type": "IP", "params": {}}
63
 
64
  # Set to store unique (doc_id, collection_name) pairs across all collections.
65
  doc_collection_pairs = set()
@@ -71,7 +80,7 @@ class MilvusManager:
71
  results = self.client.search(
72
  collection,
73
  data,
74
- limit=50, # Adjust limit per collection as needed.
75
  output_fields=["vector", "seq_id", "doc_id"],
76
  search_params=search_params,
77
  )
 
2
  import numpy as np
3
  import concurrent.futures
4
  from pymilvus import Collection
5
+ import os
6
 
7
  class MilvusManager:
8
  def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
9
+
10
+ #import environ variables from .env
11
+ import dotenv
12
+ # Load the .env file
13
+ dotenv_file = dotenv.find_dotenv()
14
+ dotenv.load_dotenv(dotenv_file)
15
+
16
+ self.client = MilvusClient(uri="http://localhost:19530", token="root:Milvus")
17
  self.collection_name = collection_name
18
  self.dim = dim
19
 
 
47
 
48
  def create_index(self):
49
  index_params = self.client.prepare_index_params()
50
+
51
  index_params.add_index(
52
  field_name="vector",
53
  index_name="vector_index",
54
+ index_type="HNSW", #use HNSW option if got more mem, if not use IVF for faster processing
55
+ metric_type=os.environ["metrictype"], #"IP"
56
  params={
57
+ "M": int(os.environ["mnum"]), #M:16 for HNSW, capital M
58
+ "efConstruction": int(os.environ["efnum"]), #500 for HNSW
59
  },
60
  )
61
 
 
68
  collections = self.client.list_collections()
69
 
70
  # Set search parameters (here, using Inner Product metric).
71
+ search_params = {"metric_type": os.environ["metrictype"], "params": {}} #default metric type is "IP"
72
 
73
  # Set to store unique (doc_id, collection_name) pairs across all collections.
74
  doc_collection_pairs = set()
 
80
  results = self.client.search(
81
  collection,
82
  data,
83
+ limit=int(os.environ["topk"]), # Adjust limit per collection as needed. (default is 50)
84
  output_fields=["vector", "seq_id", "doc_id"],
85
  search_params=search_params,
86
  )
rag.py CHANGED
@@ -1,10 +1,16 @@
1
  import requests
2
  import os
3
- import google.generativeai as genai
4
 
5
  from typing import List
6
  from utils import encode_image
7
  from PIL import Image
 
 
 
 
 
 
 
8
 
9
 
10
 
@@ -35,8 +41,14 @@ class Rag:
35
  return f"Error: {str(e)}"
36
 
37
  #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
38
-
39
  def get_answer_from_openai(self, query, imagesPaths):
 
 
 
 
 
 
40
  """ #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
41
  print(f"Querying for query={query}, imagesPaths={imagesPaths}")
42
 
@@ -65,8 +77,12 @@ class Rag:
65
 
66
  #ollama method below
67
 
68
-
69
- os.environ['OLLAMA_FLASH_ATTENTION'] = '1'
 
 
 
 
70
 
71
 
72
  # Close model thread (colpali)
@@ -74,13 +90,14 @@ class Rag:
74
 
75
  try:
76
 
77
- response = chat(
78
- model='minicpm-v:8b-2.6-q8_0',
79
- messages=[
80
  {
81
  'role': 'user',
82
  'content': query,
83
  'images': imagesPaths,
 
84
  }
85
  ],
86
  )
@@ -136,4 +153,4 @@ class Rag:
136
  # query = "Based on attached images, how many new cases were reported during second wave peak"
137
  # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
138
 
139
- # rag.get_answer_from_gemini(query, imagesPaths)
 
1
  import requests
2
  import os
 
3
 
4
  from typing import List
5
  from utils import encode_image
6
  from PIL import Image
7
+ from ollama import chat
8
+ import torch
9
+ import subprocess
10
+ import psutil
11
+ import torch
12
+ from transformers import AutoModel, AutoTokenizer
13
+ import google.generativeai as genai
14
 
15
 
16
 
 
41
  return f"Error: {str(e)}"
42
 
43
  #os.environ['OPENAI_API_KEY'] = "for the love of Jesus let this work"
44
+
45
  def get_answer_from_openai(self, query, imagesPaths):
46
+ #import environ variables from .env
47
+ import dotenv
48
+
49
+ # Load the .env file
50
+ dotenv_file = dotenv.find_dotenv()
51
+ dotenv.load_dotenv(dotenv_file)
52
  """ #scuffed local hf inference (transformers incompatible to colpali version req, use ollama, more reliable, easier to use plus web server ready)
53
  print(f"Querying for query={query}, imagesPaths={imagesPaths}")
54
 
 
77
 
78
  #ollama method below
79
 
80
+ torch.cuda.empty_cache() #release cuda so that ollama can use gpu!
81
+
82
+
83
+ os.environ['OLLAMA_FLASH_ATTENTION'] = os.environ['flashattn'] #int "1"
84
+ if os.environ['ollama'] == "minicpm-v":
85
+ os.environ['ollama'] = "minicpm-v:8b-2.6-q8_0" #set to quantized version
86
 
87
 
88
  # Close model thread (colpali)
 
90
 
91
  try:
92
 
93
+ response = chat(
94
+ model=os.environ['ollama'],
95
+ messages=[
96
  {
97
  'role': 'user',
98
  'content': query,
99
  'images': imagesPaths,
100
+ "temperature":float(os.environ['temperature']), #test if temp makes a diff
101
  }
102
  ],
103
  )
 
153
  # query = "Based on attached images, how many new cases were reported during second wave peak"
154
  # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
155
 
156
+ # rag.get_answer_from_gemini(query, imagesPaths)
requirements.txt CHANGED
@@ -1,9 +1,15 @@
 
 
 
 
 
 
 
 
1
  git+https://github.com/illuin-tech/colpali
2
- gradio==4.25.0
3
- PyMuPDF==1.24.9
4
- pdf2image==1.17.0
5
- pymilvus==2.4.9
6
- tqdm==4.66.5
7
- pillow==10.4.0
8
- spaces==0.30.4
9
- google-generativeai==0.8.3
 
1
+ gradio
2
+ PyMuPDF
3
+ pdf2image
4
+ pymilvus
5
+ tqdm
6
+ pillow
7
+ spaces
8
+ google-generativeai
9
  git+https://github.com/illuin-tech/colpali
10
+ timm==1.0.13
11
+ transformers
12
+ https://github.com/woct0rdho/triton-windows/releases/download/v3.2.0-windows.post10/triton-3.2.0-cp311-cp311-win_amd64.wh
13
+ comtypes
14
+ python-dotenv
15
+ colpali-engine[interpretability]