Spaces:
Running
on
T4
Running
on
T4
# Streamlit app: Chat with PDFs using OpenSearch, RAG, and ColPali | |
import streamlit as st | |
import uuid | |
import os | |
import sys | |
import warnings | |
import boto3 | |
import json | |
import random | |
import string | |
import pandas as pd | |
from PIL import Image | |
from requests.auth import HTTPBasicAuth | |
# Suppress Streamlit deprecation warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
# Add necessary module paths | |
base_path = "/".join(os.path.realpath(__file__).split("/")[:-2]) | |
sys.path.insert(1, f"{base_path}/semantic_search") | |
sys.path.insert(1, f"{base_path}/RAG") | |
sys.path.insert(1, f"{base_path}/utilities") | |
# Local modules | |
import rag_DocumentLoader | |
import rag_DocumentSearcher | |
import colpali | |
# AWS & OpenSearch setup | |
region = 'us-east-1' | |
s3_bucket_ = "pdf-repo-uploads" | |
bedrock_runtime_client = boto3.client('bedrock-runtime', region_name=region) | |
polly_client = boto3.client( | |
'polly', | |
aws_access_key_id=st.secrets['user_access_key'], | |
aws_secret_access_key=st.secrets['user_secret_key'], | |
region_name=region | |
) | |
credentials = boto3.Session().get_credentials() | |
awsauth = HTTPBasicAuth('master', st.secrets['ml_search_demo_api_access']) | |
# App configuration | |
st.set_page_config(layout="wide", page_icon="images/opensearch_mark_default.png") | |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[:-1]) | |
USER_ICON = "images/user.png" | |
AI_ICON = "images/opensearch-twitter-card.png" | |
REGENERATE_ICON = "images/regenerate.png" | |
# Session state setup | |
if 'user_id' not in st.session_state: | |
st.session_state['user_id'] = str(uuid.uuid4()) | |
st.session_state.setdefault('session_id', "") | |
st.session_state.setdefault('chats', [{'id': 0, 'question': '', 'answer': ''}]) | |
st.session_state.setdefault('questions_', []) | |
st.session_state.setdefault('answers_', []) | |
st.session_state.setdefault('show_columns', False) | |
st.session_state.setdefault('input_index', "hpijan2024hometrack") | |
st.session_state.setdefault('input_is_rerank', True) | |
st.session_state.setdefault('input_is_colpali', False) | |
st.session_state.setdefault('input_copali_rerank', False) | |
st.session_state.setdefault('input_table_with_sql', False) | |
st.session_state.setdefault('input_query', "which city has the highest average housing price in UK ?") | |
st.session_state.setdefault('input_rag_searchType', ["Vector Search"]) | |
# Custom styling | |
st.markdown(""" | |
<style> | |
[data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock], | |
[data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock] { | |
gap: 0rem; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Top bar with app logo and clear button | |
def write_top_bar(): | |
col1, col2 = st.columns([77, 23]) | |
with col1: | |
st.header("Chat with your data", divider='rainbow') | |
with col2: | |
clear = st.button("Clear") | |
st.write("") # spacing | |
return clear | |
# Reset inputs when Clear is clicked | |
if write_top_bar(): | |
st.session_state.questions_ = [] | |
st.session_state.answers_ = [] | |
st.session_state.input_query = "" | |
# Handle user query submission | |
def handle_input(): | |
if st.session_state.input_query == '': | |
return | |
inputs = {key.removeprefix('input_'): st.session_state[key] for key in st.session_state if key.startswith('input_')} | |
st.session_state.inputs_ = inputs | |
st.session_state.questions_.append({ | |
'question': inputs["query"], | |
'id': len(st.session_state.questions_) | |
}) | |
if st.session_state.input_is_colpali: | |
out_ = colpali.colpali_search_rerank(st.session_state.input_query) | |
else: | |
out_ = rag_DocumentSearcher.query_( | |
awsauth, | |
inputs, | |
st.session_state['session_id'], | |
st.session_state.input_rag_searchType | |
) | |
st.session_state.answers_.append({ | |
'answer': out_['text'], | |
'source': out_['source'], | |
'id': len(st.session_state.questions_), | |
'image': out_['image'], | |
'table': out_['table'] | |
}) | |
st.session_state.input_query = "" | |
# Display user message block | |
def write_user_message(msg): | |
col1, col2 = st.columns([3, 97]) | |
with col1: | |
st.image(USER_ICON, use_container_width=True) | |
with col2: | |
st.markdown( | |
f"<div style='color:#e28743;font-size:18px;padding:3px 7px;border-radius:10px;font-style:italic;'>{msg['question']}</div>", | |
unsafe_allow_html=True | |
) | |
# Render assistant answer block with optional images and tables | |
def write_chat_message(response, question, index): | |
col1, col2, col3 = st.columns([4, 74, 22]) | |
with col1: | |
st.image(AI_ICON, use_container_width=True) | |
with col2: | |
answer_text = response['answer'] | |
st.write(answer_text) | |
polly_response = polly_client.synthesize_speech( | |
VoiceId='Joanna', OutputFormat='ogg_vorbis', Text=answer_text, Engine='neural') | |
st.audio(polly_response['AudioStream'].read(), format="audio/ogg") | |
if st.session_state.input_is_colpali: | |
if st.button("Show similarity map", key=f"simmap_{index}"): | |
st.session_state.show_columns = True | |
st.session_state.maxSimImages = colpali.img_highlight( | |
st.session_state.top_img, | |
st.session_state.query_token_vectors, | |
st.session_state.query_tokens | |
) | |
handle_input() | |
with placeholder.container(): | |
render_all() | |
with st.expander("Relevant Sources"): | |
for img in response.get('image', []): | |
if isinstance(img, dict) and 'file' in img: | |
st.image(img['file']) | |
for tbl in response.get('table', []): | |
try: | |
df = pd.read_csv(tbl['name'], skipinitialspace=True, on_bad_lines='skip', delimiter='`') | |
df.fillna(method='pad', inplace=True) | |
st.table(df) | |
except Exception as e: | |
st.warning(f"Failed to load table: {e}") | |
st.write(response.get("source", "")) | |
# Render all Q&A pairs | |
def render_all(): | |
for index, (q, a) in enumerate(zip(st.session_state.questions_, st.session_state.answers_), start=1): | |
write_user_message(q) | |
write_chat_message(a, q, index) | |
# Placeholder for dynamic rendering | |
placeholder = st.empty() | |
with placeholder.container(): | |
render_all() | |
# Input field for user question | |
col_2, col_3 = st.columns([75, 20]) | |
with col_2: | |
st.text_input("Ask here", label_visibility="collapsed", key="input_query") | |
with col_3: | |
st.button("GO", on_click=handle_input, key="play") | |
# Sidebar configuration | |
with st.sidebar: | |
st.page_link("app.py", label=":orange[Home]", icon="🏠") | |
st.subheader(":blue[Sample Data]") | |
coln_1, coln_2 = st.columns([70, 30]) | |
with coln_1: | |
st.radio("Choose one index", ["UK Housing", "Global Warming stats", "Covid19 impacts on Ireland"], key="input_rad_index") | |
with coln_2: | |
st.markdown("<p style='font-size:15px'>Preview file</p>", unsafe_allow_html=True) | |
st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)") | |
st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)") | |
st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)") | |
st.subheader(":blue[Retriever]") | |
st.multiselect("Select the Retriever(s)", ["Keyword Search", "Vector Search", "Sparse Search"], default=["Vector Search"], key="input_rag_searchType") | |
st.checkbox("Re-rank results", key="input_is_rerank", value=True) | |
st.subheader(":blue[Multi-vector retrieval]") | |
colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', | |
key='input_colpali', | |
disabled=False, | |
value=False, | |
help="Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim") | |
if colpali_search_rerank: | |
st.session_state.input_is_colpali = True | |
else: | |
st.session_state.input_is_colpali = False | |
with st.expander("Sample questions for Colpali retriever:"): | |
st.write(""" | |
1. Proportion of female new hires 2021-2023? | |
2. First-half 2021 return on unlisted real estate investments? | |
3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? | |
4. Fund return percentage in 2017? | |
5. Annualized gross return of the fund from 1997 to 2008? | |
""") | |