geethareddy commited on
Commit
d25340b
·
verified ·
1 Parent(s): fdb1a6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -16
app.py CHANGED
@@ -1,9 +1,11 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import json
5
  import logging
6
  import os
 
7
 
8
  # Set up logging to stdout only
9
  logging.basicConfig(
@@ -15,16 +17,6 @@ logging.basicConfig(
15
  )
16
  logger = logging.getLogger(__name__)
17
 
18
- # Initialize FastAPI app
19
- app = FastAPI()
20
-
21
- # Define input model for validation
22
- class CoachingInput(BaseModel):
23
- role: str
24
- project_id: str
25
- milestones: str
26
- reflection_log: str
27
-
28
  # Global variables for model and tokenizer
29
  model = None
30
  tokenizer = None
@@ -53,11 +45,24 @@ async def load_model_background():
53
  logger.error(f"Failed to load model or tokenizer: {str(e)}")
54
  model_load_status = f"failed: {str(e)}"
55
 
56
- # Startup event to initiate model loading in the background
57
- @app.on_event("startup")
58
- async def startup_event(background_tasks: BackgroundTasks):
59
- logger.debug("FastAPI application started")
60
- background_tasks.add_task(load_model_background)
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  @app.get("/")
63
  async def root():
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from contextlib import asynccontextmanager
5
  import json
6
  import logging
7
  import os
8
+ import asyncio
9
 
10
  # Set up logging to stdout only
11
  logging.basicConfig(
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
 
 
 
 
20
  # Global variables for model and tokenizer
21
  model = None
22
  tokenizer = None
 
45
  logger.error(f"Failed to load model or tokenizer: {str(e)}")
46
  model_load_status = f"failed: {str(e)}"
47
 
48
+ # Lifespan event handler to manage startup and shutdown
49
+ @asynccontextmanager
50
+ async def lifespan(app: FastAPI):
51
+ logger.debug("FastAPI application starting")
52
+ # Start the background task for model loading
53
+ asyncio.create_task(load_model_background())
54
+ yield
55
+ logger.debug("FastAPI application shutting down")
56
+
57
+ # Initialize FastAPI app with lifespan handler
58
+ app = FastAPI(lifespan=lifespan)
59
+
60
+ # Define input model for validation
61
+ class CoachingInput(BaseModel):
62
+ role: str
63
+ project_id: str
64
+ milestones: str
65
+ reflection_log: str
66
 
67
  @app.get("/")
68
  async def root():