Spaces:
Running
on
T4
Running
on
T4
import json | |
import random | |
import pprint | |
import os | |
from io import BytesIO | |
import glob | |
from pathlib import Path | |
from typing import Optional, cast | |
import numpy as np | |
#from datasets import load_dataset | |
import json | |
import boto3 | |
from opensearchpy import OpenSearch, RequestsHttpConnection, AWSV4SignerAuth | |
from requests.auth import HTTPBasicAuth | |
from requests_aws4auth import AWS4Auth | |
import matplotlib.pyplot as plt | |
import requests | |
import boto3 | |
import streamlit as st | |
import base64 | |
from colpali_engine.interpretability import ( | |
get_similarity_maps_from_embeddings, | |
plot_all_similarity_maps, | |
plot_similarity_map, | |
) | |
import torch | |
# from colpali_engine.models import ColPali, ColPaliProcessor | |
# from colpali_engine.utils.torch_utils import get_torch_device | |
from PIL import Image | |
import utilities.invoke_models as invoke_models | |
model_name = ( | |
"vidore/colpali-v1.3" | |
) | |
# colpali_model = ColPali.from_pretrained( | |
# model_name, | |
# torch_dtype=torch.bfloat16, | |
# device_map="cuda:0", # Use "cuda:0" for GPU, "cpu" for CPU, or "mps" for Apple Silicon | |
# ).eval() | |
# colpali_processor = ColPaliProcessor.from_pretrained( | |
# model_name | |
# ) | |
awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access']) | |
headers = {"Content-Type": "application/json"} | |
aos_client = OpenSearch( | |
hosts = [{'host': 'search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com', 'port': 443}], | |
http_auth = awsauth, | |
use_ssl = True, | |
verify_certs = True, | |
connection_class = RequestsHttpConnection, | |
pool_maxsize = 20 | |
) | |
region_endpoint = "us-east-1" | |
# Your SageMaker endpoint name | |
endpoint_name = "colpali-endpoint" | |
# Create a SageMaker runtime client | |
runtime = boto3.client("sagemaker-runtime",aws_access_key_id=st.secrets['user_access_key'], | |
aws_secret_access_key=st.secrets['user_secret_key'], region_name=region_endpoint) | |
# Prepare your payload (e.g., text-only input) | |
if 'top_img' not in st.session_state: | |
st.session_state['top_img'] = "" | |
if 'query_token_vectors' not in st.session_state: | |
st.session_state['query_token_vectors'] = "" | |
if 'query_tokens' not in st.session_state: | |
st.session_state['query_tokens'] = "" | |
def call_nova( | |
model, | |
messages, | |
system_message="", | |
streaming=False, | |
max_tokens=512, | |
temp=0.0001, | |
top_p=0.99, | |
top_k=20, | |
tools=None, | |
verbose=False, | |
): | |
client = boto3.client('bedrock-runtime', | |
aws_access_key_id=st.secrets['user_access_key'], | |
aws_secret_access_key=st.secrets['user_secret_key'], region_name = region_endpoint) | |
system_list = [{"text": system_message}] | |
inf_params = { | |
"max_new_tokens": max_tokens, | |
"top_p": top_p, | |
"top_k": top_k, | |
"temperature": temp, | |
} | |
request_body = { | |
"messages": messages, | |
"system": system_list, | |
"inferenceConfig": inf_params, | |
} | |
if tools is not None: | |
tool_config = [] | |
for tool in tools: | |
tool_config.append({"toolSpec": tool}) | |
request_body["toolConfig"] = {"tools": tool_config} | |
if verbose: | |
print("Request Body", request_body) | |
if not streaming: | |
response = client.invoke_model(modelId=model, body=json.dumps(request_body)) | |
model_response = json.loads(response["body"].read()) | |
return model_response, model_response["output"]["message"]["content"][0]["text"] | |
else: | |
response = client.invoke_model_with_response_stream( | |
modelId=model, body=json.dumps(request_body) | |
) | |
return response["body"] | |
def get_base64_encoded_value(media_path): | |
with open(media_path, "rb") as media_file: | |
binary_data = media_file.read() | |
base_64_encoded_data = base64.b64encode(binary_data) | |
base64_string = base_64_encoded_data.decode("utf-8") | |
return base64_string | |
def generate_ans(top_result,query): | |
print(query) | |
system_message = "given an image of a PDF page, answer the question. Be accurate to the question. If you don't find the answer in the page, please say, I don't know" | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"image": { | |
"format": "jpeg", | |
"source": { | |
"bytes": get_base64_encoded_value( | |
top_result | |
) | |
}, | |
} | |
}, | |
{ | |
"text": query#"what is the proportion of female new hires 2021-2023?" | |
}, | |
], | |
} | |
] | |
model_response, content_text = call_nova( | |
"amazon.nova-pro-v1:0", messages, system_message=system_message, max_tokens=300 | |
) | |
print(content_text) | |
return content_text | |
def img_highlight(img,batch_queries,query_tokens): | |
img_name = os.path.basename(img) # e.g., "my_image.png" | |
# Construct the search pattern | |
search_pattern = f"/home/user/app/similarity_maps/similarity_map_{img_name}_token_*" | |
# Search for matching files | |
matching_files = glob.glob(search_pattern) | |
# Check if any match exists | |
map_images = [] | |
if matching_files: | |
print("✅ Matching similarity map exists:") | |
for file_path in matching_files: | |
print(f" - {file_path}") | |
map_images.append({'file':file_path}) | |
return map_images | |
# Reference from : https://github.com/tonywu71/colpali-cookbooks/blob/main/examples/gen_colpali_similarity_maps.ipynb | |
with open(img, "rb") as f: | |
img_b64 = base64.b64encode(f.read()).decode("utf-8") | |
# Construct payload with only the image | |
payload = { | |
"images": [img_b64] | |
} | |
# Send to endpoint | |
response = runtime.invoke_endpoint( | |
EndpointName=endpoint_name, # your endpoint name | |
ContentType="application/json", | |
Body=json.dumps(payload) | |
) | |
# Read response | |
img_colpali_res = (json.loads(response["Body"].read().decode())) | |
# Convert outputs to tensors | |
image_embeddings = torch.tensor(img_colpali_res["image_embeddings"][0]) # shape: [B, T, D] or [T, D] | |
query_embeddings = torch.tensor(batch_queries) # shape: [B, D] | |
# Ensure you're accessing the full 1D mask vector, not a single value | |
image_mask_list = img_colpali_res["image_mask"] | |
if isinstance(image_mask_list[0], list): | |
# Correct: list of lists | |
image_mask = torch.tensor(image_mask_list[0]).bool() | |
else: | |
# Edge case: already flattened | |
image_mask = torch.tensor(image_mask_list).bool() | |
print("Valid patch count:", image_mask.sum().item()) # shape: [B, T] or [T] | |
# Ensure 2D query_embeddings | |
if query_embeddings.dim() == 2: | |
query_embeddings = query_embeddings.unsqueeze(0) | |
# Ensure image_embeddings and image_mask are batched | |
if image_embeddings.dim() == 2: | |
image_embeddings = image_embeddings.unsqueeze(0) # [1, T, D] | |
if image_mask.dim() == 1: | |
image_mask = image_mask.unsqueeze(0) | |
print("query_embeddings shape:", query_embeddings.shape) | |
print("image_embeddings shape:", image_embeddings.shape) | |
print("image_mask shape:", image_mask.shape) | |
# Get the number of image patches | |
image = Image.open(img) | |
n_patches = (img_colpali_res["patch_shape"]['height'],img_colpali_res["patch_shape"]['width']) | |
print(f"Number of image patches: {n_patches}") | |
# # Generate the similarity maps | |
batched_similarity_maps = get_similarity_maps_from_embeddings( | |
image_embeddings=image_embeddings, | |
query_embeddings=query_embeddings, | |
n_patches=n_patches, | |
image_mask = image_mask | |
) | |
# # Get the similarity map for our (only) input image | |
similarity_maps = batched_similarity_maps[0] # (query_length, n_patches_x, n_patches_y) | |
query_tokens_from_model = query_tokens[0]['tokens'] | |
plots = plot_all_similarity_maps( | |
image=image, | |
query_tokens=query_tokens_from_model, | |
similarity_maps=similarity_maps, | |
figsize=(8, 8), | |
show_colorbar=False, | |
add_title=True, | |
) | |
map_images = [] | |
for idx, (fig, ax) in enumerate(plots): | |
if(idx<3): | |
continue | |
savepath = "/home/user/app/similarity_maps/similarity_map_"+(img.split("/"))[-1]+"_token_"+str(idx)+"_"+query_tokens_from_model[idx]+".png" | |
fig.savefig(savepath, bbox_inches="tight") | |
map_images.append({'file':savepath}) | |
print(f"Similarity map for token `{query_tokens_from_model[idx]}` saved at `{savepath}`") | |
plt.close("all") | |
return map_images | |
def colpali_search_rerank(query): | |
if(st.session_state.show_columns == True): | |
print("show columns activated------------------------") | |
st.session_state.maxSimImages = img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens) | |
st.session_state.show_columns = False | |
return_val = {'text':st.session_state.answers_[0]['answer'],'source':st.session_state.answers_[0]['source'],'image':st.session_state.maxSimImages,'table':[]} | |
st.session_state.input_query = st.session_state.questions_[-1]["question"] | |
st.session_state.answers_.pop() | |
st.session_state.questions_.pop() | |
return return_val | |
# Convert to JSON string | |
payload = { | |
"queries": [query] | |
} | |
body = json.dumps(payload) | |
# Call the endpoint | |
response = runtime.invoke_endpoint( | |
EndpointName=endpoint_name, | |
ContentType="application/json", | |
Body=body | |
) | |
# Read and print the response | |
result = json.loads(response["Body"].read().decode()) | |
#print(len(result['query_embeddings'][0])) | |
final_docs_sorted_20 = [] | |
for i in result['query_embeddings']: | |
batch_embeddings = i | |
a = np.array(batch_embeddings) | |
vec = a.mean(axis=0) | |
#print(vec) | |
hits = [] | |
#for v in batch_embeddings: | |
query_ = { | |
"size": 200, | |
"query": { | |
"nested": { | |
"path": "page_sub_vectors", | |
"query": { | |
"knn": { | |
"page_sub_vectors.page_sub_vector": { | |
"vector": vec.tolist(), | |
"k": 200 | |
} | |
} | |
} | |
} | |
} | |
} | |
response = aos_client.search( | |
body = query_, | |
index = 'colpali-vs' | |
) | |
#print(response) | |
query_token_vectors = batch_embeddings | |
final_docs = [] | |
hits += response['hits']['hits'] | |
#print(len(hits)) | |
for ind,j in enumerate(hits): | |
max_score_dict_list = [] | |
doc={"id":j["_id"],"score":j["_score"],"image":j["_source"]["image"]} | |
with_s = j['_source']['page_sub_vectors'] | |
add_score = 0 | |
for index,i in enumerate(query_token_vectors): | |
query_token_vector = np.array(i) | |
scores = [] | |
for m in with_s: | |
doc_token_vector = np.array(m['page_sub_vector']) | |
score = np.dot(query_token_vector,doc_token_vector) | |
scores.append(score) | |
scores.sort(reverse=True) | |
max_score = scores[0] | |
add_score+=max_score | |
doc["total_score"] = add_score | |
final_docs.append(doc) | |
final_docs_sorted = sorted(final_docs, key=lambda d: d['total_score'], reverse=True) | |
final_docs_sorted_20.append(final_docs_sorted[:20]) | |
img = "/home/user/app/vs/"+final_docs_sorted_20[0][0]['image'] | |
ans = generate_ans(img,query) | |
images_highlighted = [{'file':img}] | |
st.session_state.top_img = img | |
st.session_state.query_token_vectors = query_token_vectors | |
st.session_state.query_tokens = result['query_tokens'] | |
return {'text':ans,'source':img,'image':images_highlighted,'table':[]}#[{'file':img}] | |