Spaces:
Runtime error
Runtime error
import os | |
import streamlit as st | |
from haystack import Pipeline | |
from haystack.document_stores import FAISSDocumentStore | |
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever | |
from haystack.nodes.retriever.web import WebRetriever | |
QUERIES = [ | |
"Did SVB collapse?", | |
"Why did SVB collapse?", | |
"What does SVB failure mean for our economy?", | |
"Who is responsible for SVC collapse?", | |
"When did SVB collapse?" | |
] | |
def ChangeWidgetFontSize(wgt_txt, wch_font_size = '12px'): | |
htmlstr = """<script>var elements = window.parent.document.querySelectorAll('*'), i; | |
for (i = 0; i < elements.length; ++i) { if (elements[i].innerText == |wgt_txt|) | |
{ elements[i].style.fontSize='""" + wch_font_size + """';} } </script> """ | |
htmlstr = htmlstr.replace('|wgt_txt|', "'" + wgt_txt + "'") | |
def get_plain_pipeline(): | |
prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"]) | |
# Now let make one PromptNode use the default model and the other one the OpenAI model: | |
plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query") | |
node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300) | |
pipeline = Pipeline() | |
pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"]) | |
return pipeline | |
def get_retrieval_augmented_pipeline(): | |
ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss", | |
faiss_config_path="data/my_faiss_index.json") | |
retriever = EmbeddingRetriever( | |
document_store=ds, | |
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1", | |
model_format="sentence_transformers", | |
top_k=2 | |
) | |
shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"]) | |
default_template = PromptTemplate( | |
name="question-answering", | |
prompt_text="Given the context please answer the question. Context: $documents; Question: " | |
"$query; Answer:", | |
) | |
# Let's initiate the PromptNode | |
node = PromptNode("text-davinci-003", default_prompt_template=default_template, | |
api_key=st.secrets["OPENAI_API_KEY"], max_length=500) | |
# Let's create a pipeline with Shaper and PromptNode | |
pipeline = Pipeline() | |
pipeline.add_node(component=retriever, name='retriever', inputs=['Query']) | |
pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"]) | |
pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"]) | |
return pipeline | |
def get_web_retrieval_augmented_pipeline(): | |
search_key = st.secrets["WEBRET_API_KEY"] | |
web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev") | |
shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"]) | |
default_template = PromptTemplate( | |
name="question-answering", | |
prompt_text="Given the context please answer the question. Context: $documents; Question: " | |
"$query; Answer:", | |
) | |
# Let's initiate the PromptNode | |
node = PromptNode("text-davinci-003", default_prompt_template=default_template, | |
api_key=st.secrets["OPENAI_API_KEY"], max_length=500) | |
# Let's create a pipeline with Shaper and PromptNode | |
pipeline = Pipeline() | |
pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query']) | |
pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"]) | |
pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"]) | |
return pipeline | |
def app_init(): | |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
p1 = get_plain_pipeline() | |
p2 = get_retrieval_augmented_pipeline() | |
p3 = get_web_retrieval_augmented_pipeline() | |
return p1, p2, p3 | |
if 'query' not in st.session_state: | |
st.session_state['query'] = "" | |
def set_question(): | |
st.session_state['query'] = st.session_state['q_drop_down'] | |
def set_q1(): | |
st.session_state['query'] = QUERIES[0] | |
def set_q2(): | |
st.session_state['query'] = QUERIES[1] | |
def set_q3(): | |
st.session_state['query'] = QUERIES[2] | |
def set_q4(): | |
st.session_state['query'] = QUERIES[3] | |
def set_q5(): | |
st.session_state['query'] = QUERIES[4] | |