ciyidogan commited on
Commit
944d883
·
verified ·
1 Parent(s): d7b1a88

Update llm_model.py

Browse files
Files changed (1) hide show
  1. llm_model.py +24 -18
llm_model.py CHANGED
@@ -3,27 +3,36 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequen
3
  from log import log
4
  from pydantic import BaseModel
5
 
6
- model = None
7
- tokenizer = None
8
- eos_token_id = None
 
 
 
 
 
 
 
 
 
9
 
10
  class Message(BaseModel):
11
  user_input: str
12
 
13
  def setup_model(s_config):
14
- global model, tokenizer, eos_token_id
15
  try:
16
  log("🧠 setup_model() başladı")
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  log(f"📡 Kullanılan cihaz: {device}")
19
- tokenizer = AutoTokenizer.from_pretrained(s_config.MODEL_BASE, use_fast=False)
20
  log("📦 Tokenizer yüklendi. Ana model indiriliyor...")
21
- model = AutoModelForCausalLM.from_pretrained(s_config.MODEL_BASE, torch_dtype=torch.float32).to(device)
22
  log("📦 Ana model indirildi ve yüklendi. eval() çağırılıyor...")
23
- tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
24
- model.config.pad_token_id = tokenizer.pad_token_id
25
- eos_token_id = tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
26
- model.eval()
27
  log("✅ Ana model eval() çağrıldı")
28
  log(f"📦 Intent modeli indiriliyor: {s_config.INTENT_MODEL_ID}")
29
  _ = AutoTokenizer.from_pretrained(s_config.INTENT_MODEL_ID)
@@ -35,9 +44,12 @@ def setup_model(s_config):
35
  traceback.print_exc()
36
 
37
  async def generate_response(text, app_config):
 
 
 
 
38
  messages = [{"role": "user", "content": text}]
39
  encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
40
- eos_token = tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
41
  input_ids = encodeds.to(model.device)
42
  attention_mask = (input_ids != tokenizer.pad_token_id).long()
43
 
@@ -47,7 +59,7 @@ async def generate_response(text, app_config):
47
  attention_mask=attention_mask,
48
  max_new_tokens=128,
49
  do_sample=app_config.USE_SAMPLING,
50
- eos_token_id=eos_token,
51
  pad_token_id=tokenizer.pad_token_id,
52
  return_dict_in_generate=True,
53
  output_scores=True
@@ -67,9 +79,3 @@ async def generate_response(text, app_config):
67
  decoded = decoded[start + len(tag):].strip()
68
  break
69
  return decoded, top_conf
70
-
71
- def get_model():
72
- return model
73
-
74
- def get_tokenizer():
75
- return tokenizer
 
3
  from log import log
4
  from pydantic import BaseModel
5
 
6
+ _model = None
7
+ _tokenizer = None
8
+ _eos_token_id = None
9
+
10
+ def get_model():
11
+ return _model
12
+
13
+ def get_tokenizer():
14
+ return _tokenizer
15
+
16
+ def get_eos_token_id():
17
+ return _eos_token_id
18
 
19
  class Message(BaseModel):
20
  user_input: str
21
 
22
  def setup_model(s_config):
23
+ global _model, _tokenizer, _eos_token_id
24
  try:
25
  log("🧠 setup_model() başladı")
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  log(f"📡 Kullanılan cihaz: {device}")
28
+ _tokenizer = AutoTokenizer.from_pretrained(s_config.MODEL_BASE, use_fast=False)
29
  log("📦 Tokenizer yüklendi. Ana model indiriliyor...")
30
+ _model = AutoModelForCausalLM.from_pretrained(s_config.MODEL_BASE, torch_dtype=torch.float32).to(device)
31
  log("📦 Ana model indirildi ve yüklendi. eval() çağırılıyor...")
32
+ _tokenizer.pad_token = _tokenizer.pad_token or _tokenizer.eos_token
33
+ _model.config.pad_token_id = _tokenizer.pad_token_id
34
+ _eos_token_id = _tokenizer("<|im_end|>", add_special_tokens=False)["input_ids"][0]
35
+ _model.eval()
36
  log("✅ Ana model eval() çağrıldı")
37
  log(f"📦 Intent modeli indiriliyor: {s_config.INTENT_MODEL_ID}")
38
  _ = AutoTokenizer.from_pretrained(s_config.INTENT_MODEL_ID)
 
44
  traceback.print_exc()
45
 
46
  async def generate_response(text, app_config):
47
+ model = get_model()
48
+ tokenizer = get_tokenizer()
49
+ eos_token_id = get_eos_token_id()
50
+
51
  messages = [{"role": "user", "content": text}]
52
  encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
 
53
  input_ids = encodeds.to(model.device)
54
  attention_mask = (input_ids != tokenizer.pad_token_id).long()
55
 
 
59
  attention_mask=attention_mask,
60
  max_new_tokens=128,
61
  do_sample=app_config.USE_SAMPLING,
62
+ eos_token_id=eos_token_id,
63
  pad_token_id=tokenizer.pad_token_id,
64
  return_dict_in_generate=True,
65
  output_scores=True
 
79
  decoded = decoded[start + len(tag):].strip()
80
  break
81
  return decoded, top_conf