Nguyen Quang Truong commited on
Commit
5c054cc
·
1 Parent(s): 58106c5
Files changed (2) hide show
  1. Agent/utils.py +2 -2
  2. utils.py +67 -0
Agent/utils.py CHANGED
@@ -11,12 +11,12 @@ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
11
 
12
  def config():
13
  load_dotenv()
14
-
15
  # Set up Neo4J & Gemini API
16
  os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI")
17
  os.environ["NEO4J_USERNAME"] = os.getenv("NEO4J_USERNAME")
18
  os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD")
19
- os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY")
20
 
21
  def load_prompt(filepath):
22
  with open(filepath, "r") as file:
 
11
 
12
  def config():
13
  load_dotenv()
14
+
15
  # Set up Neo4J & Gemini API
16
  os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI")
17
  os.environ["NEO4J_USERNAME"] = os.getenv("NEO4J_USERNAME")
18
  os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD")
19
+ os.environ["GEMINI_API_KEY"] = os.getenv("GEMINI_API_KEY")
20
 
21
  def load_prompt(filepath):
22
  with open(filepath, "r") as file:
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ from dotenv import load_dotenv
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain_community.graphs import Neo4jGraph
6
+ from langchain_core.prompts.prompt import PromptTemplate
7
+ from langchain.chains import GraphCypherQAChain
8
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
9
+
10
+ def config():
11
+ load_dotenv()
12
+
13
+ # Set up Neo4J & Gemini API
14
+ os.environ["NEO4J_URI"] = os.getenv("NEO4J_URI")
15
+ os.environ["NEO4J_USERNAME"] = os.getenv("NEO4J_USERNAME")
16
+ os.environ["NEO4J_PASSWORD"] = os.getenv("NEO4J_PASSWORD")
17
+ os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY")
18
+
19
+ def load_prompt(filepath):
20
+ with open(filepath, "r") as file:
21
+ prompt = yaml.safe_load(file)
22
+
23
+ return prompt
24
+
25
+ def init_():
26
+ config()
27
+ graph = Neo4jGraph(enhanced_schema= True)
28
+ llm = ChatGoogleGenerativeAI(
29
+ model= "gemini-1.5-flash-latest",
30
+ temperature = 0
31
+ )
32
+
33
+ return graph, llm
34
+
35
+ def get_llm_response(query):
36
+ # Connect to Neo4J Knowledge Graph
37
+ knowledge_graph, llm_chat = init_()
38
+ cypher_prompt = load_prompt("prompts/cypher_prompt.yaml")
39
+ qa_prompt = load_prompt("prompts/qa_prompt.yaml")
40
+
41
+ CYPHER_GENERATION_PROMPT = PromptTemplate(**cypher_prompt)
42
+ QA_GENERATION_PROMPT = PromptTemplate(**qa_prompt)
43
+
44
+ chain = GraphCypherQAChain.from_llm(
45
+ llm_chat, graph=knowledge_graph, verbose=True,
46
+ cypher_prompt= CYPHER_GENERATION_PROMPT,
47
+ qa_prompt= QA_GENERATION_PROMPT
48
+ )
49
+
50
+ return chain.invoke({"query": query})["result"]
51
+
52
+ def llm_answer(message, history):
53
+
54
+
55
+ try:
56
+ response = get_llm_response(message["text"])
57
+ except Exception:
58
+ response = "Exception"
59
+ except Error:
60
+ response = "Error"
61
+ return response
62
+
63
+ # if __name__ == "__main__":
64
+ # message = "Have any company recruiting jobs about Machine Learning and coresponding job titles?"
65
+ # history = [("What's your name?", "My name is Gemini")]
66
+ # resp = llm_answer(message, history)
67
+ # print(resp)