Spaces:
Paused
Paused
File size: 5,866 Bytes
d828ce4 a189e20 d828ce4 a189e20 d828ce4 a189e20 d828ce4 a189e20 d828ce4 a189e20 d828ce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional, Dict, Any, Union
import torch
import logging
from pathlib import Path
from litgpt.api import LLM
import os
import uvicorn
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="LLM Engine Service")
# Global variable to store the LLM instance
llm_instance = None
class InitializeRequest(BaseModel):
"""
Configuration for model initialization including model path
"""
mode: str = "cpu"
precision: Optional[str] = None
quantize: Optional[str] = None
gpu_count: Union[str, int] = "auto"
model_path: str
class GenerateRequest(BaseModel):
prompt: str
max_new_tokens: int = 50
temperature: float = 1.0
top_k: Optional[int] = None
top_p: float = 1.0
return_as_token_ids: bool = False
stream: bool = False
@app.post("/initialize")
async def initialize_model(request: InitializeRequest):
"""
Initialize the LLM model with specified configuration.
"""
global llm_instance
try:
# Get the project root directory (where main.py is located)
project_root = Path(__file__).parent
checkpoints_dir = project_root / "checkpoints"
# For LitGPT downloaded models, path includes organization
if "/" in request.model_path:
# e.g., "mistralai/Mistral-7B-Instruct-v0.3"
org, model_name = request.model_path.split("/")
model_path = str(checkpoints_dir / org / model_name)
else:
# Fallback for direct model paths
model_path = str(checkpoints_dir / request.model_path)
logger.info(f"Using model path: {model_path}")
# Load the model
llm_instance = LLM.load(
model=model_path,
distribute=None if request.precision or request.quantize else "auto"
)
# If manual distribution is needed
if request.precision or request.quantize:
llm_instance.distribute(
accelerator="cuda" if request.mode == "gpu" else "cpu",
devices=request.gpu_count,
precision=request.precision,
quantize=request.quantize
)
logger.info(
f"Model initialized successfully with config:\n"
f"Mode: {request.mode}\n"
f"Precision: {request.precision}\n"
f"Quantize: {request.quantize}\n"
f"GPU Count: {request.gpu_count}\n"
f"Model Path: {model_path}\n"
f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
)
return {"success": True, "message": "Model initialized successfully"}
except Exception as e:
logger.error(f"Error initializing model: {str(e)}")
# Print detailed memory statistics on failure
logger.error(f"GPU Memory Stats:\n"
f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")
@app.post("/generate")
async def generate(request: GenerateRequest):
"""
Generate text using the initialized model.
"""
global llm_instance
if llm_instance is None:
raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")
try:
if request.stream:
# For streaming responses, we need to handle differently
# This is a placeholder as the actual streaming implementation
# would need to use StreamingResponse from FastAPI
raise HTTPException(
status_code=400,
detail="Streaming is not currently supported through the API"
)
generated_text = llm_instance.generate(
prompt=request.prompt,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
return_as_token_ids=request.return_as_token_ids,
stream=False # Force stream to False for now
)
response = {
"generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
"metadata": {
"prompt": request.prompt,
"max_new_tokens": request.max_new_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p
}
}
return response
except Exception as e:
logger.error(f"Error generating text: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
@app.get("/health")
async def health_check():
"""
Check if the service is running and model is loaded.
"""
global llm_instance
status = {
"status": "healthy",
"model_loaded": llm_instance is not None,
}
if llm_instance is not None:
status["model_info"] = {
"model_path": llm_instance.config.name,
"device": str(next(llm_instance.model.parameters()).device)
}
return status
def main():
# Load environment variables or configuration here
host = os.getenv("LLM_ENGINE_HOST", "0.0.0.0")
port = int(os.getenv("LLM_ENGINE_PORT", "8001"))
# Start the server
uvicorn.run(
app,
host=host,
port=port,
log_level="info",
reload=False
)
if __name__ == "__main__":
main() |