Chris4K commited on
Commit
86179ff
·
verified ·
1 Parent(s): b572bdb

Update services/model_service.py

Browse files
Files changed (1) hide show
  1. services/model_service.py +5 -4
services/model_service.py CHANGED
@@ -25,7 +25,8 @@ class ModelService:
25
  # Load tokenizer
26
  self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
27
 
28
-
 
29
 
30
  # Check quantization type and adjust accordingly
31
  if config.get('quantization_config', {}).get('type', '') == 'compressed-tensors':
@@ -34,10 +35,10 @@ class ModelService:
34
 
35
  # Load model with the updated configuration
36
  self.model = AutoModelForCausalLM.from_pretrained(
37
- model_type == "llama" ,
38
  settings.MODEL_NAME,
 
39
  torch_dtype=torch.float16 if settings.DEVICE == "cuda" else torch.float32,
40
- device_map="auto" if settings.DEVICE == "cuda" else None
41
  )
42
 
43
  # Load sentence embedder
@@ -48,4 +49,4 @@ class ModelService:
48
  raise
49
 
50
  def get_models(self):
51
- return self.tokenizer, self.model, self.embedder
 
25
  # Load tokenizer
26
  self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
27
 
28
+ # Load model configuration
29
+ config = LlamaConfig.from_pretrained(settings.MODEL_NAME)
30
 
31
  # Check quantization type and adjust accordingly
32
  if config.get('quantization_config', {}).get('type', '') == 'compressed-tensors':
 
35
 
36
  # Load model with the updated configuration
37
  self.model = AutoModelForCausalLM.from_pretrained(
 
38
  settings.MODEL_NAME,
39
+ config=config,
40
  torch_dtype=torch.float16 if settings.DEVICE == "cuda" else torch.float32,
41
+ device_map="auto" if settings.DEVICE == "cuda" else None
42
  )
43
 
44
  # Load sentence embedder
 
49
  raise
50
 
51
  def get_models(self):
52
+ return self.tokenizer, self.model, self.embedder