Chris4K commited on
Commit
0dad39b
·
verified ·
1 Parent(s): 5a1eeff

Update services/model_service.py

Browse files
Files changed (1) hide show
  1. services/model_service.py +12 -11
services/model_service.py CHANGED
@@ -25,32 +25,33 @@ class ModelService:
25
  @lru_cache(maxsize=1)
26
  def _load_models(self):
27
  try:
28
- self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME )
 
 
 
 
29
 
30
- # Modify the model configuration to use a valid rope_scaling format
31
- config = LlamaConfig.from_pretrained(settings.model_name)
32
  if hasattr(config, "rope_scaling"):
 
33
  config.rope_scaling = {
34
- "type": "linear",
35
- "factor": 32.0
36
  }
37
-
38
- # Load model with updated configuration
39
- #self.model = AutoModelForCausalLM.from_pretrained(model_name, config=config).to(device)
40
 
41
-
42
-
43
  self.model = AutoModelForCausalLM.from_pretrained(
44
  settings.MODEL_NAME,
45
  torch_dtype=torch.float16 if settings.DEVICE == "cuda" else torch.float32,
46
  device_map="auto" if settings.DEVICE == "cuda" else None,
47
  config=config
48
  )
 
 
49
  self.embedder = SentenceTransformer(settings.EMBEDDER_MODEL)
 
50
  except Exception as e:
51
  logger.error(f"Error loading models: {e}")
52
  raise
53
 
54
  def get_models(self):
55
  return self.tokenizer, self.model, self.embedder
56
-
 
25
  @lru_cache(maxsize=1)
26
  def _load_models(self):
27
  try:
28
+ # Load tokenizer
29
+ self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
30
+
31
+ # Load model configuration and modify rope_scaling if applicable
32
+ config = LlamaConfig.from_pretrained(settings.MODEL_NAME)
33
 
 
 
34
  if hasattr(config, "rope_scaling"):
35
+ logger.info("Updating rope_scaling configuration...")
36
  config.rope_scaling = {
37
+ "type": "linear", # Ensure the type is valid
38
+ "factor": 32.0 # Ensure factor is a valid float
39
  }
 
 
 
40
 
41
+ # Load model with the updated configuration
 
42
  self.model = AutoModelForCausalLM.from_pretrained(
43
  settings.MODEL_NAME,
44
  torch_dtype=torch.float16 if settings.DEVICE == "cuda" else torch.float32,
45
  device_map="auto" if settings.DEVICE == "cuda" else None,
46
  config=config
47
  )
48
+
49
+ # Load sentence embedder
50
  self.embedder = SentenceTransformer(settings.EMBEDDER_MODEL)
51
+
52
  except Exception as e:
53
  logger.error(f"Error loading models: {e}")
54
  raise
55
 
56
  def get_models(self):
57
  return self.tokenizer, self.model, self.embedder