Spaces:
Running
Running
File size: 8,295 Bytes
d2a45b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
"""
Security module for AIBOM generator implementation.
This module provides security functions that can be integrated
into the AIBOM generator to improve input validation, error handling,
and protection against common web vulnerabilities.
"""
import re
import os
import json
import logging
from typing import Dict, Any, Optional, Union
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
def validate_model_id(model_id: str) -> str:
"""
Validate model ID to prevent injection attacks.
Args:
model_id: The model ID to validate
Returns:
The validated model ID
Raises:
ValueError: If the model ID contains invalid characters
"""
# Only allow alphanumeric characters, hyphens, underscores, and forward slashes
if not model_id or not isinstance(model_id, str):
raise ValueError("Model ID must be a non-empty string")
if not re.match(r'^[a-zA-Z0-9_\-/]+$', model_id):
raise ValueError(f"Invalid model ID format: {model_id}")
# Prevent path traversal attempts
if '..' in model_id:
raise ValueError(f"Invalid model ID - contains path traversal sequence: {model_id}")
return model_id
def safe_path_join(directory: str, filename: str) -> str:
"""
Safely join directory and filename to prevent path traversal attacks.
Args:
directory: Base directory
filename: Filename to append
Returns:
Safe file path
"""
# Ensure filename doesn't contain path traversal attempts
filename = os.path.basename(filename)
return os.path.join(directory, filename)
def safe_json_parse(json_string: str, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Safely parse JSON with error handling.
Args:
json_string: JSON string to parse
default: Default value to return if parsing fails
Returns:
Parsed JSON object or default value
"""
if default is None:
default = {}
try:
return json.loads(json_string)
except (json.JSONDecodeError, TypeError) as e:
logger.error(f"Invalid JSON: {e}")
return default
def sanitize_html_output(text: str) -> str:
"""
Sanitize text for safe HTML output to prevent XSS attacks.
Args:
text: Text to sanitize
Returns:
Sanitized text
"""
if not text or not isinstance(text, str):
return ""
# Replace HTML special characters with their entities
replacements = {
'&': '&',
'<': '<',
'>': '>',
'"': '"',
"'": ''',
'/': '/',
}
for char, entity in replacements.items():
text = text.replace(char, entity)
return text
def secure_file_operations(file_path: str, operation: str, content: Optional[str] = None) -> Union[str, bool]:
"""
Perform secure file operations with proper error handling.
Args:
file_path: Path to the file
operation: Operation to perform ('read', 'write', 'append')
content: Content to write (for 'write' and 'append' operations)
Returns:
File content for 'read' operation, True for successful 'write'/'append', False otherwise
"""
try:
if operation == 'read':
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif operation == 'write' and content is not None:
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
return True
elif operation == 'append' and content is not None:
with open(file_path, 'a', encoding='utf-8') as f:
f.write(content)
return True
else:
logger.error(f"Invalid file operation: {operation}")
return False
except Exception as e:
logger.error(f"File operation failed: {e}")
return "" if operation == 'read' else False
def validate_url(url: str) -> bool:
"""
Validate URL format to prevent malicious URL injection.
Args:
url: URL to validate
Returns:
True if URL is valid, False otherwise
"""
# Basic URL validation
url_pattern = re.compile(
r'^(https?):\/\/' # http:// or https://
r'(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*' # domain segments
r'([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])' # last domain segment
r'(:\d+)?' # optional port
r'(\/[-a-zA-Z0-9%_.~#+]*)*' # path
r'(\?[;&a-zA-Z0-9%_.~+=-]*)?' # query string
r'(\#[-a-zA-Z0-9%_.~+=/]*)?$' # fragment
)
return bool(url_pattern.match(url))
def secure_template_rendering(template_content: str, context: Dict[str, Any]) -> str:
"""
Render templates securely with auto-escaping enabled.
This is a placeholder function. In a real implementation, you would use
a template engine like Jinja2 with auto-escaping enabled.
Args:
template_content: Template content
context: Context variables for rendering
Returns:
Rendered template
"""
try:
from jinja2 import Template
template = Template(template_content, autoescape=True)
return template.render(**context)
except ImportError:
logger.error("Jinja2 not available, falling back to basic rendering")
# Very basic fallback (not recommended for production)
result = template_content
for key, value in context.items():
if isinstance(value, str):
placeholder = "{{" + key + "}}"
result = result.replace(placeholder, sanitize_html_output(value))
return result
except Exception as e:
logger.error(f"Template rendering failed: {e}")
return ""
def implement_rate_limiting(user_id: str, action: str, limit: int, period: int) -> bool:
"""
Implement basic rate limiting to prevent abuse.
This is a placeholder function. In a real implementation, you would use
a database or cache to track request counts.
Args:
user_id: Identifier for the user
action: Action being performed
limit: Maximum number of actions allowed
period: Time period in seconds
Returns:
True if action is allowed, False if rate limit exceeded
"""
# In a real implementation, you would:
# 1. Check if user has exceeded limit in the given period
# 2. If not, increment counter and allow action
# 3. If yes, deny action
# Placeholder implementation always allows action
logger.info(f"Rate limiting check for user {user_id}, action {action}")
return True
# Integration example for the AIBOM generator
def secure_aibom_generation(model_id: str, output_file: Optional[str] = None) -> Dict[str, Any]:
"""
Example of how to integrate security improvements into AIBOM generation.
Args:
model_id: Model ID to generate AIBOM for
output_file: Optional output file path
Returns:
Generated AIBOM data
"""
try:
# Validate input
validated_model_id = validate_model_id(model_id)
# Process model ID securely
# (This would call your actual AIBOM generation logic)
aibom_data = {"message": f"AIBOM for {validated_model_id}"}
# Handle output file securely if provided
if output_file:
safe_output_path = safe_path_join(os.path.dirname(output_file), os.path.basename(output_file))
secure_file_operations(safe_output_path, 'write', json.dumps(aibom_data, indent=2))
return aibom_data
except ValueError as e:
# Handle validation errors
logger.error(f"Validation error: {e}")
return {"error": "Invalid input parameters"}
except Exception as e:
# Handle unexpected errors
logger.error(f"AIBOM generation failed: {e}")
return {"error": "An internal error occurred"}
|