sachin commited on
Commit
3fa9edb
·
1 Parent(s): 20c50d1
Files changed (1) hide show
  1. src/server/main.py +48 -52
src/server/main.py CHANGED
@@ -68,37 +68,6 @@ async def get_user_id_for_rate_limit(request: Request):
68
  limiter = Limiter(key_func=get_user_id_for_rate_limit)
69
 
70
  # Request/Response Models
71
- class SpeechRequest(BaseModel):
72
- input: str = Field(..., description="Text to convert to speech (max 1000 characters)")
73
- voice: str = Field(..., description="Voice identifier for the TTS service")
74
- model: str = Field(..., description="TTS model to use")
75
- response_format: ResponseFormat = Field(tts_config.response_format, description="Audio format: mp3, flac, or wav")
76
- speed: float = Field(SPEED, description="Speech speed (default: 1.0)")
77
-
78
- @field_validator("input")
79
- def input_must_be_valid(cls, v):
80
- if len(v) > 1000:
81
- raise ValueError("Input cannot exceed 1000 characters")
82
- return v.strip()
83
-
84
- @field_validator("response_format")
85
- def validate_response_format(cls, v):
86
- supported_formats = [ResponseFormat.MP3, ResponseFormat.FLAC, ResponseFormat.WAV]
87
- if v not in supported_formats:
88
- raise ValueError(f"Response format must be one of {[fmt.value for fmt in supported_formats]}")
89
- return v
90
-
91
- class Config:
92
- schema_extra = {
93
- "example": {
94
- "input": "Hello, how are you?",
95
- "voice": "female-1",
96
- "model": "tts-model-1",
97
- "response_format": "mp3",
98
- "speed": 1.0
99
- }
100
- }
101
-
102
  class TranscriptionResponse(BaseModel):
103
  text: str = Field(..., description="Transcribed text from the audio")
104
 
@@ -120,8 +89,9 @@ class AudioProcessingResponse(BaseModel):
120
  class ChatRequest(BaseModel):
121
  prompt: str = Field(..., description="Base64-encoded encrypted prompt (max 1000 characters after decryption)")
122
  src_lang: str = Field(..., description="Base64-encoded encrypted source language code")
 
123
 
124
- @field_validator("prompt", "src_lang")
125
  def must_be_valid_base64(cls, v):
126
  try:
127
  base64.b64decode(v)
@@ -133,7 +103,8 @@ class ChatRequest(BaseModel):
133
  schema_extra = {
134
  "example": {
135
  "prompt": "base64_encoded_encrypted_prompt",
136
- "src_lang": "base64_encoded_encrypted_kan_Knda"
 
137
  }
138
  }
139
 
@@ -213,16 +184,18 @@ class ExternalTTSService(TTSService):
213
  async def generate_speech(self, payload: dict) -> requests.Response:
214
  try:
215
  return requests.post(
216
- settings.external_tts_url,
217
  json=payload,
218
- headers={"accept": "application/json", "Content-Type": "application/json"},
219
  stream=True,
220
  timeout=60
221
  )
222
  except requests.Timeout:
 
223
  raise HTTPException(status_code=504, detail="External TTS API timeout")
224
  except requests.RequestException as e:
225
- raise HTTPException(status_code=500, detail=f"External TTS API error: {str(e)}")
 
226
 
227
  def get_tts_service() -> TTSService:
228
  return ExternalTTSService()
@@ -310,53 +283,68 @@ async def app_register_user(
310
 
311
  @app.post("/v1/audio/speech",
312
  summary="Generate Speech from Text",
313
- description="Convert text to speech in the specified format using an external TTS service. Rate limited to 5 requests per minute per user. Requires authentication.",
314
  tags=["Audio"],
315
  responses={
316
  200: {"description": "Audio stream", "content": {"audio/mp3": {"example": "Binary audio data"}}},
317
- 400: {"description": "Invalid input"},
318
  401: {"description": "Unauthorized - Token required"},
319
  429: {"description": "Rate limit exceeded"},
 
320
  504: {"description": "TTS service timeout"}
321
  })
322
  @limiter.limit(settings.speech_rate_limit)
323
  async def generate_audio(
324
  request: Request,
325
- speech_request: SpeechRequest = Depends(),
 
326
  credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
 
327
  tts_service: TTSService = Depends(get_tts_service)
328
  ):
329
  user_id = await get_current_user(credentials)
330
- if not speech_request.input.strip():
 
 
 
 
 
 
 
 
 
 
331
  raise HTTPException(status_code=400, detail="Input cannot be empty")
 
 
332
 
333
  logger.info("Processing speech request", extra={
334
  "endpoint": "/v1/audio/speech",
335
- "input_length": len(speech_request.input),
336
  "client_ip": get_remote_address(request),
337
  "user_id": user_id
338
  })
339
 
340
  payload = {
341
- "input": speech_request.input,
342
- "voice": speech_request.voice,
343
- "model": speech_request.model,
344
- "response_format": speech_request.response_format.value,
345
- "speed": speech_request.speed
346
  }
347
 
348
- response = await tts_service.generate_speech(payload)
349
- response.raise_for_status()
 
 
 
 
350
 
351
  headers = {
352
- "Content-Disposition": f"inline; filename=\"speech.{speech_request.response_format.value}\"",
353
  "Cache-Control": "no-cache",
354
- "Content-Type": f"audio/{speech_request.response_format.value}"
355
  }
356
 
357
  return StreamingResponse(
358
  response.iter_content(chunk_size=8192),
359
- media_type=f"audio/{speech_request.response_format.value}",
360
  headers=headers
361
  )
362
 
@@ -398,6 +386,14 @@ async def chat(
398
  logger.error(f"Source language decryption failed: {str(e)}")
399
  raise HTTPException(status_code=400, detail="Invalid encrypted source language")
400
 
 
 
 
 
 
 
 
 
401
  if not decrypted_prompt:
402
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
403
  if len(decrypted_prompt) > 1000:
@@ -410,7 +406,7 @@ async def chat(
410
  payload = {
411
  "prompt": decrypted_prompt,
412
  "src_lang": decrypted_src_lang,
413
- "tgt_lang": decrypted_src_lang
414
  }
415
 
416
  response = requests.post(
 
68
  limiter = Limiter(key_func=get_user_id_for_rate_limit)
69
 
70
  # Request/Response Models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  class TranscriptionResponse(BaseModel):
72
  text: str = Field(..., description="Transcribed text from the audio")
73
 
 
89
  class ChatRequest(BaseModel):
90
  prompt: str = Field(..., description="Base64-encoded encrypted prompt (max 1000 characters after decryption)")
91
  src_lang: str = Field(..., description="Base64-encoded encrypted source language code")
92
+ tgt_lang: str = Field(..., description="Base64-encoded encrypted target language code")
93
 
94
+ @field_validator("prompt", "src_lang", "tgt_lang")
95
  def must_be_valid_base64(cls, v):
96
  try:
97
  base64.b64decode(v)
 
103
  schema_extra = {
104
  "example": {
105
  "prompt": "base64_encoded_encrypted_prompt",
106
+ "src_lang": "base64_encoded_encrypted_kan_Knda",
107
+ "tgt_lang": "base64_encoded_encrypted_kan_Knda"
108
  }
109
  }
110
 
 
184
  async def generate_speech(self, payload: dict) -> requests.Response:
185
  try:
186
  return requests.post(
187
+ f"{settings.external_tts_url}/audio/speech",
188
  json=payload,
189
+ headers={"accept": "*/*", "Content-Type": "application/json"},
190
  stream=True,
191
  timeout=60
192
  )
193
  except requests.Timeout:
194
+ logger.error("External TTS API timeout")
195
  raise HTTPException(status_code=504, detail="External TTS API timeout")
196
  except requests.RequestException as e:
197
+ logger.error(f"External TTS API error: {str(e)}")
198
+ raise HTTPException(status_code=502, detail=f"External TTS service error: {str(e)}")
199
 
200
  def get_tts_service() -> TTSService:
201
  return ExternalTTSService()
 
283
 
284
  @app.post("/v1/audio/speech",
285
  summary="Generate Speech from Text",
286
+ description="Convert encrypted text to speech using an external TTS service. Rate limited to 5 requests per minute per user. Requires authentication and X-Session-Key header.",
287
  tags=["Audio"],
288
  responses={
289
  200: {"description": "Audio stream", "content": {"audio/mp3": {"example": "Binary audio data"}}},
290
+ 400: {"description": "Invalid or empty input"},
291
  401: {"description": "Unauthorized - Token required"},
292
  429: {"description": "Rate limit exceeded"},
293
+ 502: {"description": "External TTS service unavailable"},
294
  504: {"description": "TTS service timeout"}
295
  })
296
  @limiter.limit(settings.speech_rate_limit)
297
  async def generate_audio(
298
  request: Request,
299
+ input: str = Query(..., description="Base64-encoded encrypted text to convert to speech (max 1000 characters after decryption)"),
300
+ response_format: str = Query("mp3", description="Audio format (ignored, defaults to mp3 for external API)"),
301
  credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
302
+ x_session_key: str = Header(..., alias="X-Session-Key"),
303
  tts_service: TTSService = Depends(get_tts_service)
304
  ):
305
  user_id = await get_current_user(credentials)
306
+ session_key = base64.b64decode(x_session_key)
307
+
308
+ # Decrypt input
309
+ try:
310
+ encrypted_input = base64.b64decode(input)
311
+ decrypted_input = decrypt_data(encrypted_input, session_key).decode("utf-8")
312
+ except Exception as e:
313
+ logger.error(f"Input decryption failed: {str(e)}")
314
+ raise HTTPException(status_code=400, detail="Invalid encrypted input")
315
+
316
+ if not decrypted_input.strip():
317
  raise HTTPException(status_code=400, detail="Input cannot be empty")
318
+ if len(decrypted_input) > 1000:
319
+ raise HTTPException(status_code=400, detail="Decrypted input cannot exceed 1000 characters")
320
 
321
  logger.info("Processing speech request", extra={
322
  "endpoint": "/v1/audio/speech",
323
+ "input_length": len(decrypted_input),
324
  "client_ip": get_remote_address(request),
325
  "user_id": user_id
326
  })
327
 
328
  payload = {
329
+ "text": decrypted_input
 
 
 
 
330
  }
331
 
332
+ try:
333
+ response = await tts_service.generate_speech(payload)
334
+ response.raise_for_status()
335
+ except requests.HTTPError as e:
336
+ logger.error(f"External TTS request failed: {str(e)}")
337
+ raise HTTPException(status_code=502, detail=f"External TTS service error: {str(e)}")
338
 
339
  headers = {
340
+ "Content-Disposition": "inline; filename=\"speech.mp3\"",
341
  "Cache-Control": "no-cache",
342
+ "Content-Type": "audio/mp3"
343
  }
344
 
345
  return StreamingResponse(
346
  response.iter_content(chunk_size=8192),
347
+ media_type="audio/mp3",
348
  headers=headers
349
  )
350
 
 
386
  logger.error(f"Source language decryption failed: {str(e)}")
387
  raise HTTPException(status_code=400, detail="Invalid encrypted source language")
388
 
389
+ # Decrypt the target language
390
+ try:
391
+ encrypted_tgt_lang = base64.b64decode(chat_request.tgt_lang)
392
+ decrypted_tgt_lang = decrypt_data(encrypted_tgt_lang, session_key).decode("utf-8")
393
+ except Exception as e:
394
+ logger.error(f"Target language decryption failed: {str(e)}")
395
+ raise HTTPException(status_code=400, detail="Invalid encrypted target language")
396
+
397
  if not decrypted_prompt:
398
  raise HTTPException(status_code=400, detail="Prompt cannot be empty")
399
  if len(decrypted_prompt) > 1000:
 
406
  payload = {
407
  "prompt": decrypted_prompt,
408
  "src_lang": decrypted_src_lang,
409
+ "tgt_lang": decrypted_tgt_lang
410
  }
411
 
412
  response = requests.post(