Yhhxhfh commited on
Commit
05c34a8
·
verified ·
1 Parent(s): ae5e30e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from pydantic import BaseModel
2
  from llama_cpp import Llama
3
- from concurrent.futures import ThreadPoolExecutor
4
  import re
5
  import gradio as gr
6
  import os
@@ -9,6 +9,8 @@ from functools import lru_cache
9
  from dotenv import load_dotenv
10
  from fastapi import FastAPI, Request
11
  from fastapi.responses import JSONResponse
 
 
12
 
13
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
14
 
@@ -35,21 +37,43 @@ global_data = {
35
  }
36
 
37
  response_cache = {}
 
 
38
 
39
  class ModelManager:
40
- def __init__(self):
41
  self.models = {}
 
 
 
42
  def load_model(self, model_config):
43
  model_name = model_config['name']
 
44
  if model_name not in self.models:
45
  try:
46
- self.models[model_name] = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
47
- except Exception:
 
 
 
 
 
 
 
 
 
48
  self.models[model_name] = None
 
 
 
 
49
  def unload_model(self, model_name):
50
  if model_name in self.models and self.models[model_name] is not None:
 
 
 
51
  del self.models[model_name]
52
-
53
 
54
  model_manager = ModelManager()
55
 
@@ -86,19 +110,21 @@ async def process_message(message):
86
  return response_cache[inputs]
87
 
88
  responses = {}
89
- for config in global_data['model_configs']:
90
- model_name = config['name']
91
- model_manager.load_model(config)
92
- model = model_manager.models.get(model_name)
93
- if model:
94
- responses[model_name] = generate_model_response(model, inputs)
95
- model_manager.unload_model(model_name) #Unload immediately after use
 
 
 
96
 
97
  formatted_response = "\n\n".join([f"**{model}:**\n{response}" for model, response in responses.items()])
98
  response_cache[inputs] = formatted_response
99
  return formatted_response
100
 
101
-
102
  @app.post("/generate_multimodel")
103
  async def api_generate_multimodel(request: Request):
104
  try:
 
1
  from pydantic import BaseModel
2
  from llama_cpp import Llama
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
  import re
5
  import gradio as gr
6
  import os
 
9
  from dotenv import load_dotenv
10
  from fastapi import FastAPI, Request
11
  from fastapi.responses import JSONResponse
12
+ from queue import Queue
13
+ import pickle #Para persistencia
14
 
15
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
16
 
 
37
  }
38
 
39
  response_cache = {}
40
+ model_cache_dir = "model_cache" # Directorio para guardar modelos en disco
41
+ os.makedirs(model_cache_dir, exist_ok=True)
42
 
43
  class ModelManager:
44
+ def __init__(self, max_models=10):
45
  self.models = {}
46
+ self.max_models = max_models
47
+ self.model_cache_dir = model_cache_dir
48
+
49
  def load_model(self, model_config):
50
  model_name = model_config['name']
51
+ cache_file = os.path.join(self.model_cache_dir, f"{model_name}.pkl")
52
  if model_name not in self.models:
53
  try:
54
+ if os.path.exists(cache_file):
55
+ with open(cache_file, "rb") as f:
56
+ self.models[model_name] = pickle.load(f)
57
+ print(f"Modelo {model_name} cargado desde caché.")
58
+ else:
59
+ self.models[model_name] = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
60
+ with open(cache_file, "wb") as f:
61
+ pickle.dump(self.models[model_name], f)
62
+ print(f"Modelo {model_name} cargado y guardado en caché.")
63
+ except Exception as e:
64
+ print(f"Error al cargar el modelo {model_name}: {e}")
65
  self.models[model_name] = None
66
+
67
+ def get_model(self, model_name):
68
+ return self.models.get(model_name)
69
+
70
  def unload_model(self, model_name):
71
  if model_name in self.models and self.models[model_name] is not None:
72
+ cache_file = os.path.join(self.model_cache_dir, f"{model_name}.pkl")
73
+ with open(cache_file, "wb") as f:
74
+ pickle.dump(self.models[model_name], f)
75
  del self.models[model_name]
76
+ print(f"Modelo {model_name} descargado y guardado en caché.")
77
 
78
  model_manager = ModelManager()
79
 
 
110
  return response_cache[inputs]
111
 
112
  responses = {}
113
+ with ThreadPoolExecutor(max_workers=model_manager.max_models) as executor:
114
+ futures = [executor.submit(model_manager.load_model, config) for config in global_data['model_configs']]
115
+ for future in as_completed(futures):
116
+ future.result()
117
+
118
+ for config in global_data['model_configs']:
119
+ model = model_manager.get_model(config['name'])
120
+ if model:
121
+ responses[config['name']] = generate_model_response(model, inputs)
122
+ model_manager.unload_model(config['name'])
123
 
124
  formatted_response = "\n\n".join([f"**{model}:**\n{response}" for model, response in responses.items()])
125
  response_cache[inputs] = formatted_response
126
  return formatted_response
127
 
 
128
  @app.post("/generate_multimodel")
129
  async def api_generate_multimodel(request: Request):
130
  try: