Spaces:
Runtime error
Runtime error
ej68okap
commited on
Commit
Β·
241c492
1
Parent(s):
b9e672c
new code added
Browse files- README.md +76 -9
- app.py +133 -56
- colpali_manager.py +76 -41
- middleware.py +61 -20
- milvus_manager.py +101 -44
- pdf_manager.py +47 -10
- rag.py +139 -55
- utils.py +13 -2
README.md
CHANGED
@@ -1,12 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multimodal RAG with Colpali, Milvus, and Visual Language Models
|
2 |
+
|
3 |
+
This repository demonstrates how to build a **Multimodal Retrieval-Augmented Generation (RAG)** application using **Colpali**, **Milvus**, and **Visual Language Models (VLMs)** like Gemini or GPT-4o. The application allows users to upload a PDF and perform Q&A queries on both textual and visual elements of the document.
|
4 |
+
|
5 |
+
---
|
6 |
+
|
7 |
+
## Features
|
8 |
+
|
9 |
+
- **Multimodal Q&A**: Combines visual and textual embeddings for robust query answering.
|
10 |
+
- **PDF as Images**: Treats PDF pages as images to preserve layout and visual context.
|
11 |
+
- **Efficient Retrieval**: Utilizes Milvus for fast and accurate vector search.
|
12 |
+
- **Advanced Query Processing**: Integrates Colpali and VLMs for embeddings and response generation.
|
13 |
+
|
14 |
---
|
15 |
+
|
16 |
+
## Architecture Overview
|
17 |
+
|
18 |
+
1. **Colpali**:
|
19 |
+
- Generates embeddings for images (PDF pages) and text (user queries).
|
20 |
+
- Processes visual and textual data seamlessly.
|
21 |
+
|
22 |
+
2. **Milvus**:
|
23 |
+
- A vector database used for indexing and retrieving embeddings.
|
24 |
+
- Supports HNSW-based indexing for efficient similarity searches.
|
25 |
+
|
26 |
+
3. **Visual Language Models**:
|
27 |
+
- Gemini or GPT-4o performs context-aware Q&A using retrieved pages.
|
28 |
+
|
29 |
---
|
30 |
|
31 |
+
## Installation
|
32 |
+
|
33 |
+
### Prerequisites
|
34 |
+
- Python 3.8 or higher
|
35 |
+
- CUDA-compatible GPU for acceleration
|
36 |
+
- Milvus installed and running ([Installation Guide](https://milvus.io/docs/install_standalone.md))
|
37 |
+
- Required Python packages (see `requirements.txt`)
|
38 |
+
|
39 |
+
### Steps to Run the Application Locally
|
40 |
+
1. Clone the repository
|
41 |
+
2. Install dependencies as **pip install -r requirements.txt**
|
42 |
+
3. Set up environment variables
|
43 |
+
Add the following variables to your .env file or environment:
|
44 |
+
GEMINI_API_KEY=<Your_Gemini_API_Key>
|
45 |
+
4. Launch the Gradio App as **python app.py**
|
46 |
+
|
47 |
+
|
48 |
+
### Deploying the Gradio App on Hugging Face Spaces
|
49 |
+
1. Prepare the Repository
|
50 |
+
git clone https://github.com/saumitras/colpali-milvus-rag.git
|
51 |
+
cd colpali-milvus-rag
|
52 |
+
|
53 |
+
2. Organize the Repository:
|
54 |
+
Ensure the app file (e.g., app.py) contains the Gradio application code.
|
55 |
+
Include the requirements.txt file for dependencies.
|
56 |
+
|
57 |
+
Update the Hugging Face API Configuration:
|
58 |
+
|
59 |
+
3. Add necessary environment variables like GEMINI_API_KEY or OPENAI_API_KEY to the Hugging Face Spaces Secrets:
|
60 |
+
Navigate to your Hugging Face Space.
|
61 |
+
Go to the Settings tab and add the required secrets under Repository secrets.
|
62 |
+
|
63 |
+
4. Create a New Space
|
64 |
+
Visit Hugging Face Spaces.
|
65 |
+
Click New Space.
|
66 |
+
Fill in the details:
|
67 |
+
Name: Give your Space a unique name (e.g., multimodal_rag).
|
68 |
+
SDK: Select Gradio as the SDK.
|
69 |
+
Visibility: Choose between Public or Private.
|
70 |
+
Click Create Space.
|
71 |
+
5. Push Code to Hugging Face
|
72 |
+
Initialize Git and push the code:
|
73 |
+
git remote add hf https://huggingface.co/spaces/ultron1996/multimodal_rag
|
74 |
+
git push hf main
|
75 |
+
|
76 |
+
6. Wait for the Hugging Face Space to build and deploy the application.
|
77 |
+
|
78 |
+
|
79 |
+
The app has been deployed on Hugging Face Spaces and Demo is running at https://huggingface.co/spaces/ultron1996/multimodal_rag
|
app.py
CHANGED
@@ -1,97 +1,171 @@
|
|
1 |
import gradio as gr
|
2 |
import tempfile
|
3 |
import os
|
4 |
-
import fitz # PyMuPDF
|
5 |
import uuid
|
6 |
|
7 |
-
|
8 |
from middleware import Middleware
|
9 |
from rag import Rag
|
10 |
|
11 |
-
rag = Rag()
|
12 |
|
|
|
13 |
def generate_uuid(state):
|
14 |
# Check if UUID already exists in session state
|
15 |
if state["user_uuid"] is None:
|
16 |
# Generate a new UUID if not already set
|
17 |
state["user_uuid"] = str(uuid.uuid4())
|
18 |
-
|
19 |
return state["user_uuid"]
|
20 |
|
21 |
|
22 |
class PDFSearchApp:
|
|
|
|
|
23 |
def __init__(self):
|
24 |
-
self.indexed_docs = {}
|
25 |
-
self.current_pdf = None
|
26 |
-
|
27 |
-
|
28 |
def upload_and_convert(self, state, file, max_pages):
|
29 |
-
id = generate_uuid(state)
|
30 |
|
31 |
-
if file is None:
|
32 |
return "No file uploaded"
|
33 |
|
34 |
print(f"Uploading file: {file.name}, id: {id}")
|
35 |
-
|
36 |
try:
|
37 |
-
self.current_pdf = file.name
|
38 |
|
|
|
39 |
middleware = Middleware(id, create_collection=True)
|
40 |
|
|
|
41 |
pages = middleware.index(pdf_path=file.name, id=id, max_pages=max_pages)
|
42 |
|
|
|
43 |
self.indexed_docs[id] = True
|
44 |
-
|
45 |
return f"Uploaded and extracted {len(pages)} pages"
|
46 |
-
except Exception as e:
|
47 |
return f"Error processing PDF: {str(e)}"
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
print(f"Searching for query: {query}")
|
52 |
-
id = generate_uuid(state)
|
53 |
-
|
54 |
-
if
|
|
|
55 |
print("Please index documents first")
|
56 |
-
return "Please index documents first",
|
|
|
|
|
57 |
if not query:
|
58 |
print("Please enter a search query")
|
59 |
-
return "Please enter a search query",
|
60 |
-
|
61 |
-
try:
|
62 |
|
|
|
|
|
63 |
middleware = Middleware(id, create_collection=False)
|
64 |
-
|
65 |
-
search_results = middleware.search([query])[0]
|
66 |
|
67 |
-
|
|
|
68 |
|
69 |
-
|
|
|
|
|
|
|
70 |
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
print(f"Retrieved image
|
74 |
|
75 |
-
|
|
|
|
|
|
|
76 |
|
77 |
-
return img_path, rag_response
|
78 |
-
|
79 |
except Exception as e:
|
80 |
-
|
|
|
|
|
81 |
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
|
|
|
|
|
|
|
|
|
|
85 |
with gr.Blocks() as demo:
|
86 |
-
state = gr.State(value={"user_uuid": None})
|
87 |
|
|
|
88 |
gr.Markdown("# Colpali Milvus Multimodal RAG Demo")
|
89 |
-
gr.Markdown(
|
90 |
-
|
|
|
|
|
|
|
91 |
with gr.Tab("Upload PDF"):
|
92 |
with gr.Column():
|
|
|
93 |
file_input = gr.File(label="Upload PDF")
|
94 |
-
|
|
|
95 |
max_pages_input = gr.Slider(
|
96 |
minimum=1,
|
97 |
maximum=50,
|
@@ -99,38 +173,41 @@ def create_ui():
|
|
99 |
step=10,
|
100 |
label="Max pages to extract and index"
|
101 |
)
|
102 |
-
|
|
|
103 |
status = gr.Textbox(label="Indexing Status", interactive=False)
|
104 |
-
|
|
|
105 |
with gr.Tab("Query"):
|
106 |
with gr.Column():
|
|
|
107 |
query_input = gr.Textbox(label="Enter query")
|
108 |
-
|
109 |
-
#
|
110 |
-
# maximum=10,
|
111 |
-
# value=5,
|
112 |
-
# step=1,
|
113 |
-
# label="Number of results"
|
114 |
-
# )
|
115 |
search_btn = gr.Button("Query")
|
|
|
|
|
116 |
llm_answer = gr.Textbox(label="RAG Response", interactive=False)
|
|
|
|
|
117 |
images = gr.Image(label="Top page matching query")
|
118 |
-
|
119 |
-
# Event handlers
|
120 |
file_input.change(
|
121 |
fn=app.upload_and_convert,
|
122 |
inputs=[state, file_input, max_pages_input],
|
123 |
outputs=[status]
|
124 |
)
|
125 |
-
|
126 |
search_btn.click(
|
127 |
fn=app.search_documents,
|
128 |
inputs=[state, query_input],
|
129 |
outputs=[images, llm_answer]
|
130 |
)
|
131 |
-
|
132 |
-
return demo
|
133 |
|
|
|
|
|
|
|
134 |
if __name__ == "__main__":
|
135 |
-
demo = create_ui()
|
136 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import tempfile
|
3 |
import os
|
4 |
+
import fitz # PyMuPDF for working with PDF files
|
5 |
import uuid
|
6 |
|
7 |
+
# Importing middleware and RAG (Retrieval-Augmented Generation) components
|
8 |
from middleware import Middleware
|
9 |
from rag import Rag
|
10 |
|
11 |
+
rag = Rag() # Initializing RAG for question-answering functionality
|
12 |
|
13 |
+
# Function to generate a unique UUID for each user session
|
14 |
def generate_uuid(state):
|
15 |
# Check if UUID already exists in session state
|
16 |
if state["user_uuid"] is None:
|
17 |
# Generate a new UUID if not already set
|
18 |
state["user_uuid"] = str(uuid.uuid4())
|
|
|
19 |
return state["user_uuid"]
|
20 |
|
21 |
|
22 |
class PDFSearchApp:
|
23 |
+
"""Class to manage PDF upload, indexing, and querying."""
|
24 |
+
|
25 |
def __init__(self):
|
26 |
+
self.indexed_docs = {} # Dictionary to track indexed documents by user ID
|
27 |
+
self.current_pdf = None # Store the currently processed PDF
|
28 |
+
|
29 |
+
# Function to handle file uploads and convert PDFs into searchable data
|
30 |
def upload_and_convert(self, state, file, max_pages):
|
31 |
+
id = generate_uuid(state) # Get unique user ID
|
32 |
|
33 |
+
if file is None: # Check if a file was uploaded
|
34 |
return "No file uploaded"
|
35 |
|
36 |
print(f"Uploading file: {file.name}, id: {id}")
|
37 |
+
|
38 |
try:
|
39 |
+
self.current_pdf = file.name # Store the name of the uploaded file
|
40 |
|
41 |
+
# Initialize Middleware for indexing the PDF content
|
42 |
middleware = Middleware(id, create_collection=True)
|
43 |
|
44 |
+
# Index the specified number of pages from the PDF
|
45 |
pages = middleware.index(pdf_path=file.name, id=id, max_pages=max_pages)
|
46 |
|
47 |
+
# Mark the document as indexed for this user
|
48 |
self.indexed_docs[id] = True
|
49 |
+
|
50 |
return f"Uploaded and extracted {len(pages)} pages"
|
51 |
+
except Exception as e: # Handle errors during processing
|
52 |
return f"Error processing PDF: {str(e)}"
|
53 |
+
def search_documents(self, state, query, num_results=3): # Set num_results to return more pages
|
54 |
+
"""
|
55 |
+
Search for a query within indexed PDF documents and return multiple matching pages.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
state (dict): Session state containing user-specific data.
|
59 |
+
query (str): The user's search query.
|
60 |
+
num_results (int): Number of top results to return (default is 3).
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
tuple: (list of image paths, RAG response) or an error message if no match is found.
|
64 |
+
"""
|
65 |
print(f"Searching for query: {query}")
|
66 |
+
id = generate_uuid(state) # Get unique user ID
|
67 |
+
|
68 |
+
# Check if the document has been indexed
|
69 |
+
if not self.indexed_docs.get(id, False):
|
70 |
print("Please index documents first")
|
71 |
+
return "Please index documents first", None
|
72 |
+
|
73 |
+
# Check if a query was provided
|
74 |
if not query:
|
75 |
print("Please enter a search query")
|
76 |
+
return "Please enter a search query", None
|
|
|
|
|
77 |
|
78 |
+
try:
|
79 |
+
# Initialize Middleware for searching
|
80 |
middleware = Middleware(id, create_collection=False)
|
|
|
|
|
81 |
|
82 |
+
# Perform the search and retrieve the top results
|
83 |
+
search_results = middleware.search([query]) # Returns multiple matches
|
84 |
|
85 |
+
# Check if there are valid search results
|
86 |
+
if not search_results or not search_results[0]:
|
87 |
+
print("No relevant matches found in the PDF")
|
88 |
+
return "No relevant matches found in the PDF", None
|
89 |
|
90 |
+
# Extract multiple matching pages (up to num_results)
|
91 |
+
image_paths = []
|
92 |
+
for i in range(min(len(search_results[0]), num_results)): # Limit to num_results
|
93 |
+
page_num = search_results[0][i][1] + 1 # Convert zero-based index to one-based
|
94 |
+
img_path = f"pages/{id}/page_{page_num}.png"
|
95 |
+
image_paths.append(img_path)
|
96 |
|
97 |
+
print(f"Retrieved image paths: {image_paths}")
|
98 |
|
99 |
+
# Get an answer from the RAG model using multiple images
|
100 |
+
rag_response = rag.get_answer_from_gemini(query, image_paths)
|
101 |
+
|
102 |
+
return image_paths, rag_response # Return multiple image paths and RAG response
|
103 |
|
|
|
|
|
104 |
except Exception as e:
|
105 |
+
# Handle and log any errors that occur
|
106 |
+
print(f"Error during search: {e}")
|
107 |
+
return f"Error during search: {str(e)}", None
|
108 |
|
109 |
+
|
110 |
+
# # Function to handle search queries within indexed PDFs
|
111 |
+
# def search_documents(self, state, query, num_results=1):
|
112 |
+
# print(f"Searching for query: {query}")
|
113 |
+
# id = generate_uuid(state) # Get unique user ID
|
114 |
+
|
115 |
+
# # Check if the document has been indexed
|
116 |
+
# if not self.indexed_docs.get(id, False):
|
117 |
+
# print("Please index documents first")
|
118 |
+
# return "Please index documents first", "--"
|
119 |
+
|
120 |
+
# # Check if a query was provided
|
121 |
+
# if not query:
|
122 |
+
# print("Please enter a search query")
|
123 |
+
# return "Please enter a search query", "--"
|
124 |
+
|
125 |
+
# try:
|
126 |
+
# # Initialize Middleware for searching
|
127 |
+
# middleware = Middleware(id, create_collection=False)
|
128 |
+
|
129 |
+
# # Perform the search and retrieve the top result
|
130 |
+
# search_results = middleware.search([query])[0]
|
131 |
+
|
132 |
+
# # Extract the page number from the search results
|
133 |
+
# page_num = search_results[0][1] + 1
|
134 |
+
|
135 |
+
# print(f"Retrieved page number: {page_num}")
|
136 |
+
|
137 |
+
# # Construct the image path for the retrieved page
|
138 |
+
# img_path = f"pages/{id}/page_{page_num}.png"
|
139 |
+
# print(f"Retrieved image path: {img_path}")
|
140 |
+
|
141 |
+
# # Get an answer from the RAG model using the query and associated image
|
142 |
+
# rag_response = rag.get_answer_from_gemini(query, [img_path])
|
143 |
+
|
144 |
+
# return img_path, rag_response
|
145 |
+
# except Exception as e: # Handle errors during the search process
|
146 |
+
# return f"Error during search: {str(e)}", "--"
|
147 |
|
148 |
+
|
149 |
+
# Function to create the Gradio user interface
|
150 |
+
def create_ui():
|
151 |
+
app = PDFSearchApp() # Instantiate the PDFSearchApp class
|
152 |
+
|
153 |
with gr.Blocks() as demo:
|
154 |
+
state = gr.State(value={"user_uuid": None}) # Initialize session state
|
155 |
|
156 |
+
# Header and introduction markdown
|
157 |
gr.Markdown("# Colpali Milvus Multimodal RAG Demo")
|
158 |
+
gr.Markdown(
|
159 |
+
"This demo showcases how to use [Colpali](https://github.com/illuin-tech/colpali) embeddings with [Milvus](https://milvus.io/) and utilizing Gemini/OpenAI multimodal RAG for pdf search and Q&A."
|
160 |
+
)
|
161 |
+
|
162 |
+
# Upload PDF tab
|
163 |
with gr.Tab("Upload PDF"):
|
164 |
with gr.Column():
|
165 |
+
# Input for uploading files
|
166 |
file_input = gr.File(label="Upload PDF")
|
167 |
+
|
168 |
+
# Slider to select the maximum number of pages to index
|
169 |
max_pages_input = gr.Slider(
|
170 |
minimum=1,
|
171 |
maximum=50,
|
|
|
173 |
step=10,
|
174 |
label="Max pages to extract and index"
|
175 |
)
|
176 |
+
|
177 |
+
# Textbox to display indexing status
|
178 |
status = gr.Textbox(label="Indexing Status", interactive=False)
|
179 |
+
|
180 |
+
# Query tab for searching documents
|
181 |
with gr.Tab("Query"):
|
182 |
with gr.Column():
|
183 |
+
# Textbox for entering search queries
|
184 |
query_input = gr.Textbox(label="Enter query")
|
185 |
+
|
186 |
+
# Button to trigger the search
|
|
|
|
|
|
|
|
|
|
|
187 |
search_btn = gr.Button("Query")
|
188 |
+
|
189 |
+
# Textbox to display the response from RAG
|
190 |
llm_answer = gr.Textbox(label="RAG Response", interactive=False)
|
191 |
+
|
192 |
+
# Image display for the top-matching page
|
193 |
images = gr.Image(label="Top page matching query")
|
194 |
+
|
195 |
+
# Event handlers to connect UI components with backend functions
|
196 |
file_input.change(
|
197 |
fn=app.upload_and_convert,
|
198 |
inputs=[state, file_input, max_pages_input],
|
199 |
outputs=[status]
|
200 |
)
|
201 |
+
|
202 |
search_btn.click(
|
203 |
fn=app.search_documents,
|
204 |
inputs=[state, query_input],
|
205 |
outputs=[images, llm_answer]
|
206 |
)
|
|
|
|
|
207 |
|
208 |
+
return demo # Return the constructed UI
|
209 |
+
|
210 |
+
# Entry point to launch the application
|
211 |
if __name__ == "__main__":
|
212 |
+
demo = create_ui() # Create the Gradio interface
|
213 |
+
demo.launch() # Launch the app
|
colpali_manager.py
CHANGED
@@ -1,97 +1,132 @@
|
|
1 |
-
|
2 |
-
from colpali_engine.models
|
3 |
-
from colpali_engine.
|
4 |
-
from colpali_engine.utils.
|
5 |
-
from
|
6 |
-
import
|
7 |
-
|
8 |
-
|
9 |
-
from tqdm import tqdm
|
10 |
-
from PIL import Image
|
11 |
-
import os
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
model_name = "vidore/colpali-v1.2"
|
16 |
-
device = get_torch_device("cuda")
|
17 |
|
|
|
18 |
model = ColPali.from_pretrained(
|
19 |
model_name,
|
20 |
-
torch_dtype=torch.bfloat16,
|
21 |
-
device_map=device,
|
22 |
-
).eval()
|
23 |
|
|
|
24 |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
25 |
|
26 |
-
class ColpaliManager:
|
27 |
|
28 |
-
|
29 |
-
|
|
|
|
|
30 |
|
|
|
31 |
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
|
32 |
|
|
|
33 |
# self.device = get_torch_device(device)
|
34 |
-
|
35 |
# self.model = ColPali.from_pretrained(
|
36 |
# model_name,
|
37 |
# torch_dtype=torch.bfloat16,
|
38 |
# device_map=self.device,
|
39 |
# ).eval()
|
40 |
-
|
41 |
# self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
42 |
|
43 |
@spaces.GPU
|
44 |
def get_images(self, paths: list[str]) -> List[Image.Image]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
return [Image.open(path) for path in paths]
|
46 |
|
47 |
@spaces.GPU
|
48 |
-
def process_images(self, image_paths:list[str], batch_size=5):
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
|
|
|
|
|
|
50 |
print(f"Processing {len(image_paths)} image_paths")
|
51 |
-
|
|
|
52 |
images = self.get_images(image_paths)
|
53 |
|
|
|
54 |
dataloader = DataLoader(
|
55 |
dataset=ListDataset[str](images),
|
56 |
batch_size=batch_size,
|
57 |
shuffle=False,
|
58 |
-
collate_fn=lambda x: processor.process_images(x),
|
59 |
)
|
60 |
|
61 |
-
ds: List[torch.Tensor] = []
|
62 |
-
for batch_doc in tqdm(dataloader):
|
63 |
-
with torch.no_grad():
|
|
|
64 |
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
|
|
65 |
embeddings_doc = model(**batch_doc)
|
66 |
-
ds.extend(list(torch.unbind(embeddings_doc.to(device))))
|
67 |
-
|
|
|
68 |
ds_np = [d.float().cpu().numpy() for d in ds]
|
69 |
|
70 |
return ds_np
|
71 |
-
|
72 |
|
73 |
@spaces.GPU
|
74 |
def process_text(self, texts: list[str]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
print(f"Processing {len(texts)} texts")
|
76 |
|
|
|
77 |
dataloader = DataLoader(
|
78 |
dataset=ListDataset[str](texts),
|
79 |
-
batch_size=1,
|
80 |
shuffle=False,
|
81 |
-
collate_fn=lambda x: processor.process_queries(x),
|
82 |
)
|
83 |
|
84 |
-
qs: List[torch.Tensor] = []
|
85 |
-
for batch_query in dataloader:
|
86 |
-
with torch.no_grad():
|
|
|
87 |
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
|
|
88 |
embeddings_query = model(**batch_query)
|
89 |
|
90 |
-
qs.extend(list(torch.unbind(embeddings_query.to(device))))
|
91 |
|
|
|
92 |
qs_np = [q.float().cpu().numpy() for q in qs]
|
93 |
|
94 |
return qs_np
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
1 |
+
# Importing required modules and libraries
|
2 |
+
from colpali_engine.models import ColPali # Main ColPali model for embeddings
|
3 |
+
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor # Preprocessing for ColPali
|
4 |
+
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor # Base processor utility
|
5 |
+
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device # Torch utilities for dataset and device management
|
6 |
+
from torch.utils.data import DataLoader # PyTorch DataLoader for batching
|
7 |
+
import torch # PyTorch library
|
8 |
+
from typing import List, cast # Type annotations
|
9 |
+
from tqdm import tqdm # Progress bar utility
|
10 |
+
from PIL import Image # Image processing library
|
11 |
+
import os # OS module for file path handling
|
12 |
+
import spaces # Custom decorator module for GPU management
|
13 |
+
|
14 |
+
# Setting model name and initializing device
|
15 |
model_name = "vidore/colpali-v1.2"
|
16 |
+
device = get_torch_device("cuda") # Get the available CUDA device
|
17 |
|
18 |
+
# Load the ColPali model with the specified configuration
|
19 |
model = ColPali.from_pretrained(
|
20 |
model_name,
|
21 |
+
torch_dtype=torch.bfloat16, # Use bfloat16 for reduced precision
|
22 |
+
device_map=device, # Map the model to the selected device
|
23 |
+
).eval() # Set the model to evaluation mode
|
24 |
|
25 |
+
# Initialize the processor for handling image and text inputs
|
26 |
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
27 |
|
|
|
28 |
|
29 |
+
class ColpaliManager:
|
30 |
+
"""
|
31 |
+
A class to manage the processing of images and text using the ColPali model.
|
32 |
+
"""
|
33 |
|
34 |
+
def __init__(self, device="cuda", model_name="vidore/colpali-v1.2"):
|
35 |
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
|
36 |
|
37 |
+
# Uncomment the below lines if the class should initialize its own model and processor
|
38 |
# self.device = get_torch_device(device)
|
|
|
39 |
# self.model = ColPali.from_pretrained(
|
40 |
# model_name,
|
41 |
# torch_dtype=torch.bfloat16,
|
42 |
# device_map=self.device,
|
43 |
# ).eval()
|
|
|
44 |
# self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
|
45 |
|
46 |
@spaces.GPU
|
47 |
def get_images(self, paths: list[str]) -> List[Image.Image]:
|
48 |
+
"""
|
49 |
+
Load images from the given file paths.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
paths (list[str]): List of file paths to images.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
List[Image.Image]: List of loaded PIL Image objects.
|
56 |
+
"""
|
57 |
return [Image.open(path) for path in paths]
|
58 |
|
59 |
@spaces.GPU
|
60 |
+
def process_images(self, image_paths: list[str], batch_size=5):
|
61 |
+
"""
|
62 |
+
Process a list of image paths to generate embeddings.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
image_paths (list[str]): List of image file paths.
|
66 |
+
batch_size (int): Batch size for processing images.
|
67 |
|
68 |
+
Returns:
|
69 |
+
list: List of image embeddings as NumPy arrays.
|
70 |
+
"""
|
71 |
print(f"Processing {len(image_paths)} image_paths")
|
72 |
+
|
73 |
+
# Load images
|
74 |
images = self.get_images(image_paths)
|
75 |
|
76 |
+
# Create a DataLoader for batching the images
|
77 |
dataloader = DataLoader(
|
78 |
dataset=ListDataset[str](images),
|
79 |
batch_size=batch_size,
|
80 |
shuffle=False,
|
81 |
+
collate_fn=lambda x: processor.process_images(x), # Process images using the processor
|
82 |
)
|
83 |
|
84 |
+
ds: List[torch.Tensor] = [] # Initialize a list to store embeddings
|
85 |
+
for batch_doc in tqdm(dataloader): # Iterate through batches with a progress bar
|
86 |
+
with torch.no_grad(): # Disable gradient calculations for inference
|
87 |
+
# Move batch to the model's device
|
88 |
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
|
89 |
+
# Generate embeddings
|
90 |
embeddings_doc = model(**batch_doc)
|
91 |
+
ds.extend(list(torch.unbind(embeddings_doc.to(device)))) # Append each embedding to the list
|
92 |
+
|
93 |
+
# Convert embeddings to NumPy arrays
|
94 |
ds_np = [d.float().cpu().numpy() for d in ds]
|
95 |
|
96 |
return ds_np
|
|
|
97 |
|
98 |
@spaces.GPU
|
99 |
def process_text(self, texts: list[str]):
|
100 |
+
"""
|
101 |
+
Process a list of text inputs to generate embeddings.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
texts (list[str]): List of text inputs.
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
list: List of text embeddings as NumPy arrays.
|
108 |
+
"""
|
109 |
print(f"Processing {len(texts)} texts")
|
110 |
|
111 |
+
# Create a DataLoader for batching the texts
|
112 |
dataloader = DataLoader(
|
113 |
dataset=ListDataset[str](texts),
|
114 |
+
batch_size=1, # Process texts one at a time
|
115 |
shuffle=False,
|
116 |
+
collate_fn=lambda x: processor.process_queries(x), # Process texts using the processor
|
117 |
)
|
118 |
|
119 |
+
qs: List[torch.Tensor] = [] # Initialize a list to store text embeddings
|
120 |
+
for batch_query in dataloader: # Iterate through batches
|
121 |
+
with torch.no_grad(): # Disable gradient calculations for inference
|
122 |
+
# Move batch to the model's device
|
123 |
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
|
124 |
+
# Generate embeddings
|
125 |
embeddings_query = model(**batch_query)
|
126 |
|
127 |
+
qs.extend(list(torch.unbind(embeddings_query.to(device)))) # Append each embedding to the list
|
128 |
|
129 |
+
# Convert embeddings to NumPy arrays
|
130 |
qs_np = [q.float().cpu().numpy() for q in qs]
|
131 |
|
132 |
return qs_np
|
|
|
|
|
|
middleware.py
CHANGED
@@ -1,56 +1,97 @@
|
|
1 |
-
|
2 |
-
from
|
3 |
-
from
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
pdf_manager = PdfManager()
|
8 |
-
colpali_manager = ColpaliManager()
|
9 |
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
class Middleware:
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
|
15 |
milvus_db_name = f"milvus_{hashed_id}.db"
|
|
|
|
|
16 |
self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
|
17 |
|
18 |
-
def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
|
21 |
|
|
|
22 |
image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
|
23 |
-
|
24 |
print(f"Saved {len(image_paths)} images")
|
25 |
|
|
|
26 |
colbert_vecs = colpali_manager.process_images(image_paths)
|
27 |
|
|
|
28 |
images_data = [{
|
29 |
-
"colbert_vecs": colbert_vecs[i],
|
30 |
-
"filepath": image_paths[i]
|
31 |
} for i in range(len(image_paths))]
|
32 |
|
33 |
print(f"Inserting {len(images_data)} images data to Milvus")
|
34 |
|
|
|
35 |
self.milvus_manager.insert_images_data(images_data)
|
36 |
|
37 |
print("Indexing completed")
|
38 |
|
39 |
-
return image_paths
|
40 |
-
|
41 |
|
42 |
-
|
43 |
def search(self, search_queries: list[str]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
print(f"Searching for {len(search_queries)} queries")
|
45 |
|
46 |
-
final_res = []
|
47 |
|
48 |
for query in search_queries:
|
49 |
print(f"Searching for query: {query}")
|
|
|
|
|
50 |
query_vec = colpali_manager.process_text([query])[0]
|
|
|
|
|
51 |
search_res = self.milvus_manager.search(query_vec, topk=1)
|
|
|
52 |
print(f"Search result: {search_res} for query: {query}")
|
53 |
-
final_res.append(search_res)
|
54 |
|
55 |
-
|
|
|
56 |
|
|
|
|
1 |
+
# Import necessary modules and classes
|
2 |
+
from colpali_manager import ColpaliManager # Manages processing of images and text with the ColPali model
|
3 |
+
from milvus_manager import MilvusManager # Manages interactions with the Milvus database
|
4 |
+
from pdf_manager import PdfManager # Handles PDF processing tasks
|
5 |
+
import hashlib # Library for creating hashed identifiers
|
|
|
|
|
|
|
6 |
|
7 |
+
# Initialize managers
|
8 |
+
pdf_manager = PdfManager() # PDF manager instance for handling PDF-related operations
|
9 |
+
colpali_manager = ColpaliManager() # ColPali manager instance for processing images and text
|
10 |
|
11 |
|
12 |
class Middleware:
|
13 |
+
"""
|
14 |
+
Middleware class that integrates PDF processing, image embedding, and database indexing/searching.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, id: str, create_collection=True):
|
18 |
+
"""
|
19 |
+
Initialize the Middleware with a unique identifier and Milvus database setup.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
id (str): Unique identifier for the user/session.
|
23 |
+
create_collection (bool): Whether to create a new collection in the Milvus database.
|
24 |
+
"""
|
25 |
+
# Generate a hashed ID for the Milvus database name
|
26 |
hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
|
27 |
milvus_db_name = f"milvus_{hashed_id}.db"
|
28 |
+
|
29 |
+
# Initialize the Milvus manager with the generated database name
|
30 |
self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
|
31 |
|
32 |
+
def index(self, pdf_path: str, id: str, max_pages: int, pages: list[int] = None):
|
33 |
+
"""
|
34 |
+
Index the content of a PDF file into the Milvus database.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
pdf_path (str): Path to the PDF file.
|
38 |
+
id (str): Unique identifier for the session.
|
39 |
+
max_pages (int): Maximum number of pages to extract and index.
|
40 |
+
pages (list[int], optional): Specific pages to extract (default is None for all).
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
list[str]: List of paths to the saved image files.
|
44 |
+
"""
|
45 |
print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
|
46 |
|
47 |
+
# Convert PDF pages into image files and save them
|
48 |
image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
|
|
|
49 |
print(f"Saved {len(image_paths)} images")
|
50 |
|
51 |
+
# Generate image embeddings using the ColPali model
|
52 |
colbert_vecs = colpali_manager.process_images(image_paths)
|
53 |
|
54 |
+
# Prepare data for insertion into Milvus
|
55 |
images_data = [{
|
56 |
+
"colbert_vecs": colbert_vecs[i], # Image embeddings
|
57 |
+
"filepath": image_paths[i] # Corresponding image file path
|
58 |
} for i in range(len(image_paths))]
|
59 |
|
60 |
print(f"Inserting {len(images_data)} images data to Milvus")
|
61 |
|
62 |
+
# Insert the image data into the Milvus database
|
63 |
self.milvus_manager.insert_images_data(images_data)
|
64 |
|
65 |
print("Indexing completed")
|
66 |
|
67 |
+
return image_paths # Return the list of saved image paths
|
|
|
68 |
|
|
|
69 |
def search(self, search_queries: list[str]):
|
70 |
+
"""
|
71 |
+
Search for matching results in the indexed database based on text queries.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
search_queries (list[str]): List of search queries.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
list: Search results for each query.
|
78 |
+
"""
|
79 |
print(f"Searching for {len(search_queries)} queries")
|
80 |
|
81 |
+
final_res = [] # List to store the final search results
|
82 |
|
83 |
for query in search_queries:
|
84 |
print(f"Searching for query: {query}")
|
85 |
+
|
86 |
+
# Process the query text to generate an embedding
|
87 |
query_vec = colpali_manager.process_text([query])[0]
|
88 |
+
|
89 |
+
# Perform the search in the Milvus database
|
90 |
search_res = self.milvus_manager.search(query_vec, topk=1)
|
91 |
+
|
92 |
print(f"Search result: {search_res} for query: {query}")
|
|
|
93 |
|
94 |
+
# Append the search results to the final results list
|
95 |
+
final_res.append(search_res)
|
96 |
|
97 |
+
return final_res # Return all search results
|
milvus_manager.py
CHANGED
@@ -1,69 +1,99 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
|
6 |
class MilvusManager:
|
|
|
|
|
|
|
|
|
|
|
7 |
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
self.collection_name = collection_name
|
|
|
|
|
|
|
10 |
if self.client.has_collection(collection_name=self.collection_name):
|
11 |
self.client.load_collection(collection_name)
|
12 |
-
self.dim = dim
|
13 |
|
14 |
if create_collection:
|
15 |
-
self.create_collection()
|
16 |
-
self.create_index()
|
17 |
-
|
18 |
|
19 |
def create_collection(self):
|
|
|
|
|
|
|
|
|
20 |
if self.client.has_collection(collection_name=self.collection_name):
|
21 |
self.client.drop_collection(collection_name=self.collection_name)
|
|
|
|
|
22 |
schema = self.client.create_schema(
|
23 |
-
auto_id=True,
|
24 |
-
enable_dynamic_fields=True,
|
25 |
)
|
26 |
-
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
|
27 |
schema.add_field(
|
28 |
-
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
|
29 |
)
|
30 |
-
schema.add_field(field_name="seq_id", datatype=DataType.INT16)
|
31 |
-
schema.add_field(field_name="doc_id", datatype=DataType.INT64)
|
32 |
-
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
|
33 |
|
|
|
34 |
self.client.create_collection(
|
35 |
collection_name=self.collection_name, schema=schema
|
36 |
)
|
37 |
|
38 |
def create_index(self):
|
|
|
|
|
|
|
|
|
39 |
self.client.release_collection(collection_name=self.collection_name)
|
40 |
-
self.client.drop_index(
|
41 |
-
|
42 |
-
|
43 |
index_params = self.client.prepare_index_params()
|
44 |
index_params.add_index(
|
45 |
field_name="vector",
|
46 |
index_name="vector_index",
|
47 |
-
index_type="HNSW",
|
48 |
-
metric_type="IP",
|
49 |
params={
|
50 |
-
"M": 16,
|
51 |
-
"efConstruction": 500,
|
52 |
},
|
53 |
)
|
54 |
|
|
|
55 |
self.client.create_index(
|
56 |
collection_name=self.collection_name, index_params=index_params, sync=True
|
57 |
)
|
58 |
|
59 |
def create_scalar_index(self):
|
|
|
|
|
|
|
60 |
self.client.release_collection(collection_name=self.collection_name)
|
61 |
|
62 |
index_params = self.client.prepare_index_params()
|
63 |
index_params.add_index(
|
64 |
field_name="doc_id",
|
65 |
index_name="int32_index",
|
66 |
-
index_type="INVERTED",
|
67 |
)
|
68 |
|
69 |
self.client.create_index(
|
@@ -71,14 +101,26 @@ class MilvusManager:
|
|
71 |
)
|
72 |
|
73 |
def search(self, data, topk):
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
results = self.client.search(
|
76 |
self.collection_name,
|
77 |
data,
|
78 |
-
limit=
|
79 |
-
output_fields=["vector", "seq_id", "doc_id"],
|
80 |
search_params=search_params,
|
81 |
)
|
|
|
|
|
82 |
doc_ids = set()
|
83 |
for r_id in range(len(results)):
|
84 |
for r in range(len(results[r_id])):
|
@@ -86,19 +128,22 @@ class MilvusManager:
|
|
86 |
|
87 |
scores = []
|
88 |
|
|
|
89 |
def rerank_single_doc(doc_id, data, client, collection_name):
|
90 |
doc_colbert_vecs = client.query(
|
91 |
collection_name=collection_name,
|
92 |
-
filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
|
93 |
-
output_fields=["seq_id", "vector", "doc"],
|
94 |
-
limit=1000,
|
95 |
)
|
|
|
96 |
doc_vecs = np.vstack(
|
97 |
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
|
98 |
)
|
99 |
score = np.dot(data, doc_vecs.T).max(1).sum()
|
100 |
return (score, doc_id)
|
101 |
|
|
|
102 |
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
|
103 |
futures = {
|
104 |
executor.submit(
|
@@ -110,20 +155,25 @@ class MilvusManager:
|
|
110 |
score, doc_id = future.result()
|
111 |
scores.append((score, doc_id))
|
112 |
|
|
|
113 |
scores.sort(key=lambda x: x[0], reverse=True)
|
114 |
-
if len(scores) >= topk
|
115 |
-
return scores[:topk]
|
116 |
-
else:
|
117 |
-
return scores
|
118 |
|
119 |
def insert(self, data):
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
colbert_vecs = [vec for vec in data["colbert_vecs"]]
|
121 |
seq_length = len(colbert_vecs)
|
122 |
doc_ids = [data["doc_id"] for i in range(seq_length)]
|
123 |
seq_ids = list(range(seq_length))
|
124 |
docs = [""] * seq_length
|
125 |
-
docs[0] = data["filepath"]
|
126 |
|
|
|
127 |
self.client.insert(
|
128 |
self.collection_name,
|
129 |
[
|
@@ -137,11 +187,17 @@ class MilvusManager:
|
|
137 |
],
|
138 |
)
|
139 |
|
|
|
|
|
|
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
images_data = []
|
144 |
|
|
|
|
|
|
|
|
|
145 |
for i in range(len(images_with_vectors)):
|
146 |
data = {
|
147 |
"colbert_vecs": images_with_vectors[i]["colbert_vecs"],
|
@@ -149,14 +205,15 @@ class MilvusManager:
|
|
149 |
"filepath": images_with_vectors[i]["filepath"],
|
150 |
}
|
151 |
images_data.append(data)
|
152 |
-
|
153 |
return images_data
|
154 |
|
155 |
-
|
156 |
def insert_images_data(self, image_data):
|
157 |
-
|
|
|
158 |
|
|
|
|
|
|
|
|
|
159 |
for i in range(len(data)):
|
160 |
-
self.insert(data[i])
|
161 |
-
|
162 |
-
|
|
|
1 |
+
# Import necessary modules
|
2 |
+
from pymilvus import MilvusClient, DataType # Milvus client and data type definitions
|
3 |
+
import numpy as np # For numerical operations
|
4 |
+
import concurrent.futures # For concurrent execution of tasks
|
5 |
|
6 |
class MilvusManager:
|
7 |
+
"""
|
8 |
+
A manager class for interacting with the Milvus database, handling collection creation,
|
9 |
+
data insertion, and search functionality.
|
10 |
+
"""
|
11 |
+
|
12 |
def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
|
13 |
+
"""
|
14 |
+
Initialize the MilvusManager.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
milvus_uri (str): URI for connecting to the Milvus server.
|
18 |
+
collection_name (str): Name of the collection in Milvus.
|
19 |
+
create_collection (bool): Whether to create a new collection.
|
20 |
+
dim (int): Dimensionality of the vector embeddings (default is 128).
|
21 |
+
"""
|
22 |
+
self.client = MilvusClient(uri=milvus_uri) # Initialize the Milvus client
|
23 |
self.collection_name = collection_name
|
24 |
+
self.dim = dim
|
25 |
+
|
26 |
+
# Load the collection if it exists, otherwise create it
|
27 |
if self.client.has_collection(collection_name=self.collection_name):
|
28 |
self.client.load_collection(collection_name)
|
|
|
29 |
|
30 |
if create_collection:
|
31 |
+
self.create_collection() # Create a new collection
|
32 |
+
self.create_index() # Create an index for the collection
|
|
|
33 |
|
34 |
def create_collection(self):
|
35 |
+
"""
|
36 |
+
Create a new collection in Milvus with a predefined schema.
|
37 |
+
"""
|
38 |
+
# Drop the collection if it already exists
|
39 |
if self.client.has_collection(collection_name=self.collection_name):
|
40 |
self.client.drop_collection(collection_name=self.collection_name)
|
41 |
+
|
42 |
+
# Define the schema for the collection
|
43 |
schema = self.client.create_schema(
|
44 |
+
auto_id=True, # Enable automatic ID assignment
|
45 |
+
enable_dynamic_fields=True, # Allow dynamic fields
|
46 |
)
|
47 |
+
schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) # Primary key
|
48 |
schema.add_field(
|
49 |
+
field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim # Vector field
|
50 |
)
|
51 |
+
schema.add_field(field_name="seq_id", datatype=DataType.INT16) # Sequence ID
|
52 |
+
schema.add_field(field_name="doc_id", datatype=DataType.INT64) # Document ID
|
53 |
+
schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) # Document path
|
54 |
|
55 |
+
# Create the collection with the specified schema
|
56 |
self.client.create_collection(
|
57 |
collection_name=self.collection_name, schema=schema
|
58 |
)
|
59 |
|
60 |
def create_index(self):
|
61 |
+
"""
|
62 |
+
Create an HNSW index for the vector field in the collection.
|
63 |
+
"""
|
64 |
+
# Release the collection before updating the index
|
65 |
self.client.release_collection(collection_name=self.collection_name)
|
66 |
+
self.client.drop_index(collection_name=self.collection_name, index_name="vector")
|
67 |
+
|
68 |
+
# Define the HNSW index parameters
|
69 |
index_params = self.client.prepare_index_params()
|
70 |
index_params.add_index(
|
71 |
field_name="vector",
|
72 |
index_name="vector_index",
|
73 |
+
index_type="HNSW", # Hierarchical Navigable Small World graph index
|
74 |
+
metric_type="IP", # Inner Product (dot product) as similarity metric
|
75 |
params={
|
76 |
+
"M": 16, # Number of candidate connections
|
77 |
+
"efConstruction": 500, # Construction complexity
|
78 |
},
|
79 |
)
|
80 |
|
81 |
+
# Create the index and synchronize with the server
|
82 |
self.client.create_index(
|
83 |
collection_name=self.collection_name, index_params=index_params, sync=True
|
84 |
)
|
85 |
|
86 |
def create_scalar_index(self):
|
87 |
+
"""
|
88 |
+
Create an inverted index for scalar fields such as document IDs.
|
89 |
+
"""
|
90 |
self.client.release_collection(collection_name=self.collection_name)
|
91 |
|
92 |
index_params = self.client.prepare_index_params()
|
93 |
index_params.add_index(
|
94 |
field_name="doc_id",
|
95 |
index_name="int32_index",
|
96 |
+
index_type="INVERTED", # Inverted index for scalar data
|
97 |
)
|
98 |
|
99 |
self.client.create_index(
|
|
|
101 |
)
|
102 |
|
103 |
def search(self, data, topk):
|
104 |
+
"""
|
105 |
+
Search for the top-k most similar vectors in the collection.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
data (array-like): Query vector.
|
109 |
+
topk (int): Number of top results to return.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
list: Sorted list of top-k results.
|
113 |
+
"""
|
114 |
+
search_params = {"metric_type": "IP", "params": {}} # Search parameters for Inner Product
|
115 |
results = self.client.search(
|
116 |
self.collection_name,
|
117 |
data,
|
118 |
+
limit=50, # Initial retrieval limit
|
119 |
+
output_fields=["vector", "seq_id", "doc_id"], # Fields to include in the output
|
120 |
search_params=search_params,
|
121 |
)
|
122 |
+
|
123 |
+
# Collect unique document IDs from the search results
|
124 |
doc_ids = set()
|
125 |
for r_id in range(len(results)):
|
126 |
for r in range(len(results[r_id])):
|
|
|
128 |
|
129 |
scores = []
|
130 |
|
131 |
+
# Function to rerank a single document based on its relevance to the query
|
132 |
def rerank_single_doc(doc_id, data, client, collection_name):
|
133 |
doc_colbert_vecs = client.query(
|
134 |
collection_name=collection_name,
|
135 |
+
filter=f"doc_id in [{doc_id}, {doc_id + 1}]", # Query documents by ID
|
136 |
+
output_fields=["seq_id", "vector", "doc"], # Fields to retrieve
|
137 |
+
limit=1000, # Retrieve a maximum of 1000 vectors per document
|
138 |
)
|
139 |
+
# Compute the maximum similarity score for the document
|
140 |
doc_vecs = np.vstack(
|
141 |
[doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
|
142 |
)
|
143 |
score = np.dot(data, doc_vecs.T).max(1).sum()
|
144 |
return (score, doc_id)
|
145 |
|
146 |
+
# Use multithreading to rerank documents in parallel
|
147 |
with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
|
148 |
futures = {
|
149 |
executor.submit(
|
|
|
155 |
score, doc_id = future.result()
|
156 |
scores.append((score, doc_id))
|
157 |
|
158 |
+
# Sort scores in descending order and return the top-k results
|
159 |
scores.sort(key=lambda x: x[0], reverse=True)
|
160 |
+
return scores[:topk] if len(scores) >= topk else scores
|
|
|
|
|
|
|
161 |
|
162 |
def insert(self, data):
|
163 |
+
"""
|
164 |
+
Insert a batch of data into the collection.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
data (dict): Dictionary containing vector embeddings and metadata.
|
168 |
+
"""
|
169 |
colbert_vecs = [vec for vec in data["colbert_vecs"]]
|
170 |
seq_length = len(colbert_vecs)
|
171 |
doc_ids = [data["doc_id"] for i in range(seq_length)]
|
172 |
seq_ids = list(range(seq_length))
|
173 |
docs = [""] * seq_length
|
174 |
+
docs[0] = data["filepath"] # Store file path in the first entry
|
175 |
|
176 |
+
# Insert the data into the collection
|
177 |
self.client.insert(
|
178 |
self.collection_name,
|
179 |
[
|
|
|
187 |
],
|
188 |
)
|
189 |
|
190 |
+
def get_images_as_doc(self, images_with_vectors: list):
|
191 |
+
"""
|
192 |
+
Convert image data with vectors into document-like format for insertion.
|
193 |
|
194 |
+
Args:
|
195 |
+
images_with_vectors (list): List of dictionaries containing image vectors and file paths.
|
|
|
196 |
|
197 |
+
Returns:
|
198 |
+
list: Transformed data ready for insertion.
|
199 |
+
"""
|
200 |
+
images_data = []
|
201 |
for i in range(len(images_with_vectors)):
|
202 |
data = {
|
203 |
"colbert_vecs": images_with_vectors[i]["colbert_vecs"],
|
|
|
205 |
"filepath": images_with_vectors[i]["filepath"],
|
206 |
}
|
207 |
images_data.append(data)
|
|
|
208 |
return images_data
|
209 |
|
|
|
210 |
def insert_images_data(self, image_data):
|
211 |
+
"""
|
212 |
+
Insert processed image data into the collection.
|
213 |
|
214 |
+
Args:
|
215 |
+
image_data (list): List of image data dictionaries.
|
216 |
+
"""
|
217 |
+
data = self.get_images_as_doc(image_data)
|
218 |
for i in range(len(data)):
|
219 |
+
self.insert(data[i]) # Insert each item individually
|
|
|
|
pdf_manager.py
CHANGED
@@ -1,42 +1,79 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
import
|
|
|
4 |
|
5 |
class PdfManager:
|
|
|
|
|
|
|
|
|
|
|
6 |
def __init__(self):
|
|
|
|
|
|
|
|
|
7 |
pass
|
8 |
-
|
9 |
def clear_and_recreate_dir(self, output_folder):
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
print(f"Clearing output folder {output_folder}")
|
11 |
|
|
|
12 |
if os.path.exists(output_folder):
|
13 |
-
shutil.rmtree(output_folder)
|
14 |
|
|
|
15 |
os.makedirs(output_folder)
|
16 |
|
17 |
def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
output_folder = f"pages/{id}/"
|
19 |
-
images = convert_from_path(pdf_path)
|
20 |
|
|
|
|
|
21 |
print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
|
22 |
|
|
|
23 |
self.clear_and_recreate_dir(output_folder)
|
24 |
|
25 |
-
num_page_processed = 0
|
26 |
|
|
|
27 |
for i, image in enumerate(images):
|
|
|
28 |
if max_pages and num_page_processed >= max_pages:
|
29 |
break
|
30 |
|
|
|
31 |
if pages and i not in pages:
|
32 |
continue
|
33 |
|
|
|
34 |
full_save_path = f"{output_folder}/page_{i + 1}.png"
|
35 |
|
36 |
-
#
|
37 |
-
|
38 |
image.save(full_save_path, "PNG")
|
39 |
|
40 |
-
num_page_processed += 1
|
41 |
|
|
|
42 |
return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
|
|
|
1 |
+
# Import necessary modules
|
2 |
+
from pdf2image import convert_from_path # Convert PDF pages to images
|
3 |
+
import os # For file and directory operations
|
4 |
+
import shutil # For removing and recreating directories
|
5 |
|
6 |
class PdfManager:
|
7 |
+
"""
|
8 |
+
A manager class for handling PDF-related operations, such as converting pages to images
|
9 |
+
and managing output directories.
|
10 |
+
"""
|
11 |
+
|
12 |
def __init__(self):
|
13 |
+
"""
|
14 |
+
Initialize the PdfManager.
|
15 |
+
Currently, no attributes are set during initialization.
|
16 |
+
"""
|
17 |
pass
|
18 |
+
|
19 |
def clear_and_recreate_dir(self, output_folder):
|
20 |
+
"""
|
21 |
+
Clear the specified directory and recreate it.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
output_folder (str): Path to the directory to be cleared and recreated.
|
25 |
+
"""
|
26 |
print(f"Clearing output folder {output_folder}")
|
27 |
|
28 |
+
# Remove the directory if it exists
|
29 |
if os.path.exists(output_folder):
|
30 |
+
shutil.rmtree(output_folder) # Delete the folder and its contents
|
31 |
|
32 |
+
# Recreate the directory
|
33 |
os.makedirs(output_folder)
|
34 |
|
35 |
def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
|
36 |
+
"""
|
37 |
+
Convert PDF pages to images and save them to a specified directory.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
id (str): Unique identifier for the output folder.
|
41 |
+
pdf_path (str): Path to the PDF file to be processed.
|
42 |
+
max_pages (int): Maximum number of pages to convert and save.
|
43 |
+
pages (list[int], optional): Specific page numbers to convert (default is None for all).
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
list[str]: List of paths to the saved images.
|
47 |
+
"""
|
48 |
+
# Define the output folder for the images
|
49 |
output_folder = f"pages/{id}/"
|
|
|
50 |
|
51 |
+
# Convert the PDF pages to images
|
52 |
+
images = convert_from_path(pdf_path)
|
53 |
print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
|
54 |
|
55 |
+
# Clear the existing directory and recreate it
|
56 |
self.clear_and_recreate_dir(output_folder)
|
57 |
|
58 |
+
num_page_processed = 0 # Counter for the number of pages processed
|
59 |
|
60 |
+
# Iterate through the converted images
|
61 |
for i, image in enumerate(images):
|
62 |
+
# Stop processing if the maximum number of pages is reached
|
63 |
if max_pages and num_page_processed >= max_pages:
|
64 |
break
|
65 |
|
66 |
+
# Skip pages not in the specified list (if provided)
|
67 |
if pages and i not in pages:
|
68 |
continue
|
69 |
|
70 |
+
# Define the save path for the current page
|
71 |
full_save_path = f"{output_folder}/page_{i + 1}.png"
|
72 |
|
73 |
+
# Save the image in PNG format
|
|
|
74 |
image.save(full_save_path, "PNG")
|
75 |
|
76 |
+
num_page_processed += 1 # Increment the processed page counter
|
77 |
|
78 |
+
# Return the paths of the saved images
|
79 |
return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
|
rag.py
CHANGED
@@ -1,104 +1,188 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
import
|
4 |
-
|
5 |
-
from typing import List
|
6 |
-
from utils import encode_image
|
7 |
-
from PIL import Image
|
8 |
|
9 |
class Rag:
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
def get_answer_from_gemini(self, query, imagePaths):
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
images = [Image.open(path) for path in imagePaths]
|
20 |
-
|
21 |
-
chat = model.start_chat()
|
22 |
|
23 |
-
|
|
|
|
|
24 |
|
25 |
-
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
return answer
|
30 |
-
|
31 |
-
except Exception as e:
|
32 |
-
print(f"An error occurred while querying Gemini: {e}")
|
33 |
-
return f"Error: {str(e)}"
|
34 |
-
|
35 |
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
|
38 |
|
39 |
-
try:
|
|
|
40 |
payload = self.__get_openai_api_payload(query, imagesPaths)
|
41 |
|
|
|
42 |
headers = {
|
43 |
"Content-Type": "application/json",
|
44 |
-
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
45 |
}
|
46 |
-
|
|
|
47 |
response = requests.post(
|
48 |
url="https://api.openai.com/v1/chat/completions",
|
49 |
headers=headers,
|
50 |
json=payload
|
51 |
)
|
52 |
-
response.raise_for_status() # Raise an
|
53 |
-
|
|
|
54 |
answer = response.json()["choices"][0]["message"]["content"]
|
55 |
-
|
56 |
-
print(answer)
|
57 |
-
|
58 |
return answer
|
59 |
-
|
60 |
except Exception as e:
|
|
|
61 |
print(f"An error occurred while querying OpenAI: {e}")
|
62 |
return None
|
|
|
|
|
|
|
63 |
|
|
|
|
|
|
|
64 |
|
65 |
-
|
66 |
-
|
|
|
|
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
for imagePath in imagesPaths:
|
69 |
-
base64_image = encode_image(imagePath)
|
70 |
image_payload.append({
|
71 |
"type": "image_url",
|
72 |
"image_url": {
|
73 |
-
"url": f"data:image/jpeg;base64,{base64_image}"
|
74 |
}
|
75 |
})
|
76 |
|
|
|
77 |
payload = {
|
78 |
-
"model": "gpt-4o",
|
79 |
"messages": [
|
80 |
{
|
81 |
-
"role": "user",
|
82 |
"content": [
|
83 |
{
|
84 |
"type": "text",
|
85 |
-
"text": query
|
86 |
},
|
87 |
-
*image_payload
|
88 |
]
|
89 |
}
|
90 |
],
|
91 |
-
"max_tokens": 1024
|
92 |
}
|
93 |
|
94 |
return payload
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
# if __name__ == "__main__":
|
99 |
-
# rag = Rag()
|
100 |
-
|
101 |
-
# query = "Based on attached images, how many new cases were reported during second wave peak"
|
102 |
-
# imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
|
103 |
-
|
104 |
-
# rag.get_answer_from_gemini(query, imagesPaths)
|
|
|
1 |
+
# Import required libraries
|
2 |
+
import requests # For making HTTP requests
|
3 |
+
import os # For accessing environment variables
|
4 |
+
import google.generativeai as genai # For interacting with Google's Generative AI APIs
|
5 |
+
from typing import List # For type annotations
|
6 |
+
from utils import encode_image # Utility function to encode images as base64
|
7 |
+
from PIL import Image # For image processing
|
8 |
|
9 |
class Rag:
|
10 |
+
"""
|
11 |
+
A class for interacting with Generative AI models (Gemini and OpenAI) to retrieve answers
|
12 |
+
based on user queries and associated images.
|
13 |
+
"""
|
14 |
|
15 |
+
# def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str:
|
16 |
+
# """
|
17 |
+
# Query the Gemini model with a text query and associated images.
|
18 |
|
19 |
+
# Args:
|
20 |
+
# query (str): The user's query.
|
21 |
+
# imagePaths (List[str]): List of file paths to images.
|
22 |
|
23 |
+
# Returns:
|
24 |
+
# str: The response text from the Gemini model.
|
25 |
+
# """
|
26 |
+
# print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
|
|
|
|
|
|
|
27 |
|
28 |
+
# try:
|
29 |
+
# # Configure the Gemini API client using the API key from environment variables
|
30 |
+
# genai.configure(api_key=os.environ['GEMINI_API_KEY'])
|
31 |
|
32 |
+
# # Initialize the Gemini generative model
|
33 |
+
# model = genai.GenerativeModel('gemini-1.5-flash')
|
34 |
|
35 |
+
# # Load images from the given paths
|
36 |
+
# images = [Image.open(path) for path in imagePaths]
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
# # Start a new chat session
|
39 |
+
# chat = model.start_chat()
|
40 |
+
|
41 |
+
# # Send the query and images to the model
|
42 |
+
# response = chat.send_message([*images, query])
|
43 |
+
|
44 |
+
# # Extract the response text
|
45 |
+
# answer = response.text
|
46 |
+
|
47 |
+
# print(answer) # Log the answer
|
48 |
+
|
49 |
+
# return answer
|
50 |
+
|
51 |
+
# except Exception as e:
|
52 |
+
# # Handle and log any errors that occur
|
53 |
+
# print(f"An error occurred while querying Gemini: {e}")
|
54 |
+
# return f"Error: {str(e)}"
|
55 |
+
|
56 |
+
def get_answer_from_openai(self, query: str, imagesPaths: List[str]) -> str:
|
57 |
+
"""
|
58 |
+
Query OpenAI's GPT model with a text query and associated images.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
query (str): The user's query.
|
62 |
+
imagesPaths (List[str]): List of file paths to images.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
str: The response text from OpenAI.
|
66 |
+
"""
|
67 |
print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
|
68 |
|
69 |
+
try:
|
70 |
+
# Prepare the API payload with the query and images
|
71 |
payload = self.__get_openai_api_payload(query, imagesPaths)
|
72 |
|
73 |
+
# Define the HTTP headers for the OpenAI API request
|
74 |
headers = {
|
75 |
"Content-Type": "application/json",
|
76 |
+
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" # API key from environment variables
|
77 |
}
|
78 |
+
|
79 |
+
# Send a POST request to the OpenAI API
|
80 |
response = requests.post(
|
81 |
url="https://api.openai.com/v1/chat/completions",
|
82 |
headers=headers,
|
83 |
json=payload
|
84 |
)
|
85 |
+
response.raise_for_status() # Raise an error for unsuccessful requests
|
86 |
+
|
87 |
+
# Extract the content of the response
|
88 |
answer = response.json()["choices"][0]["message"]["content"]
|
89 |
+
|
90 |
+
print(answer) # Log the answer
|
91 |
+
|
92 |
return answer
|
93 |
+
|
94 |
except Exception as e:
|
95 |
+
# Handle and log any errors that occur
|
96 |
print(f"An error occurred while querying OpenAI: {e}")
|
97 |
return None
|
98 |
+
def get_answer_from_gemini(self, query: str, imagePaths: List[str]) -> str:
|
99 |
+
"""
|
100 |
+
Query the Gemini model with a text query and associated images.
|
101 |
|
102 |
+
Args:
|
103 |
+
query (str): The user's query.
|
104 |
+
imagePaths (List[str]): List of file paths to images.
|
105 |
|
106 |
+
Returns:
|
107 |
+
str: The response text from the Gemini model.
|
108 |
+
"""
|
109 |
+
print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
|
110 |
|
111 |
+
try:
|
112 |
+
# Configure the Gemini API client using the API key from environment variables
|
113 |
+
genai.configure(api_key=os.environ['GEMINI_API_KEY'])
|
114 |
+
|
115 |
+
# Initialize the Gemini generative model
|
116 |
+
model = genai.GenerativeModel('gemini-1.5-flash')
|
117 |
+
|
118 |
+
# Load images from the given paths (skip missing files)
|
119 |
+
images = []
|
120 |
+
for path in imagePaths:
|
121 |
+
if os.path.exists(path):
|
122 |
+
images.append(Image.open(path))
|
123 |
+
else:
|
124 |
+
print(f"Warning: Image not found {path}, skipping.")
|
125 |
+
|
126 |
+
# Start a new chat session
|
127 |
+
chat = model.start_chat()
|
128 |
+
|
129 |
+
# Construct the input for the model (handle cases with and without images)
|
130 |
+
input_data = [query] if not images else [*images, query]
|
131 |
+
|
132 |
+
# Send the query (and images, if any) to the model
|
133 |
+
response = chat.send_message(input_data)
|
134 |
+
|
135 |
+
# Extract the response text
|
136 |
+
answer = response.text
|
137 |
+
|
138 |
+
print(answer) # Log the answer
|
139 |
+
|
140 |
+
return answer
|
141 |
+
|
142 |
+
except Exception as e:
|
143 |
+
# Handle and log any errors that occur
|
144 |
+
print(f"An error occurred while querying Gemini: {e}")
|
145 |
+
return f"Error: {str(e)}"
|
146 |
+
|
147 |
+
def __get_openai_api_payload(self, query: str, imagesPaths: List[str]) -> dict:
|
148 |
+
"""
|
149 |
+
Prepare the payload for the OpenAI API request.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
query (str): The user's query.
|
153 |
+
imagesPaths (List[str]): List of file paths to images.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
dict: The payload for the OpenAI API request.
|
157 |
+
"""
|
158 |
+
image_payload = [] # List to store encoded image data
|
159 |
+
|
160 |
+
# Encode each image as base64 and prepare the payload
|
161 |
for imagePath in imagesPaths:
|
162 |
+
base64_image = encode_image(imagePath) # Encode image in base64
|
163 |
image_payload.append({
|
164 |
"type": "image_url",
|
165 |
"image_url": {
|
166 |
+
"url": f"data:image/jpeg;base64,{base64_image}" # Embed image data as a URL
|
167 |
}
|
168 |
})
|
169 |
|
170 |
+
# Create the complete payload for the API request
|
171 |
payload = {
|
172 |
+
"model": "gpt-4o", # Specify the OpenAI model
|
173 |
"messages": [
|
174 |
{
|
175 |
+
"role": "user", # Role of the message sender
|
176 |
"content": [
|
177 |
{
|
178 |
"type": "text",
|
179 |
+
"text": query # Include the user's query
|
180 |
},
|
181 |
+
*image_payload # Include the image data
|
182 |
]
|
183 |
}
|
184 |
],
|
185 |
+
"max_tokens": 1024 # Limit the response length
|
186 |
}
|
187 |
|
188 |
return payload
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -1,5 +1,16 @@
|
|
1 |
-
import base64
|
2 |
|
3 |
-
def encode_image(image_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
with open(image_path, "rb") as image_file:
|
|
|
5 |
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
1 |
+
import base64 # Library for encoding and decoding data in base64 format
|
2 |
|
3 |
+
def encode_image(image_path: str) -> str:
|
4 |
+
"""
|
5 |
+
Encode an image file to a base64 string.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
image_path (str): The file path of the image to be encoded.
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
str: The base64-encoded string representation of the image.
|
12 |
+
"""
|
13 |
+
# Open the image file in binary read mode
|
14 |
with open(image_path, "rb") as image_file:
|
15 |
+
# Read the image content, encode it to base64, and decode it to a UTF-8 string
|
16 |
return base64.b64encode(image_file.read()).decode('utf-8')
|