sachin
commited on
Commit
·
af923b7
1
Parent(s):
3f6f875
add-tts
Browse files- src/server/main.py +4 -4
src/server/main.py
CHANGED
@@ -176,7 +176,7 @@ async def generate_audio(
|
|
176 |
response_format: Annotated[ResponseFormat, Body(include_in_schema=False)] = config.response_format,
|
177 |
speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
|
178 |
) -> StreamingResponse:
|
179 |
-
tts, tokenizer, description_tokenizer =
|
180 |
if speed != SPEED:
|
181 |
logger.warning(
|
182 |
"Specifying speed isn't supported by this model. Audio will be generated with the default speed"
|
@@ -190,11 +190,11 @@ async def generate_audio(
|
|
190 |
desc_inputs = description_tokenizer(voice,
|
191 |
return_tensors="pt",
|
192 |
padding="max_length",
|
193 |
-
max_length=
|
194 |
prompt_inputs = tokenizer(input,
|
195 |
return_tensors="pt",
|
196 |
padding="max_length",
|
197 |
-
max_length=
|
198 |
|
199 |
# Use the tensor fields directly instead of BatchEncoding object
|
200 |
input_ids = desc_inputs["input_ids"]
|
@@ -262,7 +262,7 @@ async def generate_audio_batch(
|
|
262 |
response_format: Annotated[ResponseFormat, Body()] = config.response_format,
|
263 |
speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
|
264 |
) -> StreamingResponse:
|
265 |
-
tts, tokenizer, description_tokenizer =
|
266 |
if speed != SPEED:
|
267 |
logger.warning(
|
268 |
"Specifying speed isn't supported by this model. Audio will be generated with the default speed"
|
|
|
176 |
response_format: Annotated[ResponseFormat, Body(include_in_schema=False)] = config.response_format,
|
177 |
speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
|
178 |
) -> StreamingResponse:
|
179 |
+
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model)
|
180 |
if speed != SPEED:
|
181 |
logger.warning(
|
182 |
"Specifying speed isn't supported by this model. Audio will be generated with the default speed"
|
|
|
190 |
desc_inputs = description_tokenizer(voice,
|
191 |
return_tensors="pt",
|
192 |
padding="max_length",
|
193 |
+
max_length=tts_model_manager.max_length).to(device)
|
194 |
prompt_inputs = tokenizer(input,
|
195 |
return_tensors="pt",
|
196 |
padding="max_length",
|
197 |
+
max_length=tts_model_manager.max_length).to(device)
|
198 |
|
199 |
# Use the tensor fields directly instead of BatchEncoding object
|
200 |
input_ids = desc_inputs["input_ids"]
|
|
|
262 |
response_format: Annotated[ResponseFormat, Body()] = config.response_format,
|
263 |
speed: Annotated[float, Body(include_in_schema=False)] = SPEED,
|
264 |
) -> StreamingResponse:
|
265 |
+
tts, tokenizer, description_tokenizer = tts_model_manager.get_or_load_model(model)
|
266 |
if speed != SPEED:
|
267 |
logger.warning(
|
268 |
"Specifying speed isn't supported by this model. Audio will be generated with the default speed"
|