import streamlit as st
from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES,
PLAIN_GPT_ANS, GPT_WEB_RET_AUG_ANS, GPT_LOCAL_RET_AUG_ANS)
st.set_page_config(
page_title="Retrieval Augmentation with Haystack",
)
st.markdown("
Reduce Hallucinations with Retrieval Augmentation
", unsafe_allow_html=True)
st.markdown("Ask a question about the collapse of the Silicon Valley Bank (SVB).", unsafe_allow_html=True)
# if not st.session_state.get('pipelines_loaded', False):
# with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
# p1, p2, p3 = app_init()
# st.success('Pipelines are loaded', icon="✅")
# st.session_state['pipelines_loaded'] = True
placeholder = st.empty()
with placeholder:
search_bar, button = st.columns([3, 1])
with search_bar:
username = st.text_area(f" ", max_chars=200, key='query')
with button:
st.write(" ")
st.write(" ")
run_pressed = st.button("Run")
st.markdown(" Example questions
", unsafe_allow_html=True)
st.write(" ")
st.write(" ")
c1, c2, c3, c4, c5 = st.columns(5)
with c1:
st.button(QUERIES[0], on_click=set_q1)
with c2:
st.button(QUERIES[1], on_click=set_q2)
with c3:
st.button(QUERIES[2], on_click=set_q3)
with c4:
st.button(QUERIES[3], on_click=set_q4)
with c5:
st.button(QUERIES[4], on_click=set_q5)
st.write(" ")
st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieval Augmented with Web Search"), key="query_type")
# st.sidebar.selectbox(
# "Example Questions:",
# QUERIES,
# key='q_drop_down', on_change=set_question)
st.markdown(f" {PLAIN_GPT_ANS}
", unsafe_allow_html=True)
placeholder_plain_gpt = st.empty()
st.text(" ")
st.text(" ")
if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
st.markdown(f" {GPT_LOCAL_RET_AUG_ANS}
", unsafe_allow_html=True)
else:
st.markdown(f"{GPT_WEB_RET_AUG_ANS}
", unsafe_allow_html=True)
placeholder_retrieval_augmented = st.empty()
if st.session_state.get('query') and run_pressed:
ip = st.session_state['query']
with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
p1 = get_plain_pipeline()
with st.spinner('Fetching answers from plain GPT... '
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
answers = p1.run(ip)
placeholder_plain_gpt.markdown(answers['results'][0])
if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
with st.spinner(
'Loading Retrieval Augmented pipeline that can fetch relevant documents from local data store... '
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
p2 = get_retrieval_augmented_pipeline()
with st.spinner('Getting relevant documents from documented stores and calculating answers... '
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
answers_2 = p2.run(ip)
else:
with st.spinner(
'Loading Retrieval Augmented pipeline that can fetch relevant documents from the web... \
n This may take a few mins and might also fail if OpenAI API server is down.'):
p3 = get_web_retrieval_augmented_pipeline()
with st.spinner('Getting relevant documents from the Web and calculating answers... '
'\n This may take a few mins and might also fail if OpenAI API server is down.'):
answers_2 = p3.run(ip)
placeholder_retrieval_augmented.markdown(answers_2['results'][0])