Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
# + | |
st.set_page_config(page_title="Protected Areas Database Chat", page_icon="๐ฆ") | |
st.title("Protected Areas Database Chat") | |
st.markdown(''' | |
This Chatbot is designed only to answer questions based on [PAD Technical How-Tos](https://www.protectedlands.net/pad-us-technical-how-tos/). The Chatbot will refuse to answer questions outside of this context. | |
Example queries: | |
- What is the difference between Fee and Easements? | |
- What do the gap status codes mean? | |
''') | |
# - | |
# optional | |
# os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass() | |
import bs4 | |
from langchain import hub | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_chroma import Chroma | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_community.llms import Ollama | |
from langchain_openai import ChatOpenAI | |
# + | |
llm = ChatOpenAI(model="gpt-3.5-turbo-0125") | |
# Setup LLM and QA chain | |
models = {"chatgpt3.5": ChatOpenAI(model="gpt-3.5-turbo", temperature=0, api_key=st.secrets["OPENAI_API_KEY"], streaming=True), | |
# "chatgpt4": ChatOpenAI(model="gpt-4", temperature=0, api_key=st.secrets["OPENAI_API_KEY"]), | |
"phi3": Ollama(model="phi3", temperature=0), | |
"duckdb-nsql": Ollama(model="duckdb-nsql", temperature=0), | |
"command-r-plus": Ollama(model="command-r-plus", temperature=0), | |
"mistral": Ollama(model="mistral", temperature=0), | |
"wizardlm2": Ollama(model="wizardlm2", temperature=0), | |
"sqlcoder": Ollama(model="sqlcoder", temperature=0), | |
"zephyr": Ollama(model="zephyr", temperature=0), | |
"gemma": Ollama(model="gemma", temperature=0), | |
"llama3": Ollama(model="llama3", temperature=0), | |
} | |
with st.sidebar: | |
"Non-ChatGPT models assume you are running the app locally with ollama installed." | |
choice = st.radio("Select an LLM:", models) | |
llm = models[choice] | |
# - | |
# Load, chunk and index the contents of the blog. | |
loader = WebBaseLoader( | |
web_paths=(["https://www.protectedlands.net/pad-us-technical-how-tos/", | |
"https://www.protectedlands.net/uses-of-pad-us/", | |
# "https://www.protectedlands.net/pad-us-data-structure-attributes/" | |
]), | |
bs_kwargs=dict( | |
parse_only=bs4.SoupStrainer( | |
class_=("main-body") | |
) | |
), | |
) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=400) | |
splits = text_splitter.split_documents(docs) | |
vectorstore = Chroma.from_documents(documents=splits, embedding=OpenAIEmbeddings()) | |
# Retrieve and generate using the relevant snippets of the blog. | |
retriever = vectorstore.as_retriever() | |
prompt = hub.pull("rlm/rag-prompt") | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
rag_chain = ( | |
{"context": retriever | format_docs, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
rag_chain.invoke("What is the difference between Fee and Easement?") | |
# + | |
from langchain_core.runnables import RunnableParallel | |
rag_chain_from_docs = ( | |
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
rag_chain_with_source = RunnableParallel( | |
{"context": retriever, "question": RunnablePassthrough()} | |
).assign(answer=rag_chain_from_docs) | |
rag_chain_with_source.invoke("What is the difference between Fee and Easement?") | |
# + | |
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
# Setup memory for contextual conversation | |
msgs = StreamlitChatMessageHistory() | |
memory = ConversationBufferMemory(memory_key="chat_history", chat_memory=msgs, return_messages=True) | |
#qa_chain = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, memory=memory, verbose=True) | |
if len(msgs.messages) == 0 or st.sidebar.button("Clear message history"): | |
msgs.clear() | |
msgs.add_ai_message("How can I help you?") | |
avatars = {"human": "user", "ai": "assistant"} | |
for msg in msgs.messages: | |
st.chat_message(avatars[msg.type]).write(msg.content) | |
if user_query := st.chat_input(placeholder="Ask me about PAD!"): | |
st.chat_message("user").write(user_query) | |
with st.chat_message("assistant"): | |
response = rag_chain.invoke(user_query) | |
response | |