File size: 8,295 Bytes
d2a45b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""
Security module for AIBOM generator implementation.

This module provides security functions that can be integrated
into the AIBOM generator to improve input validation, error handling,
and protection against common web vulnerabilities.
"""

import re
import os
import json
import logging
from typing import Dict, Any, Optional, Union

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

def validate_model_id(model_id: str) -> str:
    """
    Validate model ID to prevent injection attacks.
    
    Args:
        model_id: The model ID to validate
        
    Returns:
        The validated model ID
        
    Raises:
        ValueError: If the model ID contains invalid characters
    """
    # Only allow alphanumeric characters, hyphens, underscores, and forward slashes
    if not model_id or not isinstance(model_id, str):
        raise ValueError("Model ID must be a non-empty string")
        
    if not re.match(r'^[a-zA-Z0-9_\-/]+$', model_id):
        raise ValueError(f"Invalid model ID format: {model_id}")
    
    # Prevent path traversal attempts
    if '..' in model_id:
        raise ValueError(f"Invalid model ID - contains path traversal sequence: {model_id}")
        
    return model_id

def safe_path_join(directory: str, filename: str) -> str:
    """
    Safely join directory and filename to prevent path traversal attacks.
    
    Args:
        directory: Base directory
        filename: Filename to append
        
    Returns:
        Safe file path
    """
    # Ensure filename doesn't contain path traversal attempts
    filename = os.path.basename(filename)
    return os.path.join(directory, filename)

def safe_json_parse(json_string: str, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    """
    Safely parse JSON with error handling.
    
    Args:
        json_string: JSON string to parse
        default: Default value to return if parsing fails
        
    Returns:
        Parsed JSON object or default value
    """
    if default is None:
        default = {}
        
    try:
        return json.loads(json_string)
    except (json.JSONDecodeError, TypeError) as e:
        logger.error(f"Invalid JSON: {e}")
        return default

def sanitize_html_output(text: str) -> str:
    """
    Sanitize text for safe HTML output to prevent XSS attacks.
    
    Args:
        text: Text to sanitize
        
    Returns:
        Sanitized text
    """
    if not text or not isinstance(text, str):
        return ""
        
    # Replace HTML special characters with their entities
    replacements = {
        '&': '&',
        '<': '&lt;',
        '>': '&gt;',
        '"': '&quot;',
        "'": '&#x27;',
        '/': '&#x2F;',
    }
    
    for char, entity in replacements.items():
        text = text.replace(char, entity)
        
    return text

def secure_file_operations(file_path: str, operation: str, content: Optional[str] = None) -> Union[str, bool]:
    """
    Perform secure file operations with proper error handling.
    
    Args:
        file_path: Path to the file
        operation: Operation to perform ('read', 'write', 'append')
        content: Content to write (for 'write' and 'append' operations)
        
    Returns:
        File content for 'read' operation, True for successful 'write'/'append', False otherwise
    """
    try:
        if operation == 'read':
            with open(file_path, 'r', encoding='utf-8') as f:
                return f.read()
        elif operation == 'write' and content is not None:
            with open(file_path, 'w', encoding='utf-8') as f:
                f.write(content)
            return True
        elif operation == 'append' and content is not None:
            with open(file_path, 'a', encoding='utf-8') as f:
                f.write(content)
            return True
        else:
            logger.error(f"Invalid file operation: {operation}")
            return False
    except Exception as e:
        logger.error(f"File operation failed: {e}")
        return "" if operation == 'read' else False

def validate_url(url: str) -> bool:
    """
    Validate URL format to prevent malicious URL injection.
    
    Args:
        url: URL to validate
        
    Returns:
        True if URL is valid, False otherwise
    """
    # Basic URL validation
    url_pattern = re.compile(
        r'^(https?):\/\/'  # http:// or https://
        r'(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*'  # domain segments
        r'([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])'  # last domain segment
        r'(:\d+)?'  # optional port
        r'(\/[-a-zA-Z0-9%_.~#+]*)*'  # path
        r'(\?[;&a-zA-Z0-9%_.~+=-]*)?'  # query string
        r'(\#[-a-zA-Z0-9%_.~+=/]*)?$'  # fragment
    )
    
    return bool(url_pattern.match(url))

def secure_template_rendering(template_content: str, context: Dict[str, Any]) -> str:
    """
    Render templates securely with auto-escaping enabled.
    
    This is a placeholder function. In a real implementation, you would use
    a template engine like Jinja2 with auto-escaping enabled.
    
    Args:
        template_content: Template content
        context: Context variables for rendering
        
    Returns:
        Rendered template
    """
    try:
        from jinja2 import Template
        template = Template(template_content, autoescape=True)
        return template.render(**context)
    except ImportError:
        logger.error("Jinja2 not available, falling back to basic rendering")
        # Very basic fallback (not recommended for production)
        result = template_content
        for key, value in context.items():
            if isinstance(value, str):
                placeholder = "{{" + key + "}}"
                result = result.replace(placeholder, sanitize_html_output(value))
        return result
    except Exception as e:
        logger.error(f"Template rendering failed: {e}")
        return ""

def implement_rate_limiting(user_id: str, action: str, limit: int, period: int) -> bool:
    """
    Implement basic rate limiting to prevent abuse.
    
    This is a placeholder function. In a real implementation, you would use
    a database or cache to track request counts.
    
    Args:
        user_id: Identifier for the user
        action: Action being performed
        limit: Maximum number of actions allowed
        period: Time period in seconds
        
    Returns:
        True if action is allowed, False if rate limit exceeded
    """
    # In a real implementation, you would:
    # 1. Check if user has exceeded limit in the given period
    # 2. If not, increment counter and allow action
    # 3. If yes, deny action
    
    # Placeholder implementation always allows action
    logger.info(f"Rate limiting check for user {user_id}, action {action}")
    return True

# Integration example for the AIBOM generator
def secure_aibom_generation(model_id: str, output_file: Optional[str] = None) -> Dict[str, Any]:
    """
    Example of how to integrate security improvements into AIBOM generation.
    
    Args:
        model_id: Model ID to generate AIBOM for
        output_file: Optional output file path
        
    Returns:
        Generated AIBOM data
    """
    try:
        # Validate input
        validated_model_id = validate_model_id(model_id)
        
        # Process model ID securely
        # (This would call your actual AIBOM generation logic)
        aibom_data = {"message": f"AIBOM for {validated_model_id}"}
        
        # Handle output file securely if provided
        if output_file:
            safe_output_path = safe_path_join(os.path.dirname(output_file), os.path.basename(output_file))
            secure_file_operations(safe_output_path, 'write', json.dumps(aibom_data, indent=2))
            
        return aibom_data
        
    except ValueError as e:
        # Handle validation errors
        logger.error(f"Validation error: {e}")
        return {"error": "Invalid input parameters"}
        
    except Exception as e:
        # Handle unexpected errors
        logger.error(f"AIBOM generation failed: {e}")
        return {"error": "An internal error occurred"}