LPX55 commited on
Commit
0e2cdc4
·
1 Parent(s): e799e01

feat: dynamically initialize model weights using MODEL_REGISTRY and adjust weights for specific models

Browse files
Files changed (1) hide show
  1. agents/ensemble_weights.py +11 -13
agents/ensemble_weights.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  import torch
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
@@ -15,7 +16,7 @@ class ContextualWeightOverrideAgent:
15
  "model_2": 0.7,
16
  "model_7": 1.3,
17
  },
18
- "sunny": {
19
  "model_3": 0.9,
20
  "model_4": 1.1,
21
  }
@@ -36,21 +37,18 @@ class ContextualWeightOverrideAgent:
36
 
37
  class ModelWeightManager:
38
  def __init__(self):
39
- self.base_weights = {
40
- "model_1": 0.15, # SwinV2 Based
41
- "model_2": 0.15, # ViT Based
42
- "model_3": 0.15, # SDXL Dataset
43
- "model_4": 0.15, # SDXL + FLUX
44
- "model_5": 0.15, # ViT Based
45
- "model_5b": 0.10, # ViT Based, Newer Dataset
46
- "model_6": 0.10, # Swin, Midj + SDXL
47
- "model_7": 0.05 # ViT
48
- }
49
  self.situation_weights = {
50
  "high_confidence": 1.2, # Boost weights for high confidence predictions
51
  "low_confidence": 0.8, # Reduce weights for low confidence
52
- "conflict": 0.5, # Reduce weights when models disagree
53
- "consensus": 1.5 # Boost weights when models agree
54
  }
55
  self.context_override_agent = ContextualWeightOverrideAgent()
56
 
 
1
  import logging
2
  import torch
3
+ from utils.registry import MODEL_REGISTRY
4
 
5
  logger = logging.getLogger(__name__)
6
 
 
16
  "model_2": 0.7,
17
  "model_7": 1.3,
18
  },
19
+ "SDXL": {
20
  "model_3": 0.9,
21
  "model_4": 1.1,
22
  }
 
37
 
38
  class ModelWeightManager:
39
  def __init__(self):
40
+ # Dynamically initialize base_weights for all models
41
+ self.base_weights = {model_id: 1.0 for model_id in MODEL_REGISTRY}
42
+ # Assign a higher weight to the strongest model
43
+ if "simple_prediction" in self.base_weights:
44
+ self.base_weights["simple_prediction"] = 2.5 # Tune as needed
45
+
46
+ # Situation-based multipliers (applied to all models)
 
 
 
47
  self.situation_weights = {
48
  "high_confidence": 1.2, # Boost weights for high confidence predictions
49
  "low_confidence": 0.8, # Reduce weights for low confidence
50
+ "conflict": 0.5, # Reduce weights when models disagree
51
+ "consensus": 1.5 # Boost weights when models agree
52
  }
53
  self.context_override_agent = ContextualWeightOverrideAgent()
54