mriusero commited on
Commit
3bcd8f6
·
1 Parent(s): 439e39d

feat: retrieval (1st version)

Browse files
.gitignore CHANGED
@@ -11,9 +11,13 @@ my-traffic-analysis-441217-32bda1474a0f.json
11
  # Python
12
  *__pycache__/
13
 
14
- # Model
15
  llm/
16
  attachments/
17
  logs/
18
  1st_run/
19
- metadata.jsonl
 
 
 
 
 
11
  # Python
12
  *__pycache__/
13
 
14
+ # Project
15
  llm/
16
  attachments/
17
  logs/
18
  1st_run/
19
+ metadata.jsonl
20
+ tests.py
21
+
22
+ chroma_db/
23
+ *.bin
prompt.md CHANGED
@@ -2,7 +2,7 @@ You are a general and precise AI assistant. I will ask you a question.
2
  Report your thoughts, and finish
3
  your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
4
  If a tool provide an error, use the tool differently.
5
- For web searching, ensure your answer by cross-checking data with several sources.
6
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of
7
  numbers and/or strings.
8
  If you are asked for a number, don’t use comma to write your number neither use units such as $ or percent
 
2
  Report your thoughts, and finish
3
  your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
4
  If a tool provide an error, use the tool differently.
5
+ For web searching, first search in your knowledge and if necessary complete them with web_search and ensure your answer by cross-checking data with several sources.
6
  YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of
7
  numbers and/or strings.
8
  If you are asked for a number, don’t use comma to write your number neither use units such as $ or percent
src/inference.py CHANGED
@@ -19,6 +19,7 @@ from src.tools import (
19
  analyze_excel,
20
  analyze_youtube_video,
21
  calculate_sum,
 
22
  )
23
 
24
  load_dotenv()
@@ -43,6 +44,7 @@ class Agent:
43
  "analyze_excel": analyze_excel,
44
  "analyze_youtube_video": analyze_youtube_video,
45
  "calculate_sum": calculate_sum,
 
46
  }
47
  self.log = []
48
  self.tools = self.get_tools()
@@ -74,6 +76,7 @@ class Agent:
74
  analyze_excel,
75
  analyze_youtube_video,
76
  calculate_sum,
 
77
  ]
78
  ).get('tools')
79
 
 
19
  analyze_excel,
20
  analyze_youtube_video,
21
  calculate_sum,
22
+ retrieve_knowledge,
23
  )
24
 
25
  load_dotenv()
 
44
  "analyze_excel": analyze_excel,
45
  "analyze_youtube_video": analyze_youtube_video,
46
  "calculate_sum": calculate_sum,
47
+ "retrieve_knowledge": retrieve_knowledge,
48
  }
49
  self.log = []
50
  self.tools = self.get_tools()
 
76
  analyze_excel,
77
  analyze_youtube_video,
78
  calculate_sum,
79
+ retrieve_knowledge,
80
  ]
81
  ).get('tools')
82
 
src/tools/__init__.py CHANGED
@@ -9,4 +9,5 @@ from .transcribe_audio import transcribe_audio
9
  from .execute_code import execute_code
10
  from .analyze_excel import analyze_excel
11
  from .analyze_youtube_video import analyze_youtube_video
12
- from .calculator import calculate_sum
 
 
9
  from .execute_code import execute_code
10
  from .analyze_excel import analyze_excel
11
  from .analyze_youtube_video import analyze_youtube_video
12
+ from .calculator import calculate_sum
13
+ from .retrieve_knowledge import retrieve_knowledge
src/tools/retrieve_knowledge.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.utils.tooling import tool
2
+
3
+ @tool
4
+ def retrieve_knowledge(query: str, n_results: int = 5, distance_threshold : float = 0.5) -> str:
5
+ """
6
+ Retrieves knowledge from a database with a provided query.
7
+ Args:
8
+ query (str): The query to search for in the vector store.
9
+ n_results (int, optional): The number of results to return. Default is 5.
10
+ distance_threshold (float, optional): The minimum distance score for results. Default is 0.5.
11
+ """
12
+ try:
13
+ from src.utils.vector_store import retrieve_from_database
14
+ results = retrieve_from_database(
15
+ query=query,
16
+ n_results=n_results,
17
+ distance_threshold=distance_threshold
18
+ )
19
+ return str(results)
20
+
21
+ except Exception as e:
22
+ return f"An unexpected error occurred: {str(e)}"
src/tools/visit_webpage.py CHANGED
@@ -1,6 +1,5 @@
1
- import re
2
-
3
  from src.utils.tooling import tool
 
4
 
5
  @tool
6
  def visit_webpage(url: str) -> str:
@@ -11,6 +10,7 @@ def visit_webpage(url: str) -> str:
11
  url (str): The URL of the webpage to visit.
12
  """
13
  try:
 
14
  import requests
15
  from markdownify import markdownify
16
  from requests.exceptions import RequestException
@@ -28,6 +28,14 @@ def visit_webpage(url: str) -> str:
28
  markdown_content = markdownify(response.text).strip() # Convert the HTML content to Markdown
29
  markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) # Remove multiple line breaks
30
 
 
 
 
 
 
 
 
 
31
  return truncate_content(markdown_content, 10000)
32
 
33
  except requests.exceptions.Timeout:
@@ -37,4 +45,4 @@ def visit_webpage(url: str) -> str:
37
  return f"Error fetching the webpage: {str(e)}"
38
 
39
  except Exception as e:
40
- return f"An unexpected error occurred: {str(e)}"
 
 
 
1
  from src.utils.tooling import tool
2
+ from src.utils.vector_store import vectorize, load_in_vector_db
3
 
4
  @tool
5
  def visit_webpage(url: str) -> str:
 
10
  url (str): The URL of the webpage to visit.
11
  """
12
  try:
13
+ import re
14
  import requests
15
  from markdownify import markdownify
16
  from requests.exceptions import RequestException
 
28
  markdown_content = markdownify(response.text).strip() # Convert the HTML content to Markdown
29
  markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content) # Remove multiple line breaks
30
 
31
+ # Adding metadata
32
+ metadatas = {
33
+ "url": url,
34
+ }
35
+
36
+ text_embeddings, chunks = vectorize(markdown_content) # Vectorize the content
37
+ load_in_vector_db(text_embeddings, chunks, metadatas=metadatas) # Load the text embeddings into a FAISS index
38
+
39
  return truncate_content(markdown_content, 10000)
40
 
41
  except requests.exceptions.Timeout:
 
45
  return f"Error fetching the webpage: {str(e)}"
46
 
47
  except Exception as e:
48
+ return f"An unexpected error occurred: {str(e)}"
src/utils/vector_store.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from mistralai import Mistral
4
+ import numpy as np
5
+ import time
6
+ import chromadb
7
+ import json
8
+
9
+ load_dotenv()
10
+ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
11
+ COLLECTION_NAME = "webpages_collection"
12
+ PERSIST_DIRECTORY = "./chroma_db"
13
+
14
+ def get_text_embeddings(input_texts):
15
+ """
16
+ Get the text embeddings for the given inputs using Mistral API.
17
+ """
18
+ client = Mistral(api_key=MISTRAL_API_KEY)
19
+ while True:
20
+ try:
21
+ embeddings_batch_response = client.embeddings.create(
22
+ model="mistral-embed",
23
+ inputs=input_texts
24
+ )
25
+ return [data.embedding for data in embeddings_batch_response.data]
26
+ except Exception as e:
27
+ if "rate limit exceeded" in str(e).lower():
28
+ print("Rate limit exceeded. Retrying after 1 second...")
29
+ time.sleep(1)
30
+ else:
31
+ raise
32
+
33
+ def vectorize(markdown_content, chunk_size=2048):
34
+ """
35
+ Vectorizes the given markdown content into chunks of specified size.
36
+ """
37
+ chunks = [markdown_content[i:i + chunk_size] for i in range(0, len(markdown_content), chunk_size)]
38
+ text_embeddings = get_text_embeddings(chunks)
39
+ return np.array(text_embeddings), chunks
40
+
41
+ def load_in_vector_db(text_embeddings, chunks, metadatas=None, collection_name=COLLECTION_NAME):
42
+ """
43
+ Load the text embeddings into a ChromaDB collection for efficient similarity search.
44
+ """
45
+ client = chromadb.PersistentClient(path=PERSIST_DIRECTORY)
46
+
47
+ # Check if the collection exists, if not, create it
48
+ if collection_name not in [col.name for col in client.list_collections()]:
49
+ collection = client.create_collection(collection_name)
50
+ else:
51
+ collection = client.get_collection(collection_name)
52
+
53
+ for embedding, chunk in zip(text_embeddings, chunks):
54
+ collection.add(
55
+ embeddings=[embedding],
56
+ documents=[chunk],
57
+ metadatas=[metadatas],
58
+ ids=[str(hash(chunk))]
59
+ )
60
+
61
+
62
+ def see_database(collection_name=COLLECTION_NAME):
63
+ """
64
+ Load the ChromaDB collection and text chunks.
65
+ """
66
+ client = chromadb.PersistentClient(path=PERSIST_DIRECTORY)
67
+
68
+ if collection_name not in [col.name for col in client.list_collections()]:
69
+ print("Collection not found. Please ensure it is created.")
70
+ return
71
+
72
+ collection = client.get_collection(collection_name)
73
+
74
+ items = collection.get()
75
+
76
+ print(f"Type of items: {type(items)}")
77
+ print(f"Items: {items}")
78
+
79
+ for item in items:
80
+ print(f"Type of item: {type(item)}")
81
+ print(f"Item: {item}")
82
+
83
+ if isinstance(item, dict):
84
+ print(f"ID: {item.get('ids')}")
85
+ print(f"Document: {item.get('document')}")
86
+ print(f"Metadata: {item.get('metadata')}")
87
+ else:
88
+ print("Item is not a dictionary")
89
+
90
+ print("---")
91
+
92
+ def retrieve_from_database(query, collection_name=COLLECTION_NAME, n_results=5, distance_threshold=None):
93
+ """
94
+ Retrieve the most similar documents from the vector store based on the query.
95
+ """
96
+ client = chromadb.PersistentClient(path=PERSIST_DIRECTORY)
97
+ collection = client.get_collection(collection_name)
98
+ query_embeddings = get_text_embeddings([query])
99
+ raw_results = collection.query(
100
+ query_embeddings=query_embeddings,
101
+ n_results=n_results,
102
+ include=["documents", "metadatas", "distances"]
103
+ )
104
+ if distance_threshold is not None:
105
+ filtered_results = {
106
+ "ids": [],
107
+ "distances": [],
108
+ "metadatas": [],
109
+ "documents": []
110
+ }
111
+ for i, distance in enumerate(raw_results['distances'][0]):
112
+ if distance >= distance_threshold:
113
+ filtered_results['ids'].append(raw_results['ids'][0][i])
114
+ filtered_results['distances'].append(distance)
115
+ filtered_results['metadatas'].append(raw_results['metadatas'][0][i])
116
+ filtered_results['documents'].append(raw_results['documents'][0][i])
117
+ results = filtered_results
118
+
119
+ if len(results['documents']) == 0:
120
+ return "No relevant data found in knowledge database, have you visited webpages?"
121
+ else:
122
+ return json.dumps(results, indent=4)
123
+ else:
124
+ return json.dumps(raw_results, indent=4)
tools.json CHANGED
@@ -253,5 +253,32 @@
253
  ]
254
  }
255
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  }
257
  ]
 
253
  ]
254
  }
255
  }
256
+ },
257
+ {
258
+ "type": "function",
259
+ "function": {
260
+ "name": "retrieve_knowledge",
261
+ "description": "Retrieves knowledge from a database with a provided query.",
262
+ "parameters": {
263
+ "type": "object",
264
+ "properties": {
265
+ "query": {
266
+ "type": "string",
267
+ "description": "The query to search for in the vector store."
268
+ },
269
+ "n_results": {
270
+ "type": "integer",
271
+ "description": "The number of results to return. Default is 5."
272
+ },
273
+ "similarity_threshold": {
274
+ "type": "number",
275
+ "description": "The minimum similarity score for results. Default is 0.7."
276
+ }
277
+ },
278
+ "required": [
279
+ "query"
280
+ ]
281
+ }
282
+ }
283
  }
284
  ]