Spaces:
Sleeping
Sleeping
mriusero
commited on
Commit
·
cdb8feb
1
Parent(s):
3992c41
core: clear status
Browse files- prompt.md +6 -4
- src/inference.py +28 -23
- src/tools/retrieve_knowledge.py +10 -7
- src/utils/__init__.py +2 -1
- src/utils/vector_store.py +51 -1
- src/workflow.py +17 -2
prompt.md
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish
|
2 |
your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
|
|
|
|
|
3 |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of
|
4 |
numbers and/or strings.
|
5 |
If you are asked for a number, don’t use comma to write your number neither use units such as $ or percent
|
6 |
sign unless specified otherwise.
|
7 |
-
If you are asked for a string, don’t use articles, neither abbreviations (e.g. for cities)
|
8 |
-
|
9 |
-
If
|
10 |
-
in the list is a number or a string.
|
|
|
1 |
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish
|
2 |
your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
|
3 |
+
First, search in your knowledge base the final answer of the question by query the full question.
|
4 |
+
Then, if you find the answer, report it. If you do not find the answer, think about the question and try to solve it step by step with other tools.
|
5 |
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of
|
6 |
numbers and/or strings.
|
7 |
If you are asked for a number, don’t use comma to write your number neither use units such as $ or percent
|
8 |
sign unless specified otherwise.
|
9 |
+
If you are asked for a string, don’t use articles, neither abbreviations (e.g. for cities).
|
10 |
+
If the final answer is just one word, put the first letter in uppercase and the rest in lowercase (e.g. "hello" -> "Hello").
|
11 |
+
If the final answer is a list of strings, write all in lowercase and separate words with comma and space (e.g. "Fruits,Vegetables,Drinks" -> "fruits, vegetables, drinks" or "85,62,18" -> "85, 62, 18").
|
12 |
+
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
src/inference.py
CHANGED
@@ -49,7 +49,8 @@ class Agent:
|
|
49 |
|
50 |
}
|
51 |
self.log = []
|
52 |
-
self.
|
|
|
53 |
|
54 |
@staticmethod
|
55 |
def save_log(messages, task_id, truth, final_answer=None):
|
@@ -62,26 +63,30 @@ class Agent:
|
|
62 |
)
|
63 |
|
64 |
@staticmethod
|
65 |
-
def get_tools():
|
66 |
"""Generate the tools.json file with the tools to be used by the agent."""
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
def make_initial_request(self, input):
|
87 |
"""Make the initial request to the agent with the given input."""
|
@@ -92,7 +97,7 @@ class Agent:
|
|
92 |
{"role": "user", "content": input},
|
93 |
{
|
94 |
"role": "assistant",
|
95 |
-
"content": "Let's tackle this problem,
|
96 |
"prefix": True,
|
97 |
},
|
98 |
]
|
@@ -104,7 +109,7 @@ class Agent:
|
|
104 |
"stop": None,
|
105 |
"random_seed": None,
|
106 |
"response_format": None,
|
107 |
-
"tools": self.
|
108 |
"tool_choice": 'auto',
|
109 |
"presence_penalty": 0,
|
110 |
"frequency_penalty": 0,
|
@@ -203,6 +208,6 @@ class Agent:
|
|
203 |
response = self.client.agents.complete(
|
204 |
agent_id=self.agent_id,
|
205 |
messages=messages,
|
206 |
-
tools=self.
|
207 |
tool_choice='auto',
|
208 |
)
|
|
|
49 |
|
50 |
}
|
51 |
self.log = []
|
52 |
+
self.first_tools = self.get_tools(first=True)
|
53 |
+
self.all_tools = self.get_tools(first=False)
|
54 |
|
55 |
@staticmethod
|
56 |
def save_log(messages, task_id, truth, final_answer=None):
|
|
|
63 |
)
|
64 |
|
65 |
@staticmethod
|
66 |
+
def get_tools(first=None):
|
67 |
"""Generate the tools.json file with the tools to be used by the agent."""
|
68 |
+
if first:
|
69 |
+
return generate_tools_json(
|
70 |
+
[retrieve_knowledge]
|
71 |
+
).get('tools')
|
72 |
+
else:
|
73 |
+
return generate_tools_json(
|
74 |
+
[
|
75 |
+
web_search,
|
76 |
+
visit_webpage,
|
77 |
+
retrieve_knowledge,
|
78 |
+
# load_file,
|
79 |
+
reverse_text,
|
80 |
+
analyze_chess,
|
81 |
+
# analyze_document,
|
82 |
+
classify_foods,
|
83 |
+
transcribe_audio,
|
84 |
+
execute_code,
|
85 |
+
analyze_excel,
|
86 |
+
analyze_youtube_video,
|
87 |
+
calculate_sum,
|
88 |
+
]
|
89 |
+
).get('tools')
|
90 |
|
91 |
def make_initial_request(self, input):
|
92 |
"""Make the initial request to the agent with the given input."""
|
|
|
97 |
{"role": "user", "content": input},
|
98 |
{
|
99 |
"role": "assistant",
|
100 |
+
"content": "Let's tackle this problem, ",
|
101 |
"prefix": True,
|
102 |
},
|
103 |
]
|
|
|
109 |
"stop": None,
|
110 |
"random_seed": None,
|
111 |
"response_format": None,
|
112 |
+
"tools": self.all_tools,
|
113 |
"tool_choice": 'auto',
|
114 |
"presence_penalty": 0,
|
115 |
"frequency_penalty": 0,
|
|
|
208 |
response = self.client.agents.complete(
|
209 |
agent_id=self.agent_id,
|
210 |
messages=messages,
|
211 |
+
tools=self.all_tools,
|
212 |
tool_choice='auto',
|
213 |
)
|
src/tools/retrieve_knowledge.py
CHANGED
@@ -2,7 +2,7 @@ from src.utils.tooling import tool
|
|
2 |
|
3 |
def format_the(query, results):
|
4 |
|
5 |
-
if results == "No relevant data found in the knowledge database. Have you checked any webpages? If so, please try to find more relevant data.":
|
6 |
return results
|
7 |
else:
|
8 |
formatted_text = f"# Knowledge for '{query}' \n\n"
|
@@ -10,9 +10,8 @@ def format_the(query, results):
|
|
10 |
try:
|
11 |
for i in range(len(results['documents'])):
|
12 |
formatted_text += f"## Document {i + 1} ---\n"
|
13 |
-
formatted_text += f"- Title: {results['metadatas'][i]['title']}\n"
|
14 |
-
formatted_text += f"- URL: {results['metadatas'][i]['url']}\n"
|
15 |
formatted_text += f"- Content: '''\n{results['documents'][i]}\n'''\n"
|
|
|
16 |
formatted_text += f"---\n\n"
|
17 |
except Exception as e:
|
18 |
return f"Error: Index out of range. Please check the results structure. {str(e)}"
|
@@ -28,15 +27,19 @@ def retrieve_knowledge(query: str, n_results: int = 2) -> str:
|
|
28 |
"""
|
29 |
try:
|
30 |
from src.utils.vector_store import retrieve_from_database
|
31 |
-
distance_threshold = 0.
|
32 |
results = retrieve_from_database(
|
33 |
query=query,
|
34 |
n_results=n_results,
|
35 |
distance_threshold=distance_threshold
|
36 |
)
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
except Exception as e:
|
41 |
-
|
|
|
42 |
|
|
|
2 |
|
3 |
def format_the(query, results):
|
4 |
|
5 |
+
if results == "No relevant data found in the knowledge database. Have you checked any webpages or use any tools? If so, please try to find more relevant data.":
|
6 |
return results
|
7 |
else:
|
8 |
formatted_text = f"# Knowledge for '{query}' \n\n"
|
|
|
10 |
try:
|
11 |
for i in range(len(results['documents'])):
|
12 |
formatted_text += f"## Document {i + 1} ---\n"
|
|
|
|
|
13 |
formatted_text += f"- Content: '''\n{results['documents'][i]}\n'''\n"
|
14 |
+
formatted_text += f"- Metadata: {results['metadatas'][i]}\n"
|
15 |
formatted_text += f"---\n\n"
|
16 |
except Exception as e:
|
17 |
return f"Error: Index out of range. Please check the results structure. {str(e)}"
|
|
|
27 |
"""
|
28 |
try:
|
29 |
from src.utils.vector_store import retrieve_from_database
|
30 |
+
distance_threshold = 0.4
|
31 |
results = retrieve_from_database(
|
32 |
query=query,
|
33 |
n_results=n_results,
|
34 |
distance_threshold=distance_threshold
|
35 |
)
|
36 |
+
results_formatted = format_the(query, results)
|
37 |
+
if results_formatted:
|
38 |
+
return results_formatted
|
39 |
+
else:
|
40 |
+
return "No relevant data found in the knowledge database. Have you checked any webpages or use any tools? If so, please try to find more relevant data."
|
41 |
|
42 |
except Exception as e:
|
43 |
+
print(f"Error retrieving knowledge: {e}")
|
44 |
+
return f"No relevant data found in the knowledge database. Have you checked any webpages or use any tools? If so, please try to find more relevant data."
|
45 |
|
src/utils/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
from .api import fetch_questions, submit_answers, get_file
|
|
|
|
1 |
+
from .api import fetch_questions, submit_answers, get_file
|
2 |
+
from .vector_store import load_in_vector_db
|
src/utils/vector_store.py
CHANGED
@@ -4,6 +4,7 @@ from mistralai import Mistral
|
|
4 |
import numpy as np
|
5 |
import time
|
6 |
import chromadb
|
|
|
7 |
import json
|
8 |
import hashlib
|
9 |
|
@@ -171,7 +172,7 @@ def retrieve_from_database(query, collection_name=COLLECTION_NAME, n_results=5,
|
|
171 |
"documents": []
|
172 |
}
|
173 |
for i, distance in enumerate(raw_results['distances'][0]):
|
174 |
-
if distance
|
175 |
filtered_results['ids'].append(raw_results['ids'][0][i])
|
176 |
filtered_results['distances'].append(distance)
|
177 |
filtered_results['metadatas'].append(raw_results['metadatas'][0][i])
|
@@ -184,3 +185,52 @@ def retrieve_from_database(query, collection_name=COLLECTION_NAME, n_results=5,
|
|
184 |
return results
|
185 |
else:
|
186 |
return raw_results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
import time
|
6 |
import chromadb
|
7 |
+
from chromadb.config import Settings
|
8 |
import json
|
9 |
import hashlib
|
10 |
|
|
|
172 |
"documents": []
|
173 |
}
|
174 |
for i, distance in enumerate(raw_results['distances'][0]):
|
175 |
+
if distance <= distance_threshold:
|
176 |
filtered_results['ids'].append(raw_results['ids'][0][i])
|
177 |
filtered_results['distances'].append(distance)
|
178 |
filtered_results['metadatas'].append(raw_results['metadatas'][0][i])
|
|
|
185 |
return results
|
186 |
else:
|
187 |
return raw_results
|
188 |
+
|
189 |
+
|
190 |
+
def search_documents(collection_name=COLLECTION_NAME, query=None, query_embedding=None, metadata_filter=None, n_results=10):
|
191 |
+
"""
|
192 |
+
Search for documents in a ChromaDB collection.
|
193 |
+
|
194 |
+
:param collection_name: The name of the collection to search within.
|
195 |
+
:param query: The text query to search for (optional).
|
196 |
+
:param query_embedding: The embedding query to search for (optional).
|
197 |
+
:param metadata_filter: A filter to apply to the metadata (optional).
|
198 |
+
:param n_results: The number of results to return (default is 10).
|
199 |
+
:return: The search results.
|
200 |
+
"""
|
201 |
+
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY)
|
202 |
+
collection = client.get_collection(collection_name)
|
203 |
+
|
204 |
+
if query:
|
205 |
+
query_embedding = vectorize([query])[0]
|
206 |
+
|
207 |
+
if query_embedding:
|
208 |
+
results = collection.query(query_embeddings=[query_embedding], n_results=n_results, where=metadata_filter)
|
209 |
+
else:
|
210 |
+
results = collection.get(where=metadata_filter, limit=n_results)
|
211 |
+
|
212 |
+
return results
|
213 |
+
|
214 |
+
|
215 |
+
def delete_documents(collection_name=COLLECTION_NAME, ids=None):
|
216 |
+
"""
|
217 |
+
Delete documents from a ChromaDB collection based on their IDs.
|
218 |
+
|
219 |
+
:param collection_name: The name of the collection.
|
220 |
+
:param ids: A list of IDs of the documents to delete.
|
221 |
+
"""
|
222 |
+
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY)
|
223 |
+
collection = client.get_collection(collection_name)
|
224 |
+
|
225 |
+
collection.delete(ids=ids)
|
226 |
+
print(f"Documents with IDs {ids} have been deleted from the collection {collection_name}.")
|
227 |
+
|
228 |
+
def delete_collection(collection_name=COLLECTION_NAME):
|
229 |
+
"""
|
230 |
+
Delete a ChromaDB collection.
|
231 |
+
|
232 |
+
:param collection_name: The name of the collection to delete.
|
233 |
+
"""
|
234 |
+
client = chromadb.PersistentClient(path=PERSIST_DIRECTORY)
|
235 |
+
client.delete_collection(collection_name)
|
236 |
+
print(f"Collection {collection_name} has been deleted.")
|
src/workflow.py
CHANGED
@@ -11,6 +11,7 @@ from src.utils import (
|
|
11 |
fetch_questions,
|
12 |
submit_answers,
|
13 |
get_file,
|
|
|
14 |
)
|
15 |
from src.inference import Agent
|
16 |
|
@@ -36,7 +37,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
36 |
results_log = []
|
37 |
answers_payload = []
|
38 |
|
39 |
-
#chosen_task_id = "
|
40 |
#questions_data = [item for item in questions_data if item.get("task_id") == chosen_task_id]
|
41 |
|
42 |
for item in questions_data:
|
@@ -71,6 +72,20 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
71 |
truth=final_answer
|
72 |
)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
console.print(Panel(f"[bold green]Submitted Answer[/bold green]\n{submitted_answer}", expand=False))
|
75 |
console.print(Panel(f"The correct final answer is: [bold]{final_answer}[/bold]"))
|
76 |
|
@@ -88,7 +103,6 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
88 |
|
89 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
90 |
result_data = submit_answers(submission_data)
|
91 |
-
|
92 |
if result_data:
|
93 |
final_status = (
|
94 |
f"Submission Successful!\n"
|
@@ -101,3 +115,4 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
|
|
101 |
return final_status, results_df
|
102 |
else:
|
103 |
return "Submission Failed.", pd.DataFrame(results_log)
|
|
|
|
11 |
fetch_questions,
|
12 |
submit_answers,
|
13 |
get_file,
|
14 |
+
load_in_vector_db,
|
15 |
)
|
16 |
from src.inference import Agent
|
17 |
|
|
|
37 |
results_log = []
|
38 |
answers_payload = []
|
39 |
|
40 |
+
#chosen_task_id = "f918266a-b3e0-4914-865d-4faa564f1aef"
|
41 |
#questions_data = [item for item in questions_data if item.get("task_id") == chosen_task_id]
|
42 |
|
43 |
for item in questions_data:
|
|
|
72 |
truth=final_answer
|
73 |
)
|
74 |
|
75 |
+
if submitted_answer == final_answer:
|
76 |
+
try:
|
77 |
+
load_in_vector_db(
|
78 |
+
markdown_content=f"{question_text}{file_context}\n\nFINAL ANSWER:{submitted_answer}",
|
79 |
+
#metadatas={
|
80 |
+
# "task_id": task_id,
|
81 |
+
# "question": question_text,
|
82 |
+
# "file_name": file_name,
|
83 |
+
#},
|
84 |
+
)
|
85 |
+
console.print(f"Correct answer vectorized and stored")
|
86 |
+
except Exception as e:
|
87 |
+
console.print(f"Error loading in vector DB: {e}", style="bold red")
|
88 |
+
|
89 |
console.print(Panel(f"[bold green]Submitted Answer[/bold green]\n{submitted_answer}", expand=False))
|
90 |
console.print(Panel(f"The correct final answer is: [bold]{final_answer}[/bold]"))
|
91 |
|
|
|
103 |
|
104 |
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
105 |
result_data = submit_answers(submission_data)
|
|
|
106 |
if result_data:
|
107 |
final_status = (
|
108 |
f"Submission Successful!\n"
|
|
|
115 |
return final_status, results_df
|
116 |
else:
|
117 |
return "Submission Failed.", pd.DataFrame(results_log)
|
118 |
+
|