Spaces:
Sleeping
Sleeping
Asaad Almutareb
commited on
Commit
·
3dfd099
1
Parent(s):
8eb79b5
added an agent as a tool
Browse files- .devcontainer/Dockerfile +3 -1
- rag_app/agents/kb_retriever_agent.py +73 -0
- rag_app/agents/react_agent.py +5 -3
- rag_app/loading_data/load_S3_vector_stores.py +3 -36
- rag_app/structured_tools/agent_tools.py +0 -0
- rag_app/structured_tools/structured_tools.py +15 -3
- rag_app/templates/react_json_ger.py +47 -0
.devcontainer/Dockerfile
CHANGED
@@ -44,4 +44,6 @@ RUN echo "done 0" \
|
|
44 |
&& pyenv global ${PYTHON_VERSION} \
|
45 |
&& echo "done 3" \
|
46 |
&& curl -sSL https://install.python-poetry.org | python3 - \
|
47 |
-
&& poetry config virtualenvs.in-project true
|
|
|
|
|
|
44 |
&& pyenv global ${PYTHON_VERSION} \
|
45 |
&& echo "done 3" \
|
46 |
&& curl -sSL https://install.python-poetry.org | python3 - \
|
47 |
+
&& poetry config virtualenvs.in-project true \
|
48 |
+
&& echo "done 4" \
|
49 |
+
&& pip install -r requirements.txt
|
rag_app/agents/kb_retriever_agent.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF libraries
|
2 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
3 |
+
from langchain.agents import AgentExecutor
|
4 |
+
from langchain.agents.format_scratchpad import format_log_to_str
|
5 |
+
from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
|
6 |
+
# Import things that are needed generically
|
7 |
+
from langchain.tools.render import render_text_description
|
8 |
+
import os
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from rag_app.structured_tools.structured_tools import (
|
11 |
+
google_search, knowledgeBase_search
|
12 |
+
)
|
13 |
+
|
14 |
+
from langchain.prompts import PromptTemplate
|
15 |
+
from rag_app.templates.react_json_with_memory import template_system
|
16 |
+
# from innovation_pathfinder_ai.utils import logger
|
17 |
+
# from langchain.globals import set_llm_cache
|
18 |
+
# from langchain.cache import SQLiteCache
|
19 |
+
|
20 |
+
# set_llm_cache(SQLiteCache(database_path=".cache.db"))
|
21 |
+
# logger = logger.get_console_logger("hf_mixtral_agent")
|
22 |
+
|
23 |
+
config = load_dotenv(".env")
|
24 |
+
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
25 |
+
GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
|
26 |
+
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
|
27 |
+
|
28 |
+
# Load the model from the Hugging Face Hub
|
29 |
+
llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
30 |
+
temperature=0.1,
|
31 |
+
max_new_tokens=1024,
|
32 |
+
repetition_penalty=1.2,
|
33 |
+
return_full_text=False
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
tools = [
|
38 |
+
knowledgeBase_search,
|
39 |
+
google_search,
|
40 |
+
]
|
41 |
+
|
42 |
+
prompt = PromptTemplate.from_template(
|
43 |
+
template=template_system
|
44 |
+
)
|
45 |
+
prompt = prompt.partial(
|
46 |
+
tools=render_text_description(tools),
|
47 |
+
tool_names=", ".join([t.name for t in tools]),
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
# define the agent
|
52 |
+
chat_model_with_stop = llm.bind(stop=["\nObservation"])
|
53 |
+
agent = (
|
54 |
+
{
|
55 |
+
"input": lambda x: x["input"],
|
56 |
+
"agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
|
57 |
+
#"chat_history": lambda x: x["chat_history"],
|
58 |
+
}
|
59 |
+
| prompt
|
60 |
+
| chat_model_with_stop
|
61 |
+
| ReActJsonSingleInputOutputParser()
|
62 |
+
)
|
63 |
+
|
64 |
+
# instantiate AgentExecutor
|
65 |
+
agent_executor = AgentExecutor(
|
66 |
+
agent=agent,
|
67 |
+
tools=tools,
|
68 |
+
verbose=True,
|
69 |
+
max_iterations=10, # cap number of iterations
|
70 |
+
#max_execution_time=60, # timout at 60 sec
|
71 |
+
#return_intermediate_steps=True,
|
72 |
+
handle_parsing_errors=True,
|
73 |
+
)
|
rag_app/agents/react_agent.py
CHANGED
@@ -8,7 +8,8 @@ from langchain.tools.render import render_text_description
|
|
8 |
import os
|
9 |
from dotenv import load_dotenv
|
10 |
from rag_app.structured_tools.structured_tools import (
|
11 |
-
google_search, knowledgeBase_search
|
|
|
12 |
)
|
13 |
|
14 |
from langchain.prompts import PromptTemplate
|
@@ -39,8 +40,9 @@ llm = HuggingFaceEndpoint(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
39 |
|
40 |
|
41 |
tools = [
|
42 |
-
knowledgeBase_search,
|
43 |
-
google_search,
|
|
|
44 |
]
|
45 |
|
46 |
prompt = PromptTemplate.from_template(
|
|
|
8 |
import os
|
9 |
from dotenv import load_dotenv
|
10 |
from rag_app.structured_tools.structured_tools import (
|
11 |
+
#google_search, knowledgeBase_search,
|
12 |
+
web_research
|
13 |
)
|
14 |
|
15 |
from langchain.prompts import PromptTemplate
|
|
|
40 |
|
41 |
|
42 |
tools = [
|
43 |
+
#knowledgeBase_search,
|
44 |
+
#google_search,
|
45 |
+
web_research
|
46 |
]
|
47 |
|
48 |
prompt = PromptTemplate.from_template(
|
rag_app/loading_data/load_S3_vector_stores.py
CHANGED
@@ -10,7 +10,6 @@ from dotenv import load_dotenv
|
|
10 |
import os
|
11 |
import sys
|
12 |
import logging
|
13 |
-
from pathlib import Path
|
14 |
|
15 |
# Load environment variables from a .env file
|
16 |
config = load_dotenv(".env")
|
@@ -39,7 +38,6 @@ def get_faiss_vs():
|
|
39 |
|
40 |
# Define the destination for the downloaded file
|
41 |
VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
|
42 |
-
|
43 |
try:
|
44 |
# Download the pre-prepared vectorized index from the S3 bucket
|
45 |
print("Downloading the pre-prepared FAISS vectorized index from S3...")
|
@@ -49,36 +47,11 @@ def get_faiss_vs():
|
|
49 |
with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
|
50 |
zip_ref.extractall('./vectorstore/')
|
51 |
print("Download and extraction completed.")
|
52 |
-
return FAISS.load_local(FAISS_INDEX_PATH, embeddings,allow_dangerous_deserialization=True)
|
53 |
|
54 |
except Exception as e:
|
55 |
print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
def get_faiss_vs_from_s3(s3_loc:str,
|
60 |
-
s3_vs_name:str,
|
61 |
-
vs_dir:str='vectorstore') -> None:
|
62 |
-
""" Download the FAISS vector store from S3 bucket
|
63 |
-
|
64 |
-
Args:
|
65 |
-
s3_loc (str): Name of the S3 bucket
|
66 |
-
s3_vs_name (str): Name of the file to be downloaded
|
67 |
-
vs_dir (str): The name of the directory where the file is to be saved
|
68 |
-
"""
|
69 |
-
# Initialize an S3 client with unsigned configuration for public access
|
70 |
-
s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
|
71 |
-
# Destination folder
|
72 |
-
vs_dir_path = Path("..") / vs_dir
|
73 |
-
assert vs_dir_path.is_dir(), "Cannot find vs_dir folder"
|
74 |
-
try:
|
75 |
-
vs_destination = Path("..") / vs_dir / "faiss-insurance-agent-500.zip"
|
76 |
-
s3.download_file(s3_loc, s3_vs_name, vs_destination)
|
77 |
-
# Extract the downloaded zip file
|
78 |
-
with zipfile.ZipFile(file=vs_destination, mode='r') as zip_ref:
|
79 |
-
zip_ref.extractall(path=vs_dir_path.as_posix())
|
80 |
-
except Exception as e:
|
81 |
-
print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
|
82 |
|
83 |
|
84 |
## Chroma DB
|
@@ -97,10 +70,4 @@ def get_chroma_vs():
|
|
97 |
chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
|
98 |
chromadb.get()
|
99 |
except Exception as e:
|
100 |
-
print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
|
101 |
-
|
102 |
-
|
103 |
-
if __name__ == "__main__":
|
104 |
-
# get_faiss_vs_from_s3(s3_loc=S3_LOCATION, s3_vs_name=FAISS_VS_NAME)
|
105 |
-
pass
|
106 |
-
|
|
|
10 |
import os
|
11 |
import sys
|
12 |
import logging
|
|
|
13 |
|
14 |
# Load environment variables from a .env file
|
15 |
config = load_dotenv(".env")
|
|
|
38 |
|
39 |
# Define the destination for the downloaded file
|
40 |
VS_DESTINATION = FAISS_INDEX_PATH + ".zip"
|
|
|
41 |
try:
|
42 |
# Download the pre-prepared vectorized index from the S3 bucket
|
43 |
print("Downloading the pre-prepared FAISS vectorized index from S3...")
|
|
|
47 |
with zipfile.ZipFile(VS_DESTINATION, 'r') as zip_ref:
|
48 |
zip_ref.extractall('./vectorstore/')
|
49 |
print("Download and extraction completed.")
|
50 |
+
return FAISS.load_local(FAISS_INDEX_PATH, embeddings, allow_dangerous_deserialization=True)
|
51 |
|
52 |
except Exception as e:
|
53 |
print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
|
54 |
+
#faissdb = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
|
57 |
## Chroma DB
|
|
|
70 |
chromadb = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=embeddings)
|
71 |
chromadb.get()
|
72 |
except Exception as e:
|
73 |
+
print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
rag_app/structured_tools/agent_tools.py
ADDED
File without changes
|
rag_app/structured_tools/structured_tools.py
CHANGED
@@ -17,6 +17,7 @@ from rag_app.utils.utils import (
|
|
17 |
from rag_app.database.db_handler import (
|
18 |
add_many
|
19 |
)
|
|
|
20 |
|
21 |
import os
|
22 |
# from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
|
@@ -71,7 +72,7 @@ def knowledgeBase_search(query:str) -> str:
|
|
71 |
embedding_function=embedding_function,
|
72 |
)
|
73 |
|
74 |
-
retriever = vector_db.as_retriever()
|
75 |
# This is deprecated, changed to invoke
|
76 |
# LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
|
77 |
docs = retriever.invoke(query)
|
@@ -83,7 +84,6 @@ def knowledgeBase_search(query:str) -> str:
|
|
83 |
@tool
|
84 |
def google_search(query: str) -> str:
|
85 |
"""Verbessere die Ergebnisse durch eine Suche über die Webseite der Versicherung. Erstelle eine neue Suchanfrage, um die Erfolgschancen zu verbesseren."""
|
86 |
-
global all_sources
|
87 |
|
88 |
websearch = GoogleSearchAPIWrapper()
|
89 |
search_results:dict = websearch.results(query, 3)
|
@@ -95,4 +95,16 @@ def google_search(query: str) -> str:
|
|
95 |
else:
|
96 |
cleaner_sources = search_results
|
97 |
|
98 |
-
return cleaner_sources.__str__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from rag_app.database.db_handler import (
|
18 |
add_many
|
19 |
)
|
20 |
+
from rag_app.agents.kb_retriever_agent import agent_executor
|
21 |
|
22 |
import os
|
23 |
# from innovation_pathfinder_ai.utils import create_wikipedia_urls_from_text
|
|
|
72 |
embedding_function=embedding_function,
|
73 |
)
|
74 |
|
75 |
+
retriever = vector_db.as_retriever(search_kwargs={'k':1})
|
76 |
# This is deprecated, changed to invoke
|
77 |
# LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
|
78 |
docs = retriever.invoke(query)
|
|
|
84 |
@tool
|
85 |
def google_search(query: str) -> str:
|
86 |
"""Verbessere die Ergebnisse durch eine Suche über die Webseite der Versicherung. Erstelle eine neue Suchanfrage, um die Erfolgschancen zu verbesseren."""
|
|
|
87 |
|
88 |
websearch = GoogleSearchAPIWrapper()
|
89 |
search_results:dict = websearch.results(query, 3)
|
|
|
95 |
else:
|
96 |
cleaner_sources = search_results
|
97 |
|
98 |
+
return cleaner_sources.__str__()
|
99 |
+
|
100 |
+
@tool
|
101 |
+
def web_research(query: str) -> str:
|
102 |
+
"""Verbessere die Ergebnisse durch eine Suche über die Webseite der Versicherung. Erstelle eine neue Suchanfrage, um die Erfolgschancen zu verbesseren."""
|
103 |
+
|
104 |
+
result = agent_executor.invoke(
|
105 |
+
{
|
106 |
+
"input": query
|
107 |
+
}
|
108 |
+
)
|
109 |
+
print(result)
|
110 |
+
return result.__str__()
|
rag_app/templates/react_json_ger.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
template_system = """
|
2 |
+
Answer the following questions as best you can. You have access to the following tools:
|
3 |
+
|
4 |
+
<TOOLS>
|
5 |
+
{tools}
|
6 |
+
</TOOLS>
|
7 |
+
|
8 |
+
The way you use the tools is by specifying a json blob.
|
9 |
+
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
|
10 |
+
|
11 |
+
The only values that should be in the "action" field are: {tool_names}
|
12 |
+
|
13 |
+
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:
|
14 |
+
|
15 |
+
```
|
16 |
+
{{
|
17 |
+
"action": $TOOL_NAME,
|
18 |
+
"action_input": $INPUT
|
19 |
+
}}
|
20 |
+
```
|
21 |
+
|
22 |
+
ALWAYS use the following format:
|
23 |
+
|
24 |
+
Question: the input question you must answer
|
25 |
+
Thought: you should always think about what to do
|
26 |
+
Action:
|
27 |
+
```
|
28 |
+
$JSON_BLOB
|
29 |
+
```
|
30 |
+
Observation: the result of the action
|
31 |
+
... (this Thought/Action/Observation can repeat N times)
|
32 |
+
Thought: I now know the final answer
|
33 |
+
Final Answer: the final answer to the original input question
|
34 |
+
|
35 |
+
Begin! Reminder to always use the exact characters `Final Answer` when responding.
|
36 |
+
|
37 |
+
Previous conversation history:
|
38 |
+
<CONVERSATION_HISTORY>
|
39 |
+
{chat_history}
|
40 |
+
</CONVERSATION_HISTORY>
|
41 |
+
|
42 |
+
<NEW_INPUT>
|
43 |
+
{input}
|
44 |
+
</NEW_INPUT>
|
45 |
+
|
46 |
+
{agent_scratchpad}
|
47 |
+
"""
|