huytofu92 commited on
Commit
29cd08b
·
1 Parent(s): b120423

Fix tools desc

Browse files
Files changed (2) hide show
  1. mini_agents.py +2 -2
  2. vlm_tools.py +17 -16
mini_agents.py CHANGED
@@ -1,7 +1,7 @@
1
  from smolagents import CodeAgent, InferenceClientModel
2
  from tools import sort_list, operate_two_numbers, convert_number, load_dataframe_from_csv
3
  from tools import to_dataframe, to_json, get_dataframe_data, get_dataframe_column, get_dataframe_row, get_dataframe_groupby
4
- from vlm_tools import download_image, image_processing, object_detection_tool, ocr_scan, extract_frames_from_video
5
  from audio_tools import audio_to_base64, noise_reduction, audio_segmentation, speaker_diarization
6
  from community_tools import community_tools
7
  import os
@@ -40,7 +40,7 @@ vlm_model = InferenceClientModel(
40
 
41
  vlm_agent = CodeAgent(
42
  model=vlm_model,
43
- tools=[download_image, image_processing, object_detection_tool, ocr_scan, extract_frames_from_video],
44
  max_steps=4,
45
  name="vlm_agent",
46
  description="This agent is responsible for downloading images, processing images, detecting objects in them and extracting text from them."
 
1
  from smolagents import CodeAgent, InferenceClientModel
2
  from tools import sort_list, operate_two_numbers, convert_number, load_dataframe_from_csv
3
  from tools import to_dataframe, to_json, get_dataframe_data, get_dataframe_column, get_dataframe_row, get_dataframe_groupby
4
+ from vlm_tools import download_image, image_processing, object_detection_tool, ocr_scan_tool, extract_frames_from_video
5
  from audio_tools import audio_to_base64, noise_reduction, audio_segmentation, speaker_diarization
6
  from community_tools import community_tools
7
  import os
 
40
 
41
  vlm_agent = CodeAgent(
42
  model=vlm_model,
43
+ tools=[download_image, image_processing, object_detection_tool, ocr_scan_tool, extract_frames_from_video],
44
  max_steps=4,
45
  name="vlm_agent",
46
  description="This agent is responsible for downloading images, processing images, detecting objects in them and extracting text from them."
vlm_tools.py CHANGED
@@ -145,7 +145,7 @@ class ObjectDetectionTool(Tool):
145
  self.names_path = names_path
146
  self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
147
 
148
- def forward(self, frames: List[str])->List[List[str]]:
149
  # Load class labels
150
  with open(self.names_path, 'r') as f:
151
  classes = [line.strip() for line in f.readlines()]
@@ -163,22 +163,23 @@ class ObjectDetectionTool(Tool):
163
 
164
  return detected_objects
165
 
166
- @tool
167
- def ocr_scan(frames: List[str])->List[List[str]]:
168
- """
169
- Scan an image for text
170
- Args:
171
- frames: The list of frames (images) to scan for text
172
- Returns:
173
- The list of text in the images
174
- """
175
- scanned_text = []
176
- for frame in frames:
177
- image_data = base64.b64decode(frame)
178
- img = Image.open(BytesIO(image_data))
179
- scanned_text.append(pytesseract.image_to_string(img))
180
- return scanned_text
181
 
 
182
  object_detection_tool = ObjectDetectionTool()
183
 
184
 
 
145
  self.names_path = names_path
146
  self.onnx_model = onnxruntime.InferenceSession(self.onnx_path)
147
 
148
+ def forward(self, frames: any)->any:
149
  # Load class labels
150
  with open(self.names_path, 'r') as f:
151
  classes = [line.strip() for line in f.readlines()]
 
163
 
164
  return detected_objects
165
 
166
+ class OCRTool(Tool):
167
+ description = "Scan an image for text. It takes a list of frames (images) as input and returns a list of text in the images."
168
+ name = "ocr_scan"
169
+ inputs = {
170
+ "frames": {"type": "List[str]", "description": "The list of frames (images) to scan for text"}
171
+ }
172
+ output_type = "List[List[str]]"
173
+
174
+ def forward(self, frames: any)->any:
175
+ scanned_text = []
176
+ for frame in frames:
177
+ image_data = base64.b64decode(frame)
178
+ img = Image.open(BytesIO(image_data))
179
+ scanned_text.append(pytesseract.image_to_string(img))
180
+ return scanned_text
181
 
182
+ ocr_scan_tool = OCRTool()
183
  object_detection_tool = ObjectDetectionTool()
184
 
185