Hjgugugjhuhjggg commited on
Commit
436996b
verified
1 Parent(s): f1afab6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -17
app.py CHANGED
@@ -9,8 +9,7 @@ from transformers import (
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
- StoppingCriteriaList,
13
- TextIteratorStreamer # Importar TextIteratorStreamer
14
  )
15
  import uvicorn
16
  import asyncio
@@ -24,7 +23,7 @@ class GenerateRequest(BaseModel):
24
  input_text: str = ""
25
  task_type: str
26
  temperature: float = 1.0
27
- max_new_tokens: int = 200
28
  stream: bool = True
29
  top_p: float = 1.0
30
  top_k: int = 50
@@ -121,27 +120,44 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
121
 
122
  stopping_criteria = StoppingCriteriaList([stop_criteria])
123
 
124
- streamer = TextIteratorStreamer(tokenizer, chunk_delay=chunk_delay, skip_prompt=True) # Inicializar streamer
125
-
126
- generation_kwargs = dict(
127
  **encoded_input,
128
- generation_config=generation_config,
 
 
 
 
 
 
129
  stopping_criteria=stopping_criteria,
130
- streamer=streamer, # Pasar streamer a generate
131
- return_dict_in_generate=True,
132
- output_scores=True
133
  )
134
 
135
- async def generate_task():
136
- model.generate(**generation_kwargs) # Ejecutar generate en background
137
-
138
- asyncio.create_task(generate_task()) # Iniciar la tarea de generaci贸n
 
139
 
140
- for token in streamer: # Iterar sobre el streamer para obtener tokens uno por uno
141
- yield token
142
- if stop_sequences and any(stop in token for stop in stop_sequences): # Comprobar stop sequences en cada token
143
  return
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  @app.post("/generate-image")
147
  async def generate_image(request: GenerateRequest):
 
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
+ StoppingCriteriaList
 
13
  )
14
  import uvicorn
15
  import asyncio
 
23
  input_text: str = ""
24
  task_type: str
25
  temperature: float = 1.0
26
+ max_new_tokens: int = 4
27
  stream: bool = True
28
  top_p: float = 1.0
29
  top_k: int = 50
 
120
 
121
  stopping_criteria = StoppingCriteriaList([stop_criteria])
122
 
123
+ output_text = ""
124
+ outputs = model.generate(
 
125
  **encoded_input,
126
+ do_sample=generation_config.do_sample,
127
+ max_new_tokens=generation_config.max_new_tokens,
128
+ temperature=generation_config.temperature,
129
+ top_p=generation_config.top_p,
130
+ top_k=generation_config.top_k,
131
+ repetition_penalty=generation_config.repetition_penalty,
132
+ num_return_sequences=generation_config.num_return_sequences,
133
  stopping_criteria=stopping_criteria,
134
+ output_scores=True,
135
+ return_dict_in_generate=True
 
136
  )
137
 
138
+ for output in outputs.sequences:
139
+ for token_id in output:
140
+ token = tokenizer.decode(token_id, skip_special_tokens=True)
141
+ yield token
142
+ await asyncio.sleep(chunk_delay)
143
 
144
+ if stop_sequences and any(stop in output_text for stop in stop_sequences):
145
+ yield output_text
 
146
  return
147
 
148
+ outputs = model.generate(
149
+ **encoded_input,
150
+ do_sample=generation_config.do_sample,
151
+ max_new_tokens=generation_config.max_new_tokens,
152
+ temperature=generation_config.temperature,
153
+ top_p=generation_config.top_p,
154
+ top_k=generation_config.top_k,
155
+ repetition_penalty=generation_config.repetition_penalty,
156
+ num_return_sequences=generation_config.num_return_sequences,
157
+ stopping_criteria=stopping_criteria,
158
+ output_scores=True,
159
+ return_dict_in_generate=True
160
+ )
161
 
162
  @app.post("/generate-image")
163
  async def generate_image(request: GenerateRequest):