Zeta / app.py
Ritvik19's picture
Upload 2 files
60e8923 verified
raw
history blame
2.84 kB
import os
from pathlib import Path
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.llms.openai import OpenAIChat
from langchain.document_loaders import PyPDFLoader, WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
import streamlit as st
LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath("vector_store")
def load_documents():
loaders = [
PyPDFLoader(source_doc_url)
if source_doc_url.endswith(".pdf")
else WebBaseLoader(source_doc_url)
for source_doc_url in st.session_state.source_doc_urls
]
documents = []
for loader in loaders:
documents.extend(loader.load())
return documents
def split_documents(documents):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=256)
texts = text_splitter.split_documents(documents)
return texts
def embeddings_on_local_vectordb(texts):
vectordb = Chroma.from_documents(
texts,
embedding=OpenAIEmbeddings(),
persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix(),
)
vectordb.persist()
retriever = vectordb.as_retriever(search_kwargs={"k": 3})
return retriever
def query_llm(retriever, query):
qa_chain = ConversationalRetrievalChain.from_llm(
llm=OpenAIChat(),
retriever=retriever,
return_source_documents=True,
)
result = qa_chain({"question": query, "chat_history": st.session_state.messages})
result = result["answer"]
st.session_state.messages.append((query, result))
return result
def input_fields():
os.environ["OPENAI_API_KEY"] = "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"
st.session_state.source_doc_urls = [
url.strip() for url in st.sidebar.text_input("Source Document URLs").split(",")
]
def process_documents():
try:
documents = load_documents()
texts = split_documents(documents)
st.session_state.retriever = embeddings_on_local_vectordb(texts)
except Exception as e:
st.error(f"An error occurred: {e}")
def boot():
st.title("Enigma Chatbot")
input_fields()
st.sidebar.button("Submit Documents", on_click=process_documents)
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
st.chat_message("human").write(message[0])
st.chat_message("ai").write(message[1])
if query := st.chat_input():
st.chat_message("human").write(query)
response = query_llm(st.session_state.retriever, query)
st.chat_message("ai").write(response)
if __name__ == "__main__":
boot()