File size: 10,557 Bytes
769af53
a140209
77e791d
613df31
d5aceae
77e791d
 
 
 
982eb2a
77e791d
 
f35434e
dd156bd
 
 
982eb2a
77e791d
982eb2a
 
b3d4a61
8e00393
 
065afe6
dd156bd
288ca36
d5aceae
dd156bd
 
 
065afe6
 
 
 
 
 
dd156bd
982eb2a
065afe6
 
 
 
 
 
dd156bd
dde02d9
d5aceae
 
 
 
dd156bd
d5aceae
 
dd156bd
d5aceae
 
 
982eb2a
 
 
 
 
 
 
 
 
 
1994eb7
b2f3ea1
d5aceae
 
 
dd156bd
1994eb7
b2f3ea1
dd156bd
77e791d
 
b3d4a61
 
 
 
 
d1ecef9
b3d4a61
 
 
 
 
 
 
 
065afe6
 
b3d4a61
 
 
 
 
 
 
 
 
 
 
 
 
982eb2a
b3d4a61
 
 
 
 
 
 
 
77e791d
 
b3d4a61
77e791d
 
 
 
 
 
 
 
 
 
 
 
 
f35434e
 
77e791d
 
 
 
 
f35434e
 
 
 
e16892f
a140209
2990c41
dde02d9
2990c41
dde02d9
f35434e
dde02d9
 
dd156bd
b2f3ea1
 
 
 
 
dd156bd
 
b2f3ea1
dd156bd
 
dde02d9
 
 
 
a140209
77e791d
a140209
 
 
 
613df31
 
 
 
 
 
dde02d9
2990c41
dde02d9
 
 
b3d4a61
 
 
 
 
 
 
f35434e
982eb2a
 
 
 
 
 
25a1d6b
d1ecef9
 
 
982eb2a
dde02d9
 
 
f35434e
dde02d9
b3d4a61
a140209
 
e16892f
 
b3d4a61
 
 
 
 
 
f35434e
982eb2a
 
 
 
 
 
25a1d6b
d1ecef9
 
 
982eb2a
dde02d9
 
 
f35434e
 
 
 
 
dde02d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import streamlit as st
from openai import OpenAI
import glob
import time
import pickle
from langchain_community.vectorstores import Chroma
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter, CharacterTextSplitter
from langchain.callbacks import get_openai_callback
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel
from langchain import VectorDBQAWithSourcesChain
from langchain.chains import RetrievalQA
import json
from documents import read_documents_from_file, create_documents, store_documents, create_faq_documents, html_to_chunks

#store_documents(html_to_chunks(), path="./docs/langchain_semantic_documents.json")
#store_documents(create_documents())
#create_faq_documents()
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]

    
#vectorstore = Chroma(persist_directory=directory, embedding_function=OpenAIEmbeddings())
st.set_page_config(initial_sidebar_state="collapsed")

data_source = st.radio("Data source", options=['FAQ', 'Blog articles'])
if data_source == 'FAQ':
    docs=read_documents_from_file("./docs/faq_docs.json")
    def_model = "gpt-3.5-turbo"
    def_temperature = 0.0
    def_k = 1
    def_chunk_size = 500
    def_chunk_overlap = 0
    directory = "./chroma_db"
elif data_source == 'Blog articles':
    docs=read_documents_from_file("./docs/langchain_semantic_documents.json")
    def_model = "gpt-3.5-turbo"
    def_temperature = 0.0
    def_k = 3
    def_chunk_size = 500
    def_chunk_overlap = 0
    directory = "./chroma_db"

with st.sidebar:
    if st.toggle("Experimental"):
        disabled = False
    else:
        disabled = True
    model = def_model
    temperature = st.number_input("Temperature", value=def_temperature, min_value=0.0, step=0.2, max_value=1.0, placeholder=def_temperature, disabled=disabled)
    k = st.number_input("Number of documents to include", value=def_k, min_value=1, step=1, placeholder=def_k, disabled=disabled)

    if st.toggle("Splitting", value=True, disabled=disabled):
        chunk_size = st.number_input("Chunk size", value=def_chunk_size, step=250, placeholder=def_chunk_size, disabled=disabled) # Defines the chunks in amount of tokens in which the files are split. Also defines the amount of tokens that are feeded into the context. 
        chunk_overlap = st.number_input("Chunk overlap", value=def_chunk_overlap, step=10, placeholder=def_chunk_overlap, disabled=disabled)
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, 
            chunk_overlap=chunk_overlap, 
            separators=[
                "\n\n",
                "\n",
                " ",
                ". "
                ]
            )
        splits = text_splitter.split_documents(docs)
        vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())
        if chunk_size != def_chunk_size | chunk_overlap != def_chunk_overlap:
            splits = text_splitter.split_documents(docs)
            vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings())
            print("Created new vectordb for this session.")
    else:
        vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings())
        print("Used vectordb with all blog articles.")


if data_source == 'FAQ':
    prompt_template ="""
    Du bist Volker. Du beantwortest häufig gestellte Fragen. 
    Beantworte nur die Fragen zu denen du einen Kontext hast. Wenn du die Antwort nicht weißt, sag dass du es nicht weißt.
    Antworte immer in ganzen Sätzen und verwende korrekte Grammatik und Rechtschreibung. Antworte nur auf Deutsch. 
    Antworte kurz und nur im Sinne des vorhandenen Kontext.
    
    Frage: {question} 
    Kontext: {context} 
    Antwort:
    """
else:
    prompt_template ="""
    ###
    Wenn du gefragt wirst wer du bist: Du heißt "Volker". Du bist ein Assistent für die Beantwortung von Fragen zu Long-Covid (Post-Covid synonym verwenden).
    Deine Quellen sind Blogartikel von Fimo Health. 

    ###
    Du weißt nichts über Krankheiten wie 'tumorbedingte Fatigue', 'Multiple Sklerose', 'Hashimoto-Thyreoditis' oder 'Krebs'. 
    Werden Fragen zu diesen Erkrankungen gestellt, beantworte sie mit "Dazu fehlen mir Informationen".

    ###
    Du beantwortest keine Fragen zu 'Tod', 'Suizid', 'Selbstverletzung', oder anderen potenziell schädigenden Themen.
    Werden Fragen zum 'Tod' gestellt, verweise auf den behandelnden Arzt.
    Bei Fragen zu Suizid verweise auf die Telefonseelsorge: 0800 1110111

    ###
    Du gibst keine Ratschläge zur Diagnose, Behandlung oder Therapie. 
    Wenn du die Antwort nicht weißt oder du keinen Kontext hast, sage dass du es nicht weißt.
    Wenn du allgemeine unspezifische Fragen gestellt bekommst, antworte, dass du die Frage nicht verstehst und frage nach einer präziseren Fragestellung.
    Antworte immer in ganzen Sätzen und verwende korrekte Grammatik und Rechtschreibung. Antworte nur auf Deutsch. 
    Antworte kurz mit maximal fünf Sätzen außer es wird von dir eine ausführlichere Antwort verlangt.
    Verwende zur Beantwortung der Frage nur den vorhandenen Kontext.
    
    Frage: {question} 
    Kontext: {context} 
    Antwort:
    """ # Source: hub.pull("rlm/rag-prompt")

# (1) Retriever
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.75, "k": k})

# (2) Prompt
prompt = ChatPromptTemplate.from_template(prompt_template)

# (3) LLM
# Define the LLM we want to use. Default is "gpt-3.5-turbo" with temperature 0. 
# Temperature is a number between 0 and 1. With 0.8 it generates more random answers, with 0.2 it is more focused on the retrieved content. With temperature = 0 it uses log-probabilities depending on the content.

llm = ChatOpenAI(model_name=model, temperature=temperature)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain_from_docs = (
    RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
    | prompt
    | llm
    | StrOutputParser()
)

rag_chain = RunnableParallel(
    {"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)

st.title("🐔 Volker-Chat")

def click_button(prompt):
    st.session_state.clicked = True
    st.session_state['prompt'] = prompt


c = st.container()
c.write("Beispielfragen")
if data_source == 'Blog articles':
    examples = ['Was ist Pacing?', 'Wie funktioniert die Wiedereingliederung?', 'Sollte ich eine Reha machen?']
    for i, col in enumerate(c.columns(len(examples))):
        question = examples[i]
        col.button(question, on_click=click_button, args=[question])

elif data_source == 'FAQ':
    examples = ['Wie komme ich an meinen PDF-Report?', 'Wer steckt hinter den Kurs-Inhalten?', 'Wozu dient der Check-Out?']
    for i, col in enumerate(c.columns(len(examples))):
        question = examples[i]
        col.button(question, on_click=click_button, args=[question])

if 'clicked' not in st.session_state:
    st.session_state.clicked = False

if "messages" not in st.session_state:
    st.session_state["messages"] = [{"role": "assistant", "content": "Ahoi! Ich bin Volker. Wie kann ich dir helfen?"}]

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["content"])

# Streamed response emulator
def response_generator(response):
    for word in response.split():
        yield word + " "
        time.sleep(0.05)

if st.session_state.clicked:
    prompt = st.session_state['prompt']
    st.chat_message("user").write(prompt)
    with get_openai_callback() as cb:
        response = rag_chain.invoke(prompt)
        print(response)
        if response['context'] != []:
            response_stream = response_generator(response['answer'])
            st.chat_message("assistant").write_stream(response_stream)
        else:
            response_stream = response_generator("Dazu kann ich dir leider keine Antwort geben. Bitte versuche eine andere Frage.")
            st.chat_message("assistant").write_stream(response_stream)
        with st.expander("Kontext ansehen"):
            if len(response['context'][0].page_content) > 50:
                for i, citation in enumerate(response["context"]):
                    print(citation.metadata)
                    st.write(f"[{i+1}] ", str(citation.page_content))
                    st.write(str(citation.metadata['source']))
                    section = ""
                    if (len(list(citation.metadata.values())) > 1) & (data_source=='Blog articles'):
                        for chapter in list(citation.metadata.values())[:-1]:
                            section += f"{chapter} "
                        st.write(f"Abschnitt: '{section}'")
                    st.write(str("---")*20)
        with st.sidebar:
            sidebar_c = st.container()
            sidebar_c.success(cb)



if prompt := st.chat_input():
    st.chat_message("user").write(prompt)
    with get_openai_callback() as cb:
        response = rag_chain.invoke(prompt)
        if response['context'] != []:
            response_stream = response_generator(response['answer'])
            st.chat_message("assistant").write_stream(response_stream)
        else:
            response_stream = response_generator("Dazu kann ich dir leider keine Antwort geben. Bitte versuche eine andere Frage.")
            st.chat_message("assistant").write_stream(response_stream)
        with st.expander("Kontext ansehen"):
            if len(response['context'][0].page_content) > 50:
                for i, citation in enumerate(response["context"]):
                    print(citation.metadata)
                    st.write(f"[{i+1}] ", str(citation.page_content))
                    st.write(str(citation.metadata['source']))
                    section = ""
                    if (len(list(citation.metadata.values())) > 1) & (data_source=='Blog articles'):
                        for chapter in list(citation.metadata.values())[:-1]:
                            section += f"{chapter} "
                        st.write(f"Abschnitt: '{section}'")
                    st.write(str("---")*20)
        with st.sidebar:
            sidebar_c = st.container()
            sidebar_c.success(cb)


# cleanup
st.session_state.clicked = False
vectorstore.delete_collection()