ej68okap commited on
Commit
241c492
Β·
1 Parent(s): b9e672c

new code added

Browse files
Files changed (8) hide show
  1. README.md +76 -9
  2. app.py +133 -56
  3. colpali_manager.py +76 -41
  4. middleware.py +61 -20
  5. milvus_manager.py +101 -44
  6. pdf_manager.py +47 -10
  7. rag.py +139 -55
  8. utils.py +13 -2
README.md CHANGED
@@ -1,12 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Multimodal Rag
3
- emoji: 🐨
4
- colorFrom: indigo
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.12.0
8
- app_file: app.py
9
- pinned: false
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multimodal RAG with Colpali, Milvus, and Visual Language Models
2
+
3
+ This repository demonstrates how to build a **Multimodal Retrieval-Augmented Generation (RAG)** application using **Colpali**, **Milvus**, and **Visual Language Models (VLMs)** like Gemini or GPT-4o. The application allows users to upload a PDF and perform Q&A queries on both textual and visual elements of the document.
4
+
5
+ ---
6
+
7
+ ## Features
8
+
9
+ - **Multimodal Q&A**: Combines visual and textual embeddings for robust query answering.
10
+ - **PDF as Images**: Treats PDF pages as images to preserve layout and visual context.
11
+ - **Efficient Retrieval**: Utilizes Milvus for fast and accurate vector search.
12
+ - **Advanced Query Processing**: Integrates Colpali and VLMs for embeddings and response generation.
13
+
14
  ---
15
+
16
+ ## Architecture Overview
17
+
18
+ 1. **Colpali**:
19
+ - Generates embeddings for images (PDF pages) and text (user queries).
20
+ - Processes visual and textual data seamlessly.
21
+
22
+ 2. **Milvus**:
23
+ - A vector database used for indexing and retrieving embeddings.
24
+ - Supports HNSW-based indexing for efficient similarity searches.
25
+
26
+ 3. **Visual Language Models**:
27
+ - Gemini or GPT-4o performs context-aware Q&A using retrieved pages.
28
+
29
  ---
30
 
31
+ ## Installation
32
+
33
+ ### Prerequisites
34
+ - Python 3.8 or higher
35
+ - CUDA-compatible GPU for acceleration
36
+ - Milvus installed and running ([Installation Guide](https://milvus.io/docs/install_standalone.md))
37
+ - Required Python packages (see `requirements.txt`)
38
+
39
+ ### Steps to Run the Application Locally
40
+ 1. Clone the repository
41
+ 2. Install dependencies as **pip install -r requirements.txt**
42
+ 3. Set up environment variables
43
+ Add the following variables to your .env file or environment:
44
+ GEMINI_API_KEY=<Your_Gemini_API_Key>
45
+ 4. Launch the Gradio App as **python app.py**
46
+
47
+
48
+ ### Deploying the Gradio App on Hugging Face Spaces
49
+ 1. Prepare the Repository
50
+ git clone https://github.com/saumitras/colpali-milvus-rag.git
51
+ cd colpali-milvus-rag
52
+
53
+ 2. Organize the Repository:
54
+ Ensure the app file (e.g., app.py) contains the Gradio application code.
55
+ Include the requirements.txt file for dependencies.
56
+
57
+ Update the Hugging Face API Configuration:
58
+
59
+ 3. Add necessary environment variables like GEMINI_API_KEY or OPENAI_API_KEY to the Hugging Face Spaces Secrets:
60
+ Navigate to your Hugging Face Space.
61
+ Go to the Settings tab and add the required secrets under Repository secrets.
62
+
63
+ 4. Create a New Space
64
+ Visit Hugging Face Spaces.
65
+ Click New Space.
66
+ Fill in the details:
67
+ Name: Give your Space a unique name (e.g., multimodal_rag).
68
+ SDK: Select Gradio as the SDK.
69
+ Visibility: Choose between Public or Private.
70
+ Click Create Space.
71
+ 5. Push Code to Hugging Face
72
+ Initialize Git and push the code:
73
+ git remote add hf https://huggingface.co/spaces/ultron1996/multimodal_rag
74
+ git push hf main
75
+
76
+ 6. Wait for the Hugging Face Space to build and deploy the application.
77
+
78
+
79
+ The app has been deployed on Hugging Face Spaces and Demo is running at https://huggingface.co/spaces/ultron1996/multimodal_rag
app.py CHANGED
@@ -1,97 +1,171 @@
1
  import gradio as gr
2
  import tempfile
3
  import os
4
- import fitz # PyMuPDF
5
  import uuid
6
 
7
-
8
  from middleware import Middleware
9
  from rag import Rag
10
 
11
- rag = Rag()
12
 
 
13
  def generate_uuid(state):
14
  # Check if UUID already exists in session state
15
  if state["user_uuid"] is None:
16
  # Generate a new UUID if not already set
17
  state["user_uuid"] = str(uuid.uuid4())
18
-
19
  return state["user_uuid"]
20
 
21
 
22
  class PDFSearchApp:
 
 
23
  def __init__(self):
24
- self.indexed_docs = {}
25
- self.current_pdf = None
26
-
27
-
28
  def upload_and_convert(self, state, file, max_pages):
29
- id = generate_uuid(state)
30
 
31
- if file is None:
32
  return "No file uploaded"
33
 
34
  print(f"Uploading file: {file.name}, id: {id}")
35
-
36
  try:
37
- self.current_pdf = file.name
38
 
 
39
  middleware = Middleware(id, create_collection=True)
40
 
 
41
  pages = middleware.index(pdf_path=file.name, id=id, max_pages=max_pages)
42
 
 
43
  self.indexed_docs[id] = True
44
-
45
  return f"Uploaded and extracted {len(pages)} pages"
46
- except Exception as e:
47
  return f"Error processing PDF: {str(e)}"
48
-
49
-
50
- def search_documents(self, state, query, num_results=1):
 
 
 
 
 
 
 
 
 
51
  print(f"Searching for query: {query}")
52
- id = generate_uuid(state)
53
-
54
- if not self.indexed_docs[id]:
 
55
  print("Please index documents first")
56
- return "Please index documents first", "--"
 
 
57
  if not query:
58
  print("Please enter a search query")
59
- return "Please enter a search query", "--"
60
-
61
- try:
62
 
 
 
63
  middleware = Middleware(id, create_collection=False)
64
-
65
- search_results = middleware.search([query])[0]
66
 
67
- page_num = search_results[0][1] + 1
 
68
 
69
- print(f"Retrieved page number: {page_num}")
 
 
 
70
 
71
- img_path = f"pages/{id}/page_{page_num}.png"
 
 
 
 
 
72
 
73
- print(f"Retrieved image path: {img_path}")
74
 
75
- rag_response = rag.get_answer_from_gemini(query, [img_path])
 
 
 
76
 
77
- return img_path, rag_response
78
-
79
  except Exception as e:
80
- return f"Error during search: {str(e)}", "--"
 
 
81
 
82
- def create_ui():
83
- app = PDFSearchApp()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
 
 
 
 
 
85
  with gr.Blocks() as demo:
86
- state = gr.State(value={"user_uuid": None})
87
 
 
88
  gr.Markdown("# Colpali Milvus Multimodal RAG Demo")
89
- gr.Markdown("This demo showcases how to use [Colpali](https://github.com/illuin-tech/colpali) embeddings with [Milvus](https://milvus.io/) and utilizing Gemini/OpenAI multimodal RAG for pdf search and Q&A.")
90
-
 
 
 
91
  with gr.Tab("Upload PDF"):
92
  with gr.Column():
 
93
  file_input = gr.File(label="Upload PDF")
94
-
 
95
  max_pages_input = gr.Slider(
96
  minimum=1,
97
  maximum=50,
@@ -99,38 +173,41 @@ def create_ui():
99
  step=10,
100
  label="Max pages to extract and index"
101
  )
102
-
 
103
  status = gr.Textbox(label="Indexing Status", interactive=False)
104
-
 
105
  with gr.Tab("Query"):
106
  with gr.Column():
 
107
  query_input = gr.Textbox(label="Enter query")
108
- # num_results = gr.Slider(
109
- # minimum=1,
110
- # maximum=10,
111
- # value=5,
112
- # step=1,
113
- # label="Number of results"
114
- # )
115
  search_btn = gr.Button("Query")
 
 
116
  llm_answer = gr.Textbox(label="RAG Response", interactive=False)
 
 
117
  images = gr.Image(label="Top page matching query")
118
-
119
- # Event handlers
120
  file_input.change(
121
  fn=app.upload_and_convert,
122
  inputs=[state, file_input, max_pages_input],
123
  outputs=[status]
124
  )
125
-
126
  search_btn.click(
127
  fn=app.search_documents,
128
  inputs=[state, query_input],
129
  outputs=[images, llm_answer]
130
  )
131
-
132
- return demo
133
 
 
 
 
134
  if __name__ == "__main__":
135
- demo = create_ui()
136
- demo.launch()
 
1
  import gradio as gr
2
  import tempfile
3
  import os
4
+ import fitz # PyMuPDF for working with PDF files
5
  import uuid
6
 
7
+ # Importing middleware and RAG (Retrieval-Augmented Generation) components
8
  from middleware import Middleware
9
  from rag import Rag
10
 
11
+ rag = Rag() # Initializing RAG for question-answering functionality
12
 
13
+ # Function to generate a unique UUID for each user session
14
  def generate_uuid(state):
15
  # Check if UUID already exists in session state
16
  if state["user_uuid"] is None:
17
  # Generate a new UUID if not already set
18
  state["user_uuid"] = str(uuid.uuid4())
 
19
  return state["user_uuid"]
20
 
21
 
22
  class PDFSearchApp:
23
+ """Class to manage PDF upload, indexing, and querying."""
24
+
25
  def __init__(self):
26
+ self.indexed_docs = {} # Dictionary to track indexed documents by user ID
27
+ self.current_pdf = None # Store the currently processed PDF
28
+
29
+ # Function to handle file uploads and convert PDFs into searchable data
30
  def upload_and_convert(self, state, file, max_pages):
31
+ id = generate_uuid(state) # Get unique user ID
32
 
33
+ if file is None: # Check if a file was uploaded
34
  return "No file uploaded"
35
 
36
  print(f"Uploading file: {file.name}, id: {id}")
37
+
38
  try:
39
+ self.current_pdf = file.name # Store the name of the uploaded file
40
 
41
+ # Initialize Middleware for indexing the PDF content
42
  middleware = Middleware(id, create_collection=True)
43
 
44
+ # Index the specified number of pages from the PDF
45
  pages = middleware.index(pdf_path=file.name, id=id, max_pages=max_pages)
46
 
47
+ # Mark the document as indexed for this user
48
  self.indexed_docs[id] = True
49
+
50
  return f"Uploaded and extracted {len(pages)} pages"
51
+ except Exception as e: # Handle errors during processing
52
  return f"Error processing PDF: {str(e)}"
53
+ def search_documents(self, state, query, num_results=3): # Set num_results to return more pages
54
+ """
55
+ Search for a query within indexed PDF documents and return multiple matching pages.
56
+
57
+ Args:
58
+ state (dict): Session state containing user-specific data.
59
+ query (str): The user's search query.
60
+ num_results (int): Number of top results to return (default is 3).
61
+
62
+ Returns:
63
+ tuple: (list of image paths, RAG response) or an error message if no match is found.
64
+ """
65
  print(f"Searching for query: {query}")
66
+ id = generate_uuid(state) # Get unique user ID
67
+
68
+ # Check if the document has been indexed
69
+ if not self.indexed_docs.get(id, False):
70
  print("Please index documents first")
71
+ return "Please index documents first", None
72
+
73
+ # Check if a query was provided
74
  if not query:
75
  print("Please enter a search query")
76
+ return "Please enter a search query", None
 
 
77
 
78
+ try:
79
+ # Initialize Middleware for searching
80
  middleware = Middleware(id, create_collection=False)
 
 
81
 
82
+ # Perform the search and retrieve the top results
83
+ search_results = middleware.search([query]) # Returns multiple matches
84
 
85
+ # Check if there are valid search results
86
+ if not search_results or not search_results[0]:
87
+ print("No relevant matches found in the PDF")
88
+ return "No relevant matches found in the PDF", None
89
 
90
+ # Extract multiple matching pages (up to num_results)
91
+ image_paths = []
92
+ for i in range(min(len(search_results[0]), num_results)): # Limit to num_results
93
+ page_num = search_results[0][i][1] + 1 # Convert zero-based index to one-based
94
+ img_path = f"pages/{id}/page_{page_num}.png"
95
+ image_paths.append(img_path)
96
 
97
+ print(f"Retrieved image paths: {image_paths}")
98
 
99
+ # Get an answer from the RAG model using multiple images
100
+ rag_response = rag.get_answer_from_gemini(query, image_paths)
101
+
102
+ return image_paths, rag_response # Return multiple image paths and RAG response
103
 
 
 
104
  except Exception as e:
105
+ # Handle and log any errors that occur
106
+ print(f"Error during search: {e}")
107
+ return f"Error during search: {str(e)}", None
108
 
109
+
110
+ # # Function to handle search queries within indexed PDFs
111
+ # def search_documents(self, state, query, num_results=1):
112
+ # print(f"Searching for query: {query}")
113
+ # id = generate_uuid(state) # Get unique user ID
114
+
115
+ # # Check if the document has been indexed
116
+ # if not self.indexed_docs.get(id, False):
117
+ # print("Please index documents first")
118
+ # return "Please index documents first", "--"
119
+
120
+ # # Check if a query was provided
121
+ # if not query:
122
+ # print("Please enter a search query")
123
+ # return "Please enter a search query", "--"
124
+
125
+ # try:
126
+ # # Initialize Middleware for searching
127
+ # middleware = Middleware(id, create_collection=False)
128
+
129
+ # # Perform the search and retrieve the top result
130
+ # search_results = middleware.search([query])[0]
131
+
132
+ # # Extract the page number from the search results
133
+ # page_num = search_results[0][1] + 1
134
+
135
+ # print(f"Retrieved page number: {page_num}")
136
+
137
+ # # Construct the image path for the retrieved page
138
+ # img_path = f"pages/{id}/page_{page_num}.png"
139
+ # print(f"Retrieved image path: {img_path}")
140
+
141
+ # # Get an answer from the RAG model using the query and associated image
142
+ # rag_response = rag.get_answer_from_gemini(query, [img_path])
143
+
144
+ # return img_path, rag_response
145
+ # except Exception as e: # Handle errors during the search process
146
+ # return f"Error during search: {str(e)}", "--"
147
 
148
+
149
+ # Function to create the Gradio user interface
150
+ def create_ui():
151
+ app = PDFSearchApp() # Instantiate the PDFSearchApp class
152
+
153
  with gr.Blocks() as demo:
154
+ state = gr.State(value={"user_uuid": None}) # Initialize session state
155
 
156
+ # Header and introduction markdown
157
  gr.Markdown("# Colpali Milvus Multimodal RAG Demo")
158
+ gr.Markdown(
159
+ "This demo showcases how to use [Colpali](https://github.com/illuin-tech/colpali) embeddings with [Milvus](https://milvus.io/) and utilizing Gemini/OpenAI multimodal RAG for pdf search and Q&A."
160
+ )
161
+
162
+ # Upload PDF tab
163
  with gr.Tab("Upload PDF"):
164
  with gr.Column():
165
+ # Input for uploading files
166
  file_input = gr.File(label="Upload PDF")
167
+
168
+ # Slider to select the maximum number of pages to index
169
  max_pages_input = gr.Slider(
170
  minimum=1,
171
  maximum=50,
 
173
  step=10,
174
  label="Max pages to extract and index"
175
  )
176
+
177
+ # Textbox to display indexing status
178
  status = gr.Textbox(label="Indexing Status", interactive=False)
179
+
180
+ # Query tab for searching documents
181
  with gr.Tab("Query"):
182
  with gr.Column():
183
+ # Textbox for entering search queries
184
  query_input = gr.Textbox(label="Enter query")
185
+
186
+ # Button to trigger the search
 
 
 
 
 
187
  search_btn = gr.Button("Query")
188
+
189
+ # Textbox to display the response from RAG
190
  llm_answer = gr.Textbox(label="RAG Response", interactive=False)
191
+
192
+ # Image display for the top-matching page
193
  images = gr.Image(label="Top page matching query")
194
+
195
+ # Event handlers to connect UI components with backend functions
196
  file_input.change(
197
  fn=app.upload_and_convert,
198
  inputs=[state, file_input, max_pages_input],
199
  outputs=[status]
200
  )
201
+
202
  search_btn.click(
203
  fn=app.search_documents,
204
  inputs=[state, query_input],
205
  outputs=[images, llm_answer]
206
  )
 
 
207
 
208
+ return demo # Return the constructed UI
209
+
210
+ # Entry point to launch the application
211
  if __name__ == "__main__":
212
+ demo = create_ui() # Create the Gradio interface
213
+ demo.launch() # Launch the app
colpali_manager.py CHANGED
@@ -1,97 +1,132 @@
1
- from colpali_engine.models import ColPali
2
- from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
3
- from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
4
- from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
5
- from torch.utils.data import DataLoader
6
- import torch
7
- from typing import List, cast
8
-
9
- from tqdm import tqdm
10
- from PIL import Image
11
- import os
12
-
13
- import spaces
14
-
15
  model_name = "vidore/colpali-v1.2"
16
- device = get_torch_device("cuda")
17
 
 
18
  model = ColPali.from_pretrained(
19
  model_name,
20
- torch_dtype=torch.bfloat16,
21
- device_map=device,
22
- ).eval()
23
 
 
24
  processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
25
 
26
- class ColpaliManager:
27
 
28
-
29
- def __init__(self, device = "cuda", model_name = "vidore/colpali-v1.2"):
 
 
30
 
 
31
  print(f"Initializing ColpaliManager with device {device} and model {model_name}")
32
 
 
33
  # self.device = get_torch_device(device)
34
-
35
  # self.model = ColPali.from_pretrained(
36
  # model_name,
37
  # torch_dtype=torch.bfloat16,
38
  # device_map=self.device,
39
  # ).eval()
40
-
41
  # self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
42
 
43
  @spaces.GPU
44
  def get_images(self, paths: list[str]) -> List[Image.Image]:
 
 
 
 
 
 
 
 
 
45
  return [Image.open(path) for path in paths]
46
 
47
  @spaces.GPU
48
- def process_images(self, image_paths:list[str], batch_size=5):
 
 
 
 
 
 
49
 
 
 
 
50
  print(f"Processing {len(image_paths)} image_paths")
51
-
 
52
  images = self.get_images(image_paths)
53
 
 
54
  dataloader = DataLoader(
55
  dataset=ListDataset[str](images),
56
  batch_size=batch_size,
57
  shuffle=False,
58
- collate_fn=lambda x: processor.process_images(x),
59
  )
60
 
61
- ds: List[torch.Tensor] = []
62
- for batch_doc in tqdm(dataloader):
63
- with torch.no_grad():
 
64
  batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
 
65
  embeddings_doc = model(**batch_doc)
66
- ds.extend(list(torch.unbind(embeddings_doc.to(device))))
67
-
 
68
  ds_np = [d.float().cpu().numpy() for d in ds]
69
 
70
  return ds_np
71
-
72
 
73
  @spaces.GPU
74
  def process_text(self, texts: list[str]):
 
 
 
 
 
 
 
 
 
75
  print(f"Processing {len(texts)} texts")
76
 
 
77
  dataloader = DataLoader(
78
  dataset=ListDataset[str](texts),
79
- batch_size=1,
80
  shuffle=False,
81
- collate_fn=lambda x: processor.process_queries(x),
82
  )
83
 
84
- qs: List[torch.Tensor] = []
85
- for batch_query in dataloader:
86
- with torch.no_grad():
 
87
  batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
 
88
  embeddings_query = model(**batch_query)
89
 
90
- qs.extend(list(torch.unbind(embeddings_query.to(device))))
91
 
 
92
  qs_np = [q.float().cpu().numpy() for q in qs]
93
 
94
  return qs_np
95
-
96
-
97
-
 
1
+ # Importing required modules and libraries
2
+ from colpali_engine.models import ColPali # Main ColPali model for embeddings
3
+ from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor # Preprocessing for ColPali
4
+ from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor # Base processor utility
5
+ from colpali_engine.utils.torch_utils import ListDataset, get_torch_device # Torch utilities for dataset and device management
6
+ from torch.utils.data import DataLoader # PyTorch DataLoader for batching
7
+ import torch # PyTorch library
8
+ from typing import List, cast # Type annotations
9
+ from tqdm import tqdm # Progress bar utility
10
+ from PIL import Image # Image processing library
11
+ import os # OS module for file path handling
12
+ import spaces # Custom decorator module for GPU management
13
+
14
+ # Setting model name and initializing device
15
  model_name = "vidore/colpali-v1.2"
16
+ device = get_torch_device("cuda") # Get the available CUDA device
17
 
18
+ # Load the ColPali model with the specified configuration
19
  model = ColPali.from_pretrained(
20
  model_name,
21
+ torch_dtype=torch.bfloat16, # Use bfloat16 for reduced precision
22
+ device_map=device, # Map the model to the selected device
23
+ ).eval() # Set the model to evaluation mode
24
 
25
+ # Initialize the processor for handling image and text inputs
26
  processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
27
 
 
28
 
29
+ class ColpaliManager:
30
+ """
31
+ A class to manage the processing of images and text using the ColPali model.
32
+ """
33
 
34
+ def __init__(self, device="cuda", model_name="vidore/colpali-v1.2"):
35
  print(f"Initializing ColpaliManager with device {device} and model {model_name}")
36
 
37
+ # Uncomment the below lines if the class should initialize its own model and processor
38
  # self.device = get_torch_device(device)
 
39
  # self.model = ColPali.from_pretrained(
40
  # model_name,
41
  # torch_dtype=torch.bfloat16,
42
  # device_map=self.device,
43
  # ).eval()
 
44
  # self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
45
 
46
  @spaces.GPU
47
  def get_images(self, paths: list[str]) -> List[Image.Image]:
48
+ """
49
+ Load images from the given file paths.
50
+
51
+ Args:
52
+ paths (list[str]): List of file paths to images.
53
+
54
+ Returns:
55
+ List[Image.Image]: List of loaded PIL Image objects.
56
+ """
57
  return [Image.open(path) for path in paths]
58
 
59
  @spaces.GPU
60
+ def process_images(self, image_paths: list[str], batch_size=5):
61
+ """
62
+ Process a list of image paths to generate embeddings.
63
+
64
+ Args:
65
+ image_paths (list[str]): List of image file paths.
66
+ batch_size (int): Batch size for processing images.
67
 
68
+ Returns:
69
+ list: List of image embeddings as NumPy arrays.
70
+ """
71
  print(f"Processing {len(image_paths)} image_paths")
72
+
73
+ # Load images
74
  images = self.get_images(image_paths)
75
 
76
+ # Create a DataLoader for batching the images
77
  dataloader = DataLoader(
78
  dataset=ListDataset[str](images),
79
  batch_size=batch_size,
80
  shuffle=False,
81
+ collate_fn=lambda x: processor.process_images(x), # Process images using the processor
82
  )
83
 
84
+ ds: List[torch.Tensor] = [] # Initialize a list to store embeddings
85
+ for batch_doc in tqdm(dataloader): # Iterate through batches with a progress bar
86
+ with torch.no_grad(): # Disable gradient calculations for inference
87
+ # Move batch to the model's device
88
  batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
89
+ # Generate embeddings
90
  embeddings_doc = model(**batch_doc)
91
+ ds.extend(list(torch.unbind(embeddings_doc.to(device)))) # Append each embedding to the list
92
+
93
+ # Convert embeddings to NumPy arrays
94
  ds_np = [d.float().cpu().numpy() for d in ds]
95
 
96
  return ds_np
 
97
 
98
  @spaces.GPU
99
  def process_text(self, texts: list[str]):
100
+ """
101
+ Process a list of text inputs to generate embeddings.
102
+
103
+ Args:
104
+ texts (list[str]): List of text inputs.
105
+
106
+ Returns:
107
+ list: List of text embeddings as NumPy arrays.
108
+ """
109
  print(f"Processing {len(texts)} texts")
110
 
111
+ # Create a DataLoader for batching the texts
112
  dataloader = DataLoader(
113
  dataset=ListDataset[str](texts),
114
+ batch_size=1, # Process texts one at a time
115
  shuffle=False,
116
+ collate_fn=lambda x: processor.process_queries(x), # Process texts using the processor
117
  )
118
 
119
+ qs: List[torch.Tensor] = [] # Initialize a list to store text embeddings
120
+ for batch_query in dataloader: # Iterate through batches
121
+ with torch.no_grad(): # Disable gradient calculations for inference
122
+ # Move batch to the model's device
123
  batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
124
+ # Generate embeddings
125
  embeddings_query = model(**batch_query)
126
 
127
+ qs.extend(list(torch.unbind(embeddings_query.to(device)))) # Append each embedding to the list
128
 
129
+ # Convert embeddings to NumPy arrays
130
  qs_np = [q.float().cpu().numpy() for q in qs]
131
 
132
  return qs_np
 
 
 
middleware.py CHANGED
@@ -1,56 +1,97 @@
1
- from colpali_manager import ColpaliManager
2
- from milvus_manager import MilvusManager
3
- from pdf_manager import PdfManager
4
- import hashlib
5
-
6
-
7
- pdf_manager = PdfManager()
8
- colpali_manager = ColpaliManager()
9
 
 
 
 
10
 
11
 
12
  class Middleware:
13
- def __init__(self, id:str, create_collection=True):
 
 
 
 
 
 
 
 
 
 
 
 
14
  hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
15
  milvus_db_name = f"milvus_{hashed_id}.db"
 
 
16
  self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
17
 
18
- def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
19
-
 
 
 
 
 
 
 
 
 
 
 
20
  print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
21
 
 
22
  image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
23
-
24
  print(f"Saved {len(image_paths)} images")
25
 
 
26
  colbert_vecs = colpali_manager.process_images(image_paths)
27
 
 
28
  images_data = [{
29
- "colbert_vecs": colbert_vecs[i],
30
- "filepath": image_paths[i]
31
  } for i in range(len(image_paths))]
32
 
33
  print(f"Inserting {len(images_data)} images data to Milvus")
34
 
 
35
  self.milvus_manager.insert_images_data(images_data)
36
 
37
  print("Indexing completed")
38
 
39
- return image_paths
40
-
41
 
42
-
43
  def search(self, search_queries: list[str]):
 
 
 
 
 
 
 
 
 
44
  print(f"Searching for {len(search_queries)} queries")
45
 
46
- final_res = []
47
 
48
  for query in search_queries:
49
  print(f"Searching for query: {query}")
 
 
50
  query_vec = colpali_manager.process_text([query])[0]
 
 
51
  search_res = self.milvus_manager.search(query_vec, topk=1)
 
52
  print(f"Search result: {search_res} for query: {query}")
53
- final_res.append(search_res)
54
 
55
- return final_res
 
56
 
 
 
1
+ # Import necessary modules and classes
2
+ from colpali_manager import ColpaliManager # Manages processing of images and text with the ColPali model
3
+ from milvus_manager import MilvusManager # Manages interactions with the Milvus database
4
+ from pdf_manager import PdfManager # Handles PDF processing tasks
5
+ import hashlib # Library for creating hashed identifiers
 
 
 
6
 
7
+ # Initialize managers
8
+ pdf_manager = PdfManager() # PDF manager instance for handling PDF-related operations
9
+ colpali_manager = ColpaliManager() # ColPali manager instance for processing images and text
10
 
11
 
12
  class Middleware:
13
+ """
14
+ Middleware class that integrates PDF processing, image embedding, and database indexing/searching.
15
+ """
16
+
17
+ def __init__(self, id: str, create_collection=True):
18
+ """
19
+ Initialize the Middleware with a unique identifier and Milvus database setup.
20
+
21
+ Args:
22
+ id (str): Unique identifier for the user/session.
23
+ create_collection (bool): Whether to create a new collection in the Milvus database.
24
+ """
25
+ # Generate a hashed ID for the Milvus database name
26
  hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
27
  milvus_db_name = f"milvus_{hashed_id}.db"
28
+
29
+ # Initialize the Milvus manager with the generated database name
30
  self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
31
 
32
+ def index(self, pdf_path: str, id: str, max_pages: int, pages: list[int] = None):
33
+ """
34
+ Index the content of a PDF file into the Milvus database.
35
+
36
+ Args:
37
+ pdf_path (str): Path to the PDF file.
38
+ id (str): Unique identifier for the session.
39
+ max_pages (int): Maximum number of pages to extract and index.
40
+ pages (list[int], optional): Specific pages to extract (default is None for all).
41
+
42
+ Returns:
43
+ list[str]: List of paths to the saved image files.
44
+ """
45
  print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
46
 
47
+ # Convert PDF pages into image files and save them
48
  image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
 
49
  print(f"Saved {len(image_paths)} images")
50
 
51
+ # Generate image embeddings using the ColPali model
52
  colbert_vecs = colpali_manager.process_images(image_paths)
53
 
54
+ # Prepare data for insertion into Milvus
55
  images_data = [{
56
+ "colbert_vecs": colbert_vecs[i], # Image embeddings
57
+ "filepath": image_paths[i] # Corresponding image file path
58
  } for i in range(len(image_paths))]
59
 
60
  print(f"Inserting {len(images_data)} images data to Milvus")
61
 
62
+ # Insert the image data into the Milvus database
63
  self.milvus_manager.insert_images_data(images_data)
64
 
65
  print("Indexing completed")
66
 
67
+ return image_paths # Return the list of saved image paths
 
68
 
 
69
  def search(self, search_queries: list[str]):
70
+ """
71
+ Search for matching results in the indexed database based on text queries.
72
+
73
+ Args:
74
+ search_queries (list[str]): List of search queries.
75
+
76
+ Returns:
77
+ list: Search results for each query.
78
+ """
79
  print(f"Searching for {len(search_queries)} queries")
80
 
81
+ final_res = [] # List to store the final search results
82
 
83
  for query in search_queries:
84
  print(f"Searching for query: {query}")
85
+
86
+ # Process the query text to generate an embedding
87
  query_vec = colpali_manager.process_text([query])[0]
88
+
89
+ # Perform the search in the Milvus database
90
  search_res = self.milvus_manager.search(query_vec, topk=1)
91
+
92
  print(f"Search result: {search_res} for query: {query}")
 
93
 
94
+ # Append the search results to the final results list
95
+ final_res.append(search_res)
96
 
97
+ return final_res # Return all search results
milvus_manager.py CHANGED
@@ -1,69 +1,99 @@
1
- from pymilvus import MilvusClient, DataType
2
- import numpy as np
3
- import concurrent.futures
4
-
5
 
6
  class MilvusManager:
 
 
 
 
 
7
  def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
8
- self.client = MilvusClient(uri=milvus_uri)
 
 
 
 
 
 
 
 
 
9
  self.collection_name = collection_name
 
 
 
10
  if self.client.has_collection(collection_name=self.collection_name):
11
  self.client.load_collection(collection_name)
12
- self.dim = dim
13
 
14
  if create_collection:
15
- self.create_collection()
16
- self.create_index()
17
-
18
 
19
  def create_collection(self):
 
 
 
 
20
  if self.client.has_collection(collection_name=self.collection_name):
21
  self.client.drop_collection(collection_name=self.collection_name)
 
 
22
  schema = self.client.create_schema(
23
- auto_id=True,
24
- enable_dynamic_fields=True,
25
  )
26
- schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
27
  schema.add_field(
28
- field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
29
  )
30
- schema.add_field(field_name="seq_id", datatype=DataType.INT16)
31
- schema.add_field(field_name="doc_id", datatype=DataType.INT64)
32
- schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
33
 
 
34
  self.client.create_collection(
35
  collection_name=self.collection_name, schema=schema
36
  )
37
 
38
  def create_index(self):
 
 
 
 
39
  self.client.release_collection(collection_name=self.collection_name)
40
- self.client.drop_index(
41
- collection_name=self.collection_name, index_name="vector"
42
- )
43
  index_params = self.client.prepare_index_params()
44
  index_params.add_index(
45
  field_name="vector",
46
  index_name="vector_index",
47
- index_type="HNSW",
48
- metric_type="IP",
49
  params={
50
- "M": 16,
51
- "efConstruction": 500,
52
  },
53
  )
54
 
 
55
  self.client.create_index(
56
  collection_name=self.collection_name, index_params=index_params, sync=True
57
  )
58
 
59
  def create_scalar_index(self):
 
 
 
60
  self.client.release_collection(collection_name=self.collection_name)
61
 
62
  index_params = self.client.prepare_index_params()
63
  index_params.add_index(
64
  field_name="doc_id",
65
  index_name="int32_index",
66
- index_type="INVERTED",
67
  )
68
 
69
  self.client.create_index(
@@ -71,14 +101,26 @@ class MilvusManager:
71
  )
72
 
73
  def search(self, data, topk):
74
- search_params = {"metric_type": "IP", "params": {}}
 
 
 
 
 
 
 
 
 
 
75
  results = self.client.search(
76
  self.collection_name,
77
  data,
78
- limit=int(50),
79
- output_fields=["vector", "seq_id", "doc_id"],
80
  search_params=search_params,
81
  )
 
 
82
  doc_ids = set()
83
  for r_id in range(len(results)):
84
  for r in range(len(results[r_id])):
@@ -86,19 +128,22 @@ class MilvusManager:
86
 
87
  scores = []
88
 
 
89
  def rerank_single_doc(doc_id, data, client, collection_name):
90
  doc_colbert_vecs = client.query(
91
  collection_name=collection_name,
92
- filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
93
- output_fields=["seq_id", "vector", "doc"],
94
- limit=1000,
95
  )
 
96
  doc_vecs = np.vstack(
97
  [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
98
  )
99
  score = np.dot(data, doc_vecs.T).max(1).sum()
100
  return (score, doc_id)
101
 
 
102
  with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
103
  futures = {
104
  executor.submit(
@@ -110,20 +155,25 @@ class MilvusManager:
110
  score, doc_id = future.result()
111
  scores.append((score, doc_id))
112
 
 
113
  scores.sort(key=lambda x: x[0], reverse=True)
114
- if len(scores) >= topk:
115
- return scores[:topk]
116
- else:
117
- return scores
118
 
119
  def insert(self, data):
 
 
 
 
 
 
120
  colbert_vecs = [vec for vec in data["colbert_vecs"]]
121
  seq_length = len(colbert_vecs)
122
  doc_ids = [data["doc_id"] for i in range(seq_length)]
123
  seq_ids = list(range(seq_length))
124
  docs = [""] * seq_length
125
- docs[0] = data["filepath"]
126
 
 
127
  self.client.insert(
128
  self.collection_name,
129
  [
@@ -137,11 +187,17 @@ class MilvusManager:
137
  ],
138
  )
139
 
 
 
 
140
 
141
- def get_images_as_doc(self, images_with_vectors:list):
142
-
143
- images_data = []
144
 
 
 
 
 
145
  for i in range(len(images_with_vectors)):
146
  data = {
147
  "colbert_vecs": images_with_vectors[i]["colbert_vecs"],
@@ -149,14 +205,15 @@ class MilvusManager:
149
  "filepath": images_with_vectors[i]["filepath"],
150
  }
151
  images_data.append(data)
152
-
153
  return images_data
154
 
155
-
156
  def insert_images_data(self, image_data):
157
- data = self.get_images_as_doc(image_data)
 
158
 
 
 
 
 
159
  for i in range(len(data)):
160
- self.insert(data[i])
161
-
162
-
 
1
+ # Import necessary modules
2
+ from pymilvus import MilvusClient, DataType # Milvus client and data type definitions
3
+ import numpy as np # For numerical operations
4
+ import concurrent.futures # For concurrent execution of tasks
5
 
6
  class MilvusManager:
7
+ """
8
+ A manager class for interacting with the Milvus database, handling collection creation,
9
+ data insertion, and search functionality.
10
+ """
11
+
12
  def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
13
+ """
14
+ Initialize the MilvusManager.
15
+
16
+ Args:
17
+ milvus_uri (str): URI for connecting to the Milvus server.
18
+ collection_name (str): Name of the collection in Milvus.
19
+ create_collection (bool): Whether to create a new collection.
20
+ dim (int): Dimensionality of the vector embeddings (default is 128).
21
+ """
22
+ self.client = MilvusClient(uri=milvus_uri) # Initialize the Milvus client
23
  self.collection_name = collection_name
24
+ self.dim = dim
25
+
26
+ # Load the collection if it exists, otherwise create it
27
  if self.client.has_collection(collection_name=self.collection_name):
28
  self.client.load_collection(collection_name)
 
29
 
30
  if create_collection:
31
+ self.create_collection() # Create a new collection
32
+ self.create_index() # Create an index for the collection
 
33
 
34
  def create_collection(self):
35
+ """
36
+ Create a new collection in Milvus with a predefined schema.
37
+ """
38
+ # Drop the collection if it already exists
39
  if self.client.has_collection(collection_name=self.collection_name):
40
  self.client.drop_collection(collection_name=self.collection_name)
41
+
42
+ # Define the schema for the collection
43
  schema = self.client.create_schema(
44
+ auto_id=True, # Enable automatic ID assignment
45
+ enable_dynamic_fields=True, # Allow dynamic fields
46
  )
47
+ schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) # Primary key
48
  schema.add_field(
49
+ field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim # Vector field
50
  )
51
+ schema.add_field(field_name="seq_id", datatype=DataType.INT16) # Sequence ID
52
+ schema.add_field(field_name="doc_id", datatype=DataType.INT64) # Document ID
53
+ schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) # Document path
54
 
55
+ # Create the collection with the specified schema
56
  self.client.create_collection(
57
  collection_name=self.collection_name, schema=schema
58
  )
59
 
60
  def create_index(self):
61
+ """
62
+ Create an HNSW index for the vector field in the collection.
63
+ """
64
+ # Release the collection before updating the index
65
  self.client.release_collection(collection_name=self.collection_name)
66
+ self.client.drop_index(collection_name=self.collection_name, index_name="vector")
67
+
68
+ # Define the HNSW index parameters
69
  index_params = self.client.prepare_index_params()
70
  index_params.add_index(
71
  field_name="vector",
72
  index_name="vector_index",
73
+ index_type="HNSW", # Hierarchical Navigable Small World graph index
74
+ metric_type="IP", # Inner Product (dot product) as similarity metric
75
  params={
76
+ "M": 16, # Number of candidate connections
77
+ "efConstruction": 500, # Construction complexity
78
  },
79
  )
80
 
81
+ # Create the index and synchronize with the server
82
  self.client.create_index(
83
  collection_name=self.collection_name, index_params=index_params, sync=True
84
  )
85
 
86
  def create_scalar_index(self):
87
+ """
88
+ Create an inverted index for scalar fields such as document IDs.
89
+ """
90
  self.client.release_collection(collection_name=self.collection_name)
91
 
92
  index_params = self.client.prepare_index_params()
93
  index_params.add_index(
94
  field_name="doc_id",
95
  index_name="int32_index",
96
+ index_type="INVERTED", # Inverted index for scalar data
97
  )
98
 
99
  self.client.create_index(
 
101
  )
102
 
103
  def search(self, data, topk):
104
+ """
105
+ Search for the top-k most similar vectors in the collection.
106
+
107
+ Args:
108
+ data (array-like): Query vector.
109
+ topk (int): Number of top results to return.
110
+
111
+ Returns:
112
+ list: Sorted list of top-k results.
113
+ """
114
+ search_params = {"metric_type": "IP", "params": {}} # Search parameters for Inner Product
115
  results = self.client.search(
116
  self.collection_name,
117
  data,
118
+ limit=50, # Initial retrieval limit
119
+ output_fields=["vector", "seq_id", "doc_id"], # Fields to include in the output
120
  search_params=search_params,
121
  )
122
+
123
+ # Collect unique document IDs from the search results
124
  doc_ids = set()
125
  for r_id in range(len(results)):
126
  for r in range(len(results[r_id])):
 
128
 
129
  scores = []
130
 
131
+ # Function to rerank a single document based on its relevance to the query
132
  def rerank_single_doc(doc_id, data, client, collection_name):
133
  doc_colbert_vecs = client.query(
134
  collection_name=collection_name,
135
+ filter=f"doc_id in [{doc_id}, {doc_id + 1}]", # Query documents by ID
136
+ output_fields=["seq_id", "vector", "doc"], # Fields to retrieve
137
+ limit=1000, # Retrieve a maximum of 1000 vectors per document
138
  )
139
+ # Compute the maximum similarity score for the document
140
  doc_vecs = np.vstack(
141
  [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
142
  )
143
  score = np.dot(data, doc_vecs.T).max(1).sum()
144
  return (score, doc_id)
145
 
146
+ # Use multithreading to rerank documents in parallel
147
  with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
148
  futures = {
149
  executor.submit(
 
155
  score, doc_id = future.result()
156
  scores.append((score, doc_id))
157
 
158
+ # Sort scores in descending order and return the top-k results
159
  scores.sort(key=lambda x: x[0], reverse=True)
160
+ return scores[:topk] if len(scores) >= topk else scores
 
 
 
161
 
162
  def insert(self, data):
163
+ """
164
+ Insert a batch of data into the collection.
165
+
166
+ Args:
167
+ data (dict): Dictionary containing vector embeddings and metadata.
168
+ """
169
  colbert_vecs = [vec for vec in data["colbert_vecs"]]
170
  seq_length = len(colbert_vecs)
171
  doc_ids = [data["doc_id"] for i in range(seq_length)]
172
  seq_ids = list(range(seq_length))
173
  docs = [""] * seq_length
174
+ docs[0] = data["filepath"] # Store file path in the first entry
175
 
176
+ # Insert the data into the collection
177
  self.client.insert(
178
  self.collection_name,
179
  [
 
187
  ],
188
  )
189
 
190
+ def get_images_as_doc(self, images_with_vectors: list):
191
+ """
192
+ Convert image data with vectors into document-like format for insertion.
193
 
194
+ Args:
195
+ images_with_vectors (list): List of dictionaries containing image vectors and file paths.
 
196
 
197
+ Returns:
198
+ list: Transformed data ready for insertion.
199
+ """
200
+ images_data = []
201
  for i in range(len(images_with_vectors)):
202
  data = {
203
  "colbert_vecs": images_with_vectors[i]["colbert_vecs"],
 
205
  "filepath": images_with_vectors[i]["filepath"],
206
  }
207
  images_data.append(data)
 
208
  return images_data
209
 
 
210
  def insert_images_data(self, image_data):
211
+ """
212
+ Insert processed image data into the collection.
213
 
214
+ Args:
215
+ image_data (list): List of image data dictionaries.
216
+ """
217
+ data = self.get_images_as_doc(image_data)
218
  for i in range(len(data)):
219
+ self.insert(data[i]) # Insert each item individually
 
 
pdf_manager.py CHANGED
@@ -1,42 +1,79 @@
1
- from pdf2image import convert_from_path
2
- import os
3
- import shutil
 
4
 
5
  class PdfManager:
 
 
 
 
 
6
  def __init__(self):
 
 
 
 
7
  pass
8
-
9
  def clear_and_recreate_dir(self, output_folder):
 
 
 
 
 
 
10
  print(f"Clearing output folder {output_folder}")
11
 
 
12
  if os.path.exists(output_folder):
13
- shutil.rmtree(output_folder)
14
 
 
15
  os.makedirs(output_folder)
16
 
17
  def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  output_folder = f"pages/{id}/"
19
- images = convert_from_path(pdf_path)
20
 
 
 
21
  print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
22
 
 
23
  self.clear_and_recreate_dir(output_folder)
24
 
25
- num_page_processed = 0
26
 
 
27
  for i, image in enumerate(images):
 
28
  if max_pages and num_page_processed >= max_pages:
29
  break
30
 
 
31
  if pages and i not in pages:
32
  continue
33
 
 
34
  full_save_path = f"{output_folder}/page_{i + 1}.png"
35
 
36
- #print(f"Saving image to {full_save_path}")
37
-
38
  image.save(full_save_path, "PNG")
39
 
40
- num_page_processed += 1
41
 
 
42
  return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
 
1
+ # Import necessary modules
2
+ from pdf2image import convert_from_path # Convert PDF pages to images
3
+ import os # For file and directory operations
4
+ import shutil # For removing and recreating directories
5
 
6
  class PdfManager:
7
+ """
8
+ A manager class for handling PDF-related operations, such as converting pages to images
9
+ and managing output directories.
10
+ """
11
+
12
  def __init__(self):
13
+ """
14
+ Initialize the PdfManager.
15
+ Currently, no attributes are set during initialization.
16
+ """
17
  pass
18
+
19
  def clear_and_recreate_dir(self, output_folder):
20
+ """
21
+ Clear the specified directory and recreate it.
22
+
23
+ Args:
24
+ output_folder (str): Path to the directory to be cleared and recreated.
25
+ """
26
  print(f"Clearing output folder {output_folder}")
27
 
28
+ # Remove the directory if it exists
29
  if os.path.exists(output_folder):
30
+ shutil.rmtree(output_folder) # Delete the folder and its contents
31
 
32
+ # Recreate the directory
33
  os.makedirs(output_folder)
34
 
35
  def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
36
+ """
37
+ Convert PDF pages to images and save them to a specified directory.
38
+
39
+ Args:
40
+ id (str): Unique identifier for the output folder.
41
+ pdf_path (str): Path to the PDF file to be processed.
42
+ max_pages (int): Maximum number of pages to convert and save.
43
+ pages (list[int], optional): Specific page numbers to convert (default is None for all).
44
+
45
+ Returns:
46
+ list[str]: List of paths to the saved images.
47
+ """
48
+ # Define the output folder for the images
49
  output_folder = f"pages/{id}/"
 
50
 
51
+ # Convert the PDF pages to images
52
+ images = convert_from_path(pdf_path)
53
  print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
54
 
55
+ # Clear the existing directory and recreate it
56
  self.clear_and_recreate_dir(output_folder)
57
 
58
+ num_page_processed = 0 # Counter for the number of pages processed
59
 
60
+ # Iterate through the converted images
61
  for i, image in enumerate(images):
62
+ # Stop processing if the maximum number of pages is reached
63
  if max_pages and num_page_processed >= max_pages:
64
  break
65
 
66
+ # Skip pages not in the specified list (if provided)
67
  if pages and i not in pages:
68
  continue
69
 
70
+ # Define the save path for the current page
71
  full_save_path = f"{output_folder}/page_{i + 1}.png"
72
 
73
+ # Save the image in PNG format
 
74
  image.save(full_save_path, "PNG")
75
 
76
+ num_page_processed += 1 # Increment the processed page counter
77
 
78
+ # Return the paths of the saved images
79
  return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
rag.py CHANGED
@@ -1,104 +1,188 @@
1
- import requests
2
- import os
3
- import google.generativeai as genai
4
-
5
- from typing import List
6
- from utils import encode_image
7
- from PIL import Image
8
 
9
  class Rag:
 
 
 
 
10
 
11
- def get_answer_from_gemini(self, query, imagePaths):
 
 
12
 
13
- print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
 
 
14
 
15
- try:
16
- genai.configure(api_key=os.environ['GEMINI_API_KEY'])
17
- model = genai.GenerativeModel('gemini-1.5-flash')
18
-
19
- images = [Image.open(path) for path in imagePaths]
20
-
21
- chat = model.start_chat()
22
 
23
- response = chat.send_message([*images, query])
 
 
24
 
25
- answer = response.text
 
26
 
27
- print(answer)
28
-
29
- return answer
30
-
31
- except Exception as e:
32
- print(f"An error occurred while querying Gemini: {e}")
33
- return f"Error: {str(e)}"
34
-
35
 
36
- def get_answer_from_openai(self, query, imagesPaths):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
38
 
39
- try:
 
40
  payload = self.__get_openai_api_payload(query, imagesPaths)
41
 
 
42
  headers = {
43
  "Content-Type": "application/json",
44
- "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
45
  }
46
-
 
47
  response = requests.post(
48
  url="https://api.openai.com/v1/chat/completions",
49
  headers=headers,
50
  json=payload
51
  )
52
- response.raise_for_status() # Raise an HTTPError for bad responses
53
-
 
54
  answer = response.json()["choices"][0]["message"]["content"]
55
-
56
- print(answer)
57
-
58
  return answer
59
-
60
  except Exception as e:
 
61
  print(f"An error occurred while querying OpenAI: {e}")
62
  return None
 
 
 
63
 
 
 
 
64
 
65
- def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
66
- image_payload = []
 
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  for imagePath in imagesPaths:
69
- base64_image = encode_image(imagePath)
70
  image_payload.append({
71
  "type": "image_url",
72
  "image_url": {
73
- "url": f"data:image/jpeg;base64,{base64_image}"
74
  }
75
  })
76
 
 
77
  payload = {
78
- "model": "gpt-4o",
79
  "messages": [
80
  {
81
- "role": "user",
82
  "content": [
83
  {
84
  "type": "text",
85
- "text": query
86
  },
87
- *image_payload
88
  ]
89
  }
90
  ],
91
- "max_tokens": 1024
92
  }
93
 
94
  return payload
95
-
96
-
97
-
98
- # if __name__ == "__main__":
99
- # rag = Rag()
100
-
101
- # query = "Based on attached images, how many new cases were reported during second wave peak"
102
- # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
103
-
104
- # rag.get_answer_from_gemini(query, imagesPaths)
 
1
+ # Import required libraries
2
+ import requests # For making HTTP requests
3
+ import os # For accessing environment variables
4
+ import google.generativeai as genai # For interacting with Google's Generative AI APIs
5
+ from typing import List # For type annotations
6
+ from utils import encode_image # Utility function to encode images as base64
7
+ from PIL import Image # For image processing
8
 
9
  class Rag:
10
+ """
11
+ A class for interacting with Generative AI models (Gemini and OpenAI) to retrieve answers
12
+ based on user queries and associated images.
13
+ """
14
 
15
+ # def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str:
16
+ # """
17
+ # Query the Gemini model with a text query and associated images.
18
 
19
+ # Args:
20
+ # query (str): The user's query.
21
+ # imagePaths (List[str]): List of file paths to images.
22
 
23
+ # Returns:
24
+ # str: The response text from the Gemini model.
25
+ # """
26
+ # print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
 
 
 
27
 
28
+ # try:
29
+ # # Configure the Gemini API client using the API key from environment variables
30
+ # genai.configure(api_key=os.environ['GEMINI_API_KEY'])
31
 
32
+ # # Initialize the Gemini generative model
33
+ # model = genai.GenerativeModel('gemini-1.5-flash')
34
 
35
+ # # Load images from the given paths
36
+ # images = [Image.open(path) for path in imagePaths]
 
 
 
 
 
 
37
 
38
+ # # Start a new chat session
39
+ # chat = model.start_chat()
40
+
41
+ # # Send the query and images to the model
42
+ # response = chat.send_message([*images, query])
43
+
44
+ # # Extract the response text
45
+ # answer = response.text
46
+
47
+ # print(answer) # Log the answer
48
+
49
+ # return answer
50
+
51
+ # except Exception as e:
52
+ # # Handle and log any errors that occur
53
+ # print(f"An error occurred while querying Gemini: {e}")
54
+ # return f"Error: {str(e)}"
55
+
56
+ def get_answer_from_openai(self, query: str, imagesPaths: List[str]) -> str:
57
+ """
58
+ Query OpenAI's GPT model with a text query and associated images.
59
+
60
+ Args:
61
+ query (str): The user's query.
62
+ imagesPaths (List[str]): List of file paths to images.
63
+
64
+ Returns:
65
+ str: The response text from OpenAI.
66
+ """
67
  print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
68
 
69
+ try:
70
+ # Prepare the API payload with the query and images
71
  payload = self.__get_openai_api_payload(query, imagesPaths)
72
 
73
+ # Define the HTTP headers for the OpenAI API request
74
  headers = {
75
  "Content-Type": "application/json",
76
+ "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" # API key from environment variables
77
  }
78
+
79
+ # Send a POST request to the OpenAI API
80
  response = requests.post(
81
  url="https://api.openai.com/v1/chat/completions",
82
  headers=headers,
83
  json=payload
84
  )
85
+ response.raise_for_status() # Raise an error for unsuccessful requests
86
+
87
+ # Extract the content of the response
88
  answer = response.json()["choices"][0]["message"]["content"]
89
+
90
+ print(answer) # Log the answer
91
+
92
  return answer
93
+
94
  except Exception as e:
95
+ # Handle and log any errors that occur
96
  print(f"An error occurred while querying OpenAI: {e}")
97
  return None
98
+ def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str:
99
+ """
100
+ Query the Gemini model with a text query and associated images.
101
 
102
+ Args:
103
+ query (str): The user's query.
104
+ imagePaths (List[str]): List of file paths to images.
105
 
106
+ Returns:
107
+ str: The response text from the Gemini model.
108
+ """
109
+ print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
110
 
111
+ try:
112
+ # Configure the Gemini API client using the API key from environment variables
113
+ genai.configure(api_key=os.environ['GEMINI_API_KEY'])
114
+
115
+ # Initialize the Gemini generative model
116
+ model = genai.GenerativeModel('gemini-1.5-flash')
117
+
118
+ # Load images from the given paths (skip missing files)
119
+ images = []
120
+ for path in imagePaths:
121
+ if os.path.exists(path):
122
+ images.append(Image.open(path))
123
+ else:
124
+ print(f"Warning: Image not found {path}, skipping.")
125
+
126
+ # Start a new chat session
127
+ chat = model.start_chat()
128
+
129
+ # Construct the input for the model (handle cases with and without images)
130
+ input_data = [query] if not images else [*images, query]
131
+
132
+ # Send the query (and images, if any) to the model
133
+ response = chat.send_message(input_data)
134
+
135
+ # Extract the response text
136
+ answer = response.text
137
+
138
+ print(answer) # Log the answer
139
+
140
+ return answer
141
+
142
+ except Exception as e:
143
+ # Handle and log any errors that occur
144
+ print(f"An error occurred while querying Gemini: {e}")
145
+ return f"Error: {str(e)}"
146
+
147
+ def __get_openai_api_payload(self, query: str, imagesPaths: List[str]) -> dict:
148
+ """
149
+ Prepare the payload for the OpenAI API request.
150
+
151
+ Args:
152
+ query (str): The user's query.
153
+ imagesPaths (List[str]): List of file paths to images.
154
+
155
+ Returns:
156
+ dict: The payload for the OpenAI API request.
157
+ """
158
+ image_payload = [] # List to store encoded image data
159
+
160
+ # Encode each image as base64 and prepare the payload
161
  for imagePath in imagesPaths:
162
+ base64_image = encode_image(imagePath) # Encode image in base64
163
  image_payload.append({
164
  "type": "image_url",
165
  "image_url": {
166
+ "url": f"data:image/jpeg;base64,{base64_image}" # Embed image data as a URL
167
  }
168
  })
169
 
170
+ # Create the complete payload for the API request
171
  payload = {
172
+ "model": "gpt-4o", # Specify the OpenAI model
173
  "messages": [
174
  {
175
+ "role": "user", # Role of the message sender
176
  "content": [
177
  {
178
  "type": "text",
179
+ "text": query # Include the user's query
180
  },
181
+ *image_payload # Include the image data
182
  ]
183
  }
184
  ],
185
+ "max_tokens": 1024 # Limit the response length
186
  }
187
 
188
  return payload
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,5 +1,16 @@
1
- import base64
2
 
3
- def encode_image(image_path):
 
 
 
 
 
 
 
 
 
 
4
  with open(image_path, "rb") as image_file:
 
5
  return base64.b64encode(image_file.read()).decode('utf-8')
 
1
+ import base64 # Library for encoding and decoding data in base64 format
2
 
3
+ def encode_image(image_path: str) -> str:
4
+ """
5
+ Encode an image file to a base64 string.
6
+
7
+ Args:
8
+ image_path (str): The file path of the image to be encoded.
9
+
10
+ Returns:
11
+ str: The base64-encoded string representation of the image.
12
+ """
13
+ # Open the image file in binary read mode
14
  with open(image_path, "rb") as image_file:
15
+ # Read the image content, encode it to base64, and decode it to a UTF-8 string
16
  return base64.b64encode(image_file.read()).decode('utf-8')