File size: 6,703 Bytes
b6208a3
8329090
 
 
 
 
 
 
 
b6208a3
8329090
b6208a3
90ba1bf
8329090
 
 
b6208a3
 
90ba1bf
b6208a3
90ba1bf
b6208a3
90ba1bf
 
 
 
 
b6208a3
 
 
 
90ba1bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8329090
 
 
b6208a3
8329090
b6208a3
 
8329090
b6208a3
8329090
 
 
 
b6208a3
8329090
 
 
 
b6208a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90ba1bf
 
 
 
 
 
b6208a3
 
 
 
 
 
8329090
b6208a3
8329090
b6208a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8329090
 
b6208a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from operator import index
import streamlit as st
import logging
import os

from annotated_text import annotation
from json import JSONDecodeError
from markdown import markdown
from utils.config import parser
from utils.haystack import start_document_store, query, initialize_pipeline
from utils.ui import reset_results, set_initial_state
import pandas as pd
import haystack

try:
    args = parser.parse_args()
    document_store = start_document_store(type=args.store)
    st.set_page_config(
        page_title="MLReplySearch",
        layout="centered",
        page_icon=":shark:",
        menu_items={
            'Get Help': 'https://www.extremelycoolapp.com/help',
            'Report a bug': "https://www.extremelycoolapp.com/bug",
            'About': "# This is a header. This is an *extremely* cool app!"
        }
    )
    st.sidebar.image("ml_logo.png", use_column_width=True)

    # Sidebar for Task Selection
    st.sidebar.header('Options:')

    # OpenAI Key Input
    openai_key = st.sidebar.text_input("Enter OpenAI Key:", type="password")

    if openai_key:
        task_options = ['Extractive', 'Generative']
    else:
        task_options = ['Extractive']

    task_selection = st.sidebar.radio('Select the task:', task_options)

    # Check the task and initialize pipeline accordingly
    if task_selection == 'Extractive':
        pipeline_extractive = initialize_pipeline("extractive", document_store)
    elif task_selection == 'Generative' and openai_key:  # Check for openai_key to ensure user has entered it
        pipeline_rag = initialize_pipeline("rag", document_store, openai_key=openai_key)

    set_initial_state()

    st.write('# ' + args.name)

    if "question" not in st.session_state:
        st.session_state.question = ""
    # Search bar
    question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)
    
    run_pressed = st.button("Run")

    run_query = (
        run_pressed or question != st.session_state.question #or task_selection != st.session_state.task
    )

    # Get results for query
    if run_query and question:
        if task_selection == 'Extractive':
            reset_results()
            st.session_state.question = question
            with st.spinner("πŸ”Ž    Running your pipeline"):
                try:
                    st.session_state.results_extractive = query(pipeline_extractive, question)
                    st.session_state.task = task_selection
                except JSONDecodeError as je:
                    st.error(
                        "πŸ‘“    An error occurred reading the results. Is the document store working?"
                    )    
                except Exception as e:
                    logging.exception(e)
                    st.error("🐞    An error occurred during the request.")
        
        elif task_selection == 'Generative':
            reset_results()
            st.session_state.question = question
            with st.spinner("πŸ”Ž    Running your pipeline"):
                try:
                    st.session_state.results_generative = query(pipeline_rag, question)
                    st.session_state.task = task_selection
                except JSONDecodeError as je:
                    st.error(
                        "πŸ‘“    An error occurred reading the results. Is the document store working?"
                    )    
                except Exception as e:
                    if "API key is invalid" in str(e):
                        logging.exception(e)
                        st.error("🐞    incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.")
                    else:
                        logging.exception(e)
                        st.error("🐞    An error occurred during the request.")
    # Display results
    if (st.session_state.results_extractive or st.session_state.results_generative) and run_query:
        
        # Handle Extractive Answers
        if task_selection == 'Extractive':
            results = st.session_state.results_extractive
            
            st.subheader("Extracted Answers:")

            if 'answers' in results:
                answers = results['answers']
                treshold = 0.2
                higher_then_treshold = any(ans.score > treshold for ans in answers)
                if not higher_then_treshold:
                    st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True)
                for count, answer in enumerate(answers):
                    if answer.answer:
                        text, context = answer.answer, answer.context
                        start_idx = context.find(text)
                        end_idx = start_idx + len(text)
                        score = round(answer.score, 3)
                        st.markdown(f"**Answer {count + 1}:**")
                        st.markdown(
                            context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:],
                            unsafe_allow_html=True,
                        )
                    else:
                        st.info(
                            "πŸ€” &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
                        )

        # Handle Generative Answers
        elif task_selection == 'Generative':
            results = st.session_state.results_generative
            st.subheader("Generated Answer:")
            if 'results' in results:
                st.markdown("**Answer:**")
                st.write(results['results'][0])

        # Handle Retrieved Documents
        if 'documents' in results:
            retrieved_documents = results['documents']
            st.subheader("Retriever Results:")

            data = []
            for i, document in enumerate(retrieved_documents):
                # Truncate the content
                truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content
                data.append([i + 1, document.meta['name'], truncated_content])

            # Convert data to DataFrame and display using Streamlit
            df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content'])
            st.table(df)

except SystemExit as e:
    os._exit(e.code)