tori29umai commited on
Commit
397e88b
1 Parent(s): fae7c5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -9
app.py CHANGED
@@ -28,13 +28,12 @@ if not os.path.exists("models"):
28
  os.makedirs("models")
29
 
30
  # 使用するモデルのファイル名を指定
31
- model_filename = "EZO-Common-9B-gemma-2-it.f16.gguf.gguf"
32
  model_path = os.path.join("models", model_filename)
33
 
34
  # モデルファイルが存在しない場合はダウンロード
35
  if not os.path.exists(model_path):
36
- dl_guff_model("models", f"https://huggingface.co/Aratako/Ninja-v1-RP-expressive-v2-GGUF/resolve/main/{model_filename}")
37
-
38
 
39
  class ConfigManager:
40
  @staticmethod
@@ -228,8 +227,10 @@ class GenTextParams:
228
 
229
  class LlamaAdapter:
230
  def __init__(self, model_path, params, n_gpu_layers):
231
- self.llm = Llama(model_path=model_path, n_ctx=params.chat_n_ctx, n_gpu_layers=n_gpu_layers)
232
  self.params = params
 
 
233
 
234
  def generate_text(self, text, author_description, gen_characters, gen_token_multiplier, instruction):
235
  max_tokens = int(gen_characters * gen_token_multiplier)
@@ -295,7 +296,6 @@ def load_model_gpu(model_type, model_path, n_gpu_layers, params):
295
  print(f"{model_type} モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
296
  return llama
297
 
298
-
299
  class CharacterMaker:
300
  def __init__(self):
301
  self.llama = None
@@ -309,8 +309,16 @@ class CharacterMaker:
309
 
310
  def load_model(self, model_type):
311
  with self.model_lock:
312
- if self.current_model == model_type:
313
- return
 
 
 
 
 
 
 
 
314
 
315
  self.model_loaded.clear()
316
  if self.llama:
@@ -318,8 +326,6 @@ class CharacterMaker:
318
  self.llama = None
319
 
320
  try:
321
- model_path = os.path.join(MODEL_DIR, self.settings[f'DEFAULT_{model_type.upper()}_MODEL'])
322
- n_gpu_layers = self.settings[f'{model_type.lower()}_n_gpu_layers']
323
  self.llama = load_model_gpu(model_type, model_path, n_gpu_layers, params)
324
  self.current_model = model_type
325
  self.model_loaded.set()
@@ -327,6 +333,7 @@ class CharacterMaker:
327
  print(f"{model_type} モデルのロード中にエラーが発生しました: {str(e)}")
328
  self.model_loaded.set()
329
 
 
330
  def generate_response(self, input_str):
331
  self.load_model('CHAT')
332
  if not self.model_loaded.wait(timeout=30) or not self.llama:
 
28
  os.makedirs("models")
29
 
30
  # 使用するモデルのファイル名を指定
31
+ model_filename = "EZO-Common-9B-gemma-2-it.f16.gguf"
32
  model_path = os.path.join("models", model_filename)
33
 
34
  # モデルファイルが存在しない場合はダウンロード
35
  if not os.path.exists(model_path):
36
+ dl_guff_model("models", f"https://huggingface.co/MCZK/EZO-Common-9B-gemma-2-it-GGUF/resolve/main//{model_filename}")
 
37
 
38
  class ConfigManager:
39
  @staticmethod
 
227
 
228
  class LlamaAdapter:
229
  def __init__(self, model_path, params, n_gpu_layers):
230
+ self.model_path = model_path
231
  self.params = params
232
+ self.n_gpu_layers = n_gpu_layers
233
+ self.llm = Llama(model_path=model_path, n_ctx=params.chat_n_ctx, n_gpu_layers=n_gpu_layers)
234
 
235
  def generate_text(self, text, author_description, gen_characters, gen_token_multiplier, instruction):
236
  max_tokens = int(gen_characters * gen_token_multiplier)
 
296
  print(f"{model_type} モデル {model_path} のロードが完了しました。(n_gpu_layers: {n_gpu_layers})")
297
  return llama
298
 
 
299
  class CharacterMaker:
300
  def __init__(self):
301
  self.llama = None
 
309
 
310
  def load_model(self, model_type):
311
  with self.model_lock:
312
+ model_path = os.path.join(MODEL_DIR, self.settings[f'DEFAULT_{model_type.upper()}_MODEL'])
313
+ n_gpu_layers = self.settings[f'{model_type.lower()}_n_gpu_layers']
314
+
315
+ # 現在のモデルが既にロードされているか、同じ設定であるかチェック
316
+ if self.llama and self.current_model == model_type:
317
+ if (self.llama.model_path == model_path and
318
+ self.llama.n_gpu_layers == n_gpu_layers):
319
+ print(f"{model_type} モデルは既にロードされています。再ロードをスキップします。")
320
+ self.model_loaded.set()
321
+ return
322
 
323
  self.model_loaded.clear()
324
  if self.llama:
 
326
  self.llama = None
327
 
328
  try:
 
 
329
  self.llama = load_model_gpu(model_type, model_path, n_gpu_layers, params)
330
  self.current_model = model_type
331
  self.model_loaded.set()
 
333
  print(f"{model_type} モデルのロード中にエラーが発生しました: {str(e)}")
334
  self.model_loaded.set()
335
 
336
+
337
  def generate_response(self, input_str):
338
  self.load_model('CHAT')
339
  if not self.model_loaded.wait(timeout=30) or not self.llama: