Batnini commited on
Commit
f9a91b4
·
verified ·
1 Parent(s): bdb3160

Update tools/tool_agent.py

Browse files
Files changed (1) hide show
  1. tools/tool_agent.py +26 -24
tools/tool_agent.py CHANGED
@@ -1,7 +1,18 @@
1
- def generate(self, prompt, tools):
2
- try:
 
 
 
 
 
 
 
 
 
 
 
3
  tools_json = json.dumps(tools, ensure_ascii=False)
4
- prompt = f"""أنت مساعد ذكي. اختر أداة واحدة من القائمة التالية ثم أعد النتيجة كـ JSON فقط:
5
  {tools_json}
6
 
7
  مدخل المستخدم: {prompt}
@@ -9,24 +20,15 @@ def generate(self, prompt, tools):
9
  صيغة الإخراج المطلوبة:
10
  {{"tool_name": "اسم_الأداة", "parameters": {{"param": "القيمة"}}}}"""
11
 
12
- response = self.model(
13
- prompt,
14
- max_new_tokens=150,
15
- do_sample=False
16
- )
17
-
18
- text = response[0]['generated_text']
19
-
20
- # Extract JSON from text
21
- json_start = text.find("{")
22
- json_end = text.rfind("}") + 1
23
- parsed = json.loads(text[json_start:json_end])
24
-
25
- # Fallback if JSON missing keys
26
- if "tool_name" not in parsed:
27
- parsed = {"tool_name": "error", "parameters": {"message": "لم يتم التعرف على الأداة"}}
28
-
29
- return parsed
30
-
31
- except Exception as e:
32
- return {"tool_name": "error", "parameters": {"message": f"فشل المعالجة: {str(e)}"}}
 
1
+ # tools/tool_agent.py
2
+ from transformers import pipeline
3
+ import json
4
+
5
+ class ToolCallingAgent:
6
+ def __init__(self):
7
+ self.model = pipeline(
8
+ "text-generation",
9
+ model="aubmindlab/aragpt2-base", # Arabic-friendly, small
10
+ device=-1
11
+ )
12
+
13
+ def generate(self, prompt, tools):
14
  tools_json = json.dumps(tools, ensure_ascii=False)
15
+ full_prompt = f"""أنت مساعد ذكي. اختر أداة واحدة من القائمة التالية وأعد النتيجة كـ JSON فقط:
16
  {tools_json}
17
 
18
  مدخل المستخدم: {prompt}
 
20
  صيغة الإخراج المطلوبة:
21
  {{"tool_name": "اسم_الأداة", "parameters": {{"param": "القيمة"}}}}"""
22
 
23
+ try:
24
+ response = self.model(
25
+ full_prompt,
26
+ max_new_tokens=150,
27
+ do_sample=False
28
+ )
29
+ text = response[0]['generated_text']
30
+ json_start = text.find("{")
31
+ json_end = text.rfind("}") + 1
32
+ return json.loads(text[json_start:json_end])
33
+ except Exception as e:
34
+ return {"tool_name": "error", "parameters": {"message": str(e)}}