Niansuh commited on
Commit
1d86e6d
·
verified ·
1 Parent(s): 6b3b956

Update api/utils.py

Browse files
Files changed (1) hide show
  1. api/utils.py +9 -11
api/utils.py CHANGED
@@ -1,24 +1,22 @@
1
  from fastapi import HTTPException
2
- from fastapi.responses import StreamingResponse
3
  from api.config import MODEL_PROVIDER_MAPPING
4
- from api.models import ChatRequest
5
  from api.provider import gizai
6
  from api.logger import setup_logger
7
 
8
  logger = setup_logger(__name__)
9
 
10
- async def process_streaming_response(request: ChatRequest):
11
- provider_name = MODEL_PROVIDER_MAPPING.get(request.model)
12
  if provider_name == 'gizai':
13
  # GizAI does not support streaming; process as non-streaming
14
- response = await gizai.process_non_streaming_response(request)
15
- return StreamingResponse(iter([json.dumps(response)]), media_type="application/json")
16
  else:
17
- raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported for streaming.")
18
 
19
- async def process_non_streaming_response(request: ChatRequest):
20
- provider_name = MODEL_PROVIDER_MAPPING.get(request.model)
21
  if provider_name == 'gizai':
22
- return await gizai.process_non_streaming_response(request)
23
  else:
24
- raise HTTPException(status_code=400, detail=f"Model {request.model} is not supported.")
 
1
  from fastapi import HTTPException
 
2
  from api.config import MODEL_PROVIDER_MAPPING
 
3
  from api.provider import gizai
4
  from api.logger import setup_logger
5
 
6
  logger = setup_logger(__name__)
7
 
8
+ async def process_streaming_response(request_data):
9
+ provider_name = MODEL_PROVIDER_MAPPING.get(request_data.get('model'))
10
  if provider_name == 'gizai':
11
  # GizAI does not support streaming; process as non-streaming
12
+ response = await gizai.process_non_streaming_response(request_data)
13
+ return iter([json.dumps(response)])
14
  else:
15
+ raise HTTPException(status_code=400, detail=f"Model {request_data.get('model')} is not supported for streaming.")
16
 
17
+ async def process_non_streaming_response(request_data):
18
+ provider_name = MODEL_PROVIDER_MAPPING.get(request_data.get('model'))
19
  if provider_name == 'gizai':
20
+ return await gizai.process_non_streaming_response(request_data)
21
  else:
22
+ raise HTTPException(status_code=400, detail=f"Model {request_data.get('model')} is not supported.")