File size: 5,992 Bytes
d828ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict, Any, Union
import torch
import logging
from pathlib import Path
from litgpt.api import LLM
import os
import uvicorn

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="LLM Engine Service")

# Global variable to store the LLM instance
llm_instance = None

class InitializeRequest(BaseModel):
    """
    Configuration for model initialization including model path
    """
    mode: str = "cpu"
    precision: Optional[str] = None
    quantize: Optional[str] = None
    gpu_count: Union[str, int] = "auto"
    model_path: str

class GenerateRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 50
    temperature: float = 1.0
    top_k: Optional[int] = None
    top_p: float = 1.0
    return_as_token_ids: bool = False
    stream: bool = False

@app.post("/initialize")
async def initialize_model(request: InitializeRequest):
    """
    Initialize the LLM model with specified configuration.
    """
    global llm_instance

    try:
        if request.precision is None and request.quantize is None:
            # Use auto distribution from load when no specific precision or quantization is set
            llm_instance = LLM.load(
                model=request.model_path,
                distribute="auto"  # Let the load function handle distribution automatically
            )

            logger.info(
                f"Model initialized with auto settings:\n"
                f"Model Path: {request.model_path}\n"
                f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
                f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
            )
        else:
            # Original initialization path for when specific settings are requested
            llm_instance = LLM.load(
                model=request.model_path,
                distribute=None  # We'll distribute manually
            )

            # Distribute the model according to the configuration
            llm_instance.distribute(
                accelerator="cuda" if request.mode == "gpu" else "cpu",
                devices=request.gpu_count,
                precision=request.precision,
                quantize=request.quantize
            )

            logger.info(
                f"Model initialized successfully with config:\n"
                f"Mode: {request.mode}\n"
                f"Precision: {request.precision}\n"
                f"Quantize: {request.quantize}\n"
                f"GPU Count: {request.gpu_count}\n"
                f"Model Path: {request.model_path}\n"
                f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
                f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
            )

        return {"success": True, "message": "Model initialized successfully"}

    except Exception as e:
        logger.error(f"Error initializing model: {str(e)}")
        # Print detailed memory statistics on failure
        logger.error(f"GPU Memory Stats:\n"
                     f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
                     f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
                     f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
        raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")

@app.post("/generate")
async def generate(request: GenerateRequest):
    """
    Generate text using the initialized model.
    """
    global llm_instance

    if llm_instance is None:
        raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")

    try:
        if request.stream:
            # For streaming responses, we need to handle differently
            # This is a placeholder as the actual streaming implementation
            # would need to use StreamingResponse from FastAPI
            raise HTTPException(
                status_code=400,
                detail="Streaming is not currently supported through the API"
            )

        generated_text = llm_instance.generate(
            prompt=request.prompt,
            max_new_tokens=request.max_new_tokens,
            temperature=request.temperature,
            top_k=request.top_k,
            top_p=request.top_p,
            return_as_token_ids=request.return_as_token_ids,
            stream=False  # Force stream to False for now
        )

        response = {
            "generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
            "metadata": {
                "prompt": request.prompt,
                "max_new_tokens": request.max_new_tokens,
                "temperature": request.temperature,
                "top_k": request.top_k,
                "top_p": request.top_p
            }
        }

        return response

    except Exception as e:
        logger.error(f"Error generating text: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")

@app.get("/health")
async def health_check():
    """
    Check if the service is running and model is loaded.
    """
    global llm_instance

    status = {
        "status": "healthy",
        "model_loaded": llm_instance is not None,
    }

    if llm_instance is not None:
        status["model_info"] = {
            "model_path": llm_instance.config.name,
            "device": str(next(llm_instance.model.parameters()).device)
        }

    return status

def main():
    # Load environment variables or configuration here
    host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0")
    port = int(os.getenv("LLM_ENGINE_PORT", "8001"))

    # Start the server
    uvicorn.run(
        app,
        host=host,
        port=port,
        log_level="info",
        reload=False
    )

if __name__ == "__main__":
    main()