|
import logging |
|
import torch |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class ContextualWeightOverrideAgent: |
|
def __init__(self): |
|
self.context_overrides = { |
|
|
|
"outdoor": { |
|
"model_1": 0.8, |
|
"model_5": 1.2, |
|
}, |
|
"low_light": { |
|
"model_2": 0.7, |
|
"model_7": 1.3, |
|
}, |
|
"sunny": { |
|
"model_3": 0.9, |
|
"model_4": 1.1, |
|
} |
|
|
|
} |
|
|
|
def get_overrides(self, context_tags: list[str]) -> dict: |
|
"""Returns combined weight overrides for given context tags.""" |
|
combined_overrides = {} |
|
for tag in context_tags: |
|
if tag in self.context_overrides: |
|
for model_id, multiplier in self.context_overrides[tag].items(): |
|
|
|
|
|
combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier |
|
return combined_overrides |
|
|
|
|
|
class ModelWeightManager: |
|
def __init__(self): |
|
self.base_weights = { |
|
"model_1": 0.15, |
|
"model_2": 0.15, |
|
"model_3": 0.15, |
|
"model_4": 0.15, |
|
"model_5": 0.15, |
|
"model_5b": 0.10, |
|
"model_6": 0.10, |
|
"model_7": 0.05 |
|
} |
|
self.situation_weights = { |
|
"high_confidence": 1.2, |
|
"low_confidence": 0.8, |
|
"conflict": 0.5, |
|
"consensus": 1.5 |
|
} |
|
self.context_override_agent = ContextualWeightOverrideAgent() |
|
|
|
def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None): |
|
"""Dynamically adjust weights based on prediction patterns and optional context.""" |
|
adjusted_weights = self.base_weights.copy() |
|
|
|
|
|
if context_tags: |
|
overrides = self.context_override_agent.get_overrides(context_tags) |
|
for model_id, multiplier in overrides.items(): |
|
adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier |
|
|
|
|
|
|
|
if self._has_consensus(predictions): |
|
for model in adjusted_weights: |
|
adjusted_weights[model] *= self.situation_weights["consensus"] |
|
|
|
|
|
if self._has_conflicts(predictions): |
|
for model in adjusted_weights: |
|
adjusted_weights[model] *= self.situation_weights["conflict"] |
|
|
|
|
|
for model, confidence in confidence_scores.items(): |
|
if confidence > 0.8: |
|
adjusted_weights[model] *= self.situation_weights["high_confidence"] |
|
elif confidence < 0.5: |
|
adjusted_weights[model] *= self.situation_weights["low_confidence"] |
|
|
|
return self._normalize_weights(adjusted_weights) |
|
|
|
def _has_consensus(self, predictions): |
|
"""Check if models agree on prediction""" |
|
|
|
non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"] |
|
return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1 |
|
|
|
def _has_conflicts(self, predictions): |
|
"""Check if models have conflicting predictions""" |
|
|
|
non_none_predictions = [p for p in predictions.values() if p is not None and p != "Error"] |
|
return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1 |
|
|
|
def _normalize_weights(self, weights): |
|
"""Normalize weights to sum to 1""" |
|
total = sum(weights.values()) |
|
if total == 0: |
|
|
|
|
|
logger.warning("All weights became zero after adjustments. Reverting to base weights.") |
|
return {k: 1.0/len(self.base_weights) for k in self.base_weights} |
|
return {k: v/total for k, v in weights.items()} |