AurelioAguirre's picture
Fixing dockerfile v2
1eab622
raw
history blame
5.36 kB
import httpx
from typing import Optional, Iterator, List, Dict, Union
import logging
class InferenceApi:
def __init__(self, config: dict):
"""Initialize the Inference API with configuration."""
self.logger = logging.getLogger(__name__)
self.logger.info("Initializing Inference API")
# Get base URL from config
self.base_url = config["llm_server"]["base_url"]
self.timeout = config["llm_server"].get("timeout", 60)
# Initialize HTTP client
self.client = httpx.AsyncClient(
base_url=self.base_url,
timeout=self.timeout
)
self.logger.info("Inference API initialized successfully")
async def generate_response(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> str:
"""
Generate a complete response by forwarding the request to the LLM Server.
"""
self.logger.debug(f"Forwarding generation request for prompt: {prompt[:50]}...")
try:
response = await self.client.post(
"/api/v1/generate",
json={
"prompt": prompt,
"system_message": system_message,
"max_new_tokens": max_new_tokens
}
)
response.raise_for_status()
data = response.json()
return data["generated_text"]
except Exception as e:
self.logger.error(f"Error in generate_response: {str(e)}")
raise
async def generate_stream(
self,
prompt: str,
system_message: Optional[str] = None,
max_new_tokens: Optional[int] = None
) -> Iterator[str]:
"""
Generate a streaming response by forwarding the request to the LLM Server.
"""
self.logger.debug(f"Forwarding streaming request for prompt: {prompt[:50]}...")
try:
async with self.client.stream(
"POST",
"/api/v1/generate/stream",
json={
"prompt": prompt,
"system_message": system_message,
"max_new_tokens": max_new_tokens
}
) as response:
response.raise_for_status()
async for chunk in response.aiter_text():
yield chunk
except Exception as e:
self.logger.error(f"Error in generate_stream: {str(e)}")
raise
async def generate_embedding(self, text: str) -> List[float]:
"""
Generate embedding by forwarding the request to the LLM Server.
"""
self.logger.debug(f"Forwarding embedding request for text: {text[:50]}...")
try:
response = await self.client.post(
"/api/v1/embedding",
json={"text": text}
)
response.raise_for_status()
data = response.json()
return data["embedding"]
except Exception as e:
self.logger.error(f"Error in generate_embedding: {str(e)}")
raise
async def check_system_status(self) -> Dict[str, Union[Dict, str]]:
"""
Get system status from the LLM Server.
"""
try:
response = await self.client.get("/api/v1/system/status")
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error getting system status: {str(e)}")
raise
async def validate_system(self) -> Dict[str, Union[Dict, str, List[str]]]:
"""
Get system validation status from the LLM Server.
"""
try:
response = await self.client.get("/api/v1/system/validate")
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error validating system: {str(e)}")
raise
async def initialize_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
"""
Initialize a model on the LLM Server.
"""
try:
response = await self.client.post(
"/api/v1/model/initialize",
params={"model_name": model_name} if model_name else None
)
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error initializing model: {str(e)}")
raise
async def initialize_embedding_model(self, model_name: Optional[str] = None) -> Dict[str, str]:
"""
Initialize an embedding model on the LLM Server.
"""
try:
response = await self.client.post(
"/api/v1/model/initialize/embedding",
params={"model_name": model_name} if model_name else None
)
response.raise_for_status()
return response.json()
except Exception as e:
self.logger.error(f"Error initializing embedding model: {str(e)}")
raise
async def close(self):
"""Close the HTTP client session."""
await self.client.aclose()