Chris4K commited on
Commit
44c613f
·
verified ·
1 Parent(s): 09b1583

Update services/model_service.py

Browse files
Files changed (1) hide show
  1. services/model_service.py +14 -11
services/model_service.py CHANGED
@@ -21,20 +21,22 @@ class ModelService:
21
  self._initialized = True
22
  self._load_models()
23
 
24
- @lru_cache(maxsize=1)
25
  def _load_models(self):
26
  try:
27
  # Load tokenizer
28
  self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
29
-
30
- # Load model configuration without modifying rope_scaling
31
  config = LlamaConfig.from_pretrained(settings.MODEL_NAME)
32
-
33
- # Remove rope_scaling if present
34
- if hasattr(config, "rope_scaling"):
35
- logger.info("Removing rope_scaling from configuration...")
36
- config.rope_scaling = None
37
-
 
 
 
38
  # Load model with the updated configuration
39
  self.model = AutoModelForCausalLM.from_pretrained(
40
  settings.MODEL_NAME,
@@ -42,13 +44,14 @@ class ModelService:
42
  device_map="auto" if settings.DEVICE == "cuda" else None,
43
  config=config
44
  )
45
-
46
  # Load sentence embedder
47
  self.embedder = SentenceTransformer(settings.EMBEDDER_MODEL)
48
-
49
  except Exception as e:
50
  logger.error(f"Error loading models: {e}")
51
  raise
52
 
 
53
  def get_models(self):
54
  return self.tokenizer, self.model, self.embedder
 
21
  self._initialized = True
22
  self._load_models()
23
 
 
24
  def _load_models(self):
25
  try:
26
  # Load tokenizer
27
  self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
28
+
29
+ # Load model configuration
30
  config = LlamaConfig.from_pretrained(settings.MODEL_NAME)
31
+
32
+ # Check and update rope_scaling if necessary
33
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
34
+ logger.info("Updating rope_scaling in configuration...")
35
+ config.rope_scaling = {
36
+ "type": "linear", # Ensure this matches the expected type
37
+ "factor": config.rope_scaling.get('factor', 1.0) # Use existing factor or default to 1.0
38
+ }
39
+
40
  # Load model with the updated configuration
41
  self.model = AutoModelForCausalLM.from_pretrained(
42
  settings.MODEL_NAME,
 
44
  device_map="auto" if settings.DEVICE == "cuda" else None,
45
  config=config
46
  )
47
+
48
  # Load sentence embedder
49
  self.embedder = SentenceTransformer(settings.EMBEDDER_MODEL)
50
+
51
  except Exception as e:
52
  logger.error(f"Error loading models: {e}")
53
  raise
54
 
55
+
56
  def get_models(self):
57
  return self.tokenizer, self.model, self.embedder