root
upload
e676d24
import streamlit as st
import os, time
from app.vdr_session import *
from app.vdr_schemas import *
from st_clickable_images import clickable_images
from app.prompt_template import VDR_PROMPT
def page_vdr():
st.header("Visual Document Retrieval")
# Store session context
if "vdr_session" not in st.session_state.keys():
st.session_state["vdr_session"] = VDRSession()
with st.sidebar:
#api_key = st.text_input('Enter API Key:', type='password')
api_key = os.getenv("GLOBAL_AIFS_API_KEY")
check_api_key=st.session_state["vdr_session"].set_api_key(api_key)
if check_api_key:
st.success('API Key is valid!', icon='✅')
avai_llms = st.session_state["vdr_session"].get_available_vlms()
avai_embeds = st.session_state["vdr_session"].get_available_image_embeds()
selected_llm = st.sidebar.selectbox('Choose VLM models', avai_llms, key='selected_llm', disabled=not check_api_key)
selected_embed = st.sidebar.selectbox('Choose Embedding models', avai_embeds, key='selected_embed', disabled=not check_api_key)
#st.session_state["vdr_session"].set_context(selected_llm, selected_embed)
else:
st.warning('Please enter valid credentials!', icon='⚠️')
if check_api_key:
with st.sidebar:
uploaded_files = st.file_uploader("Upload PDF files", key="uploaded_files", accept_multiple_files=True, disabled=not check_api_key)
if st.button("Add selected context", key="add_context", type="primary"):
if uploaded_files:
try:
indexing_bar = st.progress(0, text="Indexing...")
if st.session_state["vdr_session"].indexing(uploaded_files, selected_embed, indexing_bar):
st.success('Indexing completed!')
indexing_bar.empty()
#st.rerun()
else:
st.warning('Files empty or not supported.', icon='⚠️')
except Exception as e:
st.error(f"Error during indexing: {e}")
else:
st.warning('Please upload files first!', icon='⚠️')
if st.button("🗑️ Remove all context", key="remove_context"):
try:
st.session_state["vdr_session"].clear_context()
st.success("Context removed")
st.rerun()
except Exception as e:
st.error(f"Error during removing context: {e}")
top_k_sim = st.slider(label="Top k similarity", min_value=1, max_value=10, value=3, step=1, key="top_k_sim")
#text_only_embed = st.toggle("Text only embedding", key="text_only_embed", value=False)
chat_prompt = st.text_area("Prompt template", key="chat_prompt", value=VDR_PROMPT, height=300)
query = st.text_input(label="Query",key='query',placeholder="Enter your query here",label_visibility="hidden", disabled=not st.session_state.get("vdr_session").indexed_images)
with st.expander(f"**Top {top_k_sim} retrieved contexts**", expanded=True):
try:
if len(query.strip()) > 2:
if query != st.session_state.get("last_query", None):
with st.spinner('Searching...'):
st.session_state["last_query"] = query
st.session_state["result_images"] = st.session_state["vdr_session"].search_images(query, top_k_sim)
if st.session_state.get("result_images", []):
images = st.session_state["result_images"]
clicked = clickable_images(
images,
titles=[f"Image #{str(i)}" for i in range(len(images))],
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
img_style={"margin": "5px", "height": "200px"},
)
st.write(f"**Retrieved by: {selected_embed}**")
@st.dialog(" ", width="large")
def show_selected_image(id):
st.markdown(f"**Similarity rank: {id}**")
st.image(images[id])
if clicked > -1 and clicked != st.session_state.get("clicked", None):
show_selected_image(clicked)
st.session_state["clicked"] = clicked
except Exception as e:
st.error(f"Error during search: {e}")
if st.session_state.get("result_images", None):
if st.button("Generate answer", key="ask", type="primary"):
if len(query.strip()) > 2:
try:
with st.spinner('Generating response...'):
stream_response = st.session_state["vdr_session"].ask(
query=query,
model=selected_llm,
prompt_template= chat_prompt,
retrieved_context=st.session_state["result_images"],
stream=True
)
#print(stream_response)
st.write_stream(stream_response)
st.write(f"**Answered by: {selected_llm}**")
except Exception as e:
st.error(f"Error during asking: {e}")
else:
st.warning('Please enter query first!', icon='⚠️')