DawnC commited on
Commit
d217fb0
·
verified ·
1 Parent(s): 91e463e

Upload llm_model_manager.py

Browse files
Files changed (1) hide show
  1. llm_model_manager.py +83 -35
llm_model_manager.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import re
3
  import torch
4
  import logging
 
5
  from typing import Dict, Optional, Any
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
7
  from huggingface_hub import login
@@ -20,7 +21,22 @@ class LLMModelManager:
20
  """
21
  負責LLM模型的載入、設備管理和文本生成。
22
  管理模型、記憶體優化和設備配置。
 
23
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def __init__(self,
26
  model_path: Optional[str] = None,
@@ -30,7 +46,7 @@ class LLMModelManager:
30
  temperature: float = 0.3,
31
  top_p: float = 0.85):
32
  """
33
- 初始化模型管理器
34
 
35
  Args:
36
  model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
@@ -40,36 +56,48 @@ class LLMModelManager:
40
  temperature: 生成文本的溫度參數
41
  top_p: 生成文本時的核心採樣機率閾值
42
  """
43
- # 設置專屬logger
44
- self.logger = logging.getLogger(self.__class__.__name__)
45
- if not self.logger.handlers:
46
- handler = logging.StreamHandler()
47
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
48
- handler.setFormatter(formatter)
49
- self.logger.addHandler(handler)
50
- self.logger.setLevel(logging.INFO)
51
-
52
- # 模型配置
53
- self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
54
- self.tokenizer_path = tokenizer_path or self.model_path
55
-
56
- # 設備管理
57
- self.device = self._detect_device(device)
58
- self.logger.info(f"Device selected: {self.device}")
59
-
60
- # 生成參數
61
- self.max_length = max_length
62
- self.temperature = temperature
63
- self.top_p = top_p
64
-
65
- # 模型狀態
66
- self.model = None
67
- self.tokenizer = None
68
- self._model_loaded = False
69
- self.call_count = 0
70
-
71
- # HuggingFace認證
72
- self.hf_token = self._setup_huggingface_auth()
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  def _detect_device(self, device: Optional[str]) -> str:
75
  """
@@ -119,11 +147,16 @@ class LLMModelManager:
119
  def _load_model(self):
120
  """
121
  載入LLM模型和tokenizer,使用8位量化以節省記憶體
 
122
 
123
  Raises:
124
  ModelLoadingError: 當模型載入失敗時
125
  """
126
- if self._model_loaded:
 
 
 
 
127
  return
128
 
129
  try:
@@ -160,7 +193,7 @@ class LLMModelManager:
160
  )
161
 
162
  self._model_loaded = True
163
- self.logger.info("Model loaded successfully")
164
 
165
  except Exception as e:
166
  error_msg = f"Failed to load model: {str(e)}"
@@ -331,7 +364,7 @@ class LLMModelManager:
331
  """重置模型上下文,清理GPU緩存"""
332
  if self._model_loaded:
333
  self._clear_gpu_cache()
334
- self.logger.info("Model context reset")
335
  else:
336
  self.logger.info("Model not loaded, no context to reset")
337
 
@@ -374,5 +407,20 @@ class LLMModelManager:
374
  "device": self.device,
375
  "is_loaded": self._model_loaded,
376
  "call_count": self.call_count,
377
- "has_hf_token": self.hf_token is not None
 
378
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import re
3
  import torch
4
  import logging
5
+ import threading
6
  from typing import Dict, Optional, Any
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
8
  from huggingface_hub import login
 
21
  """
22
  負責LLM模型的載入、設備管理和文本生成。
23
  管理模型、記憶體優化和設備配置。
24
+ 實現單例模式確保全應用程式只有一個模型載入方式。
25
  """
26
+
27
+ _instance = None
28
+ _initialized = False
29
+ _lock = threading.Lock()
30
+
31
+ def __new__(cls, *args, **kwargs):
32
+ """
33
+ 單例模式實現:確保整個應用程式只創建一個 LLMModelManager
34
+ """
35
+ if cls._instance is None:
36
+ with cls._lock:
37
+ if cls._instance is None:
38
+ cls._instance = super(LLMModelManager, cls).__new__(cls)
39
+ return cls._instance
40
 
41
  def __init__(self,
42
  model_path: Optional[str] = None,
 
46
  temperature: float = 0.3,
47
  top_p: float = 0.85):
48
  """
49
+ 初始化模型管理器(只在第一次創建實例時執行)
50
 
51
  Args:
52
  model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
 
56
  temperature: 生成文本的溫度參數
57
  top_p: 生成文本時的核心採樣機率閾值
58
  """
59
+ # 避免重複初始化
60
+ if self._initialized:
61
+ return
62
+
63
+ with self._lock:
64
+ if self._initialized:
65
+ return
66
+
67
+ # set logger
68
+ self.logger = logging.getLogger(self.__class__.__name__)
69
+ if not self.logger.handlers:
70
+ handler = logging.StreamHandler()
71
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
72
+ handler.setFormatter(formatter)
73
+ self.logger.addHandler(handler)
74
+ self.logger.setLevel(logging.INFO)
75
+
76
+ # model config
77
+ self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
78
+ self.tokenizer_path = tokenizer_path or self.model_path
79
+
80
+ # device management
81
+ self.device = self._detect_device(device)
82
+ self.logger.info(f"Device selected: {self.device}")
83
+
84
+ # 生成參數
85
+ self.max_length = max_length
86
+ self.temperature = temperature
87
+ self.top_p = top_p
88
+
89
+ # 模型狀態
90
+ self.model = None
91
+ self.tokenizer = None
92
+ self._model_loaded = False
93
+ self.call_count = 0
94
+
95
+ # HuggingFace認證
96
+ self.hf_token = self._setup_huggingface_auth()
97
+
98
+ # 標記為已初始化
99
+ self._initialized = True
100
+ self.logger.info("LLMModelManager singleton initialized")
101
 
102
  def _detect_device(self, device: Optional[str]) -> str:
103
  """
 
147
  def _load_model(self):
148
  """
149
  載入LLM模型和tokenizer,使用8位量化以節省記憶體
150
+ 增強的狀態檢查確保模型只載入一次
151
 
152
  Raises:
153
  ModelLoadingError: 當模型載入失敗時
154
  """
155
+ # 完整的模型狀態檢查
156
+ if (self._model_loaded and
157
+ hasattr(self, 'model') and self.model is not None and
158
+ hasattr(self, 'tokenizer') and self.tokenizer is not None):
159
+ self.logger.info("Model already loaded, skipping reload")
160
  return
161
 
162
  try:
 
193
  )
194
 
195
  self._model_loaded = True
196
+ self.logger.info("Model loaded successfully (singleton instance)")
197
 
198
  except Exception as e:
199
  error_msg = f"Failed to load model: {str(e)}"
 
364
  """重置模型上下文,清理GPU緩存"""
365
  if self._model_loaded:
366
  self._clear_gpu_cache()
367
+ self.logger.info("Model context reset (singleton instance)")
368
  else:
369
  self.logger.info("Model not loaded, no context to reset")
370
 
 
407
  "device": self.device,
408
  "is_loaded": self._model_loaded,
409
  "call_count": self.call_count,
410
+ "has_hf_token": self.hf_token is not None,
411
+ "is_singleton": True
412
  }
413
+
414
+ @classmethod
415
+ def reset_singleton(cls):
416
+ """
417
+ 重置單例實例(僅用於測試或應用程式重啟)
418
+ 注意:這會導致模型需要重新載入
419
+ """
420
+ with cls._lock:
421
+ if cls._instance is not None:
422
+ instance = cls._instance
423
+ if hasattr(instance, 'logger'):
424
+ instance.logger.info("Resetting singleton instance")
425
+ cls._instance = None
426
+ cls._initialized = False