Batnini commited on
Commit
f3ad633
·
verified ·
1 Parent(s): 233fac6

Update tools/tool_agent.py

Browse files
Files changed (1) hide show
  1. tools/tool_agent.py +27 -13
tools/tool_agent.py CHANGED
@@ -10,22 +10,36 @@ class ToolCallingAgent:
10
  )
11
 
12
  def generate(self, prompt, tools):
13
- # Format the tools specification
14
  tools_json = json.dumps(tools, ensure_ascii=False)
15
-
16
- # Create the tool-calling prompt
17
  system_msg = f"""You are an AI assistant that can call tools.
18
- Available tools: {tools_json}
19
- Respond with JSON containing 'tool_name' and 'parameters'."""
20
-
21
- # Generate the response
 
 
22
  response = self.model(
23
- f"<|system|>{system_msg}</s><|user|>{prompt}</s>",
24
  max_new_tokens=200,
25
- do_sample=True
26
  )
27
-
 
 
 
 
 
 
 
 
 
 
28
  try:
29
- return json.loads(response[0]['generated_text'])
30
- except:
31
- return {"error": "Failed to parse tool call"}
 
 
 
 
 
 
10
  )
11
 
12
  def generate(self, prompt, tools):
 
13
  tools_json = json.dumps(tools, ensure_ascii=False)
 
 
14
  system_msg = f"""You are an AI assistant that can call tools.
15
+ Available tools: {tools_json}
16
+ Respond ONLY with a valid JSON containing keys 'tool_name' and 'parameters'."""
17
+
18
+ # Construct prompt with system and user tokens (assuming model supports these)
19
+ full_prompt = f"<|system|>{system_msg}</s><|user|>{prompt}</s>"
20
+
21
  response = self.model(
22
+ full_prompt,
23
  max_new_tokens=200,
24
+ do_sample=False # deterministic output for better JSON consistency
25
  )
26
+
27
+ text = response[0]['generated_text']
28
+
29
+ # Extract JSON substring between first '{' and last '}'
30
+ json_start = text.find("{")
31
+ json_end = text.rfind("}") + 1
32
+ if json_start == -1 or json_end == -1:
33
+ return {"error": "No JSON found in model output", "raw_output": text}
34
+
35
+ json_text = text[json_start:json_end]
36
+
37
  try:
38
+ return json.loads(json_text)
39
+ except json.JSONDecodeError as e:
40
+ return {
41
+ "error": "Failed to parse JSON",
42
+ "message": str(e),
43
+ "raw_output": text,
44
+ "extracted_json": json_text
45
+ }