Asaad Almutareb commited on
Commit
3dfd099
·
1 Parent(s): 8eb79b5

added an agent as a tool

Browse files
.devcontainer/Dockerfile CHANGED
@@ -44,4 +44,6 @@ RUN echo "done 0" \
44
  && pyenv global ${PYTHON_VERSION} \
45
  && echo "done 3" \
46
  && curl -sSL https://install.python-poetry.org | python3 - \
47
- && poetry config virtualenvs.in-project true
 
 
 
44
  && pyenv global ${PYTHON_VERSION} \
45
  && echo "done 3" \
46
  && curl -sSL https://install.python-poetry.org | python3 - \
47
+ && poetry config virtualenvs.in-project true \
48
+ && echo "done 4" \
49
+ && pip install -r requirements.txt
rag_app/agents/kb_retriever_agent.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HF libraries
2
+ from langchain_huggingface import HuggingFaceEndpoint
3
+ from langchain.agents import AgentExecutor
4
+ from langchain.agents.format_scratchpad import format_log_to_str
5
+ from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
6
+ # Import things that are needed generically
7
+ from langchain.tools.render import render_text_description
8
+ import os
9
+ from dotenv import load_dotenv
10
+ from rag_app.structured_tools.structured_tools import (
11
+ google_search, knowledgeBase_search
12
+ )
13
+
14
+ from langchain.prompts import PromptTemplate
15
+ from rag_app.templates.react_json_with_memory import template_system
16
+ # from innovation_pathfinder_ai.utils import logger
17
+ # from langchain.globals import set_llm_cache
18
+ # from langchain.cache import SQLiteCache
19
+
20
+ # set_llm_cache(SQLiteCache(database_path=".cache.db"))
21
+ # logger = logger.get_console_logger("hf_mixtral_agent")
22
+
23
+ config = load_dotenv(".env")
24
+ HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
25
+ GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
26
+ GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
27
+
28
+ # Load the model from the Hugging Face Hub
29
+ llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
30
+ temperature=0.1,
31
+ max_new_tokens=1024,
32
+ repetition_penalty=1.2,
33
+ return_full_text=False
34
+ )
35
+
36
+
37
+ tools = [
38
+ knowledgeBase_search,
39
+ google_search,
40
+ ]
41
+
42
+ prompt = PromptTemplate.from_template(
43
+ template=template_system
44
+ )
45
+ prompt = prompt.partial(
46
+ tools=render_text_description(tools),
47
+ tool_names=", ".join([t.name for t in tools]),
48
+ )
49
+
50
+
51
+ # define the agent
52
+ chat_model_with_stop = llm.bind(stop=["\nObservation"])
53
+ agent = (
54
+ {
55
+ "input": lambda x: x["input"],
56
+ "agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
57
+ #"chat_history": lambda x: x["chat_history"],
58
+ }
59
+ | prompt
60
+ | chat_model_with_stop
61
+ | ReActJsonSingleInputOutputParser()
62
+ )
63
+
64
+ # instantiate AgentExecutor
65
+ agent_executor = AgentExecutor(
66
+ agent=agent,
67
+ tools=tools,
68
+ verbose=True,
69
+ max_iterations=10, # cap number of iterations
70
+ #max_execution_time=60, # timout at 60 sec
71
+ #return_intermediate_steps=True,
72
+ handle_parsing_errors=True,
73
+ )
rag_app/agents/react_agent.py CHANGED
@@ -8,7 +8,8 @@ from langchain.tools.render import render_text_description
8
  import os
9
  from dotenv import load_dotenv
10
  from rag_app.structured_tools.structured_tools import (
11
- google_search, knowledgeBase_search
 
12
  )
13
 
14
  from langchain.prompts import PromptTemplate
@@ -39,8 +40,9 @@ llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
39
 
40
 
41
  tools = [
42
- knowledgeBase_search,
43
- google_search,
 
44
  ]
45
 
46
  prompt = PromptTemplate.from_template(
 
8
  import os
9
  from dotenv import load_dotenv
10
  from rag_app.structured_tools.structured_tools import (
11
+ #google_search, knowledgeBase_search,
12
+ web_research
13
  )
14
 
15
  from langchain.prompts import PromptTemplate
 
40
 
41
 
42
  tools = [
43
+ #knowledgeBase_search,
44
+ #google_search,
45
+ web_research
46
  ]
47
 
48
  prompt = PromptTemplate.from_template(
rag_app/loading_data/load_S3_vector_stores.py CHANGED
@@ -10,7 +10,6 @@ from dotenv import load_dotenv
10
  import os
11
  import sys
12
  import logging
13
- from pathlib import Path
14
 
15
  # Load environment variables from a .env file
16
  config = load_dotenv(".env")
@@ -39,7 +38,6 @@ def get_faiss_vs():
39
 
40
  # Define the destination for the downloaded file
41
  VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
42
-
43
  try:
44
  # Download the pre-prepared vectorized index from the S3 bucket
45
  print("Downloading the pre-prepared FAISS vectorized index from S3...")
@@ -49,36 +47,11 @@ def get_faiss_vs():
49
  with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
50
  zip_ref.extractall('./vectorstore/')
51
  print("Download and extraction completed.")
52
- return FAISS.load_local(FAISS_INDEX_PATH, embeddings,allow_dangerous_deserialization=True)
53
 
54
  except Exception as e:
55
  print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
56
- # faissdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
57
-
58
-
59
- def get_faiss_vs_from_s3(s3_loc:str,
60
- s3_vs_name:str,
61
- vs_dir:str='vectorstore') -> None:
62
- """ Download the FAISS vector store from S3 bucket
63
-
64
- Args:
65
- s3_loc (str): Name of the S3 bucket
66
- s3_vs_name (str): Name of the file to be downloaded
67
- vs_dir (str): The name of the directory where the file is to be saved
68
- """
69
- # Initialize an S3 client with unsigned configuration for public access
70
- s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
71
- # Destination folder
72
- vs_dir_path = Path("..") / vs_dir
73
- assert vs_dir_path.is_dir(), "Cannot find vs_dir folder"
74
- try:
75
- vs_destination = Path("..") / vs_dir / "faiss-insurance-agent-500.zip"
76
- s3.download_file(s3_loc, s3_vs_name, vs_destination)
77
- # Extract the downloaded zip file
78
- with zipfile.ZipFile(file=vs_destination, mode='r') as zip_ref:
79
- zip_ref.extractall(path=vs_dir_path.as_posix())
80
- except Exception as e:
81
- print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
82
 
83
 
84
  ## Chroma DB
@@ -97,10 +70,4 @@ def get_chroma_vs():
97
  chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
98
  chromadb.get()
99
  except Exception as e:
100
- print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
101
-
102
-
103
- if __name__ == "__main__":
104
- # get_faiss_vs_from_s3(s3_loc=S3_LOCATION, s3_vs_name=FAISS_VS_NAME)
105
- pass
106
-
 
10
  import os
11
  import sys
12
  import logging
 
13
 
14
  # Load environment variables from a .env file
15
  config = load_dotenv(".env")
 
38
 
39
  # Define the destination for the downloaded file
40
  VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
 
41
  try:
42
  # Download the pre-prepared vectorized index from the S3 bucket
43
  print("Downloading the pre-prepared FAISS vectorized index from S3...")
 
47
  with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
48
  zip_ref.extractall('./vectorstore/')
49
  print("Download and extraction completed.")
50
+ return FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
51
 
52
  except Exception as e:
53
  print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
54
+ #faissdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  ## Chroma DB
 
70
  chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
71
  chromadb.get()
72
  except Exception as e:
73
+ print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
 
 
 
 
 
 
rag_app/structured_tools/agent_tools.py ADDED
File without changes
rag_app/structured_tools/structured_tools.py CHANGED
@@ -17,6 +17,7 @@ from rag_app.utils.utils import (
17
  from rag_app.database.db_handler import (
18
  add_many
19
  )
 
20
 
21
  import os
22
  # from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
@@ -71,7 +72,7 @@ def knowledgeBase_search(query:str) -> str:
71
  embedding_function=embedding_function,
72
  )
73
 
74
- retriever = vector_db.as_retriever()
75
  # This is deprecated, changed to invoke
76
  # LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
77
  docs = retriever.invoke(query)
@@ -83,7 +84,6 @@ def knowledgeBase_search(query:str) -> str:
83
  @tool
84
  def google_search(query: str) -> str:
85
  """Verbessere die Ergebnisse durch eine Suche über die Webseite der Versicherung. Erstelle eine neue Suchanfrage, um die Erfolgschancen zu verbesseren."""
86
- global all_sources
87
 
88
  websearch = GoogleSearchAPIWrapper()
89
  search_results:dict = websearch.results(query, 3)
@@ -95,4 +95,16 @@ def google_search(query: str) -> str:
95
  else:
96
  cleaner_sources = search_results
97
 
98
- return cleaner_sources.__str__()
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from rag_app.database.db_handler import (
18
  add_many
19
  )
20
+ from rag_app.agents.kb_retriever_agent import agent_executor
21
 
22
  import os
23
  # from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
 
72
  embedding_function=embedding_function,
73
  )
74
 
75
+ retriever = vector_db.as_retriever(search_kwargs={'k':1})
76
  # This is deprecated, changed to invoke
77
  # LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
78
  docs = retriever.invoke(query)
 
84
  @tool
85
  def google_search(query: str) -> str:
86
  """Verbessere die Ergebnisse durch eine Suche über die Webseite der Versicherung. Erstelle eine neue Suchanfrage, um die Erfolgschancen zu verbesseren."""
 
87
 
88
  websearch = GoogleSearchAPIWrapper()
89
  search_results:dict = websearch.results(query, 3)
 
95
  else:
96
  cleaner_sources = search_results
97
 
98
+ return cleaner_sources.__str__()
99
+
100
+ @tool
101
+ def web_research(query: str) -> str:
102
+ """Verbessere die Ergebnisse durch eine Suche über die Webseite der Versicherung. Erstelle eine neue Suchanfrage, um die Erfolgschancen zu verbesseren."""
103
+
104
+ result = agent_executor.invoke(
105
+ {
106
+ "input": query
107
+ }
108
+ )
109
+ print(result)
110
+ return result.__str__()
rag_app/templates/react_json_ger.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ template_system = """
2
+ Answer the following questions as best you can. You have access to the following tools:
3
+
4
+ <TOOLS>
5
+ {tools}
6
+ </TOOLS>
7
+
8
+ The way you use the tools is by specifying a json blob.
9
+ Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
10
+
11
+ The only values that should be in the "action" field are: {tool_names}
12
+
13
+ The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:
14
+
15
+ ```
16
+ {{
17
+ "action": $TOOL_NAME,
18
+ "action_input": $INPUT
19
+ }}
20
+ ```
21
+
22
+ ALWAYS use the following format:
23
+
24
+ Question: the input question you must answer
25
+ Thought: you should always think about what to do
26
+ Action:
27
+ ```
28
+ $JSON_BLOB
29
+ ```
30
+ Observation: the result of the action
31
+ ... (this Thought/Action/Observation can repeat N times)
32
+ Thought: I now know the final answer
33
+ Final Answer: the final answer to the original input question
34
+
35
+ Begin! Reminder to always use the exact characters `Final Answer` when responding.
36
+
37
+ Previous conversation history:
38
+ <CONVERSATION_HISTORY>
39
+ {chat_history}
40
+ </CONVERSATION_HISTORY>
41
+
42
+ <NEW_INPUT>
43
+ {input}
44
+ </NEW_INPUT>
45
+
46
+ {agent_scratchpad}
47
+ """