Spaces:
Building
Building
import streamlit as st | |
import os, time | |
from app.vdr_session import * | |
from app.vdr_schemas import * | |
from st_clickable_images import clickable_images | |
from app.prompt_template import VDR_PROMPT | |
def page_vdr(): | |
st.header("Visual Document Retrieval") | |
# Store session context | |
if "vdr_session" not in st.session_state.keys(): | |
st.session_state["vdr_session"] = VDRSession() | |
with st.sidebar: | |
#api_key = st.text_input('Enter API Key:', type='password') | |
api_key = os.getenv("GLOBAL_AIFS_API_KEY") | |
check_api_key=st.session_state["vdr_session"].set_api_key(api_key) | |
if check_api_key: | |
st.success('API Key is valid!', icon='✅') | |
avai_llms = st.session_state["vdr_session"].get_available_vlms() | |
avai_embeds = st.session_state["vdr_session"].get_available_image_embeds() | |
selected_llm = st.sidebar.selectbox('Choose VLM models', avai_llms, key='selected_llm', disabled=not check_api_key) | |
selected_embed = st.sidebar.selectbox('Choose Embedding models', avai_embeds, key='selected_embed', disabled=not check_api_key) | |
#st.session_state["vdr_session"].set_context(selected_llm, selected_embed) | |
else: | |
st.warning('Please enter valid credentials!', icon='⚠️') | |
if check_api_key: | |
with st.sidebar: | |
uploaded_files = st.file_uploader("Upload PDF files", key="uploaded_files", accept_multiple_files=True, disabled=not check_api_key) | |
if st.button("Add selected context", key="add_context", type="primary"): | |
if uploaded_files: | |
try: | |
indexing_bar = st.progress(0, text="Indexing...") | |
if st.session_state["vdr_session"].indexing(uploaded_files, selected_embed, indexing_bar): | |
st.success('Indexing completed!') | |
indexing_bar.empty() | |
#st.rerun() | |
else: | |
st.warning('Files empty or not supported.', icon='⚠️') | |
except Exception as e: | |
st.error(f"Error during indexing: {e}") | |
else: | |
st.warning('Please upload files first!', icon='⚠️') | |
if st.button("🗑️ Remove all context", key="remove_context"): | |
try: | |
st.session_state["vdr_session"].clear_context() | |
st.success("Context removed") | |
st.rerun() | |
except Exception as e: | |
st.error(f"Error during removing context: {e}") | |
top_k_sim = st.slider(label="Top k similarity", min_value=1, max_value=10, value=3, step=1, key="top_k_sim") | |
#text_only_embed = st.toggle("Text only embedding", key="text_only_embed", value=False) | |
chat_prompt = st.text_area("Prompt template", key="chat_prompt", value=VDR_PROMPT, height=300) | |
query = st.text_input(label="Query",key='query',placeholder="Enter your query here",label_visibility="hidden", disabled=not st.session_state.get("vdr_session").indexed_images) | |
with st.expander(f"**Top {top_k_sim} retrieved contexts**", expanded=True): | |
try: | |
if len(query.strip()) > 2: | |
if query != st.session_state.get("last_query", None): | |
with st.spinner('Searching...'): | |
st.session_state["last_query"] = query | |
st.session_state["result_images"] = st.session_state["vdr_session"].search_images(query, top_k_sim) | |
if st.session_state.get("result_images", []): | |
images = st.session_state["result_images"] | |
clicked = clickable_images( | |
images, | |
titles=[f"Image #{str(i)}" for i in range(len(images))], | |
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, | |
img_style={"margin": "5px", "height": "200px"}, | |
) | |
st.write(f"**Retrieved by: {selected_embed}**") | |
def show_selected_image(id): | |
st.markdown(f"**Similarity rank: {id}**") | |
st.image(images[id]) | |
if clicked > -1 and clicked != st.session_state.get("clicked", None): | |
show_selected_image(clicked) | |
st.session_state["clicked"] = clicked | |
except Exception as e: | |
st.error(f"Error during search: {e}") | |
if st.session_state.get("result_images", None): | |
if st.button("Generate answer", key="ask", type="primary"): | |
if len(query.strip()) > 2: | |
try: | |
with st.spinner('Generating response...'): | |
stream_response = st.session_state["vdr_session"].ask( | |
query=query, | |
model=selected_llm, | |
prompt_template= chat_prompt, | |
retrieved_context=st.session_state["result_images"], | |
stream=True | |
) | |
#print(stream_response) | |
st.write_stream(stream_response) | |
st.write(f"**Answered by: {selected_llm}**") | |
except Exception as e: | |
st.error(f"Error during asking: {e}") | |
else: | |
st.warning('Please enter query first!', icon='⚠️') | |