AurelioAguirre commited on
Commit
b5d6152
·
1 Parent(s): f876b72

added openAI schema based endpoint and response v3

Browse files
Files changed (2) hide show
  1. main/main.py +5 -3
  2. main/routes.py +43 -40
main/main.py CHANGED
@@ -33,10 +33,12 @@ async def async_main():
33
  config = load_config()
34
  server_config = config.get('server', {})
35
 
36
- # Initialize API with config and await setup
37
  api = InferenceApi(config)
 
38
  await api.setup()
39
- await init_router(config)
 
40
 
41
  # Create LitServer instance with config
42
  server = ls.LitServer(
@@ -61,7 +63,7 @@ async def async_main():
61
 
62
  # Get configured port
63
  port = server_config.get('port', 8001)
64
- host = server_config.get('host', 'localhost')
65
 
66
  # Run server
67
  server.run(host=host, port=port)
 
33
  config = load_config()
34
  server_config = config.get('server', {})
35
 
36
+ # Initialize API with config
37
  api = InferenceApi(config)
38
+ # Setup API first
39
  await api.setup()
40
+ # Initialize router with the already setup API instance
41
+ await init_router(api)
42
 
43
  # Create LitServer instance with config
44
  server = ls.LitServer(
 
63
 
64
  # Get configured port
65
  port = server_config.get('port', 8001)
66
+ host = server_config.get('host', '0.0.0.0')
67
 
68
  # Run server
69
  server.run(host=host, port=port)
main/routes.py CHANGED
@@ -1,5 +1,9 @@
1
  from fastapi import APIRouter, HTTPException
 
2
  from typing import Optional
 
 
 
3
  from .api import InferenceApi
4
  from .schemas import (
5
  GenerateRequest,
@@ -10,12 +14,49 @@ from .schemas import (
10
  ChatCompletionRequest,
11
  ChatCompletionResponse
12
  )
13
- import logging
14
 
15
  router = APIRouter()
16
  logger = logging.getLogger(__name__)
17
  api = None
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  @router.post("/v1/chat/completions")
21
  async def create_chat_completion(request: ChatCompletionRequest):
@@ -72,44 +113,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
72
  logger.error(f"Error in chat completion endpoint: {str(e)}")
73
  raise HTTPException(status_code=500, detail=str(e))
74
 
75
-
76
- async def init_router(config: dict):
77
- """Initialize router with config and Inference API instance"""
78
- global api
79
- api = InferenceApi(config)
80
- await api.setup()
81
- logger.info("Router initialized with Inference API instance")
82
-
83
- @router.post("/generate")
84
- async def generate_text(request: GenerateRequest):
85
- """Generate text response from prompt"""
86
- logger.info(f"Received generation request for prompt: {request.prompt[:50]}...")
87
- try:
88
- response = await api.generate_response(
89
- prompt=request.prompt,
90
- system_message=request.system_message,
91
- max_new_tokens=request.max_new_tokens
92
- )
93
- logger.info("Successfully generated response")
94
- return {"generated_text": response}
95
- except Exception as e:
96
- logger.error(f"Error in generate_text endpoint: {str(e)}")
97
- raise HTTPException(status_code=500, detail=str(e))
98
-
99
- @router.post("/generate/stream")
100
- async def generate_stream(request: GenerateRequest):
101
- """Generate streaming text response from prompt"""
102
- logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...")
103
- try:
104
- return api.generate_stream(
105
- prompt=request.prompt,
106
- system_message=request.system_message,
107
- max_new_tokens=request.max_new_tokens
108
- )
109
- except Exception as e:
110
- logger.error(f"Error in generate_stream endpoint: {str(e)}")
111
- raise HTTPException(status_code=500, detail=str(e))
112
-
113
  @router.post("/embedding", response_model=EmbeddingResponse)
114
  async def generate_embedding(request: EmbeddingRequest):
115
  """Generate embedding vector from text"""
@@ -175,4 +178,4 @@ async def initialize_embedding_model(model_name: Optional[str] = None):
175
  async def shutdown_event():
176
  """Clean up resources on shutdown"""
177
  if api:
178
- await api.close()
 
1
  from fastapi import APIRouter, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
  from typing import Optional
4
+ import json
5
+ from time import time
6
+ import logging
7
  from .api import InferenceApi
8
  from .schemas import (
9
  GenerateRequest,
 
14
  ChatCompletionRequest,
15
  ChatCompletionResponse
16
  )
 
17
 
18
  router = APIRouter()
19
  logger = logging.getLogger(__name__)
20
  api = None
21
 
22
+ async def init_router(inference_api: InferenceApi):
23
+ """Initialize router with an already setup API instance"""
24
+ global api
25
+ api = inference_api
26
+ logger.info("Router initialized with Inference API instance")
27
+
28
+ @router.post("/generate")
29
+ async def generate_text(request: GenerateRequest):
30
+ """Generate text response from prompt"""
31
+ logger.info(f"Received generation request for prompt: {request.prompt[:50]}...")
32
+ try:
33
+ response = await api.generate_response(
34
+ prompt=request.prompt,
35
+ system_message=request.system_message,
36
+ max_new_tokens=request.max_new_tokens
37
+ )
38
+ logger.info("Successfully generated response")
39
+ return {"generated_text": response}
40
+ except Exception as e:
41
+ logger.error(f"Error in generate_text endpoint: {str(e)}")
42
+ raise HTTPException(status_code=500, detail=str(e))
43
+
44
+ @router.post("/generate/stream")
45
+ async def generate_stream(request: GenerateRequest):
46
+ """Generate streaming text response from prompt"""
47
+ logger.info(f"Received streaming generation request for prompt: {request.prompt[:50]}...")
48
+ try:
49
+ return StreamingResponse(
50
+ api.generate_stream(
51
+ prompt=request.prompt,
52
+ system_message=request.system_message,
53
+ max_new_tokens=request.max_new_tokens
54
+ ),
55
+ media_type="text/event-stream"
56
+ )
57
+ except Exception as e:
58
+ logger.error(f"Error in generate_stream endpoint: {str(e)}")
59
+ raise HTTPException(status_code=500, detail=str(e))
60
 
61
  @router.post("/v1/chat/completions")
62
  async def create_chat_completion(request: ChatCompletionRequest):
 
113
  logger.error(f"Error in chat completion endpoint: {str(e)}")
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  @router.post("/embedding", response_model=EmbeddingResponse)
117
  async def generate_embedding(request: EmbeddingRequest):
118
  """Generate embedding vector from text"""
 
178
  async def shutdown_event():
179
  """Clean up resources on shutdown"""
180
  if api:
181
+ await api.cleanup()