Spaces:
Runtime error
Runtime error
import httpx | |
import logging | |
from abc import ABC, abstractmethod | |
from typing import Optional, Dict, Any, AsyncIterator, List | |
class LLMAdapter(ABC): | |
"""Abstract base class for LLM adapters.""" | |
async def generate_response( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> str: | |
"""Generate a complete response from the LLM.""" | |
pass | |
async def generate_stream( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> AsyncIterator[str]: | |
"""Generate a streaming response from the LLM.""" | |
pass | |
async def generate_embedding(self, text: str) -> List[float]: | |
"""Generate embedding vector from input text.""" | |
pass | |
async def check_system_status(self) -> Dict[str, Any]: | |
"""Check system status of the LLM Server.""" | |
pass | |
async def validate_system(self) -> Dict[str, Any]: | |
"""Validate system configuration and setup.""" | |
pass | |
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
"""Initialize specified model or default model.""" | |
pass | |
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
"""Initialize embedding model.""" | |
pass | |
async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
"""Download model files.""" | |
pass | |
async def cleanup(self): | |
"""Cleanup resources.""" | |
pass | |
class HTTPLLMAdapter(LLMAdapter): | |
"""HTTP adapter for connecting to LLM services over HTTP.""" | |
def __init__(self, config: Dict[str, Any]): | |
"""Initialize the HTTP LLM Adapter with configuration.""" | |
self.logger = logging.getLogger(__name__) | |
self.logger.info("Initializing HTTP LLM Adapter") | |
self.config = config | |
self.llm_config = config.get('llm_server', {}) | |
async def _get_client(self): | |
"""Get or create HTTP client as needed""" | |
host = self.llm_config.get('host', 'localhost') | |
port = self.llm_config.get('port', 8002) | |
# Construct base URL, omitting port for HF spaces | |
if 'hf.space' in host: | |
base_url = f"https://{host}" | |
else: | |
base_url = f"http://{host}:{port}" | |
return httpx.AsyncClient( | |
base_url=base_url, | |
timeout=float(self.llm_config.get('timeout', 60.0)) | |
) | |
def _get_endpoint(self, endpoint_name: str) -> str: | |
"""Get full endpoint path including prefix""" | |
endpoints = self.llm_config.get('endpoints', {}) | |
api_prefix = self.llm_config.get('api_prefix', '') | |
endpoint = endpoints.get(endpoint_name, '') | |
return f"{api_prefix}{endpoint}" | |
async def _make_request( | |
self, | |
method: str, | |
endpoint: str, | |
*, | |
params: Optional[Dict[str, Any]] = None, | |
json: Optional[Dict[str, Any]] = None, | |
stream: bool = False | |
) -> Any: | |
"""Make an authenticated request to the LLM Server.""" | |
base_url = self.llm_config.get('host', 'http://localhost:8001') | |
full_endpoint = f"{base_url.rstrip('/')}/{self._get_endpoint(endpoint).lstrip('/')}" | |
try: | |
self.logger.info(f"Making {method} request to: {full_endpoint}") | |
# Create client outside the with block for streaming | |
client = await self._get_client() | |
if stream: | |
# For streaming, return both client and response context managers | |
return client, client.stream( | |
method, | |
self._get_endpoint(endpoint), | |
params=params, | |
json=json | |
) | |
else: | |
# For non-streaming, use context manager | |
async with client as c: | |
response = await c.request( | |
method, | |
self._get_endpoint(endpoint), | |
params=params, | |
json=json | |
) | |
response.raise_for_status() | |
return response | |
except Exception as e: | |
self.logger.error(f"Error in request to {full_endpoint}: {str(e)}") | |
raise | |
async def generate_response( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> str: | |
"""Generate a complete response by forwarding the request to the LLM Server.""" | |
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") | |
try: | |
response = await self._make_request( | |
"POST", | |
"generate", | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
} | |
) | |
data = response.json() | |
return data["generated_text"] | |
except Exception as e: | |
self.logger.error(f"Error in generate_response: {str(e)}") | |
raise | |
async def generate_stream( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> AsyncIterator[str]: | |
"""Generate a streaming response by forwarding the request to the LLM Server.""" | |
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") | |
try: | |
client, stream_cm = await self._make_request( | |
"POST", | |
"generate_stream", | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
}, | |
stream=True | |
) | |
async with client: | |
async with stream_cm as response: | |
async for chunk in response.aiter_text(): | |
yield chunk | |
except Exception as e: | |
self.logger.error(f"Error in generate_stream: {str(e)}") | |
raise | |
async def generate_embedding(self, text: str) -> List[float]: | |
"""Generate embedding vector from input text.""" | |
self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") | |
try: | |
response = await self._make_request( | |
"POST", | |
"embedding", | |
json={"text": text} | |
) | |
data = response.json() | |
return data["embedding"] | |
except Exception as e: | |
self.logger.error(f"Error in generate_embedding: {str(e)}") | |
raise | |
async def check_system_status(self) -> Dict[str, Any]: | |
"""Check system status of the LLM Server.""" | |
self.logger.debug("Checking system status...") | |
try: | |
response = await self._make_request( | |
"GET", | |
"system_status" | |
) | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in check_system_status: {str(e)}") | |
raise | |
async def validate_system(self) -> Dict[str, Any]: | |
"""Validate system configuration and setup.""" | |
self.logger.debug("Validating system configuration...") | |
try: | |
response = await self._make_request( | |
"GET", | |
"system_validate" | |
) | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in validate_system: {str(e)}") | |
raise | |
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
"""Initialize specified model or default model.""" | |
self.logger.debug(f"Initializing model: {model_name or 'default'}") | |
try: | |
response = await self._make_request( | |
"POST", | |
"model_initialize", | |
params={"model_name": model_name} if model_name else None | |
) | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in initialize_model: {str(e)}") | |
raise | |
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
"""Initialize embedding model.""" | |
self.logger.debug(f"Initializing embedding model: {model_name or 'default'}") | |
try: | |
response = await self._make_request( | |
"POST", | |
"model_initialize_embedding", | |
json={"model_name": model_name} if model_name else {} | |
) | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in initialize_embedding_model: {str(e)}") | |
raise | |
async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
"""Download model files from the LLM Server.""" | |
self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}") | |
try: | |
response = await self._make_request( | |
"POST", | |
"model_download", | |
params={"model_name": model_name} if model_name else None | |
) | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in download_model: {str(e)}") | |
raise | |
async def cleanup(self): | |
"""Cleanup method - no longer needed as clients are created per-request""" | |
pass | |
class OpenAIAdapter(LLMAdapter): | |
"""Adapter for OpenAI-compatible services (OpenAI, Azure OpenAI, local services with OpenAI API).""" | |
def __init__(self, config: Dict[str, Any]): | |
self.logger = logging.getLogger(__name__) | |
self.logger.info("Initializing OpenAI Adapter") | |
self.config = config | |
self.openai_config = config.get('openai', {}) | |
# Additional OpenAI-specific setup would go here | |
async def generate_response(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> str: | |
"""OpenAI implementation - would use openai Python client""" | |
# Implementation would go here | |
pass | |
async def generate_stream(self, prompt: str, system_message: Optional[str] = None, max_new_tokens: Optional[int] = None) -> AsyncIterator[str]: | |
"""OpenAI streaming implementation""" | |
# Implementation would go here | |
async def placeholder_stream(): | |
yield "Not implemented yet" | |
return placeholder_stream() | |
# ... implementations for other methods | |
class vLLMAdapter(LLMAdapter): | |
"""Adapter for vLLM services.""" | |
def __init__(self, config: Dict[str, Any]): | |
self.logger = logging.getLogger(__name__) | |
self.logger.info("Initializing vLLM Adapter") | |
self.config = config | |
self.vllm_config = config.get('vllm', {}) | |
# Additional vLLM-specific setup would go here | |
# ... implementations for all methods | |
# Factory function to create the appropriate adapter | |
def create_adapter(config: Dict[str, Any]) -> LLMAdapter: | |
"""Create an adapter instance based on configuration.""" | |
adapter_type = config.get('adapter', {}).get('type', 'http') | |
if adapter_type == 'http': | |
return HTTPLLMAdapter(config) | |
elif adapter_type == 'openai': | |
return OpenAIAdapter(config) | |
elif adapter_type == 'vllm': | |
return vLLMAdapter(config) | |
else: | |
raise ValueError(f"Unknown adapter type: {adapter_type}") |