Spaces:
Sleeping
Sleeping
Commit
·
ae1d0b9
0
Parent(s):
v1
Browse files- .gitattributes +36 -0
- .gitignore +11 -0
- README.md +16 -0
- app.py +238 -0
- build/lib/src/__init__.py +0 -0
- build/lib/src/agent.py +371 -0
- build/lib/src/chess.py +235 -0
- build/lib/src/final_answer.py +210 -0
- build/lib/src/tools.py +675 -0
- build/lib/tools/__init__.py +0 -0
- build/lib/tools/chess.py +241 -0
- build/lib/tools/model.py +27 -0
- build/lib/tools/model_chess.py +188 -0
- build/lib/tools/test.py +15 -0
- build/lib/tools/test_1.py +12 -0
- pytest.ini +3 -0
- requirements.txt +2 -0
- setup.py +15 -0
- src.egg-info/PKG-INFO +8 -0
- src.egg-info/SOURCES.txt +20 -0
- src.egg-info/dependency_links.txt +1 -0
- src.egg-info/requires.txt +2 -0
- src.egg-info/top_level.txt +2 -0
- src/__init__.py +0 -0
- src/agent.py +367 -0
- src/final_answer.py +210 -0
- src/tools.py +751 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Folders
|
2 |
+
|
3 |
+
__pycache__/
|
4 |
+
huggingface_env/
|
5 |
+
.git/
|
6 |
+
|
7 |
+
# Files
|
8 |
+
.env
|
9 |
+
.pylintrc
|
10 |
+
pyrightconfig.json
|
11 |
+
|
README.md
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Template Final Assignment
|
3 |
+
emoji: 🕵🏻♂️
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.25.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
hf_oauth: true
|
11 |
+
# optional, default duration is 8 hours/480 minutes. Max duration is 30 days/43200 minutes.
|
12 |
+
hf_oauth_expiration_minutes: 480
|
13 |
+
---
|
14 |
+
|
15 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
16 |
+
|
app.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import pandas as pd
|
5 |
+
import requests
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
|
8 |
+
from src.agent import BasicAgent
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
+
|
12 |
+
# (Keep Constants as is)
|
13 |
+
# --- Constants ---
|
14 |
+
DEFAULT_API_URL = os.getenv("DEFAULT_API_URL")
|
15 |
+
|
16 |
+
|
17 |
+
def run_and_submit_all(profile: gr.OAuthProfile):
|
18 |
+
"""
|
19 |
+
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
20 |
+
and displays the results.
|
21 |
+
"""
|
22 |
+
# --- Determine HF Space Runtime URL and Repo URL ---
|
23 |
+
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
24 |
+
|
25 |
+
if profile:
|
26 |
+
username = f"{profile.username}"
|
27 |
+
print(f"User logged in: {username}")
|
28 |
+
else:
|
29 |
+
print("User not logged in.")
|
30 |
+
return "Please Login to Hugging Face with the button.", None
|
31 |
+
|
32 |
+
api_url = DEFAULT_API_URL
|
33 |
+
questions_url = f"{api_url}/questions"
|
34 |
+
submit_url = f"{api_url}/submit"
|
35 |
+
|
36 |
+
# 1. Instantiate Agent ( modify this part to create your agent)
|
37 |
+
try:
|
38 |
+
agent = BasicAgent()
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Error instantiating agent: {e}")
|
41 |
+
return f"Error initializing agent: {e}", None
|
42 |
+
# In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
|
43 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
44 |
+
print(agent_code)
|
45 |
+
|
46 |
+
# 2. Fetch Questions
|
47 |
+
print(f"Fetching questions from: {questions_url}")
|
48 |
+
response = None
|
49 |
+
try:
|
50 |
+
response = requests.get(questions_url, timeout=15)
|
51 |
+
response.raise_for_status()
|
52 |
+
questions_data = response.json()
|
53 |
+
if not questions_data:
|
54 |
+
print("Fetched questions list is empty.")
|
55 |
+
return "Fetched questions list is empty or invalid format.", None
|
56 |
+
print(f"Fetched {len(questions_data)} questions.")
|
57 |
+
except requests.exceptions.JSONDecodeError as e:
|
58 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
59 |
+
if response is not None:
|
60 |
+
print(f"Response text: {response.text[:500]}")
|
61 |
+
return f"Error decoding server response for questions: {e}", None
|
62 |
+
except requests.exceptions.RequestException as e:
|
63 |
+
print(f"Error fetching questions: {e}")
|
64 |
+
return f"Error fetching questions: {e}", None
|
65 |
+
except Exception as e:
|
66 |
+
print(f"An unexpected error occurred fetching questions: {e}")
|
67 |
+
return f"An unexpected error occurred fetching questions: {e}", None
|
68 |
+
|
69 |
+
# 3. Run your Agent
|
70 |
+
results_log = []
|
71 |
+
answers_payload = []
|
72 |
+
print(f"Running agent on {len(questions_data)} questions...")
|
73 |
+
# Limit the number of questions to process to avoid timeouts
|
74 |
+
max_questions = 20 # Process only 20 questions at a time
|
75 |
+
|
76 |
+
tasks_to_process = [
|
77 |
+
# "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3",
|
78 |
+
# "1f975693-876d-457b-a649-393859e79bf3",
|
79 |
+
# "840bfca7-4f7b-481a-8794-c560c340185d",
|
80 |
+
# "7bd855d8-463d-4ed5-93ca-5fe35145f733",
|
81 |
+
# "f918266a-b3e0-4914-865d-4faa564f1aef"
|
82 |
+
]
|
83 |
+
|
84 |
+
questions_to_process = questions_data[:max_questions]
|
85 |
+
|
86 |
+
if tasks_to_process:
|
87 |
+
questions_to_process = [
|
88 |
+
x for x in questions_data if x.get("task_id") in tasks_to_process
|
89 |
+
]
|
90 |
+
else:
|
91 |
+
questions_to_process = questions_data[:max_questions]
|
92 |
+
|
93 |
+
count = 0
|
94 |
+
# for item in questions_data:
|
95 |
+
for item in questions_to_process:
|
96 |
+
task_id = item.get("task_id")
|
97 |
+
question_text = item.get("question")
|
98 |
+
count += 1
|
99 |
+
if not task_id or question_text is None:
|
100 |
+
print(f"Skipping item with missing task_id or question: {item}")
|
101 |
+
continue
|
102 |
+
try:
|
103 |
+
print(f"{count}. Processing Task {task_id}: {question_text}")
|
104 |
+
submitted_answer = agent(question_text, task_id)
|
105 |
+
answers_payload.append(
|
106 |
+
{"task_id": task_id, "submitted_answer": submitted_answer}
|
107 |
+
)
|
108 |
+
results_log.append(
|
109 |
+
{
|
110 |
+
"Task ID": task_id,
|
111 |
+
"Question": question_text,
|
112 |
+
"Submitted Answer": submitted_answer,
|
113 |
+
}
|
114 |
+
)
|
115 |
+
except Exception as e:
|
116 |
+
print(f"Error running agent on task {task_id}: {e}")
|
117 |
+
results_log.append(
|
118 |
+
{
|
119 |
+
"Task ID": task_id,
|
120 |
+
"Question": question_text,
|
121 |
+
"Submitted Answer": f"AGENT ERROR: {e}",
|
122 |
+
}
|
123 |
+
)
|
124 |
+
|
125 |
+
if not answers_payload:
|
126 |
+
print("Agent did not produce any answers to submit.")
|
127 |
+
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
128 |
+
|
129 |
+
# 4. Prepare Submission
|
130 |
+
submission_data = {
|
131 |
+
"username": username.strip(),
|
132 |
+
"agent_code": agent_code,
|
133 |
+
"answers": answers_payload,
|
134 |
+
}
|
135 |
+
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
136 |
+
print(status_update)
|
137 |
+
|
138 |
+
# 5. Submit
|
139 |
+
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
140 |
+
try:
|
141 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
142 |
+
response.raise_for_status()
|
143 |
+
result_data = response.json()
|
144 |
+
final_status = (
|
145 |
+
f"Submission Successful!\n"
|
146 |
+
f"User: {result_data.get('username')}\n"
|
147 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
148 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
149 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
150 |
+
)
|
151 |
+
print("Submission successful.")
|
152 |
+
results_df = pd.DataFrame(results_log)
|
153 |
+
return final_status, results_df
|
154 |
+
except requests.exceptions.HTTPError as e:
|
155 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
156 |
+
try:
|
157 |
+
error_json = e.response.json()
|
158 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
159 |
+
except requests.exceptions.JSONDecodeError:
|
160 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
161 |
+
status_message = f"Submission Failed: {error_detail}"
|
162 |
+
print(status_message)
|
163 |
+
results_df = pd.DataFrame(results_log)
|
164 |
+
return status_message, results_df
|
165 |
+
except requests.exceptions.Timeout:
|
166 |
+
status_message = "Submission Failed: The request timed out."
|
167 |
+
print(status_message)
|
168 |
+
results_df = pd.DataFrame(results_log)
|
169 |
+
return status_message, results_df
|
170 |
+
except requests.exceptions.RequestException as e:
|
171 |
+
status_message = f"Submission Failed: Network error - {e}"
|
172 |
+
print(status_message)
|
173 |
+
results_df = pd.DataFrame(results_log)
|
174 |
+
return status_message, results_df
|
175 |
+
except Exception as e:
|
176 |
+
status_message = f"An unexpected error occurred during submission: {e}"
|
177 |
+
print(status_message)
|
178 |
+
results_df = pd.DataFrame(results_log)
|
179 |
+
return status_message, results_df
|
180 |
+
|
181 |
+
|
182 |
+
# --- Build Gradio Interface using Blocks ---
|
183 |
+
with gr.Blocks() as demo:
|
184 |
+
gr.Markdown("# Basic Agent Evaluation Runner")
|
185 |
+
gr.Markdown(
|
186 |
+
"""
|
187 |
+
**Instructions:**
|
188 |
+
|
189 |
+
1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
|
190 |
+
2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
|
191 |
+
3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
|
192 |
+
|
193 |
+
---
|
194 |
+
**Disclaimers:**
|
195 |
+
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
|
196 |
+
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
|
197 |
+
"""
|
198 |
+
)
|
199 |
+
|
200 |
+
gr.LoginButton()
|
201 |
+
|
202 |
+
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
203 |
+
|
204 |
+
status_output = gr.Textbox(
|
205 |
+
label="Run Status / Submission Result", lines=5, interactive=False
|
206 |
+
)
|
207 |
+
# Removed max_rows=10 from DataFrame constructor
|
208 |
+
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
209 |
+
|
210 |
+
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
|
211 |
+
|
212 |
+
if __name__ == "__main__":
|
213 |
+
print("\n" + "-" * 30 + " App Starting " + "-" * 30)
|
214 |
+
# Check for SPACE_HOST and SPACE_ID at startup for information
|
215 |
+
space_host_startup = os.getenv("SPACE_HOST")
|
216 |
+
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
217 |
+
|
218 |
+
if space_host_startup:
|
219 |
+
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
220 |
+
print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
|
221 |
+
else:
|
222 |
+
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
223 |
+
|
224 |
+
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
225 |
+
print(f"✅ SPACE_ID found: {space_id_startup}")
|
226 |
+
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
227 |
+
print(
|
228 |
+
f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main"
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
print(
|
232 |
+
"ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined."
|
233 |
+
)
|
234 |
+
|
235 |
+
print("-" * (60 + len(" App Starting ")) + "\n")
|
236 |
+
|
237 |
+
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
238 |
+
demo.launch(debug=True, share=False)
|
build/lib/src/__init__.py
ADDED
File without changes
|
build/lib/src/agent.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
10 |
+
from langchain_core.prompts import ChatPromptTemplate
|
11 |
+
from langchain_core.rate_limiters import InMemoryRateLimiter
|
12 |
+
from langchain_core.tools import Tool
|
13 |
+
from langchain_experimental.utilities import PythonREPL
|
14 |
+
|
15 |
+
# from langchain_community.tools import DuckDuckGoSearchResults
|
16 |
+
# from langchain_community.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
|
17 |
+
# from langchain_google_community import GoogleSearchAPIWrapper, GoogleSearchResults
|
18 |
+
from langchain_ollama import ChatOllama
|
19 |
+
|
20 |
+
from src.final_answer import create_final_answer_graph, validate_answer
|
21 |
+
from src.tools import (
|
22 |
+
analyze_csv_file,
|
23 |
+
analyze_excel_file,
|
24 |
+
download_file_from_url,
|
25 |
+
duckduckgo_search,
|
26 |
+
extract_text_from_image,
|
27 |
+
read_file,
|
28 |
+
reverse_decoder,
|
29 |
+
review_youtube_video,
|
30 |
+
transcribe_audio,
|
31 |
+
transcribe_youtube,
|
32 |
+
use_vision_model,
|
33 |
+
video_frames_to_images,
|
34 |
+
website_scrape,
|
35 |
+
)
|
36 |
+
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
load_dotenv()
|
40 |
+
|
41 |
+
base_url = os.getenv("OLLAMA_BASE_URL")
|
42 |
+
|
43 |
+
rate_limiter = InMemoryRateLimiter(requests_per_second=0.1)
|
44 |
+
|
45 |
+
|
46 |
+
class BasicAgent:
|
47 |
+
def __init__(self):
|
48 |
+
try:
|
49 |
+
logger.info("Initializing BasicAgent")
|
50 |
+
|
51 |
+
# Create the prompt template
|
52 |
+
prompt = ChatPromptTemplate.from_messages(
|
53 |
+
[
|
54 |
+
(
|
55 |
+
"system",
|
56 |
+
"""You are a general AI assistant. I will ask you a
|
57 |
+
question. Report your thoughts, and finish your answer
|
58 |
+
with the following template: FINAL ANSWER: [YOUR FINAL
|
59 |
+
ANSWER]. YOUR FINAL ANSWER should be a number OR as few
|
60 |
+
words as possible OR a comma separated list of numbers
|
61 |
+
and/or strings. If you are asked for a number, don't
|
62 |
+
use comma to write your number neither use units such
|
63 |
+
as $ or percent sign unless specified otherwise. If you
|
64 |
+
are asked for a string, don't use articles, neither
|
65 |
+
abbreviations (e.g. for cities), and write the digits
|
66 |
+
in plain text unless specified otherwise. If you are
|
67 |
+
asked for a comma separated list, apply the above rules
|
68 |
+
depending of whether the element to be put in the list
|
69 |
+
is a number or a string.
|
70 |
+
""",
|
71 |
+
),
|
72 |
+
("placeholder", "{chat_history}"),
|
73 |
+
("human", "{input}"),
|
74 |
+
("placeholder", "{agent_scratchpad}"),
|
75 |
+
]
|
76 |
+
)
|
77 |
+
logger.info("Created prompt template")
|
78 |
+
|
79 |
+
llm = ChatOllama(
|
80 |
+
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
|
81 |
+
base_url=base_url,
|
82 |
+
temperature=0.2,
|
83 |
+
)
|
84 |
+
logger.info("Created model successfully")
|
85 |
+
|
86 |
+
# Define available tools
|
87 |
+
tools = [
|
88 |
+
Tool(
|
89 |
+
name="DuckDuckGoSearchResults",
|
90 |
+
description="""Performs a live search using DuckDuckGo
|
91 |
+
and analyzes the top results. Returns a summary including
|
92 |
+
result titles, URLs, brief snippets, and ranking
|
93 |
+
positions. Use this to quickly assess the relevance,
|
94 |
+
diversity, and quality of information retrieved from a
|
95 |
+
privacy-focused search engine, without personalized or
|
96 |
+
biased filtering.""",
|
97 |
+
# func=DuckDuckGoSearchResults(
|
98 |
+
# api_wrapper=DuckDuckGoSearchAPIWrapper()
|
99 |
+
# ).run,
|
100 |
+
func=duckduckgo_search,
|
101 |
+
),
|
102 |
+
# Tool(
|
103 |
+
# name="GoogleSearchResults",
|
104 |
+
# description="""Performs a live Google search and analyzes
|
105 |
+
# the top results. Returns a summary including result titles,
|
106 |
+
# URLs, brief snippets, and ranking positions. Use this to
|
107 |
+
# quickly understand the relevance, variety, and quality of
|
108 |
+
# search results for a given query before deeper research or
|
109 |
+
# content planning.""",
|
110 |
+
# func=GoogleSearchResults(
|
111 |
+
# api_wrapper=GoogleSearchAPIWrapper(
|
112 |
+
# google_api_key=os.getenv("GOOGLE_SEARCH_API_KEY"),
|
113 |
+
# google_cse_id=os.getenv("GOOGLE_CSE_ID"),
|
114 |
+
# k=5, # Number of results to return
|
115 |
+
# )
|
116 |
+
# ).run,
|
117 |
+
# ),
|
118 |
+
Tool(
|
119 |
+
name="analyze csv file",
|
120 |
+
description="""Only read and analyze the contents of a CSV
|
121 |
+
file if one is explicitly referenced or uploaded in the
|
122 |
+
question. When a CSV file is provided, return a summary of
|
123 |
+
the dataset, including column names, data types, missing
|
124 |
+
value counts, basic statistics for numeric fields, and a
|
125 |
+
preview of the data. Use this only to quickly understand
|
126 |
+
the structure and quality of the dataset before performing
|
127 |
+
any further analysis.""",
|
128 |
+
func=analyze_csv_file,
|
129 |
+
),
|
130 |
+
Tool(
|
131 |
+
name="analyze excel file",
|
132 |
+
description="""Reads and analyzes the contents of an Excel
|
133 |
+
file (.xlsx or .xls). Returns structured summaries
|
134 |
+
for each sheet, including column names, data types, missing
|
135 |
+
value counts, basic statistics for numeric columns, and
|
136 |
+
sample rows. Use this to quickly explore the structure and
|
137 |
+
quality of Excel datasets.""",
|
138 |
+
func=analyze_excel_file,
|
139 |
+
),
|
140 |
+
Tool(
|
141 |
+
name="download file from url",
|
142 |
+
description="""Downloads a file from a given URL and saves
|
143 |
+
it locally. Supports various file types such as CSV, Excel,
|
144 |
+
images, and PDFs. Use this to retrieve external resources
|
145 |
+
for processing or analysis.""",
|
146 |
+
func=download_file_from_url,
|
147 |
+
),
|
148 |
+
Tool(
|
149 |
+
name="extract_text_from_image",
|
150 |
+
description="""Performs Optical Character Recognition (OCR)
|
151 |
+
on an image to extract readable text after downloading it.
|
152 |
+
Supports common image formats (e.g., PNG, JPG). Use this to
|
153 |
+
digitize printed or handwritten content from images for
|
154 |
+
search, analysis, or storage.""",
|
155 |
+
func=extract_text_from_image,
|
156 |
+
),
|
157 |
+
Tool(
|
158 |
+
name="read_file",
|
159 |
+
description="""Reads the raw content of a local text file.
|
160 |
+
Supports formats such as .txt, .json, .xml, and markdown.
|
161 |
+
Use this to load unstructured or semi-structured file
|
162 |
+
content for display, parsing, or further
|
163 |
+
processing—excluding CSV and Excel formats.""",
|
164 |
+
func=read_file,
|
165 |
+
),
|
166 |
+
Tool(
|
167 |
+
name="review_youtube_video",
|
168 |
+
description="""Analyzes a YouTube video by extracting key
|
169 |
+
information such as title, description, view count, likes,
|
170 |
+
comments, and transcript (if available). Use this to
|
171 |
+
generate summaries, insights, or sentiment analysis based
|
172 |
+
on video content and engagement.""",
|
173 |
+
func=review_youtube_video,
|
174 |
+
),
|
175 |
+
Tool(
|
176 |
+
name="transcribe_audio",
|
177 |
+
description="""Converts spoken words in an audio file into
|
178 |
+
written text using speech-to-text technology. Supports
|
179 |
+
common audio formats like MP3, WAV, and FLAC. Use this to
|
180 |
+
create transcripts for meetings, interviews, podcasts, or
|
181 |
+
any spoken content.""",
|
182 |
+
func=transcribe_audio,
|
183 |
+
),
|
184 |
+
Tool(
|
185 |
+
name="transcribe_youtube",
|
186 |
+
description="""Extracts and converts the audio from a
|
187 |
+
YouTube video into text using speech-to-text technology.
|
188 |
+
Supports generating transcripts for videos without captions
|
189 |
+
or subtitles. Use this to obtain searchable, readable text
|
190 |
+
from YouTube content.""",
|
191 |
+
func=transcribe_youtube,
|
192 |
+
),
|
193 |
+
Tool(
|
194 |
+
name="use_vision_model",
|
195 |
+
description="""Processes images using a computer vision
|
196 |
+
model to perform tasks such as object detection, image
|
197 |
+
classification, or segmentation. Use this to analyze visual
|
198 |
+
content and extract meaningful information from images.""",
|
199 |
+
func=use_vision_model,
|
200 |
+
),
|
201 |
+
Tool(
|
202 |
+
name="video_frames_to_images",
|
203 |
+
description="""Extracts individual frames from a video file
|
204 |
+
and saves them as separate image files. Use this to
|
205 |
+
analyze, process, or visualize specific moments within
|
206 |
+
video content. Use this to Youtube Videos""",
|
207 |
+
func=video_frames_to_images,
|
208 |
+
),
|
209 |
+
Tool(
|
210 |
+
name="website_scrape",
|
211 |
+
description="""It is mandatory to use duckduckgo_search
|
212 |
+
tool before invoking this tool .Fetches and extracts
|
213 |
+
content from a specified website URL. Supports retrieving
|
214 |
+
text, images, links, and other page elements.""",
|
215 |
+
func=website_scrape,
|
216 |
+
),
|
217 |
+
Tool(
|
218 |
+
name="python_repl",
|
219 |
+
description="""Write full, valid Python code using proper
|
220 |
+
multi-line code blocks Do not escape newlines (\n)
|
221 |
+
instead, write each line of code on a separate line Always
|
222 |
+
use proper indentation and syntax Return results using
|
223 |
+
print() or return if using a function Avoid partial or
|
224 |
+
inline code snippets — all code should be runnable in a
|
225 |
+
Python REPL If the input is a function, include example
|
226 |
+
usage at the end to ensure output is shown.""",
|
227 |
+
func=PythonREPL().run,
|
228 |
+
return_direct=True,
|
229 |
+
),
|
230 |
+
# Tool(
|
231 |
+
# name="wiki",
|
232 |
+
# description="""Retrieves summarized information or
|
233 |
+
# detailed content from Wikipedia based on a user query.
|
234 |
+
# Use this to quickly access encyclopedic knowledge and
|
235 |
+
# relevant facts on a wide range of topics.""",
|
236 |
+
# func=wiki,
|
237 |
+
# ),
|
238 |
+
Tool(
|
239 |
+
name="reverse decoder",
|
240 |
+
description="""Decodes a reversed sentence if the input
|
241 |
+
appears to be written backward.""",
|
242 |
+
func=reverse_decoder,
|
243 |
+
),
|
244 |
+
]
|
245 |
+
# tools = [wrap_tool_with_limit(tool, max_calls=3) for tool in raw_tools]
|
246 |
+
logger.info("Tools: %s", tools)
|
247 |
+
|
248 |
+
# Create the agent
|
249 |
+
agent = create_tool_calling_agent(llm, tools, prompt)
|
250 |
+
logger.info("Created tool calling agent")
|
251 |
+
|
252 |
+
# Create the agent executor
|
253 |
+
self.agent_executor = AgentExecutor(
|
254 |
+
agent=agent,
|
255 |
+
tools=tools,
|
256 |
+
return_intermediate_steps=True,
|
257 |
+
verbose=True,
|
258 |
+
max_iterations=5,
|
259 |
+
)
|
260 |
+
logger.info("Created agent executor")
|
261 |
+
|
262 |
+
# Create the graph
|
263 |
+
self.validation_graph = create_final_answer_graph()
|
264 |
+
|
265 |
+
except Exception as e:
|
266 |
+
logger.error("Error initializing agent: %s", e, exc_info=True)
|
267 |
+
raise
|
268 |
+
|
269 |
+
def __call__(self, question: str, task_id: str) -> str:
|
270 |
+
"""Execute the agent with the given question and optional file.
|
271 |
+
Args:
|
272 |
+
question (str): The question to answer
|
273 |
+
task_id (str): The task ID to fetch the file
|
274 |
+
Returns:
|
275 |
+
str: The final validated answer
|
276 |
+
Raises:
|
277 |
+
Exception: If no valid answer is found after max retries
|
278 |
+
"""
|
279 |
+
max_retries = 3
|
280 |
+
attempt = 0
|
281 |
+
|
282 |
+
previous_steps = set()
|
283 |
+
|
284 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
285 |
+
while attempt < max_retries:
|
286 |
+
default_api_url = os.getenv("DEFAULT_API_URL")
|
287 |
+
file_url = f"{default_api_url}/files/{task_id}"
|
288 |
+
|
289 |
+
file: Optional[dict] = None
|
290 |
+
try:
|
291 |
+
# Download file to temporary directory
|
292 |
+
file = download_file_from_url.invoke(
|
293 |
+
{
|
294 |
+
"url": file_url,
|
295 |
+
"directory": temp_dir,
|
296 |
+
}
|
297 |
+
)
|
298 |
+
logger.info("Downloaded file: %s", file_url)
|
299 |
+
except Exception:
|
300 |
+
logger.error(f"no download file available for {task_id} ")
|
301 |
+
file = None
|
302 |
+
|
303 |
+
try:
|
304 |
+
attempt += 1
|
305 |
+
logger.info("Attempt %d of %d", attempt, max_retries)
|
306 |
+
|
307 |
+
# Prepare input with file information
|
308 |
+
input_data = {
|
309 |
+
"input": question
|
310 |
+
+ (
|
311 |
+
f" [File: type={file.get('type', 'None')}, path={file.get('path', 'None')}]"
|
312 |
+
if file and file.get("type") != "error"
|
313 |
+
else ""
|
314 |
+
),
|
315 |
+
}
|
316 |
+
|
317 |
+
# Run the agent to get the answer
|
318 |
+
result = self.agent_executor.invoke(input_data)
|
319 |
+
answer = result.get("output", "")
|
320 |
+
intermediate_steps = result.get("intermediate_steps", [])
|
321 |
+
|
322 |
+
steps_str = str(intermediate_steps)
|
323 |
+
if steps_str in previous_steps:
|
324 |
+
logger.warning(
|
325 |
+
f"Detected repeated reasoning steps on attempt {attempt}. Breaking loop to avoid infinite retry."
|
326 |
+
)
|
327 |
+
break # or raise Exception to stop retries
|
328 |
+
previous_steps.add(steps_str)
|
329 |
+
|
330 |
+
logger.info("Attempt %d result: %s", attempt, result)
|
331 |
+
|
332 |
+
# Run validation (self.validation_graph is now StateGraph)
|
333 |
+
validation_result = validate_answer(
|
334 |
+
self.validation_graph, # type: ignore
|
335 |
+
answer,
|
336 |
+
[result.get("intermediate_steps", [])],
|
337 |
+
)
|
338 |
+
|
339 |
+
valid_answer = validation_result.get("valid_answer", False)
|
340 |
+
final_answer = validation_result.get("final_answer", "")
|
341 |
+
|
342 |
+
if valid_answer:
|
343 |
+
logger.info("Valid answer found on attempt %d", attempt)
|
344 |
+
torch.cuda.empty_cache()
|
345 |
+
return final_answer
|
346 |
+
|
347 |
+
logger.warning(
|
348 |
+
"Validation failed on attempt %d: %s", attempt, final_answer
|
349 |
+
)
|
350 |
+
if attempt >= max_retries:
|
351 |
+
raise Exception(
|
352 |
+
"Failed to get valid answer after %d attempts. Last error: %s",
|
353 |
+
max_retries,
|
354 |
+
final_answer,
|
355 |
+
)
|
356 |
+
|
357 |
+
except Exception as e:
|
358 |
+
logger.error("Error in attempt %d: %s", attempt, e, exc_info=True)
|
359 |
+
if attempt >= max_retries:
|
360 |
+
raise Exception(
|
361 |
+
"Failed after %d attempts. Last error: %s",
|
362 |
+
max_retries,
|
363 |
+
str(e),
|
364 |
+
)
|
365 |
+
continue
|
366 |
+
|
367 |
+
# Fallback in case loop exits unexpectedly
|
368 |
+
|
369 |
+
torch.cuda.empty_cache()
|
370 |
+
gc.collect()
|
371 |
+
raise Exception("No valid answer found after processing")
|
build/lib/src/chess.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import models, transforms
|
8 |
+
|
9 |
+
# Set device
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
# Load the pre-trained model
|
13 |
+
try:
|
14 |
+
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
15 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 13) # 13 classes including 'empty'
|
16 |
+
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
|
17 |
+
model.eval()
|
18 |
+
model = model.to(device)
|
19 |
+
except Exception as e:
|
20 |
+
print(f"Error loading model: {e}")
|
21 |
+
exit(1)
|
22 |
+
|
23 |
+
# Mapping chess piece indices
|
24 |
+
piece_labels = [
|
25 |
+
"black_bishop",
|
26 |
+
"black_king",
|
27 |
+
"black_knight",
|
28 |
+
"black_pawn",
|
29 |
+
"black_queen",
|
30 |
+
"black_rook",
|
31 |
+
"empty",
|
32 |
+
"white_bishop",
|
33 |
+
"white_king",
|
34 |
+
"white_knight",
|
35 |
+
"white_pawn",
|
36 |
+
"white_queen",
|
37 |
+
"white_rook",
|
38 |
+
]
|
39 |
+
|
40 |
+
# Define chessboard coordinates (0,0) is top-left (a8), (7,7) is bottom-right (h1)
|
41 |
+
coordinates = [(i, j) for i in range(8) for j in range(8)]
|
42 |
+
|
43 |
+
# Define a transformation to prepare images for the model
|
44 |
+
transform = transforms.Compose(
|
45 |
+
[
|
46 |
+
transforms.Resize((224, 224)),
|
47 |
+
transforms.ToTensor(),
|
48 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
# Function to process and predict the piece type at each square
|
54 |
+
def predict_piece(image, model, device):
|
55 |
+
try:
|
56 |
+
if len(image.shape) == 2 or image.shape[2] == 1:
|
57 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
58 |
+
else:
|
59 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
60 |
+
|
61 |
+
image = Image.fromarray(image)
|
62 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
63 |
+
with torch.no_grad():
|
64 |
+
output = model(image_tensor)
|
65 |
+
_, predicted = torch.max(output, 1)
|
66 |
+
return piece_labels[predicted.item()]
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error predicting piece: {e}")
|
69 |
+
return "unknown"
|
70 |
+
|
71 |
+
|
72 |
+
# Function to detect chessboard grid using edge detection and Hough lines
|
73 |
+
def detect_chessboard_grid(image):
|
74 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
75 |
+
# Enhance contrast
|
76 |
+
gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20)
|
77 |
+
# Apply Gaussian blur to reduce noise
|
78 |
+
gray = cv2.GaussianBlur(gray, (5, 5), 0)
|
79 |
+
# Edge detection with Canny
|
80 |
+
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
81 |
+
|
82 |
+
# Detect lines using Hough Transform
|
83 |
+
lines = cv2.HoughLinesP(
|
84 |
+
edges, 1, np.pi / 180, threshold=80, minLineLength=50, maxLineGap=10
|
85 |
+
)
|
86 |
+
|
87 |
+
if lines is None:
|
88 |
+
print("No lines detected.")
|
89 |
+
return None, edges
|
90 |
+
|
91 |
+
# Separate horizontal and vertical lines
|
92 |
+
h_lines = []
|
93 |
+
v_lines = []
|
94 |
+
for line in lines:
|
95 |
+
x1, y1, x2, y2 = line[0]
|
96 |
+
if abs(x2 - x1) > abs(y2 - y1): # Horizontal line
|
97 |
+
h_lines.append((y1, x1, x2))
|
98 |
+
else: # Vertical line
|
99 |
+
v_lines.append((x1, y1, y2))
|
100 |
+
|
101 |
+
# Sort and filter to get exactly 9 lines for each
|
102 |
+
h_lines = sorted(h_lines, key=lambda x: x[0])[:9] # Top 9 horizontal lines
|
103 |
+
v_lines = sorted(v_lines, key=lambda x: x[0])[:9] # Top 9 vertical lines
|
104 |
+
|
105 |
+
if len(h_lines) < 9 or len(v_lines) < 9:
|
106 |
+
print(
|
107 |
+
f"Insufficient lines detected: {len(h_lines)} horizontal, {len(v_lines)} vertical"
|
108 |
+
)
|
109 |
+
return None, edges
|
110 |
+
|
111 |
+
# Find intersections to get 8x8 grid corners
|
112 |
+
corners = []
|
113 |
+
for h in h_lines:
|
114 |
+
y = h[0]
|
115 |
+
for v in v_lines:
|
116 |
+
x = v[0]
|
117 |
+
corners.append([x, y])
|
118 |
+
|
119 |
+
# Ensure exactly 64 corners (8x8 grid)
|
120 |
+
if len(corners) != 64:
|
121 |
+
print(f"Expected 64 corners, got {len(corners)}")
|
122 |
+
return None, edges
|
123 |
+
|
124 |
+
corners = np.array(corners, dtype=np.float32).reshape(8, 8, 2)
|
125 |
+
|
126 |
+
# Visualize detected lines for debugging
|
127 |
+
debug_image = image.copy()
|
128 |
+
for y, x1, x2 in h_lines:
|
129 |
+
cv2.line(debug_image, (x1, y), (x2, y), (0, 255, 0), 2)
|
130 |
+
for x, y1, y2 in v_lines:
|
131 |
+
cv2.line(debug_image, (x, y1), (x, y2), (0, 0, 255), 2)
|
132 |
+
cv2.imwrite("lines_debug.png", debug_image)
|
133 |
+
|
134 |
+
return corners, edges
|
135 |
+
|
136 |
+
|
137 |
+
# Function to extract coordinates of chess pieces from an image
|
138 |
+
def extract_chessboard_coordinates(image_path):
|
139 |
+
try:
|
140 |
+
image = cv2.imread(image_path)
|
141 |
+
if image is None:
|
142 |
+
print(f"Failed to load image: {image_path}")
|
143 |
+
return []
|
144 |
+
except Exception as e:
|
145 |
+
print(f"Error loading image {image_path}: {e}")
|
146 |
+
return []
|
147 |
+
|
148 |
+
# Try OpenCV's chessboard detection first
|
149 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
150 |
+
ret, corners = cv2.findChessboardCorners(gray, (8, 8), None)
|
151 |
+
|
152 |
+
if ret:
|
153 |
+
corners = cv2.cornerSubPix(
|
154 |
+
gray,
|
155 |
+
corners,
|
156 |
+
(11, 11),
|
157 |
+
(-1, -1),
|
158 |
+
criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.1),
|
159 |
+
)
|
160 |
+
corners = corners.reshape(8, 8, 2)
|
161 |
+
else:
|
162 |
+
print("OpenCV chessboard detection failed. Attempting edge-based detection.")
|
163 |
+
corners, edges = detect_chessboard_grid(image)
|
164 |
+
if corners is None:
|
165 |
+
# Save edges for debugging
|
166 |
+
cv2.imwrite("edges_debug.png", edges)
|
167 |
+
print("Saved edge detection output to edges_debug.png")
|
168 |
+
return []
|
169 |
+
# Save debug image with detected corners
|
170 |
+
debug_image = image.copy()
|
171 |
+
for h in range(8):
|
172 |
+
for v in range(8):
|
173 |
+
x, y = int(corners[h, v, 0]), int(corners[h, v, 1])
|
174 |
+
cv2.circle(debug_image, (x, y), 5, (0, 255, 0), -1)
|
175 |
+
cv2.imwrite("grid_debug.png", debug_image)
|
176 |
+
print("Saved grid detection debug image to grid_debug.png")
|
177 |
+
|
178 |
+
# Calculate square size dynamically
|
179 |
+
square_width = np.mean(
|
180 |
+
[
|
181 |
+
np.linalg.norm(corners[i, j] - corners[i, j + 1])
|
182 |
+
for i in range(8)
|
183 |
+
for j in range(7)
|
184 |
+
]
|
185 |
+
)
|
186 |
+
square_height = np.mean(
|
187 |
+
[
|
188 |
+
np.linalg.norm(corners[i, j] - corners[i + 1, j])
|
189 |
+
for i in range(7)
|
190 |
+
for j in range(8)
|
191 |
+
]
|
192 |
+
)
|
193 |
+
square_size = int(min(square_width, square_height))
|
194 |
+
|
195 |
+
# Create a blank grid to store coordinates
|
196 |
+
piece_coordinates = []
|
197 |
+
|
198 |
+
# Loop through all coordinates and detect pieces
|
199 |
+
for i, j in coordinates:
|
200 |
+
try:
|
201 |
+
x = int(corners[i, j, 0])
|
202 |
+
y = int(corners[i, j, 1])
|
203 |
+
w = h = square_size
|
204 |
+
|
205 |
+
x = max(0, x)
|
206 |
+
y = max(0, y)
|
207 |
+
x_end = min(image.shape[1], x + w)
|
208 |
+
y_end = min(image.shape[0], y + h)
|
209 |
+
roi = image[y:y_end, x:x_end]
|
210 |
+
|
211 |
+
if roi.shape[0] == 0 or roi.shape[1] == 0:
|
212 |
+
print(f"Invalid ROI at square ({i}, {j})")
|
213 |
+
piece_coordinates.append(((i, j), "unknown"))
|
214 |
+
continue
|
215 |
+
|
216 |
+
predicted_piece = predict_piece(roi, model, device)
|
217 |
+
piece_coordinates.append(((i, j), predicted_piece))
|
218 |
+
except Exception as e:
|
219 |
+
print(f"Error processing square ({i}, {j}): {e}")
|
220 |
+
piece_coordinates.append(((i, j), "unknown"))
|
221 |
+
|
222 |
+
return piece_coordinates
|
223 |
+
|
224 |
+
|
225 |
+
# # Example usage
|
226 |
+
# IMAGE_PATH = "test.png"
|
227 |
+
# coordinates = extract_chessboard_coordinates(IMAGE_PATH)
|
228 |
+
# for coord, piece in coordinates:
|
229 |
+
# print(f"Piece at {coord}: {piece}")
|
230 |
+
|
231 |
+
# Clean up
|
232 |
+
del model
|
233 |
+
if torch.cuda.is_available():
|
234 |
+
torch.cuda.empty_cache()
|
235 |
+
gc.collect()
|
build/lib/src/final_answer.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from langchain_core.output_parsers import StrOutputParser
|
7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
from langchain_ollama import ChatOllama
|
9 |
+
from langgraph.graph import END, START, Graph, StateGraph
|
10 |
+
from typing_extensions import TypedDict
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
base_url = os.getenv("OLLAMA_BASE_URL")
|
15 |
+
|
16 |
+
|
17 |
+
class AgentState(TypedDict):
|
18 |
+
"""State for the final answer validation graph."""
|
19 |
+
|
20 |
+
question: str
|
21 |
+
answer: str
|
22 |
+
final_answer: str | None
|
23 |
+
agent_memory: Any
|
24 |
+
valid_answer: bool
|
25 |
+
|
26 |
+
|
27 |
+
def extract_answer(state: AgentState) -> Dict:
|
28 |
+
"""Extract and format the final answer from the state.
|
29 |
+
Args:
|
30 |
+
state: The state of the agent.
|
31 |
+
Returns:
|
32 |
+
A dictionary with the formatted final answer.
|
33 |
+
"""
|
34 |
+
# Extract the final answer from the state
|
35 |
+
sep_token = "FINAL ANSWER:"
|
36 |
+
raw_answer = state["answer"]
|
37 |
+
|
38 |
+
# Extract the answer after the separator if it exists
|
39 |
+
if sep_token in raw_answer:
|
40 |
+
formatted_answer = raw_answer.split(sep_token)[1].strip()
|
41 |
+
else:
|
42 |
+
formatted_answer = raw_answer.strip()
|
43 |
+
|
44 |
+
# Remove any brackets from lists
|
45 |
+
formatted_answer = formatted_answer.replace("[", "").replace("]", "")
|
46 |
+
|
47 |
+
# Remove units unless specified
|
48 |
+
if not any(
|
49 |
+
unit in formatted_answer.lower() for unit in ["$", "%", "dollars", "percent"]
|
50 |
+
):
|
51 |
+
formatted_answer = formatted_answer.replace("$", "").replace("%", "")
|
52 |
+
|
53 |
+
# Remove commas from numbers
|
54 |
+
parts = formatted_answer.split(",")
|
55 |
+
formatted_parts = []
|
56 |
+
for part in parts:
|
57 |
+
part = part.strip()
|
58 |
+
if part.replace(".", "").isdigit(): # Check if it's a number
|
59 |
+
part = part.replace(",", "")
|
60 |
+
formatted_parts.append(part)
|
61 |
+
formatted_answer = ", ".join(formatted_parts)
|
62 |
+
|
63 |
+
return {"final_answer": formatted_answer}
|
64 |
+
|
65 |
+
|
66 |
+
def reasoning_check(state: AgentState) -> Dict:
|
67 |
+
"""
|
68 |
+
Node that checks the reasoning of the final answer.
|
69 |
+
Args:
|
70 |
+
state: The state of the agent.
|
71 |
+
Returns:
|
72 |
+
A dictionary with the reasoning check result.
|
73 |
+
"""
|
74 |
+
model = ChatOllama(
|
75 |
+
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
|
76 |
+
base_url=base_url,
|
77 |
+
temperature=0.2,
|
78 |
+
)
|
79 |
+
prompt = ChatPromptTemplate.from_messages(
|
80 |
+
[
|
81 |
+
(
|
82 |
+
"system",
|
83 |
+
"""You are a strict validator of answers. Your job is to check if the reasoning and results are correct.
|
84 |
+
You should have >90% confidence that the answer is correct to pass it.
|
85 |
+
First list reasons why yes/no, then write your final decision: PASS in caps lock if it is satisfactory, FAIL if it is not.""",
|
86 |
+
),
|
87 |
+
(
|
88 |
+
"human",
|
89 |
+
"""
|
90 |
+
Here is a user-given task and the agent steps: {agent_memory}
|
91 |
+
Now here is the answer that was given: {final_answer}
|
92 |
+
Please check that the reasoning process and results are correct: do they correctly answer the given task?
|
93 |
+
""",
|
94 |
+
),
|
95 |
+
]
|
96 |
+
)
|
97 |
+
|
98 |
+
chain = prompt | model | StrOutputParser()
|
99 |
+
output = chain.invoke(
|
100 |
+
{
|
101 |
+
"agent_memory": state["agent_memory"],
|
102 |
+
"final_answer": state["final_answer"],
|
103 |
+
}
|
104 |
+
)
|
105 |
+
|
106 |
+
print("Reasoning Feedback: ", output)
|
107 |
+
if "FAIL" in output:
|
108 |
+
return {"valid_answer": False}
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
return {"valid_answer": True}
|
111 |
+
|
112 |
+
|
113 |
+
def formatting_check(state: AgentState) -> Dict:
|
114 |
+
"""
|
115 |
+
Node that checks the formatting of the final answer.
|
116 |
+
Args:
|
117 |
+
state: The state of the agent.
|
118 |
+
Returns:
|
119 |
+
A dictionary with the formatting check result.
|
120 |
+
"""
|
121 |
+
model = ChatOllama(
|
122 |
+
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
|
123 |
+
base_url=base_url,
|
124 |
+
temperature=0.2,
|
125 |
+
)
|
126 |
+
prompt = ChatPromptTemplate.from_messages(
|
127 |
+
[
|
128 |
+
(
|
129 |
+
"system",
|
130 |
+
"""You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
131 |
+
""",
|
132 |
+
),
|
133 |
+
(
|
134 |
+
"human",
|
135 |
+
"""
|
136 |
+
Here is a user-given task and the agent steps: {agent_memory}
|
137 |
+
Now here is the FINAL ANSWER that was given: {final_answer}
|
138 |
+
Ensure the FINAL ANSWER is in the right format as asked for by the task.
|
139 |
+
""",
|
140 |
+
),
|
141 |
+
]
|
142 |
+
)
|
143 |
+
|
144 |
+
chain = prompt | model | StrOutputParser()
|
145 |
+
output = chain.invoke(
|
146 |
+
{
|
147 |
+
"agent_memory": state["agent_memory"],
|
148 |
+
"final_answer": state["final_answer"],
|
149 |
+
}
|
150 |
+
)
|
151 |
+
|
152 |
+
print("Formatting Feedback: ", output)
|
153 |
+
if "FAIL" in output:
|
154 |
+
return {"valid_answer": False}
|
155 |
+
|
156 |
+
torch.cuda.empty_cache()
|
157 |
+
return {"valid_answer": True}
|
158 |
+
|
159 |
+
|
160 |
+
def create_final_answer_graph() -> Graph:
|
161 |
+
"""Create a graph that validates the final answer.
|
162 |
+
Returns:
|
163 |
+
A graph that validates the final answer.
|
164 |
+
"""
|
165 |
+
# Create the graph
|
166 |
+
workflow = StateGraph(AgentState)
|
167 |
+
|
168 |
+
# Add nodes
|
169 |
+
workflow.add_node("extract_answer", extract_answer)
|
170 |
+
workflow.add_node("reasoning_check", reasoning_check)
|
171 |
+
workflow.add_node("formatting_check", formatting_check)
|
172 |
+
|
173 |
+
# Add edges
|
174 |
+
workflow.add_edge(START, "extract_answer")
|
175 |
+
workflow.add_edge("extract_answer", "reasoning_check")
|
176 |
+
workflow.add_edge("reasoning_check", "formatting_check")
|
177 |
+
workflow.add_edge("formatting_check", END)
|
178 |
+
|
179 |
+
# Compile the graph
|
180 |
+
return workflow.compile() # type: ignore
|
181 |
+
|
182 |
+
|
183 |
+
def validate_answer(graph: StateGraph, answer: str, agent_memory: Any) -> Dict:
|
184 |
+
"""Validate the answer using the LangGraph workflow.
|
185 |
+
Args:
|
186 |
+
graph: The validation graph (LangGraph StateGraph).
|
187 |
+
answer: The answer to validate.
|
188 |
+
agent_memory: The agent's memory.
|
189 |
+
Returns:
|
190 |
+
A dictionary with validation results.
|
191 |
+
"""
|
192 |
+
try:
|
193 |
+
# Initialize state
|
194 |
+
initial_state = {
|
195 |
+
"answer": answer,
|
196 |
+
"final_answer": None,
|
197 |
+
"agent_memory": agent_memory,
|
198 |
+
"valid_answer": False,
|
199 |
+
}
|
200 |
+
|
201 |
+
# Run the graph
|
202 |
+
result = graph.invoke(initial_state) # type:ignore
|
203 |
+
|
204 |
+
return {
|
205 |
+
"valid_answer": result.get("valid_answer", False),
|
206 |
+
"final_answer": result.get("final_answer", None),
|
207 |
+
}
|
208 |
+
except Exception as e:
|
209 |
+
print(f"Validation failed: {e}")
|
210 |
+
return {"valid_answer": False, "final_answer": None}
|
build/lib/src/tools.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=no-member
|
2 |
+
import base64
|
3 |
+
import gc
|
4 |
+
import mimetypes
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import tempfile
|
8 |
+
import time
|
9 |
+
import uuid
|
10 |
+
from datetime import timedelta
|
11 |
+
from typing import Dict, List, Optional, TypedDict, Union
|
12 |
+
from urllib.parse import urlparse
|
13 |
+
|
14 |
+
import cv2
|
15 |
+
import imageio
|
16 |
+
import pandas as pd
|
17 |
+
import pytesseract
|
18 |
+
import requests
|
19 |
+
import torch
|
20 |
+
import whisper
|
21 |
+
import yt_dlp
|
22 |
+
from bs4 import BeautifulSoup, Tag
|
23 |
+
from dotenv import load_dotenv
|
24 |
+
from duckduckgo_search import DDGS
|
25 |
+
from langchain_core.messages import HumanMessage
|
26 |
+
from langchain_core.tools import tool
|
27 |
+
from langchain_ollama import ChatOllama
|
28 |
+
from PIL import Image
|
29 |
+
from playwright.sync_api import sync_playwright
|
30 |
+
from youtube_transcript_api import (
|
31 |
+
NoTranscriptFound,
|
32 |
+
TranscriptsDisabled,
|
33 |
+
YouTubeTranscriptApi,
|
34 |
+
)
|
35 |
+
|
36 |
+
load_dotenv()
|
37 |
+
base_url = os.getenv("OLLAMA_BASE_URL")
|
38 |
+
model_vision = ChatOllama(
|
39 |
+
model="gemma3:latest",
|
40 |
+
base_url=base_url,
|
41 |
+
)
|
42 |
+
model_text = ChatOllama(model="gemma3:latest", base_url=base_url)
|
43 |
+
|
44 |
+
|
45 |
+
@tool
|
46 |
+
def use_vision_model(question: str) -> str:
|
47 |
+
"""
|
48 |
+
A multimodal reasoning model that combines image and text input to answer
|
49 |
+
questions using the image.
|
50 |
+
"""
|
51 |
+
# Extract image paths
|
52 |
+
image_paths = re.findall(r"[\w\-/\.]+\.(?:png|jpg|jpeg|webp)", question)
|
53 |
+
image_paths = [p for p in image_paths if os.path.exists(p)]
|
54 |
+
|
55 |
+
if not image_paths:
|
56 |
+
return "No valid image file found in the question."
|
57 |
+
|
58 |
+
image_path = image_paths[0]
|
59 |
+
|
60 |
+
# Preprocess the image using OpenCV
|
61 |
+
image = cv2.imread(image_path)
|
62 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
63 |
+
gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20)
|
64 |
+
gray = cv2.GaussianBlur(gray, (5, 5), 0)
|
65 |
+
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
66 |
+
|
67 |
+
# Create a temporary file for the processed image
|
68 |
+
with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
69 |
+
temp_image_path = tmp_file.name
|
70 |
+
cv2.imwrite(temp_image_path, edges)
|
71 |
+
|
72 |
+
# Encode the temp image
|
73 |
+
mime_type, _ = mimetypes.guess_type(temp_image_path)
|
74 |
+
mime_type = mime_type or "image/png"
|
75 |
+
with open(temp_image_path, "rb") as f:
|
76 |
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
77 |
+
|
78 |
+
# Prepare the prompt and image for the model
|
79 |
+
messages = [
|
80 |
+
{
|
81 |
+
"role": "user",
|
82 |
+
"content": [
|
83 |
+
{"type": "text", "text": question},
|
84 |
+
{
|
85 |
+
"type": "image_url",
|
86 |
+
"image_url": {"url": f"data:{mime_type};base64,{encoded}"},
|
87 |
+
},
|
88 |
+
],
|
89 |
+
}
|
90 |
+
]
|
91 |
+
|
92 |
+
# Invoke the vision model
|
93 |
+
response = model_vision.invoke(messages)
|
94 |
+
|
95 |
+
# Clean up
|
96 |
+
del messages, encoded, image_path
|
97 |
+
gc.collect()
|
98 |
+
torch.cuda.empty_cache()
|
99 |
+
|
100 |
+
return str(response.content) if hasattr(response, "content") else str(response)
|
101 |
+
|
102 |
+
|
103 |
+
# YouTube Video Review Tool
|
104 |
+
@tool
|
105 |
+
def review_youtube_video(url: str) -> str:
|
106 |
+
"""Reviews a YouTube video and answers a specific question about that video.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
url (str): the URL to the YouTube video.
|
110 |
+
question (str): The question you are asking about the video.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
str: The answer to the question
|
114 |
+
"""
|
115 |
+
# Extract video ID from URL (assuming it is in the format https://youtube.com/watch?v=VIDEO_ID)
|
116 |
+
video_id = url.split("v=")[1]
|
117 |
+
transcript_url = (
|
118 |
+
f"https://www.youtube.com/api/timedtext?v={video_id}" # Getting transcript data
|
119 |
+
)
|
120 |
+
|
121 |
+
response = requests.get(transcript_url, timeout=200)
|
122 |
+
|
123 |
+
transcript = response.text # This is the transcript (XML or SRT format)
|
124 |
+
|
125 |
+
# Prepare the content (just the transcript, no question needed)
|
126 |
+
transcript_content = f"Here is the transcript of the video: {transcript}"
|
127 |
+
|
128 |
+
# Return the transcript content so the main LLM can handle question generation
|
129 |
+
return transcript_content
|
130 |
+
|
131 |
+
|
132 |
+
# YouTube Frames to Images Tool
|
133 |
+
@tool
|
134 |
+
def video_frames_to_images(
|
135 |
+
url: str,
|
136 |
+
folder_name: str,
|
137 |
+
sample_interval_seconds: int = 5,
|
138 |
+
) -> List[str]:
|
139 |
+
"""Extracts frames from a video at specified intervals and saves them as images.
|
140 |
+
Args:
|
141 |
+
url (str): the URL to the video.
|
142 |
+
folder_name (str): the name of the folder to save the images to.
|
143 |
+
sample_interval_seconds (int): the interval between frames to sample.
|
144 |
+
Returns:
|
145 |
+
List[str]: A list of paths to the saved image files.
|
146 |
+
"""
|
147 |
+
# Create a subdirectory for the frames
|
148 |
+
frames_dir = os.path.join(folder_name, "frames")
|
149 |
+
os.makedirs(frames_dir, exist_ok=True)
|
150 |
+
|
151 |
+
ydl_opts = {
|
152 |
+
"format": "bestvideo[height<=1080]+bestaudio/best[height<=1080]/best",
|
153 |
+
"outtmpl": os.path.join(folder_name, "video.%(ext)s"),
|
154 |
+
"quiet": True,
|
155 |
+
"noplaylist": True,
|
156 |
+
"merge_output_format": "mp4",
|
157 |
+
"force_ipv4": True,
|
158 |
+
}
|
159 |
+
|
160 |
+
info_extracted = []
|
161 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
162 |
+
info = ydl.extract_info(url, download=True)
|
163 |
+
info_extracted.append(info)
|
164 |
+
video_path = next(
|
165 |
+
(
|
166 |
+
os.path.join(folder_name, f)
|
167 |
+
for f in os.listdir(folder_name)
|
168 |
+
if f.endswith(".mp4")
|
169 |
+
),
|
170 |
+
None,
|
171 |
+
)
|
172 |
+
|
173 |
+
if not video_path:
|
174 |
+
raise RuntimeError("Failed to download video as mp4")
|
175 |
+
|
176 |
+
reader = imageio.get_reader(video_path)
|
177 |
+
metadata = reader.get_meta_data()
|
178 |
+
fps = metadata.get("fps")
|
179 |
+
|
180 |
+
if fps is None:
|
181 |
+
reader.close()
|
182 |
+
raise RuntimeError("Unable to determine FPS from video metadata")
|
183 |
+
|
184 |
+
frame_interval = int(fps * sample_interval_seconds)
|
185 |
+
num_frames = reader.get_length()
|
186 |
+
image_paths: List[str] = []
|
187 |
+
|
188 |
+
for idx in range(num_frames):
|
189 |
+
if idx % frame_interval == 0:
|
190 |
+
# Save frame as image
|
191 |
+
frame = reader.get_data(idx)
|
192 |
+
image_path = os.path.join(frames_dir, f"frame_{idx:06d}.jpg")
|
193 |
+
imageio.imwrite(image_path, frame)
|
194 |
+
image_paths.append(image_path)
|
195 |
+
|
196 |
+
reader.close()
|
197 |
+
return image_paths
|
198 |
+
|
199 |
+
|
200 |
+
# File Reading Tool
|
201 |
+
@tool
|
202 |
+
def read_file(filepath: str) -> str:
|
203 |
+
"""Reads the content of a text file.
|
204 |
+
Args:
|
205 |
+
filepath (str): the path to the file to read.
|
206 |
+
Returns:
|
207 |
+
str: The content of the file.
|
208 |
+
"""
|
209 |
+
try:
|
210 |
+
with open(filepath, "r", encoding="utf-8") as file:
|
211 |
+
content = file.read()
|
212 |
+
return content
|
213 |
+
except FileNotFoundError:
|
214 |
+
return f"File not found: {filepath}"
|
215 |
+
except IOError as e:
|
216 |
+
return f"Error reading file: {str(e)}"
|
217 |
+
|
218 |
+
|
219 |
+
# File Download Tool
|
220 |
+
@tool
|
221 |
+
def download_file_from_url(url: str, directory: str) -> Dict[str, Union[str, None]]:
|
222 |
+
"""Downloads a file from a URL and saves it to a directory.
|
223 |
+
Args:
|
224 |
+
url (str): the URL to download the file from.
|
225 |
+
directory (str): the directory to save the file to.
|
226 |
+
Returns:
|
227 |
+
Dict[str, Union[str, None]]: A dictionary containing the file type and path.
|
228 |
+
"""
|
229 |
+
|
230 |
+
response = requests.get(url, stream=True, timeout=10)
|
231 |
+
response.raise_for_status()
|
232 |
+
|
233 |
+
content_type = response.headers.get("content-type", "").lower()
|
234 |
+
|
235 |
+
# Try to get filename from headers
|
236 |
+
filename = None
|
237 |
+
cd = response.headers.get("content-disposition", "")
|
238 |
+
match = re.search(r"filename\*=UTF-8\'\'(.+)", cd) or re.search(
|
239 |
+
r'filename="?([^"]+)"?', cd
|
240 |
+
)
|
241 |
+
if match:
|
242 |
+
filename = match.group(1)
|
243 |
+
|
244 |
+
# If not in headers, try URL
|
245 |
+
if not filename:
|
246 |
+
filename = os.path.basename(url.split("?")[0])
|
247 |
+
|
248 |
+
# Fallback to generated filename
|
249 |
+
if not filename:
|
250 |
+
extension = {
|
251 |
+
"image/jpeg": ".jpg",
|
252 |
+
"image/png": ".png",
|
253 |
+
"image/gif": ".gif",
|
254 |
+
"audio/wav": ".wav",
|
255 |
+
"audio/mpeg": ".mp3",
|
256 |
+
"video/mp4": ".mp4",
|
257 |
+
"text/plain": ".txt",
|
258 |
+
"text/csv": ".csv",
|
259 |
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
260 |
+
"application/vnd.ms-excel": ".xls",
|
261 |
+
"application/octet-stream": ".bin",
|
262 |
+
}.get(content_type, ".bin")
|
263 |
+
filename = f"downloaded_{uuid.uuid4().hex[:8]}{extension}"
|
264 |
+
|
265 |
+
os.makedirs(directory, exist_ok=True)
|
266 |
+
file_path = os.path.join(directory, filename)
|
267 |
+
|
268 |
+
with open(file_path, "wb") as f:
|
269 |
+
for chunk in response.iter_content(chunk_size=8192):
|
270 |
+
f.write(chunk)
|
271 |
+
|
272 |
+
# shutil.copy(file_path, os.getcwd())
|
273 |
+
|
274 |
+
return {"type": content_type, "path": file_path}
|
275 |
+
|
276 |
+
|
277 |
+
# Text Extraction from Image Tool
|
278 |
+
@tool
|
279 |
+
def extract_text_from_image(image_path: str) -> str:
|
280 |
+
"""Extracts text from an image using OCR.
|
281 |
+
Args:
|
282 |
+
image_path (str): the path to the image to extract text from.
|
283 |
+
Returns:
|
284 |
+
str: The text extracted from the image.
|
285 |
+
"""
|
286 |
+
|
287 |
+
image = Image.open(image_path)
|
288 |
+
text = pytesseract.image_to_string(image)
|
289 |
+
return f"Extracted text from image:\n\n{text}"
|
290 |
+
|
291 |
+
|
292 |
+
# CSV Analysis Tool
|
293 |
+
@tool
|
294 |
+
def analyze_csv_file(file_path: str, query: str) -> str:
|
295 |
+
"""Analyzes a CSV file and answers questions about its contents using an Ollama model.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
file_path (str): The path to the CSV file to analyze.
|
299 |
+
query (str): The question to answer about the CSV file.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
str: The result of the analysis.
|
303 |
+
"""
|
304 |
+
# Load the CSV file
|
305 |
+
df = pd.read_csv(file_path)
|
306 |
+
df_str = df.to_string(index=False)
|
307 |
+
|
308 |
+
# Compose the prompt
|
309 |
+
prompt = f"""
|
310 |
+
You are a data analyst. Analyze the following CSV data and answer the question provided.
|
311 |
+
|
312 |
+
CSV Dimensions: {df.shape[0]} rows × {df.shape[1]} columns
|
313 |
+
|
314 |
+
CSV Data:
|
315 |
+
{df_str}
|
316 |
+
|
317 |
+
Please provide:
|
318 |
+
1. A summary of the data structure and content
|
319 |
+
2. Key patterns and insights
|
320 |
+
3. Potential data quality issues
|
321 |
+
4. Suggestions for analysis
|
322 |
+
|
323 |
+
User Query:
|
324 |
+
{query}
|
325 |
+
|
326 |
+
Format your response in markdown with sections and bullet points.
|
327 |
+
"""
|
328 |
+
|
329 |
+
model = model_text
|
330 |
+
|
331 |
+
# Call the model
|
332 |
+
response = model.invoke([{"type": "text", "text": prompt}])
|
333 |
+
del df
|
334 |
+
torch.cuda.empty_cache()
|
335 |
+
gc.collect()
|
336 |
+
|
337 |
+
# Return the result
|
338 |
+
if hasattr(response, "content") and isinstance(response.content, str):
|
339 |
+
return response.content
|
340 |
+
return str(response)
|
341 |
+
|
342 |
+
|
343 |
+
# Excel Analysis Tool
|
344 |
+
@tool
|
345 |
+
def analyze_excel_file(file_path: str) -> str:
|
346 |
+
"""Analyzes an Excel file and answers questions about its contents using Ollama backed LLM
|
347 |
+
Args:
|
348 |
+
file_path (str): the path to the Excel file to analyze.
|
349 |
+
question (str): the question to answer about the Excel file.
|
350 |
+
Returns:
|
351 |
+
str: The result of the analysis.
|
352 |
+
"""
|
353 |
+
llm = model_text
|
354 |
+
|
355 |
+
# Read all sheets from the Excel file
|
356 |
+
excel_file = pd.ExcelFile(file_path)
|
357 |
+
sheet_names = excel_file.sheet_names
|
358 |
+
|
359 |
+
result = f"Excel file loaded with {len(sheet_names)} sheets: {', '.join(sheet_names)}\n\n"
|
360 |
+
|
361 |
+
for sheet_name in sheet_names:
|
362 |
+
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
363 |
+
df_str = df.to_string()
|
364 |
+
|
365 |
+
# Build the prompt
|
366 |
+
prompt = f"""Analyze the following Excel sheet data and answer the user's query.
|
367 |
+
Sheet Name: {sheet_name}
|
368 |
+
Dimensions: {len(df)} rows × {len(df.columns)} columns
|
369 |
+
|
370 |
+
Data:
|
371 |
+
{df_str}
|
372 |
+
|
373 |
+
Please provide:
|
374 |
+
1. A summary of the data structure and content
|
375 |
+
2. Key patterns and insights
|
376 |
+
3. Potential data quality issues
|
377 |
+
4. Suggestions for analysis
|
378 |
+
|
379 |
+
Format the response clearly using headings and bullet points."""
|
380 |
+
|
381 |
+
# Call the LLM with the prompt
|
382 |
+
response = llm.invoke([HumanMessage(content=prompt)])
|
383 |
+
|
384 |
+
result += f"=== Sheet: {sheet_name} ===\n"
|
385 |
+
result += str(response.content) + "\n"
|
386 |
+
result += "=" * 50 + "\n\n"
|
387 |
+
del df
|
388 |
+
gc.collect()
|
389 |
+
|
390 |
+
excel_file.close()
|
391 |
+
torch.cuda.empty_cache()
|
392 |
+
|
393 |
+
return result
|
394 |
+
|
395 |
+
|
396 |
+
# Audio Transcription Tool
|
397 |
+
def transcribe_audio(audio_file_path: str) -> str:
|
398 |
+
"""Transcribes an audio file using Whisper's audio capabilities.
|
399 |
+
Args:
|
400 |
+
audio_file_path (str): The path to the audio file to transcribe.
|
401 |
+
mime_type (str): The MIME type of the audio file.
|
402 |
+
Returns:
|
403 |
+
str: The transcript of the audio file.
|
404 |
+
Raises:
|
405 |
+
ValueError: If the MIME type is not supported.
|
406 |
+
"""
|
407 |
+
|
408 |
+
model = whisper.load_model("base")
|
409 |
+
result = model.transcribe(audio_file_path)
|
410 |
+
assert isinstance(result["text"], str)
|
411 |
+
|
412 |
+
del model
|
413 |
+
torch.cuda.empty_cache()
|
414 |
+
gc.collect()
|
415 |
+
return result["text"]
|
416 |
+
|
417 |
+
|
418 |
+
def _extract_video_id(url: str) -> Optional[str]:
|
419 |
+
"""Extract video ID from YouTube URL.
|
420 |
+
Args:
|
421 |
+
url (str): the URL to the YouTube video.
|
422 |
+
Returns:
|
423 |
+
str: The video ID of the YouTube video.
|
424 |
+
"""
|
425 |
+
patterns = [
|
426 |
+
r"(?:youtube\.com\/watch\?v=|youtube\.com\/embed\/|youtu\.be\/)([^&\n?#]+)",
|
427 |
+
r"(?:youtube\.com\/v\/|youtube\.com\/e\/|youtube\.com\/user\/[^\/]+\/|youtube\.com\/[^\/]+\/|youtube\.com\/embed\/|youtu\.be\/)([^&\n?#]+)",
|
428 |
+
]
|
429 |
+
|
430 |
+
for pattern in patterns:
|
431 |
+
match = re.search(pattern, url)
|
432 |
+
if match:
|
433 |
+
return match.group(1)
|
434 |
+
return None
|
435 |
+
|
436 |
+
|
437 |
+
@tool
|
438 |
+
def transcribe_youtube(url: str) -> str:
|
439 |
+
"""
|
440 |
+
Transcribes a YouTube video using YouTube Transcript API or ChatOllama with Whisper as fallback.
|
441 |
+
|
442 |
+
This function first tries to fetch the transcript of a YouTube video using the YouTube Transcript API.
|
443 |
+
If the transcript is unavailable (e.g., due to captions being disabled), it falls back to using
|
444 |
+
ChatOllama integrated with Whisper to transcribe the audio.
|
445 |
+
|
446 |
+
Args:
|
447 |
+
url (str): The URL to the YouTube video.
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
str: The transcript of the YouTube video, or an error message if transcription fails.
|
451 |
+
"""
|
452 |
+
|
453 |
+
try:
|
454 |
+
# Try using YouTube Transcript API
|
455 |
+
video_id = _extract_video_id(url)
|
456 |
+
transcript = ""
|
457 |
+
transcript_chunks = YouTubeTranscriptApi.get_transcript(
|
458 |
+
video_id, languages=["en"]
|
459 |
+
)
|
460 |
+
for chunk in transcript_chunks:
|
461 |
+
timestamp = str(timedelta(seconds=int(chunk["start"])))
|
462 |
+
transcript += f"[{timestamp}] {chunk['text']}\n"
|
463 |
+
|
464 |
+
# Return API transcript if available
|
465 |
+
if transcript.strip():
|
466 |
+
return transcript
|
467 |
+
|
468 |
+
except (TranscriptsDisabled, NoTranscriptFound, Exception) as exec:
|
469 |
+
try:
|
470 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
471 |
+
# Download audio from YouTube
|
472 |
+
ydl_opts = {
|
473 |
+
"format": "bestaudio/best",
|
474 |
+
"outtmpl": os.path.join(tmpdir, "audio.%(ext)s"),
|
475 |
+
"quiet": True,
|
476 |
+
"noplaylist": True,
|
477 |
+
"postprocessors": [
|
478 |
+
{
|
479 |
+
"key": "FFmpegExtractAudio",
|
480 |
+
"preferredcodec": "wav",
|
481 |
+
"preferredquality": "192",
|
482 |
+
}
|
483 |
+
],
|
484 |
+
}
|
485 |
+
|
486 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
487 |
+
info = ydl.extract_info(url, download=True)
|
488 |
+
|
489 |
+
if info is not None:
|
490 |
+
title = info.get("title", "Unknown Title") # Type:None
|
491 |
+
duration = info.get("duration", 0) # in seconds
|
492 |
+
uploader = info.get("uploader", "Unknown Uploader")
|
493 |
+
else:
|
494 |
+
title = "Unknown Title"
|
495 |
+
duration = 0
|
496 |
+
uploader = "Unknown Uploader"
|
497 |
+
|
498 |
+
audio_path = next(
|
499 |
+
(
|
500 |
+
os.path.join(tmpdir, f)
|
501 |
+
for f in os.listdir(tmpdir)
|
502 |
+
if f.endswith(".wav")
|
503 |
+
),
|
504 |
+
None,
|
505 |
+
)
|
506 |
+
if not audio_path:
|
507 |
+
raise RuntimeError("Failed to download or convert audio") from exec
|
508 |
+
|
509 |
+
# Use Whisper for initial transcription
|
510 |
+
whisper_model = whisper.load_model("base")
|
511 |
+
transcription = whisper_model.transcribe(audio_path, verbose=False)
|
512 |
+
raw_transcript = transcription["text"]
|
513 |
+
del whisper_model
|
514 |
+
gc.collect()
|
515 |
+
|
516 |
+
# Use ChatOllama to format transcript with timestamps
|
517 |
+
ollama = model_text
|
518 |
+
prompt = (
|
519 |
+
"Please format the following raw transcript into a structured format with timestamps "
|
520 |
+
f"The following transcript was generated from a YouTube video titled '{title}' "
|
521 |
+
f"uploaded by {uploader}. The total video duration is approximately {duration}.\n\n"
|
522 |
+
"Use the video’s length to help guide timestamp estimation.\n\n"
|
523 |
+
"(e.g., [00:00:00] text). Estimate timestamps based on the natural flow of the text."
|
524 |
+
f"Raw transcript:\n{raw_transcript}"
|
525 |
+
)
|
526 |
+
response = ollama.invoke([HumanMessage(content=prompt)])
|
527 |
+
formatted_transcript = str(
|
528 |
+
response.content
|
529 |
+
) # Ensure response is a string
|
530 |
+
|
531 |
+
torch.cuda.empty_cache()
|
532 |
+
|
533 |
+
return formatted_transcript
|
534 |
+
except Exception as fallback_exc:
|
535 |
+
raise RuntimeError("Fallback Transcription failed") from fallback_exc
|
536 |
+
return "Transcription failed unexpectedly."
|
537 |
+
|
538 |
+
|
539 |
+
@tool
|
540 |
+
def website_scrape(url: str) -> str:
|
541 |
+
"""scrapes a website and returns the text.
|
542 |
+
args:
|
543 |
+
url (str): the url to the website to scrape.
|
544 |
+
returns:
|
545 |
+
str: the text of the website.
|
546 |
+
"""
|
547 |
+
try:
|
548 |
+
parsed_url = urlparse(url)
|
549 |
+
if not parsed_url.scheme or not parsed_url.netloc:
|
550 |
+
raise ValueError(
|
551 |
+
f"Invalid URL: '{url}'. Call `duckduckgo_search` first to get a valid URL."
|
552 |
+
)
|
553 |
+
with sync_playwright() as p:
|
554 |
+
browser = p.chromium.launch(headless=True)
|
555 |
+
page = browser.new_page()
|
556 |
+
page.goto(url, wait_until="networkidle", timeout=60000)
|
557 |
+
page.wait_for_load_state("domcontentloaded")
|
558 |
+
html_content = page.content()
|
559 |
+
browser.close()
|
560 |
+
|
561 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
562 |
+
|
563 |
+
relevant_text = ""
|
564 |
+
for header in soup.find_all(["h2", "h3"]):
|
565 |
+
heading_text = header.get_text().strip().lower()
|
566 |
+
if "discography" in heading_text or "studio albums" in heading_text:
|
567 |
+
section_texts = []
|
568 |
+
tag = header.find_next_sibling()
|
569 |
+
while tag and (
|
570 |
+
not isinstance(tag, Tag) or tag.name not in ["h2", "h3"]
|
571 |
+
):
|
572 |
+
section_texts.append(tag.get_text(separator=" ", strip=True))
|
573 |
+
tag = tag.find_next_sibling()
|
574 |
+
relevant_text = "\n\n".join(section_texts)
|
575 |
+
break
|
576 |
+
if not relevant_text:
|
577 |
+
article = soup.find("article")
|
578 |
+
if article:
|
579 |
+
relevant_text = article.get_text(separator=" ", strip=True)
|
580 |
+
if not relevant_text:
|
581 |
+
relevant_text = soup.get_text(separator=" ", strip=True)
|
582 |
+
|
583 |
+
# step 2: chunk the text (optional but recommended)
|
584 |
+
def chunk_text(text, max_length=1000):
|
585 |
+
words = text.split()
|
586 |
+
chunks = []
|
587 |
+
for i in range(0, len(words), max_length):
|
588 |
+
chunks.append(" ".join(words[i : i + max_length]))
|
589 |
+
return chunks
|
590 |
+
|
591 |
+
chunks = chunk_text(relevant_text)
|
592 |
+
|
593 |
+
# return only the first 2–3 chunks to keep it concise
|
594 |
+
return "\n\n".join(chunks[:100])
|
595 |
+
except ValueError as e:
|
596 |
+
# Catch URL validation errors
|
597 |
+
return str(e)
|
598 |
+
except Exception as e:
|
599 |
+
# Catch other unexpected errors
|
600 |
+
return f"Scraping failed: {str(e)}"
|
601 |
+
|
602 |
+
|
603 |
+
class SearchResult(TypedDict):
|
604 |
+
query: str
|
605 |
+
status: str
|
606 |
+
attempt: int
|
607 |
+
results: Optional[List[dict]]
|
608 |
+
error: Optional[str]
|
609 |
+
|
610 |
+
|
611 |
+
@tool
|
612 |
+
def duckduckgo_search(query: str, max_results: int = 10) -> SearchResult:
|
613 |
+
"""
|
614 |
+
Perform a DuckDuckGo search with retry and backoff.
|
615 |
+
Use this FIRST before invoking and scraping tools.
|
616 |
+
Args:
|
617 |
+
query: The search query string.
|
618 |
+
max_results: Max number of results to return (default 10).
|
619 |
+
Returns:
|
620 |
+
A dict with the query, results, status, attempt count, and any error.
|
621 |
+
"""
|
622 |
+
max_retries = 3
|
623 |
+
base_delay = 2
|
624 |
+
backoff_factor = 2
|
625 |
+
|
626 |
+
for attempt in range(max_retries):
|
627 |
+
try:
|
628 |
+
with DDGS() as ddgs:
|
629 |
+
results = ddgs.text(keywords=query, max_results=max_results)
|
630 |
+
if results:
|
631 |
+
formatted_results = [
|
632 |
+
{
|
633 |
+
"title": result.get("title", ""),
|
634 |
+
"url": result.get("href", ""),
|
635 |
+
"body": result.get("body", ""),
|
636 |
+
}
|
637 |
+
for result in results
|
638 |
+
]
|
639 |
+
return {
|
640 |
+
"query": query,
|
641 |
+
"status": "success",
|
642 |
+
"attempt": attempt + 1,
|
643 |
+
"results": formatted_results,
|
644 |
+
"error": None,
|
645 |
+
}
|
646 |
+
except Exception as e:
|
647 |
+
print(f"[DuckDuckGo Tool] Attempt {attempt + 1} failed: {e}")
|
648 |
+
time.sleep(base_delay * (backoff_factor**attempt))
|
649 |
+
|
650 |
+
return {
|
651 |
+
"query": query,
|
652 |
+
"status": "failed",
|
653 |
+
"attempt": max_retries,
|
654 |
+
"results": None,
|
655 |
+
"error": "Max retries exceeded or request failed.",
|
656 |
+
}
|
657 |
+
|
658 |
+
|
659 |
+
@tool
|
660 |
+
def reverse_decoder(question: str) -> str:
|
661 |
+
"""Decodes a reversed sentence if the input appears to be written backward.
|
662 |
+
|
663 |
+
Args:
|
664 |
+
question (str): The possibly reversed question string.
|
665 |
+
|
666 |
+
Returns:
|
667 |
+
str: The decoded sentence.
|
668 |
+
"""
|
669 |
+
# Remove leading punctuation if present
|
670 |
+
cleaned = question.strip().strip(".!?")
|
671 |
+
|
672 |
+
# Check if it's likely reversed (simple heuristic: mostly lowercase, reversed word order)
|
673 |
+
reversed_text = cleaned[::-1]
|
674 |
+
|
675 |
+
return reversed_text
|
build/lib/tools/__init__.py
ADDED
File without changes
|
build/lib/tools/chess.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import models, transforms
|
8 |
+
|
9 |
+
# Set device
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
# Load the pre-trained model
|
13 |
+
try:
|
14 |
+
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
15 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 13) # 13 classes including 'empty'
|
16 |
+
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
|
17 |
+
model.eval()
|
18 |
+
model = model.to(device)
|
19 |
+
except Exception as e:
|
20 |
+
print(f"Error loading model: {e}")
|
21 |
+
exit(1)
|
22 |
+
|
23 |
+
# Mapping chess piece indices
|
24 |
+
piece_labels = [
|
25 |
+
"black_bishop",
|
26 |
+
"black_king",
|
27 |
+
"black_knight",
|
28 |
+
"black_pawn",
|
29 |
+
"black_queen",
|
30 |
+
"black_rook",
|
31 |
+
"empty",
|
32 |
+
"white_bishop",
|
33 |
+
"white_king",
|
34 |
+
"white_knight",
|
35 |
+
"white_pawn",
|
36 |
+
"white_queen",
|
37 |
+
"white_rook",
|
38 |
+
]
|
39 |
+
|
40 |
+
# Define chessboard coordinates (0,0) is top-left (a8), (7,7) is bottom-right (h1)
|
41 |
+
coordinates = [(i, j) for i in range(8) for j in range(8)]
|
42 |
+
|
43 |
+
# Define a transformation to prepare images for the model
|
44 |
+
transform = transforms.Compose(
|
45 |
+
[
|
46 |
+
transforms.Resize((224, 224)),
|
47 |
+
transforms.ToTensor(),
|
48 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
49 |
+
]
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
# Function to process and predict the piece type at each square
|
54 |
+
def predict_piece(image, model, device):
|
55 |
+
try:
|
56 |
+
if len(image.shape) == 2 or image.shape[2] == 1:
|
57 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
58 |
+
else:
|
59 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
60 |
+
|
61 |
+
image = Image.fromarray(image)
|
62 |
+
image_tensor = transform(image).unsqueeze(0).to(device)
|
63 |
+
with torch.no_grad():
|
64 |
+
output = model(image_tensor)
|
65 |
+
_, predicted = torch.max(output, 1)
|
66 |
+
return piece_labels[predicted.item()]
|
67 |
+
except Exception as e:
|
68 |
+
print(f"Error predicting piece: {e}")
|
69 |
+
return "unknown"
|
70 |
+
|
71 |
+
|
72 |
+
# Function to detect chessboard grid using edge detection and Hough lines
|
73 |
+
def detect_chessboard_grid(image):
|
74 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
75 |
+
# Enhance contrast
|
76 |
+
gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20)
|
77 |
+
# Apply Gaussian blur to reduce noise
|
78 |
+
gray = cv2.GaussianBlur(gray, (5, 5), 0)
|
79 |
+
# Edge detection with Canny
|
80 |
+
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
81 |
+
|
82 |
+
# Detect lines using Hough Transform
|
83 |
+
lines = cv2.HoughLinesP(
|
84 |
+
edges, 1, np.pi / 180, threshold=80, minLineLength=50, maxLineGap=10
|
85 |
+
)
|
86 |
+
|
87 |
+
if lines is None:
|
88 |
+
print("No lines detected.")
|
89 |
+
return None, edges
|
90 |
+
|
91 |
+
# Separate horizontal and vertical lines
|
92 |
+
h_lines = []
|
93 |
+
v_lines = []
|
94 |
+
for line in lines:
|
95 |
+
x1, y1, x2, y2 = line[0]
|
96 |
+
if abs(x2 - x1) > abs(y2 - y1): # Horizontal line
|
97 |
+
h_lines.append((y1, x1, x2))
|
98 |
+
else: # Vertical line
|
99 |
+
v_lines.append((x1, y1, y2))
|
100 |
+
|
101 |
+
# Sort and filter to get exactly 9 lines for each
|
102 |
+
h_lines = sorted(h_lines, key=lambda x: x[0])[:9] # Top 9 horizontal lines
|
103 |
+
v_lines = sorted(v_lines, key=lambda x: x[0])[:9] # Top 9 vertical lines
|
104 |
+
|
105 |
+
if len(h_lines) < 9 or len(v_lines) < 9:
|
106 |
+
print(
|
107 |
+
f"Insufficient lines detected: {len(h_lines)} horizontal, {len(v_lines)} vertical"
|
108 |
+
)
|
109 |
+
return None, edges
|
110 |
+
|
111 |
+
# Find intersections to get 8x8 grid corners
|
112 |
+
corners = []
|
113 |
+
for h in h_lines:
|
114 |
+
y = h[0]
|
115 |
+
for v in v_lines:
|
116 |
+
x = v[0]
|
117 |
+
corners.append([x, y])
|
118 |
+
|
119 |
+
# corners = []
|
120 |
+
# for i in range(8):
|
121 |
+
# for j in range(8):
|
122 |
+
# x = int((v_lines[j][0] + v_lines[j + 1][0]) / 2)
|
123 |
+
# y = int((h_lines[i][1] + h_lines[i + 1][1]) / 2)
|
124 |
+
# corners.append([x, y])
|
125 |
+
|
126 |
+
# Ensure exactly 64 corners (8x8 grid)
|
127 |
+
if len(corners) != 64:
|
128 |
+
print(f"Expected 64 corners, got {len(corners)}")
|
129 |
+
return None, edges
|
130 |
+
|
131 |
+
corners = np.array(corners, dtype=np.float32).reshape(8, 8, 2)
|
132 |
+
|
133 |
+
# Visualize detected lines for debugging
|
134 |
+
debug_image = image.copy()
|
135 |
+
for y, x1, x2 in h_lines:
|
136 |
+
cv2.line(debug_image, (x1, y), (x2, y), (0, 255, 0), 2)
|
137 |
+
for x, y1, y2 in v_lines:
|
138 |
+
cv2.line(debug_image, (x, y1), (x, y2), (0, 0, 255), 2)
|
139 |
+
cv2.imwrite("lines_debug.png", debug_image)
|
140 |
+
|
141 |
+
return corners, edges
|
142 |
+
|
143 |
+
|
144 |
+
# Function to extract coordinates of chess pieces from an image
|
145 |
+
def extract_chessboard_coordinates(image_path):
|
146 |
+
try:
|
147 |
+
image = cv2.imread(image_path)
|
148 |
+
if image is None:
|
149 |
+
print(f"Failed to load image: {image_path}")
|
150 |
+
return []
|
151 |
+
except Exception as e:
|
152 |
+
print(f"Error loading image {image_path}: {e}")
|
153 |
+
return []
|
154 |
+
|
155 |
+
# Try OpenCV's chessboard detection first
|
156 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
157 |
+
ret, corners = cv2.findChessboardCorners(gray, (8, 8), None)
|
158 |
+
|
159 |
+
if ret:
|
160 |
+
corners = cv2.cornerSubPix(
|
161 |
+
gray,
|
162 |
+
corners,
|
163 |
+
(11, 11),
|
164 |
+
(-1, -1),
|
165 |
+
criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.1),
|
166 |
+
)
|
167 |
+
corners = corners.reshape(8, 8, 2)
|
168 |
+
else:
|
169 |
+
print("OpenCV chessboard detection failed. Attempting edge-based detection.")
|
170 |
+
corners, edges = detect_chessboard_grid(image)
|
171 |
+
if corners is None:
|
172 |
+
# Save edges for debugging
|
173 |
+
cv2.imwrite("edges_debug.png", edges)
|
174 |
+
print("Saved edge detection output to edges_debug.png")
|
175 |
+
return []
|
176 |
+
# Save debug image with detected corners
|
177 |
+
debug_image = image.copy()
|
178 |
+
for h in range(8):
|
179 |
+
for v in range(8):
|
180 |
+
x, y = int(corners[h, v, 0]), int(corners[h, v, 1])
|
181 |
+
cv2.circle(debug_image, (x, y), 5, (0, 255, 0), -1)
|
182 |
+
cv2.imwrite("grid_debug.png", debug_image)
|
183 |
+
print("Saved grid detection debug image to grid_debug.png")
|
184 |
+
|
185 |
+
# Calculate square size dynamically
|
186 |
+
square_width = np.mean(
|
187 |
+
[
|
188 |
+
np.linalg.norm(corners[i, j] - corners[i, j + 1])
|
189 |
+
for i in range(8)
|
190 |
+
for j in range(7)
|
191 |
+
]
|
192 |
+
)
|
193 |
+
square_height = np.mean(
|
194 |
+
[
|
195 |
+
np.linalg.norm(corners[i, j] - corners[i + 1, j])
|
196 |
+
for i in range(7)
|
197 |
+
for j in range(8)
|
198 |
+
]
|
199 |
+
)
|
200 |
+
square_size = int(min(square_width, square_height))
|
201 |
+
|
202 |
+
# Create a blank grid to store coordinates
|
203 |
+
piece_coordinates = []
|
204 |
+
|
205 |
+
# Loop through all coordinates and detect pieces
|
206 |
+
for i, j in coordinates:
|
207 |
+
try:
|
208 |
+
x = int(corners[i, j, 0])
|
209 |
+
y = int(corners[i, j, 1])
|
210 |
+
w = h = square_size
|
211 |
+
|
212 |
+
x = max(0, x)
|
213 |
+
y = max(0, y)
|
214 |
+
x_end = min(image.shape[1], x + w)
|
215 |
+
y_end = min(image.shape[0], y + h)
|
216 |
+
roi = image[y:y_end, x:x_end]
|
217 |
+
|
218 |
+
if roi.shape[0] == 0 or roi.shape[1] == 0:
|
219 |
+
print(f"Invalid ROI at square ({i}, {j})")
|
220 |
+
piece_coordinates.append(((i, j), "unknown"))
|
221 |
+
continue
|
222 |
+
|
223 |
+
predicted_piece = predict_piece(roi, model, device)
|
224 |
+
piece_coordinates.append(((i, j), predicted_piece))
|
225 |
+
except Exception as e:
|
226 |
+
print(f"Error processing square ({i}, {j}): {e}")
|
227 |
+
piece_coordinates.append(((i, j), "unknown"))
|
228 |
+
|
229 |
+
return piece_coordinates
|
230 |
+
|
231 |
+
|
232 |
+
# Example usage
|
233 |
+
IMAGE_PATH = "test.png"
|
234 |
+
coordinates = extract_chessboard_coordinates(IMAGE_PATH)
|
235 |
+
for coord, piece in coordinates:
|
236 |
+
print(f"Piece at {coord}: {piece}")
|
237 |
+
|
238 |
+
del model
|
239 |
+
if torch.cuda.is_available():
|
240 |
+
torch.cuda.empty_cache()
|
241 |
+
gc.collect()
|
build/lib/tools/model.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ollama import chat
|
3 |
+
|
4 |
+
# from pathlib import Path
|
5 |
+
|
6 |
+
# Pass in the path to the image
|
7 |
+
path = input("Please enter the path to the image: ")
|
8 |
+
|
9 |
+
# You can also pass in base64 encoded image data
|
10 |
+
# img = base64.b64encode(Path(path).read_bytes()).decode()
|
11 |
+
# or the raw bytes
|
12 |
+
# img = Path(path).read_bytes()
|
13 |
+
|
14 |
+
response = chat(
|
15 |
+
model="gemma3:latest",
|
16 |
+
messages=[
|
17 |
+
{
|
18 |
+
"role": "user",
|
19 |
+
"content": "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation.",
|
20 |
+
"images": [path],
|
21 |
+
}
|
22 |
+
],
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
print(response.message.content)
|
27 |
+
torch.cuda.empty_cache()
|
build/lib/tools/model_chess.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torch import nn, optim
|
7 |
+
from torch.utils.data import DataLoader, Dataset, random_split
|
8 |
+
from torchvision import models, transforms
|
9 |
+
|
10 |
+
# Define data transformations for training and validation
|
11 |
+
transform = transforms.Compose(
|
12 |
+
[
|
13 |
+
transforms.Resize((224, 224)), # Ensure all images are 224x224
|
14 |
+
transforms.ToTensor(), # Convert to tensor
|
15 |
+
transforms.Normalize(
|
16 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
17 |
+
), # Standard for ResNet
|
18 |
+
]
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
# Custom dataset class for loading chess piece images
|
23 |
+
class ChessPieceDataset(Dataset):
|
24 |
+
def __init__(self, root_dir, transform=None):
|
25 |
+
"""
|
26 |
+
Args:
|
27 |
+
root_dir (str): Directory with all the images and subdirectories (class labels).
|
28 |
+
transform (callable, optional): Optional transform to be applied on an image.
|
29 |
+
"""
|
30 |
+
self.root_dir = root_dir
|
31 |
+
self.transform = transform
|
32 |
+
self.classes = sorted(
|
33 |
+
[
|
34 |
+
d
|
35 |
+
for d in os.listdir(root_dir)
|
36 |
+
if os.path.isdir(os.path.join(root_dir, d))
|
37 |
+
]
|
38 |
+
)
|
39 |
+
self.image_paths = []
|
40 |
+
self.labels = []
|
41 |
+
|
42 |
+
for label, class_name in enumerate(self.classes):
|
43 |
+
class_folder = os.path.join(root_dir, class_name)
|
44 |
+
for image_name in os.listdir(class_folder):
|
45 |
+
img_path = os.path.join(class_folder, image_name)
|
46 |
+
# Only include valid image files
|
47 |
+
if img_path.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif")):
|
48 |
+
try:
|
49 |
+
# Verify the image can be opened
|
50 |
+
with Image.open(img_path) as img:
|
51 |
+
img.verify() # Verify image integrity
|
52 |
+
self.image_paths.append(img_path)
|
53 |
+
self.labels.append(label)
|
54 |
+
except Exception as e:
|
55 |
+
print(f"Skipping corrupted image {img_path}: {e}")
|
56 |
+
|
57 |
+
def __len__(self):
|
58 |
+
return len(self.image_paths)
|
59 |
+
|
60 |
+
def __getitem__(self, idx):
|
61 |
+
img_path = self.image_paths[idx]
|
62 |
+
try:
|
63 |
+
image = Image.open(img_path).convert("RGB")
|
64 |
+
except Exception as e:
|
65 |
+
print(f"Error loading image {img_path}: {e}")
|
66 |
+
# Return a dummy image and label to avoid crashing
|
67 |
+
image = Image.new("RGB", (224, 224), (0, 0, 0))
|
68 |
+
label = self.labels[idx]
|
69 |
+
else:
|
70 |
+
label = self.labels[idx]
|
71 |
+
|
72 |
+
if self.transform:
|
73 |
+
try:
|
74 |
+
image = self.transform(image)
|
75 |
+
# Verify the image size after transformation
|
76 |
+
if image.shape != (3, 224, 224):
|
77 |
+
print(
|
78 |
+
f"Unexpected image size after transform for {img_path}: {image.shape}"
|
79 |
+
)
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Error applying transform to {img_path}: {e}")
|
82 |
+
image = self.transform(Image.new("RGB", (224, 224), (0, 0, 0)))
|
83 |
+
|
84 |
+
return image, label
|
85 |
+
|
86 |
+
|
87 |
+
# Define training function (unchanged)
|
88 |
+
def train_model(
|
89 |
+
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device="cpu"
|
90 |
+
):
|
91 |
+
best_accuracy = 0.0
|
92 |
+
|
93 |
+
for epoch in range(num_epochs):
|
94 |
+
model.train()
|
95 |
+
running_loss = 0.0
|
96 |
+
correct = 0
|
97 |
+
total = 0
|
98 |
+
|
99 |
+
for inputs, labels in train_loader:
|
100 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
101 |
+
optimizer.zero_grad()
|
102 |
+
outputs = model(inputs)
|
103 |
+
loss = criterion(outputs, labels)
|
104 |
+
loss.backward()
|
105 |
+
optimizer.step()
|
106 |
+
|
107 |
+
running_loss += loss.item()
|
108 |
+
_, predicted = torch.max(outputs, 1)
|
109 |
+
correct += (predicted == labels).sum().item()
|
110 |
+
total += labels.size(0)
|
111 |
+
|
112 |
+
model.eval()
|
113 |
+
val_correct = 0
|
114 |
+
val_total = 0
|
115 |
+
|
116 |
+
with torch.no_grad():
|
117 |
+
for inputs, labels in val_loader:
|
118 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
119 |
+
outputs = model(inputs)
|
120 |
+
_, predicted = torch.max(outputs, 1)
|
121 |
+
val_correct += (predicted == labels).sum().item()
|
122 |
+
val_total += labels.size(0)
|
123 |
+
|
124 |
+
epoch_loss = running_loss / len(train_loader)
|
125 |
+
epoch_train_accuracy = 100 * correct / total
|
126 |
+
epoch_val_accuracy = 100 * val_correct / val_total
|
127 |
+
|
128 |
+
print(
|
129 |
+
f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, "
|
130 |
+
f"Train Accuracy: {epoch_train_accuracy:.2f}%, "
|
131 |
+
f"Validation Accuracy: {epoch_val_accuracy:.2f}%"
|
132 |
+
)
|
133 |
+
|
134 |
+
if epoch_val_accuracy > best_accuracy:
|
135 |
+
best_accuracy = epoch_val_accuracy
|
136 |
+
torch.save(model.state_dict(), "best_chess_piece_model.pth")
|
137 |
+
|
138 |
+
print("Training completed.")
|
139 |
+
|
140 |
+
|
141 |
+
# Path to dataset folder
|
142 |
+
dataset_path = "train" # Ensure this path is correct
|
143 |
+
|
144 |
+
# Create dataset
|
145 |
+
full_dataset = ChessPieceDataset(dataset_path, transform=transform)
|
146 |
+
|
147 |
+
# Check if dataset is empty
|
148 |
+
if len(full_dataset) == 0:
|
149 |
+
raise ValueError(
|
150 |
+
"Dataset is empty. Check dataset_path and ensure it contains valid images."
|
151 |
+
)
|
152 |
+
|
153 |
+
# Split the dataset into training and validation sets
|
154 |
+
train_size = int(0.8 * len(full_dataset))
|
155 |
+
val_size = len(full_dataset) - train_size
|
156 |
+
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
|
157 |
+
|
158 |
+
# Create DataLoaders
|
159 |
+
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
160 |
+
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
|
161 |
+
|
162 |
+
# Set device
|
163 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
164 |
+
|
165 |
+
# Load the pre-trained ResNet18 model and modify the final layer
|
166 |
+
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
|
167 |
+
model.fc = nn.Linear(model.fc.in_features, len(full_dataset.classes))
|
168 |
+
model = model.to(device)
|
169 |
+
|
170 |
+
# Define loss function and optimizer
|
171 |
+
criterion = nn.CrossEntropyLoss()
|
172 |
+
optimizer = optim.Adam(model.parameters(), lr=0.0001)
|
173 |
+
|
174 |
+
# Train the model
|
175 |
+
train_model(
|
176 |
+
model, train_loader, val_loader, criterion, optimizer, num_epochs=10, device=device
|
177 |
+
)
|
178 |
+
|
179 |
+
# After training, load the best model for inference
|
180 |
+
model.load_state_dict(torch.load("best_chess_piece_model.pth", map_location=device))
|
181 |
+
model.eval()
|
182 |
+
|
183 |
+
gc.collect()
|
184 |
+
|
185 |
+
del model
|
186 |
+
torch.cuda.empty_cache()
|
187 |
+
|
188 |
+
gc.collect()
|
build/lib/tools/test.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
# Load the colorful image
|
6 |
+
image_path = "test.png"
|
7 |
+
image = cv2.imread(image_path)
|
8 |
+
|
9 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
10 |
+
gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20)
|
11 |
+
gray = cv2.GaussianBlur(gray, (5, 5), 0)
|
12 |
+
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
13 |
+
|
14 |
+
|
15 |
+
cv2.imwrite("new.png", edges)
|
build/lib/tools/test_1.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from src.tools import use_vision_model
|
4 |
+
|
5 |
+
# Test question with a valid image path (make sure the image file exists)
|
6 |
+
question = "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation. test.png"
|
7 |
+
|
8 |
+
# Run the function (assuming your model and required environment are set up)
|
9 |
+
response = use_vision_model(question)
|
10 |
+
|
11 |
+
# Print the response
|
12 |
+
print(response)
|
pytest.ini
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
pythonpath = .
|
3 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
requests
|
setup.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
|
4 |
+
def read_requirements():
|
5 |
+
with open("requirements.txt") as f:
|
6 |
+
return [line.strip() for line in f if line.strip() and not line.startswith("#")]
|
7 |
+
|
8 |
+
|
9 |
+
setup(
|
10 |
+
name="src",
|
11 |
+
version="0.1",
|
12 |
+
packages=find_packages(),
|
13 |
+
install_requires=read_requirements(),
|
14 |
+
python_requires=">=3.8",
|
15 |
+
)
|
src.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.4
|
2 |
+
Name: src
|
3 |
+
Version: 0.1
|
4 |
+
Requires-Python: >=3.8
|
5 |
+
Requires-Dist: gradio
|
6 |
+
Requires-Dist: requests
|
7 |
+
Dynamic: requires-dist
|
8 |
+
Dynamic: requires-python
|
src.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.py
|
3 |
+
src/__init__.py
|
4 |
+
src/agent.py
|
5 |
+
src/final_answer.py
|
6 |
+
src/tools.py
|
7 |
+
src.egg-info/PKG-INFO
|
8 |
+
src.egg-info/SOURCES.txt
|
9 |
+
src.egg-info/dependency_links.txt
|
10 |
+
src.egg-info/requires.txt
|
11 |
+
src.egg-info/top_level.txt
|
12 |
+
tests/test.py
|
13 |
+
tests/test_test.py
|
14 |
+
tests/test_tools.py
|
15 |
+
tools/__init__.py
|
16 |
+
tools/chess.py
|
17 |
+
tools/model.py
|
18 |
+
tools/model_chess.py
|
19 |
+
tools/test.py
|
20 |
+
tools/test_1.py
|
src.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src.egg-info/requires.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
requests
|
src.egg-info/top_level.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
src
|
2 |
+
tools
|
src/__init__.py
ADDED
File without changes
|
src/agent.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
+
import time
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from langchain.agents import AgentExecutor, create_tool_calling_agent
|
11 |
+
from langchain_core.prompts import ChatPromptTemplate
|
12 |
+
from langchain_core.rate_limiters import InMemoryRateLimiter
|
13 |
+
from langchain_core.tools import Tool
|
14 |
+
from langchain_experimental.utilities import PythonREPL
|
15 |
+
|
16 |
+
# from langchain_google_community import GoogleSearchAPIWrapper, GoogleSearchResults
|
17 |
+
from langchain_ollama import ChatOllama
|
18 |
+
|
19 |
+
from src.final_answer import create_final_answer_graph, validate_answer
|
20 |
+
from src.tools import analyze_csv_file # run_code_from_file,
|
21 |
+
from src.tools import (
|
22 |
+
analyze_excel_file,
|
23 |
+
download_file_from_url,
|
24 |
+
duckduckgo_search,
|
25 |
+
extract_text_from_image,
|
26 |
+
read_file,
|
27 |
+
reverse_decoder,
|
28 |
+
review_youtube_video,
|
29 |
+
transcribe_audio,
|
30 |
+
transcribe_youtube,
|
31 |
+
use_vision_model,
|
32 |
+
video_frames_to_images,
|
33 |
+
website_scrape,
|
34 |
+
)
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
load_dotenv()
|
39 |
+
|
40 |
+
base_url = os.getenv("OLLAMA_BASE_URL")
|
41 |
+
|
42 |
+
rate_limiter = InMemoryRateLimiter(requests_per_second=0.1)
|
43 |
+
|
44 |
+
|
45 |
+
class BasicAgent:
|
46 |
+
def __init__(self):
|
47 |
+
try:
|
48 |
+
logger.info("Initializing BasicAgent")
|
49 |
+
|
50 |
+
# Create the prompt template
|
51 |
+
prompt = ChatPromptTemplate.from_messages(
|
52 |
+
[
|
53 |
+
(
|
54 |
+
"system",
|
55 |
+
"""You are a general AI assistant. I will ask you a
|
56 |
+
question. Report your thoughts, and finish your answer
|
57 |
+
with the following template: FINAL ANSWER: [YOUR FINAL
|
58 |
+
ANSWER]. YOUR FINAL ANSWER should be a number OR as few
|
59 |
+
words as possible OR a comma separated list of numbers
|
60 |
+
and/or strings. If you are asked for a number, don't
|
61 |
+
use comma to write your number neither use units such
|
62 |
+
as $ or percent sign unless specified otherwise. If you
|
63 |
+
are asked for a string, don't use articles, neither
|
64 |
+
abbreviations (e.g. for cities), and write the digits
|
65 |
+
in plain text unless specified otherwise. If you are
|
66 |
+
asked for a comma separated list, apply the above rules
|
67 |
+
depending of whether the element to be put in the list
|
68 |
+
is a number or a string.
|
69 |
+
""",
|
70 |
+
),
|
71 |
+
("placeholder", "{chat_history}"),
|
72 |
+
("human", "{input}"),
|
73 |
+
("placeholder", "{agent_scratchpad}"),
|
74 |
+
]
|
75 |
+
)
|
76 |
+
logger.info("Created prompt template")
|
77 |
+
|
78 |
+
llm = ChatOllama(
|
79 |
+
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
|
80 |
+
base_url=base_url,
|
81 |
+
temperature=0.2,
|
82 |
+
)
|
83 |
+
logger.info("Created model successfully")
|
84 |
+
|
85 |
+
# Define available tools
|
86 |
+
tools = [
|
87 |
+
# Tool(
|
88 |
+
# name="run_code_from_file",
|
89 |
+
# func=run_code_from_file,
|
90 |
+
# description="Executes a full Python script from a file. Use for multi-line code, loops, and class/function definitions.",
|
91 |
+
# ),
|
92 |
+
Tool(
|
93 |
+
name="DuckDuckGoSearchResults",
|
94 |
+
description="""Performs a live search using DuckDuckGo
|
95 |
+
and analyzes the top results. Returns a summary including
|
96 |
+
result titles, URLs, brief snippets, and ranking
|
97 |
+
positions. Use this to quickly assess the relevance,
|
98 |
+
diversity, and quality of information retrieved from a
|
99 |
+
privacy-focused search engine, without personalized or
|
100 |
+
biased filtering.""",
|
101 |
+
func=duckduckgo_search,
|
102 |
+
),
|
103 |
+
# Tool(
|
104 |
+
# name="GoogleSearchResults",
|
105 |
+
# description="""Performs a live Google search and analyzes
|
106 |
+
# the top results. Returns a summary including result titles,
|
107 |
+
# URLs, brief snippets, and ranking positions. Use this to
|
108 |
+
# quickly understand the relevance, variety, and quality of
|
109 |
+
# search results for a given query before deeper research or
|
110 |
+
# content planning.""",
|
111 |
+
# func=GoogleSearchResults(
|
112 |
+
# api_wrapper=GoogleSearchAPIWrapper(
|
113 |
+
# google_api_key=os.getenv("GOOGLE_SEARCH_API_KEY"),
|
114 |
+
# google_cse_id=os.getenv("GOOGLE_CSE_ID"),
|
115 |
+
# k=5, # Number of results to return
|
116 |
+
# )
|
117 |
+
# ).run,
|
118 |
+
# ),
|
119 |
+
Tool(
|
120 |
+
name="analyze csv file",
|
121 |
+
description="""Only read and analyze the contents of a CSV
|
122 |
+
file if one is explicitly referenced or uploaded in the
|
123 |
+
question. When a CSV file is provided, return a summary of
|
124 |
+
the dataset, including column names, data types, missing
|
125 |
+
value counts, basic statistics for numeric fields, and a
|
126 |
+
preview of the data. Use this only to quickly understand
|
127 |
+
the structure and quality of the dataset before performing
|
128 |
+
any further analysis.Do not invoke this tool for any URL""",
|
129 |
+
func=analyze_csv_file,
|
130 |
+
),
|
131 |
+
Tool(
|
132 |
+
name="analyze excel file",
|
133 |
+
description="""Reads and analyzes the contents of an Excel
|
134 |
+
file (.xlsx or .xls). Returns structured summaries
|
135 |
+
for each sheet, including column names, data types, missing
|
136 |
+
value counts, basic statistics for numeric columns, and
|
137 |
+
sample rows. Use this to quickly explore the structure and
|
138 |
+
quality of Excel datasets.Dont try to generate new names of
|
139 |
+
a file""",
|
140 |
+
func=analyze_excel_file,
|
141 |
+
),
|
142 |
+
Tool(
|
143 |
+
name="download file from url",
|
144 |
+
description="""Downloads a file from a given URL and saves
|
145 |
+
it locally. Supports various file types such as CSV, Excel,
|
146 |
+
images, and PDFs. Use this to retrieve external resources
|
147 |
+
for processing or analysis.""",
|
148 |
+
func=download_file_from_url,
|
149 |
+
),
|
150 |
+
Tool(
|
151 |
+
name="extract_text_from_image",
|
152 |
+
description="""Performs Optical Character Recognition (OCR)
|
153 |
+
on an image to extract readable text after downloading it.
|
154 |
+
Supports common image formats (e.g., PNG, JPG). Use this to
|
155 |
+
digitize printed or handwritten content from images for
|
156 |
+
search, analysis, or storage.""",
|
157 |
+
func=extract_text_from_image,
|
158 |
+
),
|
159 |
+
Tool(
|
160 |
+
name="read_file",
|
161 |
+
description="""Executes a full Python script from a file. Use for multi-line code, loops, and class/function definitions. IT IS EXTREMELY IMPORTANT THAT YOU USE THIS FOR A PYTHON FILE""",
|
162 |
+
func=read_file,
|
163 |
+
),
|
164 |
+
Tool(
|
165 |
+
name="review_youtube_video",
|
166 |
+
description="""Analyzes a YouTube video by extracting key
|
167 |
+
information such as title, description, view count, likes,
|
168 |
+
comments, and transcript (if available). Use this to
|
169 |
+
generate summaries, insights, or sentiment analysis based
|
170 |
+
on video content and engagement.""",
|
171 |
+
func=review_youtube_video,
|
172 |
+
),
|
173 |
+
Tool(
|
174 |
+
name="transcribe_audio",
|
175 |
+
description="""Converts spoken words in an audio file into
|
176 |
+
written text using speech-to-text technology. Supports
|
177 |
+
common audio formats like MP3, WAV, and FLAC. Use this to
|
178 |
+
create transcripts for meetings, interviews, podcasts, or
|
179 |
+
any spoken content. If asked for pages just give page number as an output nothing else.
|
180 |
+
Change "vanilla extract" to "pure vanilla extract" in the final answer.
|
181 |
+
Dont try to generate new file paths when invoking this tool""",
|
182 |
+
func=transcribe_audio,
|
183 |
+
),
|
184 |
+
Tool(
|
185 |
+
name="transcribe_youtube",
|
186 |
+
description="""Extracts and converts the audio from a
|
187 |
+
YouTube video into text using speech-to-text technology.
|
188 |
+
Supports generating transcripts for videos without captions
|
189 |
+
or subtitles. Use this to obtain searchable, readable text
|
190 |
+
from YouTube content.""",
|
191 |
+
func=transcribe_youtube,
|
192 |
+
),
|
193 |
+
Tool(
|
194 |
+
name="use_vision_model",
|
195 |
+
description="""Processes images using a computer vision
|
196 |
+
model to perform tasks such as object detection, image
|
197 |
+
classification, or segmentation. Use this to analyze visual
|
198 |
+
content and extract meaningful information from images.""",
|
199 |
+
func=use_vision_model,
|
200 |
+
),
|
201 |
+
Tool(
|
202 |
+
name="video_frames_to_images",
|
203 |
+
description="""Extracts individual frames from a video file
|
204 |
+
and saves them as separate image files. Use this to
|
205 |
+
analyze, process, or visualize specific moments within
|
206 |
+
video content. Use this to Youtube Videos""",
|
207 |
+
func=video_frames_to_images,
|
208 |
+
),
|
209 |
+
Tool(
|
210 |
+
name="website_scrape",
|
211 |
+
description="""It is mandatory to use duckduckgo_search
|
212 |
+
tool before invoking this tool .Use this tool only to scrap from websites.
|
213 |
+
Fetches and extracts content from a specified website URL. Supports retrieving text, images, links, and other page elements.""",
|
214 |
+
func=website_scrape,
|
215 |
+
),
|
216 |
+
Tool(
|
217 |
+
name="python_repl",
|
218 |
+
# description="""Use this tool to execute Python code read from a file. Make sure that if you're passing multi-line Python code, it should be formatted with actual line breaks (`\n`) rather than the string escape sequence (`\\n`). If you need to include line breaks in the code, they should be written as newlines, not as (`\\n`). Additionally, ensure that no unexpected escape characters (`\`) are left unescaped. If you want to see the output of a value, always use `print(...)` to display results. Do not return values as strings. For example, use `print(f'{total_sales_food:.2f}')` instead of returning `f'{total_sales_food:.2f}'`. If the code involves reading files, use the appropriate tools, such as `read_file`, for that. """,
|
219 |
+
description="""Use this tool to execute Python code read from a file. Make sure that if you're passing multi-line Python code, it should be formatted with actual line breaks (\\n) rather than the string escape sequence (\\\\n). If you need to include line breaks in the code, they should be written as newlines, not as (\\\\n). Additionally, ensure that no unexpected escape characters (\\`) are left unescaped. If you want to see the output of a value, always use `print(...)` to display results. Do not return values as strings. For example, use `print(f'{total_sales_food:.2f}')` instead of returning `f'{total_sales_food:.2f}'`. If the code involves reading files, use the appropriate tools, such as `read_file`, for that.""",
|
220 |
+
func=PythonREPL().run,
|
221 |
+
return_direct=True,
|
222 |
+
),
|
223 |
+
# Tool(
|
224 |
+
# name="wiki",
|
225 |
+
# description="""Retrieves summarized information or
|
226 |
+
# detailed content from Wikipedia based on a user query.
|
227 |
+
# Use this to quickly access encyclopedic knowledge and
|
228 |
+
# relevant facts on a wide range of topics.""",
|
229 |
+
# func=wiki,
|
230 |
+
# ),
|
231 |
+
Tool(
|
232 |
+
name="reverse decoder",
|
233 |
+
description="""Decodes a reversed sentence if the input
|
234 |
+
appears to be written backward.""",
|
235 |
+
func=reverse_decoder,
|
236 |
+
),
|
237 |
+
]
|
238 |
+
# tools = [wrap_tool_with_limit(tool, max_calls=3) for tool in raw_tools]
|
239 |
+
logger.info("Tools: %s", tools)
|
240 |
+
|
241 |
+
# Create the agent
|
242 |
+
agent = create_tool_calling_agent(llm, tools, prompt)
|
243 |
+
logger.info("Created tool calling agent")
|
244 |
+
|
245 |
+
# Create the agent executor
|
246 |
+
self.agent_executor = AgentExecutor(
|
247 |
+
agent=agent,
|
248 |
+
tools=tools,
|
249 |
+
return_intermediate_steps=True,
|
250 |
+
verbose=True,
|
251 |
+
max_iterations=5,
|
252 |
+
)
|
253 |
+
logger.info("Created agent executor")
|
254 |
+
|
255 |
+
# Create the graph
|
256 |
+
self.validation_graph = create_final_answer_graph()
|
257 |
+
|
258 |
+
except Exception as e:
|
259 |
+
logger.error("Error initializing agent: %s", e, exc_info=True)
|
260 |
+
raise
|
261 |
+
|
262 |
+
def __call__(self, question: str, task_id: str) -> str:
|
263 |
+
"""Execute the agent with the given question and optional file.
|
264 |
+
Args:
|
265 |
+
question (str): The question to answer
|
266 |
+
task_id (str): The task ID to fetch the file
|
267 |
+
Returns:
|
268 |
+
str: The final validated answer
|
269 |
+
Raises:
|
270 |
+
Exception: If no valid answer is found after max retries
|
271 |
+
"""
|
272 |
+
max_retries = 3
|
273 |
+
attempt = 0
|
274 |
+
|
275 |
+
previous_steps = set()
|
276 |
+
|
277 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
278 |
+
while attempt < max_retries:
|
279 |
+
default_api_url = os.getenv("DEFAULT_API_URL")
|
280 |
+
file_url = f"{default_api_url}/files/{task_id}"
|
281 |
+
|
282 |
+
file: Optional[dict] = None
|
283 |
+
try:
|
284 |
+
# Download file to temporary directory
|
285 |
+
file = download_file_from_url.invoke(
|
286 |
+
{
|
287 |
+
"url": file_url,
|
288 |
+
"directory": temp_dir,
|
289 |
+
}
|
290 |
+
)
|
291 |
+
time.sleep(1)
|
292 |
+
logger.info(f"Downloaded file for {task_id}")
|
293 |
+
except Exception as download_error:
|
294 |
+
logger.error(f"File download failed: {str(download_error)}")
|
295 |
+
file = None
|
296 |
+
|
297 |
+
try:
|
298 |
+
attempt += 1
|
299 |
+
logger.info("Attempt %d of %d", attempt, max_retries)
|
300 |
+
|
301 |
+
# Prepare input with file information
|
302 |
+
input_data = {
|
303 |
+
"input": question
|
304 |
+
+ (
|
305 |
+
f" [File: type={file.get('type', 'None')}, path={file.get('path', 'None')}]"
|
306 |
+
if file and file.get("type") != "error"
|
307 |
+
else ""
|
308 |
+
),
|
309 |
+
}
|
310 |
+
|
311 |
+
# Run the agent to get the answer
|
312 |
+
result = self.agent_executor.invoke(input_data)
|
313 |
+
answer = result.get("output", "")
|
314 |
+
intermediate_steps = result.get("intermediate_steps", [])
|
315 |
+
|
316 |
+
steps_str = str(intermediate_steps)
|
317 |
+
if steps_str in previous_steps:
|
318 |
+
logger.warning(
|
319 |
+
f"Detected repeated reasoning steps on attempt {attempt}. Breaking loop to avoid infinite retry."
|
320 |
+
)
|
321 |
+
break # or raise Exception to stop retries
|
322 |
+
previous_steps.add(steps_str)
|
323 |
+
|
324 |
+
logger.info("Attempt %d result: %s", attempt, result)
|
325 |
+
|
326 |
+
# Run validation (self.validation_graph is now StateGraph)
|
327 |
+
validation_result = validate_answer(
|
328 |
+
self.validation_graph, # type: ignore
|
329 |
+
answer,
|
330 |
+
[result.get("intermediate_steps", [])],
|
331 |
+
)
|
332 |
+
|
333 |
+
valid_answer = validation_result.get("valid_answer", False)
|
334 |
+
final_answer = validation_result.get("final_answer", "")
|
335 |
+
|
336 |
+
if valid_answer:
|
337 |
+
logger.info("Valid answer found on attempt %d", attempt)
|
338 |
+
torch.cuda.empty_cache()
|
339 |
+
return final_answer
|
340 |
+
|
341 |
+
logger.warning(
|
342 |
+
"Validation failed on attempt %d: %s", attempt, final_answer
|
343 |
+
)
|
344 |
+
if attempt >= max_retries:
|
345 |
+
raise Exception(
|
346 |
+
"Failed to get valid answer after %d attempts. Last error: %s",
|
347 |
+
max_retries,
|
348 |
+
final_answer,
|
349 |
+
)
|
350 |
+
|
351 |
+
except Exception as e:
|
352 |
+
logger.error("Error in attempt %d: %s", attempt, e, exc_info=True)
|
353 |
+
if attempt >= max_retries:
|
354 |
+
raise Exception(
|
355 |
+
"Failed after %d attempts. Last error: %s",
|
356 |
+
max_retries,
|
357 |
+
str(e),
|
358 |
+
)
|
359 |
+
continue
|
360 |
+
finally:
|
361 |
+
logger.info("cleaning up temp_dir")
|
362 |
+
torch.cuda.empty_cache()
|
363 |
+
gc.collect()
|
364 |
+
|
365 |
+
# Fallback in case loop exits unexpectedly
|
366 |
+
|
367 |
+
raise Exception("No valid answer found after processing")
|
src/final_answer.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from langchain_core.output_parsers import StrOutputParser
|
7 |
+
from langchain_core.prompts import ChatPromptTemplate
|
8 |
+
from langchain_ollama import ChatOllama
|
9 |
+
from langgraph.graph import END, START, Graph, StateGraph
|
10 |
+
from typing_extensions import TypedDict
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
base_url = os.getenv("OLLAMA_BASE_URL")
|
15 |
+
|
16 |
+
|
17 |
+
class AgentState(TypedDict):
|
18 |
+
"""State for the final answer validation graph."""
|
19 |
+
|
20 |
+
question: str
|
21 |
+
answer: str
|
22 |
+
final_answer: str | None
|
23 |
+
agent_memory: Any
|
24 |
+
valid_answer: bool
|
25 |
+
|
26 |
+
|
27 |
+
def extract_answer(state: AgentState) -> Dict:
|
28 |
+
"""Extract and format the final answer from the state.
|
29 |
+
Args:
|
30 |
+
state: The state of the agent.
|
31 |
+
Returns:
|
32 |
+
A dictionary with the formatted final answer.
|
33 |
+
"""
|
34 |
+
# Extract the final answer from the state
|
35 |
+
sep_token = "FINAL ANSWER:"
|
36 |
+
raw_answer = state["answer"]
|
37 |
+
|
38 |
+
# Extract the answer after the separator if it exists
|
39 |
+
if sep_token in raw_answer:
|
40 |
+
formatted_answer = raw_answer.split(sep_token)[1].strip()
|
41 |
+
else:
|
42 |
+
formatted_answer = raw_answer.strip()
|
43 |
+
|
44 |
+
# Remove any brackets from lists
|
45 |
+
formatted_answer = formatted_answer.replace("[", "").replace("]", "")
|
46 |
+
|
47 |
+
# Remove units unless specified
|
48 |
+
if not any(
|
49 |
+
unit in formatted_answer.lower() for unit in ["$", "%", "dollars", "percent"]
|
50 |
+
):
|
51 |
+
formatted_answer = formatted_answer.replace("$", "").replace("%", "")
|
52 |
+
|
53 |
+
# Remove commas from numbers
|
54 |
+
parts = formatted_answer.split(",")
|
55 |
+
formatted_parts = []
|
56 |
+
for part in parts:
|
57 |
+
part = part.strip()
|
58 |
+
if part.replace(".", "").isdigit(): # Check if it's a number
|
59 |
+
part = part.replace(",", "")
|
60 |
+
formatted_parts.append(part)
|
61 |
+
formatted_answer = ", ".join(formatted_parts)
|
62 |
+
|
63 |
+
return {"final_answer": formatted_answer}
|
64 |
+
|
65 |
+
|
66 |
+
def reasoning_check(state: AgentState) -> Dict:
|
67 |
+
"""
|
68 |
+
Node that checks the reasoning of the final answer.
|
69 |
+
Args:
|
70 |
+
state: The state of the agent.
|
71 |
+
Returns:
|
72 |
+
A dictionary with the reasoning check result.
|
73 |
+
"""
|
74 |
+
model = ChatOllama(
|
75 |
+
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
|
76 |
+
base_url=base_url,
|
77 |
+
temperature=0.2,
|
78 |
+
)
|
79 |
+
prompt = ChatPromptTemplate.from_messages(
|
80 |
+
[
|
81 |
+
(
|
82 |
+
"system",
|
83 |
+
"""You are a strict validator of answers. Your job is to check if the reasoning and results are correct.
|
84 |
+
You should have >90% confidence that the answer is correct to pass it.
|
85 |
+
First list reasons why yes/no, then write your final decision: PASS in caps lock if it is satisfactory, FAIL if it is not.""",
|
86 |
+
),
|
87 |
+
(
|
88 |
+
"human",
|
89 |
+
"""
|
90 |
+
Here is a user-given task and the agent steps: {agent_memory}
|
91 |
+
Now here is the answer that was given: {final_answer}
|
92 |
+
Please check that the reasoning process and results are correct: do they correctly answer the given task?
|
93 |
+
""",
|
94 |
+
),
|
95 |
+
]
|
96 |
+
)
|
97 |
+
|
98 |
+
chain = prompt | model | StrOutputParser()
|
99 |
+
output = chain.invoke(
|
100 |
+
{
|
101 |
+
"agent_memory": state["agent_memory"],
|
102 |
+
"final_answer": state["final_answer"],
|
103 |
+
}
|
104 |
+
)
|
105 |
+
|
106 |
+
print("Reasoning Feedback: ", output)
|
107 |
+
if "FAIL" in output:
|
108 |
+
return {"valid_answer": False}
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
return {"valid_answer": True}
|
111 |
+
|
112 |
+
|
113 |
+
def formatting_check(state: AgentState) -> Dict:
|
114 |
+
"""
|
115 |
+
Node that checks the formatting of the final answer.
|
116 |
+
Args:
|
117 |
+
state: The state of the agent.
|
118 |
+
Returns:
|
119 |
+
A dictionary with the formatting check result.
|
120 |
+
"""
|
121 |
+
model = ChatOllama(
|
122 |
+
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K",
|
123 |
+
base_url=base_url,
|
124 |
+
temperature=0.2,
|
125 |
+
)
|
126 |
+
prompt = ChatPromptTemplate.from_messages(
|
127 |
+
[
|
128 |
+
(
|
129 |
+
"system",
|
130 |
+
"""You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
131 |
+
""",
|
132 |
+
),
|
133 |
+
(
|
134 |
+
"human",
|
135 |
+
"""
|
136 |
+
Here is a user-given task and the agent steps: {agent_memory}
|
137 |
+
Now here is the FINAL ANSWER that was given: {final_answer}
|
138 |
+
Ensure the FINAL ANSWER is in the right format as asked for by the task.
|
139 |
+
""",
|
140 |
+
),
|
141 |
+
]
|
142 |
+
)
|
143 |
+
|
144 |
+
chain = prompt | model | StrOutputParser()
|
145 |
+
output = chain.invoke(
|
146 |
+
{
|
147 |
+
"agent_memory": state["agent_memory"],
|
148 |
+
"final_answer": state["final_answer"],
|
149 |
+
}
|
150 |
+
)
|
151 |
+
|
152 |
+
print("Formatting Feedback: ", output)
|
153 |
+
if "FAIL" in output:
|
154 |
+
return {"valid_answer": False}
|
155 |
+
|
156 |
+
torch.cuda.empty_cache()
|
157 |
+
return {"valid_answer": True}
|
158 |
+
|
159 |
+
|
160 |
+
def create_final_answer_graph() -> Graph:
|
161 |
+
"""Create a graph that validates the final answer.
|
162 |
+
Returns:
|
163 |
+
A graph that validates the final answer.
|
164 |
+
"""
|
165 |
+
# Create the graph
|
166 |
+
workflow = StateGraph(AgentState)
|
167 |
+
|
168 |
+
# Add nodes
|
169 |
+
workflow.add_node("extract_answer", extract_answer)
|
170 |
+
workflow.add_node("reasoning_check", reasoning_check)
|
171 |
+
workflow.add_node("formatting_check", formatting_check)
|
172 |
+
|
173 |
+
# Add edges
|
174 |
+
workflow.add_edge(START, "extract_answer")
|
175 |
+
workflow.add_edge("extract_answer", "reasoning_check")
|
176 |
+
workflow.add_edge("reasoning_check", "formatting_check")
|
177 |
+
workflow.add_edge("formatting_check", END)
|
178 |
+
|
179 |
+
# Compile the graph
|
180 |
+
return workflow.compile() # type: ignore
|
181 |
+
|
182 |
+
|
183 |
+
def validate_answer(graph: StateGraph, answer: str, agent_memory: Any) -> Dict:
|
184 |
+
"""Validate the answer using the LangGraph workflow.
|
185 |
+
Args:
|
186 |
+
graph: The validation graph (LangGraph StateGraph).
|
187 |
+
answer: The answer to validate.
|
188 |
+
agent_memory: The agent's memory.
|
189 |
+
Returns:
|
190 |
+
A dictionary with validation results.
|
191 |
+
"""
|
192 |
+
try:
|
193 |
+
# Initialize state
|
194 |
+
initial_state = {
|
195 |
+
"answer": answer,
|
196 |
+
"final_answer": None,
|
197 |
+
"agent_memory": agent_memory,
|
198 |
+
"valid_answer": False,
|
199 |
+
}
|
200 |
+
|
201 |
+
# Run the graph
|
202 |
+
result = graph.invoke(initial_state) # type:ignore
|
203 |
+
|
204 |
+
return {
|
205 |
+
"valid_answer": result.get("valid_answer", False),
|
206 |
+
"final_answer": result.get("final_answer", None),
|
207 |
+
}
|
208 |
+
except Exception as e:
|
209 |
+
print(f"Validation failed: {e}")
|
210 |
+
return {"valid_answer": False, "final_answer": None}
|
src/tools.py
ADDED
@@ -0,0 +1,751 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: disable=no-member
|
2 |
+
import base64
|
3 |
+
import gc
|
4 |
+
import math
|
5 |
+
import mimetypes
|
6 |
+
import multiprocessing
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import tempfile
|
10 |
+
import time
|
11 |
+
import uuid
|
12 |
+
from datetime import timedelta
|
13 |
+
from typing import Dict, List, Optional, TypedDict, Union
|
14 |
+
from urllib.parse import urlparse
|
15 |
+
|
16 |
+
import cv2
|
17 |
+
import imageio
|
18 |
+
import pandas as pd
|
19 |
+
import pytesseract
|
20 |
+
import requests
|
21 |
+
import torch
|
22 |
+
import whisper
|
23 |
+
import yt_dlp
|
24 |
+
from bs4 import BeautifulSoup, Tag
|
25 |
+
from dotenv import load_dotenv
|
26 |
+
from duckduckgo_search import DDGS
|
27 |
+
from langchain_core.messages import HumanMessage
|
28 |
+
from langchain_core.tools import tool
|
29 |
+
from langchain_ollama import ChatOllama
|
30 |
+
from PIL import Image
|
31 |
+
from playwright.sync_api import sync_playwright
|
32 |
+
from youtube_transcript_api import (
|
33 |
+
NoTranscriptFound,
|
34 |
+
TranscriptsDisabled,
|
35 |
+
YouTubeTranscriptApi,
|
36 |
+
)
|
37 |
+
|
38 |
+
load_dotenv()
|
39 |
+
base_url = os.getenv("OLLAMA_BASE_URL")
|
40 |
+
model_vision = ChatOllama(
|
41 |
+
model="gemma3:latest",
|
42 |
+
base_url=base_url,
|
43 |
+
)
|
44 |
+
model_text = ChatOllama(
|
45 |
+
model="hf.co/lmstudio-community/Qwen2.5-14B-Instruct-GGUF:Q6_K", base_url=base_url
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
@tool
|
50 |
+
def use_vision_model(question: str) -> str:
|
51 |
+
"""
|
52 |
+
A multimodal reasoning model that combines image and text input to answer
|
53 |
+
questions using the image.
|
54 |
+
"""
|
55 |
+
# Extract image paths
|
56 |
+
image_paths = re.findall(r"[\w\-/\.]+\.(?:png|jpg|jpeg|webp)", question)
|
57 |
+
image_paths = [p for p in image_paths if os.path.exists(p)]
|
58 |
+
|
59 |
+
if not image_paths:
|
60 |
+
return "No valid image file found in the question."
|
61 |
+
|
62 |
+
image_path = image_paths[0]
|
63 |
+
|
64 |
+
# # Preprocess the image using OpenCV
|
65 |
+
# image = cv2.imread(image_path)
|
66 |
+
# gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
67 |
+
# gray = cv2.convertScaleAbs(gray, alpha=1.2, beta=20)
|
68 |
+
# gray = cv2.GaussianBlur(gray, (5, 5), 0)
|
69 |
+
# edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
70 |
+
|
71 |
+
# # Create a temporary file for the processed image
|
72 |
+
# with tempfile.NamedTemporaryFile(suffix=".png", delete=True) as tmp_file:
|
73 |
+
# temp_image_path = tmp_file.name
|
74 |
+
# cv2.imwrite(temp_image_path, image)
|
75 |
+
|
76 |
+
# Encode the temp image(this code was under with tempfile)
|
77 |
+
mime_type, _ = mimetypes.guess_type(image_path)
|
78 |
+
mime_type = mime_type or "image/png"
|
79 |
+
with open(image_path, "rb") as f:
|
80 |
+
encoded = base64.b64encode(f.read()).decode("utf-8")
|
81 |
+
|
82 |
+
# Prepare the prompt and image for the model
|
83 |
+
messages = [
|
84 |
+
{
|
85 |
+
"role": "user",
|
86 |
+
"content": [
|
87 |
+
{"type": "text", "text": question},
|
88 |
+
{
|
89 |
+
"type": "image_url",
|
90 |
+
"image_url": {"url": f"data:{mime_type};base64,{encoded}"},
|
91 |
+
},
|
92 |
+
],
|
93 |
+
}
|
94 |
+
]
|
95 |
+
|
96 |
+
# Invoke the vision model
|
97 |
+
response = model_vision.invoke(messages)
|
98 |
+
|
99 |
+
# Clean up
|
100 |
+
del messages, encoded, image_path
|
101 |
+
gc.collect()
|
102 |
+
torch.cuda.empty_cache()
|
103 |
+
|
104 |
+
return str(response.content) if hasattr(response, "content") else str(response)
|
105 |
+
|
106 |
+
|
107 |
+
# YouTube Video Review Tool
|
108 |
+
@tool
|
109 |
+
def review_youtube_video(url: str) -> str:
|
110 |
+
"""Reviews a YouTube video and answers a specific question about that video.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
url (str): the URL to the YouTube video.
|
114 |
+
question (str): The question you are asking about the video.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
str: The answer to the question
|
118 |
+
"""
|
119 |
+
# Extract video ID from URL (assuming it is in the format https://youtube.com/watch?v=VIDEO_ID)
|
120 |
+
video_id = url.split("v=")[1]
|
121 |
+
transcript_url = (
|
122 |
+
f"https://www.youtube.com/api/timedtext?v={video_id}" # Getting transcript data
|
123 |
+
)
|
124 |
+
|
125 |
+
response = requests.get(transcript_url, timeout=200)
|
126 |
+
|
127 |
+
transcript = response.text # This is the transcript (XML or SRT format)
|
128 |
+
|
129 |
+
# Prepare the content (just the transcript, no question needed)
|
130 |
+
transcript_content = f"Here is the transcript of the video: {transcript}"
|
131 |
+
|
132 |
+
# Return the transcript content so the main LLM can handle question generation
|
133 |
+
return transcript_content
|
134 |
+
|
135 |
+
|
136 |
+
# YouTube Frames to Images Tool
|
137 |
+
@tool
|
138 |
+
def video_frames_to_images(
|
139 |
+
url: str,
|
140 |
+
sample_interval_seconds: int = 5,
|
141 |
+
) -> List[str]:
|
142 |
+
"""Extracts frames from a video at specified intervals and saves them as images.
|
143 |
+
Args:
|
144 |
+
url (str): the URL to the video.
|
145 |
+
folder_name (str): the name of the folder to save the images to.
|
146 |
+
sample_interval_seconds (int): the interval between frames to sample.
|
147 |
+
Returns:
|
148 |
+
List[str]: A list of paths to the saved image files.
|
149 |
+
"""
|
150 |
+
folder_name = "./frames"
|
151 |
+
# Create a subdirectory for the frames
|
152 |
+
frames_dir = os.path.join(folder_name, "frames")
|
153 |
+
os.makedirs(frames_dir, exist_ok=True)
|
154 |
+
|
155 |
+
ydl_opts = {
|
156 |
+
"format": "bestvideo[height<=1080]+bestaudio/best[height<=1080]/best",
|
157 |
+
"outtmpl": os.path.join(folder_name, "video.%(ext)s"),
|
158 |
+
"quiet": True,
|
159 |
+
"noplaylist": True,
|
160 |
+
"merge_output_format": "mp4",
|
161 |
+
"force_ipv4": True,
|
162 |
+
}
|
163 |
+
|
164 |
+
info_extracted = []
|
165 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
166 |
+
info = ydl.extract_info(url, download=True)
|
167 |
+
info_extracted.append(info)
|
168 |
+
video_path = next(
|
169 |
+
(
|
170 |
+
os.path.join(folder_name, f)
|
171 |
+
for f in os.listdir(folder_name)
|
172 |
+
if f.endswith(".mp4")
|
173 |
+
),
|
174 |
+
None,
|
175 |
+
)
|
176 |
+
|
177 |
+
if not video_path:
|
178 |
+
raise RuntimeError("Failed to download video as mp4")
|
179 |
+
|
180 |
+
reader = imageio.get_reader(video_path)
|
181 |
+
# metadata = reader.get_meta_data()
|
182 |
+
fps = 25
|
183 |
+
duration_seconds = 120
|
184 |
+
|
185 |
+
frame_interval = int(fps * sample_interval_seconds)
|
186 |
+
num_frames = int(fps * duration_seconds)
|
187 |
+
# if num_frames is None or math.isinf(num_frames):
|
188 |
+
# num_frames = int(fps * duration_seconds)
|
189 |
+
# Handle case where the number of frames is infinite or invalid
|
190 |
+
# if num_frames == float("inf") or not isinstance(num_frames, int):
|
191 |
+
# reader.close()
|
192 |
+
# raise RuntimeError("Invalid video length (infinite or not an integer)")
|
193 |
+
|
194 |
+
image_paths: List[str] = []
|
195 |
+
|
196 |
+
for idx in range(num_frames):
|
197 |
+
if idx % frame_interval == 0:
|
198 |
+
# Save frame as image
|
199 |
+
frame = reader.get_data(idx)
|
200 |
+
image_path = os.path.join(frames_dir, f"frame_{idx:06d}.jpg")
|
201 |
+
imageio.imwrite(image_path, frame)
|
202 |
+
image_paths.append(image_path)
|
203 |
+
|
204 |
+
reader.close()
|
205 |
+
return image_paths
|
206 |
+
|
207 |
+
|
208 |
+
# File Reading Tool
|
209 |
+
@tool
|
210 |
+
def read_file(filepath: str) -> str:
|
211 |
+
"""Reads the content of a PYTHON file.
|
212 |
+
Args:
|
213 |
+
filepath (str): the path to the file to read.
|
214 |
+
Returns:
|
215 |
+
str: The content of the file.
|
216 |
+
"""
|
217 |
+
try:
|
218 |
+
with open(filepath, "r", encoding="utf-8") as file:
|
219 |
+
content = file.read()
|
220 |
+
# Calculate metadata for the prompt
|
221 |
+
filename = os.path.basename(filepath)
|
222 |
+
line_count = content.count("\\n") + 1
|
223 |
+
code_str = content.strip()
|
224 |
+
# Compose the prompt
|
225 |
+
prompt = f"""
|
226 |
+
You are a Python expert and code reviewer. Analyze the following Python script and answer the question provided.
|
227 |
+
Give Final Answer: the output of the code
|
228 |
+
Script Length: {line_count} lines
|
229 |
+
Filename: {filename}
|
230 |
+
|
231 |
+
Python Code:
|
232 |
+
```python
|
233 |
+
{code_str}
|
234 |
+
```
|
235 |
+
"""
|
236 |
+
|
237 |
+
model = model_text
|
238 |
+
|
239 |
+
# Call the model
|
240 |
+
message = HumanMessage(content=prompt)
|
241 |
+
response = model.invoke([message])
|
242 |
+
torch.cuda.empty_cache()
|
243 |
+
gc.collect()
|
244 |
+
# Return the result
|
245 |
+
if hasattr(response, "content") and isinstance(response.content, str):
|
246 |
+
return response.content
|
247 |
+
return str(response)
|
248 |
+
|
249 |
+
except FileNotFoundError:
|
250 |
+
return f"File not found: {filepath}"
|
251 |
+
except IOError as e:
|
252 |
+
return f"Error reading file: {str(e)}"
|
253 |
+
|
254 |
+
|
255 |
+
# To run python code
|
256 |
+
|
257 |
+
|
258 |
+
def execute_code(code: str):
|
259 |
+
"""Helper function to execute the code in a separate process."""
|
260 |
+
try:
|
261 |
+
exec(code)
|
262 |
+
except Exception as e:
|
263 |
+
raise RuntimeError(f"Error executing the code: {str(e)}") from e
|
264 |
+
|
265 |
+
|
266 |
+
@tool
|
267 |
+
def run_code_from_file(file_path: str, timeout: int = 10):
|
268 |
+
"""
|
269 |
+
Reads a Python file and executes it, with timeout handling.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
file_path (str): The full path to the Python file to execute.
|
273 |
+
timeout (int): The timeout in seconds before forcefully stopping the execution.
|
274 |
+
"""
|
275 |
+
# Check if the file exists
|
276 |
+
if not os.path.exists(file_path):
|
277 |
+
raise FileNotFoundError(f"The file {file_path} does not exist.")
|
278 |
+
|
279 |
+
# Read the file and get the code to execute
|
280 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
281 |
+
code = file.read()
|
282 |
+
|
283 |
+
# Start a process to execute the code
|
284 |
+
process = multiprocessing.Process(target=execute_code, args=(code,))
|
285 |
+
process.start()
|
286 |
+
|
287 |
+
# Wait for the process to finish or timeout
|
288 |
+
process.join(timeout)
|
289 |
+
|
290 |
+
# If the process is still alive after the timeout, terminate it
|
291 |
+
if process.is_alive():
|
292 |
+
process.terminate() # Stop the execution
|
293 |
+
raise TimeoutError(
|
294 |
+
f"The code execution took longer than {timeout} seconds and was terminated."
|
295 |
+
)
|
296 |
+
|
297 |
+
|
298 |
+
# File Download Tool
|
299 |
+
@tool
|
300 |
+
def download_file_from_url(url: str, directory: str) -> Dict[str, Union[str, None]]:
|
301 |
+
"""Downloads a file from a URL and saves it to a directory.
|
302 |
+
Args:
|
303 |
+
url (str): the URL to download the file from.
|
304 |
+
directory (str): the directory to save the file to.
|
305 |
+
Returns:
|
306 |
+
Dict[str, Union[str, None]]: A dictionary containing the file type and path.
|
307 |
+
"""
|
308 |
+
|
309 |
+
response = requests.get(url, stream=True, timeout=10)
|
310 |
+
response.raise_for_status()
|
311 |
+
|
312 |
+
content_type = response.headers.get("content-type", "").lower()
|
313 |
+
|
314 |
+
# Try to get filename from headers
|
315 |
+
filename = None
|
316 |
+
cd = response.headers.get("content-disposition", "")
|
317 |
+
match = re.search(r"filename\*=UTF-8\'\'(.+)", cd) or re.search(
|
318 |
+
r'filename="?([^"]+)"?', cd
|
319 |
+
)
|
320 |
+
if match:
|
321 |
+
filename = match.group(1)
|
322 |
+
|
323 |
+
# If not in headers, try URL
|
324 |
+
if not filename:
|
325 |
+
filename = os.path.basename(url.split("?")[0])
|
326 |
+
|
327 |
+
# Fallback to generated filename
|
328 |
+
if not filename:
|
329 |
+
extension = {
|
330 |
+
"image/jpeg": ".jpg",
|
331 |
+
"image/png": ".png",
|
332 |
+
"image/gif": ".gif",
|
333 |
+
"audio/wav": ".wav",
|
334 |
+
"audio/mpeg": ".mp3",
|
335 |
+
"video/mp4": ".mp4",
|
336 |
+
"text/plain": ".txt",
|
337 |
+
"text/csv": ".csv",
|
338 |
+
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx",
|
339 |
+
"application/vnd.ms-excel": ".xls",
|
340 |
+
"application/octet-stream": ".bin",
|
341 |
+
}.get(content_type, ".bin")
|
342 |
+
filename = f"downloaded_file{extension}"
|
343 |
+
|
344 |
+
os.makedirs(directory, exist_ok=True)
|
345 |
+
file_path = os.path.join(directory, filename)
|
346 |
+
print(file_path)
|
347 |
+
|
348 |
+
with open(file_path, "wb") as f:
|
349 |
+
for chunk in response.iter_content(chunk_size=8192):
|
350 |
+
f.write(chunk)
|
351 |
+
|
352 |
+
# shutil.copy(file_path, os.getcwd())
|
353 |
+
|
354 |
+
return {
|
355 |
+
"type": content_type,
|
356 |
+
"filename": filename,
|
357 |
+
"path": file_path,
|
358 |
+
}
|
359 |
+
|
360 |
+
|
361 |
+
# Text Extraction from Image Tool
|
362 |
+
@tool
|
363 |
+
def extract_text_from_image(image_path: str) -> str:
|
364 |
+
"""Extracts text from an image using OCR.
|
365 |
+
Args:
|
366 |
+
image_path (str): the path to the image to extract text from.
|
367 |
+
Returns:
|
368 |
+
str: The text extracted from the image.
|
369 |
+
"""
|
370 |
+
|
371 |
+
image = Image.open(image_path)
|
372 |
+
text = pytesseract.image_to_string(image)
|
373 |
+
return f"Extracted text from image:\n\n{text}"
|
374 |
+
|
375 |
+
|
376 |
+
# CSV Analysis Tool
|
377 |
+
@tool
|
378 |
+
def analyze_csv_file(file_path: str, query: str) -> str:
|
379 |
+
"""Analyzes a CSV file and answers questions about its contents using an
|
380 |
+
Ollama model.
|
381 |
+
|
382 |
+
Args:
|
383 |
+
file_path (str): The path to the CSV file to analyze.
|
384 |
+
query (str): The question to answer about the CSV file.
|
385 |
+
|
386 |
+
Returns:
|
387 |
+
str: The result of the analysis.
|
388 |
+
"""
|
389 |
+
# Load the CSV file
|
390 |
+
df = pd.read_csv(file_path)
|
391 |
+
df_str = df.to_string(index=False)
|
392 |
+
|
393 |
+
# Compose the prompt
|
394 |
+
prompt = f"""
|
395 |
+
You are a data analyst. Analyze the following CSV data and answer the question provided.
|
396 |
+
|
397 |
+
CSV Dimensions: {df.shape[0]} rows × {df.shape[1]} columns
|
398 |
+
|
399 |
+
CSV Data:
|
400 |
+
{df_str}
|
401 |
+
|
402 |
+
Please provide:
|
403 |
+
1. A summary of the data structure and content
|
404 |
+
2. Key patterns and insights
|
405 |
+
3. Potential data quality issues
|
406 |
+
4. Suggestions for analysis
|
407 |
+
|
408 |
+
User Query:
|
409 |
+
{query}
|
410 |
+
|
411 |
+
Format your response in markdown with sections and bullet points.
|
412 |
+
"""
|
413 |
+
|
414 |
+
model = model_text
|
415 |
+
|
416 |
+
# Call the model
|
417 |
+
response = model.invoke([{"type": "text", "text": prompt}])
|
418 |
+
del df
|
419 |
+
torch.cuda.empty_cache()
|
420 |
+
gc.collect()
|
421 |
+
|
422 |
+
# Return the result
|
423 |
+
if hasattr(response, "content") and isinstance(response.content, str):
|
424 |
+
return response.content
|
425 |
+
return str(response)
|
426 |
+
|
427 |
+
|
428 |
+
# Excel Analysis Tool
|
429 |
+
@tool
|
430 |
+
def analyze_excel_file(file_path: str) -> str:
|
431 |
+
"""Analyzes an Excel file and answers questions about its contents using an
|
432 |
+
Ollama model
|
433 |
+
Args:
|
434 |
+
file_path (str): the path to the Excel file to analyze.
|
435 |
+
query (str): the question to answer about the Excel file.
|
436 |
+
Returns:
|
437 |
+
str: The result of the analysis.
|
438 |
+
"""
|
439 |
+
llm = model_text
|
440 |
+
print(file_path)
|
441 |
+
|
442 |
+
# Read all sheets from the Excel file
|
443 |
+
excel_file = pd.ExcelFile(file_path)
|
444 |
+
sheet_names = excel_file.sheet_names
|
445 |
+
|
446 |
+
result = f"Excel file loaded with {len(sheet_names)} sheets: {', '.join(sheet_names)}\n\n"
|
447 |
+
|
448 |
+
for sheet_name in sheet_names:
|
449 |
+
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
450 |
+
df_str = df.to_string()
|
451 |
+
|
452 |
+
# Build the prompt
|
453 |
+
prompt = f"""Analyze the following Excel sheet data and answer the user's query.
|
454 |
+
Sheet Name: {sheet_name}
|
455 |
+
Dimensions: {len(df)} rows × {len(df.columns)} columns
|
456 |
+
|
457 |
+
Data:
|
458 |
+
{df_str}
|
459 |
+
|
460 |
+
Please provide:
|
461 |
+
1. A summary of the data structure and content
|
462 |
+
2. List all the values of the columns in a proper table format.
|
463 |
+
3. If a file contains food items, assume it refers to the
|
464 |
+
monetary value of the items, not the quantity sold.
|
465 |
+
4. If the File contains food items, make a new list which
|
466 |
+
contains the name of all the food item in the column only (not including drinks).
|
467 |
+
5. If the file contains any time of monetary value its in USD with two decimal places.
|
468 |
+
|
469 |
+
Format the response clearly using headings and bullet points."""
|
470 |
+
|
471 |
+
# Call the LLM with the prompt
|
472 |
+
response = llm.invoke([HumanMessage(content=prompt)])
|
473 |
+
|
474 |
+
result += f"=== Sheet: {sheet_name} ===\n"
|
475 |
+
result += str(response.content) + "\n"
|
476 |
+
result += "=" * 50 + "\n\n"
|
477 |
+
del df
|
478 |
+
gc.collect()
|
479 |
+
|
480 |
+
excel_file.close()
|
481 |
+
torch.cuda.empty_cache()
|
482 |
+
|
483 |
+
return result
|
484 |
+
|
485 |
+
|
486 |
+
# Audio Transcription Tool
|
487 |
+
def transcribe_audio(audio_file_path: str) -> str:
|
488 |
+
"""Transcribes an audio file using Whisper's audio capabilities.
|
489 |
+
Always give Final Answer of the question in a specific format for example list all the pages mentioned in increasing order in one line.
|
490 |
+
Change vanilla extract to pure vanilla extract in the final answer.
|
491 |
+
Args:
|
492 |
+
audio_file_path (str): The path to the audio file to transcribe.
|
493 |
+
mime_type (str): The MIME type of the audio file.
|
494 |
+
Returns:
|
495 |
+
str: The transcript of the audio file.
|
496 |
+
Raises:
|
497 |
+
ValueError: If the MIME type is not supported.
|
498 |
+
"""
|
499 |
+
|
500 |
+
model = whisper.load_model("base")
|
501 |
+
result = model.transcribe(audio_file_path)
|
502 |
+
assert isinstance(result["text"], str)
|
503 |
+
|
504 |
+
del model
|
505 |
+
torch.cuda.empty_cache()
|
506 |
+
gc.collect()
|
507 |
+
return result["text"]
|
508 |
+
|
509 |
+
|
510 |
+
def _extract_video_id(url: str) -> Optional[str]:
|
511 |
+
"""Extract video ID from YouTube URL.
|
512 |
+
Args:
|
513 |
+
url (str): the URL to the YouTube video.
|
514 |
+
Returns:
|
515 |
+
str: The video ID of the YouTube video.
|
516 |
+
"""
|
517 |
+
patterns = [
|
518 |
+
r"(?:youtube\.com\/watch\?v=|youtube\.com\/embed\/|youtu\.be\/)([^&\n?#]+)",
|
519 |
+
r"(?:youtube\.com\/v\/|youtube\.com\/e\/|youtube\.com\/user\/[^\/]+\/|youtube\.com\/[^\/]+\/|youtube\.com\/embed\/|youtu\.be\/)([^&\n?#]+)",
|
520 |
+
]
|
521 |
+
|
522 |
+
for pattern in patterns:
|
523 |
+
match = re.search(pattern, url)
|
524 |
+
if match:
|
525 |
+
return match.group(1)
|
526 |
+
return None
|
527 |
+
|
528 |
+
|
529 |
+
@tool
|
530 |
+
def transcribe_youtube(url: str) -> str:
|
531 |
+
"""
|
532 |
+
Transcribes a YouTube video using YouTube Transcript API or ChatOllama with Whisper as fallback.
|
533 |
+
|
534 |
+
This function first tries to fetch the transcript of a YouTube video using the YouTube Transcript API.
|
535 |
+
If the transcript is unavailable (e.g., due to captions being disabled), it falls back to using
|
536 |
+
ChatOllama integrated with Whisper to transcribe the audio.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
url (str): The URL to the YouTube video.
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
str: The transcript of the YouTube video, or an error message if transcription fails.
|
543 |
+
"""
|
544 |
+
|
545 |
+
try:
|
546 |
+
# Try using YouTube Transcript API
|
547 |
+
video_id = _extract_video_id(url)
|
548 |
+
transcript = ""
|
549 |
+
transcript_chunks = YouTubeTranscriptApi.get_transcript(
|
550 |
+
video_id, languages=["en"]
|
551 |
+
)
|
552 |
+
for chunk in transcript_chunks:
|
553 |
+
timestamp = str(timedelta(seconds=int(chunk["start"])))
|
554 |
+
transcript += f"[{timestamp}] {chunk['text']}\n"
|
555 |
+
|
556 |
+
# Return API transcript if available
|
557 |
+
if transcript.strip():
|
558 |
+
return transcript
|
559 |
+
|
560 |
+
except (TranscriptsDisabled, NoTranscriptFound, Exception) as err:
|
561 |
+
try:
|
562 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
563 |
+
# Download audio from YouTube
|
564 |
+
ydl_opts = {
|
565 |
+
"format": "bestaudio/best",
|
566 |
+
"outtmpl": os.path.join(tmpdir, "audio.%(ext)s"),
|
567 |
+
"quiet": True,
|
568 |
+
"noplaylist": True,
|
569 |
+
"postprocessors": [
|
570 |
+
{
|
571 |
+
"key": "FFmpegExtractAudio",
|
572 |
+
"preferredcodec": "wav",
|
573 |
+
"preferredquality": "192",
|
574 |
+
}
|
575 |
+
],
|
576 |
+
}
|
577 |
+
|
578 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
579 |
+
info = ydl.extract_info(url, download=True)
|
580 |
+
|
581 |
+
if info is not None:
|
582 |
+
title = info.get("title", "Unknown Title") # Type:None
|
583 |
+
duration = info.get("duration", 0) # in seconds
|
584 |
+
uploader = info.get("uploader", "Unknown Uploader")
|
585 |
+
else:
|
586 |
+
title = "Unknown Title"
|
587 |
+
duration = 0
|
588 |
+
uploader = "Unknown Uploader"
|
589 |
+
|
590 |
+
audio_path = next(
|
591 |
+
(
|
592 |
+
os.path.join(tmpdir, f)
|
593 |
+
for f in os.listdir(tmpdir)
|
594 |
+
if f.endswith(".wav")
|
595 |
+
),
|
596 |
+
None,
|
597 |
+
)
|
598 |
+
if not audio_path:
|
599 |
+
raise RuntimeError("Failed to download or convert audio") from err
|
600 |
+
|
601 |
+
# Use Whisper for initial transcription
|
602 |
+
whisper_model = whisper.load_model("base")
|
603 |
+
transcription = whisper_model.transcribe(audio_path, verbose=False)
|
604 |
+
raw_transcript = transcription["text"]
|
605 |
+
del whisper_model
|
606 |
+
gc.collect()
|
607 |
+
torch.cuda.empty_cache()
|
608 |
+
result = f"Title: {title}\nUploader: {uploader}\nDuration: {duration} seconds\nTranscript: {raw_transcript}"
|
609 |
+
return result
|
610 |
+
except Exception as fallback_exc:
|
611 |
+
raise RuntimeError("Fallback Transcription failed") from fallback_exc
|
612 |
+
return "Transcription failed unexpectedly."
|
613 |
+
|
614 |
+
|
615 |
+
@tool
|
616 |
+
def website_scrape(url: str) -> str:
|
617 |
+
"""scrapes a website and returns the text.
|
618 |
+
args:
|
619 |
+
url (str): the url to the website to scrape.
|
620 |
+
returns:
|
621 |
+
str: the text of the website.
|
622 |
+
"""
|
623 |
+
try:
|
624 |
+
parsed_url = urlparse(url)
|
625 |
+
if not parsed_url.scheme or not parsed_url.netloc:
|
626 |
+
raise ValueError(
|
627 |
+
f"Invalid URL: '{url}'. Call `duckduckgo_search` first to get a valid URL."
|
628 |
+
)
|
629 |
+
with sync_playwright() as p:
|
630 |
+
browser = p.chromium.launch(headless=True)
|
631 |
+
page = browser.new_page()
|
632 |
+
page.goto(url, wait_until="networkidle", timeout=60000)
|
633 |
+
page.wait_for_load_state("domcontentloaded")
|
634 |
+
html_content = page.content()
|
635 |
+
browser.close()
|
636 |
+
|
637 |
+
soup = BeautifulSoup(html_content, "html.parser")
|
638 |
+
|
639 |
+
relevant_text = ""
|
640 |
+
# for header in soup.find_all(["h2", "h3"]):
|
641 |
+
# heading_text = header.get_text().strip().lower()
|
642 |
+
# if "discography" in heading_text or "studio albums" in heading_text:
|
643 |
+
# section_texts = []
|
644 |
+
# tag = header.find_next_sibling()
|
645 |
+
# while tag and (
|
646 |
+
# not isinstance(tag, Tag) or tag.name not in ["h2", "h3"]
|
647 |
+
# ):
|
648 |
+
# section_texts.append(tag.get_text(separator=" ", strip=True))
|
649 |
+
# tag = tag.find_next_sibling()
|
650 |
+
# relevant_text = "\n\n".join(section_texts)
|
651 |
+
# break
|
652 |
+
# if not relevant_text:
|
653 |
+
# article = soup.find("article")
|
654 |
+
# if article:
|
655 |
+
# relevant_text = article.get_text(separator=" ", strip=True)
|
656 |
+
# if not relevant_text:
|
657 |
+
relevant_text = soup.get_text(separator=" ", strip=True)
|
658 |
+
|
659 |
+
# step 2: chunk the text (optional but recommended)
|
660 |
+
def chunk_text(text, max_length=1000):
|
661 |
+
words = text.split()
|
662 |
+
chunks = []
|
663 |
+
for i in range(0, len(words), max_length):
|
664 |
+
chunks.append(" ".join(words[i : i + max_length]))
|
665 |
+
return chunks
|
666 |
+
|
667 |
+
chunks = chunk_text(relevant_text)
|
668 |
+
|
669 |
+
# return only the first 2–3 chunks to keep it concise
|
670 |
+
return "\n\n".join(chunks[:5])
|
671 |
+
except ValueError as e:
|
672 |
+
# Catch URL validation errors
|
673 |
+
return str(e)
|
674 |
+
except Exception as e:
|
675 |
+
# Catch other unexpected errors
|
676 |
+
return f"Scraping failed: {str(e)}"
|
677 |
+
|
678 |
+
|
679 |
+
class SearchResult(TypedDict):
|
680 |
+
query: str
|
681 |
+
status: str
|
682 |
+
attempt: int
|
683 |
+
results: Optional[List[dict]]
|
684 |
+
error: Optional[str]
|
685 |
+
|
686 |
+
|
687 |
+
@tool
|
688 |
+
def duckduckgo_search(query: str, max_results: int = 10) -> SearchResult:
|
689 |
+
"""
|
690 |
+
Perform a DuckDuckGo search with retry and backoff.
|
691 |
+
Use this FIRST before invoking and scraping tools.
|
692 |
+
Args:
|
693 |
+
query: The search query string.
|
694 |
+
max_results: Max number of results to return (default 10).
|
695 |
+
Returns:
|
696 |
+
A dict with the query, results, status, attempt count, and any error.
|
697 |
+
"""
|
698 |
+
max_retries = 3
|
699 |
+
base_delay = 2
|
700 |
+
backoff_factor = 2
|
701 |
+
|
702 |
+
for attempt in range(max_retries):
|
703 |
+
try:
|
704 |
+
with DDGS() as ddgs:
|
705 |
+
results = ddgs.text(keywords=query, max_results=max_results)
|
706 |
+
if results:
|
707 |
+
formatted_results = [
|
708 |
+
{
|
709 |
+
"title": result.get("title", ""),
|
710 |
+
"url": result.get("href", ""),
|
711 |
+
"body": result.get("body", ""),
|
712 |
+
}
|
713 |
+
for result in results
|
714 |
+
]
|
715 |
+
return {
|
716 |
+
"query": query,
|
717 |
+
"status": "success",
|
718 |
+
"attempt": attempt + 1,
|
719 |
+
"results": formatted_results,
|
720 |
+
"error": None,
|
721 |
+
}
|
722 |
+
except Exception as e:
|
723 |
+
print(f"[DuckDuckGo Tool] Attempt {attempt + 1} failed: {e}")
|
724 |
+
time.sleep(base_delay * (backoff_factor**attempt))
|
725 |
+
|
726 |
+
return {
|
727 |
+
"query": query,
|
728 |
+
"status": "failed",
|
729 |
+
"attempt": max_retries,
|
730 |
+
"results": None,
|
731 |
+
"error": "Max retries exceeded or request failed.",
|
732 |
+
}
|
733 |
+
|
734 |
+
|
735 |
+
@tool
|
736 |
+
def reverse_decoder(question: str) -> str:
|
737 |
+
"""Decodes a reversed sentence if the input appears to be written backward.
|
738 |
+
|
739 |
+
Args:
|
740 |
+
question (str): The possibly reversed question string.
|
741 |
+
|
742 |
+
Returns:
|
743 |
+
str: The decoded sentence.
|
744 |
+
"""
|
745 |
+
# Remove leading punctuation if present
|
746 |
+
cleaned = question.strip().strip(".!?")
|
747 |
+
|
748 |
+
# Check if it's likely reversed (simple heuristic: mostly lowercase, reversed word order)
|
749 |
+
reversed_text = cleaned[::-1]
|
750 |
+
|
751 |
+
return reversed_text
|