Inference-API / main /main.py
AurelioAguirre's picture
changed to uvicorn setup for HF
c6b21e3
raw
history blame
2.83 kB
"""
LLM Inference Server main application using LitServe framework.
"""
import litserve as ls
import yaml
import logging
import asyncio
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .routes import router, init_router
from .api import InferenceApi
def setup_logging():
"""Set up basic logging configuration"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
return logging.getLogger(__name__)
def load_config():
"""Load configuration from config.yaml"""
config_path = Path(__file__).parent / "config.yaml"
with open(config_path) as f:
return yaml.safe_load(f)
async def init_app() -> tuple[FastAPI, InferenceApi, dict]:
"""Initialize and configure the FastAPI application."""
logger = setup_logging()
try:
# Load configuration
config = load_config()
server_config = config.get('server', {})
# Initialize API with config
api = InferenceApi(config)
# Initialize router with the API instance
await init_router(api)
# Create LitServer instance with config
server = ls.LitServer(
api,
timeout=server_config.get('timeout', 60),
max_batch_size=server_config.get('max_batch_size', 1),
track_requests=True
)
# Get the FastAPI app from the LitServer
app = server.app
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add routes with configured prefix
api_prefix = config.get('llm_server', {}).get('api_prefix', '/api/v1')
app.include_router(router, prefix=api_prefix)
return app, api, config
except Exception as e:
logger.error(f"Application initialization failed: {str(e)}")
raise
# Create the FastAPI app instance for uvicorn
app, api_instance, config_dict = asyncio.get_event_loop().run_until_complete(init_app())
async def run_server():
"""Run the server directly (not through uvicorn)"""
server_config = config_dict.get('server', {})
port = server_config.get('port', 8001)
host = server_config.get('host', '0.0.0.0')
# Create LitServer instance with all required parameters
server = ls.LitServer(
api_instance,
timeout=server_config.get('timeout', 60),
max_batch_size=server_config.get('max_batch_size', 1),
track_requests=True
)
server.run(host=host, port=port)
def main():
"""Entry point that runs the server directly"""
asyncio.run(run_server())
if __name__ == "__main__":
main()