File size: 6,703 Bytes
b6208a3 8329090 b6208a3 8329090 b6208a3 90ba1bf 8329090 b6208a3 90ba1bf b6208a3 90ba1bf b6208a3 90ba1bf b6208a3 90ba1bf 8329090 b6208a3 8329090 b6208a3 8329090 b6208a3 8329090 b6208a3 8329090 b6208a3 90ba1bf b6208a3 8329090 b6208a3 8329090 b6208a3 8329090 b6208a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from operator import index
import streamlit as st
import logging
import os
from annotated_text import annotation
from json import JSONDecodeError
from markdown import markdown
from utils.config import parser
from utils.haystack import start_document_store, query, initialize_pipeline
from utils.ui import reset_results, set_initial_state
import pandas as pd
import haystack
try:
args = parser.parse_args()
document_store = start_document_store(type=args.store)
st.set_page_config(
page_title="MLReplySearch",
layout="centered",
page_icon=":shark:",
menu_items={
'Get Help': 'https://www.extremelycoolapp.com/help',
'Report a bug': "https://www.extremelycoolapp.com/bug",
'About': "# This is a header. This is an *extremely* cool app!"
}
)
st.sidebar.image("ml_logo.png", use_column_width=True)
# Sidebar for Task Selection
st.sidebar.header('Options:')
# OpenAI Key Input
openai_key = st.sidebar.text_input("Enter OpenAI Key:", type="password")
if openai_key:
task_options = ['Extractive', 'Generative']
else:
task_options = ['Extractive']
task_selection = st.sidebar.radio('Select the task:', task_options)
# Check the task and initialize pipeline accordingly
if task_selection == 'Extractive':
pipeline_extractive = initialize_pipeline("extractive", document_store)
elif task_selection == 'Generative' and openai_key: # Check for openai_key to ensure user has entered it
pipeline_rag = initialize_pipeline("rag", document_store, openai_key=openai_key)
set_initial_state()
st.write('# ' + args.name)
if "question" not in st.session_state:
st.session_state.question = ""
# Search bar
question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)
run_pressed = st.button("Run")
run_query = (
run_pressed or question != st.session_state.question #or task_selection != st.session_state.task
)
# Get results for query
if run_query and question:
if task_selection == 'Extractive':
reset_results()
st.session_state.question = question
with st.spinner("π Running your pipeline"):
try:
st.session_state.results_extractive = query(pipeline_extractive, question)
st.session_state.task = task_selection
except JSONDecodeError as je:
st.error(
"π An error occurred reading the results. Is the document store working?"
)
except Exception as e:
logging.exception(e)
st.error("π An error occurred during the request.")
elif task_selection == 'Generative':
reset_results()
st.session_state.question = question
with st.spinner("π Running your pipeline"):
try:
st.session_state.results_generative = query(pipeline_rag, question)
st.session_state.task = task_selection
except JSONDecodeError as je:
st.error(
"π An error occurred reading the results. Is the document store working?"
)
except Exception as e:
if "API key is invalid" in str(e):
logging.exception(e)
st.error("π incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.")
else:
logging.exception(e)
st.error("π An error occurred during the request.")
# Display results
if (st.session_state.results_extractive or st.session_state.results_generative) and run_query:
# Handle Extractive Answers
if task_selection == 'Extractive':
results = st.session_state.results_extractive
st.subheader("Extracted Answers:")
if 'answers' in results:
answers = results['answers']
treshold = 0.2
higher_then_treshold = any(ans.score > treshold for ans in answers)
if not higher_then_treshold:
st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True)
for count, answer in enumerate(answers):
if answer.answer:
text, context = answer.answer, answer.context
start_idx = context.find(text)
end_idx = start_idx + len(text)
score = round(answer.score, 3)
st.markdown(f"**Answer {count + 1}:**")
st.markdown(
context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:],
unsafe_allow_html=True,
)
else:
st.info(
"π€ Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
)
# Handle Generative Answers
elif task_selection == 'Generative':
results = st.session_state.results_generative
st.subheader("Generated Answer:")
if 'results' in results:
st.markdown("**Answer:**")
st.write(results['results'][0])
# Handle Retrieved Documents
if 'documents' in results:
retrieved_documents = results['documents']
st.subheader("Retriever Results:")
data = []
for i, document in enumerate(retrieved_documents):
# Truncate the content
truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content
data.append([i + 1, document.meta['name'], truncated_content])
# Convert data to DataFrame and display using Streamlit
df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content'])
st.table(df)
except SystemExit as e:
os._exit(e.code)
|