AurelioAguirre commited on
Commit
a4e24d4
·
1 Parent(s): 1eab622

Fixing dockerfile v3

Browse files
Files changed (2) hide show
  1. src/api.py +53 -88
  2. src/main.py +3 -2
src/api.py CHANGED
@@ -1,24 +1,60 @@
1
  import httpx
2
- from typing import Optional, Iterator, List, Dict, Union
3
  import logging
 
4
 
5
- class InferenceApi:
6
  def __init__(self, config: dict):
7
  """Initialize the Inference API with configuration."""
 
8
  self.logger = logging.getLogger(__name__)
9
  self.logger.info("Initializing Inference API")
10
 
11
  # Get base URL from config
12
  self.base_url = config["llm_server"]["base_url"]
13
  self.timeout = config["llm_server"].get("timeout", 60)
 
14
 
15
- # Initialize HTTP client
 
 
 
 
 
16
  self.client = httpx.AsyncClient(
17
  base_url=self.base_url,
18
  timeout=self.timeout
19
  )
 
20
 
21
- self.logger.info("Inference API initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  async def generate_response(
24
  self,
@@ -26,9 +62,7 @@ class InferenceApi:
26
  system_message: Optional[str] = None,
27
  max_new_tokens: Optional[int] = None
28
  ) -> str:
29
- """
30
- Generate a complete response by forwarding the request to the LLM Server.
31
- """
32
  self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
33
 
34
  try:
@@ -54,9 +88,7 @@ class InferenceApi:
54
  system_message: Optional[str] = None,
55
  max_new_tokens: Optional[int] = None
56
  ) -> Iterator[str]:
57
- """
58
- Generate a streaming response by forwarding the request to the LLM Server.
59
- """
60
  self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
61
 
62
  try:
@@ -77,83 +109,16 @@ class InferenceApi:
77
  self.logger.error(f"Error in generate_stream: {str(e)}")
78
  raise
79
 
80
- async def generate_embedding(self, text: str) -> List[float]:
81
- """
82
- Generate embedding by forwarding the request to the LLM Server.
83
- """
84
- self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
85
 
86
- try:
87
- response = await self.client.post(
88
- "/api/v1/embedding",
89
- json={"text": text}
90
- )
91
- response.raise_for_status()
92
- data = response.json()
93
- return data["embedding"]
94
-
95
- except Exception as e:
96
- self.logger.error(f"Error in generate_embedding: {str(e)}")
97
- raise
98
-
99
- async def check_system_status(self) -> Dict[str, Union[Dict, str]]:
100
- """
101
- Get system status from the LLM Server.
102
- """
103
- try:
104
- response = await self.client.get("/api/v1/system/status")
105
- response.raise_for_status()
106
- return response.json()
107
-
108
- except Exception as e:
109
- self.logger.error(f"Error getting system status: {str(e)}")
110
- raise
111
-
112
- async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]:
113
- """
114
- Get system validation status from the LLM Server.
115
- """
116
- try:
117
- response = await self.client.get("/api/v1/system/validate")
118
- response.raise_for_status()
119
- return response.json()
120
-
121
- except Exception as e:
122
- self.logger.error(f"Error validating system: {str(e)}")
123
- raise
124
-
125
- async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
126
- """
127
- Initialize a model on the LLM Server.
128
- """
129
- try:
130
- response = await self.client.post(
131
- "/api/v1/model/initialize",
132
- params={"model_name": model_name} if model_name else None
133
- )
134
- response.raise_for_status()
135
- return response.json()
136
-
137
- except Exception as e:
138
- self.logger.error(f"Error initializing model: {str(e)}")
139
- raise
140
-
141
- async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
142
- """
143
- Initialize an embedding model on the LLM Server.
144
- """
145
- try:
146
- response = await self.client.post(
147
- "/api/v1/model/initialize/embedding",
148
- params={"model_name": model_name} if model_name else None
149
- )
150
- response.raise_for_status()
151
- return response.json()
152
-
153
- except Exception as e:
154
- self.logger.error(f"Error initializing embedding model: {str(e)}")
155
- raise
156
 
157
- async def close(self):
158
- """Close the HTTP client session."""
159
- await self.client.aclose()
 
 
 
 
1
  import httpx
2
+ from typing import Optional, Iterator, Union, Any
3
  import logging
4
+ from litserve import LitAPI
5
 
6
+ class InferenceApi(LitAPI):
7
  def __init__(self, config: dict):
8
  """Initialize the Inference API with configuration."""
9
+ super().__init__()
10
  self.logger = logging.getLogger(__name__)
11
  self.logger.info("Initializing Inference API")
12
 
13
  # Get base URL from config
14
  self.base_url = config["llm_server"]["base_url"]
15
  self.timeout = config["llm_server"].get("timeout", 60)
16
+ self.client = None # Will be initialized in setup()
17
 
18
+ # Set request timeout from config
19
+ self.request_timeout = float(self.timeout)
20
+
21
+ async def setup(self, device: Optional[str] = None):
22
+ """Setup method required by LitAPI - initialize HTTP client"""
23
+ self._device = device # Store device as required by LitAPI
24
  self.client = httpx.AsyncClient(
25
  base_url=self.base_url,
26
  timeout=self.timeout
27
  )
28
+ self.logger.info(f"Inference API setup completed on device: {device}")
29
 
30
+ async def predict(self, x: str, **kwargs) -> Union[str, Iterator[str]]:
31
+ """
32
+ Main prediction method required by LitAPI.
33
+ If streaming is enabled, yields chunks; otherwise returns complete response.
34
+ """
35
+ if self.stream:
36
+ async for chunk in self.generate_stream(x, **kwargs):
37
+ yield chunk
38
+ else:
39
+ return await self.generate_response(x, **kwargs)
40
+
41
+ def decode_request(self, request: Any, **kwargs) -> str:
42
+ """Convert the request payload to input format."""
43
+ # For our case, we expect the request to be text
44
+ if isinstance(request, dict) and "prompt" in request:
45
+ return request["prompt"]
46
+ return request
47
+
48
+ def encode_response(self, output: Union[str, Iterator[str]], **kwargs) -> Union[str, Iterator[str]]:
49
+ """Convert the model output to a response payload."""
50
+ if self.stream:
51
+ # For streaming, yield each chunk wrapped in a dict
52
+ async def stream_wrapper():
53
+ async for chunk in output:
54
+ yield {"generated_text": chunk}
55
+ else:
56
+ # For non-streaming, return complete response
57
+ return {"generated_text": output}
58
 
59
  async def generate_response(
60
  self,
 
62
  system_message: Optional[str] = None,
63
  max_new_tokens: Optional[int] = None
64
  ) -> str:
65
+ """Generate a complete response by forwarding the request to the LLM Server."""
 
 
66
  self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
67
 
68
  try:
 
88
  system_message: Optional[str] = None,
89
  max_new_tokens: Optional[int] = None
90
  ) -> Iterator[str]:
91
+ """Generate a streaming response by forwarding the request to the LLM Server."""
 
 
92
  self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
93
 
94
  try:
 
109
  self.logger.error(f"Error in generate_stream: {str(e)}")
110
  raise
111
 
112
+ # ... [rest of the methods remain the same: generate_embedding, check_system_status, etc.]
 
 
 
 
113
 
114
+ async def cleanup(self):
115
+ """Cleanup method - close HTTP client"""
116
+ if self.client:
117
+ await self.client.aclose()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def log(self, key: str, value: Any):
120
+ """Override log method to use our logger if queue not set"""
121
+ if self._logger_queue is None:
122
+ self.logger.info(f"Log event: {key}={value}")
123
+ else:
124
+ super().log(key, value)
src/main.py CHANGED
@@ -6,6 +6,7 @@ import yaml
6
  import logging
7
  from pathlib import Path
8
  from .routes import router, init_router
 
9
 
10
  def setup_logging():
11
  """Set up basic logging configuration"""
@@ -31,9 +32,9 @@ def main():
31
 
32
  # Initialize the router with our config
33
  init_router(config)
34
-
35
  # Create LitServer instance
36
- server = ls.LitServer(
37
  timeout=config.get("server", {}).get("timeout", 60),
38
  max_batch_size=1,
39
  track_requests=True
 
6
  import logging
7
  from pathlib import Path
8
  from .routes import router, init_router
9
+ from api import InferenceApi
10
 
11
  def setup_logging():
12
  """Set up basic logging configuration"""
 
32
 
33
  # Initialize the router with our config
34
  init_router(config)
35
+ api = InferenceApi()
36
  # Create LitServer instance
37
+ server = ls.LitServer(api,
38
  timeout=config.get("server", {}).get("timeout", 60),
39
  max_batch_size=1,
40
  track_requests=True