Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import os
|
|
|
2 |
import torch
|
3 |
from fastapi import FastAPI, HTTPException
|
4 |
from fastapi.responses import StreamingResponse
|
@@ -10,7 +11,6 @@ from transformers import (
|
|
10 |
GenerationConfig,
|
11 |
StoppingCriteriaList,
|
12 |
StoppingCriteria,
|
13 |
-
TextStreamer,
|
14 |
pipeline
|
15 |
)
|
16 |
import uvicorn
|
@@ -19,8 +19,28 @@ from io import BytesIO
|
|
19 |
import soundfile as sf
|
20 |
import traceback
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
app = FastAPI()
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
class GenerateRequest(BaseModel):
|
25 |
model_name: str
|
26 |
input_text: str = ""
|
@@ -54,16 +74,17 @@ class LocalModelLoader:
|
|
54 |
self.loaded_models = {}
|
55 |
|
56 |
async def load_model_and_tokenizer(self, model_name):
|
|
|
57 |
if model_name in self.loaded_models:
|
58 |
return self.loaded_models[model_name]
|
59 |
try:
|
60 |
config = AutoConfig.from_pretrained(model_name)
|
61 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
62 |
-
|
63 |
-
|
|
|
64 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
65 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
66 |
-
|
67 |
self.loaded_models[model_name] = (model, tokenizer)
|
68 |
return model, tokenizer
|
69 |
except Exception as e:
|
@@ -81,10 +102,10 @@ class StopOnTokens(StoppingCriteria):
|
|
81 |
return True
|
82 |
return False
|
83 |
|
84 |
-
|
85 |
@app.post("/generate")
|
86 |
async def generate(request: GenerateRequest):
|
87 |
try:
|
|
|
88 |
model_name = request.model_name
|
89 |
input_text = request.input_text
|
90 |
task_type = request.task_type
|
@@ -119,66 +140,79 @@ async def generate(request: GenerateRequest):
|
|
119 |
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
|
120 |
|
121 |
if stream:
|
122 |
-
|
|
|
123 |
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
|
124 |
media_type="text/plain"
|
125 |
)
|
126 |
else:
|
127 |
-
generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
|
128 |
-
|
129 |
|
|
|
|
|
130 |
except Exception as e:
|
131 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
132 |
|
133 |
-
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=
|
|
|
|
|
|
|
|
|
134 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
135 |
encoded_input_len = encoded_input["input_ids"].shape[-1]
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
154 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
|
|
|
164 |
return generated_text
|
165 |
|
166 |
-
|
167 |
@app.post("/generate-image")
|
168 |
async def generate_image(request: GenerateRequest):
|
169 |
try:
|
170 |
validated_body = request
|
171 |
-
device =
|
172 |
|
173 |
-
|
174 |
-
|
|
|
|
|
175 |
|
176 |
img_byte_arr = BytesIO()
|
177 |
image.save(img_byte_arr, format="PNG")
|
178 |
img_byte_arr.seek(0)
|
179 |
-
|
180 |
return StreamingResponse(img_byte_arr, media_type="image/png")
|
181 |
-
|
182 |
except Exception as e:
|
183 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
184 |
|
@@ -186,18 +220,19 @@ async def generate_image(request: GenerateRequest):
|
|
186 |
async def generate_text_to_speech(request: GenerateRequest):
|
187 |
try:
|
188 |
validated_body = request
|
189 |
-
device =
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
194 |
|
195 |
audio_byte_arr = BytesIO()
|
196 |
sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
|
197 |
audio_byte_arr.seek(0)
|
198 |
-
|
199 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
200 |
-
|
201 |
except Exception as e:
|
202 |
traceback.print_exc()
|
203 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
@@ -206,19 +241,20 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
206 |
async def generate_video(request: GenerateRequest):
|
207 |
try:
|
208 |
validated_body = request
|
209 |
-
device =
|
210 |
-
|
211 |
-
|
|
|
|
|
212 |
|
213 |
video_byte_arr = BytesIO()
|
214 |
video.save(video_byte_arr)
|
215 |
video_byte_arr.seek(0)
|
216 |
-
|
217 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
218 |
-
|
219 |
except Exception as e:
|
220 |
traceback.print_exc()
|
221 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
222 |
|
223 |
if __name__ == "__main__":
|
224 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
import os
|
2 |
+
import gc
|
3 |
import torch
|
4 |
from fastapi import FastAPI, HTTPException
|
5 |
from fastapi.responses import StreamingResponse
|
|
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList,
|
13 |
StoppingCriteria,
|
|
|
14 |
pipeline
|
15 |
)
|
16 |
import uvicorn
|
|
|
19 |
import soundfile as sf
|
20 |
import traceback
|
21 |
|
22 |
+
# --- Bloque para limitar la RAM al 1% (s贸lo en entornos Unix) ---
|
23 |
+
try:
|
24 |
+
import psutil
|
25 |
+
import resource
|
26 |
+
total_memory = psutil.virtual_memory().total
|
27 |
+
limit = int(total_memory * 0.01) # 1% del total en bytes
|
28 |
+
resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
|
29 |
+
print(f"Memory limit set to {limit} bytes (1% of total system memory).")
|
30 |
+
except Exception as e:
|
31 |
+
print("No se pudo establecer el l铆mite de memoria:", e)
|
32 |
+
# --- Fin del bloque de limitaci贸n de RAM ---
|
33 |
+
|
34 |
app = FastAPI()
|
35 |
|
36 |
+
# Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
|
37 |
+
async def cleanup_memory(device: str):
|
38 |
+
gc.collect()
|
39 |
+
if device == "cuda":
|
40 |
+
torch.cuda.empty_cache()
|
41 |
+
# Espera breve para permitir la liberaci贸n de memoria
|
42 |
+
await asyncio.sleep(0.01)
|
43 |
+
|
44 |
class GenerateRequest(BaseModel):
|
45 |
model_name: str
|
46 |
input_text: str = ""
|
|
|
74 |
self.loaded_models = {}
|
75 |
|
76 |
async def load_model_and_tokenizer(self, model_name):
|
77 |
+
# Se utiliza el modelo indicado por el usuario
|
78 |
if model_name in self.loaded_models:
|
79 |
return self.loaded_models[model_name]
|
80 |
try:
|
81 |
config = AutoConfig.from_pretrained(model_name)
|
82 |
tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
|
83 |
+
# Se usa torch_dtype=torch.float16 para reducir la huella en memoria (si es posible)
|
84 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16)
|
85 |
+
# Ajuste del token de relleno si es necesario
|
86 |
if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
|
87 |
tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
|
|
|
88 |
self.loaded_models[model_name] = (model, tokenizer)
|
89 |
return model, tokenizer
|
90 |
except Exception as e:
|
|
|
102 |
return True
|
103 |
return False
|
104 |
|
|
|
105 |
@app.post("/generate")
|
106 |
async def generate(request: GenerateRequest):
|
107 |
try:
|
108 |
+
# Extraer par谩metros del request
|
109 |
model_name = request.model_name
|
110 |
input_text = request.input_text
|
111 |
task_type = request.task_type
|
|
|
140 |
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
|
141 |
|
142 |
if stream:
|
143 |
+
# Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
|
144 |
+
response = StreamingResponse(
|
145 |
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
|
146 |
media_type="text/plain"
|
147 |
)
|
148 |
else:
|
149 |
+
generated_text = await generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
|
150 |
+
response = StreamingResponse(iter([generated_text]), media_type="text/plain")
|
151 |
|
152 |
+
await cleanup_memory(device)
|
153 |
+
return response
|
154 |
except Exception as e:
|
155 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
156 |
|
157 |
+
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=64):
|
158 |
+
"""
|
159 |
+
Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real.
|
160 |
+
"""
|
161 |
+
# Limitar la entrada para minimizar el uso de memoria
|
162 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
163 |
encoded_input_len = encoded_input["input_ids"].shape[-1]
|
164 |
|
165 |
+
# Con torch.no_grad() se evita almacenar informaci贸n para gradientes
|
166 |
+
with torch.no_grad():
|
167 |
+
# Se genera el texto de forma iterativa (streaming)
|
168 |
+
for output in model.generate(
|
169 |
+
**encoded_input,
|
170 |
+
generation_config=generation_config,
|
171 |
+
stopping_criteria=stopping_criteria_list,
|
172 |
+
stream=True,
|
173 |
+
return_dict_in_generate=True,
|
174 |
+
output_scores=True,
|
175 |
+
):
|
176 |
+
# Se extraen solo los tokens generados (excluyendo la entrada)
|
177 |
+
new_tokens = output.sequences[:, encoded_input_len:]
|
178 |
+
for token_batch in new_tokens:
|
179 |
+
token = tokenizer.decode(token_batch, skip_special_tokens=True)
|
180 |
+
if token:
|
181 |
+
# Se env铆a cada token inmediatamente
|
182 |
+
yield token
|
183 |
+
await asyncio.sleep(chunk_delay)
|
184 |
+
await cleanup_memory(device)
|
185 |
+
|
186 |
+
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
|
187 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
188 |
+
with torch.no_grad():
|
189 |
+
output = model.generate(
|
190 |
+
**encoded_input,
|
191 |
+
generation_config=generation_config,
|
192 |
+
stopping_criteria=stopping_criteria_list,
|
193 |
+
return_dict_in_generate=True,
|
194 |
+
output_scores=True
|
195 |
+
)
|
196 |
generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
|
197 |
+
await cleanup_memory(device)
|
198 |
return generated_text
|
199 |
|
|
|
200 |
@app.post("/generate-image")
|
201 |
async def generate_image(request: GenerateRequest):
|
202 |
try:
|
203 |
validated_body = request
|
204 |
+
device = 0 if torch.cuda.is_available() else -1 # pipeline espera int para CUDA
|
205 |
|
206 |
+
# Ejecutar el pipeline en un hilo separado
|
207 |
+
image_generator = await asyncio.to_thread(pipeline, "text-to-image", model=validated_body.model_name, device=device)
|
208 |
+
results = await asyncio.to_thread(image_generator, validated_body.input_text)
|
209 |
+
image = results[0]
|
210 |
|
211 |
img_byte_arr = BytesIO()
|
212 |
image.save(img_byte_arr, format="PNG")
|
213 |
img_byte_arr.seek(0)
|
214 |
+
await cleanup_memory("cuda" if device == 0 else "cpu")
|
215 |
return StreamingResponse(img_byte_arr, media_type="image/png")
|
|
|
216 |
except Exception as e:
|
217 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
218 |
|
|
|
220 |
async def generate_text_to_speech(request: GenerateRequest):
|
221 |
try:
|
222 |
validated_body = request
|
223 |
+
device = 0 if torch.cuda.is_available() else -1
|
224 |
|
225 |
+
# Ejecutar el pipeline en un hilo separado
|
226 |
+
tts_generator = await asyncio.to_thread(pipeline, "text-to-speech", model=validated_body.model_name, device=device)
|
227 |
+
tts_results = await asyncio.to_thread(tts_generator, validated_body.input_text)
|
228 |
+
audio = tts_results
|
229 |
+
sampling_rate = tts_generator.sampling_rate
|
230 |
|
231 |
audio_byte_arr = BytesIO()
|
232 |
sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
|
233 |
audio_byte_arr.seek(0)
|
234 |
+
await cleanup_memory("cuda" if device == 0 else "cpu")
|
235 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
|
|
236 |
except Exception as e:
|
237 |
traceback.print_exc()
|
238 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
|
|
241 |
async def generate_video(request: GenerateRequest):
|
242 |
try:
|
243 |
validated_body = request
|
244 |
+
device = 0 if torch.cuda.is_available() else -1
|
245 |
+
|
246 |
+
# Ejecutar el pipeline en un hilo separado
|
247 |
+
video_generator = await asyncio.to_thread(pipeline, "text-to-video", model=validated_body.model_name, device=device)
|
248 |
+
video = await asyncio.to_thread(video_generator, validated_body.input_text)
|
249 |
|
250 |
video_byte_arr = BytesIO()
|
251 |
video.save(video_byte_arr)
|
252 |
video_byte_arr.seek(0)
|
253 |
+
await cleanup_memory("cuda" if device == 0 else "cpu")
|
254 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
|
|
255 |
except Exception as e:
|
256 |
traceback.print_exc()
|
257 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
258 |
|
259 |
if __name__ == "__main__":
|
260 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|