Inference-API / main /adapter.py
AurelioAguirre's picture
WIP adapter.py
7e3820c
raw
history blame
12 kB
import httpx
import logging
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, AsyncIterator, List
class LLMAdapter(ABC):
"""Abstract base class for LLM adapters."""
@abstractmethod
async def generate_response(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> str:
"""Generate a complete response from the LLM."""
pass
@abstractmethod
async def generate_stream(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> AsyncIterator[str]:
"""Generate a streaming response from the LLM."""
pass
@abstractmethod
async def generate_embedding(self, text: str) -> List[float]:
"""Generate embedding vector from input text."""
pass
@abstractmethod
async def check_system_status(self) -> Dict[str, Any]:
"""Check system status of the LLM Server."""
pass
@abstractmethod
async def validate_system(self) -> Dict[str, Any]:
"""Validate system configuration and setup."""
pass
@abstractmethod
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
"""Initialize specified model or default model."""
pass
@abstractmethod
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
"""Initialize embedding model."""
pass
@abstractmethod
async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
"""Download model files."""
pass
@abstractmethod
async def cleanup(self):
"""Cleanup resources."""
pass
class HTTPLLMAdapter(LLMAdapter):
"""HTTP adapter for connecting to LLM services over HTTP."""
def __init__(self, config: Dict[str, Any]):
"""Initialize the HTTP LLM Adapter with configuration."""
self.logger = logging.getLogger(__name__)
self.logger.info("Initializing HTTP LLM Adapter")
self.config = config
self.llm_config = config.get('llm_server', {})
async def _get_client(self):
"""Get or create HTTP client as needed"""
host = self.llm_config.get('host', 'localhost')
port = self.llm_config.get('port', 8002)
# Construct base URL, omitting port for HF spaces
if 'hf.space' in host:
base_url = f"https://{host}"
else:
base_url = f"http://{host}:{port}"
return httpx.AsyncClient(
base_url=base_url,
timeout=float(self.llm_config.get('timeout', 60.0))
)
def _get_endpoint(self, endpoint_name: str) -> str:
"""Get full endpoint path including prefix"""
endpoints = self.llm_config.get('endpoints', {})
api_prefix = self.llm_config.get('api_prefix', '')
endpoint = endpoints.get(endpoint_name, '')
return f"{api_prefix}{endpoint}"
async def _make_request(
self,
method: str,
endpoint: str,
*,
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
stream: bool = False
) -> Any:
"""Make an authenticated request to the LLM Server."""
base_url = self.llm_config.get('host', 'http://localhost:8001')
full_endpoint = f"{base_url.rstrip('/')}/{self._get_endpoint(endpoint).lstrip('/')}"
try:
self.logger.info(f"Making {method} request to: {full_endpoint}")
# Create client outside the with block for streaming
client = await self._get_client()
if stream:
# For streaming, return both client and response context managers
return client, client.stream(
method,
self._get_endpoint(endpoint),
params=params,
json=json
)
else:
# For non-streaming, use context manager
async with client as c:
response = await c.request(
method,
self._get_endpoint(endpoint),
params=params,
json=json
)
response.raise_for_status()
return response
except Exception as e:
self.logger.error(f"Error in request to {full_endpoint}: {str(e)}")
raise
async def generate_response(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> str:
"""Generate a complete response by forwarding the request to the LLM Server."""
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
try:
response = await self._make_request(
"POST",
"generate",
json={
"prompt": prompt,
"system_message": system_message,
"max_new_tokens": max_new_tokens
}
)
data = response.json()
return data["generated_text"]
except Exception as e:
self.logger.error(f"Error in generate_response: {str(e)}")
raise
async def generate_stream(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> AsyncIterator[str]:
"""Generate a streaming response by forwarding the request to the LLM Server."""
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
try:
client, stream_cm = await self._make_request(
"POST",
"generate_stream",
json={
"prompt": prompt,
"system_message": system_message,
"max_new_tokens": max_new_tokens
},
stream=True
)
async with client:
async with stream_cm as response:
async for chunk in response.aiter_text():
yield chunk
except Exception as e:
self.logger.error(f"Error in generate_stream: {str(e)}")
raise
async def generate_embedding(self, text: str) -> List[float]:
"""Generate embedding vector from input text."""
self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
try:
response = await self._make_request(
"POST",
"embedding",
json={"text": text}
)
data = response.json()
return data["embedding"]
except Exception as e:
self.logger.error(f"Error in generate_embedding: {str(e)}")
raise
async def check_system_status(self) -> Dict[str, Any]:
"""Check system status of the LLM Server."""
self.logger.debug("Checking system status...")
try:
response = await self._make_request(
"GET",
"system_status"
)
return response.json()
except Exception as e:
self.logger.error(f"Error in check_system_status: {str(e)}")
raise
async def validate_system(self) -> Dict[str, Any]:
"""Validate system configuration and setup."""
self.logger.debug("Validating system configuration...")
try:
response = await self._make_request(
"GET",
"system_validate"
)
return response.json()
except Exception as e:
self.logger.error(f"Error in validate_system: {str(e)}")
raise
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
"""Initialize specified model or default model."""
self.logger.debug(f"Initializing model: {model_name or 'default'}")
try:
response = await self._make_request(
"POST",
"model_initialize",
params={"model_name": model_name} if model_name else None
)
return response.json()
except Exception as e:
self.logger.error(f"Error in initialize_model: {str(e)}")
raise
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]:
"""Initialize embedding model."""
self.logger.debug(f"Initializing embedding model: {model_name or 'default'}")
try:
response = await self._make_request(
"POST",
"model_initialize_embedding",
json={"model_name": model_name} if model_name else {}
)
return response.json()
except Exception as e:
self.logger.error(f"Error in initialize_embedding_model: {str(e)}")
raise
async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
"""Download model files from the LLM Server."""
self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}")
try:
response = await self._make_request(
"POST",
"model_download",
params={"model_name": model_name} if model_name else None
)
return response.json()
except Exception as e:
self.logger.error(f"Error in download_model: {str(e)}")
raise
async def cleanup(self):
"""Cleanup method - no longer needed as clients are created per-request"""
pass
class OpenAIAdapter(LLMAdapter):
"""Adapter for OpenAI-compatible services (OpenAI, Azure OpenAI, local services with OpenAI API)."""
def __init__(self, config: Dict[str, Any]):
self.logger = logging.getLogger(__name__)
self.logger.info("Initializing OpenAI Adapter")
self.config = config
self.openai_config = config.get('openai', {})
# Additional OpenAI-specific setup would go here
async def generate_response(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> str:
"""OpenAI implementation - would use openai Python client"""
# Implementation would go here
pass
async def generate_stream(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> AsyncIterator[str]:
"""OpenAI streaming implementation"""
# Implementation would go here
async def placeholder_stream():
yield "Not implemented yet"
return placeholder_stream()
# ... implementations for other methods
class vLLMAdapter(LLMAdapter):
"""Adapter for vLLM services."""
def __init__(self, config: Dict[str, Any]):
self.logger = logging.getLogger(__name__)
self.logger.info("Initializing vLLM Adapter")
self.config = config
self.vllm_config = config.get('vllm', {})
# Additional vLLM-specific setup would go here
# ... implementations for all methods
# Factory function to create the appropriate adapter
def create_adapter(config: Dict[str, Any]) -> LLMAdapter:
"""Create an adapter instance based on configuration."""
adapter_type = config.get('adapter', {}).get('type', 'http')
if adapter_type == 'http':
return HTTPLLMAdapter(config)
elif adapter_type == 'openai':
return OpenAIAdapter(config)
elif adapter_type == 'vllm':
return vLLMAdapter(config)
else:
raise ValueError(f"Unknown adapter type: {adapter_type}")