Yjhhh commited on
Commit
9967326
verified
1 Parent(s): 03ca244

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -138
app.py CHANGED
@@ -1,19 +1,12 @@
1
- from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from llama_cpp import Llama
4
- from concurrent.futures import ThreadPoolExecutor, as_completed
5
- from tqdm import tqdm
6
  import uvicorn
7
- from dotenv import load_dotenv
8
- import io
9
  import requests
 
10
  import asyncio
11
- import time
12
-
13
- # Cargar variables de entorno
14
- load_dotenv()
15
 
16
- # Inicializar aplicaci贸n FastAPI
17
  app = FastAPI()
18
 
19
  # Configuraci贸n de los modelos
@@ -31,149 +24,119 @@ model_configs = [
31
  {"repo_id": "Ffftdtd5dtft/gemma-2-2b-it-Q2_K-GGUF", "filename": "gemma-2-2b-it-q2_k.gguf", "name": "Gemma 2-2B IT"},
32
  {"repo_id": "Ffftdtd5dtft/sarvam-2b-v0.5-Q2_K-GGUF", "filename": "sarvam-2b-v0.5-q2_k.gguf", "name": "Sarvam 2B v0.5"},
33
  {"repo_id": "Ffftdtd5dtft/WizardLM-13B-Uncensored-Q2_K-GGUF", "filename": "wizardlm-13b-uncensored-q2_k.gguf", "name": "WizardLM 13B Uncensored"},
 
34
  {"repo_id": "Ffftdtd5dtft/WizardLM-7B-Uncensored-Q2_K-GGUF", "filename": "wizardlm-7b-uncensored-q2_k.gguf", "name": "WizardLM 7B Uncensored"},
35
  {"repo_id": "Ffftdtd5dtft/Qwen2-Math-7B-Instruct-Q2_K-GGUF", "filename": "qwen2-math-7b-instruct-q2_k.gguf", "name": "Qwen2 Math 7B Instruct"}
36
  ]
37
 
38
- # Clase para gestionar modelos
39
  class ModelManager:
40
  def __init__(self):
41
- self.models = []
42
- self.configs = {}
 
 
 
43
 
44
  async def download_model_to_memory(self, model_config):
45
- print(f"Descargando modelo: {model_config['name']}...")
46
  url = f"https://huggingface.co/{model_config['repo_id']}/resolve/main/{model_config['filename']}"
47
- response = requests.get(url)
48
- if response.status_code == 200:
49
- model_file = io.BytesIO(response.content)
50
- return model_file
51
- else:
52
- raise Exception(f"Error al descargar el modelo: {response.status_code}")
53
 
54
  async def load_model(self, model_config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  try:
56
- start_time = time.time()
57
- model_file = await self.download_model_to_memory(model_config)
58
- print(f"Cargando modelo: {model_config['name']}...")
59
-
60
- # Cargar el modelo usando llama_cpp
61
- llama = await asyncio.get_event_loop().run_in_executor(
62
  None,
63
- lambda: Llama.from_pretrained(model_file)
64
  )
65
-
66
- # Simulaci贸n de divisi贸n de carga si el tiempo excede 1 segundo
67
- async def load_part(part):
68
- # Esta funci贸n simula la carga de una parte del modelo
69
- await asyncio.sleep(0.1) # Simula un peque帽o retraso en la carga
70
-
71
- if time.time() - start_time > 1:
72
- print(f"Modelo {model_config['name']} tard贸 m谩s de 1 segundo en cargarse, dividiendo la carga...")
73
- await asyncio.gather(*(load_part(part) for part in range(5))) # Simulaci贸n de divisi贸n en 5 partes
74
-
75
- tokenizer = llama.tokenizer
76
-
77
- # Almacenar tokens y tokenizer en la RAM
78
- model_data = {
79
- 'model': llama,
80
- 'tokenizer': tokenizer,
81
- 'pad_token': tokenizer.pad_token,
82
- 'pad_token_id': tokenizer.pad_token_id,
83
- 'eos_token': tokenizer.eos_token,
84
- 'eos_token_id': tokenizer.eos_token_id,
85
- 'bos_token': tokenizer.bos_token,
86
- 'bos_token_id': tokenizer.bos_token_id,
87
- 'unk_token': tokenizer.unk_token,
88
- 'unk_token_id': tokenizer.unk_token_id
89
- }
90
-
91
- self.models.append({"model_data": model_data, "name": model_config['name']})
92
  except Exception as e:
93
- print(f"Error al cargar el modelo: {e}")
94
-
95
- async def load_all_models(self):
96
- print("Iniciando carga de modelos...")
97
- start_time = time.time()
98
- tasks = [self.load_model(config) for config in model_configs]
99
- await asyncio.gather(*tasks)
100
- end_time = time.time()
101
- print(f"Todos los modelos han sido cargados en {end_time - start_time:.2f} segundos.")
102
-
103
- # Instanciar ModelManager y cargar modelos
104
- model_manager = ModelManager()
105
-
106
- @app.on_event("startup")
107
- async def startup_event():
108
- await model_manager.load_all_models()
109
-
110
- # Modelo global para la solicitud de chat
111
- class ChatRequest(BaseModel):
112
- message: str
113
- top_k: int = 50
114
- top_p: float = 0.95
115
- temperature: float = 0.7
116
-
117
- # L铆mite de tokens para respuestas
118
- TOKEN_LIMIT = 1000 # Define el l铆mite de tokens permitido por respuesta
119
-
120
- # Funci贸n para generar respuestas de chat
121
- async def generate_chat_response(request, model_data):
122
  try:
123
- user_input = normalize_input(request.message)
124
- llama = model_data['model_data']['model']
125
- tokenizer = model_data['model_data']['tokenizer']
126
-
127
- # Generar respuesta de manera r谩pida
128
- response = await asyncio.get_event_loop().run_in_executor(
129
- None,
130
- lambda: llama.generate(user_input, max_length=TOKEN_LIMIT, do_sample=True, top_k=request.top_k, top_p=request.top_p, temperature=request.temperature)
131
- )
132
- generated_text = response['generated_text']
133
- # Dividir respuesta larga
134
- split_response = split_long_response(generated_text)
135
- return {"response": split_response, "literal": user_input, "model_name": model_data['name']}
136
  except Exception as e:
137
- print(f"Error al generar la respuesta: {e}")
138
- return {"response": "Error al generar la respuesta", "literal": user_input, "model_name": model_data['name']}
139
-
140
- def split_long_response(response):
141
- """ Divide la respuesta en partes m谩s peque帽as si excede el l铆mite de tokens. """
142
- parts = []
143
- while len(response) > TOKEN_LIMIT:
144
- part = response[:TOKEN_LIMIT]
145
- response = response[TOKEN_LIMIT:]
146
- parts.append(part.strip())
147
- if response:
148
- parts.append(response.strip())
149
- return '\n'.join(parts)
150
-
151
- def remove_duplicates(text):
152
- """ Elimina duplicados en el texto. """
153
- lines = text.splitlines()
154
- unique_lines = list(dict.fromkeys(lines))
155
- return '\n'.join(unique_lines)
156
-
157
- def remove_repetitive_responses(responses):
158
- unique_responses = []
159
- seen_responses = set()
160
- for response in responses:
161
- normalized_response = remove_duplicates(response['response'])
162
- if normalized_response not in seen_responses:
163
- seen_responses.add(normalized_response)
164
- response['response'] = normalized_response
165
- unique_responses.append(response)
166
- return unique_responses
167
 
168
- @app.post("/chat")
169
- async def chat(request: ChatRequest):
170
- results = []
171
- for model_data in model_manager.models:
172
- response = await generate_chat_response(request, model_data)
173
- results.append(response)
174
- unique_results = remove_repetitive_responses(results)
175
- return {"results": unique_results}
176
-
177
- # Ejecutar la aplicaci贸n FastAPI
178
- if __name__ == "__main__":
179
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
  from pydantic import BaseModel
 
 
 
3
  import uvicorn
 
 
4
  import requests
5
+ import io
6
  import asyncio
7
+ from typing import List, Dict, Any
8
+ from llama_cpp import Llama # Ajusta seg煤n la biblioteca que est茅s utilizando
 
 
9
 
 
10
  app = FastAPI()
11
 
12
  # Configuraci贸n de los modelos
 
24
  {"repo_id": "Ffftdtd5dtft/gemma-2-2b-it-Q2_K-GGUF", "filename": "gemma-2-2b-it-q2_k.gguf", "name": "Gemma 2-2B IT"},
25
  {"repo_id": "Ffftdtd5dtft/sarvam-2b-v0.5-Q2_K-GGUF", "filename": "sarvam-2b-v0.5-q2_k.gguf", "name": "Sarvam 2B v0.5"},
26
  {"repo_id": "Ffftdtd5dtft/WizardLM-13B-Uncensored-Q2_K-GGUF", "filename": "wizardlm-13b-uncensored-q2_k.gguf", "name": "WizardLM 13B Uncensored"},
27
+ {"repo_id": "Ffftdtd5dtft/Qwen2-Math-72B-Instruct-Q2_K-GGUF", "filename": "qwen2-math-72b-instruct-q2_k.gguf", "name": "Qwen2 Math 72B Instruct"},
28
  {"repo_id": "Ffftdtd5dtft/WizardLM-7B-Uncensored-Q2_K-GGUF", "filename": "wizardlm-7b-uncensored-q2_k.gguf", "name": "WizardLM 7B Uncensored"},
29
  {"repo_id": "Ffftdtd5dtft/Qwen2-Math-7B-Instruct-Q2_K-GGUF", "filename": "qwen2-math-7b-instruct-q2_k.gguf", "name": "Qwen2 Math 7B Instruct"}
30
  ]
31
 
 
32
  class ModelManager:
33
  def __init__(self):
34
+ self.models = {}
35
+ self.model_parts = {}
36
+ self.load_lock = asyncio.Lock()
37
+ self.index_lock = asyncio.Lock()
38
+ self.part_size = 1024 * 1024 # Tama帽o de cada parte en bytes (1 MB)
39
 
40
  async def download_model_to_memory(self, model_config):
 
41
  url = f"https://huggingface.co/{model_config['repo_id']}/resolve/main/{model_config['filename']}"
42
+ try:
43
+ response = requests.get(url)
44
+ response.raise_for_status()
45
+ return io.BytesIO(response.content)
46
+ except requests.RequestException as e:
47
+ raise HTTPException(status_code=500, detail=f"Error al descargar el modelo: {e}")
48
 
49
  async def load_model(self, model_config):
50
+ async with self.load_lock:
51
+ try:
52
+ model_file = await self.download_model_to_memory(model_config)
53
+ llama = Llama(model_file) # Ajusta seg煤n la biblioteca y clase correctas
54
+
55
+ tokenizer = llama.tokenizer
56
+ model_data = {
57
+ 'model': llama,
58
+ 'tokenizer': tokenizer,
59
+ 'pad_token': tokenizer.pad_token,
60
+ 'pad_token_id': tokenizer.pad_token_id,
61
+ 'eos_token': tokenizer.eos_token,
62
+ 'eos_token_id': tokenizer.eos_token_id,
63
+ 'bos_token': tokenizer.bos_token,
64
+ 'bos_token_id': tokenizer.bos_token_id,
65
+ 'unk_token': tokenizer.unk_token,
66
+ 'unk_token_id': tokenizer.unk_token_id
67
+ }
68
+
69
+ self.models[model_config['name']] = model_data
70
+ await self.handle_large_model(model_config, model_file)
71
+ except Exception as e:
72
+ print(f"Error al cargar el modelo: {e}")
73
+
74
+ async def handle_large_model(self, model_config, model_file):
75
+ total_size = len(model_file.getvalue())
76
+ num_parts = (total_size + self.part_size - 1) // self.part_size
77
+
78
+ for i in range(num_parts):
79
+ start = i * self.part_size
80
+ end = min(start + self.part_size, total_size)
81
+ model_part = io.BytesIO(model_file.getvalue()[start:end])
82
+ await self.index_model_part(model_part, i)
83
+
84
+ async def index_model_part(self, model_part, part_index):
85
+ async with self.index_lock:
86
+ part_name = f"part_{part_index}"
87
+ llama_part = Llama(model_part)
88
+ self.model_parts[part_name] = llama_part
89
+
90
+ async def generate_response(self, user_input):
91
+ tasks = [self.generate_chat_response(user_input, model_data) for model_data in self.models.values()]
92
+ responses = await asyncio.gather(*tasks)
93
+ return responses
94
+
95
+ async def generate_chat_response(self, user_input, model_data):
96
  try:
97
+ llama = model_data['model']
98
+ tokenizer = model_data['tokenizer']
99
+
100
+ response = await asyncio.get_event_loop().run_in_executor(
 
 
101
  None,
102
+ lambda: llama.generate(user_input, max_length=1000, do_sample=True)
103
  )
104
+ generated_text = response['generated_text']
105
+
106
+ # Dividir el texto generado en partes si es necesario
107
+ parts = []
108
+ while len(generated_text) > 1000:
109
+ part = generated_text[:1000]
110
+ generated_text = generated_text[1000:]
111
+ parts.append(part.strip())
112
+ if generated_text:
113
+ parts.append(generated_text.strip())
114
+
115
+ return {"response": '\n'.join(parts), "model_name": model_data['name']}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  except Exception as e:
117
+ print(f"Error al generar la respuesta: {e}")
118
+ return {"response": "Error al generar la respuesta", "model_name": model_data['name']}
119
+
120
+ @app.post("/chat")
121
+ async def chat(request: Request):
122
+ body = await request.json()
123
+ user_input = body.get('message', '').strip()
124
+ if not user_input:
125
+ raise HTTPException(status_code=400, detail="El mensaje no puede estar vac铆o.")
126
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  try:
128
+ model_manager = ModelManager()
129
+ responses = await model_manager.generate_response(user_input)
130
+ return {"responses": responses}
 
 
 
 
 
 
 
 
 
 
131
  except Exception as e:
132
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ def start_uvicorn():
 
 
 
 
 
 
 
 
 
 
135
  uvicorn.run(app, host="0.0.0.0", port=7860)
136
+
137
+ if __name__ == "__main__":
138
+ loop = asyncio.get_event_loop()
139
+ model_manager = ModelManager()
140
+ tasks = [model_manager.load_model(config) for config in model_configs]
141
+ loop.run_until_complete(asyncio.gather(*tasks))
142
+ start_uvicorn()