geethareddy commited on
Commit
875d1bc
·
verified ·
1 Parent(s): d25340b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -38
app.py CHANGED
@@ -1,11 +1,10 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from contextlib import asynccontextmanager
5
  import json
6
  import logging
7
  import os
8
- import asyncio
 
9
 
10
  # Set up logging to stdout only
11
  logging.basicConfig(
@@ -17,6 +16,9 @@ logging.basicConfig(
17
  )
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
20
  # Global variables for model and tokenizer
21
  model = None
22
  tokenizer = None
@@ -26,8 +28,8 @@ model_load_status = "not_loaded"
26
  model_path = "/app/fine-tuned-construction-llm"
27
  fallback_model = "distilgpt2"
28
 
29
- # Asynchronous function to load model in the background
30
- async def load_model_background():
31
  global model, tokenizer, model_load_status
32
  try:
33
  if os.path.isdir(model_path):
@@ -45,41 +47,41 @@ async def load_model_background():
45
  logger.error(f"Failed to load model or tokenizer: {str(e)}")
46
  model_load_status = f"failed: {str(e)}"
47
 
48
- # Lifespan event handler to manage startup and shutdown
49
- @asynccontextmanager
50
- async def lifespan(app: FastAPI):
51
- logger.debug("FastAPI application starting")
52
- # Start the background task for model loading
53
- asyncio.create_task(load_model_background())
54
- yield
55
- logger.debug("FastAPI application shutting down")
56
-
57
- # Initialize FastAPI app with lifespan handler
58
- app = FastAPI(lifespan=lifespan)
59
-
60
- # Define input model for validation
61
- class CoachingInput(BaseModel):
62
- role: str
63
- project_id: str
64
- milestones: str
65
- reflection_log: str
66
 
67
- @app.get("/")
68
- async def root():
69
  logger.debug("Root endpoint accessed")
70
- return {"message": "Supervisor AI Coach is running"}
71
 
72
- @app.get("/health")
73
- async def health_check():
74
  logger.debug("Health endpoint accessed")
75
- return {
76
  "status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
77
  "model_load_status": model_load_status
78
- }
79
 
80
- @app.post("/generate_coaching")
81
- async def generate_coaching(data: CoachingInput):
82
  logger.debug("Generate coaching endpoint accessed")
 
 
 
 
 
 
 
 
 
 
 
 
83
  if model is None or tokenizer is None:
84
  logger.warning("Model or tokenizer not loaded")
85
  # Return a static response if the model isn't loaded yet
@@ -88,13 +90,13 @@ async def generate_coaching(data: CoachingInput):
88
  "tips": ["Prioritize team communication", "Check weather updates"],
89
  "quote": "Every step forward counts!"
90
  }
91
- return response_json
92
 
93
  try:
94
  # Prepare input text
95
  input_text = (
96
- f"Role: {data.role}, Project: {data.project_id}, "
97
- f"Milestones: {data.milestones}, Reflection: {data.reflection_log}"
98
  )
99
 
100
  # Tokenize input
@@ -131,8 +133,16 @@ async def generate_coaching(data: CoachingInput):
131
  }
132
  logger.warning("Failed to parse model output as JSON, using default response")
133
 
134
- return response_json
135
 
136
  except Exception as e:
137
  logger.error(f"Error generating coaching response: {str(e)}")
138
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
3
  import json
4
  import logging
5
  import os
6
+ import threading
7
+ import time
8
 
9
  # Set up logging to stdout only
10
  logging.basicConfig(
 
16
  )
17
  logger = logging.getLogger(__name__)
18
 
19
+ # Initialize Flask app
20
+ app = Flask(__name__)
21
+
22
  # Global variables for model and tokenizer
23
  model = None
24
  tokenizer = None
 
28
  model_path = "/app/fine-tuned-construction-llm"
29
  fallback_model = "distilgpt2"
30
 
31
+ # Function to load model in the background
32
+ def load_model_background():
33
  global model, tokenizer, model_load_status
34
  try:
35
  if os.path.isdir(model_path):
 
47
  logger.error(f"Failed to load model or tokenizer: {str(e)}")
48
  model_load_status = f"failed: {str(e)}"
49
 
50
+ # Start model loading in a background thread
51
+ def start_background_tasks():
52
+ logger.debug("Starting background tasks")
53
+ thread = threading.Thread(target=load_model_background)
54
+ thread.daemon = True
55
+ thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ @app.route("/")
58
+ def root():
59
  logger.debug("Root endpoint accessed")
60
+ return jsonify({"message": "Supervisor AI Coach is running"})
61
 
62
+ @app.route("/health")
63
+ def health_check():
64
  logger.debug("Health endpoint accessed")
65
+ return jsonify({
66
  "status": "healthy" if model_load_status in ["local_model_loaded", "fallback_model_loaded"] else "starting",
67
  "model_load_status": model_load_status
68
+ })
69
 
70
+ @app.route("/generate_coaching", methods=["POST"])
71
+ def generate_coaching():
72
  logger.debug("Generate coaching endpoint accessed")
73
+ # Manual validation of request data (replacing Pydantic)
74
+ data = request.get_json()
75
+ if not data:
76
+ logger.error("Invalid request: No JSON data provided")
77
+ return jsonify({"error": "Invalid request: JSON data required"}), 400
78
+
79
+ required_fields = ["role", "project_id", "milestones", "reflection_log"]
80
+ missing_fields = [field for field in required_fields if field not in data]
81
+ if missing_fields:
82
+ logger.error(f"Missing required fields: {missing_fields}")
83
+ return jsonify({"error": f"Missing required fields: {missing_fields}"}), 400
84
+
85
  if model is None or tokenizer is None:
86
  logger.warning("Model or tokenizer not loaded")
87
  # Return a static response if the model isn't loaded yet
 
90
  "tips": ["Prioritize team communication", "Check weather updates"],
91
  "quote": "Every step forward counts!"
92
  }
93
+ return jsonify(response_json)
94
 
95
  try:
96
  # Prepare input text
97
  input_text = (
98
+ f"Role: {data['role']}, Project: {data['project_id']}, "
99
+ f"Milestones: {data['milestones']}, Reflection: {data['reflection_log']}"
100
  )
101
 
102
  # Tokenize input
 
133
  }
134
  logger.warning("Failed to parse model output as JSON, using default response")
135
 
136
+ return jsonify(response_json)
137
 
138
  except Exception as e:
139
  logger.error(f"Error generating coaching response: {str(e)}")
140
+ return jsonify({"error": f"Internal server error: {str(e)}"}), 500
141
+
142
+ if __name__ == "__main__":
143
+ # Start background tasks before the app runs
144
+ start_background_tasks()
145
+ # Run Flask app with waitress for production-ready WSGI server
146
+ from waitress import serve
147
+ logger.debug("Starting Flask app with Waitress")
148
+ serve(app, host="0.0.0.0", port=7860)