Hjgugugjhuhjggg commited on
Commit
1cb967f
1 Parent(s): 84d1dae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -73
app.py CHANGED
@@ -1,10 +1,9 @@
1
  import gc
2
  import psutil
3
  import os
4
- import time
5
  import torch
6
  from fastapi import FastAPI
7
- from vllm import VLLM
8
  from chatgptcache import cache
9
  from sklearn.feature_extraction.text import TfidfVectorizer
10
  from sklearn.metrics.pairwise import cosine_similarity
@@ -15,13 +14,14 @@ from collections import Counter
15
  import asyncio
16
  import torch.nn.utils.prune as prune
17
  from concurrent.futures import ThreadPoolExecutor
 
 
18
 
19
  nltk.download('punkt')
20
  nltk.download('stopwords')
21
 
22
  app = FastAPI()
23
 
24
- # Definir los modelos (serán cargados más tarde)
25
  model_1 = None
26
  model_2 = None
27
  model_3 = None
@@ -37,12 +37,10 @@ previous_responses_2 = []
37
  previous_responses_3 = []
38
  previous_responses_4 = []
39
 
40
- MAX_TOKENS = 2048 # Máximo de tokens para entrada y salida del modelo
41
 
42
- # Usar ThreadPoolExecutor para ejecución en paralelo
43
  executor = ThreadPoolExecutor(max_workers=4)
44
 
45
- # Configuración del dispositivo (CPU)
46
  device = torch.device("cpu")
47
 
48
  def get_best_response(new_response, previous_responses):
@@ -90,17 +88,16 @@ def apply_pruning(model):
90
  for name, module in model.named_modules():
91
  if isinstance(module, torch.nn.Linear):
92
  prune.random_unstructured(module, name="weight", amount=0.2)
93
- prune.remove(module, name="weight") # Opcional: Eliminar la máscara de poda para conservar los pesos podados
94
  return model
95
 
96
  def split_input(input_text, max_tokens):
97
- tokens = input_text.split() # Dividir entrada en palabras (tokens)
98
  chunks = []
99
  chunk = []
100
  total_tokens = 0
101
-
102
  for word in tokens:
103
- word_length = len(word.split()) # Estimar la longitud de los tokens
104
  if total_tokens + word_length > max_tokens:
105
  chunks.append(" ".join(chunk))
106
  chunk = [word]
@@ -108,20 +105,17 @@ def split_input(input_text, max_tokens):
108
  else:
109
  chunk.append(word)
110
  total_tokens += word_length
111
-
112
  if chunk:
113
- chunks.append(" ".join(chunk)) # Agregar el último fragmento
114
-
115
  return chunks
116
 
117
  def split_output(output_text, max_tokens):
118
- tokens = output_text.split() # Dividir salida en palabras (tokens)
119
  chunks = []
120
  chunk = []
121
  total_tokens = 0
122
-
123
  for word in tokens:
124
- word_length = len(word.split()) # Estimar la longitud de los tokens
125
  if total_tokens + word_length > max_tokens:
126
  chunks.append(" ".join(chunk))
127
  chunk = [word]
@@ -129,44 +123,48 @@ def split_output(output_text, max_tokens):
129
  else:
130
  chunk.append(word)
131
  total_tokens += word_length
132
-
133
  if chunk:
134
- chunks.append(" ".join(chunk)) # Agregar el último fragmento
135
-
136
  return chunks
137
 
138
- async def load_model_async(model_name: str):
139
- max_model_len = MAX_TOKENS # Establecer la longitud máxima del modelo (tokens)
140
- if model_name == "model_1":
141
- return VLLM("Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device=device, max_model_len=max_model_len)
142
- elif model_name == "model_2":
143
- return VLLM("Qwen/Qwen2.5-Coder-1.5B", device=device, max_model_len=max_model_len)
144
- elif model_name == "model_3":
145
- return VLLM("Qwen/Qwen2.5-3B-Instruct", device=device, max_model_len=max_model_len)
146
- elif model_name == "model_4":
147
- return VLLM("gpt2", device=device, max_model_len=max_model_len)
148
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  async def load_models():
151
  global model_1, model_2, model_3, model_4
152
- tasks = [
153
- load_model_async("model_1"),
154
- load_model_async("model_2"),
155
- load_model_async("model_3"),
156
- load_model_async("model_4"),
157
- ]
158
- results = await asyncio.gather(*tasks)
159
- model_1, model_2, model_3, model_4 = results
160
- model_1 = apply_pruning(model_1)
161
- model_2 = apply_pruning(model_2)
162
- model_3 = apply_pruning(model_3)
163
- model_4 = apply_pruning(model_4)
164
- print("Modelos cargados y podados exitosamente.")
165
 
166
  async def optimize_models_periodically():
167
  while True:
168
- await load_models() # Cargar y optimizar modelos automáticamente
169
- await asyncio.sleep(3600) # Optimizar modelos cada hora (ajustar intervalo según sea necesario)
170
 
171
  @app.on_event("startup")
172
  async def startup():
@@ -181,34 +179,16 @@ async def monitor_memory():
181
 
182
  @app.get("/generate")
183
  async def generate_response(model_name: str, input_text: str):
184
- def generate_for_model(model, input_text, cache, previous_responses):
185
- cached_output = cache.get(input_text)
186
- if cached_output:
187
- return cached_output
188
-
189
- input_chunks = split_input(input_text, MAX_TOKENS)
190
- output_text = ""
191
- prev_output = ""
192
-
193
- for chunk in input_chunks:
194
- prompt = prev_output + chunk
195
- output_text += model.generate(prompt)
196
- prev_output = output_text.split()[-50:]
197
-
198
- output_chunks = split_output(output_text, MAX_TOKENS)
199
- best_response = get_best_response(output_chunks[0], previous_responses)
200
- cache.put(input_text, best_response)
201
- previous_responses.append(best_response)
202
- return best_response
203
-
204
- result = await asyncio.get_event_loop().run_in_executor(
205
- executor,
206
- generate_for_model,
207
- model_1 if model_name == "model1" else model_2 if model_name == "model2" else model_3 if model_name == "model3" else model_4,
208
- input_text,
209
- cache_1 if model_name == "model1" else cache_2 if model_name == "model2" else cache_3 if model_name == "model3" else cache_4,
210
- previous_responses_1 if model_name == "model1" else previous_responses_2 if model_name == "model2" else previous_responses_3 if model_name == "model3" else previous_responses_4
211
- )
212
  return {f"{model_name}_output": result}
213
 
214
  @app.get("/unified_summary")
 
1
  import gc
2
  import psutil
3
  import os
 
4
  import torch
5
  from fastapi import FastAPI
6
+ from langchain.llms import VLLM
7
  from chatgptcache import cache
8
  from sklearn.feature_extraction.text import TfidfVectorizer
9
  from sklearn.metrics.pairwise import cosine_similarity
 
14
  import asyncio
15
  import torch.nn.utils.prune as prune
16
  from concurrent.futures import ThreadPoolExecutor
17
+ from langchain.prompts import PromptTemplate
18
+ from langchain.chains import LLMChain
19
 
20
  nltk.download('punkt')
21
  nltk.download('stopwords')
22
 
23
  app = FastAPI()
24
 
 
25
  model_1 = None
26
  model_2 = None
27
  model_3 = None
 
37
  previous_responses_3 = []
38
  previous_responses_4 = []
39
 
40
+ MAX_TOKENS = 2048
41
 
 
42
  executor = ThreadPoolExecutor(max_workers=4)
43
 
 
44
  device = torch.device("cpu")
45
 
46
  def get_best_response(new_response, previous_responses):
 
88
  for name, module in model.named_modules():
89
  if isinstance(module, torch.nn.Linear):
90
  prune.random_unstructured(module, name="weight", amount=0.2)
91
+ prune.remove(module, name="weight")
92
  return model
93
 
94
  def split_input(input_text, max_tokens):
95
+ tokens = input_text.split()
96
  chunks = []
97
  chunk = []
98
  total_tokens = 0
 
99
  for word in tokens:
100
+ word_length = len(word.split())
101
  if total_tokens + word_length > max_tokens:
102
  chunks.append(" ".join(chunk))
103
  chunk = [word]
 
105
  else:
106
  chunk.append(word)
107
  total_tokens += word_length
 
108
  if chunk:
109
+ chunks.append(" ".join(chunk))
 
110
  return chunks
111
 
112
  def split_output(output_text, max_tokens):
113
+ tokens = output_text.split()
114
  chunks = []
115
  chunk = []
116
  total_tokens = 0
 
117
  for word in tokens:
118
+ word_length = len(word.split())
119
  if total_tokens + word_length > max_tokens:
120
  chunks.append(" ".join(chunk))
121
  chunk = [word]
 
123
  else:
124
  chunk.append(word)
125
  total_tokens += word_length
 
126
  if chunk:
127
+ chunks.append(" ".join(chunk))
 
128
  return chunks
129
 
130
+ def create_langchain_model(model_name: str, device: torch.device, cache, previous_responses):
131
+ vllm_llm = VLLM(model_name=model_name, device=device)
132
+ template = """
133
+ You are a helpful assistant. Given the following text, generate a meaningful response:
134
+ {input_text}
135
+ """
136
+ prompt = PromptTemplate(input_variables=["input_text"], template=template)
137
+ chain = LLMChain(llm=vllm_llm, prompt=prompt)
138
+ def generate_for_model(input_text):
139
+ cached_output = cache.get(input_text)
140
+ if cached_output:
141
+ return cached_output
142
+ input_chunks = split_input(input_text, MAX_TOKENS)
143
+ output_text = ""
144
+ prev_output = ""
145
+ for chunk in input_chunks:
146
+ prompt = prev_output + chunk
147
+ output_text += chain.run(input_text=prompt)
148
+ prev_output = output_text.split()[-50:]
149
+ output_chunks = split_output(output_text, MAX_TOKENS)
150
+ best_response = get_best_response(output_chunks[0], previous_responses)
151
+ cache.put(input_text, best_response)
152
+ previous_responses.append(best_response)
153
+ return best_response
154
+ return generate_for_model
155
 
156
  async def load_models():
157
  global model_1, model_2, model_3, model_4
158
+ model_1 = create_langchain_model("Hjgugugjhuhjggg/llama-3.2-1B-spinquant-hf", device, cache_1, previous_responses_1)
159
+ model_2 = create_langchain_model("Qwen/Qwen2.5-Coder-1.5B", device, cache_2, previous_responses_2)
160
+ model_3 = create_langchain_model("Qwen/Qwen2.5-3B-Instruct", device, cache_3, previous_responses_3)
161
+ model_4 = create_langchain_model("gpt2", device, cache_4, previous_responses_4)
162
+ print("Modelos cargados exitosamente.")
 
 
 
 
 
 
 
 
163
 
164
  async def optimize_models_periodically():
165
  while True:
166
+ await load_models()
167
+ await asyncio.sleep(3600)
168
 
169
  @app.on_event("startup")
170
  async def startup():
 
179
 
180
  @app.get("/generate")
181
  async def generate_response(model_name: str, input_text: str):
182
+ if model_name == "model1":
183
+ result = await asyncio.get_event_loop().run_in_executor(executor, model_1, input_text)
184
+ elif model_name == "model2":
185
+ result = await asyncio.get_event_loop().run_in_executor(executor, model_2, input_text)
186
+ elif model_name == "model3":
187
+ result = await asyncio.get_event_loop().run_in_executor(executor, model_3, input_text)
188
+ elif model_name == "model4":
189
+ result = await asyncio.get_event_loop().run_in_executor(executor, model_4, input_text)
190
+ else:
191
+ return {"error": "Model not found"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  return {f"{model_name}_output": result}
193
 
194
  @app.get("/unified_summary")