AurelioAguirre commited on
Commit
db5664e
·
1 Parent(s): 19d483b

fixing pydantic v4

Browse files
Files changed (2) hide show
  1. main/api.py +19 -36
  2. main/main.py +1 -1
main/api.py CHANGED
@@ -1,36 +1,33 @@
1
  import httpx
2
- from typing import Optional, Iterator, Union, Any
3
  import logging
4
  from litserve import LitAPI
 
 
 
 
5
 
6
  class InferenceApi(LitAPI):
7
- def __init__(self, config: dict):
8
  """Initialize the Inference API with configuration."""
9
  super().__init__()
10
  self.logger = logging.getLogger(__name__)
11
  self.logger.info("Initializing Inference API")
12
-
13
- # Get base URL from config
14
- self.base_url = config["llm_server"]["base_url"]
15
- self.timeout = config["llm_server"].get("timeout", 60)
16
- self.client = None # Will be initialized in setup()
17
-
18
- # Set request timeout from config
19
- self.request_timeout = float(self.timeout)
20
 
21
  async def setup(self, device: Optional[str] = None):
22
  """Setup method required by LitAPI - initialize HTTP client"""
23
- self._device = device # Store device as required by LitAPI
24
  self.client = httpx.AsyncClient(
25
- base_url=self.base_url,
26
- timeout=self.timeout
27
  )
28
  self.logger.info(f"Inference API setup completed on device: {device}")
29
 
30
- async def predict(self, x: str, **kwargs) -> Union[str, Iterator[str]]:
31
  """
32
  Main prediction method required by LitAPI.
33
- If streaming is enabled, yields chunks; otherwise returns via yield.
34
  """
35
  if self.stream:
36
  async for chunk in self.generate_stream(x, **kwargs):
@@ -41,21 +38,16 @@ class InferenceApi(LitAPI):
41
 
42
  def decode_request(self, request: Any, **kwargs) -> str:
43
  """Convert the request payload to input format."""
44
- # For our case, we expect the request to be text
45
  if isinstance(request, dict) and "prompt" in request:
46
  return request["prompt"]
47
  return request
48
 
49
- def encode_response(self, output: Union[str, Iterator[str]], **kwargs) -> Union[str, Iterator[str]]:
50
  """Convert the model output to a response payload."""
51
- if self.stream:
52
- # For streaming, yield each chunk wrapped in a dict
53
- async def stream_wrapper():
54
- async for chunk in output:
55
- yield {"generated_text": chunk}
56
- else:
57
- # For non-streaming, return complete response
58
- return {"generated_text": output}
59
 
60
  async def generate_response(
61
  self,
@@ -88,7 +80,7 @@ class InferenceApi(LitAPI):
88
  prompt: str,
89
  system_message: Optional[str] = None,
90
  max_new_tokens: Optional[int] = None
91
- ) -> Iterator[str]:
92
  """Generate a streaming response by forwarding the request to the LLM Server."""
93
  self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
94
 
@@ -110,16 +102,7 @@ class InferenceApi(LitAPI):
110
  self.logger.error(f"Error in generate_stream: {str(e)}")
111
  raise
112
 
113
- # ... [rest of the methods remain the same: generate_embedding, check_system_status, etc.]
114
-
115
  async def cleanup(self):
116
  """Cleanup method - close HTTP client"""
117
  if self.client:
118
- await self.client.aclose()
119
-
120
- def log(self, key: str, value: Any):
121
- """Override log method to use our logger if queue not set"""
122
- if self._logger_queue is None:
123
- self.logger.info(f"Log event: {key}={value}")
124
- else:
125
- super().log(key, value)
 
1
  import httpx
2
+ from typing import Optional, AsyncIterator, Dict, Any
3
  import logging
4
  from litserve import LitAPI
5
+ from pydantic import BaseModel
6
+
7
+ class GenerationResponse(BaseModel):
8
+ generated_text: str
9
 
10
  class InferenceApi(LitAPI):
11
+ def __init__(self):
12
  """Initialize the Inference API with configuration."""
13
  super().__init__()
14
  self.logger = logging.getLogger(__name__)
15
  self.logger.info("Initializing Inference API")
16
+ self.client = None
 
 
 
 
 
 
 
17
 
18
  async def setup(self, device: Optional[str] = None):
19
  """Setup method required by LitAPI - initialize HTTP client"""
20
+ self._device = device
21
  self.client = httpx.AsyncClient(
22
+ base_url="http://localhost:8002", # We'll need to make this configurable
23
+ timeout=60.0
24
  )
25
  self.logger.info(f"Inference API setup completed on device: {device}")
26
 
27
+ async def predict(self, x: str, **kwargs) -> AsyncIterator[str]:
28
  """
29
  Main prediction method required by LitAPI.
30
+ Always yields, either chunks in streaming mode or complete response in non-streaming mode.
31
  """
32
  if self.stream:
33
  async for chunk in self.generate_stream(x, **kwargs):
 
38
 
39
  def decode_request(self, request: Any, **kwargs) -> str:
40
  """Convert the request payload to input format."""
 
41
  if isinstance(request, dict) and "prompt" in request:
42
  return request["prompt"]
43
  return request
44
 
45
+ def encode_response(self, output: AsyncIterator[str], **kwargs) -> AsyncIterator[Dict[str, str]]:
46
  """Convert the model output to a response payload."""
47
+ async def wrapper():
48
+ async for chunk in output:
49
+ yield {"generated_text": chunk}
50
+ return wrapper()
 
 
 
 
51
 
52
  async def generate_response(
53
  self,
 
80
  prompt: str,
81
  system_message: Optional[str] = None,
82
  max_new_tokens: Optional[int] = None
83
+ ) -> AsyncIterator[str]:
84
  """Generate a streaming response by forwarding the request to the LLM Server."""
85
  self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
86
 
 
102
  self.logger.error(f"Error in generate_stream: {str(e)}")
103
  raise
104
 
 
 
105
  async def cleanup(self):
106
  """Cleanup method - close HTTP client"""
107
  if self.client:
108
+ await self.client.aclose()
 
 
 
 
 
 
 
main/main.py CHANGED
@@ -32,7 +32,7 @@ def create_app():
32
  config = load_config()
33
 
34
  # Initialize API and router
35
- api = InferenceApi(config)
36
  init_router(config)
37
 
38
  # Create LitServer instance
 
32
  config = load_config()
33
 
34
  # Initialize API and router
35
+ api = InferenceApi()
36
  init_router(config)
37
 
38
  # Create LitServer instance