Hjgugugjhuhjggg commited on
Commit
f9fbc8e
verified
1 Parent(s): f40b225

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -19
app.py CHANGED
@@ -24,7 +24,7 @@ try:
24
  import psutil
25
  import resource
26
  total_memory = psutil.virtual_memory().total
27
- limit = int(total_memory * 80.0) # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
28
  resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
29
  print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
30
  except Exception as e:
@@ -55,6 +55,7 @@ class GenerateRequest(BaseModel):
55
  do_sample: bool = True
56
  chunk_delay: float = 0.0
57
  stop_sequences: list[str] = []
 
58
 
59
  @field_validator("model_name")
60
  def model_name_cannot_be_empty(cls, v):
@@ -119,6 +120,7 @@ async def generate(request: GenerateRequest):
119
  do_sample = request.do_sample
120
  chunk_delay = request.chunk_delay
121
  stop_sequences = request.stop_sequences
 
122
 
123
  model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
124
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -143,7 +145,7 @@ async def generate(request: GenerateRequest):
143
  if stream:
144
  # Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
145
  response = StreamingResponse(
146
- stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
147
  media_type="text/plain"
148
  )
149
  else:
@@ -155,31 +157,83 @@ async def generate(request: GenerateRequest):
155
  except Exception as e:
156
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
157
 
158
- async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=64):
159
  """
160
- Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real.
 
161
  """
162
  # Limitar la entrada para minimizar el uso de memoria
163
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
 
 
 
 
164
 
165
  # Con torch.no_grad() se evita almacenar informaci贸n para gradientes
166
  with torch.no_grad():
167
- # Se genera el texto de forma iterativa (streaming)
168
- for output in model.generate(
169
- **encoded_input,
170
- generation_config=generation_config,
171
- stopping_criteria=stopping_criteria_list,
172
- # return_dict_in_generate=True, # Remove return_dict_in_generate for streaming
173
- # output_scores=True, # output_scores might not be needed for streaming text only
174
- ):
175
- # In streaming mode, output is directly the generated token IDs
176
- token = tokenizer.decode(output, skip_special_tokens=True)
177
- if token:
178
- # Se env铆a cada token inmediatamente
179
- yield token
180
- await asyncio.sleep(chunk_delay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  await cleanup_memory(device)
182
 
 
183
  async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
184
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
185
  with torch.no_grad():
 
24
  import psutil
25
  import resource
26
  total_memory = psutil.virtual_memory().total
27
+ limit = int(total_memory * 90.0) # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
28
  resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
29
  print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
30
  except Exception as e:
 
55
  do_sample: bool = True
56
  chunk_delay: float = 0.0
57
  stop_sequences: list[str] = []
58
+ chunk_token_limit: int = 100 # Nuevo par谩metro para limitar tokens por chunk
59
 
60
  @field_validator("model_name")
61
  def model_name_cannot_be_empty(cls, v):
 
120
  do_sample = request.do_sample
121
  chunk_delay = request.chunk_delay
122
  stop_sequences = request.stop_sequences
123
+ chunk_token_limit = request.chunk_token_limit
124
 
125
  model, tokenizer = await model_loader.load_model_and_tokenizer(model_name)
126
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
145
  if stream:
146
  # Se utiliza StreamingResponse con la funci贸n as铆ncrona que env铆a cada token en tiempo real.
147
  response = StreamingResponse(
148
+ stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stopping_criteria_list), # Pass stopping_criteria_list
149
  media_type="text/plain"
150
  )
151
  else:
 
157
  except Exception as e:
158
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
159
 
160
+ async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, chunk_token_limit, stop_criteria): # Accept stop_criteria
161
  """
162
+ Genera tokens de forma as铆ncrona y los env铆a al cliente en tiempo real, dividiendo la respuesta en chunks si excede el l铆mite de tokens.
163
+ La generaci贸n se detiene autom谩ticamente al cumplirse los StoppingCriteriaList.
164
  """
165
  # Limitar la entrada para minimizar el uso de memoria
166
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=64).to(device)
167
+
168
+ current_chunk_tokens = 0
169
+ current_chunk_text = ""
170
+ past_key_values = None # To maintain state for streaming
171
 
172
  # Con torch.no_grad() se evita almacenar informaci贸n para gradientes
173
  with torch.no_grad():
174
+ input_ids = encoded_input.input_ids
175
+ # Generaci贸n manual token por token para control de parada y chunking
176
+ while True: # Bucle infinito que se rompe por condiciones de parada
177
+ outputs = model(
178
+ input_ids,
179
+ past_key_values=past_key_values,
180
+ use_cache=True, # Important for stateful generation
181
+ return_dict=True
182
+ )
183
+ next_token_logits = outputs.logits[:, -1, :]
184
+
185
+ # Aplicar sampling para obtener el siguiente token (igual que en generation_config)
186
+ if generation_config.do_sample:
187
+ # Apply temperature and Top-p/Top-k sampling
188
+ next_token_logits = next_token_logits / generation_config.temperature
189
+
190
+ # Top-k filtering
191
+ if generation_config.top_k is not None and generation_config.top_k > 0:
192
+ v, _ = torch.topk(next_token_logits, min(generation_config.top_k, next_token_logits.size(-1)))
193
+ next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf')
194
+
195
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
196
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
197
+ else:
198
+ # Greedy decoding
199
+ next_tokens = torch.argmax(next_token_logits, dim=-1)
200
+
201
+
202
+ # Check stop criteria BEFORE adding token to output
203
+ if stop_criteria and stop_criteria(input_ids, next_token_logits): # Check stopping criteria
204
+ break # Stop generation if criteria is met
205
+
206
+ next_tokens = next_tokens.unsqueeze(0) # Reshape to [1, 1] for concat
207
+ next_token_text = tokenizer.decode(next_tokens[0], skip_special_tokens=True)
208
+
209
+
210
+ token_count = len(tokenizer.encode(current_chunk_text + next_token_text)) - len(tokenizer.encode(current_chunk_text))
211
+
212
+ if current_chunk_tokens + token_count > chunk_token_limit:
213
+ yield current_chunk_text
214
+ current_chunk_text = next_token_text
215
+ current_chunk_tokens = token_count
216
+ else:
217
+ current_chunk_text += next_token_text
218
+ current_chunk_tokens += token_count
219
+
220
+ yield current_chunk_text # Yield every token/chunk
221
+
222
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1) # Append next token to input_ids for next iteration
223
+ past_key_values = outputs.past_key_values # Update past key values for stateful generation
224
+
225
+ await asyncio.sleep(chunk_delay)
226
+
227
+ if input_ids.shape[-1] >= generation_config.max_new_tokens + encoded_input.input_ids.shape[-1]: # Check max_new_tokens limit
228
+ break # Stop if max_new_tokens is reached
229
+
230
+ # Asegurar de enviar el 煤ltimo chunk
231
+ if current_chunk_text:
232
+ yield current_chunk_text
233
+
234
  await cleanup_memory(device)
235
 
236
+
237
  async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
238
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
239
  with torch.no_grad():