import json import random import pprint import os from io import BytesIO import glob from pathlib import Path from typing import Optional, cast import numpy as np #from datasets import load_dataset import json import boto3 from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth from requests.auth import HTTPBasicAuth from requests_aws4auth import AWS4Auth import matplotlib.pyplot as plt import requests import boto3 import streamlit as st import base64 from colpali_engine.interpretability import ( get_similarity_maps_from_embeddings, plot_all_similarity_maps, plot_similarity_map, ) import torch # from colpali_engine.models import ColPali, ColPaliProcessor # from colpali_engine.utils.torch_utils import get_torch_device from PIL import Image import utilities.invoke_models as invoke_models model_name = ( "vidore/colpali-v1.3" ) # colpali_model = ColPali.from_pretrained( # model_name, # torch_dtype=torch.bfloat16, # device_map="cuda:0", # Use "cuda:0" for GPU, "cpu" for CPU, or "mps" for Apple Silicon # ).eval() # colpali_processor = ColPaliProcessor.from_pretrained( # model_name # ) awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access']) headers = {"Content-Type": "application/json"} aos_client = OpenSearch( hosts = [{'host': 'search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com', 'port': 443}], http_auth = awsauth, use_ssl = True, verify_certs = True, connection_class = RequestsHttpConnection, pool_maxsize = 20 ) region_endpoint = "us-east-1" # Your SageMaker endpoint name endpoint_name = "colpali-endpoint" # Create a SageMaker runtime client runtime = boto3.client("sagemaker-runtime",aws_access_key_id=st.secrets['user_access_key'], aws_secret_access_key=st.secrets['user_secret_key'], region_name=region_endpoint) # Prepare your payload (e.g., text-only input) if 'top_img' not in st.session_state: st.session_state['top_img'] = "" if 'query_token_vectors' not in st.session_state: st.session_state['query_token_vectors'] = "" if 'query_tokens' not in st.session_state: st.session_state['query_tokens'] = "" def call_nova( model, messages, system_message="", streaming=False, max_tokens=512, temp=0.0001, top_p=0.99, top_k=20, tools=None, verbose=False, ): client = boto3.client('bedrock-runtime', aws_access_key_id=st.secrets['user_access_key'], aws_secret_access_key=st.secrets['user_secret_key'], region_name = region_endpoint) system_list = [{"text": system_message}] inf_params = { "max_new_tokens": max_tokens, "top_p": top_p, "top_k": top_k, "temperature": temp, } request_body = { "messages": messages, "system": system_list, "inferenceConfig": inf_params, } if tools is not None: tool_config = [] for tool in tools: tool_config.append({"toolSpec": tool}) request_body["toolConfig"] = {"tools": tool_config} if verbose: print("Request Body", request_body) if not streaming: response = client.invoke_model(modelId=model, body=json.dumps(request_body)) model_response = json.loads(response["body"].read()) return model_response, model_response["output"]["message"]["content"][0]["text"] else: response = client.invoke_model_with_response_stream( modelId=model, body=json.dumps(request_body) ) return response["body"] def get_base64_encoded_value(media_path): with open(media_path, "rb") as media_file: binary_data = media_file.read() base_64_encoded_data = base64.b64encode(binary_data) base64_string = base_64_encoded_data.decode("utf-8") return base64_string def generate_ans(top_result,query): print(query) system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know" messages = [ { "role": "user", "content": [ { "image": { "format": "jpeg", "source": { "bytes": get_base64_encoded_value( top_result ) }, } }, { "text": query#"what is the proportion of female new hires 2021-2023?" }, ], } ] model_response, content_text = call_nova( "amazon.nova-pro-v1:0", messages, system_message=system_message, max_tokens=300 ) print(content_text) return content_text def img_highlight(img,batch_queries,query_tokens): img_name = os.path.basename(img) # e.g., "my_image.png" # Construct the search pattern search_pattern = f"/home/user/app/similarity_maps/similarity_map_{img_name}_token_*" # Search for matching files matching_files = glob.glob(search_pattern) # Check if any match exists map_images = [] if matching_files: print("✅ Matching similarity map exists:") for file_path in matching_files: print(f" - {file_path}") map_images.append({'file':file_path}) return map_images # Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb with open(img, "rb") as f: img_b64 = base64.b64encode(f.read()).decode("utf-8") # Construct payload with only the image payload = { "images": [img_b64] } # Send to endpoint response = runtime.invoke_endpoint( EndpointName=endpoint_name, # your endpoint name ContentType="application/json", Body=json.dumps(payload) ) # Read response img_colpali_res = (json.loads(response["Body"].read().decode())) # Convert outputs to tensors image_embeddings = torch.tensor(img_colpali_res["image_embeddings"][0]) # shape: [B, T, D] or [T, D] query_embeddings = torch.tensor(batch_queries) # shape: [B, D] # Ensure you're accessing the full 1D mask vector, not a single value image_mask_list = img_colpali_res["image_mask"] if isinstance(image_mask_list[0], list): # Correct: list of lists image_mask = torch.tensor(image_mask_list[0]).bool() else: # Edge case: already flattened image_mask = torch.tensor(image_mask_list).bool() print("Valid patch count:", image_mask.sum().item()) # shape: [B, T] or [T] # Ensure 2D query_embeddings if query_embeddings.dim() == 2: query_embeddings = query_embeddings.unsqueeze(0) # Ensure image_embeddings and image_mask are batched if image_embeddings.dim() == 2: image_embeddings = image_embeddings.unsqueeze(0) # [1, T, D] if image_mask.dim() == 1: image_mask = image_mask.unsqueeze(0) print("query_embeddings shape:", query_embeddings.shape) print("image_embeddings shape:", image_embeddings.shape) print("image_mask shape:", image_mask.shape) # Get the number of image patches image = Image.open(img) n_patches = (img_colpali_res["patch_shape"]['height'],img_colpali_res["patch_shape"]['width']) print(f"Number of image patches: {n_patches}") # # Generate the similarity maps batched_similarity_maps = get_similarity_maps_from_embeddings( image_embeddings=image_embeddings, query_embeddings=query_embeddings, n_patches=n_patches, image_mask = image_mask ) # # Get the similarity map for our (only) input image similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y) query_tokens_from_model = query_tokens[0]['tokens'] plots = plot_all_similarity_maps( image=image, query_tokens=query_tokens_from_model, similarity_maps=similarity_maps, figsize=(8, 8), show_colorbar=False, add_title=True, ) map_images = [] for idx, (fig, ax) in enumerate(plots): if(idx<3): continue savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png" fig.savefig(savepath, bbox_inches="tight") map_images.append({'file':savepath}) print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`") plt.close("all") return map_images def colpali_search_rerank(query): if(st.session_state.show_columns == True): print("show columns activated------------------------") st.session_state.maxSimImages = img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens) st.session_state.show_columns = False return_val = {'text':st.session_state.answers_[0]['answer'],'source':st.session_state.answers_[0]['source'],'image':st.session_state.maxSimImages,'table':[]} st.session_state.input_query = st.session_state.questions_[-1]["question"] st.session_state.answers_.pop() st.session_state.questions_.pop() return return_val # Convert to JSON string payload = { "queries": [query] } body = json.dumps(payload) # Call the endpoint response = runtime.invoke_endpoint( EndpointName=endpoint_name, ContentType="application/json", Body=body ) # Read and print the response result = json.loads(response["Body"].read().decode()) #print(len(result['query_embeddings'][0])) final_docs_sorted_20 = [] for i in result['query_embeddings']: batch_embeddings = i a = np.array(batch_embeddings) vec = a.mean(axis=0) #print(vec) hits = [] #for v in batch_embeddings: query_ = { "size": 200, "query": { "nested": { "path": "page_sub_vectors", "query": { "knn": { "page_sub_vectors.page_sub_vector": { "vector": vec.tolist(), "k": 200 } } } } } } response = aos_client.search( body = query_, index = 'colpali-vs' ) #print(response) query_token_vectors = batch_embeddings final_docs = [] hits += response['hits']['hits'] #print(len(hits)) for ind,j in enumerate(hits): max_score_dict_list = [] doc={"id":j["_id"],"score":j["_score"],"image":j["_source"]["image"]} with_s = j['_source']['page_sub_vectors'] add_score = 0 for index,i in enumerate(query_token_vectors): query_token_vector = np.array(i) scores = [] for m in with_s: doc_token_vector = np.array(m['page_sub_vector']) score = np.dot(query_token_vector,doc_token_vector) scores.append(score) scores.sort(reverse=True) max_score = scores[0] add_score+=max_score doc["total_score"] = add_score final_docs.append(doc) final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True) final_docs_sorted_20.append(final_docs_sorted[:20]) img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image'] ans = generate_ans(img,query) images_highlighted = [{'file':img}] st.session_state.top_img = img st.session_state.query_token_vectors = query_token_vectors st.session_state.query_tokens = result['query_tokens'] return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}]