Hjgugugjhuhjggg commited on
Commit
395fc4a
·
verified ·
1 Parent(s): 6538615

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -73
app.py CHANGED
@@ -5,18 +5,19 @@ from fastapi.responses import StreamingResponse
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
8
- pipeline,
9
  AutoModelForCausalLM,
10
  AutoTokenizer,
11
  GenerationConfig,
12
  StoppingCriteriaList,
13
  StoppingCriteria,
14
- TextStreamer
 
15
  )
16
  import uvicorn
17
  import asyncio
18
  from io import BytesIO
19
- from transformers import pipeline
 
20
 
21
  app = FastAPI()
22
 
@@ -26,7 +27,7 @@ class GenerateRequest(BaseModel):
26
  task_type: str
27
  temperature: float = 1.0
28
  max_new_tokens: int = 2
29
- stream: bool = True # Keep stream parameter in request for flexibility, but handle it correctly in code
30
  top_p: float = 1.0
31
  top_k: int = 50
32
  repetition_penalty: float = 1.0
@@ -89,7 +90,7 @@ async def generate(request: GenerateRequest):
89
  task_type = request.task_type
90
  temperature = request.temperature
91
  max_new_tokens = request.max_new_tokens
92
- stream = request.stream # Get stream from request, but handle correctly
93
  top_p = request.top_p
94
  top_k = request.top_k
95
  repetition_penalty = request.repetition_penalty
@@ -117,71 +118,50 @@ async def generate(request: GenerateRequest):
117
  stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
118
  stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
119
 
120
-
121
- if stream: # Handle streaming based on request parameter
122
  return StreamingResponse(
123
  stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
124
  media_type="text/plain"
125
  )
126
- else: # Handle non-streaming case
127
  generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
128
- return StreamingResponse(iter([generated_text]), media_type="text/plain") # Still use StreamingResponse for consistency in return type
129
-
130
 
131
  except Exception as e:
132
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
133
 
134
  async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
135
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
136
- streamer = TextStreamer(tokenizer) # Use TextStreamer for proper streaming
137
-
138
- with torch.no_grad():
139
- model.generate(
140
- **encoded_input,
141
- generation_config=generation_config,
142
- stopping_criteria=stopping_criteria_list,
143
- streamer=streamer, # Use streamer here instead of stream=True
144
- return_dict_in_generate=True,
145
- output_scores=True
146
- )
147
-
148
- # TextStreamer handles printing to stdout by default, but we want to stream to client
149
- # We need to access the generated text from the streamer and yield it.
150
- # TextStreamer is designed for terminal output, not direct access to tokens for streaming.
151
- # We need to modify stream_text to correctly stream tokens.
152
-
153
- encoded_input_len = encoded_input["input_ids"].shape[-1]
154
- generated_tokens = []
155
- for i, output in enumerate(model.generate(
156
- **encoded_input,
157
- generation_config=generation_config,
158
- stopping_criteria=stopping_criteria_list,
159
- stream=True, # Keep stream=True for actual streaming from model
160
- return_dict_in_generate=True,
161
- output_scores=True,
162
- ):
163
- if i > 0: # Skip the first output which is just input
164
- new_tokens = output.sequences[:, encoded_input_len:]
165
- for token_batch in new_tokens:
166
- token = tokenizer.decode(token_batch, skip_special_tokens=True)
167
- if token:
168
- yield token
169
- await asyncio.sleep(chunk_delay)
170
 
171
 
172
  async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=2048):
173
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
174
 
175
- with torch.no_grad():
176
- output = model.generate(
177
- **encoded_input,
178
- generation_config=generation_config,
179
- stopping_criteria=stopping_criteria_list,
180
- return_dict_in_generate=True,
181
- output_scores=True
182
- )
183
- generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
184
- return generated_text
185
 
186
 
187
  @app.post("/generate-image")
@@ -209,24 +189,17 @@ async def generate_text_to_speech(request: GenerateRequest):
209
  device = "cuda" if torch.cuda.is_available() else "cpu"
210
 
211
  audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
212
- audio = audio_generator(validated_body.input_text)[0]
 
213
 
214
  audio_byte_arr = BytesIO()
215
- # Assuming audio_generator returns an object with a save method. Adjust based on actual object.
216
- # Example for a hypothetical audio object with save method to BytesIO
217
- # audio_generator output might vary, check documentation for the specific model/pipeline
218
- # If it's raw audio data, you might need to use a library like soundfile to write to BytesIO
219
- # Example assuming `audio` is raw data and needs to be saved as wav
220
- import soundfile as sf
221
- sf.write(audio_byte_arr, audio, samplerate=audio_generator.sampling_rate, format='WAV') # Assuming samplerate exists
222
  audio_byte_arr.seek(0)
223
 
224
-
225
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
226
 
227
  except Exception as e:
228
- import traceback
229
- traceback.print_exc() # Print detailed error for debugging
230
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
231
 
232
  @app.post("/generate-video")
@@ -235,21 +208,16 @@ async def generate_video(request: GenerateRequest):
235
  validated_body = request
236
  device = "cuda" if torch.cuda.is_available() else "cpu"
237
  video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
238
- video = video_generator(validated_body.input_text)[0]
239
 
240
  video_byte_arr = BytesIO()
241
- # Assuming video_generator returns an object with a save method. Adjust based on actual object and format.
242
- # Example for a hypothetical video object with save method to BytesIO as mp4
243
- # video_generator output might vary, check documentation for the specific model/pipeline
244
- video.save(video_byte_arr, format='MP4') # Hypothetical save method, adjust based on actual video object
245
  video_byte_arr.seek(0)
246
 
247
-
248
  return StreamingResponse(video_byte_arr, media_type="video/mp4")
249
 
250
  except Exception as e:
251
- import traceback
252
- traceback.print_exc() # Print detailed error for debugging
253
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
254
 
255
  if __name__ == "__main__":
 
5
  from pydantic import BaseModel, field_validator
6
  from transformers import (
7
  AutoConfig,
 
8
  AutoModelForCausalLM,
9
  AutoTokenizer,
10
  GenerationConfig,
11
  StoppingCriteriaList,
12
  StoppingCriteria,
13
+ TextStreamer,
14
+ pipeline
15
  )
16
  import uvicorn
17
  import asyncio
18
  from io import BytesIO
19
+ import soundfile as sf
20
+ import traceback
21
 
22
  app = FastAPI()
23
 
 
27
  task_type: str
28
  temperature: float = 1.0
29
  max_new_tokens: int = 2
30
+ stream: bool = True
31
  top_p: float = 1.0
32
  top_k: int = 50
33
  repetition_penalty: float = 1.0
 
90
  task_type = request.task_type
91
  temperature = request.temperature
92
  max_new_tokens = request.max_new_tokens
93
+ stream = request.stream
94
  top_p = request.top_p
95
  top_k = request.top_k
96
  repetition_penalty = request.repetition_penalty
 
118
  stop_token_ids = tokenizer.convert_tokens_to_ids(stop_sequences)
119
  stopping_criteria_list = StoppingCriteriaList([StopOnTokens(stop_token_ids)]) if stop_token_ids else None
120
 
121
+ if stream:
 
122
  return StreamingResponse(
123
  stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay),
124
  media_type="text/plain"
125
  )
126
+ else:
127
  generated_text = generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device)
128
+ return StreamingResponse(iter([generated_text]), media_type="text/plain")
 
129
 
130
  except Exception as e:
131
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
132
 
133
  async def stream_text(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, chunk_delay, max_length=2048):
134
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
+ encoded_input_len = encoded_input["input_ids"].shape[-1]
136
+
137
+ for output in model.generate(
138
+ **encoded_input,
139
+ generation_config=generation_config,
140
+ stopping_criteria=stopping_criteria_list,
141
+ stream=True,
142
+ return_dict_in_generate=True,
143
+ output_scores=True,
144
+ ):
145
+ new_tokens = output.sequences[:, encoded_input_len:]
146
+ for token_batch in new_tokens:
147
+ token = tokenizer.decode(token_batch, skip_special_tokens=True)
148
+ if token:
149
+ yield token
150
+ await asyncio.sleep(chunk_delay)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  async def generate_non_stream(model, tokenizer, input_text, generation_config, stopping_criteria_list, device, max_length=2048):
154
  encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
155
 
156
+ output = model.generate(
157
+ **encoded_input,
158
+ generation_config=generation_config,
159
+ stopping_criteria=stopping_criteria_list,
160
+ return_dict_in_generate=True,
161
+ output_scores=True
162
+ )
163
+ generated_text = tokenizer.decode(output.sequences[0][encoded_input["input_ids"].shape[-1]:], skip_special_tokens=True)
164
+ return generated_text
 
165
 
166
 
167
  @app.post("/generate-image")
 
189
  device = "cuda" if torch.cuda.is_available() else "cpu"
190
 
191
  audio_generator = pipeline("text-to-speech", model=validated_body.model_name, device=device)
192
+ audio = audio_generator(validated_body.input_text)
193
+ sampling_rate = audio_generator.sampling_rate
194
 
195
  audio_byte_arr = BytesIO()
196
+ sf.write(audio_byte_arr, audio, sampling_rate, format='WAV')
 
 
 
 
 
 
197
  audio_byte_arr.seek(0)
198
 
 
199
  return StreamingResponse(audio_byte_arr, media_type="audio/wav")
200
 
201
  except Exception as e:
202
+ traceback.print_exc()
 
203
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
204
 
205
  @app.post("/generate-video")
 
208
  validated_body = request
209
  device = "cuda" if torch.cuda.is_available() else "cpu"
210
  video_generator = pipeline("text-to-video", model=validated_body.model_name, device=device)
211
+ video = video_generator(validated_body.input_text)
212
 
213
  video_byte_arr = BytesIO()
214
+ video.save(video_byte_arr)
 
 
 
215
  video_byte_arr.seek(0)
216
 
 
217
  return StreamingResponse(video_byte_arr, media_type="video/mp4")
218
 
219
  except Exception as e:
220
+ traceback.print_exc()
 
221
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
222
 
223
  if __name__ == "__main__":