Akshayram1 commited on
Commit
eb8bab3
·
verified ·
1 Parent(s): 011aa4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ from dotenv import load_dotenv
4
+ import streamlit as st
5
+ from langchain.embeddings.openai import OpenAIEmbeddings
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.document_loaders import PyPDFLoader
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.schema import Document
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain.vectorstores import Neo4jVector
12
+ from langchain.graphs import Neo4jGraph
13
+ from langchain_experimental.graph_transformers import LLMGraphTransformer
14
+ from langchain.chains.graph_qa.cypher import GraphCypherQAChain
15
+ from neo4j import GraphDatabase
16
+
17
+ # Add Llama-Index imports
18
+ from llama_index.core import SimpleDirectoryReader, KnowledgeGraphIndex, Settings
19
+ from llama_index.core.graph_stores import SimpleGraphStore
20
+ from llama_index.core import StorageContext
21
+ from llama_index.llms.huggingface import HuggingFaceInferenceAPI
22
+ from langchain.embeddings import HuggingFaceEmbeddings
23
+ from llama_index.embeddings.langchain import LangchainEmbedding
24
+
25
+ def main():
26
+ st.set_page_config(
27
+ layout="wide",
28
+ page_title="MayaJal",
29
+ page_icon=":graph:"
30
+ )
31
+
32
+ # Debug statement
33
+ st.write("Starting the app...")
34
+
35
+ st.sidebar.image('logo.png', use_column_width=True)
36
+ with st.sidebar.expander("Expand Me"):
37
+ st.markdown("""
38
+ This application allows you to upload a PDF file, extract its content into a Neo4j graph database, and perform queries using natural language.
39
+ It leverages LangChain and OpenAI's GPT models to generate Cypher queries that interact with the Neo4j database in real-time.
40
+ """)
41
+ st.title("Mayajal: Realtime GraphRAG App")
42
+
43
+ load_dotenv()
44
+
45
+ # Debug statement
46
+ st.write("Loaded environment variables.")
47
+
48
+ # Set OpenAI API key
49
+ if 'OPENAI_API_KEY' not in st.session_state:
50
+ st.sidebar.subheader("OpenAI API Key")
51
+ openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type='password')
52
+ if openai_api_key:
53
+ os.environ['OPENAI_API_KEY'] = openai_api_key
54
+ st.session_state['OPENAI_API_KEY'] = openai_api_key
55
+ st.sidebar.success("OpenAI API Key set successfully.")
56
+ embeddings = OpenAIEmbeddings()
57
+ llm = ChatOpenAI(model_name="gpt-4o") # Use model that supports function calling
58
+ st.session_state['embeddings'] = embeddings
59
+ st.session_state['llm'] = llm
60
+ # Debug statement
61
+ st.write("OpenAI API Key set and models initialized.")
62
+ else:
63
+ # Debug statement
64
+ st.write("OpenAI API Key already set.")
65
+ embeddings = st.session_state['embeddings']
66
+ llm = st.session_state['llm']
67
+
68
+ # Initialize variables
69
+ neo4j_url = None
70
+ neo4j_username = None
71
+ neo4j_password = None
72
+ graph = None
73
+
74
+ # Set Neo4j connection details
75
+ if 'neo4j_connected' not in st.session_state:
76
+ st.sidebar.subheader("Connect to Neo4j Database")
77
+ neo4j_url = st.sidebar.text_input("Neo4j URL:", value="neo4j+s://<your-neo4j-url>")
78
+ neo4j_username = st.sidebar.text_input("Neo4j Username:", value="neo4j")
79
+ neo4j_password = st.sidebar.text_input("Neo4j Password:", type='password')
80
+ connect_button = st.sidebar.button("Connect")
81
+ if connect_button and neo4j_password:
82
+ try:
83
+ graph = Neo4jGraph(
84
+ url=neo4j_url,
85
+ username=neo4j_username,
86
+ password=neo4j_password
87
+ )
88
+ st.session_state['graph'] = graph
89
+ st.session_state['neo4j_connected'] = True
90
+ # Store connection parameters for later use
91
+ st.session_state['neo4j_url'] = neo4j_url
92
+ st.session_state['neo4j_username'] = neo4j_username
93
+ st.session_state['neo4j_password'] = neo4j_password
94
+ st.sidebar.success("Connected to Neo4j database.")
95
+ except Exception as e:
96
+ st.error(f"Failed to connect to Neo4j: {e}")
97
+ else:
98
+ graph = st.session_state['graph']
99
+ neo4j_url = st.session_state['neo4j_url']
100
+ neo4j_username = st.session_state['neo4j_username']
101
+ neo4j_password = st.session_state['neo4j_password']
102
+
103
+ # Ensure that the Neo4j connection is established before proceeding
104
+ if graph is not None:
105
+ # Debug statement
106
+ st.write("Neo4j connection established.")
107
+ # File uploader
108
+ uploaded_file = st.file_uploader("Please select a PDF file.", type="pdf")
109
+
110
+ if uploaded_file is not None and 'qa' not in st.session_state:
111
+ with st.spinner("Processing the PDF..."):
112
+ # Save uploaded file to temporary file
113
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
114
+ tmp_file.write(uploaded_file.read())
115
+ tmp_file_path = tmp_file.name
116
+ # Debug statement
117
+ st.write("PDF file uploaded and saved to temporary file.")
118
+
119
+ # Process document using Llama-Index
120
+ index = process_document(tmp_file_path, graph)
121
+
122
+ # Store the index in session state
123
+ st.session_state['index'] = index
124
+
125
+ # Use the stored connection parameters
126
+ index = Neo4jVector.from_existing_graph(
127
+ embedding=embeddings,
128
+ url=neo4j_url,
129
+ username=neo4j_username,
130
+ password=neo4j_password,
131
+ database="neo4j",
132
+ node_label="Patient", # Adjust node_label as needed
133
+ text_node_properties=["id", "text"],
134
+ embedding_node_property="embedding",
135
+ index_name="vector_index",
136
+ keyword_index_name="entity_index",
137
+ search_type="hybrid"
138
+ )
139
+
140
+ st.success(f"{uploaded_file.name} preparation is complete.")
141
+
142
+ # Retrieve the graph schema
143
+ schema = graph.get_schema
144
+
145
+ # Set up the QA chain
146
+ template = """
147
+ Task: Generate a Cypher statement to query the graph database.
148
+
149
+ Instructions:
150
+ Use only relationship types and properties provided in schema.
151
+ Do not use other relationship types or properties that are not provided.
152
+
153
+ schema:
154
+ {schema}
155
+
156
+ Note: Do not include explanations or apologies in your answers.
157
+ Do not answer questions that ask anything other than creating Cypher statements.
158
+ Do not include any text other than generated Cypher statements.
159
+
160
+ Question: {question}"""
161
+
162
+ question_prompt = PromptTemplate(
163
+ template=template,
164
+ input_variables=["schema", "question"]
165
+ )
166
+
167
+ qa = GraphCypherQAChain.from_llm(
168
+ llm=llm,
169
+ graph=graph,
170
+ cypher_prompt=question_prompt,
171
+ verbose=True,
172
+ allow_dangerous_requests=True
173
+ )
174
+ st.session_state['qa'] = qa
175
+ else:
176
+ # Debug statement
177
+ st.write("Neo4j connection not established.")
178
+ st.warning("Please connect to the Neo4j database before you can upload a PDF.")
179
+
180
+ if 'qa' in st.session_state:
181
+ st.subheader("Ask a Question")
182
+ with st.form(key='question_form'):
183
+ question = st.text_input("Enter your question:")
184
+ submit_button = st.form_submit_button(label='Submit')
185
+
186
+ if submit_button and question:
187
+ with st.spinner("Generating answer..."):
188
+ res = st.session_state['qa'].invoke({"query": question})
189
+ st.write("\n**Answer:**\n" + res['result'])
190
+
191
+ def process_document(file_path, graph):
192
+ # Initialize Llama-Index components
193
+ Settings.chunk_size = 512
194
+
195
+ # Create graph store
196
+ graph_store = SimpleGraphStore()
197
+ storage_context = StorageContext.from_defaults(graph_store=graph_store)
198
+
199
+ # Load document
200
+ documents = SimpleDirectoryReader(file_path).load_data()
201
+
202
+ # Create Knowledge Graph Index
203
+ index = KnowledgeGraphIndex.from_documents(
204
+ documents=documents,
205
+ max_triplets_per_chunk=3,
206
+ storage_context=storage_context,
207
+ include_embeddings=True
208
+ )
209
+
210
+ # Convert to Neo4j
211
+ g = index.get_networkx_graph()
212
+ for node in g.nodes():
213
+ cypher = f"""
214
+ CREATE (n:{node['type']} {{id: '{node['id']}', text: '{node['text']}'}})
215
+ """
216
+ graph.query(cypher)
217
+
218
+ for edge in g.edges():
219
+ cypher = f"""
220
+ MATCH (a), (b)
221
+ WHERE a.id = '{edge[0]}' AND b.id = '{edge[1]}'
222
+ CREATE (a)-[r:{edge['relationship']}]->(b)
223
+ """
224
+ graph.query(cypher)
225
+
226
+ return index
227
+
228
+ if __name__ == "__main__":
229
+ main()