AurelioAguirre commited on
Commit
02fd6bb
·
1 Parent(s): da1009f

added openAI schema based endpoint and response

Browse files
Files changed (6) hide show
  1. main/api.py +19 -21
  2. main/config.yaml +13 -6
  3. main/main.py +16 -12
  4. main/routes.py +62 -3
  5. main/schemas.py +54 -2
  6. requirements.txt +1 -0
main/api.py CHANGED
@@ -9,13 +9,15 @@ class GenerationResponse(BaseModel):
9
  generated_text: str
10
 
11
  class InferenceApi(LitAPI):
12
- def __init__(self):
13
  """Initialize the Inference API with configuration."""
14
  super().__init__()
15
  self.logger = logging.getLogger(__name__)
16
  self.logger.info("Initializing Inference API")
17
  self._device = None
18
- self.stream = False # Add stream flag for compatibility with LitAPI
 
 
19
 
20
  async def setup(self, device: Optional[str] = None):
21
  """Setup method required by LitAPI"""
@@ -25,15 +27,19 @@ class InferenceApi(LitAPI):
25
  async def _get_client(self):
26
  """Get or create HTTP client as needed"""
27
  return httpx.AsyncClient(
28
- base_url="http://localhost:8002",
29
- timeout=60.0
30
  )
31
 
 
 
 
 
 
 
 
32
  def predict(self, x: str, **kwargs) -> Iterator[str]:
33
- """
34
- Non-async prediction method that yields results.
35
- Implements required LitAPI method.
36
- """
37
  loop = asyncio.get_event_loop()
38
  async def async_gen():
39
  async for item in self._async_predict(x, **kwargs):
@@ -47,9 +53,7 @@ class InferenceApi(LitAPI):
47
  break
48
 
49
  async def _async_predict(self, x: str, **kwargs) -> AsyncIterator[str]:
50
- """
51
- Internal async prediction method.
52
- """
53
  if self.stream:
54
  async for chunk in self.generate_stream(x, **kwargs):
55
  yield chunk
@@ -58,19 +62,13 @@ class InferenceApi(LitAPI):
58
  yield response
59
 
60
  def decode_request(self, request: Any, **kwargs) -> str:
61
- """
62
- Convert the request payload to input format.
63
- Implements required LitAPI method.
64
- """
65
  if isinstance(request, dict) and "prompt" in request:
66
  return request["prompt"]
67
  return request
68
 
69
  def encode_response(self, output: Iterator[str], **kwargs) -> Dict[str, Any]:
70
- """
71
- Convert the model output to a response payload.
72
- Implements required LitAPI method.
73
- """
74
  if self.stream:
75
  return {"generated_text": output}
76
  try:
@@ -91,7 +89,7 @@ class InferenceApi(LitAPI):
91
  try:
92
  async with await self._get_client() as client:
93
  response = await client.post(
94
- "/api/v1/generate",
95
  json={
96
  "prompt": prompt,
97
  "system_message": system_message,
@@ -119,7 +117,7 @@ class InferenceApi(LitAPI):
119
  client = await self._get_client()
120
  async with client.stream(
121
  "POST",
122
- "/api/v1/generate/stream",
123
  json={
124
  "prompt": prompt,
125
  "system_message": system_message,
 
9
  generated_text: str
10
 
11
  class InferenceApi(LitAPI):
12
+ def __init__(self, config: Dict[str, Any]):
13
  """Initialize the Inference API with configuration."""
14
  super().__init__()
15
  self.logger = logging.getLogger(__name__)
16
  self.logger.info("Initializing Inference API")
17
  self._device = None
18
+ self.stream = False
19
+ self.config = config
20
+ self.llm_config = config.get('llm_server', {})
21
 
22
  async def setup(self, device: Optional[str] = None):
23
  """Setup method required by LitAPI"""
 
27
  async def _get_client(self):
28
  """Get or create HTTP client as needed"""
29
  return httpx.AsyncClient(
30
+ base_url=self.llm_config.get('base_url', 'http://localhost:8002'),
31
+ timeout=float(self.llm_config.get('timeout', 60.0))
32
  )
33
 
34
+ def _get_endpoint(self, endpoint_name: str) -> str:
35
+ """Get full endpoint path including prefix"""
36
+ endpoints = self.llm_config.get('endpoints', {})
37
+ api_prefix = self.llm_config.get('api_prefix', '')
38
+ endpoint = endpoints.get(endpoint_name, '')
39
+ return f"{api_prefix}{endpoint}"
40
+
41
  def predict(self, x: str, **kwargs) -> Iterator[str]:
42
+ """Non-async prediction method that yields results."""
 
 
 
43
  loop = asyncio.get_event_loop()
44
  async def async_gen():
45
  async for item in self._async_predict(x, **kwargs):
 
53
  break
54
 
55
  async def _async_predict(self, x: str, **kwargs) -> AsyncIterator[str]:
56
+ """Internal async prediction method."""
 
 
57
  if self.stream:
58
  async for chunk in self.generate_stream(x, **kwargs):
59
  yield chunk
 
62
  yield response
63
 
64
  def decode_request(self, request: Any, **kwargs) -> str:
65
+ """Convert the request payload to input format."""
 
 
 
66
  if isinstance(request, dict) and "prompt" in request:
67
  return request["prompt"]
68
  return request
69
 
70
  def encode_response(self, output: Iterator[str], **kwargs) -> Dict[str, Any]:
71
+ """Convert the model output to a response payload."""
 
 
 
72
  if self.stream:
73
  return {"generated_text": output}
74
  try:
 
89
  try:
90
  async with await self._get_client() as client:
91
  response = await client.post(
92
+ self._get_endpoint('generate'),
93
  json={
94
  "prompt": prompt,
95
  "system_message": system_message,
 
117
  client = await self._get_client()
118
  async with client.stream(
119
  "POST",
120
+ self._get_endpoint('generate_stream'),
121
  json={
122
  "prompt": prompt,
123
  "system_message": system_message,
main/config.yaml CHANGED
@@ -1,11 +1,18 @@
1
  server:
 
2
  port: 8001
3
  timeout: 60
 
4
 
5
  llm_server:
6
- base_url: "https://teamgenki-llmserver.hf.space:7680" # URL of your LLM Server
7
- timeout: 60 # Timeout for requests to LLM Server
8
-
9
- logging:
10
- level: "INFO"
11
- format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
 
 
 
 
 
 
1
  server:
2
+ host: "localhost"
3
  port: 8001
4
  timeout: 60
5
+ max_batch_size: 1
6
 
7
  llm_server:
8
+ base_url: "https://teamgenki-llmserver.hf.space:7860"
9
+ timeout: 60.0
10
+ api_prefix: "/api/v1" # This will be used for route prefixing
11
+ endpoints:
12
+ generate: "/generate"
13
+ generate_stream: "/generate/stream"
14
+ embedding: "/embedding"
15
+ system_status: "/system/status"
16
+ system_validate: "/system/validate"
17
+ model_initialize: "/model/initialize"
18
+ model_initialize_embedding: "/model/initialize/embedding"
main/main.py CHANGED
@@ -31,17 +31,18 @@ async def async_main():
31
  try:
32
  # Load configuration
33
  config = load_config()
 
34
 
35
- # Initialize API and router with await
36
- api = InferenceApi()
37
- await api.setup() # Properly await the setup
38
- await init_router(config) # Modified to be async
39
 
40
- # Create LitServer instance
41
  server = ls.LitServer(
42
  api,
43
- timeout=config.get("server", {}).get("timeout", 60),
44
- max_batch_size=1,
45
  track_requests=True
46
  )
47
 
@@ -54,13 +55,16 @@ async def async_main():
54
  allow_headers=["*"],
55
  )
56
 
57
- # Add our routes to the server's FastAPI app
58
- server.app.include_router(router, prefix="/api/v1")
 
59
 
60
- port = config.get("server", {}).get("port", 8001)
 
 
61
 
62
- # Get the server instance from our app and run it
63
- server.run(port=port)
64
 
65
  except Exception as e:
66
  logger.error(f"Server initialization failed: {str(e)}")
 
31
  try:
32
  # Load configuration
33
  config = load_config()
34
+ server_config = config.get('server', {})
35
 
36
+ # Initialize API with config and await setup
37
+ api = InferenceApi(config)
38
+ await api.setup()
39
+ await init_router(config)
40
 
41
+ # Create LitServer instance with config
42
  server = ls.LitServer(
43
  api,
44
+ timeout=server_config.get('timeout', 60),
45
+ max_batch_size=server_config.get('max_batch_size', 1),
46
  track_requests=True
47
  )
48
 
 
55
  allow_headers=["*"],
56
  )
57
 
58
+ # Add routes with configured prefix
59
+ api_prefix = config.get('llm_server', {}).get('api_prefix', '/api/v1')
60
+ server.app.include_router(router, prefix=api_prefix)
61
 
62
+ # Get configured port
63
+ port = server_config.get('port', 8001)
64
+ host = server_config.get('host', 'localhost')
65
 
66
+ # Run server
67
+ server.run(host=host, port=port)
68
 
69
  except Exception as e:
70
  logger.error(f"Server initialization failed: {str(e)}")
main/routes.py CHANGED
@@ -6,7 +6,9 @@ from .schemas import (
6
  EmbeddingRequest,
7
  EmbeddingResponse,
8
  SystemStatusResponse,
9
- ValidationResponse
 
 
10
  )
11
  import logging
12
 
@@ -14,11 +16,68 @@ router = APIRouter()
14
  logger = logging.getLogger(__name__)
15
  api = None
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  async def init_router(config: dict):
18
  """Initialize router with config and Inference API instance"""
19
  global api
20
- api = InferenceApi()
21
- await api.setup() # Properly await the setup
22
  logger.info("Router initialized with Inference API instance")
23
 
24
  @router.post("/generate")
 
6
  EmbeddingRequest,
7
  EmbeddingResponse,
8
  SystemStatusResponse,
9
+ ValidationResponse,
10
+ ChatCompletionRequest,
11
+ ChatCompletionResponse
12
  )
13
  import logging
14
 
 
16
  logger = logging.getLogger(__name__)
17
  api = None
18
 
19
+
20
+ @router.post("/v1/chat/completions")
21
+ async def create_chat_completion(request: ChatCompletionRequest):
22
+ """OpenAI-compatible chat completion endpoint"""
23
+ logger.info(f"Received chat completion request with {len(request.messages)} messages")
24
+
25
+ try:
26
+ # Extract the last user message, or combine messages if needed
27
+ last_message = request.messages[-1].content
28
+
29
+ if request.stream:
30
+ # For streaming, we need to create a generator that yields OpenAI-compatible chunks
31
+ async def generate_stream():
32
+ async for chunk in api.generate_stream(
33
+ prompt=last_message,
34
+ ):
35
+ # Create a streaming response chunk in OpenAI format
36
+ response_chunk = {
37
+ "id": "chatcmpl-123",
38
+ "object": "chat.completion.chunk",
39
+ "created": int(time()),
40
+ "model": request.model,
41
+ "choices": [{
42
+ "index": 0,
43
+ "delta": {
44
+ "content": chunk
45
+ },
46
+ "finish_reason": None
47
+ }]
48
+ }
49
+ yield f"data: {json.dumps(response_chunk)}\n\n"
50
+
51
+ # Send the final chunk
52
+ yield f"data: [DONE]\n\n"
53
+
54
+ return StreamingResponse(
55
+ generate_stream(),
56
+ media_type="text/event-stream"
57
+ )
58
+
59
+ else:
60
+ # For non-streaming, generate the full response
61
+ response_text = await api.generate_response(
62
+ prompt=last_message,
63
+ )
64
+
65
+ # Convert to OpenAI format
66
+ return ChatCompletionResponse.from_response(
67
+ content=response_text,
68
+ model=request.model
69
+ )
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error in chat completion endpoint: {str(e)}")
73
+ raise HTTPException(status_code=500, detail=str(e))
74
+
75
+
76
  async def init_router(config: dict):
77
  """Initialize router with config and Inference API instance"""
78
  global api
79
+ api = InferenceApi(config)
80
+ await api.setup()
81
  logger.info("Router initialized with Inference API instance")
82
 
83
  @router.post("/generate")
main/schemas.py CHANGED
@@ -1,5 +1,57 @@
1
- from pydantic import BaseModel
2
- from typing import Optional, List, Dict, Union
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  class GenerateRequest(BaseModel):
5
  prompt: str
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Optional, Dict, Union, Literal
3
+ from time import time
4
+
5
+ class ChatMessage(BaseModel):
6
+ role: str
7
+ content: str
8
+
9
+ class ChatCompletionRequest(BaseModel):
10
+ model: str
11
+ messages: List[ChatMessage]
12
+ stream: bool = False
13
+
14
+ class ChatCompletionMessage(BaseModel):
15
+ role: str = "assistant"
16
+ content: str
17
+
18
+ class ChatCompletionChoice(BaseModel):
19
+ index: int = 0
20
+ message: ChatCompletionMessage
21
+ logprobs: Optional[None] = None
22
+ finish_reason: str = "stop"
23
+
24
+ class CompletionTokenDetails(BaseModel):
25
+ reasoning_tokens: int = 0
26
+ accepted_prediction_tokens: int = 0
27
+ rejected_prediction_tokens: int = 0
28
+
29
+ class CompletionUsage(BaseModel):
30
+ prompt_tokens: int = 9 # Placeholder values
31
+ completion_tokens: int = 12
32
+ total_tokens: int = 21
33
+ completion_tokens_details: CompletionTokenDetails = Field(default_factory=CompletionTokenDetails)
34
+
35
+ class ChatCompletionResponse(BaseModel):
36
+ id: str = Field(default="chatcmpl-123")
37
+ object: str = "chat.completion"
38
+ created: int = Field(default_factory=lambda: int(time()))
39
+ model: str = "gpt-4o-mini"
40
+ system_fingerprint: str = "fp_44709d6fcb"
41
+ choices: List[ChatCompletionChoice]
42
+ usage: CompletionUsage = Field(default_factory=CompletionUsage)
43
+
44
+ @classmethod
45
+ def from_response(cls, content: str, model: str = "gpt-4o-mini") -> "ChatCompletionResponse":
46
+ """Create a ChatCompletionResponse from a simple response string"""
47
+ return cls(
48
+ model=model,
49
+ choices=[
50
+ ChatCompletionChoice(
51
+ message=ChatCompletionMessage(content=content)
52
+ )
53
+ ]
54
+ )
55
 
56
  class GenerateRequest(BaseModel):
57
  prompt: str
requirements.txt CHANGED
@@ -9,6 +9,7 @@ httptools==0.6.4
9
  httpx==0.28.1
10
  idna==3.10
11
  litserve==0.2.5
 
12
  pydantic==2.10.4
13
  pydantic_core==2.27.2
14
  python-dotenv==1.0.1
 
9
  httpx==0.28.1
10
  idna==3.10
11
  litserve==0.2.5
12
+ numpy==2.2.1
13
  pydantic==2.10.4
14
  pydantic_core==2.27.2
15
  python-dotenv==1.0.1