MagicMeWizard's picture
Create config.py
f9f65ef verified
raw
history blame
6.54 kB
"""
Configuration settings for AI Web Scraper
Centralized configuration management for security, performance, and features
"""
import os
from typing import Dict, List, Optional
from dataclasses import dataclass
@dataclass
class SecurityConfig:
"""Security-related configuration"""
# URL validation settings
allowed_schemes: List[str] = None
blocked_domains: List[str] = None
max_url_length: int = 2048
# Rate limiting
requests_per_minute: int = 30
requests_per_hour: int = 500
# Content safety
max_content_size: int = 10 * 1024 * 1024 # 10MB
max_processing_time: int = 60 # seconds
def __post_init__(self):
if self.allowed_schemes is None:
self.allowed_schemes = ['http', 'https']
if self.blocked_domains is None:
self.blocked_domains = [
'localhost', '127.0.0.1', '0.0.0.0',
'192.168.', '10.', '172.16.', '172.17.',
'172.18.', '172.19.', '172.20.', '172.21.',
'172.22.', '172.23.', '172.24.', '172.25.',
'172.26.', '172.27.', '172.28.', '172.29.',
'172.30.', '172.31.'
]
@dataclass
class ModelConfig:
"""AI model configuration"""
# Primary summarization model
primary_model: str = "facebook/bart-large-cnn"
# Fallback model for faster processing
fallback_model: str = "sshleifer/distilbart-cnn-12-6"
# Model parameters
max_input_length: int = 1024
max_summary_length: int = 500
min_summary_length: int = 30
# Performance settings
device: str = "auto" # auto, cpu, cuda
batch_size: int = 1
use_fast_tokenizer: bool = True
@dataclass
class ScrapingConfig:
"""Web scraping configuration"""
# Request settings
timeout: int = 15
max_retries: int = 3
retry_delay: int = 1
# User agent string
user_agent: str = "Mozilla/5.0 (compatible; AI-WebScraper/1.0; Research Tool)"
# Content extraction
min_content_length: int = 100
max_content_length: int = 100000
# Robots.txt settings
respect_robots_txt: bool = True
robots_cache_duration: int = 3600 # seconds
@dataclass
class UIConfig:
"""User interface configuration"""
# Default values
default_summary_length: int = 300
max_summary_length: int = 500
min_summary_length: int = 100
# Interface settings
enable_batch_processing: bool = True
max_batch_size: int = 10
show_advanced_options: bool = False
# Export settings
supported_export_formats: List[str] = None
def __post_init__(self):
if self.supported_export_formats is None:
self.supported_export_formats = ["CSV", "JSON"]
class Config:
"""Main configuration class"""
def __init__(self):
self.security = SecurityConfig()
self.models = ModelConfig()
self.scraping = ScrapingConfig()
self.ui = UIConfig()
# Load from environment variables if available
self._load_from_env()
def _load_from_env(self):
"""Load configuration from environment variables"""
# Security settings
if os.getenv('MAX_REQUESTS_PER_MINUTE'):
self.security.requests_per_minute = int(os.getenv('MAX_REQUESTS_PER_MINUTE'))
if os.getenv('MAX_CONTENT_SIZE'):
self.security.max_content_size = int(os.getenv('MAX_CONTENT_SIZE'))
# Model settings
if os.getenv('PRIMARY_MODEL'):
self.models.primary_model = os.getenv('PRIMARY_MODEL')
if os.getenv('FALLBACK_MODEL'):
self.models.fallback_model = os.getenv('FALLBACK_MODEL')
if os.getenv('DEVICE'):
self.models.device = os.getenv('DEVICE')
# Scraping settings
if os.getenv('REQUEST_TIMEOUT'):
self.scraping.timeout = int(os.getenv('REQUEST_TIMEOUT'))
if os.getenv('USER_AGENT'):
self.scraping.user_agent = os.getenv('USER_AGENT')
if os.getenv('RESPECT_ROBOTS_TXT'):
self.scraping.respect_robots_txt = os.getenv('RESPECT_ROBOTS_TXT').lower() == 'true'
def get_model_device(self) -> str:
"""Get the appropriate device for model inference"""
if self.models.device == "auto":
try:
import torch
return "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
return "cpu"
return self.models.device
def is_url_allowed(self, url: str) -> bool:
"""Check if URL is allowed based on security settings"""
from urllib.parse import urlparse
try:
parsed = urlparse(url)
# Check scheme
if parsed.scheme not in self.security.allowed_schemes:
return False
# Check blocked domains
hostname = parsed.hostname or ''
for blocked in self.security.blocked_domains:
if blocked in hostname:
return False
# Check URL length
if len(url) > self.security.max_url_length:
return False
return True
except Exception:
return False
def get_request_headers(self) -> Dict[str, str]:
"""Get standard request headers"""
return {
'User-Agent': self.scraping.user_agent,
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
'Accept-Language': 'en-US,en;q=0.5',
'Accept-Encoding': 'gzip, deflate',
'Connection': 'keep-alive',
'Upgrade-Insecure-Requests': '1',
}
# Global configuration instance
config = Config()
# Environment-specific overrides for Hugging Face Spaces
if os.getenv('SPACE_ID'):
# Running on Hugging Face Spaces
config.models.device = "auto"
config.security.requests_per_minute = 20 # More conservative on shared infrastructure
config.scraping.timeout = 10 # Shorter timeout on shared infrastructure
# Enable GPU if available
if os.getenv('CUDA_VISIBLE_DEVICES'):
config.models.device = "cuda"
# Development mode overrides
if os.getenv('ENVIRONMENT') == 'development':
config.security.requests_per_minute = 100
config.scraping.timeout = 30
config.ui.show_advanced_options = True