import httpx from typing import Optional, Iterator, Union, Any import logging from litserve import LitAPI class InferenceApi(LitAPI): def __init__(self, config: dict): """Initialize the Inference API with configuration.""" super().__init__() self.logger = logging.getLogger(__name__) self.logger.info("Initializing Inference API") # Get base URL from config self.base_url = config["llm_server"]["base_url"] self.timeout = config["llm_server"].get("timeout", 60) self.client = None # Will be initialized in setup() # Set request timeout from config self.request_timeout = float(self.timeout) async def setup(self, device: Optional[str] = None): """Setup method required by LitAPI - initialize HTTP client""" self._device = device # Store device as required by LitAPI self.client = httpx.AsyncClient( base_url=self.base_url, timeout=self.timeout ) self.logger.info(f"Inference API setup completed on device: {device}") async def predict(self, x: str, **kwargs) -> Union[str, Iterator[str]]: """ Main prediction method required by LitAPI. If streaming is enabled, yields chunks; otherwise returns complete response. """ if self.stream: async for chunk in self.generate_stream(x, **kwargs): yield chunk else: return await self.generate_response(x, **kwargs) def decode_request(self, request: Any, **kwargs) -> str: """Convert the request payload to input format.""" # For our case, we expect the request to be text if isinstance(request, dict) and "prompt" in request: return request["prompt"] return request def encode_response(self, output: Union[str, Iterator[str]], **kwargs) -> Union[str, Iterator[str]]: """Convert the model output to a response payload.""" if self.stream: # For streaming, yield each chunk wrapped in a dict async def stream_wrapper(): async for chunk in output: yield {"generated_text": chunk} else: # For non-streaming, return complete response return {"generated_text": output} async def generate_response( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> str: """Generate a complete response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") try: response = await self.client.post( "/api/v1/generate", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) response.raise_for_status() data = response.json() return data["generated_text"] except Exception as e: self.logger.error(f"Error in generate_response: {str(e)}") raise async def generate_stream( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> Iterator[str]: """Generate a streaming response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") try: async with self.client.stream( "POST", "/api/v1/generate/stream", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) as response: response.raise_for_status() async for chunk in response.aiter_text(): yield chunk except Exception as e: self.logger.error(f"Error in generate_stream: {str(e)}") raise # ... [rest of the methods remain the same: generate_embedding, check_system_status, etc.] async def cleanup(self): """Cleanup method - close HTTP client""" if self.client: await self.client.aclose() def log(self, key: str, value: Any): """Override log method to use our logger if queue not set""" if self._logger_queue is None: self.logger.info(f"Log event: {key}={value}") else: super().log(key, value)