Chris4K commited on
Commit
894f4ee
·
verified ·
1 Parent(s): 288c963

Update services/model_service.py

Browse files
Files changed (1) hide show
  1. services/model_service.py +31 -30
services/model_service.py CHANGED
@@ -1,8 +1,8 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig
2
- from config.config import settings
3
  from sentence_transformers import SentenceTransformer
4
  import torch
5
  import logging
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
@@ -18,43 +18,44 @@ class ModelService:
18
  def __init__(self):
19
  if not self._initialized:
20
  self._initialized = True
 
 
 
21
  self._load_models()
22
 
23
  def _load_models(self):
24
  try:
 
 
25
  # Load tokenizer
26
- #self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
27
-
28
- ## Load model configuration
29
- #config = LlamaConfig.from_pretrained(settings.MODEL_NAME)
30
-
31
- ## Check quantization type and adjust accordingly
32
- #if config.get('quantization_config', {}).get('type', '') == 'compressed-tensors':
33
- # logger.warning("Quantization type 'compressed-tensors' is not supported. Switching to 'bitsandbytes_8bit'.")
34
- # config.quantization_config['type'] = 'bitsandbytes_8bit'
35
-
36
- ## Load model with the updated configuration
37
- #self.model = AutoModelForCausalLM.from_pretrained(
38
- # settings.MODEL_NAME,
39
- # config=config,
40
- # torch_dtype=torch.float16 if settings.DEVICE == "cuda" else torch.float32,
41
- # device_map="auto" if settings.DEVICE == "cuda" else None
42
- #)
43
-
44
- #-----
45
- # Load Llama 3.2 model
46
- model_name = settings.MODEL_NAME #"meta-llama/Llama-3.2-3B-Instruct" # Replace with the exact model path
47
- tokenizer = AutoTokenizer.from_pretrained(model_name)
48
- #model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
49
- self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map=None, torch_dtype=torch.float32)
50
-
51
-
52
  # Load sentence embedder
53
  self.embedder = SentenceTransformer(settings.EMBEDDER_MODEL)
 
54
 
55
  except Exception as e:
56
  logger.error(f"Error loading models: {e}")
57
- raise
58
 
59
  def get_models(self):
60
- return self.tokenizer, self.model, self.embedder
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
2
  from sentence_transformers import SentenceTransformer
3
  import torch
4
  import logging
5
+ from config.config import settings
6
 
7
  logger = logging.getLogger(__name__)
8
 
 
18
  def __init__(self):
19
  if not self._initialized:
20
  self._initialized = True
21
+ self.tokenizer = None
22
+ self.model = None
23
+ self.embedder = None
24
  self._load_models()
25
 
26
  def _load_models(self):
27
  try:
28
+ logger.info("Loading models...")
29
+
30
  # Load tokenizer
31
+ self.tokenizer = AutoTokenizer.from_pretrained(settings.MODEL_NAME)
32
+ logger.info(f"Tokenizer for {settings.MODEL_NAME} loaded successfully.")
33
+
34
+ # Load language model
35
+ quantization_device = settings.DEVICE
36
+ quantization_bits = settings.QUANTIZATION_BITS
37
+
38
+ self.model = AutoModelForCausalLM.from_pretrained(
39
+ settings.MODEL_NAME,
40
+ torch_dtype=torch.float16 if quantization_device == "cuda" else torch.float32,
41
+ device_map="auto" if quantization_device == "cuda" else None,
42
+ load_in_8bit=(quantization_bits == 8),
43
+ trust_remote_code=True
44
+ )
45
+ logger.info(f"Model {settings.MODEL_NAME} loaded successfully on {quantization_device}.")
46
+
 
 
 
 
 
 
 
 
 
 
47
  # Load sentence embedder
48
  self.embedder = SentenceTransformer(settings.EMBEDDER_MODEL)
49
+ logger.info(f"Embedder {settings.EMBEDDER_MODEL} loaded successfully.")
50
 
51
  except Exception as e:
52
  logger.error(f"Error loading models: {e}")
53
+ raise RuntimeError(f"Failed to initialize ModelService: {str(e)}")
54
 
55
  def get_models(self):
56
+ """
57
+ Returns the tokenizer, language model, and sentence embedder instances.
58
+ """
59
+ if not self.tokenizer or not self.model or not self.embedder:
60
+ raise RuntimeError("Models are not fully loaded.")
61
+ return self.tokenizer, self.model, self.embedder