File size: 2,151 Bytes
47031d7
 
 
 
 
 
f71fa9b
47031d7
daae8cc
47031d7
15890c0
47031d7
f71fa9b
 
 
 
47031d7
 
 
 
 
 
 
 
 
 
 
 
 
 
0af4a83
 
f71fa9b
 
0af4a83
 
 
 
 
 
 
92cdcfc
 
 
0af4a83
 
 
 
 
 
 
baae755
f71fa9b
 
 
0af4a83
 
be8d239
0af4a83
 
 
 
 
 
 
 
c6b21e3
0af4a83
 
 
c6b21e3
f71fa9b
 
 
0af4a83
9814b43
0af4a83
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
LLM Inference Server main application using LitServe framework.
"""
import litserve as ls
import yaml
import logging
import multiprocessing as mp
from pathlib import Path
from fastapi.middleware.cors import CORSMiddleware
from .routes import router, init_router
from .api import InferenceApi

# Store process list globally so it doesn't get garbage collected
_WORKER_PROCESSES = []
_MANAGER = None

def setup_logging():
    """Set up basic logging configuration"""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    return logging.getLogger(__name__)

def load_config():
    """Load configuration from config.yaml"""
    config_path = Path(__file__).parent / "config.yaml"
    with open(config_path) as f:
        return yaml.safe_load(f)

def create_app():
    """Create and configure the application instance."""
    global _WORKER_PROCESSES, _MANAGER

    logger = setup_logging()
    config = load_config()
    server_config = config.get('server', {})

    # Initialize API with config
    api = InferenceApi(config)

    # Initialize router with API instance
    init_router(api)

    # Create LitServer instance
    server = ls.LitServer(
        api,
        timeout=server_config.get('timeout', 60),
        max_batch_size=server_config.get('max_batch_size', 1),
        track_requests=True
    )

    # Launch inference workers (assuming single uvicorn worker for now)
    _MANAGER, _WORKER_PROCESSES = server.launch_inference_worker(num_uvicorn_servers=1)

    # Get the FastAPI app
    app = server.app

    # Add CORS middleware
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    # Add routes with configured prefix
    api_prefix = config.get('llm_server', {}).get('api_prefix', '/api/v1')
    app.include_router(router, prefix=api_prefix)

    # Set the response queue ID for the app
    app.response_queue_id = 0  # Since we're using a single worker

    return app

# Create the app instance for uvicorn
app = create_app()