from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import librosa
import numpy as np
import uvicorn
import os

# Load the pre-trained model
loaded_model = tf.keras.models.load_model('depression_audio_model1.keras')
print("Model loaded successfully.")

# Constants
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
DURATION = 10
SAMPLE_RATE = 22050
FIXED_SHAPE = (N_MELS, int(DURATION * SAMPLE_RATE / HOP_LENGTH))

# Create the FastAPI app
app = FastAPI()

def extract_mel_spectrogram(file_path, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION, sample_rate=SAMPLE_RATE):
    signal, _ = librosa.load(file_path, sr=sample_rate, duration=duration)
    mel_spectrogram = librosa.feature.melspectrogram(y=signal, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
    mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
    mean = mel_spectrogram_db.mean()
    std = mel_spectrogram_db.std()
    if std > 0:
        mel_spectrogram_db = (mel_spectrogram_db - mean) / std
    else:
        mel_spectrogram_db = mel_spectrogram_db - mean

    if mel_spectrogram_db.shape[1] < FIXED_SHAPE[1]:
        pad_width = FIXED_SHAPE[1] - mel_spectrogram_db.shape[1]
        mel_spectrogram_db = np.pad(mel_spectrogram_db, ((0, 0), (0, pad_width)), mode='constant')
    else:
        mel_spectrogram_db = mel_spectrogram_db[:, :FIXED_SHAPE[1]]
    return mel_spectrogram_db

def inference(file_path):
    mel_spectrogram_db = extract_mel_spectrogram(file_path)
    mel_spectrogram_db = mel_spectrogram_db.reshape(1, *mel_spectrogram_db.shape)  # Add batch dimension
    prediction = loaded_model.predict(mel_spectrogram_db)
    predicted_label = np.argmax(prediction, axis=-1)

    return "depressed" if int(predicted_label[0]) == 1 else "non-depressed"

@app.post("/predict")
async def predict(file: UploadFile):
    try:
        # Check file type
        if not file.filename.endswith(('.wav', '.mp3')):
            raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.")

        # Save uploaded file to a temporary location
        temp_file_path = f"temp_{file.filename}"
        with open(temp_file_path, "wb") as temp_file:
            temp_file.write(await file.read())

        # Perform inference
        predicted_label = inference(temp_file_path)

        # Remove temporary file
        os.remove(temp_file_path)

        return JSONResponse(content={"prediction": predicted_label})

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")

@app.get("/")
async def root():
    return {"message": "API is up and running!"}