File size: 3,979 Bytes
70b2fc9
c24940a
 
4a448eb
c24940a
 
 
4a448eb
c24940a
 
 
 
 
 
 
 
 
4a448eb
 
 
 
 
c24940a
4a448eb
 
c24940a
 
4a448eb
 
c24940a
4a448eb
c24940a
 
4a448eb
 
c24940a
4a448eb
c24940a
4a448eb
c24940a
4a448eb
c24940a
4a448eb
c24940a
 
 
 
 
 
 
 
 
4a448eb
c24940a
4a448eb
c24940a
 
3842297
 
 
 
4a448eb
 
 
3842297
c24940a
 
 
 
3842297
4a448eb
65935d6
4a448eb
c24940a
 
 
 
 
 
3842297
65935d6
3842297
 
 
 
 
 
 
4a448eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import streamlit as st
from backend_utils import (get_plain_pipeline, get_retrieval_augmented_pipeline,
                           get_web_retrieval_augmented_pipeline, set_q1, set_q2, set_q3, set_q4, set_q5, QUERIES)

st.set_page_config(
    page_title="Retrieval Augmentation with Haystack",
)

st.markdown("<center> <h2> Reduce Hallucinations with Retrieval Augmentation </h2> </center>", unsafe_allow_html=True)

st.markdown("Ask a question about the collapse of the Silicon Valley Bank (SVB).", unsafe_allow_html=True)

# if not st.session_state.get('pipelines_loaded', False):
#     with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
#         p1, p2, p3 = app_init()
#         st.success('Pipelines are loaded', icon="✅")
#         st.session_state['pipelines_loaded'] = True

placeholder = st.empty()
with placeholder:
    search_bar, button = st.columns([3, 1])
    with search_bar:
        username = st.text_area(f" ", max_chars=200, key='query')

    with button:
        st.write(" ")
        st.write(" ")
        run_pressed = st.button("Run")

st.markdown("<center> <h5> Example questions </h5> </center>", unsafe_allow_html=True)

st.write(" ")
st.write(" ")
c1, c2, c3, c4, c5 = st.columns(5)
with c1:
    st.button(QUERIES[0], on_click=set_q1)
with c2:
    st.button(QUERIES[1], on_click=set_q2)
with c3:
    st.button(QUERIES[2], on_click=set_q3)
with c4:
    st.button(QUERIES[3], on_click=set_q4)
with c5:
    st.button(QUERIES[4], on_click=set_q5)

st.write(" ")
st.radio("Answer Type:", ("Retrieval Augmented (Static news dataset)", "Retrieval Augmented with Web Search"), key="query_type")

# st.sidebar.selectbox(
#      "Example Questions:",
#      QUERIES,
#      key='q_drop_down', on_change=set_question)

st.markdown("<h5> Answer with GPT's Internal Knowledge </h5>", unsafe_allow_html=True)
placeholder_plain_gpt = st.empty()
st.text(" ")
st.text(" ")
if st.session_state.get("query_type", "Retrieval Augmented (Static news dataset)") == "Retrieval Augmented (Static news dataset)":
    st.markdown("<h5> Answer with Retrieval Augmented GPT (Static news dataset) </h5>", unsafe_allow_html=True)
else:
    st.markdown("<h5> Answer with Retrieval Augmented GPT (Web Search) </h5>", unsafe_allow_html=True)
placeholder_retrieval_augmented = st.empty()

if st.session_state.get('query') and run_pressed:
    ip = st.session_state['query']
    with st.spinner('Loading pipelines... \n This may take a few mins and might also fail if OpenAI API server is down.'):
        p1 = get_plain_pipeline()
    with st.spinner('Fetching answers from GPT\'s internal knowledge... '
                    '\n This may take a few mins and might also fail if OpenAI API server is down.'):
        answers = p1.run(ip)
    placeholder_plain_gpt.markdown(answers['results'][0])

    if st.session_state.get("query_type", "Retrieval Augmented") == "Retrieval Augmented":
        with st.spinner(
                'Loading Retrieval Augmented pipeline... \
                n This may take a few mins and might also fail if OpenAI API server is down.'):
            p2 = get_retrieval_augmented_pipeline()
        with st.spinner('Fetching relevant documents from documented stores and calculating answers... '
                        '\n This may take a few mins and might also fail if OpenAI API server is down.'):
            answers_2 = p2.run(ip)
    else:
        with st.spinner(
                'Loading Retrieval Augmented pipeline... \
                n This may take a few mins and might also fail if OpenAI API server is down.'):
            p3 = get_web_retrieval_augmented_pipeline()
        with st.spinner('Fetching relevant documents from the Web and calculating answers... '
                        '\n This may take a few mins and might also fail if OpenAI API server is down.'):
            answers_2 = p3.run(ip)
    placeholder_retrieval_augmented.markdown(answers_2['results'][0])