rajesh1729 commited on
Commit
53d8e52
·
verified ·
1 Parent(s): 91b39b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain.embeddings.openai import OpenAIEmbeddings
4
+ from langchain.vectorstores import Chroma
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.chains import ConversationalRetrievalChain, ConversationChain
8
+ from langchain.memory import ConversationBufferMemory
9
+ from langchain.document_loaders import PyPDFLoader
10
+
11
+ def create_sidebar():
12
+ with st.sidebar:
13
+ st.title("PDF Chat")
14
+ st.markdown("### Quick Demo of RAG")
15
+
16
+ api_key = st.text_input("OpenAI API Key:", type="password")
17
+
18
+ st.markdown("""
19
+ ### Tools Used
20
+ • OpenAI
21
+ • LangChain
22
+ • ChromaDB
23
+
24
+ ### Steps
25
+ 1. Add API key
26
+ 2. Upload PDF
27
+ 3. Chat!
28
+ """)
29
+
30
+ return api_key
31
+
32
+ def save_uploaded_file(uploaded_file, path='./uploads/'):
33
+ os.makedirs(path, exist_ok=True)
34
+ file_path = os.path.join(path, uploaded_file.name)
35
+ with open(file_path, "wb") as f:
36
+ f.write(uploaded_file.getbuffer())
37
+ return file_path
38
+
39
+ @st.cache_data
40
+ def load_texts_from_papers(papers):
41
+ all_texts = []
42
+ for paper in papers:
43
+ file_path = save_uploaded_file(paper)
44
+ loader = PyPDFLoader(file_path)
45
+ documents = loader.load()
46
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
47
+ texts = text_splitter.split_documents(documents)
48
+ all_texts.extend(texts)
49
+ os.remove(file_path)
50
+ return all_texts
51
+
52
+ @st.cache_resource
53
+ def initialize_vectorstore():
54
+ embedding = OpenAIEmbeddings(openai_api_key=st.session_state.api_key)
55
+ vectorstore = Chroma(embedding_function=embedding, persist_directory="db")
56
+ return vectorstore
57
+
58
+ def main():
59
+ st.set_page_config(page_title="PDF Chat", layout="wide")
60
+
61
+ # Get API key from sidebar
62
+ api_key = create_sidebar()
63
+
64
+ if api_key:
65
+ st.session_state.api_key = api_key
66
+
67
+ st.title("Chat with PDF")
68
+ papers = st.file_uploader("Upload PDFs", type=["pdf"], accept_multiple_files=True)
69
+
70
+ if "messages" not in st.session_state:
71
+ st.session_state.messages = []
72
+
73
+ if not api_key:
74
+ st.warning("Please enter your OpenAI API key")
75
+ return
76
+
77
+ try:
78
+ vectorstore = initialize_vectorstore()
79
+ texts = load_texts_from_papers(papers) if papers else []
80
+
81
+ if texts:
82
+ vectorstore.add_documents(texts)
83
+ qa_chain = ConversationalRetrievalChain.from_llm(
84
+ ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"),
85
+ vectorstore.as_retriever(),
86
+ memory=ConversationBufferMemory(
87
+ memory_key="chat_history",
88
+ return_messages=True
89
+ )
90
+ )
91
+ else:
92
+ memory = ConversationBufferMemory(memory_key="chat_history")
93
+ qa_chain = ConversationChain(
94
+ llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo"),
95
+ memory=memory
96
+ )
97
+
98
+ # Chat interface
99
+ for message in st.session_state.messages:
100
+ with st.chat_message(message["role"]):
101
+ st.markdown(message["content"])
102
+
103
+ if prompt := st.chat_input("Ask about your PDFs"):
104
+ st.session_state.messages.append({"role": "user", "content": prompt})
105
+ with st.chat_message("user"):
106
+ st.markdown(prompt)
107
+
108
+ with st.chat_message("assistant"):
109
+ try:
110
+ if texts:
111
+ result = qa_chain({"question": prompt})
112
+ response = result["answer"]
113
+ else:
114
+ result = qa_chain.predict(input=prompt)
115
+ response = result
116
+
117
+ st.session_state.messages.append({"role": "assistant", "content": response})
118
+ st.markdown(response)
119
+
120
+ except Exception as e:
121
+ st.error(f"Error: {str(e)}")
122
+
123
+ except Exception as e:
124
+ st.error(f"Error: {str(e)}")
125
+
126
+ if __name__ == "__main__":
127
+ main()