File size: 8,587 Bytes
8f7f87a 1146644 8f7f87a e1eac06 8f7f87a 39558cb 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a 1146644 e1eac06 1146644 e1eac06 1146644 e1eac06 1146644 e1eac06 1146644 8f7f87a 39558cb 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 73783f9 e1eac06 8f7f87a e1eac06 73783f9 e1eac06 8f7f87a e1eac06 8f7f87a e1eac06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import logging
import torch
from utils.registry import MODEL_REGISTRY # Import MODEL_REGISTRY
logger = logging.getLogger(__name__)
class ContextualWeightOverrideAgent:
def __init__(self):
logger.info("Initializing ContextualWeightOverrideAgent.")
self.context_overrides = {
# Example: when image is outdoor, model_X is penalized, model_Y is boosted
"outdoor": {
"model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
"model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
},
"low_light": {
"model_2": 0.7,
"model_7": 1.3,
},
"sunny": {
"model_3": 0.9,
"model_4": 1.1,
}
# Add more contexts and their specific model weight adjustments here
}
def get_overrides(self, context_tags: list[str]) -> dict:
logger.info(f"Getting weight overrides for context tags: {context_tags}")
combined_overrides = {}
for tag in context_tags:
if tag in self.context_overrides:
for model_id, multiplier in self.context_overrides[tag].items():
# If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
# For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
logger.info(f"Combined context overrides: {combined_overrides}")
return combined_overrides
class ModelWeightManager:
def __init__(self, strongest_model_id: str = None):
logger.info(f"Initializing ModelWeightManager with strongest_model_id: {strongest_model_id}")
# Dynamically initialize base_weights from MODEL_REGISTRY
num_models = len(MODEL_REGISTRY)
if num_models > 0:
if strongest_model_id and strongest_model_id in MODEL_REGISTRY:
logger.info(f"Designating '{strongest_model_id}' as the strongest model.")
# Assign a high weight to the strongest model (e.g., 50%)
strongest_weight_share = 0.5
self.base_weights = {strongest_model_id: strongest_weight_share}
remaining_models = [mid for mid in MODEL_REGISTRY.keys() if mid != strongest_model_id]
if remaining_models:
other_models_weight_share = (1.0 - strongest_weight_share) / len(remaining_models)
for model_id in remaining_models:
self.base_weights[model_id] = other_models_weight_share
else: # Only one model, which is the strongest
self.base_weights[strongest_model_id] = 1.0
else:
if strongest_model_id and strongest_model_id not in MODEL_REGISTRY:
logger.warning(f"Strongest model ID '{strongest_model_id}' not found in MODEL_REGISTRY. Distributing weights equally.")
initial_weight = 1.0 / num_models
self.base_weights = {model_id: initial_weight for model_id in MODEL_REGISTRY.keys()}
else:
self.base_weights = {} # Handle case with no registered models
logger.info(f"Base weights initialized: {self.base_weights}")
self.situation_weights = {
"high_confidence": 1.2, # Boost weights for high confidence predictions
"low_confidence": 0.8, # Reduce weights for low confidence
"conflict": 0.5, # Reduce weights when models disagree
"consensus": 1.5 # Boost weights when models agree
}
self.context_override_agent = ContextualWeightOverrideAgent()
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
"""Dynamically adjust weights based on prediction patterns and optional context."""
logger.info("Adjusting model weights.")
adjusted_weights = self.base_weights.copy()
logger.info(f"Initial adjusted weights (copy of base): {adjusted_weights}")
# 1. Apply contextual overrides first
if context_tags:
logger.info(f"Applying contextual overrides for tags: {context_tags}")
overrides = self.context_override_agent.get_overrides(context_tags)
for model_id, multiplier in overrides.items():
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
logger.info(f"Adjusted weights after context overrides: {adjusted_weights}")
# 2. Apply situation-based adjustments (consensus, conflict, confidence)
# Check for consensus
has_consensus = self._has_consensus(predictions)
if has_consensus:
logger.info("Consensus detected. Boosting weights for consensus.")
for model in adjusted_weights:
adjusted_weights[model] *= self.situation_weights["consensus"]
logger.info(f"Adjusted weights after consensus boost: {adjusted_weights}")
# Check for conflicts
has_conflicts = self._has_conflicts(predictions)
if has_conflicts:
logger.info("Conflicts detected. Reducing weights for conflict.")
for model in adjusted_weights:
adjusted_weights[model] *= self.situation_weights["conflict"]
logger.info(f"Adjusted weights after conflict reduction: {adjusted_weights}")
# Adjust based on confidence
logger.info("Adjusting weights based on model confidence scores.")
for model, confidence in confidence_scores.items():
if confidence > 0.8:
adjusted_weights[model] *= self.situation_weights["high_confidence"]
logger.info(f"Model '{model}' has high confidence ({confidence:.2f}). Weight boosted.")
elif confidence < 0.5:
adjusted_weights[model] *= self.situation_weights["low_confidence"]
logger.info(f"Model '{model}' has low confidence ({confidence:.2f}). Weight reduced.")
logger.info(f"Adjusted weights before normalization: {adjusted_weights}")
normalized_weights = self._normalize_weights(adjusted_weights)
logger.info(f"Final normalized adjusted weights: {normalized_weights}")
return normalized_weights
def _has_consensus(self, predictions):
"""Check if models agree on prediction"""
logger.info("Checking for consensus among model predictions.")
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
logger.debug(f"Non-none predictions for consensus check: {non_none_predictions}")
result = len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
logger.info(f"Consensus detected: {result}")
return result
def _has_conflicts(self, predictions):
"""Check if models have conflicting predictions"""
logger.info("Checking for conflicts among model predictions.")
non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
logger.debug(f"Non-none predictions for conflict check: {non_none_predictions}")
result = len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
logger.info(f"Conflicts detected: {result}")
return result
def _normalize_weights(self, weights):
"""Normalize weights to sum to 1"""
logger.info("Normalizing weights.")
total = sum(weights.values())
if total == 0:
logger.warning("All weights became zero after adjustments. Reverting to equal base weights for registered models.")
# Revert to equal weights for all *registered* models if total becomes zero
num_registered_models = len(MODEL_REGISTRY)
if num_registered_models > 0:
return {k: 1.0/num_registered_models for k in MODEL_REGISTRY.keys()}
else:
return {} # No models registered
normalized = {k: v/total for k, v in weights.items()}
logger.info(f"Weights normalized. Total sum: {sum(normalized.values()):.2f}")
return normalized |