File size: 6,175 Bytes
769af53
a140209
77e791d
769af53
77e791d
 
 
 
 
 
 
f35434e
a140209
e5f6996
 
 
77e791d
8e00393
 
77e791d
e5f6996
e16892f
77e791d
e5f6996
 
 
 
77e791d
 
dde02d9
f35434e
1994eb7
36dcf79
e5f6996
f35434e
 
1994eb7
 
 
 
 
77e791d
 
 
e5f6996
 
 
 
 
 
 
 
 
 
 
 
77e791d
 
 
36dcf79
77e791d
 
 
 
 
 
 
 
 
 
 
 
 
 
f35434e
 
 
 
 
 
 
 
 
 
 
77e791d
 
 
 
 
f35434e
 
 
 
dde02d9
e16892f
a140209
2990c41
dde02d9
2990c41
dde02d9
f35434e
dde02d9
 
 
1994eb7
a7d95d1
1994eb7
a7d95d1
1994eb7
dde02d9
 
 
 
a140209
77e791d
a140209
 
 
 
dde02d9
2990c41
dde02d9
 
 
f35434e
 
e5f6996
 
 
 
dde02d9
 
 
f35434e
dde02d9
a140209
 
e16892f
 
f35434e
 
e5f6996
 
 
 
dde02d9
 
 
f35434e
 
 
 
 
dde02d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import streamlit as st
from openai import OpenAI
import glob

from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.callbacks import get_openai_callback
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel

from documents import documents

docs=documents

OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]

# Get all the filenames from the docs folder
# files = glob.glob("./docs/*.txt")

# Load files into readable documents
# docs = []
# for file in files:
#     loader = UnstructuredFileLoader(file)
#     docs.append(loader.load()[0])

# Config
with st.sidebar:
    model = st.selectbox("Model name", ["gpt-3.5-turbo"], disabled=True)
    temperature = st.number_input("Temperature", value=0.0, min_value=0.0, step=0.2, max_value=1.0, placeholder=0.0)
    k = st.number_input("Number of documents to include", value=1, min_value=1, step=1, placeholder=1)
    if st.toggle("Splitting", value=True):
        chunk_size = st.number_input("Chunk size", value=750, step=250, placeholder=750) # Defines the chunks in amount of tokens in which the files are split. Also defines the amount of tokens that are feeded into the context. 
        chunk_overlap = st.number_input("Chunk overlap", value=0, step=10, placeholder=0)
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
        splits = text_splitter.split_documents(docs)
        vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())
    else:
        vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings())


prompt_template ="""
Du heißt "Volker". Du bist ein Assistent für die Beantwortung von Fragen zu Long-Covid (Post-Covid synonym verwenden). 
Du weißt nichts über Krankheiten wie 'tumorbedingte Fatigue', 'Multiple Sklerose', 'Hashimoto-Thyreoditis' oder 'Krebs'. 
Werden Fragen zu diesen Erkrankungen gestellt, beantworte sie mit "Dazu fehlen mir Informationen".
Du gibst keine Ratschläge zur Diagnose, Behandlung oder Therapie. 
Wenn du die Antwort nicht weißt, sag einfach, dass du es nicht weißt.
Antworte immer in ganzen Sätzen und verwende korrekte Grammatik und Rechtschreibung. Antworte nur auf Deutsch. 
Antworte kurz mit maximal fünf Sätzen außer es wird von dir eine ausführlichere Antwort verlangt.
Verwende zur Beantwortung der Frage nur den retriever Kontext.
 
Frage: {question} 
Kontext: {context} 
Antwort:
""" # Source: hub.pull("rlm/rag-prompt")

# (1) Retriever
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.3, "k": k})

# (2) Prompt
prompt = ChatPromptTemplate.from_template(prompt_template)

# (3) LLM
# Define the LLM we want to use. Default is "gpt-3.5-turbo" with temperature 0. 
# Temperature is a number between 0 and 1. With 0.8 it generates more random answers, with 0.2 it is more focused on the retrieved content. With temperature = 0 it uses log-probabilities depending on the content.

llm = ChatOpenAI(model_name=model, temperature=temperature)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# rag_chain = (
#     {"context": retriever | format_docs, "question": RunnablePassthrough()}
#     | prompt
#     | llm
#     | StrOutputParser()
# )



rag_chain_from_docs = (
    RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
    | prompt
    | llm
    | StrOutputParser()
)

rag_chain = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)


st.title("🐔 Volker-Chat")

def click_button(prompt):
    st.session_state.clicked = True
    st.session_state['prompt'] = prompt


c = st.container()
c.write("Beispielfragen")
col1, col2, col3 = c.columns(3)
col1.button("Mehr zu 'Lernen'", on_click=click_button, args=["Was macht die Säule 'Lernen' aus?"])
col1.button("Was macht die Fimo Health App?", on_click=click_button, args=["Was macht die Fimo Health App?"])
col2.button("Mehr zu 'Tracken'", on_click=click_button, args=["Was macht die Säule 'Tracken' aus?"])
col2.button("Was ist Pacing?", on_click=click_button, args=["Was ist Pacing?"])
col3.button("Mehr zu 'Handeln'", on_click=click_button, args=["Was macht die Säule 'Handeln' aus?"])

if 'clicked' not in st.session_state:
    st.session_state.clicked = False

if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "content": "Ahoi! Ich bin Volker. Wie kann ich dir helfen?"}]

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["content"])

if st.session_state.clicked:
    prompt = st.session_state['prompt']
    st.chat_message("user").write(prompt)
    with get_openai_callback() as cb:
        response = rag_chain.invoke(prompt)
        st.chat_message("assistant").write(response['answer'])
        with st.expander("Kontext ansehen"):
            for citation in response["context"]:
                st.write("[...] ", str(citation.page_content), " [...]")
                st.write(str(citation.metadata['source']))
                st.write(str("---")*20)
        with st.sidebar:
            sidebar_c = st.container()
            sidebar_c.success(cb)


if prompt := st.chat_input():
    st.chat_message("user").write(prompt)
    with get_openai_callback() as cb:
        response = rag_chain.invoke(prompt)
        st.chat_message("assistant").write(response['answer'])
        with st.expander("Kontext ansehen"):
            for citation in response["context"]:
                st.write("[...] ", str(citation.page_content), " [...]")
                st.write(str(citation.metadata['source']))
                st.write(str("---")*20)
        with st.sidebar:
            sidebar_c = st.container()
            sidebar_c.success(cb)


# cleanup
st.session_state.clicked = False
vectorstore.delete_collection()