Hjgugugjhuhjggg commited on
Commit
6538615
·
verified ·
1 Parent(s): 5ded4bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -16
app.py CHANGED
@@ -10,7 +10,8 @@ from transformers import (
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList,
13
- StoppingCriteria
 
14
  )
15
  import uvicorn
16
  import asyncio
@@ -25,7 +26,7 @@ class GenerateRequest(BaseModel):
25
  task_type: str
26
  temperature: float = 1.0
27
  max_new_tokens: int = 2
28
- stream: bool = True
29
  top_p: float = 1.0
30
  top_k: int = 50
31
  repetition_penalty: float = 1.0
@@ -88,7 +89,7 @@ async def generate(request: GenerateRequest):
88
  task_type = request.task_type
89
  temperature = request.temperature
90
  max_new_tokens = request.max_new_tokens
91
- stream = request.stream
92
  top_p = request.top_p
93
  top_k = request.top_k
94
  repetition_penalty = request.repetition_penalty
@@ -117,32 +118,70 @@ async def generate(request: GenerateRequest):
117
  stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
118
 
119
 
120
- return StreamingResponse(
121
- stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
122
- media_type="text/plain"
123
- )
 
 
 
 
 
124
 
125
  except Exception as e:
126
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
127
 
128
  async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
129
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
 
130
 
131
  with torch.no_grad():
132
- streamer = model.generate(
133
  **encoded_input,
134
  generation_config=generation_config,
135
  stopping_criteria=stopping_criteria_list,
136
- stream=True, # Ensure streaming is enabled if supported by the model
137
  return_dict_in_generate=True,
138
  output_scores=True
139
  )
140
 
141
- for output in streamer.sequences[:, encoded_input["input_ids"].shape[-1]:]: # Stream from the *new* tokens
142
- token = tokenizer.decode(output, skip_special_tokens=True)
143
- if token: # Avoid yielding empty tokens
144
- yield token
145
- await asyncio.sleep(chunk_delay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
 
148
  @app.post("/generate-image")
@@ -173,12 +212,21 @@ async def generate_text_to_speech(request: GenerateRequest):
173
  audio = audio_generator(validated_body.input_text)[0]
174
 
175
  audio_byte_arr = BytesIO()
176
- audio.save(audio_byte_arr)
 
 
 
 
 
 
177
  audio_byte_arr.seek(0)
178
 
 
179
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
180
 
181
  except Exception as e:
 
 
182
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
183
 
184
  @app.post("/generate-video")
@@ -190,12 +238,18 @@ async def generate_video(request: GenerateRequest):
190
  video = video_generator(validated_body.input_text)[0]
191
 
192
  video_byte_arr = BytesIO()
193
- video.save(video_byte_arr)
 
 
 
194
  video_byte_arr.seek(0)
195
 
 
196
  return StreamingResponse(video_byte_arr, media_type="video/mp4")
197
 
198
  except Exception as e:
 
 
199
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
200
 
201
  if __name__ == "__main__":
 
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList,
13
+ StoppingCriteria,
14
+ TextStreamer
15
  )
16
  import uvicorn
17
  import asyncio
 
26
  task_type: str
27
  temperature: float = 1.0
28
  max_new_tokens: int = 2
29
+ stream: bool = True # Keep stream parameter in request for flexibility, but handle it correctly in code
30
  top_p: float = 1.0
31
  top_k: int = 50
32
  repetition_penalty: float = 1.0
 
89
  task_type = request.task_type
90
  temperature = request.temperature
91
  max_new_tokens = request.max_new_tokens
92
+ stream = request.stream # Get stream from request, but handle correctly
93
  top_p = request.top_p
94
  top_k = request.top_k
95
  repetition_penalty = request.repetition_penalty
 
118
  stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
119
 
120
 
121
+ if stream: # Handle streaming based on request parameter
122
+ return StreamingResponse(
123
+ stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
124
+ media_type="text/plain"
125
+ )
126
+ else: # Handle non-streaming case
127
+ generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
128
+ return StreamingResponse(iter([generated_text]), media_type="text/plain") # Still use StreamingResponse for consistency in return type
129
+
130
 
131
  except Exception as e:
132
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
133
 
134
  async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
135
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
136
+ streamer = TextStreamer(tokenizer) # Use TextStreamer for proper streaming
137
 
138
  with torch.no_grad():
139
+ model.generate(
140
  **encoded_input,
141
  generation_config=generation_config,
142
  stopping_criteria=stopping_criteria_list,
143
+ streamer=streamer, # Use streamer here instead of stream=True
144
  return_dict_in_generate=True,
145
  output_scores=True
146
  )
147
 
148
+ # TextStreamer handles printing to stdout by default, but we want to stream to client
149
+ # We need to access the generated text from the streamer and yield it.
150
+ # TextStreamer is designed for terminal output, not direct access to tokens for streaming.
151
+ # We need to modify stream_text to correctly stream tokens.
152
+
153
+ encoded_input_len = encoded_input["input_ids"].shape[-1]
154
+ generated_tokens = []
155
+ for i, output in enumerate(model.generate(
156
+ **encoded_input,
157
+ generation_config=generation_config,
158
+ stopping_criteria=stopping_criteria_list,
159
+ stream=True, # Keep stream=True for actual streaming from model
160
+ return_dict_in_generate=True,
161
+ output_scores=True,
162
+ ):
163
+ if i > 0: # Skip the first output which is just input
164
+ new_tokens = output.sequences[:, encoded_input_len:]
165
+ for token_batch in new_tokens:
166
+ token = tokenizer.decode(token_batch, skip_special_tokens=True)
167
+ if token:
168
+ yield token
169
+ await asyncio.sleep(chunk_delay)
170
+
171
+
172
+ async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=2048):
173
+ encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
174
+
175
+ with torch.no_grad():
176
+ output = model.generate(
177
+ **encoded_input,
178
+ generation_config=generation_config,
179
+ stopping_criteria=stopping_criteria_list,
180
+ return_dict_in_generate=True,
181
+ output_scores=True
182
+ )
183
+ generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
184
+ return generated_text
185
 
186
 
187
  @app.post("/generate-image")
 
212
  audio = audio_generator(validated_body.input_text)[0]
213
 
214
  audio_byte_arr = BytesIO()
215
+ # Assuming audio_generator returns an object with a save method. Adjust based on actual object.
216
+ # Example for a hypothetical audio object with save method to BytesIO
217
+ # audio_generator output might vary, check documentation for the specific model/pipeline
218
+ # If it's raw audio data, you might need to use a library like soundfile to write to BytesIO
219
+ # Example assuming `audio` is raw data and needs to be saved as wav
220
+ import soundfile as sf
221
+ sf.write(audio_byte_arr, audio, samplerate=audio_generator.sampling_rate, format='WAV') # Assuming samplerate exists
222
  audio_byte_arr.seek(0)
223
 
224
+
225
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
226
 
227
  except Exception as e:
228
+ import traceback
229
+ traceback.print_exc() # Print detailed error for debugging
230
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
231
 
232
  @app.post("/generate-video")
 
238
  video = video_generator(validated_body.input_text)[0]
239
 
240
  video_byte_arr = BytesIO()
241
+ # Assuming video_generator returns an object with a save method. Adjust based on actual object and format.
242
+ # Example for a hypothetical video object with save method to BytesIO as mp4
243
+ # video_generator output might vary, check documentation for the specific model/pipeline
244
+ video.save(video_byte_arr, format='MP4') # Hypothetical save method, adjust based on actual video object
245
  video_byte_arr.seek(0)
246
 
247
+
248
  return StreamingResponse(video_byte_arr, media_type="video/mp4")
249
 
250
  except Exception as e:
251
+ import traceback
252
+ traceback.print_exc() # Print detailed error for debugging
253
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
254
 
255
  if __name__ == "__main__":