Batnini commited on
Commit
dae0170
·
verified ·
1 Parent(s): 07a1c4f

Update tools/tool_agent.py

Browse files
Files changed (1) hide show
  1. tools/tool_agent.py +14 -10
tools/tool_agent.py CHANGED
@@ -3,25 +3,29 @@ import json
3
 
4
  class ToolCallingAgent:
5
  def __init__(self):
6
- # Force CPU and smaller model
7
  self.model = pipeline(
8
  "text-generation",
9
- model="cognitivecomputations/dolphin-2.1-mistral-7b", # Smaller than llama3
10
- device=-1, # Force CPU
11
- torch_dtype="float32" # Better for CPU
12
  )
13
 
14
  def generate(self, prompt, tools):
15
  try:
16
  tools_json = json.dumps(tools, ensure_ascii=False)
17
- prompt = f"""Respond with JSON for one tool call. Tools: {tools_json}\nInput: {prompt}"""
18
 
19
  response = self.model(
20
  prompt,
21
- max_new_tokens=150, # Shorter responses
22
- do_sample=False # More deterministic
23
  )
24
 
25
- return json.loads(response[0]['generated_text'])
26
- except:
27
- return {"tool_name": "error", "parameters": {"message": "Failed to process request"}}
 
 
 
 
 
3
 
4
  class ToolCallingAgent:
5
  def __init__(self):
6
+ # Small CPU-friendly model
7
  self.model = pipeline(
8
  "text-generation",
9
+ model="gpt2", # Replace with small model you want
10
+ device=-1,
11
+ torch_dtype="float32"
12
  )
13
 
14
  def generate(self, prompt, tools):
15
  try:
16
  tools_json = json.dumps(tools, ensure_ascii=False)
17
+ prompt = f"""Respond ONLY with JSON for one tool call from the following list: {tools_json}\nUser input: {prompt}"""
18
 
19
  response = self.model(
20
  prompt,
21
+ max_new_tokens=100,
22
+ do_sample=False
23
  )
24
 
25
+ # Try to find JSON in output
26
+ text = response[0]['generated_text']
27
+ json_start = text.find("{")
28
+ json_end = text.rfind("}") + 1
29
+ return json.loads(text[json_start:json_end])
30
+ except Exception as e:
31
+ return {"tool_name": "error", "parameters": {"message": f"Failed to process request: {str(e)}"}}