Hjgugugjhuhjggg commited on
Commit
b2da2fd
verified
1 Parent(s): 02659ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -13
app.py CHANGED
@@ -24,7 +24,7 @@ try:
24
  import psutil
25
  import resource
26
  total_memory = psutil.virtual_memory().total
27
- limit = int(total_memory * 80.0) # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
28
  resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
29
  print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
30
  except Exception as e:
@@ -132,6 +132,7 @@ async def generate(request: GenerateRequest):
132
  repetition_penalty=repetition_penalty,
133
  do_sample=do_sample,
134
  num_return_sequences=num_return_sequences,
 
135
  )
136
 
137
  stop_token_ids = []
@@ -160,7 +161,6 @@ async def stream_text(model, tokenizer, input_text, generation_config, stopping_
160
  """
161
  # Limitar la entrada para minimizar el uso de memoria
162
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
163
- encoded_input_len = encoded_input["input_ids"].shape[-1]
164
 
165
  # Con torch.no_grad() se evita almacenar informaci贸n para gradientes
166
  with torch.no_grad():
@@ -169,18 +169,15 @@ async def stream_text(model, tokenizer, input_text, generation_config, stopping_
169
  **encoded_input,
170
  generation_config=generation_config,
171
  stopping_criteria=stopping_criteria_list,
172
- return_dict_in_generate=True,
173
- output_scores=True,
174
- # stream=True, # Eliminar 'stream=True' aqu铆, ya que GenerationConfig lo maneja
175
  ):
176
- # Se extraen solo los tokens generados (excluyendo la entrada)
177
- new_tokens = output.sequences[:, encoded_input_len:]
178
- for token_batch in new_tokens:
179
- token = tokenizer.decode(token_batch, skip_special_tokens=True)
180
- if token:
181
- # Se env铆a cada token inmediatamente
182
- yield token
183
- await asyncio.sleep(chunk_delay)
184
  await cleanup_memory(device)
185
 
186
  async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):
 
24
  import psutil
25
  import resource
26
  total_memory = psutil.virtual_memory().total
27
+ limit = int(total_memory * 0.01) # 1% del total en bytes # Correcci贸n: Usar 0.01 para 1%
28
  resource.setrlimit(resource.RLIMIT_AS, (limit, limit))
29
  print(f"Memory limit set to {limit} bytes (1% of total system memory).") # Imprimir para verificar el l铆mite aplicado
30
  except Exception as e:
 
132
  repetition_penalty=repetition_penalty,
133
  do_sample=do_sample,
134
  num_return_sequences=num_return_sequences,
135
+ stream=stream, # Add stream=True/False to generation config
136
  )
137
 
138
  stop_token_ids = []
 
161
  """
162
  # Limitar la entrada para minimizar el uso de memoria
163
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
 
164
 
165
  # Con torch.no_grad() se evita almacenar informaci贸n para gradientes
166
  with torch.no_grad():
 
169
  **encoded_input,
170
  generation_config=generation_config,
171
  stopping_criteria=stopping_criteria_list,
172
+ # return_dict_in_generate=True, # Remove return_dict_in_generate for streaming
173
+ # output_scores=True, # output_scores might not be needed for streaming text only
 
174
  ):
175
+ # In streaming mode, output is directly the generated token IDs
176
+ token = tokenizer.decode(output, skip_special_tokens=True)
177
+ if token:
178
+ # Se env铆a cada token inmediatamente
179
+ yield token
180
+ await asyncio.sleep(chunk_delay)
 
 
181
  await cleanup_memory(device)
182
 
183
  async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=64):