Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
sachin
commited on
Commit
·
3fa9edb
1
Parent(s):
20c50d1
update-
Browse files- src/server/main.py +48 -52
src/server/main.py
CHANGED
@@ -68,37 +68,6 @@ async def get_user_id_for_rate_limit(request: Request):
|
|
68 |
limiter = Limiter(key_func=get_user_id_for_rate_limit)
|
69 |
|
70 |
# Request/Response Models
|
71 |
-
class SpeechRequest(BaseModel):
|
72 |
-
input: str = Field(..., description="Text to convert to speech (max 1000 characters)")
|
73 |
-
voice: str = Field(..., description="Voice identifier for the TTS service")
|
74 |
-
model: str = Field(..., description="TTS model to use")
|
75 |
-
response_format: ResponseFormat = Field(tts_config.response_format, description="Audio format: mp3, flac, or wav")
|
76 |
-
speed: float = Field(SPEED, description="Speech speed (default: 1.0)")
|
77 |
-
|
78 |
-
@field_validator("input")
|
79 |
-
def input_must_be_valid(cls, v):
|
80 |
-
if len(v) > 1000:
|
81 |
-
raise ValueError("Input cannot exceed 1000 characters")
|
82 |
-
return v.strip()
|
83 |
-
|
84 |
-
@field_validator("response_format")
|
85 |
-
def validate_response_format(cls, v):
|
86 |
-
supported_formats = [ResponseFormat.MP3, ResponseFormat.FLAC, ResponseFormat.WAV]
|
87 |
-
if v not in supported_formats:
|
88 |
-
raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}")
|
89 |
-
return v
|
90 |
-
|
91 |
-
class Config:
|
92 |
-
schema_extra = {
|
93 |
-
"example": {
|
94 |
-
"input": "Hello, how are you?",
|
95 |
-
"voice": "female-1",
|
96 |
-
"model": "tts-model-1",
|
97 |
-
"response_format": "mp3",
|
98 |
-
"speed": 1.0
|
99 |
-
}
|
100 |
-
}
|
101 |
-
|
102 |
class TranscriptionResponse(BaseModel):
|
103 |
text: str = Field(..., description="Transcribed text from the audio")
|
104 |
|
@@ -120,8 +89,9 @@ class AudioProcessingResponse(BaseModel):
|
|
120 |
class ChatRequest(BaseModel):
|
121 |
prompt: str = Field(..., description="Base64-encoded encrypted prompt (max 1000 characters after decryption)")
|
122 |
src_lang: str = Field(..., description="Base64-encoded encrypted source language code")
|
|
|
123 |
|
124 |
-
@field_validator("prompt", "src_lang")
|
125 |
def must_be_valid_base64(cls, v):
|
126 |
try:
|
127 |
base64.b64decode(v)
|
@@ -133,7 +103,8 @@ class ChatRequest(BaseModel):
|
|
133 |
schema_extra = {
|
134 |
"example": {
|
135 |
"prompt": "base64_encoded_encrypted_prompt",
|
136 |
-
"src_lang": "base64_encoded_encrypted_kan_Knda"
|
|
|
137 |
}
|
138 |
}
|
139 |
|
@@ -213,16 +184,18 @@ class ExternalTTSService(TTSService):
|
|
213 |
async def generate_speech(self, payload: dict) -> requests.Response:
|
214 |
try:
|
215 |
return requests.post(
|
216 |
-
settings.external_tts_url,
|
217 |
json=payload,
|
218 |
-
headers={"accept": "
|
219 |
stream=True,
|
220 |
timeout=60
|
221 |
)
|
222 |
except requests.Timeout:
|
|
|
223 |
raise HTTPException(status_code=504, detail="External TTS API timeout")
|
224 |
except requests.RequestException as e:
|
225 |
-
|
|
|
226 |
|
227 |
def get_tts_service() -> TTSService:
|
228 |
return ExternalTTSService()
|
@@ -310,53 +283,68 @@ async def app_register_user(
|
|
310 |
|
311 |
@app.post("/v1/audio/speech",
|
312 |
summary="Generate Speech from Text",
|
313 |
-
description="Convert text to speech
|
314 |
tags=["Audio"],
|
315 |
responses={
|
316 |
200: {"description": "Audio stream", "content": {"audio/mp3": {"example": "Binary audio data"}}},
|
317 |
-
400: {"description": "Invalid input"},
|
318 |
401: {"description": "Unauthorized - Token required"},
|
319 |
429: {"description": "Rate limit exceeded"},
|
|
|
320 |
504: {"description": "TTS service timeout"}
|
321 |
})
|
322 |
@limiter.limit(settings.speech_rate_limit)
|
323 |
async def generate_audio(
|
324 |
request: Request,
|
325 |
-
|
|
|
326 |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
|
|
327 |
tts_service: TTSService = Depends(get_tts_service)
|
328 |
):
|
329 |
user_id = await get_current_user(credentials)
|
330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
raise HTTPException(status_code=400, detail="Input cannot be empty")
|
|
|
|
|
332 |
|
333 |
logger.info("Processing speech request", extra={
|
334 |
"endpoint": "/v1/audio/speech",
|
335 |
-
"input_length": len(
|
336 |
"client_ip": get_remote_address(request),
|
337 |
"user_id": user_id
|
338 |
})
|
339 |
|
340 |
payload = {
|
341 |
-
"
|
342 |
-
"voice": speech_request.voice,
|
343 |
-
"model": speech_request.model,
|
344 |
-
"response_format": speech_request.response_format.value,
|
345 |
-
"speed": speech_request.speed
|
346 |
}
|
347 |
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
|
|
350 |
|
351 |
headers = {
|
352 |
-
"Content-Disposition":
|
353 |
"Cache-Control": "no-cache",
|
354 |
-
"Content-Type":
|
355 |
}
|
356 |
|
357 |
return StreamingResponse(
|
358 |
response.iter_content(chunk_size=8192),
|
359 |
-
media_type=
|
360 |
headers=headers
|
361 |
)
|
362 |
|
@@ -398,6 +386,14 @@ async def chat(
|
|
398 |
logger.error(f"Source language decryption failed: {str(e)}")
|
399 |
raise HTTPException(status_code=400, detail="Invalid encrypted source language")
|
400 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
401 |
if not decrypted_prompt:
|
402 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
403 |
if len(decrypted_prompt) > 1000:
|
@@ -410,7 +406,7 @@ async def chat(
|
|
410 |
payload = {
|
411 |
"prompt": decrypted_prompt,
|
412 |
"src_lang": decrypted_src_lang,
|
413 |
-
"tgt_lang":
|
414 |
}
|
415 |
|
416 |
response = requests.post(
|
|
|
68 |
limiter = Limiter(key_func=get_user_id_for_rate_limit)
|
69 |
|
70 |
# Request/Response Models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
class TranscriptionResponse(BaseModel):
|
72 |
text: str = Field(..., description="Transcribed text from the audio")
|
73 |
|
|
|
89 |
class ChatRequest(BaseModel):
|
90 |
prompt: str = Field(..., description="Base64-encoded encrypted prompt (max 1000 characters after decryption)")
|
91 |
src_lang: str = Field(..., description="Base64-encoded encrypted source language code")
|
92 |
+
tgt_lang: str = Field(..., description="Base64-encoded encrypted target language code")
|
93 |
|
94 |
+
@field_validator("prompt", "src_lang", "tgt_lang")
|
95 |
def must_be_valid_base64(cls, v):
|
96 |
try:
|
97 |
base64.b64decode(v)
|
|
|
103 |
schema_extra = {
|
104 |
"example": {
|
105 |
"prompt": "base64_encoded_encrypted_prompt",
|
106 |
+
"src_lang": "base64_encoded_encrypted_kan_Knda",
|
107 |
+
"tgt_lang": "base64_encoded_encrypted_kan_Knda"
|
108 |
}
|
109 |
}
|
110 |
|
|
|
184 |
async def generate_speech(self, payload: dict) -> requests.Response:
|
185 |
try:
|
186 |
return requests.post(
|
187 |
+
f"{settings.external_tts_url}/audio/speech",
|
188 |
json=payload,
|
189 |
+
headers={"accept": "*/*", "Content-Type": "application/json"},
|
190 |
stream=True,
|
191 |
timeout=60
|
192 |
)
|
193 |
except requests.Timeout:
|
194 |
+
logger.error("External TTS API timeout")
|
195 |
raise HTTPException(status_code=504, detail="External TTS API timeout")
|
196 |
except requests.RequestException as e:
|
197 |
+
logger.error(f"External TTS API error: {str(e)}")
|
198 |
+
raise HTTPException(status_code=502, detail=f"External TTS service error: {str(e)}")
|
199 |
|
200 |
def get_tts_service() -> TTSService:
|
201 |
return ExternalTTSService()
|
|
|
283 |
|
284 |
@app.post("/v1/audio/speech",
|
285 |
summary="Generate Speech from Text",
|
286 |
+
description="Convert encrypted text to speech using an external TTS service. Rate limited to 5 requests per minute per user. Requires authentication and X-Session-Key header.",
|
287 |
tags=["Audio"],
|
288 |
responses={
|
289 |
200: {"description": "Audio stream", "content": {"audio/mp3": {"example": "Binary audio data"}}},
|
290 |
+
400: {"description": "Invalid or empty input"},
|
291 |
401: {"description": "Unauthorized - Token required"},
|
292 |
429: {"description": "Rate limit exceeded"},
|
293 |
+
502: {"description": "External TTS service unavailable"},
|
294 |
504: {"description": "TTS service timeout"}
|
295 |
})
|
296 |
@limiter.limit(settings.speech_rate_limit)
|
297 |
async def generate_audio(
|
298 |
request: Request,
|
299 |
+
input: str = Query(..., description="Base64-encoded encrypted text to convert to speech (max 1000 characters after decryption)"),
|
300 |
+
response_format: str = Query("mp3", description="Audio format (ignored, defaults to mp3 for external API)"),
|
301 |
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
302 |
+
x_session_key: str = Header(..., alias="X-Session-Key"),
|
303 |
tts_service: TTSService = Depends(get_tts_service)
|
304 |
):
|
305 |
user_id = await get_current_user(credentials)
|
306 |
+
session_key = base64.b64decode(x_session_key)
|
307 |
+
|
308 |
+
# Decrypt input
|
309 |
+
try:
|
310 |
+
encrypted_input = base64.b64decode(input)
|
311 |
+
decrypted_input = decrypt_data(encrypted_input, session_key).decode("utf-8")
|
312 |
+
except Exception as e:
|
313 |
+
logger.error(f"Input decryption failed: {str(e)}")
|
314 |
+
raise HTTPException(status_code=400, detail="Invalid encrypted input")
|
315 |
+
|
316 |
+
if not decrypted_input.strip():
|
317 |
raise HTTPException(status_code=400, detail="Input cannot be empty")
|
318 |
+
if len(decrypted_input) > 1000:
|
319 |
+
raise HTTPException(status_code=400, detail="Decrypted input cannot exceed 1000 characters")
|
320 |
|
321 |
logger.info("Processing speech request", extra={
|
322 |
"endpoint": "/v1/audio/speech",
|
323 |
+
"input_length": len(decrypted_input),
|
324 |
"client_ip": get_remote_address(request),
|
325 |
"user_id": user_id
|
326 |
})
|
327 |
|
328 |
payload = {
|
329 |
+
"text": decrypted_input
|
|
|
|
|
|
|
|
|
330 |
}
|
331 |
|
332 |
+
try:
|
333 |
+
response = await tts_service.generate_speech(payload)
|
334 |
+
response.raise_for_status()
|
335 |
+
except requests.HTTPError as e:
|
336 |
+
logger.error(f"External TTS request failed: {str(e)}")
|
337 |
+
raise HTTPException(status_code=502, detail=f"External TTS service error: {str(e)}")
|
338 |
|
339 |
headers = {
|
340 |
+
"Content-Disposition": "inline; filename=\"speech.mp3\"",
|
341 |
"Cache-Control": "no-cache",
|
342 |
+
"Content-Type": "audio/mp3"
|
343 |
}
|
344 |
|
345 |
return StreamingResponse(
|
346 |
response.iter_content(chunk_size=8192),
|
347 |
+
media_type="audio/mp3",
|
348 |
headers=headers
|
349 |
)
|
350 |
|
|
|
386 |
logger.error(f"Source language decryption failed: {str(e)}")
|
387 |
raise HTTPException(status_code=400, detail="Invalid encrypted source language")
|
388 |
|
389 |
+
# Decrypt the target language
|
390 |
+
try:
|
391 |
+
encrypted_tgt_lang = base64.b64decode(chat_request.tgt_lang)
|
392 |
+
decrypted_tgt_lang = decrypt_data(encrypted_tgt_lang, session_key).decode("utf-8")
|
393 |
+
except Exception as e:
|
394 |
+
logger.error(f"Target language decryption failed: {str(e)}")
|
395 |
+
raise HTTPException(status_code=400, detail="Invalid encrypted target language")
|
396 |
+
|
397 |
if not decrypted_prompt:
|
398 |
raise HTTPException(status_code=400, detail="Prompt cannot be empty")
|
399 |
if len(decrypted_prompt) > 1000:
|
|
|
406 |
payload = {
|
407 |
"prompt": decrypted_prompt,
|
408 |
"src_lang": decrypted_src_lang,
|
409 |
+
"tgt_lang": decrypted_tgt_lang
|
410 |
}
|
411 |
|
412 |
response = requests.post(
|