mcp-deepfake-forensics / agents /ensemble_weights.py
LPX55's picture
feat: integrate ONNX model inference and logging enhancements, add contextual intelligence and forensic anomaly detection agents
e1eac06
import logging
import torch
from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
logger = logging.getLogger(__name__)
class ContextualWeightOverrideAgent:
def __init__(self):
logger.info("Initializing ContextualWeightOverrideAgent.")
self.context_overrides = {
# Example: when image is outdoor, model_X is penalized, model_Y is boosted
"outdoor": {
"model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
"model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
},
"low_light": {
"model_2": 0.7,
"model_7": 1.3,
},
"sunny": {
"model_3": 0.9,
"model_4": 1.1,
}
# Add more contexts and their specific model weight adjustments here
}
def get_overrides(self, context_tags: list[str]) -> dict:
logger.info(f"Getting weight overrides for context tags: {context_tags}")
combined_overrides = {}
for tag in context_tags:
if tag in self.context_overrides:
for model_id, multiplier in self.context_overrides[tag].items():
# If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
# For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
logger.info(f"Combined context overrides: {combined_overrides}")
return combined_overrides
class ModelWeightManager:
def __init__(self, strongest_model_id: str = None):
logger.info(f"Initializing ModelWeightManager with strongest_model_id: {strongest_model_id}")
# Dynamically initialize base_weights from MODEL_REGISTRY
num_models = len(MODEL_REGISTRY)
if num_models > 0:
if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
logger.info(f"Designating '{strongest_model_id}' as the strongest model.")
# Assign a high weight to the strongest model (e.g., 50%)
strongest_weight_share = 0.5
self.base_weights = {strongest_model_id: strongest_weight_share}
remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
if remaining_models:
other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models)
for model_id in remaining_models:
self.base_weights[model_id] = other_models_weight_share
else: # Only one model, which is the strongest
self.base_weights[strongest_model_id] = 1.0
else:
if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
logger.warning(f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.")
initial_weight = 1.0 / num_models
self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
else:
self.base_weights = {} # Handle case with no registered models
logger.info(f"Base weights initialized: {self.base_weights}")
self.situation_weights = {
"high_confidence": 1.2, # Boost weights for high confidence predictions
"low_confidence": 0.8, # Reduce weights for low confidence
"conflict": 0.5, # Reduce weights when models disagree
"consensus": 1.5 # Boost weights when models agree
}
self.context_override_agent = ContextualWeightOverrideAgent()
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
"""Dynamically adjust weights based on prediction patterns and optional context."""
logger.info("Adjusting model weights.")
adjusted_weights = self.base_weights.copy()
logger.info(f"Initial adjusted weights (copy of base): {adjusted_weights}")
# 1. Apply contextual overrides first
if context_tags:
logger.info(f"Applying contextual overrides for tags: {context_tags}")
overrides = self.context_override_agent.get_overrides(context_tags)
for model_id, multiplier in overrides.items():
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
logger.info(f"Adjusted weights after context overrides: {adjusted_weights}")
# 2. Apply situation-based adjustments (consensus, conflict, confidence)
# Check for consensus
has_consensus = self._has_consensus(predictions)
if has_consensus:
logger.info("Consensus detected. Boosting weights for consensus.")
for model in adjusted_weights:
adjusted_weights[model] *= self.situation_weights["consensus"]
logger.info(f"Adjusted weights after consensus boost: {adjusted_weights}")
# Check for conflicts
has_conflicts = self._has_conflicts(predictions)
if has_conflicts:
logger.info("Conflicts detected. Reducing weights for conflict.")
for model in adjusted_weights:
adjusted_weights[model] *= self.situation_weights["conflict"]
logger.info(f"Adjusted weights after conflict reduction: {adjusted_weights}")
# Adjust based on confidence
logger.info("Adjusting weights based on model confidence scores.")
for model, confidence in confidence_scores.items():
if confidence > 0.8:
adjusted_weights[model] *= self.situation_weights["high_confidence"]
logger.info(f"Model '{model}' has high confidence ({confidence:.2f}). Weight boosted.")
elif confidence < 0.5:
adjusted_weights[model] *= self.situation_weights["low_confidence"]
logger.info(f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
logger.info(f"Adjusted weights before normalization: {adjusted_weights}")
normalized_weights = self._normalize_weights(adjusted_weights)
logger.info(f"Final normalized adjusted weights: {normalized_weights}")
return normalized_weights
def _has_consensus(self, predictions):
"""Check if models agree on prediction"""
logger.info("Checking for consensus among model predictions.")
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
logger.debug(f"Non-none predictions for consensus check: {non_none_predictions}")
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
logger.info(f"Consensus detected: {result}")
return result
def _has_conflicts(self, predictions):
"""Check if models have conflicting predictions"""
logger.info("Checking for conflicts among model predictions.")
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
logger.debug(f"Non-none predictions for conflict check: {non_none_predictions}")
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
logger.info(f"Conflicts detected: {result}")
return result
def _normalize_weights(self, weights):
"""Normalize weights to sum to 1"""
logger.info("Normalizing weights.")
total = sum(weights.values())
if total == 0:
logger.warning("All weights became zero after adjustments. Reverting to equal base weights for registered models.")
# Revert to equal weights for all *registered* models if total becomes zero
num_registered_models = len(MODEL_REGISTRY)
if num_registered_models > 0:
return {k: 1.0/num_registered_models for k in MODEL_REGISTRY.keys()}
else:
return {} # No models registered
normalized = {k: v/total for k, v in weights.items()}
logger.info(f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
return normalized