|
import logging |
|
import torch |
|
from utils.registry import MODEL_REGISTRY |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class ContextualWeightOverrideAgent: |
|
def __init__(self): |
|
logger.info("Initializing ContextualWeightOverrideAgent.") |
|
self.context_overrides = { |
|
|
|
"outdoor": { |
|
"model_1": 0.8, |
|
"model_5": 1.2, |
|
}, |
|
"low_light": { |
|
"model_2": 0.7, |
|
"model_7": 1.3, |
|
}, |
|
"sunny": { |
|
"model_3": 0.9, |
|
"model_4": 1.1, |
|
} |
|
|
|
} |
|
|
|
def get_overrides(self, context_tags: list[str]) -> dict: |
|
logger.info(f"Getting weight overrides for context tags: {context_tags}") |
|
combined_overrides = {} |
|
for tag in context_tags: |
|
if tag in self.context_overrides: |
|
for model_id, multiplier in self.context_overrides[tag].items(): |
|
|
|
|
|
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier |
|
logger.info(f"Combined context overrides: {combined_overrides}") |
|
return combined_overrides |
|
|
|
|
|
class ModelWeightManager: |
|
def __init__(self, strongest_model_id: str = None): |
|
logger.info(f"Initializing ModelWeightManager with strongest_model_id: {strongest_model_id}") |
|
|
|
num_models = len(MODEL_REGISTRY) |
|
if num_models > 0: |
|
if strongest_model_id and strongest_model_id in MODEL_REGISTRY: |
|
logger.info(f"Designating '{strongest_model_id}' as the strongest model.") |
|
|
|
strongest_weight_share = 0.5 |
|
self.base_weights = {strongest_model_id: strongest_weight_share} |
|
remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id] |
|
if remaining_models: |
|
other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models) |
|
for model_id in remaining_models: |
|
self.base_weights[model_id] = other_models_weight_share |
|
else: |
|
self.base_weights[strongest_model_id] = 1.0 |
|
else: |
|
if strongest_model_id and strongest_model_id not in MODEL_REGISTRY: |
|
logger.warning(f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.") |
|
initial_weight = 1.0 / num_models |
|
self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()} |
|
else: |
|
self.base_weights = {} |
|
logger.info(f"Base weights initialized: {self.base_weights}") |
|
|
|
self.situation_weights = { |
|
"high_confidence": 1.2, |
|
"low_confidence": 0.8, |
|
"conflict": 0.5, |
|
"consensus": 1.5 |
|
} |
|
self.context_override_agent = ContextualWeightOverrideAgent() |
|
|
|
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None): |
|
"""Dynamically adjust weights based on prediction patterns and optional context.""" |
|
logger.info("Adjusting model weights.") |
|
adjusted_weights = self.base_weights.copy() |
|
logger.info(f"Initial adjusted weights (copy of base): {adjusted_weights}") |
|
|
|
|
|
if context_tags: |
|
logger.info(f"Applying contextual overrides for tags: {context_tags}") |
|
overrides = self.context_override_agent.get_overrides(context_tags) |
|
for model_id, multiplier in overrides.items(): |
|
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier |
|
logger.info(f"Adjusted weights after context overrides: {adjusted_weights}") |
|
|
|
|
|
|
|
has_consensus = self._has_consensus(predictions) |
|
if has_consensus: |
|
logger.info("Consensus detected. Boosting weights for consensus.") |
|
for model in adjusted_weights: |
|
adjusted_weights[model] *= self.situation_weights["consensus"] |
|
logger.info(f"Adjusted weights after consensus boost: {adjusted_weights}") |
|
|
|
|
|
has_conflicts = self._has_conflicts(predictions) |
|
if has_conflicts: |
|
logger.info("Conflicts detected. Reducing weights for conflict.") |
|
for model in adjusted_weights: |
|
adjusted_weights[model] *= self.situation_weights["conflict"] |
|
logger.info(f"Adjusted weights after conflict reduction: {adjusted_weights}") |
|
|
|
|
|
logger.info("Adjusting weights based on model confidence scores.") |
|
for model, confidence in confidence_scores.items(): |
|
if confidence > 0.8: |
|
adjusted_weights[model] *= self.situation_weights["high_confidence"] |
|
logger.info(f"Model '{model}' has high confidence ({confidence:.2f}). Weight boosted.") |
|
elif confidence < 0.5: |
|
adjusted_weights[model] *= self.situation_weights["low_confidence"] |
|
logger.info(f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.") |
|
logger.info(f"Adjusted weights before normalization: {adjusted_weights}") |
|
|
|
normalized_weights = self._normalize_weights(adjusted_weights) |
|
logger.info(f"Final normalized adjusted weights: {normalized_weights}") |
|
return normalized_weights |
|
|
|
def _has_consensus(self, predictions): |
|
"""Check if models agree on prediction""" |
|
logger.info("Checking for consensus among model predictions.") |
|
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] |
|
logger.debug(f"Non-none predictions for consensus check: {non_none_predictions}") |
|
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1 |
|
logger.info(f"Consensus detected: {result}") |
|
return result |
|
|
|
def _has_conflicts(self, predictions): |
|
"""Check if models have conflicting predictions""" |
|
logger.info("Checking for conflicts among model predictions.") |
|
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"] |
|
logger.debug(f"Non-none predictions for conflict check: {non_none_predictions}") |
|
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1 |
|
logger.info(f"Conflicts detected: {result}") |
|
return result |
|
|
|
def _normalize_weights(self, weights): |
|
"""Normalize weights to sum to 1""" |
|
logger.info("Normalizing weights.") |
|
total = sum(weights.values()) |
|
if total == 0: |
|
logger.warning("All weights became zero after adjustments. Reverting to equal base weights for registered models.") |
|
|
|
num_registered_models = len(MODEL_REGISTRY) |
|
if num_registered_models > 0: |
|
return {k: 1.0/num_registered_models for k in MODEL_REGISTRY.keys()} |
|
else: |
|
return {} |
|
normalized = {k: v/total for k, v in weights.items()} |
|
logger.info(f"Weights normalized. Total sum: {sum(normalized.values()):.2f}") |
|
return normalized |