Update app.py
Browse files
app.py
CHANGED
@@ -5,18 +5,19 @@ from fastapi.responses import StreamingResponse
|
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
8 |
-
pipeline,
|
9 |
AutoModelForCausalLM,
|
10 |
AutoTokenizer,
|
11 |
GenerationConfig,
|
12 |
StoppingCriteriaList,
|
13 |
StoppingCriteria,
|
14 |
-
TextStreamer
|
|
|
15 |
)
|
16 |
import uvicorn
|
17 |
import asyncio
|
18 |
from io import BytesIO
|
19 |
-
|
|
|
20 |
|
21 |
app = FastAPI()
|
22 |
|
@@ -26,7 +27,7 @@ class GenerateRequest(BaseModel):
|
|
26 |
task_type: str
|
27 |
temperature: float = 1.0
|
28 |
max_new_tokens: int = 2
|
29 |
-
stream: bool = True
|
30 |
top_p: float = 1.0
|
31 |
top_k: int = 50
|
32 |
repetition_penalty: float = 1.0
|
@@ -89,7 +90,7 @@ async def generate(request: GenerateRequest):
|
|
89 |
task_type = request.task_type
|
90 |
temperature = request.temperature
|
91 |
max_new_tokens = request.max_new_tokens
|
92 |
-
stream = request.stream
|
93 |
top_p = request.top_p
|
94 |
top_k = request.top_k
|
95 |
repetition_penalty = request.repetition_penalty
|
@@ -117,71 +118,50 @@ async def generate(request: GenerateRequest):
|
|
117 |
stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
|
118 |
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
|
119 |
|
120 |
-
|
121 |
-
if stream: # Handle streaming based on request parameter
|
122 |
return StreamingResponse(
|
123 |
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
|
124 |
media_type="text/plain"
|
125 |
)
|
126 |
-
else:
|
127 |
generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
|
128 |
-
return StreamingResponse(iter([generated_text]), media_type="text/plain")
|
129 |
-
|
130 |
|
131 |
except Exception as e:
|
132 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
133 |
|
134 |
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
|
135 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
encoded_input_len = encoded_input["input_ids"].shape[-1]
|
154 |
-
generated_tokens = []
|
155 |
-
for i, output in enumerate(model.generate(
|
156 |
-
**encoded_input,
|
157 |
-
generation_config=generation_config,
|
158 |
-
stopping_criteria=stopping_criteria_list,
|
159 |
-
stream=True, # Keep stream=True for actual streaming from model
|
160 |
-
return_dict_in_generate=True,
|
161 |
-
output_scores=True,
|
162 |
-
):
|
163 |
-
if i > 0: # Skip the first output which is just input
|
164 |
-
new_tokens = output.sequences[:, encoded_input_len:]
|
165 |
-
for token_batch in new_tokens:
|
166 |
-
token = tokenizer.decode(token_batch, skip_special_tokens=True)
|
167 |
-
if token:
|
168 |
-
yield token
|
169 |
-
await asyncio.sleep(chunk_delay)
|
170 |
|
171 |
|
172 |
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=2048):
|
173 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
return generated_text
|
185 |
|
186 |
|
187 |
@app.post("/generate-image")
|
@@ -209,24 +189,17 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
209 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
210 |
|
211 |
audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
|
212 |
-
audio = audio_generator(validated_body.input_text)
|
|
|
213 |
|
214 |
audio_byte_arr = BytesIO()
|
215 |
-
|
216 |
-
# Example for a hypothetical audio object with save method to BytesIO
|
217 |
-
# audio_generator output might vary, check documentation for the specific model/pipeline
|
218 |
-
# If it's raw audio data, you might need to use a library like soundfile to write to BytesIO
|
219 |
-
# Example assuming `audio` is raw data and needs to be saved as wav
|
220 |
-
import soundfile as sf
|
221 |
-
sf.write(audio_byte_arr, audio, samplerate=audio_generator.sampling_rate, format='WAV') # Assuming samplerate exists
|
222 |
audio_byte_arr.seek(0)
|
223 |
|
224 |
-
|
225 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
226 |
|
227 |
except Exception as e:
|
228 |
-
|
229 |
-
traceback.print_exc() # Print detailed error for debugging
|
230 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
231 |
|
232 |
@app.post("/generate-video")
|
@@ -235,21 +208,16 @@ async def generate_video(request: GenerateRequest):
|
|
235 |
validated_body = request
|
236 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
237 |
video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
|
238 |
-
video = video_generator(validated_body.input_text)
|
239 |
|
240 |
video_byte_arr = BytesIO()
|
241 |
-
|
242 |
-
# Example for a hypothetical video object with save method to BytesIO as mp4
|
243 |
-
# video_generator output might vary, check documentation for the specific model/pipeline
|
244 |
-
video.save(video_byte_arr, format='MP4') # Hypothetical save method, adjust based on actual video object
|
245 |
video_byte_arr.seek(0)
|
246 |
|
247 |
-
|
248 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
249 |
|
250 |
except Exception as e:
|
251 |
-
|
252 |
-
traceback.print_exc() # Print detailed error for debugging
|
253 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
254 |
|
255 |
if __name__ == "__main__":
|
|
|
5 |
from pydantic import BaseModel, field_validator
|
6 |
from transformers import (
|
7 |
AutoConfig,
|
|
|
8 |
AutoModelForCausalLM,
|
9 |
AutoTokenizer,
|
10 |
GenerationConfig,
|
11 |
StoppingCriteriaList,
|
12 |
StoppingCriteria,
|
13 |
+
TextStreamer,
|
14 |
+
pipeline
|
15 |
)
|
16 |
import uvicorn
|
17 |
import asyncio
|
18 |
from io import BytesIO
|
19 |
+
import soundfile as sf
|
20 |
+
import traceback
|
21 |
|
22 |
app = FastAPI()
|
23 |
|
|
|
27 |
task_type: str
|
28 |
temperature: float = 1.0
|
29 |
max_new_tokens: int = 2
|
30 |
+
stream: bool = True
|
31 |
top_p: float = 1.0
|
32 |
top_k: int = 50
|
33 |
repetition_penalty: float = 1.0
|
|
|
90 |
task_type = request.task_type
|
91 |
temperature = request.temperature
|
92 |
max_new_tokens = request.max_new_tokens
|
93 |
+
stream = request.stream
|
94 |
top_p = request.top_p
|
95 |
top_k = request.top_k
|
96 |
repetition_penalty = request.repetition_penalty
|
|
|
118 |
stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
|
119 |
stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
|
120 |
|
121 |
+
if stream:
|
|
|
122 |
return StreamingResponse(
|
123 |
stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
|
124 |
media_type="text/plain"
|
125 |
)
|
126 |
+
else:
|
127 |
generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
|
128 |
+
return StreamingResponse(iter([generated_text]), media_type="text/plain")
|
|
|
129 |
|
130 |
except Exception as e:
|
131 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
132 |
|
133 |
async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
|
134 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
135 |
+
encoded_input_len = encoded_input["input_ids"].shape[-1]
|
136 |
+
|
137 |
+
for output in model.generate(
|
138 |
+
**encoded_input,
|
139 |
+
generation_config=generation_config,
|
140 |
+
stopping_criteria=stopping_criteria_list,
|
141 |
+
stream=True,
|
142 |
+
return_dict_in_generate=True,
|
143 |
+
output_scores=True,
|
144 |
+
):
|
145 |
+
new_tokens = output.sequences[:, encoded_input_len:]
|
146 |
+
for token_batch in new_tokens:
|
147 |
+
token = tokenizer.decode(token_batch, skip_special_tokens=True)
|
148 |
+
if token:
|
149 |
+
yield token
|
150 |
+
await asyncio.sleep(chunk_delay)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
|
153 |
async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=2048):
|
154 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
155 |
|
156 |
+
output = model.generate(
|
157 |
+
**encoded_input,
|
158 |
+
generation_config=generation_config,
|
159 |
+
stopping_criteria=stopping_criteria_list,
|
160 |
+
return_dict_in_generate=True,
|
161 |
+
output_scores=True
|
162 |
+
)
|
163 |
+
generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
|
164 |
+
return generated_text
|
|
|
165 |
|
166 |
|
167 |
@app.post("/generate-image")
|
|
|
189 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
190 |
|
191 |
audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
|
192 |
+
audio = audio_generator(validated_body.input_text)
|
193 |
+
sampling_rate = audio_generator.sampling_rate
|
194 |
|
195 |
audio_byte_arr = BytesIO()
|
196 |
+
sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
audio_byte_arr.seek(0)
|
198 |
|
|
|
199 |
return StreamingResponse(audio_byte_arr, media_type="audio/wav")
|
200 |
|
201 |
except Exception as e:
|
202 |
+
traceback.print_exc()
|
|
|
203 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
204 |
|
205 |
@app.post("/generate-video")
|
|
|
208 |
validated_body = request
|
209 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
210 |
video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
|
211 |
+
video = video_generator(validated_body.input_text)
|
212 |
|
213 |
video_byte_arr = BytesIO()
|
214 |
+
video.save(video_byte_arr)
|
|
|
|
|
|
|
215 |
video_byte_arr.seek(0)
|
216 |
|
|
|
217 |
return StreamingResponse(video_byte_arr, media_type="video/mp4")
|
218 |
|
219 |
except Exception as e:
|
220 |
+
traceback.print_exc()
|
|
|
221 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
222 |
|
223 |
if __name__ == "__main__":
|