LPX55 commited on
Commit
39558cb
·
1 Parent(s): 088ff8c

Revert "feat: dynamically initialize model weights using MODEL_REGISTRY and adjust weights for specific models"

Browse files

This reverts commit 0e2cdc4906eb4f67ab48ec61749fbaabacd8574d.

Files changed (1) hide show
  1. agents/ensemble_weights.py +13 -11
agents/ensemble_weights.py CHANGED
@@ -1,6 +1,5 @@
1
  import logging
2
  import torch
3
- from utils.registry import MODEL_REGISTRY
4
 
5
  logger = logging.getLogger(__name__)
6
 
@@ -16,7 +15,7 @@ class ContextualWeightOverrideAgent:
16
  "model_2": 0.7,
17
  "model_7": 1.3,
18
  },
19
- "SDXL": {
20
  "model_3": 0.9,
21
  "model_4": 1.1,
22
  }
@@ -37,18 +36,21 @@ class ContextualWeightOverrideAgent:
37
 
38
  class ModelWeightManager:
39
  def __init__(self):
40
- # Dynamically initialize base_weights for all models
41
- self.base_weights = {model_id: 1.0 for model_id in MODEL_REGISTRY}
42
- # Assign a higher weight to the strongest model
43
- if "simple_prediction" in self.base_weights:
44
- self.base_weights["simple_prediction"] = 2.5 # Tune as needed
45
-
46
- # Situation-based multipliers (applied to all models)
 
 
 
47
  self.situation_weights = {
48
  "high_confidence": 1.2, # Boost weights for high confidence predictions
49
  "low_confidence": 0.8, # Reduce weights for low confidence
50
- "conflict": 0.5, # Reduce weights when models disagree
51
- "consensus": 1.5 # Boost weights when models agree
52
  }
53
  self.context_override_agent = ContextualWeightOverrideAgent()
54
 
 
1
  import logging
2
  import torch
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
 
15
  "model_2": 0.7,
16
  "model_7": 1.3,
17
  },
18
+ "sunny": {
19
  "model_3": 0.9,
20
  "model_4": 1.1,
21
  }
 
36
 
37
  class ModelWeightManager:
38
  def __init__(self):
39
+ self.base_weights = {
40
+ "model_1": 0.15, # SwinV2 Based
41
+ "model_2": 0.15, # ViT Based
42
+ "model_3": 0.15, # SDXL Dataset
43
+ "model_4": 0.15, # SDXL + FLUX
44
+ "model_5": 0.15, # ViT Based
45
+ "model_5b": 0.10, # ViT Based, Newer Dataset
46
+ "model_6": 0.10, # Swin, Midj + SDXL
47
+ "model_7": 0.05 # ViT
48
+ }
49
  self.situation_weights = {
50
  "high_confidence": 1.2, # Boost weights for high confidence predictions
51
  "low_confidence": 0.8, # Reduce weights for low confidence
52
+ "conflict": 0.5, # Reduce weights when models disagree
53
+ "consensus": 1.5 # Boost weights when models agree
54
  }
55
  self.context_override_agent = ContextualWeightOverrideAgent()
56