# api.py file in main directory of the Inference API module. import httpx from typing import Optional, AsyncIterator, Dict, Any import logging from litserve import LitAPI from pydantic import BaseModel class GenerationResponse(BaseModel): generated_text: str class InferenceApi(LitAPI): def __init__(self): """Initialize the Inference API with configuration.""" super().__init__() self.logger = logging.getLogger(__name__) self.logger.info("Initializing Inference API") self.client = None async def setup(self, device: Optional[str] = None): """Setup method required by LitAPI - initialize HTTP client""" self._device = device self.client = httpx.AsyncClient( base_url="http://localhost:8002", # We'll need to make this configurable timeout=60.0 ) self.logger.info(f"Inference API setup completed on device: {device}") async def predict(self, x: str, **kwargs) -> AsyncIterator[str]: """ Main prediction method required by LitAPI. Always yields, either chunks in streaming mode or complete response in non-streaming mode. """ if self.stream: async for chunk in self.generate_stream(x, **kwargs): yield chunk else: response = await self.generate_response(x, **kwargs) yield response def decode_request(self, request: Any, **kwargs) -> str: """Convert the request payload to input format.""" if isinstance(request, dict) and "prompt" in request: return request["prompt"] return request def encode_response(self, output: AsyncIterator[str], **kwargs) -> AsyncIterator[Dict[str, str]]: """Convert the model output to a response payload.""" async def wrapper(): async for chunk in output: yield {"generated_text": chunk} return wrapper() async def generate_response( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> str: """Generate a complete response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") try: response = await self.client.post( "/api/v1/generate", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) response.raise_for_status() data = response.json() return data["generated_text"] except Exception as e: self.logger.error(f"Error in generate_response: {str(e)}") raise async def generate_stream( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> AsyncIterator[str]: """Generate a streaming response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") try: async with self.client.stream( "POST", "/api/v1/generate/stream", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) as response: response.raise_for_status() async for chunk in response.aiter_text(): yield chunk except Exception as e: self.logger.error(f"Error in generate_stream: {str(e)}") raise async def cleanup(self): """Cleanup method - close HTTP client""" if self.client: await self.client.aclose()