import os import gradio as gr import torch from FlagEmbedding.visual.modeling import Visualized_BGE from torchvision import transforms from PIL import Image from torch.utils.data import DataLoader from tqdm import tqdm from pdf2image import convert_from_path import numpy as np import torch.nn.functional as F import io # Initialize the Visualized-BGE model def load_bge_model(model_name: str, model_weight_path: str): model = Visualized_BGE(model_name_bge=model_name, model_weight=model_weight_path) model.eval() return model # Load the BGE model (ensure you have downloaded the weights and provide the correct path) model_name = "BAAI/bge-base-en-v1.5" # or "BAAI/bge-m3" for multilingual model_weight_path ="./Visualized_base_en_v1.5.pth" model = load_bge_model(model_name, model_weight_path) device = "cuda:0" if torch.cuda.is_available() else "cpu" model.to(device) # Function to encode images import tempfile import os def encode_image(image_input): """ Encodes an image for retrieval. Args: image_input: Can be a file path (str), a NumPy array, or a PIL Image. Returns: torch.Tensor: The image embedding. """ delete_temp_file = False if isinstance(image_input, str): image_path = image_input else: with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file: if isinstance(image_input, np.ndarray): image = Image.fromarray(image_input) elif isinstance(image_input, Image.Image): image = image_input else: raise ValueError("Unsupported image input type for image encoding.") image.save(tmp_file.name) image_path = tmp_file.name delete_temp_file = True # Mark that we need to delete this temp file try: with torch.no_grad(): embed = model.encode(image=image_path) embed = embed.squeeze(0) finally: if delete_temp_file: # Remove the temporary file os.remove(image_path) return embed.cpu() # Function to encode text def encode_text(text): with torch.no_grad(): embed = model.encode(text=text) # Assuming encode returns [1, D] embed = embed.squeeze(0) # Remove the batch dimension if present return embed.cpu() # Function to index uploaded files (PDFs or images) def index_files(files, embeddings_state, metadata_state): print("Indexing files...") embeddings = [] metadata = [] for file in files: if file.name.lower().endswith('.pdf'): images = convert_from_path(file.name, thread_count=4) for idx, img in enumerate(images): img_path = f"{file.name}_page_{idx}.png" img.save(img_path) embed = encode_image(img_path) print(f"Embedding shape after encoding image: {embed.shape}") # Should be [768] embeddings.append(embed) metadata.append({"type": "image", "path": img_path, "info": f"Page {idx}"}) elif file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')): img_path = file.name embed = encode_image(img_path) print(f"Embedding shape after encoding image: {embed.shape}") # Should be [768] embeddings.append(embed) metadata.append({"type": "image", "path": img_path, "info": "Uploaded Image"}) else: raise gr.Error("Unsupported file type. Please upload PDFs or image files.") embeddings = torch.stack(embeddings).to(device) # Should result in shape [N, 768] print(f"Stacked embeddings shape: {embeddings.shape}") embeddings_state = embeddings metadata_state = metadata return f"Indexed {len(embeddings)} items.", embeddings_state, metadata_state def search(query_text, query_image, k, embeddings_state, metadata_state): embeddings = embeddings_state metadata = metadata_state if embeddings is None or embeddings.size(0) == 0: return "No embeddings indexed. Please upload and index files first.", [] query_emb = None if query_text and query_image: gr.warning("Please provide either a text query or an image query, not both. Using text query by default.") # text_emb = encode_text(query_text) # [D] # image_emb = encode_image(query_image) # [D] # query_emb = (text_emb + image_emb) / 2 # [D] # print("Combined text and image embeddings for query.") query_emb = encode_text(query_text) # [D] if query_text: query_emb = encode_text(query_text) # [D] print("Encoded text query.") elif query_image is not None : print(query_image) query_emb = encode_image(query_image) # [D] print("Encoded image query.") else: return "Please provide at least a text query or an image query.", [] # Ensure query_emb has shape [1, D] if query_emb.dim() == 1: query_emb = query_emb.unsqueeze(0) # [1, D] # Normalize embeddings for cosine similarity query_emb = F.normalize(query_emb.to(device), p=2, dim=1) # [1, D] indexed_emb = F.normalize(embeddings.to(device), p=2, dim=1) # [N, D] print(f"Query embedding shape: {query_emb.shape}") # Should be [1, 768] print(f"Indexed embeddings shape: {indexed_emb.shape}") # Should be [N, 768] # Compute cosine similarities similarities = torch.matmul(query_emb, indexed_emb.T).squeeze(0) # [N] print(f"Similarities shape: {similarities.shape}") # Get top-k results topk = torch.topk(similarities, k) topk_indices = topk.indices.cpu().numpy() topk_scores = topk.values.cpu().numpy() print(f"Top-{k} indices: {topk_indices}") print(f"Top-{k} scores: {topk_scores}") results = [] for idx, score in zip(topk_indices, topk_scores): item = metadata[idx] if item["type"] == "image": # Load image from path img = Image.open(item["path"]).convert("RGB") results.append((img, f"Score: {score:.4f} | {item['info']}")) else: # Handle text data if applicable results.append((item["data"], f"Score: {score:.4f} | {item['info']}")) return results # Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Visualized-BGE: Multimodal Retrieval Demo 🎉") gr.Markdown(""" Upload PDF or image files to index them. Then, perform searches using text, images, or both to retrieve the most relevant items. **Note:** Ensure that you have indexed the files before performing a search. """) # Initialize state variables embeddings_state = gr.State(None) metadata_state = gr.State(None) with gr.Row(): with gr.Column(scale=2): gr.Markdown("## 1️⃣ Upload and Index Files") file_input = gr.File(file_types=["pdf", "png", "jpg", "jpeg", "bmp", "gif"], file_count="multiple", label="Upload Files") index_button = gr.Button("🔄 Index Files") index_status = gr.Textbox("No files indexed yet.", label="Indexing Status") with gr.Column(scale=3): gr.Markdown("## 2️⃣ Perform Search") with gr.Row(): query_text = gr.Textbox(placeholder="Enter your text query here...", label="Text Query") query_image = gr.Image(label="Image Query (Optional)") k = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5) search_button = gr.Button("🔍 Search") output_gallery = gr.Gallery(label="Retrieved Results", show_label=True, columns=2) # Define button actions index_button.click( index_files, inputs=[file_input, embeddings_state, metadata_state], outputs=[index_status, embeddings_state, metadata_state] ) search_button.click( search, inputs=[query_text, query_image, k, embeddings_state, metadata_state], outputs=output_gallery ) gr.Markdown(""" --- ## About This demo uses the **Visualized-BGE** model for efficient multimodal retrieval tasks. Upload your documents or images, index them, and perform searches using text, images, or a combination of both. **References:** - [Visualized-BGE Paper](https://arxiv.org/abs/2406.04292) - [FlagEmbedding GitHub](https://github.com/FlagOpen/FlagEmbedding) """) if __name__ == "__main__": demo.launch(debug=True, share=True)