from fastapi import FastAPI, UploadFile, HTTPException from fastapi.responses import JSONResponse import tensorflow as tf import librosa import numpy as np import uvicorn import os # Load the pre-trained model loaded_model = tf.keras.models.load_model('depression_audio_model1.keras') print("Model loaded successfully.") # Constants N_MELS = 128 N_FFT = 2048 HOP_LENGTH = 512 DURATION = 10 SAMPLE_RATE = 22050 FIXED_SHAPE = (N_MELS, int(DURATION * SAMPLE_RATE / HOP_LENGTH)) # Create the FastAPI app app = FastAPI() def extract_mel_spectrogram(file_path, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION, sample_rate=SAMPLE_RATE): signal, _ = librosa.load(file_path, sr=sample_rate, duration=duration) mel_spectrogram = librosa.feature.melspectrogram(y=signal, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max) mean = mel_spectrogram_db.mean() std = mel_spectrogram_db.std() if std > 0: mel_spectrogram_db = (mel_spectrogram_db - mean) / std else: mel_spectrogram_db = mel_spectrogram_db - mean if mel_spectrogram_db.shape[1] < FIXED_SHAPE[1]: pad_width = FIXED_SHAPE[1] - mel_spectrogram_db.shape[1] mel_spectrogram_db = np.pad(mel_spectrogram_db, ((0, 0), (0, pad_width)), mode='constant') else: mel_spectrogram_db = mel_spectrogram_db[:, :FIXED_SHAPE[1]] return mel_spectrogram_db def inference(file_path): mel_spectrogram_db = extract_mel_spectrogram(file_path) mel_spectrogram_db = mel_spectrogram_db.reshape(1, *mel_spectrogram_db.shape) # Add batch dimension prediction = loaded_model.predict(mel_spectrogram_db) predicted_label = np.argmax(prediction, axis=-1) return int(predicted_label[0]) @app.post("/predict") async def predict(file: UploadFile): try: # Check file type if not file.filename.endswith(('.wav', '.mp3')): raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.") # Save uploaded file to a temporary location temp_file_path = f"temp_{file.filename}" with open(temp_file_path, "wb") as temp_file: temp_file.write(await file.read()) # Perform inference predicted_label = inference(temp_file_path) # Remove temporary file os.remove(temp_file_path) return JSONResponse(content={"prediction": predicted_label}) except Exception as e: raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}") @app.get("/") async def root(): return {"message": "API is up and running!"}