|
from fastapi import FastAPI, UploadFile, HTTPException |
|
from fastapi.responses import JSONResponse |
|
import tensorflow as tf |
|
import librosa |
|
import numpy as np |
|
import uvicorn |
|
import os |
|
|
|
|
|
loaded_model = tf.keras.models.load_model('depression_audio_model1.keras') |
|
print("Model loaded successfully.") |
|
|
|
|
|
N_MELS = 128 |
|
N_FFT = 2048 |
|
HOP_LENGTH = 512 |
|
DURATION = 10 |
|
SAMPLE_RATE = 22050 |
|
FIXED_SHAPE = (N_MELS, int(DURATION * SAMPLE_RATE / HOP_LENGTH)) |
|
|
|
|
|
app = FastAPI() |
|
|
|
def extract_mel_spectrogram(file_path, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION, sample_rate=SAMPLE_RATE): |
|
signal, _ = librosa.load(file_path, sr=sample_rate, duration=duration) |
|
mel_spectrogram = librosa.feature.melspectrogram(y=signal, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length) |
|
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max) |
|
mean = mel_spectrogram_db.mean() |
|
std = mel_spectrogram_db.std() |
|
if std > 0: |
|
mel_spectrogram_db = (mel_spectrogram_db - mean) / std |
|
else: |
|
mel_spectrogram_db = mel_spectrogram_db - mean |
|
|
|
if mel_spectrogram_db.shape[1] < FIXED_SHAPE[1]: |
|
pad_width = FIXED_SHAPE[1] - mel_spectrogram_db.shape[1] |
|
mel_spectrogram_db = np.pad(mel_spectrogram_db, ((0, 0), (0, pad_width)), mode='constant') |
|
else: |
|
mel_spectrogram_db = mel_spectrogram_db[:, :FIXED_SHAPE[1]] |
|
return mel_spectrogram_db |
|
|
|
def inference(file_path): |
|
mel_spectrogram_db = extract_mel_spectrogram(file_path) |
|
mel_spectrogram_db = mel_spectrogram_db.reshape(1, *mel_spectrogram_db.shape) |
|
prediction = loaded_model.predict(mel_spectrogram_db) |
|
predicted_label = np.argmax(prediction, axis=-1) |
|
|
|
return int(predicted_label[0]) |
|
|
|
@app.post("/predict") |
|
async def predict(file: UploadFile): |
|
try: |
|
|
|
if not file.filename.endswith(('.wav', '.mp3')): |
|
raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.") |
|
|
|
|
|
temp_file_path = f"temp_{file.filename}" |
|
with open(temp_file_path, "wb") as temp_file: |
|
temp_file.write(await file.read()) |
|
|
|
|
|
predicted_label = inference(temp_file_path) |
|
|
|
|
|
os.remove(temp_file_path) |
|
|
|
return JSONResponse(content={"prediction": predicted_label}) |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}") |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "API is up and running!"} |
|
|