Spaces:
Runtime error
Runtime error
import streamlit as st | |
from openai import OpenAI | |
import glob | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain.callbacks import get_openai_callback | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnableParallel | |
from langchain_community.document_loaders import UnstructuredFileLoader | |
OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"] | |
# Get all the filenames from the docs folder | |
files = glob.glob("./docs/*.txt") | |
# Load files into readable documents | |
docs = [] | |
for file in files: | |
loader = UnstructuredFileLoader(file) | |
docs.append(loader.load()[0]) | |
# Config | |
with st.sidebar: | |
st.write(f"Injected documents: {'\n'.join(file for file in files)}") | |
model = st.selectbox("Model name", ["gpt-3.5-turbo"], disabled=True) | |
temperature = st.number_input("Temperature", value=0.0, min_value=0.0, step=0.2, max_value=1.0, placeholder=0.0) | |
k = st.number_input("Number of documents to include", value=1, min_value=1, step=1, placeholder=1) | |
if st.toggle("Splitting", value=False): | |
chunk_size = st.number_input("Chunk size", value=750, step=250, placeholder=750) # Defines the chunks in amount of tokens in which the files are split. Also defines the amount of tokens that are feeded into the context. | |
chunk_overlap = st.number_input("Chunk overlap", value=0, step=10, placeholder=0) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
splits = text_splitter.split_documents(docs) | |
vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings()) | |
else: | |
vectorstore = Chroma.from_documents(documents=docs, embedding=OpenAIEmbeddings()) | |
prompt_template =""" | |
You are called "Volker". You are an assistant for question-answering tasks. | |
You only answer questions about Long-Covid (use Post-Covid synonymously) and the Volker-App. | |
If you don't know the answer, just say that you don't know. Say why you don't know the answer. | |
Never answer questions about other diseases (e.g. Cancer-related fatigue, Multiple Sklerose). | |
Always answer in german language. Stay emphatic and positive. | |
When you use the word e.g "Arzt", "Ärzt", always write it as "Arzt". | |
Only use the following pieces of retrieved context to answer the question. | |
Question: {question} | |
Context: {context} | |
Answer: | |
""" # Source: hub.pull("rlm/rag-prompt") | |
# (1) Retriever | |
retriever = vectorstore.as_retriever(search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.3, "k": k}) | |
# (2) Prompt | |
prompt = ChatPromptTemplate.from_template(prompt_template) | |
# (3) LLM | |
# Define the LLM we want to use. Default is "gpt-3.5-turbo" with temperature 0. | |
# Temperature is a number between 0 and 1. With 0.8 it generates more random answers, with 0.2 it is more focused on the retrieved content. With temperature = 0 it uses log-probabilities depending on the content. | |
llm = ChatOpenAI(model_name=model, temperature=temperature) | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
# rag_chain = ( | |
# {"context": retriever | format_docs, "question": RunnablePassthrough()} | |
# | prompt | |
# | llm | |
# | StrOutputParser() | |
# ) | |
rag_chain_from_docs = ( | |
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
rag_chain = RunnableParallel( | |
{"context": retriever, "question": RunnablePassthrough()} | |
).assign(answer=rag_chain_from_docs) | |
st.title("🐔 Volker-Chat") | |
def click_button(prompt): | |
st.session_state.clicked = True | |
st.session_state['prompt'] = prompt | |
c = st.container() | |
c.write("Beispielfragen") | |
col1, col2, col3 = c.columns(3) | |
col1.button("Mehr zu 'Lernen'", on_click=click_button, args=["Was macht die Säule 'Lernen' aus?"]) | |
col1.button("Was macht die Volker-App?", on_click=click_button, args=["Was macht die Volker-App?"]) | |
col2.button("Mehr zu 'Tracken'", on_click=click_button, args=["Was macht die Säule 'Tracken' aus?"]) | |
col2.button("Welche Krankenkassen erstatten die App?", on_click=click_button, args=["Welche Krankenkassen erstatten die App?"]) | |
col3.button("Mehr zu 'Handeln'", on_click=click_button, args=["Was macht die Säule 'Handeln' aus?"]) | |
if 'clicked' not in st.session_state: | |
st.session_state.clicked = False | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [{"role": "assistant", "content": "Ahoi! Ich bin Volker. Wie kann ich dir helfen?"}] | |
for msg in st.session_state.messages: | |
st.chat_message(msg["role"]).write(msg["content"]) | |
if st.session_state.clicked: | |
prompt = st.session_state['prompt'] | |
st.chat_message("user").write(prompt) | |
with get_openai_callback() as cb: | |
response = rag_chain.invoke(prompt) | |
st.chat_message("assistant").write(response['answer']) | |
with st.expander("Kontext ansehen"): | |
st.write(response["context"]) | |
with st.sidebar: | |
sidebar_c = st.container() | |
sidebar_c.success(cb) | |
if prompt := st.chat_input(): | |
st.chat_message("user").write(prompt) | |
with get_openai_callback() as cb: | |
response = rag_chain.invoke(prompt) | |
st.chat_message("assistant").write(response['answer']) | |
with st.expander("Kontext ansehen"): | |
st.write(response["context"]) | |
with st.sidebar: | |
sidebar_c = st.container() | |
sidebar_c.success(cb) | |
# cleanup | |
st.session_state.clicked = False | |
vectorstore.delete_collection() | |