|
import streamlit as st |
|
import os, time |
|
from app.vdr_session import * |
|
from app.vdr_schemas import * |
|
from st_clickable_images import clickable_images |
|
from app.prompt_template import VDR_PROMPT |
|
|
|
def page_vdr(): |
|
st.header("Visual Document Retrieval") |
|
|
|
|
|
if "vdr_session" not in st.session_state.keys(): |
|
st.session_state["vdr_session"] = VDRSession() |
|
|
|
with st.sidebar: |
|
|
|
|
|
api_key = os.getenv("GLOBAL_AIFS_API_KEY") |
|
|
|
check_api_key=st.session_state["vdr_session"].set_api_key(api_key) |
|
|
|
if check_api_key: |
|
st.success('API Key is valid!', icon='✅') |
|
avai_llms = st.session_state["vdr_session"].get_available_vlms() |
|
avai_embeds = st.session_state["vdr_session"].get_available_image_embeds() |
|
selected_llm = st.sidebar.selectbox('Choose VLM models', avai_llms, key='selected_llm', disabled=not check_api_key) |
|
selected_embed = st.sidebar.selectbox('Choose Embedding models', avai_embeds, key='selected_embed', disabled=not check_api_key) |
|
|
|
else: |
|
st.warning('Please enter valid credentials!', icon='⚠️') |
|
|
|
if check_api_key: |
|
|
|
with st.sidebar: |
|
uploaded_files = st.file_uploader("Upload PDF files", key="uploaded_files", accept_multiple_files=True, disabled=not check_api_key) |
|
|
|
if st.button("Add selected context", key="add_context", type="primary"): |
|
if uploaded_files: |
|
try: |
|
indexing_bar = st.progress(0, text="Indexing...") |
|
if st.session_state["vdr_session"].indexing(uploaded_files, selected_embed, indexing_bar): |
|
st.success('Indexing completed!') |
|
indexing_bar.empty() |
|
|
|
else: |
|
st.warning('Files empty or not supported.', icon='⚠️') |
|
except Exception as e: |
|
st.error(f"Error during indexing: {e}") |
|
else: |
|
st.warning('Please upload files first!', icon='⚠️') |
|
|
|
if st.button("🗑️ Remove all context", key="remove_context"): |
|
try: |
|
st.session_state["vdr_session"].clear_context() |
|
st.success("Context removed") |
|
st.rerun() |
|
except Exception as e: |
|
st.error(f"Error during removing context: {e}") |
|
|
|
|
|
top_k_sim = st.slider(label="Top k similarity", min_value=1, max_value=10, value=3, step=1, key="top_k_sim") |
|
|
|
chat_prompt = st.text_area("Prompt template", key="chat_prompt", value=VDR_PROMPT, height=300) |
|
|
|
query = st.text_input(label="Query",key='query',placeholder="Enter your query here",label_visibility="hidden", disabled=not st.session_state.get("vdr_session").indexed_images) |
|
|
|
with st.expander(f"**Top {top_k_sim} retrieved contexts**", expanded=True): |
|
try: |
|
if len(query.strip()) > 2: |
|
if query != st.session_state.get("last_query", None): |
|
with st.spinner('Searching...'): |
|
st.session_state["last_query"] = query |
|
st.session_state["result_images"] = st.session_state["vdr_session"].search_images(query, top_k_sim) |
|
|
|
if st.session_state.get("result_images", []): |
|
images = st.session_state["result_images"] |
|
|
|
clicked = clickable_images( |
|
images, |
|
titles=[f"Image #{str(i)}" for i in range(len(images))], |
|
div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, |
|
img_style={"margin": "5px", "height": "200px"}, |
|
) |
|
st.write(f"**Retrieved by: {selected_embed}**") |
|
|
|
@st.dialog(" ", width="large") |
|
def show_selected_image(id): |
|
st.markdown(f"**Similarity rank: {id}**") |
|
st.image(images[id]) |
|
|
|
if clicked > -1 and clicked != st.session_state.get("clicked", None): |
|
show_selected_image(clicked) |
|
st.session_state["clicked"] = clicked |
|
|
|
except Exception as e: |
|
st.error(f"Error during search: {e}") |
|
|
|
if st.session_state.get("result_images", None): |
|
if st.button("Generate answer", key="ask", type="primary"): |
|
if len(query.strip()) > 2: |
|
try: |
|
with st.spinner('Generating response...'): |
|
stream_response = st.session_state["vdr_session"].ask( |
|
query=query, |
|
model=selected_llm, |
|
prompt_template= chat_prompt, |
|
retrieved_context=st.session_state["result_images"], |
|
stream=True |
|
) |
|
|
|
st.write_stream(stream_response) |
|
st.write(f"**Answered by: {selected_llm}**") |
|
except Exception as e: |
|
st.error(f"Error during asking: {e}") |
|
else: |
|
st.warning('Please enter query first!', icon='⚠️') |
|
|
|
|
|
|
|
|
|
|