import httpx import logging from abc import ABC, abstractmethod from typing import Optional, Dict, Any, AsyncIterator, List class LLMAdapter(ABC): """Abstract base class for LLM adapters.""" @abstractmethod async def generate_response( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> str: """Generate a complete response from the LLM.""" pass @abstractmethod async def generate_stream( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> AsyncIterator[str]: """Generate a streaming response from the LLM.""" pass @abstractmethod async def generate_embedding(self, text: str) -> List[float]: """Generate embedding vector from input text.""" pass @abstractmethod async def check_system_status(self) -> Dict[str, Any]: """Check system status of the LLM Server.""" pass @abstractmethod async def validate_system(self) -> Dict[str, Any]: """Validate system configuration and setup.""" pass @abstractmethod async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: """Initialize specified model or default model.""" pass @abstractmethod async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: """Initialize embedding model.""" pass @abstractmethod async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]: """Download model files.""" pass @abstractmethod async def cleanup(self): """Cleanup resources.""" pass class HTTPLLMAdapter(LLMAdapter): """HTTP adapter for connecting to LLM services over HTTP.""" def __init__(self, config: Dict[str, Any]): """Initialize the HTTP LLM Adapter with configuration.""" self.logger = logging.getLogger(__name__) self.logger.info("Initializing HTTP LLM Adapter") self.config = config self.llm_config = config.get('llm_server', {}) async def _get_client(self): """Get or create HTTP client as needed""" host = self.llm_config.get('host', 'localhost') port = self.llm_config.get('port', 8002) # Construct base URL, omitting port for HF spaces if 'hf.space' in host: base_url = f"https://{host}" else: base_url = f"http://{host}:{port}" return httpx.AsyncClient( base_url=base_url, timeout=float(self.llm_config.get('timeout', 60.0)) ) def _get_endpoint(self, endpoint_name: str) -> str: """Get full endpoint path including prefix""" endpoints = self.llm_config.get('endpoints', {}) api_prefix = self.llm_config.get('api_prefix', '') endpoint = endpoints.get(endpoint_name, '') return f"{api_prefix}{endpoint}" async def _make_request( self, method: str, endpoint: str, *, params: Optional[Dict[str, Any]] = None, json: Optional[Dict[str, Any]] = None, stream: bool = False ) -> Any: """Make an authenticated request to the LLM Server.""" base_url = self.llm_config.get('host', 'http://localhost:8001') full_endpoint = f"{base_url.rstrip('/')}/{self._get_endpoint(endpoint).lstrip('/')}" try: self.logger.info(f"Making {method} request to: {full_endpoint}") # Create client outside the with block for streaming client = await self._get_client() if stream: # For streaming, return both client and response context managers return client, client.stream( method, self._get_endpoint(endpoint), params=params, json=json ) else: # For non-streaming, use context manager async with client as c: response = await c.request( method, self._get_endpoint(endpoint), params=params, json=json ) response.raise_for_status() return response except Exception as e: self.logger.error(f"Error in request to {full_endpoint}: {str(e)}") raise async def generate_response( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> str: """Generate a complete response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") try: response = await self._make_request( "POST", "generate", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens } ) data = response.json() return data["generated_text"] except Exception as e: self.logger.error(f"Error in generate_response: {str(e)}") raise async def generate_stream( self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None ) -> AsyncIterator[str]: """Generate a streaming response by forwarding the request to the LLM Server.""" self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") try: client, stream_cm = await self._make_request( "POST", "generate_stream", json={ "prompt": prompt, "system_message": system_message, "max_new_tokens": max_new_tokens }, stream=True ) async with client: async with stream_cm as response: async for chunk in response.aiter_text(): yield chunk except Exception as e: self.logger.error(f"Error in generate_stream: {str(e)}") raise async def generate_embedding(self, text: str) -> List[float]: """Generate embedding vector from input text.""" self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") try: response = await self._make_request( "POST", "embedding", json={"text": text} ) data = response.json() return data["embedding"] except Exception as e: self.logger.error(f"Error in generate_embedding: {str(e)}") raise async def check_system_status(self) -> Dict[str, Any]: """Check system status of the LLM Server.""" self.logger.debug("Checking system status...") try: response = await self._make_request( "GET", "system_status" ) return response.json() except Exception as e: self.logger.error(f"Error in check_system_status: {str(e)}") raise async def validate_system(self) -> Dict[str, Any]: """Validate system configuration and setup.""" self.logger.debug("Validating system configuration...") try: response = await self._make_request( "GET", "system_validate" ) return response.json() except Exception as e: self.logger.error(f"Error in validate_system: {str(e)}") raise async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: """Initialize specified model or default model.""" self.logger.debug(f"Initializing model: {model_name or 'default'}") try: response = await self._make_request( "POST", "model_initialize", params={"model_name": model_name} if model_name else None ) return response.json() except Exception as e: self.logger.error(f"Error in initialize_model: {str(e)}") raise async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: """Initialize embedding model.""" self.logger.debug(f"Initializing embedding model: {model_name or 'default'}") try: response = await self._make_request( "POST", "model_initialize_embedding", json={"model_name": model_name} if model_name else {} ) return response.json() except Exception as e: self.logger.error(f"Error in initialize_embedding_model: {str(e)}") raise async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]: """Download model files from the LLM Server.""" self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}") try: response = await self._make_request( "POST", "model_download", params={"model_name": model_name} if model_name else None ) return response.json() except Exception as e: self.logger.error(f"Error in download_model: {str(e)}") raise async def cleanup(self): """Cleanup method - no longer needed as clients are created per-request""" pass class OpenAIAdapter(LLMAdapter): """Adapter for OpenAI-compatible services (OpenAI, Azure OpenAI, local services with OpenAI API).""" def __init__(self, config: Dict[str, Any]): self.logger = logging.getLogger(__name__) self.logger.info("Initializing OpenAI Adapter") self.config = config self.openai_config = config.get('openai', {}) # Additional OpenAI-specific setup would go here async def generate_response(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> str: """OpenAI implementation - would use openai Python client""" # Implementation would go here pass async def generate_stream(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> AsyncIterator[str]: """OpenAI streaming implementation""" # Implementation would go here async def placeholder_stream(): yield "Not implemented yet" return placeholder_stream() # ... implementations for other methods class vLLMAdapter(LLMAdapter): """Adapter for vLLM services.""" def __init__(self, config: Dict[str, Any]): self.logger = logging.getLogger(__name__) self.logger.info("Initializing vLLM Adapter") self.config = config self.vllm_config = config.get('vllm', {}) # Additional vLLM-specific setup would go here # ... implementations for all methods # Factory function to create the appropriate adapter def create_adapter(config: Dict[str, Any]) -> LLMAdapter: """Create an adapter instance based on configuration.""" adapter_type = config.get('adapter', {}).get('type', 'http') if adapter_type == 'http': return HTTPLLMAdapter(config) elif adapter_type == 'openai': return OpenAIAdapter(config) elif adapter_type == 'vllm': return vLLMAdapter(config) else: raise ValueError(f"Unknown adapter type: {adapter_type}")