CoralLeiCN
commited on
Commit
·
0b4ed3e
1
Parent(s):
d564629
Add BaseballQATool: implement tool for answering baseball player queries
Browse files- agent/agents.py +3 -0
- agent/tools.py +45 -0
agent/agents.py
CHANGED
@@ -16,6 +16,7 @@ from agent.tools import (
|
|
16 |
TranscribeYoutubeVideo,
|
17 |
UnderstandImageBytes,
|
18 |
WikipediaSearchTool,
|
|
|
19 |
)
|
20 |
from agent.utils import gemini_client, gemini_model_liteLLM
|
21 |
|
@@ -63,6 +64,7 @@ class BasicAgent:
|
|
63 |
code_execution_tool = CodeExecutionTool()
|
64 |
wiki_retriever = WikipediaSearchTool()
|
65 |
chess_best_move = ChessBestMove()
|
|
|
66 |
model = gemini_model_liteLLM(self.model)
|
67 |
|
68 |
if if_sleep:
|
@@ -81,6 +83,7 @@ class BasicAgent:
|
|
81 |
code_execution_tool,
|
82 |
wiki_retriever,
|
83 |
chess_best_move,
|
|
|
84 |
],
|
85 |
model=model,
|
86 |
additional_authorized_imports=["pandas"],
|
|
|
16 |
TranscribeYoutubeVideo,
|
17 |
UnderstandImageBytes,
|
18 |
WikipediaSearchTool,
|
19 |
+
BaseballQATool,
|
20 |
)
|
21 |
from agent.utils import gemini_client, gemini_model_liteLLM
|
22 |
|
|
|
64 |
code_execution_tool = CodeExecutionTool()
|
65 |
wiki_retriever = WikipediaSearchTool()
|
66 |
chess_best_move = ChessBestMove()
|
67 |
+
baseball_qa_tool = BaseballQATool()
|
68 |
model = gemini_model_liteLLM(self.model)
|
69 |
|
70 |
if if_sleep:
|
|
|
83 |
code_execution_tool,
|
84 |
wiki_retriever,
|
85 |
chess_best_move,
|
86 |
+
baseball_qa_tool,
|
87 |
],
|
88 |
model=model,
|
89 |
additional_authorized_imports=["pandas"],
|
agent/tools.py
CHANGED
@@ -9,6 +9,51 @@ from PIL import Image
|
|
9 |
from smolagents import Tool
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
class ChessBestMove(Tool):
|
13 |
name = "chess_best_move"
|
14 |
description = """Find the best move for a chess position.
|
|
|
9 |
from smolagents import Tool
|
10 |
|
11 |
|
12 |
+
class BaseballQATool(Tool):
|
13 |
+
name = "baseball_qa"
|
14 |
+
description = (
|
15 |
+
"""This tool can answer questions about baseball players information."""
|
16 |
+
)
|
17 |
+
inputs = {
|
18 |
+
"player_name": {
|
19 |
+
"type": "string",
|
20 |
+
"description": "The name of the baseball player.",
|
21 |
+
},
|
22 |
+
"question": {
|
23 |
+
"type": "string",
|
24 |
+
"description": "The question to ask about the player.",
|
25 |
+
},
|
26 |
+
"team_name": {
|
27 |
+
"type": "string",
|
28 |
+
"description": "The name of the team the player is associated with.",
|
29 |
+
"nullable": True,
|
30 |
+
},
|
31 |
+
}
|
32 |
+
output_type = "string"
|
33 |
+
|
34 |
+
def forward(self, player_name: str, question: str, team_name: str = ""):
|
35 |
+
config = types.GenerateContentConfig(
|
36 |
+
temperature=0,
|
37 |
+
candidate_count=1,
|
38 |
+
response_mime_type="application/json",
|
39 |
+
top_p=0.95,
|
40 |
+
seed=42,
|
41 |
+
)
|
42 |
+
client = genai.Client()
|
43 |
+
response = client.models.generate_content(
|
44 |
+
model="gemini-2.5-pro",
|
45 |
+
contents=types.Content(
|
46 |
+
parts=[
|
47 |
+
types.Part(
|
48 |
+
text=f"Pay attention to the details. Make corrections if needed. Player: {player_name}\nTeam: {team_name}\nQuestion: {question}"
|
49 |
+
)
|
50 |
+
],
|
51 |
+
),
|
52 |
+
config=config,
|
53 |
+
)
|
54 |
+
return response.text
|
55 |
+
|
56 |
+
|
57 |
class ChessBestMove(Tool):
|
58 |
name = "chess_best_move"
|
59 |
description = """Find the best move for a chess position.
|