kokluch commited on
Commit
a6ba120
·
1 Parent(s): 9a8c112

Add VirusTotal tools

Browse files
Gradio_UI.py CHANGED
@@ -141,7 +141,7 @@ def stream_to_gradio(
141
 
142
  for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
143
  # Track tokens if model provides them
144
- if hasattr(agent.model, "last_input_token_count"):
145
  total_input_tokens += agent.model.last_input_token_count
146
  total_output_tokens += agent.model.last_output_token_count
147
  if isinstance(step_log, ActionStep):
 
141
 
142
  for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
143
  # Track tokens if model provides them
144
+ if hasattr(agent.model, "last_input_token_count") and agent.model.last_input_token_count is not None:
145
  total_input_tokens += agent.model.last_input_token_count
146
  total_output_tokens += agent.model.last_output_token_count
147
  if isinstance(step_log, ActionStep):
app.py CHANGED
@@ -1,22 +1,17 @@
1
- from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
2
  import datetime
3
- import requests
 
4
  import pytz
 
5
  import yaml
6
- from tools.final_answer import FinalAnswerTool
7
 
8
  from Gradio_UI import GradioUI
 
 
 
 
9
 
10
- # Below is an example of a tool that does nothing. Amaze us with your creativity !
11
- @tool
12
- def my_cutom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return type
13
- #Keep this format for the description / args / args description but feel free to modify the tool
14
- """A tool that does nothing yet
15
- Args:
16
- arg1: the first argument
17
- arg2: the second argument
18
- """
19
- return "What magic will you build ?"
20
 
21
  @tool
22
  def get_current_time_in_timezone(timezone: str) -> str:
@@ -33,25 +28,198 @@ def get_current_time_in_timezone(timezone: str) -> str:
33
  except Exception as e:
34
  return f"Error fetching time for timezone '{timezone}': {str(e)}"
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  final_answer = FinalAnswerTool()
38
  model = HfApiModel(
39
- max_tokens=2096,
40
- temperature=0.5,
41
- model_id='https://wxknx1kg971u7k1n.us-east-1.aws.endpoints.huggingface.cloud',# it is possible that this model may be overloaded
42
- custom_role_conversions=None,
 
43
  )
44
 
45
-
46
  # Import tool from Hub
47
- image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
 
 
48
 
49
  with open("prompts.yaml", 'r') as stream:
50
  prompt_templates = yaml.safe_load(stream)
51
 
52
  agent = CodeAgent(
53
  model=model,
54
- tools=[final_answer], ## add your tools here (don't remove final answer)
55
  max_steps=6,
56
  verbosity_level=1,
57
  grammar=None,
@@ -61,5 +229,4 @@ agent = CodeAgent(
61
  prompt_templates=prompt_templates
62
  )
63
 
64
-
65
  GradioUI(agent).launch()
 
 
1
  import datetime
2
+ import os
3
+
4
  import pytz
5
+ import requests
6
  import yaml
7
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, tool
8
 
9
  from Gradio_UI import GradioUI
10
+ from classes.get_url_report import GetURLReportResponse, Data, Attributes, Stats
11
+ from classes.ip_address_report import IPAddressReport, TotalVotes, AnalysisStats
12
+ from classes.scan_url import DataAnalysis, Links, ScanResponse
13
+ from tools.final_answer import FinalAnswerTool
14
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  @tool
17
  def get_current_time_in_timezone(timezone: str) -> str:
 
28
  except Exception as e:
29
  return f"Error fetching time for timezone '{timezone}': {str(e)}"
30
 
31
+ @tool
32
+ def get_my_ip_address() -> str:
33
+ """
34
+ Retrieves the public IP address of the machine running this code.
35
+
36
+ Returns:
37
+ str: The public IP address.
38
+
39
+ Raises:
40
+ Exception: If the request to the external service fails.
41
+ """
42
+ url = "https://api.ipify.org?format=json"
43
+ response = requests.get(url)
44
+
45
+ if response.status_code == 200:
46
+ data = response.json()
47
+ return data["ip"]
48
+ else:
49
+ raise Exception(f"Failed to retrieve IP address: {response.status_code} - {response.text}")
50
+
51
+ @tool
52
+ def get_ip_address_report(ip_address: str) -> IPAddressReport:
53
+ """
54
+ Fetches the IP address report from the VirusTotal API and returns it as an IPAddressReport object.
55
+
56
+ Args:
57
+ ip_address: The IP address to fetch the report for.
58
+
59
+ Returns:
60
+ IPAddressReport: An object containing the IP address report.
61
+
62
+ Raises:
63
+ Exception: If the request to the VirusTotal API fails.
64
+ """
65
+ url = f"https://www.virustotal.com/api/v3/ip_addresses/{ip_address}"
66
+ headers = {
67
+ "accept": "application/json",
68
+ "x-apikey": os.getenv('VT_API_KEY')
69
+ }
70
+
71
+ response = requests.get(url, headers=headers)
72
+
73
+ if response.status_code == 200:
74
+ data = response.json()
75
+ report = IPAddressReport(
76
+ id=data["data"]["id"],
77
+ type=data["data"]["type"],
78
+ reputation=data["data"]["attributes"]["reputation"],
79
+ continent=data["data"]["attributes"]["continent"],
80
+ as_owner=data["data"]["attributes"]["as_owner"],
81
+ country=data["data"]["attributes"]["country"],
82
+ tags=data["data"]["attributes"]["tags"],
83
+ total_votes=TotalVotes(
84
+ harmless=data["data"]["attributes"]["total_votes"]["harmless"],
85
+ malicious=data["data"]["attributes"]["total_votes"]["malicious"]
86
+ ),
87
+ network=data["data"]["attributes"]["network"],
88
+ last_analysis_stats=AnalysisStats(
89
+ malicious=data["data"]["attributes"]["last_analysis_stats"]["malicious"],
90
+ suspicious=data["data"]["attributes"]["last_analysis_stats"]["suspicious"],
91
+ undetected=data["data"]["attributes"]["last_analysis_stats"]["undetected"],
92
+ harmless=data["data"]["attributes"]["last_analysis_stats"]["harmless"],
93
+ timeout=data["data"]["attributes"]["last_analysis_stats"]["timeout"]
94
+ )
95
+ )
96
+ return report
97
+ else:
98
+ raise Exception(f"Failed to retrieve data: {response.status_code} - {response.text}")
99
+
100
+ @tool
101
+ def scan_url(url: str) -> ScanResponse:
102
+ """
103
+ Request a scan of a given URL using the VirusTotal API.
104
+
105
+ Args:
106
+ url: The URL to scan.
107
+
108
+ Returns:
109
+ ScanResponse: The response from the VirusTotal API.
110
+
111
+ Raises:
112
+ Exception: If the request to the external service fails.
113
+ """
114
+
115
+ endpoint = "https://www.virustotal.com/api/v3/urls"
116
+
117
+ payload = { "url" : url }
118
+ headers = {
119
+ "accept": "application/json",
120
+ "x-apikey": os.getenv('VT_API_KEY') ,
121
+ "content-type": "application/x-www-form-urlencoded"
122
+ }
123
+
124
+ # Send a POST request to the VirusTotal API
125
+ response = requests.post(endpoint, headers=headers, data=payload)
126
+
127
+ print(response.text)
128
+
129
+ try:
130
+ # Raise an exception if the request was unsuccessful
131
+ response.raise_for_status()
132
+
133
+ response_json = response.json()
134
+
135
+ return ScanResponse(
136
+ data=DataAnalysis(
137
+ type=response_json["data"]["type"],
138
+ id=response_json["data"]["id"],
139
+ links=Links(self_url=response_json["data"]["links"]["self"])
140
+ )
141
+ )
142
+
143
+ except requests.exceptions.RequestException as e:
144
+ # Handle any errors that occur during the request
145
+ raise Exception(f"Failed to retrieve data: {response.status_code} - {response.text}")
146
+
147
+ @tool
148
+ def get_scan_report(scan: ScanResponse) -> GetURLReportResponse:
149
+ """
150
+ Fetch a report of a scan of a given URL using the VirusTotal API.
151
+
152
+ Args:
153
+ scan: The ScanResponse object returned by calling scan_url tool.
154
+
155
+ Returns:
156
+ GetURLReportResponse: The response from the VirusTotal API.
157
+
158
+ Raises:
159
+ Exception: If the request to the external service fails.
160
+ """
161
+
162
+ headers = {
163
+ "accept": "application/json",
164
+ "x-apikey": os.getenv('VT_API_KEY')
165
+ }
166
+
167
+ # Send a GET request to the VirusTotal API
168
+ response = requests.get(scan.data.links.self_url, headers=headers)
169
+
170
+ print(response.text)
171
+
172
+ try:
173
+ # Raise an exception if the request was unsuccessful
174
+ response.raise_for_status()
175
+
176
+ response_json = response.json()
177
+
178
+ # Creating an instance of the data class from the JSON response
179
+ response = GetURLReportResponse(
180
+ data=Data(
181
+ id=response_json["data"]["id"],
182
+ type=response_json["data"]["type"],
183
+ attributes=Attributes(
184
+ date=response_json["data"]["attributes"]["date"],
185
+ status=response_json["data"]["attributes"]["status"],
186
+ stats=Stats(
187
+ malicious=response_json["data"]["attributes"]["stats"]["malicious"],
188
+ suspicious=response_json["data"]["attributes"]["stats"]["suspicious"],
189
+ undetected=response_json["data"]["attributes"]["stats"]["undetected"],
190
+ harmless=response_json["data"]["attributes"]["stats"]["harmless"],
191
+ timeout=response_json["data"]["attributes"]["stats"]["timeout"]
192
+ )
193
+ )
194
+ )
195
+ )
196
+
197
+ return response
198
+
199
+ except requests.exceptions.RequestException as e:
200
+ # Handle any errors that occur during the request
201
+ raise Exception(f"Failed to retrieve data: {response.status_code} - {response.text}")
202
 
203
  final_answer = FinalAnswerTool()
204
  model = HfApiModel(
205
+ token=os.getenv('HF_TOKEN'),
206
+ max_tokens=2096,
207
+ temperature=0.5,
208
+ model_id=os.getenv('MODEL'),
209
+ custom_role_conversions=None,
210
  )
211
 
 
212
  # Import tool from Hub
213
+ # image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
214
+
215
+ web_search_tool = DuckDuckGoSearchTool()
216
 
217
  with open("prompts.yaml", 'r') as stream:
218
  prompt_templates = yaml.safe_load(stream)
219
 
220
  agent = CodeAgent(
221
  model=model,
222
+ tools=[final_answer, get_my_ip_address, get_ip_address_report, scan_url, get_scan_report], ## add your tools here (don't remove final answer)
223
  max_steps=6,
224
  verbosity_level=1,
225
  grammar=None,
 
229
  prompt_templates=prompt_templates
230
  )
231
 
 
232
  GradioUI(agent).launch()
classes/get_url_report.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from datetime import datetime
3
+
4
+ @dataclass
5
+ class Stats:
6
+ """
7
+ Represents the statistics of the analysis results.
8
+
9
+ Attributes:
10
+ malicious (int): Number of engines that detected the URL as malicious.
11
+ suspicious (int): Number of engines that detected the URL as suspicious.
12
+ undetected (int): Number of engines that did not detect the URL.
13
+ harmless (int): Number of engines that detected the URL as harmless.
14
+ timeout (int): Number of engines that timed out during the analysis.
15
+ """
16
+ malicious: int
17
+ suspicious: int
18
+ undetected: int
19
+ harmless: int
20
+ timeout: int
21
+
22
+ @dataclass
23
+ class Attributes:
24
+ """
25
+ Represents the attributes of the analysis. If status is queued, retry later.
26
+
27
+ Attributes:
28
+ date (int): The timestamp of the analysis in Unix epoch format.
29
+ status (str): The status of the analysis (e.g., "queued").
30
+ stats (Stats): An instance of the Stats class containing analysis statistics.
31
+ """
32
+ date: int
33
+ status: str
34
+ stats: Stats
35
+
36
+ def get_date_as_datetime(self) -> datetime:
37
+ """
38
+ Convert the Unix epoch timestamp to a datetime object.
39
+
40
+ Returns:
41
+ datetime: The datetime representation of the analysis date.
42
+ """
43
+ return datetime.fromtimestamp(self.date)
44
+
45
+ @dataclass
46
+ class Data:
47
+ """
48
+ Represents the data section of the VirusTotal analysis response.
49
+
50
+ Attributes:
51
+ id (str): The unique identifier for the analysis.
52
+ type (str): The type of the data, which is "analysis" in this context.
53
+ attributes (Attributes): An instance of the Attributes class containing analysis details.
54
+ """
55
+ id: str
56
+ type: str
57
+ attributes: Attributes
58
+
59
+ @dataclass
60
+ class GetURLReportResponse:
61
+ """
62
+ Represents the overall response from the VirusTotal API for a URL scan analysis.
63
+
64
+ Attributes:
65
+ data (Data): An instance of the Data class containing analysis details.
66
+ """
67
+ data: Data
68
+
classes/ip_address_report.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class TotalVotes:
5
+ """
6
+ Represents the total votes for an IP address.
7
+
8
+ Attributes:
9
+ harmless (int): The number of votes indicating the IP address is harmless.
10
+ malicious (int): The number of votes indicating the IP address is malicious.
11
+ """
12
+ harmless: int
13
+ malicious: int
14
+
15
+ @dataclass
16
+ class AnalysisStats:
17
+ """
18
+ Represents the statistics from the last analysis of an IP address.
19
+
20
+ Attributes:
21
+ malicious (int): The number of engines that detected the IP address as malicious.
22
+ suspicious (int): The number of engines that detected the IP address as suspicious.
23
+ undetected (int): The number of engines that did not detect any issues with the IP address.
24
+ harmless (int): The number of engines that detected the IP address as harmless.
25
+ timeout (int): The number of engines that timed out during the analysis.
26
+ """
27
+ malicious: int
28
+ suspicious: int
29
+ undetected: int
30
+ harmless: int
31
+ timeout: int
32
+
33
+ @dataclass
34
+ class IPAddressReport:
35
+ """
36
+ Represents a report for an IP address.
37
+
38
+ Attributes:
39
+ id (str): The ID of the IP address.
40
+ type (str): The type of the report (e.g., "ip_address").
41
+ reputation (int): The reputation score of the IP address.
42
+ continent (str): The continent where the IP address is located.
43
+ as_owner (str): The owner of the autonomous system (AS) associated with the IP address.
44
+ country (str): The country where the IP address is located.
45
+ tags (List[str]): A list of tags associated with the IP address.
46
+ total_votes (TotalVotes): The total votes for the IP address.
47
+ network (str): The network associated with the IP address.
48
+ last_analysis_stats (AnalysisStats): The statistics from the last analysis of the IP address.
49
+ """
50
+ id: str
51
+ type: str
52
+ reputation: int
53
+ continent: str
54
+ as_owner: str
55
+ country: str
56
+ tags: list
57
+ total_votes: TotalVotes
58
+ network: str
59
+ last_analysis_stats: AnalysisStats
classes/scan_url.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ @dataclass
5
+ class Links:
6
+ """
7
+ Represents the links associated with an analysis.
8
+
9
+ Attributes:
10
+ self_url (str): The URL to access the analysis resource itself.
11
+ """
12
+ self_url: str
13
+
14
+ @dataclass
15
+ class DataAnalysis:
16
+ """
17
+ Represents the data section of the VirusTotal analysis response.
18
+
19
+ Attributes:
20
+ type (str): The type of the data, which is "analysis" in this context.
21
+ id (str): The unique identifier for the analysis.
22
+ links (Links): An instance of the Links class containing related URLs.
23
+ """
24
+ type: str
25
+ id: str
26
+ links: Links
27
+
28
+ @dataclass
29
+ class ScanResponse:
30
+ """
31
+ Represents the overall response from the VirusTotal API for a URL scan analysis.
32
+
33
+ Attributes:
34
+ data (DataAnalysis): An instance of the DataAnalysis class containing analysis details.
35
+ """
36
+ data: DataAnalysis