import httpx from typing import Optional, AsyncIterator, Dict, Any, Iterator, List import logging import asyncio from litserve import LitAPI from pydantic import BaseModel class GenerationResponse(BaseModel): generated_text: str class InferenceApi(LitAPI): def __init__(self, config: Dict[str, Any]): """Initialize the Inference API with configuration.""" super().__init__() self.logger = logging.getLogger(__name__) self.logger.info("Initializing Inference API") self._device = None self.stream = False self.config = config self.llm_config = config.get('llm_server', {}) def setup(self, device: Optional[str] = None): """Synchronous setup method required by LitAPI""" self._device = device self.logger.info(f"Inference API setup completed on device: {device}") return self # It's common for setup methods to return self for chaining async def _get_client(self): """Get or create HTTP client as needed""" return httpx.AsyncClient( base_url=self.llm_config.get('base_url', 'http://localhost:8002'), timeout=float(self.llm_config.get('timeout', 60.0)) ) def _get_endpoint(self, endpoint_name: str) -> str: """Get full endpoint path including prefix""" endpoints = self.llm_config.get('endpoints', {}) api_prefix = self.llm_config.get('api_prefix', '') endpoint = endpoints.get(endpoint_name, '') return f"{api_prefix}{endpoint}" def predict(self, x: str, **kwargs) -> Iterator[str]: """Non-async prediction method that yields results.""" loop = asyncio.get_event_loop() async def async_gen(): async for item in self._async_predict(x, **kwargs): yield item gen = async_gen() while True: try: yield loop.run_until_complete(gen.__anext__()) except StopAsyncIteration: break async def _async_predict(self, x: str, **kwargs) -> AsyncIterator[str]: """Internal async prediction method.""" if self.stream: async for chunk in self.generate_stream(x, **kwargs): yield chunk else: response = await self.generate_response(x, **kwargs) yield response async def generate_embedding(self, text: str) -> List[float]: """Generate embedding vector from input text.""" self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") try: async with await self._get_client() as client: response = await client.post( self._get_endpoint('embedding'), json={"text": text} ) response.raise_for_status() data = response.json() return data["embedding"] except Exception as e: self.logger.error(f"Error in generate_embedding: {str(e)}") raise async def check_system_status(self) -> Dict[str, Any]: """Check system status of the LLM Server.""" self.logger.debug("Checking system status...") try: async with await self._get_client() as client: response = await client.get( self._get_endpoint('system_status') ) response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error in check_system_status: {str(e)}") raise async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]: """Download model files from the LLM Server.""" self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}") try: async with await self._get_client() as client: response = await client.post( self._get_endpoint('model_download'), params={"model_name": model_name} if model_name else None ) response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error in download_model: {str(e)}") raise except Exception as e: self.logger.error(f"Error initiating model download: {str(e)}") raise async def validate_system(self) -> Dict[str, Any]: """Validate system configuration and setup.""" self.logger.debug("Validating system configuration...") try: async with await self._get_client() as client: response = await client.get( self._get_endpoint('system_validate') ) response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error in validate_system: {str(e)}") raise async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: """Initialize specified model or default model.""" self.logger.debug(f"Initializing model: {model_name or 'default'}") try: async with await self._get_client() as client: response = await client.post( self._get_endpoint('model_initialize'), params={"model_name": model_name} if model_name else None ) response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error in initialize_model: {str(e)}") raise async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: """Initialize embedding model.""" self.logger.debug(f"Initializing embedding model: {model_name or 'default'}") try: async with await self._get_client() as client: response = await client.post( self._get_endpoint('model_initialize_embedding'), json={"model_name": model_name} if model_name else {} ) response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error in initialize_embedding_model: {str(e)}") raise def decode_request(self, request: Any, **kwargs) -> str: """Convert the request payload to input format.""" if isinstance(request, dict) and "prompt" in request: return request["prompt"] return request def encode_response(self, output: Iterator[str], **kwargs) -> Dict[str, Any]: """Convert the model output to a response payload.""" if self.stream: return {"generated_text": output} try: result = next(output) return {"generated_text": result} except StopIteration: return {"generated_text": ""} async def generate_response( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> str: """Generate a complete response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") try: async with await self._get_client() as client: response = await client.post( self._get_endpoint('generate'), json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) response.raise_for_status() data = response.json() return data["generated_text"] except Exception as e: self.logger.error(f"Error in generate_response: {str(e)}") raise async def generate_stream( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> AsyncIterator[str]: """Generate a streaming response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") try: client = await self._get_client() async with client.stream( "POST", self._get_endpoint('generate_stream'), json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) as response: response.raise_for_status() async for chunk in response.aiter_text(): yield chunk await client.aclose() except Exception as e: self.logger.error(f"Error in generate_stream: {str(e)}") raise async def cleanup(self): """Cleanup method - no longer needed as clients are created per-request""" pass