import httpx from typing import Optional, Iterator, List, Dict, Union import logging class InferenceApi: def __init__(self, config: dict): """Initialize the Inference API with configuration.""" self.logger = logging.getLogger(__name__) self.logger.info("Initializing Inference API") # Get base URL from config self.base_url = config["llm_server"]["base_url"] self.timeout = config["llm_server"].get("timeout", 60) # Initialize HTTP client self.client = httpx.AsyncClient( base_url=self.base_url, timeout=self.timeout ) self.logger.info("Inference API initialized successfully") async def generate_response( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> str: """ Generate a complete response by forwarding the request to the LLM Server. """ self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") try: response = await self.client.post( "/api/v1/generate", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) response.raise_for_status() data = response.json() return data["generated_text"] except Exception as e: self.logger.error(f"Error in generate_response: {str(e)}") raise async def generate_stream( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> Iterator[str]: """ Generate a streaming response by forwarding the request to the LLM Server. """ self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") try: async with self.client.stream( "POST", "/api/v1/generate/stream", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) as response: response.raise_for_status() async for chunk in response.aiter_text(): yield chunk except Exception as e: self.logger.error(f"Error in generate_stream: {str(e)}") raise async def generate_embedding(self, text: str) -> List[float]: """ Generate embedding by forwarding the request to the LLM Server. """ self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") try: response = await self.client.post( "/api/v1/embedding", json={"text": text} ) response.raise_for_status() data = response.json() return data["embedding"] except Exception as e: self.logger.error(f"Error in generate_embedding: {str(e)}") raise async def check_system_status(self) -> Dict[str, Union[Dict, str]]: """ Get system status from the LLM Server. """ try: response = await self.client.get("/api/v1/system/status") response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error getting system status: {str(e)}") raise async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]: """ Get system validation status from the LLM Server. """ try: response = await self.client.get("/api/v1/system/validate") response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error validating system: {str(e)}") raise async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]: """ Initialize a model on the LLM Server. """ try: response = await self.client.post( "/api/v1/model/initialize", params={"model_name": model_name} if model_name else None ) response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error initializing model: {str(e)}") raise async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]: """ Initialize an embedding model on the LLM Server. """ try: response = await self.client.post( "/api/v1/model/initialize/embedding", params={"model_name": model_name} if model_name else None ) response.raise_for_status() return response.json() except Exception as e: self.logger.error(f"Error initializing embedding model: {str(e)}") raise async def close(self): """Close the HTTP client session.""" await self.client.aclose()