Pijush2023 commited on
Commit
c71d159
·
verified ·
1 Parent(s): 7a7da87

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -199
app.py CHANGED
@@ -1,212 +1,161 @@
1
  import gradio as gr
2
- import pdfplumber
3
  import os
4
- from langchain.schema import Document
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.embeddings import OpenAIEmbeddings
7
- from langchain.vectorstores import Pinecone
8
- import pinecone
9
- import pandas as pd
10
- import time
11
- from pinecone.grpc import PineconeGRPC as Pinecone
12
- from pinecone import ServerlessSpec
13
- from langchain_pinecone import PineconeVectorStore
14
- from datetime import datetime
15
- import os
16
- from langchain.document_loaders import PyPDFLoader
17
- from langchain.text_splitter import RecursiveCharacterTextSplitter
18
- from langchain.embeddings.openai import OpenAIEmbeddings
19
- from langchain.vectorstores import Pinecone
20
- from typing import TypedDict,List
21
- from langgraph.graph import StateGraph
22
- from langgraph.prebuilt import ToolNode
23
- from langchain.schema import Document
24
- from langchain.prompts import PromptTemplate
25
- from langchain.tools import Tool
26
- from langchain.llms import OpenAI
27
-
28
- # OpenAI API key
29
- openai_api_key = os.getenv("OPENAI_API_KEY")
30
- # Embedding using OpenAI
31
- embeddings = OpenAIEmbeddings(api_key=openai_api_key)
32
-
33
- # Initialize Pinecone with PineconeGRPC
34
- from pinecone import Pinecone
35
- pc = Pinecone(api_key=os.environ['PINECONE_API_KEY'])
36
- # Define index name and parameters
37
- index_name = "italy-kg"
38
- vectorstore = PineconeVectorStore(index_name=index_name, embedding=embeddings)
39
-
40
-
41
-
42
-
43
-
44
-
45
- llm=OpenAI(temperature=0,openai_api_key=openai_api_key)
46
-
47
-
48
- # Tool functions
49
- def search_vector_db(query: str, k: int = 3) -> List[Document]:
50
- docs = vectorstore.similarity_search(query, k=k)
51
- return docs
52
-
53
- def expand_query(query: str) -> str:
54
- return query
55
-
56
- def summarize_context(context: str) -> str:
57
- prompt = PromptTemplate(template="""Summarize the following Context to provide a concise overview: {context}""")
58
- summary = llm(prompt.format(context=context))
59
- return summary.strip()
60
-
61
- def generate_response(context: str, query: str) -> str:
62
- prompt = PromptTemplate(template="""Question: {question}\nContext: {context}\nAnswer:""")
63
- formatted_prompt = prompt.format(context=context, question=query)
64
- response = llm(formatted_prompt)
65
- return response.strip()
66
-
67
- # Tool objects
68
- expand_tool = Tool(
69
- name="Expand Query",
70
- func=expand_query,
71
- description="Enhance the query with additional terms or context"
72
  )
 
73
 
74
- summarize_tool = Tool(
75
- name="Summarize Context",
76
- func=summarize_context,
77
- description="Summarize the context to provide a concise overview"
 
78
  )
79
 
80
- search_tool = Tool(
81
- name="Search Vector Database",
82
- func=search_vector_db,
83
- description="Search the vector database for relevant information"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
85
 
86
- generate_tool = Tool(
87
- name="Generate Response",
88
- func=generate_response,
89
- description="Generate a response based on the context and query"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
91
 
92
- # State for the graph
93
- class State(TypedDict):
94
- question: str
95
- context: List[Document]
96
- response: str
97
- expanded_query: str
98
- summarized_context: str
99
-
100
- # Workflow node definitions
101
- def expand(state: State) -> State:
102
- state["expanded_query"] = expand_tool.func(state["question"]) # Expand the query
103
- return state
104
-
105
- def search(state: State) -> State:
106
- results = search_tool.func(state["expanded_query"]) # Search using the expanded query
107
- state["context"] = results
108
- print(f"Retrieved Documents: {[doc.page_content[:100] for doc in results]}")
109
- return state
110
-
111
- def summarize(state: State) -> State:
112
- context = " ".join(doc.page_content for doc in state["context"]) if state["context"] else ""
113
- state["summarized_context"] = summarize_tool.func(context)
114
- print(f"Summarized Context: {state['summarized_context']}")
115
- return state
116
-
117
- def generate(state: State) -> State:
118
- response = generate_tool.func(state["summarized_context"], state["question"])
119
- state["response"] = response
120
- print(f"Generated Response: {state['response']}")
121
- return state
122
-
123
- # Workflow graph
124
- workflow = StateGraph(State)
125
-
126
- workflow.add_node("expand", expand)
127
- workflow.add_node("search", search)
128
- workflow.add_node("summarize", summarize)
129
- workflow.add_node("generate", generate)
130
-
131
- workflow.set_entry_point("expand")
132
- workflow.add_edge("expand", "search")
133
- workflow.add_edge("search", "summarize")
134
- workflow.add_edge("summarize", "generate")
135
- workflow.set_finish_point("generate")
136
-
137
- graph = workflow.compile()
138
-
139
- # Function to run the graph
140
- def run_graph(question: str):
141
- result = graph.invoke({"question": question})
142
- return result["response"]
143
-
144
- # Function to clear the input and response
145
- def clear_inputs():
146
- return "", "" # Return empty strings for both the query input and response output
147
-
148
- # Create a global list to store uploaded document records
149
- uploaded_documents = []
150
-
151
-
152
-
153
- # Function to process PDF, extract text, split it into chunks, and upload to the vector DB
154
- def process_pdf(pdf_file, uploaded_documents):
155
- if pdf_file is None:
156
- return uploaded_documents, "No PDF file uploaded."
157
-
158
- # Open the PDF file and extract text page by page
159
- with pdfplumber.open(pdf_file.name) as pdf:
160
- chunks = []
161
- for page_num, page in enumerate(pdf.pages, start=1):
162
- text = page.extract_text()
163
- if text:
164
- # Split the text into chunks and attach page number metadata to each chunk
165
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
166
- page_chunks = text_splitter.split_text(text)
167
- for chunk in page_chunks:
168
- # Create a Document with the page number as metadata
169
- document = Document(page_content=chunk, metadata={"page_number": page_num})
170
- chunks.append(document)
171
-
172
- # Embed and upload the chunks into the vector database
173
- chunk_ids = vectorstore.add_documents(chunks)
174
-
175
- # Update the upload history
176
- document_record = {
177
- "Document Name": pdf_file.name,
178
- "Upload Time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
179
- "Chunks": len(chunks),
180
- "Pinecone Index": index_name
181
- }
182
 
183
- # Add the record to the global list
184
- uploaded_documents.append(document_record)
185
-
186
- # Convert the list of dictionaries into a list of lists for the dataframe
187
- table_data = [[doc["Document Name"], doc["Upload Time"], doc["Chunks"], doc["Pinecone Index"]] for doc in uploaded_documents]
188
-
189
- return table_data, f"Uploaded {len(chunks)} chunks to the vector database with page numbers included as metadata."
190
-
191
-
192
-
193
- # Gradio Interface
194
- with gr.Blocks() as demo:
195
-
196
- with gr.Row():
197
- with gr.Column():
198
- response_output = gr.Textbox(label="Response:", lines=10, max_lines=10)
199
- query_input = gr.Textbox(label="Enter your query:")
200
- with gr.Row():
201
- query_button = gr.Button("Get Response")
202
- clear_button = gr.Button("Clear") # New Clear button
203
- query_button.click(fn=run_graph, inputs=query_input, outputs=response_output)
204
- clear_button.click(fn=clear_inputs, inputs=[], outputs=[query_input, response_output]) # Clear both input and output
205
- with gr.Column():
206
- file_input = gr.File(label="Upload PDF", file_types=[".pdf"])
207
- document_table = gr.Dataframe(headers=["Document Name", "Upload Time", "Chunks", "Pinecone Index"], interactive=False)
208
- output_textbox = gr.Textbox(label="Result")
209
- process_button = gr.Button("Process PDF and Upload")
210
- process_button.click(fn=process_pdf, inputs=[file_input, gr.State([])], outputs=[document_table, output_textbox])
211
 
 
212
  demo.launch(show_error=True)
 
1
  import gradio as gr
 
2
  import os
3
+ import logging
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_community.graphs import Neo4jGraph
8
+ from typing import List, Tuple
9
+ from langchain_core.pydantic_v1 import BaseModel, Field
10
+ from langchain_core.messages import AIMessage, HumanMessage
11
+ from langchain_core.runnables import (
12
+ RunnableBranch,
13
+ RunnableLambda,
14
+ RunnablePassthrough,
15
+ RunnableParallel,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  )
17
+ from langchain_core.prompts.prompt import PromptTemplate
18
 
19
+ # Setup Neo4j
20
+ graph = Neo4jGraph(
21
+ url="neo4j+s://6457770f.databases.neo4j.io",
22
+ username="neo4j",
23
+ password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4"
24
  )
25
 
26
+ # Define entity extraction and retrieval functions
27
+ class Entities(BaseModel):
28
+ names: List[str] = Field(
29
+ ..., description="All the person, organization, or business entities that appear in the text"
30
+ )
31
+
32
+ entity_prompt = ChatPromptTemplate.from_messages([
33
+ ("system", "You are extracting organization and person entities from the text."),
34
+ ("human", "Use the given format to extract information from the following input: {question}"),
35
+ ])
36
+
37
+ chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key="YOUR_OPENAI_API_KEY")
38
+ entity_chain = entity_prompt | chat_model.with_structured_output(Entities)
39
+
40
+ def remove_lucene_chars(input: str) -> str:
41
+ return input.translate(str.maketrans({
42
+ "\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!",
43
+ "(": r"\(", ")": r"\)", "{": r"\{", "}": r"\}", "[": r"\[", "]": r"\]",
44
+ "^": r"\^", "~": r"\~", "*": r"\*", "?": r"\?", ":": r"\:", '"': r'\"',
45
+ ";": r"\;", " ": r"\ "
46
+ }))
47
+
48
+ def generate_full_text_query(input: str) -> str:
49
+ full_text_query = ""
50
+ words = [el for el in remove_lucene_chars(input).split() if el]
51
+ for word in words[:-1]:
52
+ full_text_query += f" {word}~2 AND"
53
+ full_text_query += f" {words[-1]}~2"
54
+ return full_text_query.strip()
55
+
56
+ def structured_retriever(question: str) -> str:
57
+ result = ""
58
+ entities = entity_chain.invoke({"question": question})
59
+ for entity in entities.names:
60
+ response = graph.query(
61
+ """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
62
+ YIELD node,score
63
+ CALL {
64
+ WITH node
65
+ MATCH (node)-[r:!MENTIONS]->(neighbor)
66
+ RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
67
+ UNION ALL
68
+ WITH node
69
+ MATCH (node)<-[r:!MENTIONS]-(neighbor)
70
+ RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output
71
+ }
72
+ RETURN output LIMIT 50
73
+ """,
74
+ {"query": generate_full_text_query(entity)},
75
+ )
76
+ result += "\n".join([el['output'] for el in response])
77
+ return result
78
+
79
+ def retriever_neo4j(question: str):
80
+ structured_data = structured_retriever(question)
81
+ logging.debug(f"Structured data: {structured_data}")
82
+ return structured_data
83
+
84
+ # Setup for condensing the follow-up questions
85
+ _template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question,
86
+ in its original language.
87
+ Chat History:
88
+ {chat_history}
89
+ Follow Up Input: {question}
90
+ Standalone question:"""
91
+
92
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
93
+
94
+ def _format_chat_history(chat_history: list[tuple[str, str]]) -> list:
95
+ buffer = []
96
+ for human, ai in chat_history:
97
+ buffer.append(HumanMessage(content=human))
98
+ buffer.append(AIMessage(content=ai))
99
+ return buffer
100
+
101
+ _search_query = RunnableBranch(
102
+ (
103
+ RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
104
+ run_name="HasChatHistoryCheck"
105
+ ),
106
+ RunnablePassthrough.assign(
107
+ chat_history=lambda x: _format_chat_history(x["chat_history"])
108
+ )
109
+ | CONDENSE_QUESTION_PROMPT
110
+ | ChatOpenAI(temperature=0, api_key=os.environ['OPENAI_API_KEY'])
111
+ | StrOutputParser(),
112
+ ),
113
+ RunnableLambda(lambda x: x["question"]),
114
  )
115
 
116
+ # Define the QA prompt template
117
+ template = """As an expert concierge known for being helpful and a renowned guide for Birmingham, Alabama, I assist visitors in discovering the best that the city has to offer. I also assist the visitors about various sports and activities. Given today's sunny and bright weather on {current_date}, I am well-equipped to provide valuable insights and recommendations without revealing specific locations. I draw upon my extensive knowledge of the area, including perennial events and historical context.
118
+ In light of this, how can I assist you today? Feel free to ask any questions or seek recommendations for your day in Birmingham. If there's anything specific you'd like to know or experience, please share, and I'll be glad to help. Remember, keep the question concise for a quick, short, crisp, and accurate response.
119
+ "It was my pleasure!"
120
+ {context}
121
+ Question: {question}
122
+ Helpful Answer:"""
123
+
124
+ qa_prompt = ChatPromptTemplate.from_template(template)
125
+
126
+ # Define the chain for Neo4j-based retrieval and response generation
127
+ chain_neo4j = (
128
+ RunnableParallel(
129
+ {
130
+ "context": _search_query | retriever_neo4j,
131
+ "question": RunnablePassthrough(),
132
+ }
133
+ )
134
+ | qa_prompt
135
+ | chat_model
136
+ | StrOutputParser()
137
  )
138
 
139
+ # Define the function to get the response
140
+ def get_response(question):
141
+ try:
142
+ return chain_neo4j.invoke({"question": question})
143
+ except Exception as e:
144
+ return f"Error: {str(e)}"
145
+
146
+ # Define the function to clear input and output
147
+ def clear_fields():
148
+ return "", ""
149
+
150
+ # Create the Gradio Blocks interface
151
+ with gr.Blocks() as demo:
152
+ question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...")
153
+ response_output = gr.Textbox(label="Response", placeholder="The response will appear here...", interactive=False)
154
+ get_response_btn = gr.Button("Get Response")
155
+ clean_btn = gr.Button("Clean")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ get_response_btn.click(fn=get_response, inputs=question_input, outputs=response_output)
158
+ clean_btn.click(fn=clear_fields, inputs=[], outputs=[question_input, response_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # Launch the Gradio interface
161
  demo.launch(show_error=True)