File size: 2,121 Bytes
47031d7
 
 
 
 
 
9814b43
47031d7
daae8cc
47031d7
15890c0
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baae755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8d239
baae755
be8d239
 
 
 
 
 
 
c6b21e3
 
 
 
 
 
 
9814b43
c6b21e3
 
9814b43
b688bec
5b76cc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
"""
LLM Inference Server main application using LitServe framework.
"""
import litserve as ls
import yaml
import logging
import asyncio
from pathlib import Path
from fastapi.middleware.cors import CORSMiddleware
from .routes import router, init_router
from .api import InferenceApi

def setup_logging():
    """Set up basic logging configuration"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)

def load_config():
    """Load configuration from config.yaml"""
    config_path = Path(__file__).parent / "config.yaml"
    with open(config_path) as f:
        return yaml.safe_load(f)

# Initialize everything synchronously
logger = setup_logging()
config = load_config()
server_config = config.get('server', {})
api = InferenceApi(config)

# Create LitServer instance
server = ls.LitServer(
    api,
    timeout=server_config.get('timeout', 60),
    max_batch_size=server_config.get('max_batch_size', 1),
    track_requests=True
)

# Get the FastAPI app from LitServer
app = server.app

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Add routes with configured prefix
api_prefix = config.get('llm_server', {}).get('api_prefix', '/api/v1')
app.include_router(router, prefix=api_prefix)

@app.on_event("startup")
async def startup_event():
    """Initialize async components on startup."""
    # Initialize the router
    await init_router(api)
    # Launch the inference worker
    server.launch_inference_worker()

@app.on_event("shutdown")
async def shutdown_event():
    """Cleanup on shutdown."""
    server.stop_inference_worker()

async def run_server():
    """Run the server directly (not through uvicorn)"""
    port = server_config.get('port', 8001)
    host = server_config.get('host', '0.0.0.0')
    server.run(host=host, port=port)

def main():
    """Entry point that runs the server directly"""
    asyncio.run(run_server())

if __name__ == "__main__":
    main()