File size: 2,922 Bytes
acb544e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import os
import re
from typing import Any, Dict, List, Literal, Optional
import structlog
import yaml
from pydantic import BaseModel, Field
LOGGER = structlog.getLogger(__name__)
_var_matcher = re.compile(r"\${([^}^{]+)}")
_tag_matcher = re.compile(r"[^$]*\${([^}^{]+)}.*")
class RateLimitConfig(BaseModel):
enabled: bool = Field(default=False)
limit: str = Field(default="100/minute")
class CacheConfig(BaseModel):
ttl: int = Field(default=60)
max_size: Optional[int] = Field(default=None)
class AuthConfig(BaseModel):
type: Literal["http_bearer", "http_basic"] = Field()
token: Optional[str] = Field(default=None)
username: Optional[str] = Field(default=None)
password: Optional[str] = Field(default=None)
class TracingConfig(BaseModel):
exporter: Literal["otel_http", "console"] = Field(default="console")
endpoint: Optional[str] = Field(default=None)
class MetricsConfig(BaseModel):
exporter: Literal["otel_http", "prometheus", "console"] = Field(default="console")
endpoint: Optional[str] = Field(default=None)
class AppConfig(BaseModel):
name: Optional[str] = Field(default="LLM Guard API")
port: Optional[int] = Field(default=7860)
log_level: Optional[str] = Field(default="INFO")
scan_fail_fast: Optional[bool] = Field(default=False)
scan_prompt_timeout: Optional[int] = Field(default=10)
scan_output_timeout: Optional[int] = Field(default=30)
class ScannerConfig(BaseModel):
type: str
params: Optional[Dict] = Field(default_factory=dict)
class Config(BaseModel):
input_scanners: List[ScannerConfig] = Field()
output_scanners: List[ScannerConfig] = Field()
rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig)
cache: CacheConfig = Field(default_factory=CacheConfig)
auth: Optional[AuthConfig] = Field(default=None)
app: AppConfig = Field(default_factory=AppConfig)
tracing: Optional[TracingConfig] = Field(default=None)
metrics: Optional[MetricsConfig] = Field(default=None)
def _path_constructor(_loader: Any, node: Any):
def replace_fn(match):
envparts = f"{match.group(1)}:".split(":")
return os.environ.get(envparts[0], envparts[1])
return _var_matcher.sub(replace_fn, node.value)
def load_yaml(filename: str) -> dict:
yaml.add_implicit_resolver("!envvar", _tag_matcher, None, yaml.SafeLoader)
yaml.add_constructor("!envvar", _path_constructor, yaml.SafeLoader)
try:
with open(filename, "r") as f:
return yaml.safe_load(f.read())
except (FileNotFoundError, PermissionError, yaml.YAMLError) as exc:
LOGGER.error("Error loading YAML file", exception=exc)
return dict()
def get_config(file_name: str) -> Optional[Config]:
LOGGER.debug("Loading config file", file_name=file_name)
conf = load_yaml(file_name)
if conf == {}:
return None
return Config(**conf)
|