Hjgugugjhuhjggg commited on
Commit
f1afab6
verified
1 Parent(s): ac0dda7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -32
app.py CHANGED
@@ -9,7 +9,8 @@ from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
- StoppingCriteriaList
 
13
  )
14
  import uvicorn
15
  import asyncio
@@ -120,44 +121,27 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
120
 
121
  stopping_criteria = StoppingCriteriaList([stop_criteria])
122
 
123
- output_text = ""
124
- outputs = model.generate(
 
125
  **encoded_input,
126
- do_sample=generation_config.do_sample,
127
- max_new_tokens=generation_config.max_new_tokens,
128
- temperature=generation_config.temperature,
129
- top_p=generation_config.top_p,
130
- top_k=generation_config.top_k,
131
- repetition_penalty=generation_config.repetition_penalty,
132
- num_return_sequences=generation_config.num_return_sequences,
133
  stopping_criteria=stopping_criteria,
134
- output_scores=True,
135
- return_dict_in_generate=True
 
136
  )
137
 
138
- for output in outputs.sequences:
139
- for token_id in output:
140
- token = tokenizer.decode(token_id, skip_special_tokens=True)
141
- yield token
142
- await asyncio.sleep(chunk_delay)
143
 
144
- if stop_sequences and any(stop in output_text for stop in stop_sequences):
145
- yield output_text
 
146
  return
147
 
148
- outputs = model.generate(
149
- **encoded_input,
150
- do_sample=generation_config.do_sample,
151
- max_new_tokens=generation_config.max_new_tokens,
152
- temperature=generation_config.temperature,
153
- top_p=generation_config.top_p,
154
- top_k=generation_config.top_k,
155
- repetition_penalty=generation_config.repetition_penalty,
156
- num_return_sequences=generation_config.num_return_sequences,
157
- stopping_criteria=stopping_criteria,
158
- output_scores=True,
159
- return_dict_in_generate=True
160
- )
161
 
162
  @app.post("/generate-image")
163
  async def generate_image(request: GenerateRequest):
 
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
+ StoppingCriteriaList,
13
+ TextIteratorStreamer # Importar TextIteratorStreamer
14
  )
15
  import uvicorn
16
  import asyncio
 
121
 
122
  stopping_criteria = StoppingCriteriaList([stop_criteria])
123
 
124
+ streamer = TextIteratorStreamer(tokenizer, chunk_delay=chunk_delay, skip_prompt=True) # Inicializar streamer
125
+
126
+ generation_kwargs = dict(
127
  **encoded_input,
128
+ generation_config=generation_config,
 
 
 
 
 
 
129
  stopping_criteria=stopping_criteria,
130
+ streamer=streamer, # Pasar streamer a generate
131
+ return_dict_in_generate=True,
132
+ output_scores=True
133
  )
134
 
135
+ async def generate_task():
136
+ model.generate(**generation_kwargs) # Ejecutar generate en background
137
+
138
+ asyncio.create_task(generate_task()) # Iniciar la tarea de generaci贸n
 
139
 
140
+ for token in streamer: # Iterar sobre el streamer para obtener tokens uno por uno
141
+ yield token
142
+ if stop_sequences and any(stop in token for stop in stop_sequences): # Comprobar stop sequences en cada token
143
  return
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  @app.post("/generate-image")
147
  async def generate_image(request: GenerateRequest):