File size: 5,866 Bytes
d828ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a189e20
 
 
d828ce4
a189e20
 
 
 
 
d828ce4
a189e20
 
d828ce4
a189e20
 
 
 
 
 
 
 
 
 
d828ce4
 
 
 
 
 
 
a189e20
 
 
 
 
 
 
 
 
 
d828ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict, Any, Union
import torch
import logging
from pathlib import Path
from litgpt.api import LLM
import os
import uvicorn

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="LLM Engine Service")

# Global variable to store the LLM instance
llm_instance = None

class InitializeRequest(BaseModel):
    """
    Configuration for model initialization including model path
    """
    mode: str = "cpu"
    precision: Optional[str] = None
    quantize: Optional[str] = None
    gpu_count: Union[str, int] = "auto"
    model_path: str

class GenerateRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 50
    temperature: float = 1.0
    top_k: Optional[int] = None
    top_p: float = 1.0
    return_as_token_ids: bool = False
    stream: bool = False

@app.post("/initialize")
async def initialize_model(request: InitializeRequest):
    """
    Initialize the LLM model with specified configuration.
    """
    global llm_instance

    try:
        # Get the project root directory (where main.py is located)
        project_root = Path(__file__).parent
        checkpoints_dir = project_root / "checkpoints"

        # For LitGPT downloaded models, path includes organization
        if "/" in request.model_path:
            # e.g., "mistralai/Mistral-7B-Instruct-v0.3"
            org, model_name = request.model_path.split("/")
            model_path = str(checkpoints_dir / org / model_name)
        else:
            # Fallback for direct model paths
            model_path = str(checkpoints_dir / request.model_path)

        logger.info(f"Using model path: {model_path}")

        # Load the model
        llm_instance = LLM.load(
            model=model_path,
            distribute=None if request.precision or request.quantize else "auto"
        )

        # If manual distribution is needed
        if request.precision or request.quantize:
            llm_instance.distribute(
                accelerator="cuda" if request.mode == "gpu" else "cpu",
                devices=request.gpu_count,
                precision=request.precision,
                quantize=request.quantize
            )

        logger.info(
            f"Model initialized successfully with config:\n"
            f"Mode: {request.mode}\n"
            f"Precision: {request.precision}\n"
            f"Quantize: {request.quantize}\n"
            f"GPU Count: {request.gpu_count}\n"
            f"Model Path: {model_path}\n"
            f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
            f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
        )

        return {"success": True, "message": "Model initialized successfully"}

    except Exception as e:
        logger.error(f"Error initializing model: {str(e)}")
        # Print detailed memory statistics on failure
        logger.error(f"GPU Memory Stats:\n"
                     f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
                     f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
                     f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
        raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")

@app.post("/generate")
async def generate(request: GenerateRequest):
    """
    Generate text using the initialized model.
    """
    global llm_instance

    if llm_instance is None:
        raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")

    try:
        if request.stream:
            # For streaming responses, we need to handle differently
            # This is a placeholder as the actual streaming implementation
            # would need to use StreamingResponse from FastAPI
            raise HTTPException(
                status_code=400,
                detail="Streaming is not currently supported through the API"
            )

        generated_text = llm_instance.generate(
            prompt=request.prompt,
            max_new_tokens=request.max_new_tokens,
            temperature=request.temperature,
            top_k=request.top_k,
            top_p=request.top_p,
            return_as_token_ids=request.return_as_token_ids,
            stream=False  # Force stream to False for now
        )

        response = {
            "generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
            "metadata": {
                "prompt": request.prompt,
                "max_new_tokens": request.max_new_tokens,
                "temperature": request.temperature,
                "top_k": request.top_k,
                "top_p": request.top_p
            }
        }

        return response

    except Exception as e:
        logger.error(f"Error generating text: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")

@app.get("/health")
async def health_check():
    """
    Check if the service is running and model is loaded.
    """
    global llm_instance

    status = {
        "status": "healthy",
        "model_loaded": llm_instance is not None,
    }

    if llm_instance is not None:
        status["model_info"] = {
            "model_path": llm_instance.config.name,
            "device": str(next(llm_instance.model.parameters()).device)
        }

    return status

def main():
    # Load environment variables or configuration here
    host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0")
    port = int(os.getenv("LLM_ENGINE_PORT", "8001"))

    # Start the server
    uvicorn.run(
        app,
        host=host,
        port=port,
        log_level="info",
        reload=False
    )

if __name__ == "__main__":
    main()