AurelioAguirre commited on
Commit
28fa644
·
1 Parent(s): eda2ff2

Small refactor, moving endpoints to the routes.py file. Also added streaming endpoint, and from_pretrained

Browse files
Files changed (2) hide show
  1. main/main.py +4 -166
  2. main/routes.py +365 -0
main/main.py CHANGED
@@ -1,13 +1,9 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
- from typing import Optional, Union
5
- import torch
6
  import logging
7
- from pathlib import Path
8
- from litgpt.api import LLM
9
  import os
10
  import uvicorn
 
11
 
12
  # Set up logging
13
  logging.basicConfig(level=logging.INFO)
@@ -30,166 +26,8 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
33
- # Global variable to store the LLM instance
34
- llm_instance = None
35
-
36
- class InitializeRequest(BaseModel):
37
- """
38
- Configuration for model initialization including model path
39
- """
40
- mode: str = "cpu"
41
- precision: Optional[str] = None
42
- quantize: Optional[str] = None
43
- gpu_count: Union[str, int] = "auto"
44
- model_path: str
45
-
46
- class GenerateRequest(BaseModel):
47
- prompt: str
48
- max_new_tokens: int = 50
49
- temperature: float = 1.0
50
- top_k: Optional[int] = None
51
- top_p: float = 1.0
52
- return_as_token_ids: bool = False
53
- stream: bool = False
54
-
55
- @app.get("/")
56
- async def root():
57
- """Root endpoint to verify service is running"""
58
- return {
59
- "status": "running",
60
- "service": "LLM Engine",
61
- "endpoints": {
62
- "initialize": "/initialize",
63
- "generate": "/generate",
64
- "health": "/health"
65
- }
66
- }
67
-
68
- @app.post("/initialize")
69
- async def initialize_model(request: InitializeRequest):
70
- """
71
- Initialize the LLM model with specified configuration.
72
- """
73
- global llm_instance
74
-
75
- try:
76
- # Get the project root directory (where main.py is located)
77
- project_root = Path(__file__).parent
78
- checkpoints_dir = project_root / "checkpoints"
79
- logger.info(f"Checkpoint dir is: {checkpoints_dir}")
80
-
81
- # For LitGPT downloaded models, path includes organization
82
- if "/" in request.model_path:
83
- # e.g., "mistralai/Mistral-7B-Instruct-v0.3"
84
- org, model_name = request.model_path.split("/")
85
- model_path = str(checkpoints_dir / org / model_name)
86
- else:
87
- # Fallback for direct model paths
88
- model_path = str(checkpoints_dir / request.model_path)
89
-
90
- logger.info(f"Using model path: {model_path}")
91
-
92
- # Load the model
93
- llm_instance = LLM.load(
94
- model=model_path,
95
- distribute=None if request.precision or request.quantize else "auto"
96
- )
97
-
98
- # If manual distribution is needed
99
- if request.precision or request.quantize:
100
- llm_instance.distribute(
101
- accelerator="cuda" if request.mode == "gpu" else "cpu",
102
- devices=request.gpu_count,
103
- precision=request.precision,
104
- quantize=request.quantize
105
- )
106
-
107
- logger.info(
108
- f"Model initialized successfully with config:\n"
109
- f"Mode: {request.mode}\n"
110
- f"Precision: {request.precision}\n"
111
- f"Quantize: {request.quantize}\n"
112
- f"GPU Count: {request.gpu_count}\n"
113
- f"Model Path: {model_path}\n"
114
- f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
115
- f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
116
- )
117
-
118
- return {"success": True, "message": "Model initialized successfully"}
119
-
120
- except Exception as e:
121
- logger.error(f"Error initializing model: {str(e)}")
122
- # Print detailed memory statistics on failure
123
- logger.error(f"GPU Memory Stats:\n"
124
- f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
125
- f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
126
- f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
127
- raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")
128
-
129
- @app.post("/generate")
130
- async def generate(request: GenerateRequest):
131
- """
132
- Generate text using the initialized model.
133
- """
134
- global llm_instance
135
-
136
- if llm_instance is None:
137
- raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")
138
-
139
- try:
140
- if request.stream:
141
- raise HTTPException(
142
- status_code=400,
143
- detail="Streaming is not currently supported through the API"
144
- )
145
-
146
- generated_text = llm_instance.generate(
147
- prompt=request.prompt,
148
- max_new_tokens=request.max_new_tokens,
149
- temperature=request.temperature,
150
- top_k=request.top_k,
151
- top_p=request.top_p,
152
- return_as_token_ids=request.return_as_token_ids,
153
- stream=False # Force stream to False for now
154
- )
155
-
156
- response = {
157
- "generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
158
- "metadata": {
159
- "prompt": request.prompt,
160
- "max_new_tokens": request.max_new_tokens,
161
- "temperature": request.temperature,
162
- "top_k": request.top_k,
163
- "top_p": request.top_p
164
- }
165
- }
166
-
167
- return response
168
-
169
- except Exception as e:
170
- logger.error(f"Error generating text: {str(e)}")
171
- raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
172
-
173
- @app.get("/health")
174
- async def health_check():
175
- """
176
- Check if the service is running and model is loaded.
177
- """
178
- global llm_instance
179
-
180
- status = {
181
- "status": "healthy",
182
- "model_loaded": llm_instance is not None,
183
- }
184
-
185
- if llm_instance is not None:
186
- logger.info(f"llm_instance is: {llm_instance}")
187
- status["model_info"] = {
188
- "model_path": llm_instance.config.name,
189
- "device": str(next(llm_instance.model.parameters()).device)
190
- }
191
-
192
- return status
193
 
194
  def main():
195
  # Load environment variables or configuration here
 
1
+ from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
 
 
 
3
  import logging
 
 
4
  import os
5
  import uvicorn
6
+ from routes import router
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO)
 
26
  allow_headers=["*"],
27
  )
28
 
29
+ # Include the router from routes.py
30
+ app.include_router(router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def main():
33
  # Load environment variables or configuration here
main/routes.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from pydantic import BaseModel
4
+ from typing import Optional, Union, AsyncGenerator
5
+ import torch
6
+ import logging
7
+ from pathlib import Path
8
+ from litgpt.api import LLM
9
+ import json
10
+ import asyncio
11
+
12
+ # Set up logging
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Create router instance
16
+ router = APIRouter()
17
+
18
+ # Global variable to store the LLM instance
19
+ llm_instance = None
20
+
21
+ class InitializeRequest(BaseModel):
22
+ """
23
+ Configuration for model initialization including model path
24
+ """
25
+ mode: str = "cpu"
26
+ precision: Optional[str] = None
27
+ quantize: Optional[str] = None
28
+ gpu_count: Union[str, int] = "auto"
29
+ model_path: str
30
+
31
+ class GenerateRequest(BaseModel):
32
+ prompt: str
33
+ max_new_tokens: int = 50
34
+ temperature: float = 1.0
35
+ top_k: Optional[int] = None
36
+ top_p: float = 1.0
37
+ return_as_token_ids: bool = False
38
+ stream: bool = False
39
+
40
+ # A Pydantic model for the streaming generation request
41
+ class StreamGenerateRequest(BaseModel):
42
+ prompt: str
43
+ max_new_tokens: int = 50
44
+ temperature: float = 1.0
45
+ top_k: Optional[int] = None
46
+ top_p: float = 1.0
47
+
48
+ class InitializeCustomRequest(BaseModel):
49
+ """
50
+ Configuration for custom model initialization using from_pretrained
51
+ """
52
+ mode: str = "cpu"
53
+ precision: Optional[str] = None
54
+ quantize: Optional[str] = None
55
+ gpu_count: Union[str, int] = "auto"
56
+ folder_path: str # Path to the model folder relative to checkpoints
57
+ model_filename: str # Name of the model file (e.g., "lit_model.pth")
58
+ config_filename: str = "config.json" # Default config filename
59
+ tokenizer_filename: Optional[str] = "tokenizer.json" # Optional tokenizer filename
60
+
61
+
62
+ @router.post("/initialize/custom")
63
+ async def initialize_custom_model(request: InitializeCustomRequest):
64
+ """
65
+ Initialize a custom model using from_pretrained method.
66
+ This is for models that are already downloaded and stored in the checkpoints directory.
67
+ """
68
+ global llm_instance
69
+
70
+ try:
71
+ # Get the project root directory and construct paths
72
+ project_root = Path(__file__).parent
73
+ checkpoints_dir = project_root / "checkpoints"
74
+ model_dir = checkpoints_dir / request.folder_path
75
+
76
+ logger.info(f"Loading custom model from directory: {model_dir}")
77
+
78
+ # Verify that all required files exist
79
+ model_path = model_dir / request.model_filename
80
+ config_path = model_dir / request.config_filename
81
+
82
+ if not model_path.exists():
83
+ raise HTTPException(
84
+ status_code=400,
85
+ detail=f"Model file not found: {request.model_filename}"
86
+ )
87
+
88
+ if not config_path.exists():
89
+ raise HTTPException(
90
+ status_code=400,
91
+ detail=f"Config file not found: {request.config_filename}"
92
+ )
93
+
94
+ # Check for tokenizer if specified
95
+ tokenizer_path = None
96
+ if request.tokenizer_filename:
97
+ tokenizer_path = model_dir / request.tokenizer_filename
98
+ if not tokenizer_path.exists():
99
+ raise HTTPException(
100
+ status_code=400,
101
+ detail=f"Tokenizer file not found: {request.tokenizer_filename}"
102
+ )
103
+
104
+ # Load the model using from_pretrained
105
+ llm_instance = LLM.from_pretrained(
106
+ path=str(model_dir),
107
+ model_file=request.model_filename,
108
+ config_file=request.config_filename,
109
+ tokenizer_file=request.tokenizer_filename if request.tokenizer_filename else None,
110
+ distribute=None if request.precision or request.quantize else "auto"
111
+ )
112
+
113
+ # If manual distribution is needed
114
+ if request.precision or request.quantize:
115
+ llm_instance.distribute(
116
+ accelerator="cuda" if request.mode == "gpu" else "cpu",
117
+ devices=request.gpu_count,
118
+ precision=request.precision,
119
+ quantize=request.quantize
120
+ )
121
+
122
+ # Log success and memory stats
123
+ logger.info(
124
+ f"Custom model initialized successfully with config:\n"
125
+ f"Mode: {request.mode}\n"
126
+ f"Precision: {request.precision}\n"
127
+ f"Quantize: {request.quantize}\n"
128
+ f"GPU Count: {request.gpu_count}\n"
129
+ f"Model Directory: {model_dir}\n"
130
+ f"Model File: {request.model_filename}\n"
131
+ f"Config File: {request.config_filename}\n"
132
+ f"Tokenizer File: {request.tokenizer_filename}\n"
133
+ f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
134
+ f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
135
+ )
136
+
137
+ return {
138
+ "success": True,
139
+ "message": "Custom model initialized successfully",
140
+ "model_info": {
141
+ "folder": str(model_dir),
142
+ "model_file": request.model_filename,
143
+ "config_file": request.config_filename,
144
+ "tokenizer_file": request.tokenizer_filename
145
+ }
146
+ }
147
+
148
+ except Exception as e:
149
+ logger.error(f"Error initializing custom model: {str(e)}")
150
+ # Print detailed memory statistics on failure
151
+ logger.error(f"GPU Memory Stats:\n"
152
+ f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
153
+ f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
154
+ f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
155
+ raise HTTPException(status_code=500, detail=f"Error initializing custom model: {str(e)}")
156
+
157
+
158
+ # Endpoint for streaming generation
159
+ @router.post("/generate/stream")
160
+ async def generate_stream(request: StreamGenerateRequest):
161
+ """
162
+ Generate text using the initialized model with streaming response.
163
+ Returns a StreamingResponse that yields JSON-formatted chunks of text.
164
+ """
165
+ global llm_instance
166
+
167
+ if llm_instance is None:
168
+ raise HTTPException(
169
+ status_code=400,
170
+ detail="Model not initialized. Call /initialize first."
171
+ )
172
+
173
+ async def event_generator() -> AsyncGenerator[str, None]:
174
+ try:
175
+ # Start the generation with streaming enabled
176
+ async for token in llm_instance.generate(
177
+ prompt=request.prompt,
178
+ max_new_tokens=request.max_new_tokens,
179
+ temperature=request.temperature,
180
+ top_k=request.top_k,
181
+ top_p=request.top_p,
182
+ stream=True # Enable streaming
183
+ ):
184
+ # Create a JSON response for each token
185
+ chunk = {
186
+ "token": token,
187
+ "metadata": {
188
+ "prompt": request.prompt,
189
+ "is_finished": False
190
+ }
191
+ }
192
+ # Format as SSE data
193
+ yield f"data: {json.dumps(chunk)}\n\n"
194
+
195
+ # Small delay to prevent overwhelming the client
196
+ await asyncio.sleep(0.01)
197
+
198
+ # Send final message indicating completion
199
+ final_chunk = {
200
+ "token": "",
201
+ "metadata": {
202
+ "prompt": request.prompt,
203
+ "is_finished": True
204
+ }
205
+ }
206
+ yield f"data: {json.dumps(final_chunk)}\n\n"
207
+
208
+ except Exception as e:
209
+ logger.error(f"Error in stream generation: {str(e)}")
210
+ error_chunk = {
211
+ "error": str(e),
212
+ "metadata": {
213
+ "prompt": request.prompt,
214
+ "is_finished": True
215
+ }
216
+ }
217
+ yield f"data: {json.dumps(error_chunk)}\n\n"
218
+
219
+ return StreamingResponse(
220
+ event_generator(),
221
+ media_type="text/event-stream",
222
+ headers={
223
+ 'Cache-Control': 'no-cache',
224
+ 'Connection': 'keep-alive',
225
+ }
226
+ )
227
+
228
+ @router.get("/")
229
+ async def root():
230
+ """Root endpoint to verify service is running"""
231
+ return {
232
+ "status": "running",
233
+ "service": "LLM Engine",
234
+ "endpoints": {
235
+ "initialize": "/initialize",
236
+ "generate": "/generate",
237
+ "health": "/health"
238
+ }
239
+ }
240
+
241
+ @router.post("/initialize")
242
+ async def initialize_model(request: InitializeRequest):
243
+ """
244
+ Initialize the LLM model with specified configuration.
245
+ """
246
+ global llm_instance
247
+
248
+ try:
249
+ # Get the project root directory (where main.py is located)
250
+ project_root = Path(__file__).parent
251
+ checkpoints_dir = project_root / "checkpoints"
252
+ logger.info(f"Checkpoint dir is: {checkpoints_dir}")
253
+
254
+ # For LitGPT downloaded models, path includes organization
255
+ if "/" in request.model_path:
256
+ # e.g., "mistralai/Mistral-7B-Instruct-v0.3"
257
+ org, model_name = request.model_path.split("/")
258
+ model_path = str(checkpoints_dir / org / model_name)
259
+ else:
260
+ # Fallback for direct model paths
261
+ model_path = str(checkpoints_dir / request.model_path)
262
+
263
+ logger.info(f"Using model path: {model_path}")
264
+
265
+ # Load the model
266
+ llm_instance = LLM.load(
267
+ model=model_path,
268
+ distribute=None if request.precision or request.quantize else "auto"
269
+ )
270
+
271
+ # If manual distribution is needed
272
+ if request.precision or request.quantize:
273
+ llm_instance.distribute(
274
+ accelerator="cuda" if request.mode == "gpu" else "cpu",
275
+ devices=request.gpu_count,
276
+ precision=request.precision,
277
+ quantize=request.quantize
278
+ )
279
+
280
+ logger.info(
281
+ f"Model initialized successfully with config:\n"
282
+ f"Mode: {request.mode}\n"
283
+ f"Precision: {request.precision}\n"
284
+ f"Quantize: {request.quantize}\n"
285
+ f"GPU Count: {request.gpu_count}\n"
286
+ f"Model Path: {model_path}\n"
287
+ f"Current GPU Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, "
288
+ f"{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved"
289
+ )
290
+
291
+ return {"success": True, "message": "Model initialized successfully"}
292
+
293
+ except Exception as e:
294
+ logger.error(f"Error initializing model: {str(e)}")
295
+ # Print detailed memory statistics on failure
296
+ logger.error(f"GPU Memory Stats:\n"
297
+ f"Allocated: {torch.cuda.memory_allocated()/1024**3:.2f}GB\n"
298
+ f"Reserved: {torch.cuda.memory_reserved()/1024**3:.2f}GB\n"
299
+ f"Max Allocated: {torch.cuda.max_memory_allocated()/1024**3:.2f}GB")
300
+ raise HTTPException(status_code=500, detail=f"Error initializing model: {str(e)}")
301
+
302
+ @router.post("/generate")
303
+ async def generate(request: GenerateRequest):
304
+ """
305
+ Generate text using the initialized model.
306
+ """
307
+ global llm_instance
308
+
309
+ if llm_instance is None:
310
+ raise HTTPException(status_code=400, detail="Model not initialized. Call /initialize first.")
311
+
312
+ try:
313
+ if request.stream:
314
+ raise HTTPException(
315
+ status_code=400,
316
+ detail="Streaming is not currently supported through the API"
317
+ )
318
+
319
+ generated_text = llm_instance.generate(
320
+ prompt=request.prompt,
321
+ max_new_tokens=request.max_new_tokens,
322
+ temperature=request.temperature,
323
+ top_k=request.top_k,
324
+ top_p=request.top_p,
325
+ return_as_token_ids=request.return_as_token_ids,
326
+ stream=False # Force stream to False for now
327
+ )
328
+
329
+ response = {
330
+ "generated_text": generated_text if not request.return_as_token_ids else generated_text.tolist(),
331
+ "metadata": {
332
+ "prompt": request.prompt,
333
+ "max_new_tokens": request.max_new_tokens,
334
+ "temperature": request.temperature,
335
+ "top_k": request.top_k,
336
+ "top_p": request.top_p
337
+ }
338
+ }
339
+
340
+ return response
341
+
342
+ except Exception as e:
343
+ logger.error(f"Error generating text: {str(e)}")
344
+ raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
345
+
346
+ @router.get("/health")
347
+ async def health_check():
348
+ """
349
+ Check if the service is running and model is loaded.
350
+ """
351
+ global llm_instance
352
+
353
+ status = {
354
+ "status": "healthy",
355
+ "model_loaded": llm_instance is not None,
356
+ }
357
+
358
+ if llm_instance is not None:
359
+ logger.info(f"llm_instance is: {llm_instance}")
360
+ status["model_info"] = {
361
+ "model_path": llm_instance.config.name,
362
+ "device": str(next(llm_instance.model.parameters()).device)
363
+ }
364
+
365
+ return status