File size: 4,279 Bytes
4a448eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24940a
4a448eb
 
 
 
 
 
 
 
 
 
c24940a
4a448eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24940a
4a448eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24940a
 
 
 
 
 
 
 
 
 
4a448eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
from haystack import Pipeline
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import Shaper, PromptNode, PromptTemplate, PromptModel, EmbeddingRetriever
from haystack.nodes.retriever.web import WebRetriever


QUERIES = [
    "Did SVB collapse?",
    "Why did SVB collapse?",
    "What does SVB failure mean for our economy?",
    "Who is responsible for SVC collapse?",
    "When did SVB collapse?"
]


@st.cache_resource(show_spinner=False)
def get_plain_pipeline():
    prompt_open_ai = PromptModel(model_name_or_path="text-davinci-003", api_key=st.secrets["OPENAI_API_KEY"])
    # Now let make one PromptNode use the default model and the other one the OpenAI model:
    plain_llm_template = PromptTemplate(name="plain_llm", prompt_text="Answer the following question: $query")
    node_openai = PromptNode(prompt_open_ai, default_prompt_template=plain_llm_template, max_length=300)
    pipeline = Pipeline()
    pipeline.add_node(component=node_openai, name="prompt_node", inputs=["Query"])
    return pipeline


@st.cache_resource(show_spinner=False)
def get_retrieval_augmented_pipeline():
    ds = FAISSDocumentStore(faiss_index_path="data/my_faiss_index.faiss",
                            faiss_config_path="data/my_faiss_index.json")

    retriever = EmbeddingRetriever(
        document_store=ds,
        embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
        model_format="sentence_transformers",
        top_k=2
    )
    shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])

    default_template = PromptTemplate(
        name="question-answering",
        prompt_text="Given the context please answer the question. Context: $documents; Question: "
                    "$query; Answer:",
    )
    # Let's initiate the PromptNode
    node = PromptNode("text-davinci-003", default_prompt_template=default_template,
                      api_key=st.secrets["OPENAI_API_KEY"], max_length=500)

    # Let's create a pipeline with Shaper and PromptNode
    pipeline = Pipeline()
    pipeline.add_node(component=retriever, name='retriever', inputs=['Query'])
    pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
    pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
    return pipeline


@st.cache_resource(show_spinner=False)
def get_web_retrieval_augmented_pipeline():
    search_key = st.secrets["WEBRET_API_KEY"]
    web_retriever = WebRetriever(api_key=search_key, search_engine_provider="SerperDev")
    shaper = Shaper(func="join_documents", inputs={"documents": "documents"}, outputs=["documents"])
    default_template = PromptTemplate(
        name="question-answering",
        prompt_text="Given the context please answer the question. Context: $documents; Question: "
                    "$query; Answer:",
    )
    # Let's initiate the PromptNode
    node = PromptNode("text-davinci-003", default_prompt_template=default_template,
                      api_key=st.secrets["OPENAI_API_KEY"], max_length=500)
    # Let's create a pipeline with Shaper and PromptNode
    pipeline = Pipeline()
    pipeline.add_node(component=web_retriever, name='retriever', inputs=['Query'])
    pipeline.add_node(component=shaper, name="shaper", inputs=["retriever"])
    pipeline.add_node(component=node, name="prompt_node", inputs=["shaper"])
    return pipeline


# @st.cache_resource(show_spinner=False)
# def app_init():
#     print("Loading Pipelines...")
#     p1 = get_plain_pipeline()
#     print("Loaded Plain Pipeline")
#     p2 = get_retrieval_augmented_pipeline()
#     print("Loaded Retrieval Augmented Pipeline")
#     p3 = get_web_retrieval_augmented_pipeline()
#     print("Loaded Web Retrieval Augmented Pipeline")
#     return p1, p2, p3


if 'query' not in st.session_state:
    st.session_state['query'] = ""


def set_question():
    st.session_state['query'] = st.session_state['q_drop_down']


def set_q1():
    st.session_state['query'] = QUERIES[0]


def set_q2():
    st.session_state['query'] = QUERIES[1]


def set_q3():
    st.session_state['query'] = QUERIES[2]


def set_q4():
    st.session_state['query'] = QUERIES[3]


def set_q5():
    st.session_state['query'] = QUERIES[4]