multimodal_rag / middleware.py
ej68okap
new code added
241c492
raw
history blame
3.87 kB
# Import necessary modules and classes
from colpali_manager import ColpaliManager # Manages processing of images and text with the ColPali model
from milvus_manager import MilvusManager # Manages interactions with the Milvus database
from pdf_manager import PdfManager # Handles PDF processing tasks
import hashlib # Library for creating hashed identifiers
# Initialize managers
pdf_manager = PdfManager() # PDF manager instance for handling PDF-related operations
colpali_manager = ColpaliManager() # ColPali manager instance for processing images and text
class Middleware:
"""
Middleware class that integrates PDF processing, image embedding, and database indexing/searching.
"""
def __init__(self, id: str, create_collection=True):
"""
Initialize the Middleware with a unique identifier and Milvus database setup.
Args:
id (str): Unique identifier for the user/session.
create_collection (bool): Whether to create a new collection in the Milvus database.
"""
# Generate a hashed ID for the Milvus database name
hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
milvus_db_name = f"milvus_{hashed_id}.db"
# Initialize the Milvus manager with the generated database name
self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
def index(self, pdf_path: str, id: str, max_pages: int, pages: list[int] = None):
"""
Index the content of a PDF file into the Milvus database.
Args:
pdf_path (str): Path to the PDF file.
id (str): Unique identifier for the session.
max_pages (int): Maximum number of pages to extract and index.
pages (list[int], optional): Specific pages to extract (default is None for all).
Returns:
list[str]: List of paths to the saved image files.
"""
print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
# Convert PDF pages into image files and save them
image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
print(f"Saved {len(image_paths)} images")
# Generate image embeddings using the ColPali model
colbert_vecs = colpali_manager.process_images(image_paths)
# Prepare data for insertion into Milvus
images_data = [{
"colbert_vecs": colbert_vecs[i], # Image embeddings
"filepath": image_paths[i] # Corresponding image file path
} for i in range(len(image_paths))]
print(f"Inserting {len(images_data)} images data to Milvus")
# Insert the image data into the Milvus database
self.milvus_manager.insert_images_data(images_data)
print("Indexing completed")
return image_paths # Return the list of saved image paths
def search(self, search_queries: list[str]):
"""
Search for matching results in the indexed database based on text queries.
Args:
search_queries (list[str]): List of search queries.
Returns:
list: Search results for each query.
"""
print(f"Searching for {len(search_queries)} queries")
final_res = [] # List to store the final search results
for query in search_queries:
print(f"Searching for query: {query}")
# Process the query text to generate an embedding
query_vec = colpali_manager.process_text([query])[0]
# Perform the search in the Milvus database
search_res = self.milvus_manager.search(query_vec, topk=1)
print(f"Search result: {search_res} for query: {query}")
# Append the search results to the final results list
final_res.append(search_res)
return final_res # Return all search results