import streamlit as st from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline, get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES, PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS) st.set_page_config( page_title="Retrieval Augmentation with Haystack", ) st.markdown("

Reduce Hallucinations with Retrieval Augmentation

", unsafe_allow_html=True) st.markdown("Ask a question about the collapse of the Silicon Valley Bank (SVB).", unsafe_allow_html=True) # if not st.session_state.get('pipelines_loaded', False): # with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'): # p1, p2, p3 = app_init() # st.success('Pipelines are loaded', icon="✅") # st.session_state['pipelines_loaded'] = True placeholder = st.empty() with placeholder: search_bar, button = st.columns([3, 1]) with search_bar: username = st.text_area(f" ", max_chars=200, key='query') with button: st.write(" ") st.write(" ") run_pressed = st.button("Run") st.markdown("
Example questions
", unsafe_allow_html=True) st.write(" ") st.write(" ") c1, c2, c3, c4, c5 = st.columns(5) with c1: st.button(QUERIES[0], on_click=set_q1) with c2: st.button(QUERIES[1], on_click=set_q2) with c3: st.button(QUERIES[2], on_click=set_q3) with c4: st.button(QUERIES[3], on_click=set_q4) with c5: st.button(QUERIES[4], on_click=set_q5) st.write(" ") st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieval Augmented with Web Search"), key="query_type") # st.sidebar.selectbox( # "Example Questions:", # QUERIES, # key='q_drop_down', on_change=set_question) st.markdown(f"
{PLAIN_GPT_ANS}
", unsafe_allow_html=True) placeholder_plain_gpt = st.empty() st.text(" ") st.text(" ") if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)": st.markdown(f"
{GPT_LOCAL_RET_AUG_ANS}
", unsafe_allow_html=True) else: st.markdown(f"
{GPT_WEB_RET_AUG_ANS}
", unsafe_allow_html=True) placeholder_retrieval_augmented = st.empty() if st.session_state.get('query') and run_pressed: ip = st.session_state['query'] with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'): p1 = get_plain_pipeline() with st.spinner('Fetching answers from plain GPT... ' '\n This may take a few mins and might also fail if OpenAI API server is down.'): answers = p1.run(ip) placeholder_plain_gpt.markdown(answers['results'][0]) if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented": with st.spinner( 'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... ' '\n This may take a few mins and might also fail if OpenAI API server is down.'): p2 = get_retrieval_augmented_pipeline() with st.spinner('Getting relevant documents from documented stores and calculating answers... ' '\n This may take a few mins and might also fail if OpenAI API server is down.'): answers_2 = p2.run(ip) else: with st.spinner( 'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \ n This may take a few mins and might also fail if OpenAI API server is down.'): p3 = get_web_retrieval_augmented_pipeline() with st.spinner('Getting relevant documents from the Web and calculating answers... ' '\n This may take a few mins and might also fail if OpenAI API server is down.'): answers_2 = p3.run(ip) placeholder_retrieval_augmented.markdown(answers_2['results'][0])