Update tools/tool_agent.py
Browse files- tools/tool_agent.py +14 -10
tools/tool_agent.py
CHANGED
@@ -3,25 +3,29 @@ import json
|
|
3 |
|
4 |
class ToolCallingAgent:
|
5 |
def __init__(self):
|
6 |
-
#
|
7 |
self.model = pipeline(
|
8 |
"text-generation",
|
9 |
-
model="
|
10 |
-
device=-1,
|
11 |
-
torch_dtype="float32"
|
12 |
)
|
13 |
|
14 |
def generate(self, prompt, tools):
|
15 |
try:
|
16 |
tools_json = json.dumps(tools, ensure_ascii=False)
|
17 |
-
prompt = f"""Respond with JSON for one tool call
|
18 |
|
19 |
response = self.model(
|
20 |
prompt,
|
21 |
-
max_new_tokens=
|
22 |
-
do_sample=False
|
23 |
)
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class ToolCallingAgent:
|
5 |
def __init__(self):
|
6 |
+
# Small CPU-friendly model
|
7 |
self.model = pipeline(
|
8 |
"text-generation",
|
9 |
+
model="gpt2", # Replace with small model you want
|
10 |
+
device=-1,
|
11 |
+
torch_dtype="float32"
|
12 |
)
|
13 |
|
14 |
def generate(self, prompt, tools):
|
15 |
try:
|
16 |
tools_json = json.dumps(tools, ensure_ascii=False)
|
17 |
+
prompt = f"""Respond ONLY with JSON for one tool call from the following list: {tools_json}\nUser input: {prompt}"""
|
18 |
|
19 |
response = self.model(
|
20 |
prompt,
|
21 |
+
max_new_tokens=100,
|
22 |
+
do_sample=False
|
23 |
)
|
24 |
|
25 |
+
# Try to find JSON in output
|
26 |
+
text = response[0]['generated_text']
|
27 |
+
json_start = text.find("{")
|
28 |
+
json_end = text.rfind("}") + 1
|
29 |
+
return json.loads(text[json_start:json_end])
|
30 |
+
except Exception as e:
|
31 |
+
return {"tool_name": "error", "parameters": {"message": f"Failed to process request: {str(e)}"}}
|