Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
4 |
+
from langchain.vectorstores import Chroma
|
5 |
+
from langchain.text_splitter import CharacterTextSplitter
|
6 |
+
from langchain.chat_models import ChatOpenAI
|
7 |
+
from langchain.chains import ConversationalRetrievalChain, ConversationChain
|
8 |
+
from langchain.memory import ConversationBufferMemory
|
9 |
+
from langchain.document_loaders import PyPDFLoader
|
10 |
+
|
11 |
+
def create_sidebar():
|
12 |
+
with st.sidebar:
|
13 |
+
st.title("PDF Chat")
|
14 |
+
st.markdown("### Quick Demo of RAG")
|
15 |
+
|
16 |
+
api_key = st.text_input("OpenAI API Key:", type="password")
|
17 |
+
|
18 |
+
st.markdown("""
|
19 |
+
### Tools Used
|
20 |
+
• OpenAI
|
21 |
+
• LangChain
|
22 |
+
• ChromaDB
|
23 |
+
|
24 |
+
### Steps
|
25 |
+
1. Add API key
|
26 |
+
2. Upload PDF
|
27 |
+
3. Chat!
|
28 |
+
""")
|
29 |
+
|
30 |
+
return api_key
|
31 |
+
|
32 |
+
def save_uploaded_file(uploaded_file, path='./uploads/'):
|
33 |
+
os.makedirs(path, exist_ok=True)
|
34 |
+
file_path = os.path.join(path, uploaded_file.name)
|
35 |
+
with open(file_path, "wb") as f:
|
36 |
+
f.write(uploaded_file.getbuffer())
|
37 |
+
return file_path
|
38 |
+
|
39 |
+
@st.cache_data
|
40 |
+
def load_texts_from_papers(papers):
|
41 |
+
all_texts = []
|
42 |
+
for paper in papers:
|
43 |
+
file_path = save_uploaded_file(paper)
|
44 |
+
loader = PyPDFLoader(file_path)
|
45 |
+
documents = loader.load()
|
46 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
47 |
+
texts = text_splitter.split_documents(documents)
|
48 |
+
all_texts.extend(texts)
|
49 |
+
os.remove(file_path)
|
50 |
+
return all_texts
|
51 |
+
|
52 |
+
@st.cache_resource
|
53 |
+
def initialize_vectorstore():
|
54 |
+
embedding = OpenAIEmbeddings(openai_api_key=st.session_state.api_key)
|
55 |
+
vectorstore = Chroma(embedding_function=embedding, persist_directory="db")
|
56 |
+
return vectorstore
|
57 |
+
|
58 |
+
def main():
|
59 |
+
st.set_page_config(page_title="PDF Chat", layout="wide")
|
60 |
+
|
61 |
+
# Get API key from sidebar
|
62 |
+
api_key = create_sidebar()
|
63 |
+
|
64 |
+
if api_key:
|
65 |
+
st.session_state.api_key = api_key
|
66 |
+
|
67 |
+
st.title("Chat with PDF")
|
68 |
+
papers = st.file_uploader("Upload PDFs", type=["pdf"], accept_multiple_files=True)
|
69 |
+
|
70 |
+
if "messages" not in st.session_state:
|
71 |
+
st.session_state.messages = []
|
72 |
+
|
73 |
+
if not api_key:
|
74 |
+
st.warning("Please enter your OpenAI API key")
|
75 |
+
return
|
76 |
+
|
77 |
+
try:
|
78 |
+
vectorstore = initialize_vectorstore()
|
79 |
+
texts = load_texts_from_papers(papers) if papers else []
|
80 |
+
|
81 |
+
if texts:
|
82 |
+
vectorstore.add_documents(texts)
|
83 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
84 |
+
ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"),
|
85 |
+
vectorstore.as_retriever(),
|
86 |
+
memory=ConversationBufferMemory(
|
87 |
+
memory_key="chat_history",
|
88 |
+
return_messages=True
|
89 |
+
)
|
90 |
+
)
|
91 |
+
else:
|
92 |
+
memory = ConversationBufferMemory(memory_key="chat_history")
|
93 |
+
qa_chain = ConversationChain(
|
94 |
+
llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"),
|
95 |
+
memory=memory
|
96 |
+
)
|
97 |
+
|
98 |
+
# Chat interface
|
99 |
+
for message in st.session_state.messages:
|
100 |
+
with st.chat_message(message["role"]):
|
101 |
+
st.markdown(message["content"])
|
102 |
+
|
103 |
+
if prompt := st.chat_input("Ask about your PDFs"):
|
104 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
105 |
+
with st.chat_message("user"):
|
106 |
+
st.markdown(prompt)
|
107 |
+
|
108 |
+
with st.chat_message("assistant"):
|
109 |
+
try:
|
110 |
+
if texts:
|
111 |
+
result = qa_chain({"question": prompt})
|
112 |
+
response = result["answer"]
|
113 |
+
else:
|
114 |
+
result = qa_chain.predict(input=prompt)
|
115 |
+
response = result
|
116 |
+
|
117 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
118 |
+
st.markdown(response)
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
st.error(f"Error: {str(e)}")
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
st.error(f"Error: {str(e)}")
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
main()
|