File size: 5,251 Bytes
8f7f87a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73783f9
8f7f87a
 
 
 
 
73783f9
8f7f87a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import logging
import torch

logger = logging.getLogger(__name__)

class ContextualWeightOverrideAgent:
    def __init__(self):
        self.context_overrides = {
            # Example: when image is outdoor, model_X is penalized, model_Y is boosted
            "outdoor": {
                "model_1": 0.8, # Example: Reduce weight of model_1 by 20% for outdoor scenes
                "model_5": 1.2, # Example: Boost weight of model_5 by 20% for outdoor scenes
            },
            "low_light": {
                "model_2": 0.7, 
                "model_7": 1.3,
            },
            "sunny": {
                "model_3": 0.9,
                "model_4": 1.1,
            }
            # Add more contexts and their specific model weight adjustments here
        }

    def get_overrides(self, context_tags: list[str]) -> dict:
        """Returns combined weight overrides for given context tags."""
        combined_overrides = {}
        for tag in context_tags:
            if tag in self.context_overrides:
                for model_id, multiplier in self.context_overrides[tag].items():
                    # If a model appears in multiple contexts, we can decide how to combine (e.g., multiply, average, take max)
                    # For now, let's just take the last one if there are conflicts, or multiply for simple cumulative effect.
                    combined_overrides[model_id] = combined_overrides.get(model_id, 1.0) * multiplier
        return combined_overrides


class ModelWeightManager:
    def __init__(self):
        self.base_weights = {
            "model_1": 0.15,  # SwinV2 Based
            "model_2": 0.15,  # ViT Based
            "model_3": 0.15,  # SDXL Dataset
            "model_4": 0.15,  # SDXL + FLUX
            "model_5": 0.15,  # ViT Based
            "model_5b": 0.10, # ViT Based, Newer Dataset
            "model_6": 0.10,  # Swin, Midj + SDXL
            "model_7": 0.05   # ViT
        }
        self.situation_weights = {
            "high_confidence": 1.2,    # Boost weights for high confidence predictions
            "low_confidence": 0.8,     # Reduce weights for low confidence
            "conflict": 0.5,          # Reduce weights when models disagree
            "consensus": 1.5          # Boost weights when models agree
        }
        self.context_override_agent = ContextualWeightOverrideAgent()
    
    def adjust_weights(self, predictions, confidence_scores, context_tags: list[str] = None):
        """Dynamically adjust weights based on prediction patterns and optional context."""
        adjusted_weights = self.base_weights.copy()

        # 1. Apply contextual overrides first
        if context_tags:
            overrides = self.context_override_agent.get_overrides(context_tags)
            for model_id, multiplier in overrides.items():
                adjusted_weights[model_id] = adjusted_weights.get(model_id, 0.0) * multiplier
        
        # 2. Apply situation-based adjustments (consensus, conflict, confidence)
        # Check for consensus
        if self._has_consensus(predictions):
            for model in adjusted_weights:
                adjusted_weights[model] *= self.situation_weights["consensus"]
        
        # Check for conflicts
        if self._has_conflicts(predictions):
            for model in adjusted_weights:
                adjusted_weights[model] *= self.situation_weights["conflict"]
        
        # Adjust based on confidence
        for model, confidence in confidence_scores.items():
            if confidence > 0.8:
                adjusted_weights[model] *= self.situation_weights["high_confidence"]
            elif confidence < 0.5:
                adjusted_weights[model] *= self.situation_weights["low_confidence"]
        
        return self._normalize_weights(adjusted_weights)
    
    def _has_consensus(self, predictions):
        """Check if models agree on prediction"""
        # Ensure all predictions are not None before checking for consensus
        non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
        return len(non_none_predictions) > 0 and len(set(non_none_predictions)) == 1
    
    def _has_conflicts(self, predictions):
        """Check if models have conflicting predictions"""
        # Ensure all predictions are not None before checking for conflicts
        non_none_predictions = [p.get("Label") for p in predictions.values() if p is not None and isinstance(p, dict) and p.get("Label") is not None and p.get("Label") != "Error"]
        return len(non_none_predictions) > 1 and len(set(non_none_predictions)) > 1
    
    def _normalize_weights(self, weights):
        """Normalize weights to sum to 1"""
        total = sum(weights.values())
        if total == 0:
            # Handle case where all weights became zero due to aggressive multipliers
            # This could assign equal weights or revert to base weights
            logger.warning("All weights became zero after adjustments. Reverting to base weights.")
            return {k: 1.0/len(self.base_weights) for k in self.base_weights}
        return {k: v/total for k, v in weights.items()}