import streamlit as st import os, time from app.vdr_session import * from app.vdr_schemas import * from st_clickable_images import clickable_images from app.prompt_template import VDR_PROMPT def page_vdr(): st.header("Visual Document Retrieval") # Store session context if "vdr_session" not in st.session_state.keys(): st.session_state["vdr_session"] = VDRSession() with st.sidebar: #api_key = st.text_input('Enter API Key:', type='password') api_key = os.getenv("GLOBAL_AIFS_API_KEY") check_api_key=st.session_state["vdr_session"].set_api_key(api_key) if check_api_key: st.success('API Key is valid!', icon='✅') avai_llms = st.session_state["vdr_session"].get_available_vlms() avai_embeds = st.session_state["vdr_session"].get_available_image_embeds() selected_llm = st.sidebar.selectbox('Choose VLM models', avai_llms, key='selected_llm', disabled=not check_api_key) selected_embed = st.sidebar.selectbox('Choose Embedding models', avai_embeds, key='selected_embed', disabled=not check_api_key) #st.session_state["vdr_session"].set_context(selected_llm, selected_embed) else: st.warning('Please enter valid credentials!', icon='⚠️') if check_api_key: with st.sidebar: uploaded_files = st.file_uploader("Upload PDF files", key="uploaded_files", accept_multiple_files=True, disabled=not check_api_key) if st.button("Add selected context", key="add_context", type="primary"): if uploaded_files: try: indexing_bar = st.progress(0, text="Indexing...") if st.session_state["vdr_session"].indexing(uploaded_files, selected_embed, indexing_bar): st.success('Indexing completed!') indexing_bar.empty() #st.rerun() else: st.warning('Files empty or not supported.', icon='⚠️') except Exception as e: st.error(f"Error during indexing: {e}") else: st.warning('Please upload files first!', icon='⚠️') if st.button("🗑️ Remove all context", key="remove_context"): try: st.session_state["vdr_session"].clear_context() st.success("Context removed") st.rerun() except Exception as e: st.error(f"Error during removing context: {e}") top_k_sim = st.slider(label="Top k similarity", min_value=1, max_value=10, value=3, step=1, key="top_k_sim") #text_only_embed = st.toggle("Text only embedding", key="text_only_embed", value=False) chat_prompt = st.text_area("Prompt template", key="chat_prompt", value=VDR_PROMPT, height=300) query = st.text_input(label="Query",key='query',placeholder="Enter your query here",label_visibility="hidden", disabled=not st.session_state.get("vdr_session").indexed_images) with st.expander(f"**Top {top_k_sim} retrieved contexts**", expanded=True): try: if len(query.strip()) > 2: if query != st.session_state.get("last_query", None): with st.spinner('Searching...'): st.session_state["last_query"] = query st.session_state["result_images"] = st.session_state["vdr_session"].search_images(query, top_k_sim) if st.session_state.get("result_images", []): images = st.session_state["result_images"] clicked = clickable_images( images, titles=[f"Image #{str(i)}" for i in range(len(images))], div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, img_style={"margin": "5px", "height": "200px"}, ) st.write(f"**Retrieved by: {selected_embed}**") @st.dialog(" ", width="large") def show_selected_image(id): st.markdown(f"**Similarity rank: {id}**") st.image(images[id]) if clicked > -1 and clicked != st.session_state.get("clicked", None): show_selected_image(clicked) st.session_state["clicked"] = clicked except Exception as e: st.error(f"Error during search: {e}") if st.session_state.get("result_images", None): if st.button("Generate answer", key="ask", type="primary"): if len(query.strip()) > 2: try: with st.spinner('Generating response...'): stream_response = st.session_state["vdr_session"].ask( query=query, model=selected_llm, prompt_template= chat_prompt, retrieved_context=st.session_state["result_images"], stream=True ) #print(stream_response) st.write_stream(stream_response) st.write(f"**Answered by: {selected_llm}**") except Exception as e: st.error(f"Error during asking: {e}") else: st.warning('Please enter query first!', icon='⚠️')