altozachmo commited on
Commit
9ff7774
·
1 Parent(s): 50df397

remove logging change

Browse files
agents/agent.py CHANGED
@@ -8,6 +8,7 @@ from tools.text_search import TextSearch
8
  from tools.text_splitter import text_splitter
9
  from tools.video_analyzer import YouTubeObjectCounterTool
10
 
 
11
  class MyAgent:
12
  def __init__(
13
  self,
@@ -45,12 +46,11 @@ class MyAgent:
45
  DuckDuckGoSearchTool(), # Search tool for web queries
46
  WikipediaSearchTool(), # Search tool for Wikipedia queries
47
  TextSearch(), # Search tool for text queries
48
- text_splitter, # Text splitter tool for breaking down large texts
49
- # into manageable lists.
50
- YouTubeObjectCounterTool(), # Tool for analyzing YouTube videos
51
  ]
52
 
53
-
54
  # Initialize the agent with the specified provider and model ID
55
  if provider == "litellm":
56
  self.agent = CodeAgent(
 
8
  from tools.text_splitter import text_splitter
9
  from tools.video_analyzer import YouTubeObjectCounterTool
10
 
11
+
12
  class MyAgent:
13
  def __init__(
14
  self,
 
46
  DuckDuckGoSearchTool(), # Search tool for web queries
47
  WikipediaSearchTool(), # Search tool for Wikipedia queries
48
  TextSearch(), # Search tool for text queries
49
+ text_splitter, # Text splitter tool for breaking down large texts
50
+ # into manageable lists.
51
+ YouTubeObjectCounterTool(), # Tool for analyzing YouTube videos
52
  ]
53
 
 
54
  # Initialize the agent with the specified provider and model ID
55
  if provider == "litellm":
56
  self.agent = CodeAgent(
app.py CHANGED
@@ -70,7 +70,11 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
70
  results_log = []
71
  answers_payload = []
72
  print(f"Running agent on {len(questions_data)} questions...")
73
- for item in tqdm(questions_data[0:3], desc="Agent is answering questions...", total=len(questions_data)):
 
 
 
 
74
  task_id = item.get("task_id")
75
  question_text = item.get("question")
76
  if not task_id or question_text is None:
@@ -78,7 +82,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
78
  continue
79
  try:
80
  submitted_answer = agent(question_text)
81
- time.sleep(30) # to avoid rate limiting
82
  answers_payload.append(
83
  {"task_id": task_id, "submitted_answer": submitted_answer}
84
  )
 
70
  results_log = []
71
  answers_payload = []
72
  print(f"Running agent on {len(questions_data)} questions...")
73
+ for item in tqdm(
74
+ questions_data[0:3],
75
+ desc="Agent is answering questions...",
76
+ total=len(questions_data),
77
+ ):
78
  task_id = item.get("task_id")
79
  question_text = item.get("question")
80
  if not task_id or question_text is None:
 
82
  continue
83
  try:
84
  submitted_answer = agent(question_text)
85
+ time.sleep(30) # to avoid rate limiting
86
  answers_payload.append(
87
  {"task_id": task_id, "submitted_answer": submitted_answer}
88
  )
run_local_agent.py CHANGED
@@ -4,6 +4,7 @@ from utils import run_agent
4
  import os
5
  import json
6
  from dotenv import load_dotenv
 
7
  load_dotenv()
8
 
9
  QUESTIONS_FILEPATH: str = os.getenv("QUESTIONS_FILEPATH", default="metadata.jsonl")
 
4
  import os
5
  import json
6
  from dotenv import load_dotenv
7
+
8
  load_dotenv()
9
 
10
  QUESTIONS_FILEPATH: str = os.getenv("QUESTIONS_FILEPATH", default="metadata.jsonl")
test.py CHANGED
@@ -1,5 +1,6 @@
1
  from smolagents import LiteLLMModel, OpenAIServerModel
2
  from dotenv import load_dotenv
 
3
  load_dotenv()
4
 
5
  model_id = "ollama_chat/mistral-small3.1:latest"
 
1
  from smolagents import LiteLLMModel, OpenAIServerModel
2
  from dotenv import load_dotenv
3
+
4
  load_dotenv()
5
 
6
  model_id = "ollama_chat/mistral-small3.1:latest"
tools/text_search.py CHANGED
@@ -1,5 +1,6 @@
1
  from smolagents import Tool
2
 
 
3
  class TextSearch(Tool):
4
  name: str = "text_search_tool"
5
  description: str = "This tool searches through a string for substrings and returns the indices of all occurances of that substring."
@@ -11,7 +12,7 @@ class TextSearch(Tool):
11
  "search_text": {
12
  "type": "string",
13
  "description": "The text to search for within source_text.",
14
- }
15
  }
16
  output_type: str = "array"
17
 
 
1
  from smolagents import Tool
2
 
3
+
4
  class TextSearch(Tool):
5
  name: str = "text_search_tool"
6
  description: str = "This tool searches through a string for substrings and returns the indices of all occurances of that substring."
 
12
  "search_text": {
13
  "type": "string",
14
  "description": "The text to search for within source_text.",
15
+ },
16
  }
17
  output_type: str = "array"
18
 
tools/text_splitter.py CHANGED
@@ -1,10 +1,11 @@
1
  from smolagents import tool
2
 
 
3
  @tool
4
  def text_splitter(text: str, separator: str = "\n") -> list[str]:
5
  """
6
- Splits the input text string into a list on `separator` which
7
- defaults to the newline character. This is useful for when
8
  you need to browse through a large text file that may contain
9
  a list your are interested in.
10
 
 
1
  from smolagents import tool
2
 
3
+
4
  @tool
5
  def text_splitter(text: str, separator: str = "\n") -> list[str]:
6
  """
7
+ Splits the input text string into a list on `separator` which
8
+ defaults to the newline character. This is useful for when
9
  you need to browse through a large text file that may contain
10
  a list your are interested in.
11
 
tools/video_analyzer.py CHANGED
@@ -6,8 +6,7 @@ from yt_dlp import YoutubeDL
6
  from transformers import pipeline
7
  from typing import Any
8
  from PIL import Image
9
- import numpy as np
10
- from transformers import logging
11
 
12
  class YouTubeObjectCounterTool(Tool):
13
  name = "youtube_object_counter"
@@ -15,12 +14,12 @@ class YouTubeObjectCounterTool(Tool):
15
  inputs = {
16
  "url": {
17
  "type": "string",
18
- "description": "The URL of the YouTube video to analyze."
19
  },
20
  "label": {
21
  "type": "string",
22
- "description": "The type of object to count (e.g., 'bird', 'person', 'car', 'dog'). Use common object names recognized by standard object detection models."
23
- }
24
  }
25
  output_type = "string"
26
 
@@ -28,16 +27,16 @@ class YouTubeObjectCounterTool(Tool):
28
  """Downloads the YouTube video to a temporary file."""
29
  print(f"Downloading video from {url}...")
30
  temp_dir = tempfile.mkdtemp()
31
-
32
  video_path = os.path.join(temp_dir, "video.mp4")
33
-
34
  ydl_opts = {
35
- 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
36
- 'outtmpl': video_path,
37
- 'quiet': True,
38
- 'no_warnings': True
39
  }
40
-
41
  try:
42
  with YoutubeDL(ydl_opts) as ydl:
43
  ydl.download([url])
@@ -50,22 +49,24 @@ class YouTubeObjectCounterTool(Tool):
50
 
51
  def _count_objects_in_frame(self, frame, label: str):
52
  """Counts objects of specified label in a single frame using the object detection model."""
53
-
54
  try:
55
  # Convert OpenCV BGR frame to RGB
56
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
57
-
58
  # Convert numpy array to PIL Image
59
  pil_image = Image.fromarray(rgb_frame)
60
-
61
  # Load the detector
62
  detector = pipeline("object-detection", model="facebook/detr-resnet-50")
63
-
64
  # Run detection with PIL Image
65
  results = detector(pil_image)
66
-
67
  # Count objects matching the label
68
- object_count = sum(1 for result in results if label.lower() in result['label'].lower())
 
 
69
  return object_count
70
  except Exception as e:
71
  print(f"Error detecting objects in frame: {str(e)}")
@@ -74,65 +75,73 @@ class YouTubeObjectCounterTool(Tool):
74
  def _analyze_video(self, video_path: str, label: str) -> dict[str, Any]:
75
  """Analyzes the video frame by frame and counts objects of the specified label."""
76
  sample_rate = 30
77
- print(f"Analyzing video {video_path}, looking for '{label}' objects, sampling every {sample_rate} frames...")
78
-
 
 
79
  # Open the video file
80
  cap = cv2.VideoCapture(video_path)
81
  if not cap.isOpened():
82
  raise RuntimeError(f"Error: Could not open video file {video_path}")
83
-
84
  # Get video properties
85
  fps = cap.get(cv2.CAP_PROP_FPS)
86
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
87
  duration = frame_count / fps
88
-
89
  # Initialize results
90
  frame_results = []
91
  total_objects = 0
92
  max_objects = 0
93
  max_objects_frame = 0
94
  frame_idx = 0
95
-
96
  # Process frames
97
  while cap.isOpened():
98
  ret, frame = cap.read()
99
  if not ret:
100
  break
101
-
102
  # Only process every nth frame
103
  if frame_idx % sample_rate == 0:
104
  time_point = frame_idx / fps
105
  print(f"Processing frame {frame_idx} at time {time_point:.2f}s...")
106
-
107
  object_count = self._count_objects_in_frame(frame, label)
108
  total_objects += object_count
109
-
110
  if object_count > max_objects:
111
  max_objects = object_count
112
  max_objects_frame = frame_idx
113
-
114
- frame_results.append({
115
- "frame": frame_idx,
116
- "time": time_point,
117
- "object_count": object_count
118
- })
119
-
 
 
120
  frame_idx += 1
121
-
122
  # Release resources
123
  cap.release()
124
-
125
  # Calculate statistics
126
- avg_objects_per_frame = total_objects / len(frame_results) if frame_results else 0
 
 
127
  max_objects_time = max_objects_frame / fps if max_objects_frame else 0
128
-
129
  # Clean up the temporary file
130
  try:
131
  os.remove(video_path)
132
  print(f"Deleted temporary video file: {video_path}")
133
  except Exception as e:
134
- print(f"Warning: Failed to delete temporary video file: {video_path} | {str(e)}")
135
-
 
 
136
  return {
137
  "frame_results": frame_results,
138
  "total_frames_analyzed": len(frame_results),
@@ -143,48 +152,48 @@ class YouTubeObjectCounterTool(Tool):
143
  "max_objects_in_single_frame": max_objects,
144
  "max_objects_frame": max_objects_frame,
145
  "max_objects_time": max_objects_time,
146
- "label": label
147
  }
148
 
149
  def forward(self, url: str, label: str) -> str:
150
  """
151
  Analyzes a YouTube video frame by frame and counts objects of the specified type.
152
-
153
  Args:
154
  url (str): The URL of the YouTube video to analyze.
155
  label (str): The type of object to count (e.g., 'bird', 'person', 'car', 'dog').
156
-
157
  Returns:
158
  str: A detailed report of object counts per frame and summary statistics.
159
  """
160
 
161
- logging.set_verbosity_error()
162
  try:
163
  # Download the video
164
  video_path = self._download_video(url)
165
-
166
  # Analyze the video
167
  results = self._analyze_video(video_path, label)
168
-
169
  # Generate a report
170
  report = [
171
  f"# {label.title()} Count Analysis for YouTube Video",
172
  f"Video URL: {url}",
173
  f"Video duration: {results['video_duration']:.2f} seconds",
174
  f"Analyzed {results['total_frames_analyzed']} frames out of {results['total_frames']} total frames",
175
- f"Sampling rate: 1 frame every 30 frames (approximately {results['fps']/30:.2f} frames per second)",
176
  "## Summary",
177
  f"Average {label}s per analyzed frame: {results['average_objects_per_analyzed_frame']:.2f}",
178
  f"Maximum {label}s in a single frame: {results['max_objects_in_single_frame']} (at {results['max_objects_time']:.2f} seconds)",
179
  ]
180
-
181
  # Add frame-by-frame details
182
  report.append("## Frame-by-Frame Analysis")
183
  for result in results["frame_results"]:
184
- report.append(f"Frame {result['frame']} (Time: {result['time']:.2f}s): {result['object_count']} {label}s")
185
-
 
 
186
  return "\n".join(report)
187
-
188
  except Exception as e:
189
  return f"Error analyzing video: {str(e)}"
190
-
 
6
  from transformers import pipeline
7
  from typing import Any
8
  from PIL import Image
9
+
 
10
 
11
  class YouTubeObjectCounterTool(Tool):
12
  name = "youtube_object_counter"
 
14
  inputs = {
15
  "url": {
16
  "type": "string",
17
+ "description": "The URL of the YouTube video to analyze.",
18
  },
19
  "label": {
20
  "type": "string",
21
+ "description": "The type of object to count (e.g., 'bird', 'person', 'car', 'dog'). Use common object names recognized by standard object detection models.",
22
+ },
23
  }
24
  output_type = "string"
25
 
 
27
  """Downloads the YouTube video to a temporary file."""
28
  print(f"Downloading video from {url}...")
29
  temp_dir = tempfile.mkdtemp()
30
+
31
  video_path = os.path.join(temp_dir, "video.mp4")
32
+
33
  ydl_opts = {
34
+ "format": "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
35
+ "outtmpl": video_path,
36
+ "quiet": True,
37
+ "no_warnings": True,
38
  }
39
+
40
  try:
41
  with YoutubeDL(ydl_opts) as ydl:
42
  ydl.download([url])
 
49
 
50
  def _count_objects_in_frame(self, frame, label: str):
51
  """Counts objects of specified label in a single frame using the object detection model."""
52
+
53
  try:
54
  # Convert OpenCV BGR frame to RGB
55
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
56
+
57
  # Convert numpy array to PIL Image
58
  pil_image = Image.fromarray(rgb_frame)
59
+
60
  # Load the detector
61
  detector = pipeline("object-detection", model="facebook/detr-resnet-50")
62
+
63
  # Run detection with PIL Image
64
  results = detector(pil_image)
65
+
66
  # Count objects matching the label
67
+ object_count = sum(
68
+ 1 for result in results if label.lower() in result["label"].lower()
69
+ )
70
  return object_count
71
  except Exception as e:
72
  print(f"Error detecting objects in frame: {str(e)}")
 
75
  def _analyze_video(self, video_path: str, label: str) -> dict[str, Any]:
76
  """Analyzes the video frame by frame and counts objects of the specified label."""
77
  sample_rate = 30
78
+ print(
79
+ f"Analyzing video {video_path}, looking for '{label}' objects, sampling every {sample_rate} frames..."
80
+ )
81
+
82
  # Open the video file
83
  cap = cv2.VideoCapture(video_path)
84
  if not cap.isOpened():
85
  raise RuntimeError(f"Error: Could not open video file {video_path}")
86
+
87
  # Get video properties
88
  fps = cap.get(cv2.CAP_PROP_FPS)
89
  frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
90
  duration = frame_count / fps
91
+
92
  # Initialize results
93
  frame_results = []
94
  total_objects = 0
95
  max_objects = 0
96
  max_objects_frame = 0
97
  frame_idx = 0
98
+
99
  # Process frames
100
  while cap.isOpened():
101
  ret, frame = cap.read()
102
  if not ret:
103
  break
104
+
105
  # Only process every nth frame
106
  if frame_idx % sample_rate == 0:
107
  time_point = frame_idx / fps
108
  print(f"Processing frame {frame_idx} at time {time_point:.2f}s...")
109
+
110
  object_count = self._count_objects_in_frame(frame, label)
111
  total_objects += object_count
112
+
113
  if object_count > max_objects:
114
  max_objects = object_count
115
  max_objects_frame = frame_idx
116
+
117
+ frame_results.append(
118
+ {
119
+ "frame": frame_idx,
120
+ "time": time_point,
121
+ "object_count": object_count,
122
+ }
123
+ )
124
+
125
  frame_idx += 1
126
+
127
  # Release resources
128
  cap.release()
129
+
130
  # Calculate statistics
131
+ avg_objects_per_frame = (
132
+ total_objects / len(frame_results) if frame_results else 0
133
+ )
134
  max_objects_time = max_objects_frame / fps if max_objects_frame else 0
135
+
136
  # Clean up the temporary file
137
  try:
138
  os.remove(video_path)
139
  print(f"Deleted temporary video file: {video_path}")
140
  except Exception as e:
141
+ print(
142
+ f"Warning: Failed to delete temporary video file: {video_path} | {str(e)}"
143
+ )
144
+
145
  return {
146
  "frame_results": frame_results,
147
  "total_frames_analyzed": len(frame_results),
 
152
  "max_objects_in_single_frame": max_objects,
153
  "max_objects_frame": max_objects_frame,
154
  "max_objects_time": max_objects_time,
155
+ "label": label,
156
  }
157
 
158
  def forward(self, url: str, label: str) -> str:
159
  """
160
  Analyzes a YouTube video frame by frame and counts objects of the specified type.
161
+
162
  Args:
163
  url (str): The URL of the YouTube video to analyze.
164
  label (str): The type of object to count (e.g., 'bird', 'person', 'car', 'dog').
165
+
166
  Returns:
167
  str: A detailed report of object counts per frame and summary statistics.
168
  """
169
 
 
170
  try:
171
  # Download the video
172
  video_path = self._download_video(url)
173
+
174
  # Analyze the video
175
  results = self._analyze_video(video_path, label)
176
+
177
  # Generate a report
178
  report = [
179
  f"# {label.title()} Count Analysis for YouTube Video",
180
  f"Video URL: {url}",
181
  f"Video duration: {results['video_duration']:.2f} seconds",
182
  f"Analyzed {results['total_frames_analyzed']} frames out of {results['total_frames']} total frames",
183
+ f"Sampling rate: 1 frame every 30 frames (approximately {results['fps'] / 30:.2f} frames per second)",
184
  "## Summary",
185
  f"Average {label}s per analyzed frame: {results['average_objects_per_analyzed_frame']:.2f}",
186
  f"Maximum {label}s in a single frame: {results['max_objects_in_single_frame']} (at {results['max_objects_time']:.2f} seconds)",
187
  ]
188
+
189
  # Add frame-by-frame details
190
  report.append("## Frame-by-Frame Analysis")
191
  for result in results["frame_results"]:
192
+ report.append(
193
+ f"Frame {result['frame']} (Time: {result['time']:.2f}s): {result['object_count']} {label}s"
194
+ )
195
+
196
  return "\n".join(report)
197
+
198
  except Exception as e:
199
  return f"Error analyzing video: {str(e)}"