File size: 3,630 Bytes
b181476
 
 
 
 
3db7383
b181476
 
 
 
 
 
 
 
 
 
 
 
 
 
3db7383
 
92b443e
3db7383
b181476
 
92b443e
 
 
 
 
 
 
 
b181476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92b443e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b181476
 
 
 
 
 
 
 
 
3db7383
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import logging
import os

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Define input model for validation
class CoachingInput(BaseModel):
    role: str
    project_id: str
    milestones: str
    reflection_log: str

# Define model path (absolute path in the container)
model_path = "/app/fine-tuned-construction-llm"
fallback_model = "gpt2"  # Fallback to a pre-trained model if local model is unavailable

# Load model and tokenizer
try:
    if os.path.isdir(model_path):
        logger.info(f"Loading local model from {model_path}")
        model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True)
        tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
    else:
        logger.warning(f"Model directory not found: {model_path}. Falling back to pre-trained model: {fallback_model}")
        model = AutoModelForCausalLM.from_pretrained(fallback_model)
        tokenizer = AutoTokenizer.from_pretrained(fallback_model)
    logger.info("Model and tokenizer loaded successfully")
except Exception as e:
    logger.error(f"Failed to load model or tokenizer: {str(e)}")
    raise Exception(f"Model loading failed: {str(e)}")

@app.post("/generate_coaching")
async def generate_coaching(data: CoachingInput):
    try:
        # Prepare input text
        input_text = (
            f"Role: {data.role}, Project: {data.project_id}, "
            f"Milestones: {data.milestones}, Reflection: {data.reflection_log}"
        )
        
        # Tokenize input
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
        
        # Generate output
        outputs = model.generate(
            inputs["input_ids"],
            max_length=200,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            temperature=0.7
        )
        
        # Decode and parse response
        response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Since gpt2 may not output JSON, parse the response manually or use fallback
        # This is a simplified parsing logic; adjust based on your model's output format
        if not response_text.startswith("{"):
            checklist = ["Inspect safety equipment", "Review milestone progress"]
            tips = ["Prioritize team communication", "Check weather updates"]
            quote = "Every step forward counts!"
            response_json = {"checklist": checklist, "tips": tips, "quote": quote}
            logger.warning("Model output is not JSON, using default response")
        else:
            try:
                response_json = json.loads(response_text)
            except json.JSONDecodeError:
                response_json = {
                    "checklist": ["Inspect safety equipment", "Review milestone progress"],
                    "tips": ["Prioritize team communication", "Check weather updates"],
                    "quote": "Every step forward counts!"
                }
                logger.warning("Failed to parse model output as JSON, using default response")
        
        return response_json
    
    except Exception as e:
        logger.error(f"Error generating coaching response: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

@app.get("/health")
async def health_check():
    return {"status": "healthy"}