Batnini commited on
Commit
f00bc1d
·
verified ·
1 Parent(s): f9a91b4

Update tools/tool_agent.py

Browse files
Files changed (1) hide show
  1. tools/tool_agent.py +20 -23
tools/tool_agent.py CHANGED
@@ -1,4 +1,3 @@
1
- # tools/tool_agent.py
2
  from transformers import pipeline
3
  import json
4
 
@@ -6,29 +5,27 @@ class ToolCallingAgent:
6
  def __init__(self):
7
  self.model = pipeline(
8
  "text-generation",
9
- model="aubmindlab/aragpt2-base", # Arabic-friendly, small
10
- device=-1
11
  )
12
-
13
  def generate(self, prompt, tools):
 
14
  tools_json = json.dumps(tools, ensure_ascii=False)
15
- full_prompt = f"""أنت مساعد ذكي. اختر أداة واحدة من القائمة التالية وأعد النتيجة كـ JSON فقط:
16
- {tools_json}
17
-
18
- مدخل المستخدم: {prompt}
19
-
20
- صيغة الإخراج المطلوبة:
21
- {{"tool_name": "اسم_الأداة", "parameters": {{"param": "القيمة"}}}}"""
22
-
 
 
 
 
 
23
  try:
24
- response = self.model(
25
- full_prompt,
26
- max_new_tokens=150,
27
- do_sample=False
28
- )
29
- text = response[0]['generated_text']
30
- json_start = text.find("{")
31
- json_end = text.rfind("}") + 1
32
- return json.loads(text[json_start:json_end])
33
- except Exception as e:
34
- return {"tool_name": "error", "parameters": {"message": str(e)}}
 
 
1
  from transformers import pipeline
2
  import json
3
 
 
5
  def __init__(self):
6
  self.model = pipeline(
7
  "text-generation",
8
+ model="cognitivecomputations/dolphin-2.9-llama3-8b",
9
+ device_map="auto"
10
  )
11
+
12
  def generate(self, prompt, tools):
13
+ # Format the tools specification
14
  tools_json = json.dumps(tools, ensure_ascii=False)
15
+
16
+ # Create the tool-calling prompt
17
+ system_msg = f"""You are an AI assistant that can call tools.
18
+ Available tools: {tools_json}
19
+ Respond with JSON containing 'tool_name' and 'parameters'."""
20
+
21
+ # Generate the response
22
+ response = self.model(
23
+ f"<|system|>{system_msg}</s><|user|>{prompt}</s>",
24
+ max_new_tokens=200,
25
+ do_sample=True
26
+ )
27
+
28
  try:
29
+ return json.loads(response[0]['generated_text'])
30
+ except:
31
+ return {"error": "Failed to parse tool call"}