File size: 5,936 Bytes
b6208a3
8329090
 
 
 
 
 
 
 
b6208a3
8329090
b6208a3
8329090
 
 
b6208a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8329090
 
 
b6208a3
8329090
b6208a3
 
8329090
b6208a3
8329090
 
 
 
b6208a3
8329090
 
 
 
b6208a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8329090
b6208a3
8329090
b6208a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8329090
 
b6208a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from operator import index
import streamlit as st
import logging
import os

from annotated_text import annotation
from json import JSONDecodeError
from markdown import markdown
from utils.config import parser
from utils.haystack import start_document_store, query, initialize_pipeline
from utils.ui import reset_results, set_initial_state
import pandas as pd

try:
    args = parser.parse_args()
    document_store = start_document_store(type=args.store)
    st.set_page_config(
        page_title="test",
        layout="centered",
        page_icon = (":shark:"),
        menu_items={
         'Get Help': 'https://www.extremelycoolapp.com/help',
         'Report a bug': "https://www.extremelycoolapp.com/bug",
         'About': "# This is a header. This is an *extremely* cool app!"
     })
    st.sidebar.image("ml_logo.png", use_column_width=True)

    # Sidebar for Task Selection
    st.sidebar.header('Options:')
    task_selection = st.sidebar.radio('Select the task:', ['Extractive', 'Generative'])
    
    pipeline_rag = initialize_pipeline("rag", document_store)
    pipeline_extractive = initialize_pipeline("extractive", document_store)

    set_initial_state()

    st.write('# ' + args.name)

    if "question" not in st.session_state:
        st.session_state.question = ""
    # Search bar
    question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)
    
    run_pressed = st.button("Run")

    run_query = (
        run_pressed or question != st.session_state.question #or task_selection != st.session_state.task
    )

    # Get results for query
    if run_query and question:
        if task_selection == 'Extractive':
            reset_results()
            st.session_state.question = question
            with st.spinner("πŸ”Ž    Running your pipeline"):
                try:
                    st.session_state.results_extractive = query(pipeline_extractive, question)
                    st.session_state.task = task_selection
                except JSONDecodeError as je:
                    st.error(
                        "πŸ‘“    An error occurred reading the results. Is the document store working?"
                    )    
                except Exception as e:
                    logging.exception(e)
                    st.error("🐞    An error occurred during the request.")
        
        elif task_selection == 'Generative':
            reset_results()
            st.session_state.question = question
            with st.spinner("πŸ”Ž    Running your pipeline"):
                try:
                    st.session_state.results_generative = query(pipeline_rag, question)
                    st.session_state.task = task_selection
                except JSONDecodeError as je:
                    st.error(
                        "πŸ‘“    An error occurred reading the results. Is the document store working?"
                    )    
                except Exception as e:
                    logging.exception(e)
                    st.error("🐞    An error occurred during the request.")

    # Display results
    if (st.session_state.results_extractive or st.session_state.results_generative) and run_query:
        
        # Handle Extractive Answers
        if task_selection == 'Extractive':
            results = st.session_state.results_extractive
            
            st.subheader("Extracted Answers:")

            if 'answers' in results:
                answers = results['answers']
                treshold = 0.2
                higher_then_treshold = any(ans.score > treshold for ans in answers)
                if not higher_then_treshold:
                    st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True)
                for count, answer in enumerate(answers):
                    if answer.answer:
                        text, context = answer.answer, answer.context
                        start_idx = context.find(text)
                        end_idx = start_idx + len(text)
                        score = round(answer.score, 3)
                        st.markdown(f"**Answer {count + 1}:**")
                        st.markdown(
                            context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:],
                            unsafe_allow_html=True,
                        )
                    else:
                        st.info(
                            "πŸ€” &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
                        )

        # Handle Generative Answers
        elif task_selection == 'Generative':
            results = st.session_state.results_generative
            st.subheader("Generated Answer:")
            if 'results' in results:
                st.markdown("**Answer:**")
                st.write(results['results'][0])

        # Handle Retrieved Documents
        if 'documents' in results:
            retrieved_documents = results['documents']
            st.subheader("Retriever Results:")

            data = []
            for i, document in enumerate(retrieved_documents):
                # Truncate the content
                truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content
                data.append([i + 1, document.meta['name'], truncated_content])

            # Convert data to DataFrame and display using Streamlit
            df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content'])
            st.table(df)

except SystemExit as e:
    os._exit(e.code)