geethareddy commited on
Commit
e7c1a90
·
verified ·
1 Parent(s): 84d2880

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -23
app.py CHANGED
@@ -9,6 +9,7 @@ import os
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
 
12
  app = FastAPI()
13
 
14
  # Define input model for validation
@@ -18,27 +19,52 @@ class CoachingInput(BaseModel):
18
  milestones: str
19
  reflection_log: str
20
 
21
- # Define model path (absolute path in the container)
 
 
 
 
 
22
  model_path = "/app/fine-tuned-construction-llm"
23
- fallback_model = "gpt2" # Fallback to a pre-trained model if local model is unavailable
24
 
25
- # Load model and tokenizer
26
- try:
27
- if os.path.isdir(model_path):
28
- logger.info(f"Loading local model from {model_path}")
29
- model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
30
- tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
31
- else:
32
- logger.warning(f"Model directory not found: {model_path}. Falling back to pre-trained model: {fallback_model}")
33
- model = AutoModelForCausalLM.from_pretrained(fallback_model)
34
- tokenizer = AutoTokenizer.from_pretrained(fallback_model)
35
- logger.info("Model and tokenizer loaded successfully")
36
- except Exception as e:
37
- logger.error(f"Failed to load model or tokenizer: {str(e)}")
38
- raise Exception(f"Model loading failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @app.post("/generate_coaching")
41
  async def generate_coaching(data: CoachingInput):
 
 
 
 
42
  try:
43
  # Prepare input text
44
  input_text = (
@@ -62,8 +88,7 @@ async def generate_coaching(data: CoachingInput):
62
  # Decode and parse response
63
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
 
65
- # Since gpt2 may not output JSON, parse the response manually or use fallback
66
- # This is a simplified parsing logic; adjust based on your model's output format
67
  if not response_text.startswith("{"):
68
  checklist = ["Inspect safety equipment", "Review milestone progress"]
69
  tips = ["Prioritize team communication", "Check weather updates"]
@@ -85,8 +110,4 @@ async def generate_coaching(data: CoachingInput):
85
 
86
  except Exception as e:
87
  logger.error(f"Error generating coaching response: {str(e)}")
88
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
89
-
90
- @app.get("/health")
91
- async def health_check():
92
- return {"status": "healthy"}
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ # Initialize FastAPI app
13
  app = FastAPI()
14
 
15
  # Define input model for validation
 
19
  milestones: str
20
  reflection_log: str
21
 
22
+ # Global variables for model and tokenizer
23
+ model = None
24
+ tokenizer = None
25
+ model_load_status = "not_loaded"
26
+
27
+ # Define model path and fallback
28
  model_path = "/app/fine-tuned-construction-llm"
29
+ fallback_model = "distilgpt2" # Smaller model for faster loading
30
 
31
+ # Load model and tokenizer at startup
32
+ def load_model():
33
+ global model, tokenizer, model_load_status
34
+ try:
35
+ if os.path.isdir(model_path):
36
+ logger.info(f"Loading local model from {model_path}")
37
+ model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
39
+ model_load_status = "local_model_loaded"
40
+ else:
41
+ logger.warning(f"Model directory not found: {model_path}. Falling back to pre-trained model: {fallback_model}")
42
+ model = AutoModelForCausalLM.from_pretrained(fallback_model)
43
+ tokenizer = AutoTokenizer.from_pretrained(fallback_model)
44
+ model_load_status = "fallback_model_loaded"
45
+ logger.info("Model and tokenizer loaded successfully")
46
+ except Exception as e:
47
+ logger.error(f"Failed to load model or tokenizer: {str(e)}")
48
+ model_load_status = f"failed: {str(e)}"
49
+ # Do not raise an exception; allow the app to start
50
+
51
+ # Load model on startup
52
+ load_model()
53
+
54
+ @app.on_event("startup")
55
+ async def startup_event():
56
+ logger.info("FastAPI application started")
57
+
58
+ @app.get("/health")
59
+ async def health_check():
60
+ return {"status": "healthy", "model_load_status": model_load_status}
61
 
62
  @app.post("/generate_coaching")
63
  async def generate_coaching(data: CoachingInput):
64
+ if model is None or tokenizer is None:
65
+ logger.error("Model or tokenizer not loaded")
66
+ raise HTTPException(status_code=503, detail="Model not loaded. Please check server logs.")
67
+
68
  try:
69
  # Prepare input text
70
  input_text = (
 
88
  # Decode and parse response
89
  response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
90
 
91
+ # Since distilgpt2 may not output JSON, parse the response manually or use fallback
 
92
  if not response_text.startswith("{"):
93
  checklist = ["Inspect safety equipment", "Review milestone progress"]
94
  tips = ["Prioritize team communication", "Check weather updates"]
 
110
 
111
  except Exception as e:
112
  logger.error(f"Error generating coaching response: {str(e)}")
113
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")