Update app.py
Browse files
app.py
CHANGED
@@ -10,7 +10,8 @@ from transformers import (
|
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList,
|
13 |
-
StoppingCriteria
|
|
|
14 |
)
|
15 |
import uvicorn
|
16 |
import asyncio
|
@@ -25,7 +26,7 @@ class GenerateRequest(BaseModel):
|
|
25 |
task_type: str
|
26 |
temperature: float = 1.0
|
27 |
max_new_tokens: int = 2
|
28 |
-
stream: bool = True
|
29 |
top_p: float = 1.0
|
30 |
top_k: int = 50
|
31 |
repetition_penalty: float = 1.0
|
@@ -88,7 +89,7 @@ async def generate(request: GenerateRequest):
|
|
88 |
task_type = request.task_type
|
89 |
temperature = request.temperature
|
90 |
max_new_tokens = request.max_new_tokens
|
91 |
-
stream = request.stream
|
92 |
top_p = request.top_p
|
93 |
top_k = request.top_k
|
94 |
repetition_penalty = request.repetition_penalty
|
@@ -117,32 +118,70 @@ async def generate(request: GenerateRequest):
|
|
117 |
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
|
118 |
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
except Exception as e:
|
126 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
127 |
|
128 |
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
|
129 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
|
|
130 |
|
131 |
with torch.no_grad():
|
132 |
-
|
133 |
**encoded_input,
|
134 |
generation_config=generation_config,
|
135 |
stopping_criteria=stopping_criteria_list,
|
136 |
-
|
137 |
return_dict_in_generate=True,
|
138 |
output_scores=True
|
139 |
)
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
|
148 |
@app.post("/generate-image")
|
@@ -173,12 +212,21 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
173 |
audio = audio_generator(validated_body.input_text)[0]
|
174 |
|
175 |
audio_byte_arr = BytesIO()
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
audio_byte_arr.seek(0)
|
178 |
|
|
|
179 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
180 |
|
181 |
except Exception as e:
|
|
|
|
|
182 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
183 |
|
184 |
@app.post("/generate-video")
|
@@ -190,12 +238,18 @@ async def generate_video(request: GenerateRequest):
|
|
190 |
video = video_generator(validated_body.input_text)[0]
|
191 |
|
192 |
video_byte_arr = BytesIO()
|
193 |
-
|
|
|
|
|
|
|
194 |
video_byte_arr.seek(0)
|
195 |
|
|
|
196 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
197 |
|
198 |
except Exception as e:
|
|
|
|
|
199 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
200 |
|
201 |
if __name__ == "__main__":
|
|
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList,
|
13 |
+
StoppingCriteria,
|
14 |
+
TextStreamer
|
15 |
)
|
16 |
import uvicorn
|
17 |
import asyncio
|
|
|
26 |
task_type: str
|
27 |
temperature: float = 1.0
|
28 |
max_new_tokens: int = 2
|
29 |
+
stream: bool = True # Keep stream parameter in request for flexibility, but handle it correctly in code
|
30 |
top_p: float = 1.0
|
31 |
top_k: int = 50
|
32 |
repetition_penalty: float = 1.0
|
|
|
89 |
task_type = request.task_type
|
90 |
temperature = request.temperature
|
91 |
max_new_tokens = request.max_new_tokens
|
92 |
+
stream = request.stream # Get stream from request, but handle correctly
|
93 |
top_p = request.top_p
|
94 |
top_k = request.top_k
|
95 |
repetition_penalty = request.repetition_penalty
|
|
|
118 |
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
|
119 |
|
120 |
|
121 |
+
if stream: # Handle streaming based on request parameter
|
122 |
+
return StreamingResponse(
|
123 |
+
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
|
124 |
+
media_type="text/plain"
|
125 |
+
)
|
126 |
+
else: # Handle non-streaming case
|
127 |
+
generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
|
128 |
+
return StreamingResponse(iter([generated_text]), media_type="text/plain") # Still use StreamingResponse for consistency in return type
|
129 |
+
|
130 |
|
131 |
except Exception as e:
|
132 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
133 |
|
134 |
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
|
135 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
136 |
+
streamer = TextStreamer(tokenizer) # Use TextStreamer for proper streaming
|
137 |
|
138 |
with torch.no_grad():
|
139 |
+
model.generate(
|
140 |
**encoded_input,
|
141 |
generation_config=generation_config,
|
142 |
stopping_criteria=stopping_criteria_list,
|
143 |
+
streamer=streamer, # Use streamer here instead of stream=True
|
144 |
return_dict_in_generate=True,
|
145 |
output_scores=True
|
146 |
)
|
147 |
|
148 |
+
# TextStreamer handles printing to stdout by default, but we want to stream to client
|
149 |
+
# We need to access the generated text from the streamer and yield it.
|
150 |
+
# TextStreamer is designed for terminal output, not direct access to tokens for streaming.
|
151 |
+
# We need to modify stream_text to correctly stream tokens.
|
152 |
+
|
153 |
+
encoded_input_len = encoded_input["input_ids"].shape[-1]
|
154 |
+
generated_tokens = []
|
155 |
+
for i, output in enumerate(model.generate(
|
156 |
+
**encoded_input,
|
157 |
+
generation_config=generation_config,
|
158 |
+
stopping_criteria=stopping_criteria_list,
|
159 |
+
stream=True, # Keep stream=True for actual streaming from model
|
160 |
+
return_dict_in_generate=True,
|
161 |
+
output_scores=True,
|
162 |
+
):
|
163 |
+
if i > 0: # Skip the first output which is just input
|
164 |
+
new_tokens = output.sequences[:, encoded_input_len:]
|
165 |
+
for token_batch in new_tokens:
|
166 |
+
token = tokenizer.decode(token_batch, skip_special_tokens=True)
|
167 |
+
if token:
|
168 |
+
yield token
|
169 |
+
await asyncio.sleep(chunk_delay)
|
170 |
+
|
171 |
+
|
172 |
+
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=2048):
|
173 |
+
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
174 |
+
|
175 |
+
with torch.no_grad():
|
176 |
+
output = model.generate(
|
177 |
+
**encoded_input,
|
178 |
+
generation_config=generation_config,
|
179 |
+
stopping_criteria=stopping_criteria_list,
|
180 |
+
return_dict_in_generate=True,
|
181 |
+
output_scores=True
|
182 |
+
)
|
183 |
+
generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
|
184 |
+
return generated_text
|
185 |
|
186 |
|
187 |
@app.post("/generate-image")
|
|
|
212 |
audio = audio_generator(validated_body.input_text)[0]
|
213 |
|
214 |
audio_byte_arr = BytesIO()
|
215 |
+
# Assuming audio_generator returns an object with a save method. Adjust based on actual object.
|
216 |
+
# Example for a hypothetical audio object with save method to BytesIO
|
217 |
+
# audio_generator output might vary, check documentation for the specific model/pipeline
|
218 |
+
# If it's raw audio data, you might need to use a library like soundfile to write to BytesIO
|
219 |
+
# Example assuming `audio` is raw data and needs to be saved as wav
|
220 |
+
import soundfile as sf
|
221 |
+
sf.write(audio_byte_arr, audio, samplerate=audio_generator.sampling_rate, format='WAV') # Assuming samplerate exists
|
222 |
audio_byte_arr.seek(0)
|
223 |
|
224 |
+
|
225 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
226 |
|
227 |
except Exception as e:
|
228 |
+
import traceback
|
229 |
+
traceback.print_exc() # Print detailed error for debugging
|
230 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
231 |
|
232 |
@app.post("/generate-video")
|
|
|
238 |
video = video_generator(validated_body.input_text)[0]
|
239 |
|
240 |
video_byte_arr = BytesIO()
|
241 |
+
# Assuming video_generator returns an object with a save method. Adjust based on actual object and format.
|
242 |
+
# Example for a hypothetical video object with save method to BytesIO as mp4
|
243 |
+
# video_generator output might vary, check documentation for the specific model/pipeline
|
244 |
+
video.save(video_byte_arr, format='MP4') # Hypothetical save method, adjust based on actual video object
|
245 |
video_byte_arr.seek(0)
|
246 |
|
247 |
+
|
248 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
249 |
|
250 |
except Exception as e:
|
251 |
+
import traceback
|
252 |
+
traceback.print_exc() # Print detailed error for debugging
|
253 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
254 |
|
255 |
if __name__ == "__main__":
|