Update app.py
Browse files
app.py
CHANGED
@@ -24,7 +24,7 @@ try:
|
|
24 |
import psutil
|
25 |
import resource
|
26 |
total_memory = psutil.virtual_memory().total
|
27 |
-
limit = int(total_memory *
|
28 |
resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
|
29 |
print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
|
30 |
except Exception as e:
|
@@ -55,6 +55,7 @@ class GenerateRequest(BaseModel):
|
|
55 |
do_sample: bool = True
|
56 |
chunk_delay: float = 0.0
|
57 |
stop_sequences: list[str] = []
|
|
|
58 |
|
59 |
@field_validator("model_name")
|
60 |
def model_name_cannot_be_empty(cls, v):
|
@@ -119,6 +120,7 @@ async def generate(request: GenerateRequest):
|
|
119 |
do_sample = request.do_sample
|
120 |
chunk_delay = request.chunk_delay
|
121 |
stop_sequences = request.stop_sequences
|
|
|
122 |
|
123 |
model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
|
124 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -143,7 +145,7 @@ async def generate(request: GenerateRequest):
|
|
143 |
if stream:
|
144 |
# Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
|
145 |
response = StreamingResponse(
|
146 |
-
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
|
147 |
media_type="text/plain"
|
148 |
)
|
149 |
else:
|
@@ -155,31 +157,83 @@ async def generate(request: GenerateRequest):
|
|
155 |
except Exception as e:
|
156 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
157 |
|
158 |
-
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay,
|
159 |
"""
|
160 |
-
Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real.
|
|
|
161 |
"""
|
162 |
# Limitar la entrada para minimizar el uso de memoria
|
163 |
-
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=
|
|
|
|
|
|
|
|
|
164 |
|
165 |
# Con torch.no_grad() se evita almacenar informaci贸n para gradientes
|
166 |
with torch.no_grad():
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
await cleanup_memory(device)
|
182 |
|
|
|
183 |
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
|
184 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
185 |
with torch.no_grad():
|
|
|
24 |
import psutil
|
25 |
import resource
|
26 |
total_memory = psutil.virtual_memory().total
|
27 |
+
limit = int(total_memory * 90.0) # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
|
28 |
resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
|
29 |
print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
|
30 |
except Exception as e:
|
|
|
55 |
do_sample: bool = True
|
56 |
chunk_delay: float = 0.0
|
57 |
stop_sequences: list[str] = []
|
58 |
+
chunk_token_limit: int = 100 # Nuevo par谩metro para limitar tokens por chunk
|
59 |
|
60 |
@field_validator("model_name")
|
61 |
def model_name_cannot_be_empty(cls, v):
|
|
|
120 |
do_sample = request.do_sample
|
121 |
chunk_delay = request.chunk_delay
|
122 |
stop_sequences = request.stop_sequences
|
123 |
+
chunk_token_limit = request.chunk_token_limit
|
124 |
|
125 |
model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
|
126 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
145 |
if stream:
|
146 |
# Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
|
147 |
response = StreamingResponse(
|
148 |
+
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stopping_criteria_list), # Pass stopping_criteria_list
|
149 |
media_type="text/plain"
|
150 |
)
|
151 |
else:
|
|
|
157 |
except Exception as e:
|
158 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
159 |
|
160 |
+
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stop_criteria): # Accept stop_criteria
|
161 |
"""
|
162 |
+
Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real, dividiendo la respuesta en chunks si excede el l铆mite de tokens.
|
163 |
+
La generaci贸n se detiene autom谩ticamente al cumplirse los StoppingCriteriaList.
|
164 |
"""
|
165 |
# Limitar la entrada para minimizar el uso de memoria
|
166 |
+
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=64).to(device)
|
167 |
+
|
168 |
+
current_chunk_tokens = 0
|
169 |
+
current_chunk_text = ""
|
170 |
+
past_key_values = None # To maintain state for streaming
|
171 |
|
172 |
# Con torch.no_grad() se evita almacenar informaci贸n para gradientes
|
173 |
with torch.no_grad():
|
174 |
+
input_ids = encoded_input.input_ids
|
175 |
+
# Generaci贸n manual token por token para control de parada y chunking
|
176 |
+
while True: # Bucle infinito que se rompe por condiciones de parada
|
177 |
+
outputs = model(
|
178 |
+
input_ids,
|
179 |
+
past_key_values=past_key_values,
|
180 |
+
use_cache=True, # Important for stateful generation
|
181 |
+
return_dict=True
|
182 |
+
)
|
183 |
+
next_token_logits = outputs.logits[:, -1, :]
|
184 |
+
|
185 |
+
# Aplicar sampling para obtener el siguiente token (igual que en generation_config)
|
186 |
+
if generation_config.do_sample:
|
187 |
+
# Apply temperature and Top-p/Top-k sampling
|
188 |
+
next_token_logits = next_token_logits / generation_config.temperature
|
189 |
+
|
190 |
+
# Top-k filtering
|
191 |
+
if generation_config.top_k is not None and generation_config.top_k > 0:
|
192 |
+
v, _ = torch.topk(next_token_logits, min(generation_config.top_k, next_token_logits.size(-1)))
|
193 |
+
next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
|
194 |
+
|
195 |
+
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
196 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
197 |
+
else:
|
198 |
+
# Greedy decoding
|
199 |
+
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
200 |
+
|
201 |
+
|
202 |
+
# Check stop criteria BEFORE adding token to output
|
203 |
+
if stop_criteria and stop_criteria(input_ids, next_token_logits): # Check stopping criteria
|
204 |
+
break # Stop generation if criteria is met
|
205 |
+
|
206 |
+
next_tokens = next_tokens.unsqueeze(0) # Reshape to [1, 1] for concat
|
207 |
+
next_token_text = tokenizer.decode(next_tokens[0], skip_special_tokens=True)
|
208 |
+
|
209 |
+
|
210 |
+
token_count = len(tokenizer.encode(current_chunk_text + next_token_text)) - len(tokenizer.encode(current_chunk_text))
|
211 |
+
|
212 |
+
if current_chunk_tokens + token_count > chunk_token_limit:
|
213 |
+
yield current_chunk_text
|
214 |
+
current_chunk_text = next_token_text
|
215 |
+
current_chunk_tokens = token_count
|
216 |
+
else:
|
217 |
+
current_chunk_text += next_token_text
|
218 |
+
current_chunk_tokens += token_count
|
219 |
+
|
220 |
+
yield current_chunk_text # Yield every token/chunk
|
221 |
+
|
222 |
+
input_ids = torch.cat([input_ids, next_tokens], dim=-1) # Append next token to input_ids for next iteration
|
223 |
+
past_key_values = outputs.past_key_values # Update past key values for stateful generation
|
224 |
+
|
225 |
+
await asyncio.sleep(chunk_delay)
|
226 |
+
|
227 |
+
if input_ids.shape[-1] >= generation_config.max_new_tokens + encoded_input.input_ids.shape[-1]: # Check max_new_tokens limit
|
228 |
+
break # Stop if max_new_tokens is reached
|
229 |
+
|
230 |
+
# Asegurar de enviar el 煤ltimo chunk
|
231 |
+
if current_chunk_text:
|
232 |
+
yield current_chunk_text
|
233 |
+
|
234 |
await cleanup_memory(device)
|
235 |
|
236 |
+
|
237 |
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
|
238 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
239 |
with torch.no_grad():
|