a1c00l commited on
Commit
d2a45b4
·
verified ·
1 Parent(s): 8402521

Upload aibom_security.py

Browse files
Files changed (1) hide show
  1. src/aibom_generator/aibom_security.py +256 -0
src/aibom_generator/aibom_security.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Security module for AIBOM generator implementation.
3
+
4
+ This module provides security functions that can be integrated
5
+ into the AIBOM generator to improve input validation, error handling,
6
+ and protection against common web vulnerabilities.
7
+ """
8
+
9
+ import re
10
+ import os
11
+ import json
12
+ import logging
13
+ from typing import Dict, Any, Optional, Union
14
+
15
+ # Set up logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def validate_model_id(model_id: str) -> str:
23
+ """
24
+ Validate model ID to prevent injection attacks.
25
+
26
+ Args:
27
+ model_id: The model ID to validate
28
+
29
+ Returns:
30
+ The validated model ID
31
+
32
+ Raises:
33
+ ValueError: If the model ID contains invalid characters
34
+ """
35
+ # Only allow alphanumeric characters, hyphens, underscores, and forward slashes
36
+ if not model_id or not isinstance(model_id, str):
37
+ raise ValueError("Model ID must be a non-empty string")
38
+
39
+ if not re.match(r'^[a-zA-Z0-9_\-/]+$', model_id):
40
+ raise ValueError(f"Invalid model ID format: {model_id}")
41
+
42
+ # Prevent path traversal attempts
43
+ if '..' in model_id:
44
+ raise ValueError(f"Invalid model ID - contains path traversal sequence: {model_id}")
45
+
46
+ return model_id
47
+
48
+ def safe_path_join(directory: str, filename: str) -> str:
49
+ """
50
+ Safely join directory and filename to prevent path traversal attacks.
51
+
52
+ Args:
53
+ directory: Base directory
54
+ filename: Filename to append
55
+
56
+ Returns:
57
+ Safe file path
58
+ """
59
+ # Ensure filename doesn't contain path traversal attempts
60
+ filename = os.path.basename(filename)
61
+ return os.path.join(directory, filename)
62
+
63
+ def safe_json_parse(json_string: str, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
64
+ """
65
+ Safely parse JSON with error handling.
66
+
67
+ Args:
68
+ json_string: JSON string to parse
69
+ default: Default value to return if parsing fails
70
+
71
+ Returns:
72
+ Parsed JSON object or default value
73
+ """
74
+ if default is None:
75
+ default = {}
76
+
77
+ try:
78
+ return json.loads(json_string)
79
+ except (json.JSONDecodeError, TypeError) as e:
80
+ logger.error(f"Invalid JSON: {e}")
81
+ return default
82
+
83
+ def sanitize_html_output(text: str) -> str:
84
+ """
85
+ Sanitize text for safe HTML output to prevent XSS attacks.
86
+
87
+ Args:
88
+ text: Text to sanitize
89
+
90
+ Returns:
91
+ Sanitized text
92
+ """
93
+ if not text or not isinstance(text, str):
94
+ return ""
95
+
96
+ # Replace HTML special characters with their entities
97
+ replacements = {
98
+ '&': '&',
99
+ '<': '&lt;',
100
+ '>': '&gt;',
101
+ '"': '&quot;',
102
+ "'": '&#x27;',
103
+ '/': '&#x2F;',
104
+ }
105
+
106
+ for char, entity in replacements.items():
107
+ text = text.replace(char, entity)
108
+
109
+ return text
110
+
111
+ def secure_file_operations(file_path: str, operation: str, content: Optional[str] = None) -> Union[str, bool]:
112
+ """
113
+ Perform secure file operations with proper error handling.
114
+
115
+ Args:
116
+ file_path: Path to the file
117
+ operation: Operation to perform ('read', 'write', 'append')
118
+ content: Content to write (for 'write' and 'append' operations)
119
+
120
+ Returns:
121
+ File content for 'read' operation, True for successful 'write'/'append', False otherwise
122
+ """
123
+ try:
124
+ if operation == 'read':
125
+ with open(file_path, 'r', encoding='utf-8') as f:
126
+ return f.read()
127
+ elif operation == 'write' and content is not None:
128
+ with open(file_path, 'w', encoding='utf-8') as f:
129
+ f.write(content)
130
+ return True
131
+ elif operation == 'append' and content is not None:
132
+ with open(file_path, 'a', encoding='utf-8') as f:
133
+ f.write(content)
134
+ return True
135
+ else:
136
+ logger.error(f"Invalid file operation: {operation}")
137
+ return False
138
+ except Exception as e:
139
+ logger.error(f"File operation failed: {e}")
140
+ return "" if operation == 'read' else False
141
+
142
+ def validate_url(url: str) -> bool:
143
+ """
144
+ Validate URL format to prevent malicious URL injection.
145
+
146
+ Args:
147
+ url: URL to validate
148
+
149
+ Returns:
150
+ True if URL is valid, False otherwise
151
+ """
152
+ # Basic URL validation
153
+ url_pattern = re.compile(
154
+ r'^(https?):\/\/' # http:// or https://
155
+ r'(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*' # domain segments
156
+ r'([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])' # last domain segment
157
+ r'(:\d+)?' # optional port
158
+ r'(\/[-a-zA-Z0-9%_.~#+]*)*' # path
159
+ r'(\?[;&a-zA-Z0-9%_.~+=-]*)?' # query string
160
+ r'(\#[-a-zA-Z0-9%_.~+=/]*)?$' # fragment
161
+ )
162
+
163
+ return bool(url_pattern.match(url))
164
+
165
+ def secure_template_rendering(template_content: str, context: Dict[str, Any]) -> str:
166
+ """
167
+ Render templates securely with auto-escaping enabled.
168
+
169
+ This is a placeholder function. In a real implementation, you would use
170
+ a template engine like Jinja2 with auto-escaping enabled.
171
+
172
+ Args:
173
+ template_content: Template content
174
+ context: Context variables for rendering
175
+
176
+ Returns:
177
+ Rendered template
178
+ """
179
+ try:
180
+ from jinja2 import Template
181
+ template = Template(template_content, autoescape=True)
182
+ return template.render(**context)
183
+ except ImportError:
184
+ logger.error("Jinja2 not available, falling back to basic rendering")
185
+ # Very basic fallback (not recommended for production)
186
+ result = template_content
187
+ for key, value in context.items():
188
+ if isinstance(value, str):
189
+ placeholder = "{{" + key + "}}"
190
+ result = result.replace(placeholder, sanitize_html_output(value))
191
+ return result
192
+ except Exception as e:
193
+ logger.error(f"Template rendering failed: {e}")
194
+ return ""
195
+
196
+ def implement_rate_limiting(user_id: str, action: str, limit: int, period: int) -> bool:
197
+ """
198
+ Implement basic rate limiting to prevent abuse.
199
+
200
+ This is a placeholder function. In a real implementation, you would use
201
+ a database or cache to track request counts.
202
+
203
+ Args:
204
+ user_id: Identifier for the user
205
+ action: Action being performed
206
+ limit: Maximum number of actions allowed
207
+ period: Time period in seconds
208
+
209
+ Returns:
210
+ True if action is allowed, False if rate limit exceeded
211
+ """
212
+ # In a real implementation, you would:
213
+ # 1. Check if user has exceeded limit in the given period
214
+ # 2. If not, increment counter and allow action
215
+ # 3. If yes, deny action
216
+
217
+ # Placeholder implementation always allows action
218
+ logger.info(f"Rate limiting check for user {user_id}, action {action}")
219
+ return True
220
+
221
+ # Integration example for the AIBOM generator
222
+ def secure_aibom_generation(model_id: str, output_file: Optional[str] = None) -> Dict[str, Any]:
223
+ """
224
+ Example of how to integrate security improvements into AIBOM generation.
225
+
226
+ Args:
227
+ model_id: Model ID to generate AIBOM for
228
+ output_file: Optional output file path
229
+
230
+ Returns:
231
+ Generated AIBOM data
232
+ """
233
+ try:
234
+ # Validate input
235
+ validated_model_id = validate_model_id(model_id)
236
+
237
+ # Process model ID securely
238
+ # (This would call your actual AIBOM generation logic)
239
+ aibom_data = {"message": f"AIBOM for {validated_model_id}"}
240
+
241
+ # Handle output file securely if provided
242
+ if output_file:
243
+ safe_output_path = safe_path_join(os.path.dirname(output_file), os.path.basename(output_file))
244
+ secure_file_operations(safe_output_path, 'write', json.dumps(aibom_data, indent=2))
245
+
246
+ return aibom_data
247
+
248
+ except ValueError as e:
249
+ # Handle validation errors
250
+ logger.error(f"Validation error: {e}")
251
+ return {"error": "Invalid input parameters"}
252
+
253
+ except Exception as e:
254
+ # Handle unexpected errors
255
+ logger.error(f"AIBOM generation failed: {e}")
256
+ return {"error": "An internal error occurred"}