import httpx from typing import Optional, AsyncIterator, Dict, Any, Iterator import logging import asyncio from litserve import LitAPI from pydantic import BaseModel class GenerationResponse(BaseModel): generated_text: str class InferenceApi(LitAPI): def __init__(self): """Initialize the Inference API with configuration.""" super().__init__() self.logger = logging.getLogger(__name__) self.logger.info("Initializing Inference API") self.client = None async def setup(self, device: Optional[str] = None): """Setup method required by LitAPI - initialize HTTP client""" self._device = device self.client = httpx.AsyncClient( base_url="http://localhost:8002", # We'll need to make this configurable timeout=60.0 ) self.logger.info(f"Inference API setup completed on device: {device}") def predict(self, x: str, **kwargs) -> Iterator[str]: """ Non-async prediction method that yields results. """ loop = asyncio.get_event_loop() async def async_gen(): async for item in self._async_predict(x, **kwargs): yield item gen = async_gen() while True: try: yield loop.run_until_complete(gen.__anext__()) except StopAsyncIteration: break async def _async_predict(self, x: str, **kwargs) -> AsyncIterator[str]: """ Internal async prediction method. """ if self.stream: async for chunk in self.generate_stream(x, **kwargs): yield chunk else: response = await self.generate_response(x, **kwargs) yield response def decode_request(self, request: Any, **kwargs) -> str: """Convert the request payload to input format.""" if isinstance(request, dict) and "prompt" in request: return request["prompt"] return request def encode_response(self, output: Iterator[str], **kwargs) -> Dict[str, Any]: """Convert the model output to a response payload.""" # For streaming responses if self.stream: return {"generated_text": output} # For non-streaming, take the first (and only) item from the iterator try: result = next(output) return {"generated_text": result} except StopIteration: return {"generated_text": ""} async def generate_response( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> str: """Generate a complete response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") try: response = await self.client.post( "/api/v1/generate", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) response.raise_for_status() data = response.json() return data["generated_text"] except Exception as e: self.logger.error(f"Error in generate_response: {str(e)}") raise async def generate_stream( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> AsyncIterator[str]: """Generate a streaming response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") try: async with self.client.stream( "POST", "/api/v1/generate/stream", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) as response: response.raise_for_status() async for chunk in response.aiter_text(): yield chunk except Exception as e: self.logger.error(f"Error in generate_stream: {str(e)}") raise async def cleanup(self): """Cleanup method - close HTTP client""" if self.client: await self.client.aclose()