Spaces:
Runtime error
Runtime error
import httpx | |
from typing import Optional, Iterator, Union, Any | |
import logging | |
from litserve import LitAPI | |
class InferenceApi(LitAPI): | |
def __init__(self, config: dict): | |
"""Initialize the Inference API with configuration.""" | |
super().__init__() | |
self.logger = logging.getLogger(__name__) | |
self.logger.info("Initializing Inference API") | |
# Get base URL from config | |
self.base_url = config["llm_server"]["base_url"] | |
self.timeout = config["llm_server"].get("timeout", 60) | |
self.client = None # Will be initialized in setup() | |
# Set request timeout from config | |
self.request_timeout = float(self.timeout) | |
async def setup(self, device: Optional[str] = None): | |
"""Setup method required by LitAPI - initialize HTTP client""" | |
self._device = device # Store device as required by LitAPI | |
self.client = httpx.AsyncClient( | |
base_url=self.base_url, | |
timeout=self.timeout | |
) | |
self.logger.info(f"Inference API setup completed on device: {device}") | |
async def predict(self, x: str, **kwargs) -> Union[str, Iterator[str]]: | |
""" | |
Main prediction method required by LitAPI. | |
If streaming is enabled, yields chunks; otherwise returns complete response. | |
""" | |
if self.stream: | |
async for chunk in self.generate_stream(x, **kwargs): | |
yield chunk | |
else: | |
return await self.generate_response(x, **kwargs) | |
def decode_request(self, request: Any, **kwargs) -> str: | |
"""Convert the request payload to input format.""" | |
# For our case, we expect the request to be text | |
if isinstance(request, dict) and "prompt" in request: | |
return request["prompt"] | |
return request | |
def encode_response(self, output: Union[str, Iterator[str]], **kwargs) -> Union[str, Iterator[str]]: | |
"""Convert the model output to a response payload.""" | |
if self.stream: | |
# For streaming, yield each chunk wrapped in a dict | |
async def stream_wrapper(): | |
async for chunk in output: | |
yield {"generated_text": chunk} | |
else: | |
# For non-streaming, return complete response | |
return {"generated_text": output} | |
async def generate_response( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> str: | |
"""Generate a complete response by forwarding the request to the LLM Server.""" | |
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...") | |
try: | |
response = await self.client.post( | |
"/api/v1/generate", | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
} | |
) | |
response.raise_for_status() | |
data = response.json() | |
return data["generated_text"] | |
except Exception as e: | |
self.logger.error(f"Error in generate_response: {str(e)}") | |
raise | |
async def generate_stream( | |
self, | |
prompt: str, | |
system_message: Optional[str] = None, | |
max_new_tokens: Optional[int] = None | |
) -> Iterator[str]: | |
"""Generate a streaming response by forwarding the request to the LLM Server.""" | |
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...") | |
try: | |
async with self.client.stream( | |
"POST", | |
"/api/v1/generate/stream", | |
json={ | |
"prompt": prompt, | |
"system_message": system_message, | |
"max_new_tokens": max_new_tokens | |
} | |
) as response: | |
response.raise_for_status() | |
async for chunk in response.aiter_text(): | |
yield chunk | |
except Exception as e: | |
self.logger.error(f"Error in generate_stream: {str(e)}") | |
raise | |
# ... [rest of the methods remain the same: generate_embedding, check_system_status, etc.] | |
async def cleanup(self): | |
"""Cleanup method - close HTTP client""" | |
if self.client: | |
await self.client.aclose() | |
def log(self, key: str, value: Any): | |
"""Override log method to use our logger if queue not set""" | |
if self._logger_queue is None: | |
self.logger.info(f"Log event: {key}={value}") | |
else: | |
super().log(key, value) |