Hjgugugjhuhjggg commited on
Commit
3ef7ee3
verified
1 Parent(s): bd394da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -52
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
  from fastapi import FastAPI, HTTPException
4
  from fastapi.responses import StreamingResponse
@@ -10,7 +11,6 @@ from transformers import (
10
  GenerationConfig,
11
  StoppingCriteriaList,
12
  StoppingCriteria,
13
- TextStreamer,
14
  pipeline
15
  )
16
  import uvicorn
@@ -19,8 +19,28 @@ from io import BytesIO
19
  import soundfile as sf
20
  import traceback
21
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  app = FastAPI()
23
 
 
 
 
 
 
 
 
 
24
  class GenerateRequest(BaseModel):
25
  model_name: str
26
  input_text: str = ""
@@ -54,16 +74,17 @@ class LocalModelLoader:
54
  self.loaded_models = {}
55
 
56
  async def load_model_and_tokenizer(self, model_name):
 
57
  if model_name in self.loaded_models:
58
  return self.loaded_models[model_name]
59
  try:
60
  config = AutoConfig.from_pretrained(model_name)
61
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
62
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
63
-
 
64
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
65
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
66
-
67
  self.loaded_models[model_name] = (model, tokenizer)
68
  return model, tokenizer
69
  except Exception as e:
@@ -81,10 +102,10 @@ class StopOnTokens(StoppingCriteria):
81
  return True
82
  return False
83
 
84
-
85
  @app.post("/generate")
86
  async def generate(request: GenerateRequest):
87
  try:
 
88
  model_name = request.model_name
89
  input_text = request.input_text
90
  task_type = request.task_type
@@ -119,66 +140,79 @@ async def generate(request: GenerateRequest):
119
  stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
120
 
121
  if stream:
122
- return StreamingResponse(
 
123
  stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
124
  media_type="text/plain"
125
  )
126
  else:
127
- generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
128
- return StreamingResponse(iter([generated_text]), media_type="text/plain")
129
 
 
 
130
  except Exception as e:
131
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
132
 
133
- async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
 
 
 
 
134
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
  encoded_input_len = encoded_input["input_ids"].shape[-1]
136
 
137
- for output in model.generate(
138
- **encoded_input,
139
- generation_config=generation_config,
140
- stopping_criteria=stopping_criteria_list,
141
- stream=True,
142
- return_dict_in_generate=True,
143
- output_scores=True,
144
- ):
145
- new_tokens = output.sequences[:, encoded_input_len:]
146
- for token_batch in new_tokens:
147
- token = tokenizer.decode(token_batch, skip_special_tokens=True)
148
- if token:
149
- yield token
150
- await asyncio.sleep(chunk_delay)
151
-
152
-
153
- async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=2048):
 
 
 
 
 
154
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
155
-
156
- output = model.generate(
157
- **encoded_input,
158
- generation_config=generation_config,
159
- stopping_criteria=stopping_criteria_list,
160
- return_dict_in_generate=True,
161
- output_scores=True
162
- )
163
  generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
 
164
  return generated_text
165
 
166
-
167
  @app.post("/generate-image")
168
  async def generate_image(request: GenerateRequest):
169
  try:
170
  validated_body = request
171
- device = "cuda" if torch.cuda.is_available() else "cpu"
172
 
173
- image_generator = pipeline("text-to-image", model=validated_body.model_name, device=device)
174
- image = image_generator(validated_body.input_text)[0]
 
 
175
 
176
  img_byte_arr = BytesIO()
177
  image.save(img_byte_arr, format="PNG")
178
  img_byte_arr.seek(0)
179
-
180
  return StreamingResponse(img_byte_arr, media_type="image/png")
181
-
182
  except Exception as e:
183
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
184
 
@@ -186,18 +220,19 @@ async def generate_image(request: GenerateRequest):
186
  async def generate_text_to_speech(request: GenerateRequest):
187
  try:
188
  validated_body = request
189
- device = "cuda" if torch.cuda.is_available() else "cpu"
190
 
191
- audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
192
- audio = audio_generator(validated_body.input_text)
193
- sampling_rate = audio_generator.sampling_rate
 
 
194
 
195
  audio_byte_arr = BytesIO()
196
  sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
197
  audio_byte_arr.seek(0)
198
-
199
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
200
-
201
  except Exception as e:
202
  traceback.print_exc()
203
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@@ -206,19 +241,20 @@ async def generate_text_to_speech(request: GenerateRequest):
206
  async def generate_video(request: GenerateRequest):
207
  try:
208
  validated_body = request
209
- device = "cuda" if torch.cuda.is_available() else "cpu"
210
- video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
211
- video = video_generator(validated_body.input_text)
 
 
212
 
213
  video_byte_arr = BytesIO()
214
  video.save(video_byte_arr)
215
  video_byte_arr.seek(0)
216
-
217
  return StreamingResponse(video_byte_arr, media_type="video/mp4")
218
-
219
  except Exception as e:
220
  traceback.print_exc()
221
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
222
 
223
  if __name__ == "__main__":
224
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ import gc
3
  import torch
4
  from fastapi import FastAPI, HTTPException
5
  from fastapi.responses import StreamingResponse
 
11
  GenerationConfig,
12
  StoppingCriteriaList,
13
  StoppingCriteria,
 
14
  pipeline
15
  )
16
  import uvicorn
 
19
  import soundfile as sf
20
  import traceback
21
 
22
+ # --- Bloque para limitar la RAM al 1% (s贸lo en entornos Unix) ---
23
+ try:
24
+ import psutil
25
+ import resource
26
+ total_memory = psutil.virtual_memory().total
27
+ limit = int(total_memory * 0.01) # 1% del total en bytes
28
+ resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
29
+ print(f"Memory limit set to {limit} bytes (1% of total system memory).")
30
+ except Exception as e:
31
+ print("No se pudo establecer el l铆mite de memoria:", e)
32
+ # --- Fin del bloque de limitaci贸n de RAM ---
33
+
34
  app = FastAPI()
35
 
36
+ # Funci贸n as铆ncrona para limpiar la memoria (RAM y cach茅 CUDA)
37
+ async def cleanup_memory(device: str):
38
+ gc.collect()
39
+ if device == "cuda":
40
+ torch.cuda.empty_cache()
41
+ # Espera breve para permitir la liberaci贸n de memoria
42
+ await asyncio.sleep(0.01)
43
+
44
  class GenerateRequest(BaseModel):
45
  model_name: str
46
  input_text: str = ""
 
74
  self.loaded_models = {}
75
 
76
  async def load_model_and_tokenizer(self, model_name):
77
+ # Se utiliza el modelo indicado por el usuario
78
  if model_name in self.loaded_models:
79
  return self.loaded_models[model_name]
80
  try:
81
  config = AutoConfig.from_pretrained(model_name)
82
  tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
83
+ # Se usa torch_dtype=torch.float16 para reducir la huella en memoria (si es posible)
84
+ model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16)
85
+ # Ajuste del token de relleno si es necesario
86
  if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
87
  tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
 
88
  self.loaded_models[model_name] = (model, tokenizer)
89
  return model, tokenizer
90
  except Exception as e:
 
102
  return True
103
  return False
104
 
 
105
  @app.post("/generate")
106
  async def generate(request: GenerateRequest):
107
  try:
108
+ # Extraer par谩metros del request
109
  model_name = request.model_name
110
  input_text = request.input_text
111
  task_type = request.task_type
 
140
  stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
141
 
142
  if stream:
143
+ # Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
144
+ response = StreamingResponse(
145
  stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
146
  media_type="text/plain"
147
  )
148
  else:
149
+ generated_text = await generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
150
+ response = StreamingResponse(iter([generated_text]), media_type="text/plain")
151
 
152
+ await cleanup_memory(device)
153
+ return response
154
  except Exception as e:
155
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
156
 
157
+ async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=64):
158
+ """
159
+ Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real.
160
+ """
161
+ # Limitar la entrada para minimizar el uso de memoria
162
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
163
  encoded_input_len = encoded_input["input_ids"].shape[-1]
164
 
165
+ # Con torch.no_grad() se evita almacenar informaci贸n para gradientes
166
+ with torch.no_grad():
167
+ # Se genera el texto de forma iterativa (streaming)
168
+ for output in model.generate(
169
+ **encoded_input,
170
+ generation_config=generation_config,
171
+ stopping_criteria=stopping_criteria_list,
172
+ stream=True,
173
+ return_dict_in_generate=True,
174
+ output_scores=True,
175
+ ):
176
+ # Se extraen solo los tokens generados (excluyendo la entrada)
177
+ new_tokens = output.sequences[:, encoded_input_len:]
178
+ for token_batch in new_tokens:
179
+ token = tokenizer.decode(token_batch, skip_special_tokens=True)
180
+ if token:
181
+ # Se env铆a cada token inmediatamente
182
+ yield token
183
+ await asyncio.sleep(chunk_delay)
184
+ await cleanup_memory(device)
185
+
186
+ async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
187
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
188
+ with torch.no_grad():
189
+ output = model.generate(
190
+ **encoded_input,
191
+ generation_config=generation_config,
192
+ stopping_criteria=stopping_criteria_list,
193
+ return_dict_in_generate=True,
194
+ output_scores=True
195
+ )
196
  generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
197
+ await cleanup_memory(device)
198
  return generated_text
199
 
 
200
  @app.post("/generate-image")
201
  async def generate_image(request: GenerateRequest):
202
  try:
203
  validated_body = request
204
+ device = 0 if torch.cuda.is_available() else -1 # pipeline espera int para CUDA
205
 
206
+ # Ejecutar el pipeline en un hilo separado
207
+ image_generator = await asyncio.to_thread(pipeline, "text-to-image", model=validated_body.model_name, device=device)
208
+ results = await asyncio.to_thread(image_generator, validated_body.input_text)
209
+ image = results[0]
210
 
211
  img_byte_arr = BytesIO()
212
  image.save(img_byte_arr, format="PNG")
213
  img_byte_arr.seek(0)
214
+ await cleanup_memory("cuda" if device == 0 else "cpu")
215
  return StreamingResponse(img_byte_arr, media_type="image/png")
 
216
  except Exception as e:
217
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
218
 
 
220
  async def generate_text_to_speech(request: GenerateRequest):
221
  try:
222
  validated_body = request
223
+ device = 0 if torch.cuda.is_available() else -1
224
 
225
+ # Ejecutar el pipeline en un hilo separado
226
+ tts_generator = await asyncio.to_thread(pipeline, "text-to-speech", model=validated_body.model_name, device=device)
227
+ tts_results = await asyncio.to_thread(tts_generator, validated_body.input_text)
228
+ audio = tts_results
229
+ sampling_rate = tts_generator.sampling_rate
230
 
231
  audio_byte_arr = BytesIO()
232
  sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
233
  audio_byte_arr.seek(0)
234
+ await cleanup_memory("cuda" if device == 0 else "cpu")
235
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
 
236
  except Exception as e:
237
  traceback.print_exc()
238
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
241
  async def generate_video(request: GenerateRequest):
242
  try:
243
  validated_body = request
244
+ device = 0 if torch.cuda.is_available() else -1
245
+
246
+ # Ejecutar el pipeline en un hilo separado
247
+ video_generator = await asyncio.to_thread(pipeline, "text-to-video", model=validated_body.model_name, device=device)
248
+ video = await asyncio.to_thread(video_generator, validated_body.input_text)
249
 
250
  video_byte_arr = BytesIO()
251
  video.save(video_byte_arr)
252
  video_byte_arr.seek(0)
253
+ await cleanup_memory("cuda" if device == 0 else "cpu")
254
  return StreamingResponse(video_byte_arr, media_type="video/mp4")
 
255
  except Exception as e:
256
  traceback.print_exc()
257
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
258
 
259
  if __name__ == "__main__":
260
+ uvicorn.run(app, host="0.0.0.0", port=7860)