Spaces:
Runtime error
Runtime error
import httpx | |
from typing import Optional, AsyncIterator, Dict, Any, Iterator, List | |
import logging | |
import asyncio | |
from litserve import LitAPI | |
from pydantic import BaseModel | |
class GenerationResponse(BaseModel): | |
generated_text: str | |
class InferenceApi(LitAPI): | |
def __init__(self, config: Dict[str, Any]): | |
"""Initialize the Inference API with configuration.""" | |
super().__init__() | |
self.logger = logging.getLogger(__name__) | |
self.logger.info("Initializing Inference API") | |
self._device = None | |
self.stream = False | |
self.config = config | |
self.llm_config = config.get('llm_server', {}) | |
def setup(self, device: Optional[str] = None): | |
"""Synchronous setup method required by LitAPI""" | |
self._device = device | |
self.logger.info(f"Inference API setup completed on device: {device}") | |
return self # It's common for setup methods to return self for chaining | |
async def _get_client(self): | |
"""Get or create HTTP client as needed""" | |
return httpx.AsyncClient( | |
base_url=self.llm_config.get('base_url', 'http://localhost:8002'), | |
timeout=float(self.llm_config.get('timeout', 60.0)) | |
) | |
def _get_endpoint(self, endpoint_name: str) -> str: | |
"""Get full endpoint path including prefix""" | |
endpoints = self.llm_config.get('endpoints', {}) | |
api_prefix = self.llm_config.get('api_prefix', '') | |
endpoint = endpoints.get(endpoint_name, '') | |
return f"{api_prefix}{endpoint}" | |
def predict(self, x: str, **kwargs) -> Iterator[str]: | |
"""Non-async prediction method that yields results.""" | |
loop = asyncio.get_event_loop() | |
async def async_gen(): | |
async for item in self._async_predict(x, **kwargs): | |
yield item | |
gen = async_gen() | |
while True: | |
try: | |
yield loop.run_until_complete(gen.__anext__()) | |
except StopAsyncIteration: | |
break | |
async def _async_predict(self, x: str, **kwargs) -> AsyncIterator[str]: | |
"""Internal async prediction method.""" | |
if self.stream: | |
async for chunk in self.generate_stream(x, **kwargs): | |
yield chunk | |
else: | |
response = await self.generate_response(x, **kwargs) | |
yield response | |
async def generate_embedding(self, text: str) -> List[float]: | |
"""Generate embedding vector from input text.""" | |
self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...") | |
try: | |
async with await self._get_client() as client: | |
response = await client.post( | |
self._get_endpoint('embedding'), | |
json={"text": text} | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["embedding"] | |
except Exception as e: | |
self.logger.error(f"Error in generate_embedding: {str(e)}") | |
raise | |
async def check_system_status(self) -> Dict[str, Any]: | |
"""Check system status of the LLM Server.""" | |
self.logger.debug("Checking system status...") | |
try: | |
async with await self._get_client() as client: | |
response = await client.get( | |
self._get_endpoint('system_status') | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in check_system_status: {str(e)}") | |
raise | |
async def download_model(self, model_name: Optional[str] = None) -> Dict[str, str]: | |
"""Download model files from the LLM Server.""" | |
self.logger.debug(f"Forwarding model download request for: {model_name or 'default model'}") | |
try: | |
async with await self._get_client() as client: | |
response = await client.post( | |
self._get_endpoint('model_download'), | |
params={"model_name": model_name} if model_name else None | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in download_model: {str(e)}") | |
raise | |
except Exception as e: | |
self.logger.error(f"Error initiating model download: {str(e)}") | |
raise | |
async def validate_system(self) -> Dict[str, Any]: | |
"""Validate system configuration and setup.""" | |
self.logger.debug("Validating system configuration...") | |
try: | |
async with await self._get_client() as client: | |
response = await client.get( | |
self._get_endpoint('system_validate') | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in validate_system: {str(e)}") | |
raise | |
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
"""Initialize specified model or default model.""" | |
self.logger.debug(f"Initializing model: {model_name or 'default'}") | |
try: | |
async with await self._get_client() as client: | |
response = await client.post( | |
self._get_endpoint('model_initialize'), | |
params={"model_name": model_name} if model_name else None | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in initialize_model: {str(e)}") | |
raise | |
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, Any]: | |
"""Initialize embedding model.""" | |
self.logger.debug(f"Initializing embedding model: {model_name or 'default'}") | |
try: | |
async with await self._get_client() as client: | |
response = await client.post( | |
self._get_endpoint('model_initialize_embedding'), | |
json={"model_name": model_name} if model_name else {} | |
) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
self.logger.error(f"Error in initialize_embedding_model: {str(e)}") | |
raise | |
def decode_request(self, request: Any, **kwargs) -> str: | |
"""Convert the request payload to input format.""" | |
if isinstance(request, dict) and "prompt" in request: | |
return request["prompt"] | |
return request | |
def encode_response(self, output: Iterator[str], **kwargs) -> Dict[str, Any]: | |
"""Convert the model output to a response payload.""" | |
if self.stream: | |
return {"generated_text": output} | |
try: | |
result = next(output) | |
return {"generated_text": result} | |
except StopIteration: | |
return {"generated_text": ""} | |
async def generate_response( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> str: | |
"""Generate a complete response by forwarding the request to the LLM Server.""" | |
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") | |
try: | |
async with await self._get_client() as client: | |
response = await client.post( | |
self._get_endpoint('generate'), | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
} | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["generated_text"] | |
except Exception as e: | |
self.logger.error(f"Error in generate_response: {str(e)}") | |
raise | |
async def generate_stream( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> AsyncIterator[str]: | |
"""Generate a streaming response by forwarding the request to the LLM Server.""" | |
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") | |
try: | |
client = await self._get_client() | |
async with client.stream( | |
"POST", | |
self._get_endpoint('generate_stream'), | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
} | |
) as response: | |
response.raise_for_status() | |
async for chunk in response.aiter_text(): | |
yield chunk | |
await client.aclose() | |
except Exception as e: | |
self.logger.error(f"Error in generate_stream: {str(e)}") | |
raise | |
async def cleanup(self): | |
"""Cleanup method - no longer needed as clients are created per-request""" | |
pass |