File size: 2,696 Bytes
6682f41 895f8b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import librosa
import numpy as np
import uvicorn
import os
# Load the pre-trained model
loaded_model = tf.keras.models.load_model('depression_audio_model1.keras')
print("Model loaded successfully.")
# Constants
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
DURATION = 10
SAMPLE_RATE = 22050
FIXED_SHAPE = (N_MELS, int(DURATION * SAMPLE_RATE / HOP_LENGTH))
# Create the FastAPI app
app = FastAPI()
def extract_mel_spectrogram(file_path, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION, sample_rate=SAMPLE_RATE):
signal, _ = librosa.load(file_path, sr=sample_rate, duration=duration)
mel_spectrogram = librosa.feature.melspectrogram(y=signal, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
mean = mel_spectrogram_db.mean()
std = mel_spectrogram_db.std()
if std > 0:
mel_spectrogram_db = (mel_spectrogram_db - mean) / std
else:
mel_spectrogram_db = mel_spectrogram_db - mean
if mel_spectrogram_db.shape[1] < FIXED_SHAPE[1]:
pad_width = FIXED_SHAPE[1] - mel_spectrogram_db.shape[1]
mel_spectrogram_db = np.pad(mel_spectrogram_db, ((0, 0), (0, pad_width)), mode='constant')
else:
mel_spectrogram_db = mel_spectrogram_db[:, :FIXED_SHAPE[1]]
return mel_spectrogram_db
def inference(file_path):
mel_spectrogram_db = extract_mel_spectrogram(file_path)
mel_spectrogram_db = mel_spectrogram_db.reshape(1, *mel_spectrogram_db.shape) # Add batch dimension
prediction = loaded_model.predict(mel_spectrogram_db)
predicted_label = np.argmax(prediction, axis=-1)
return int(predicted_label[0])
@app.post("/predict")
async def predict(file: UploadFile):
try:
# Check file type
if not file.filename.endswith(('.wav', '.mp3')):
raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.")
# Save uploaded file to a temporary location
temp_file_path = f"temp_{file.filename}"
with open(temp_file_path, "wb") as temp_file:
temp_file.write(await file.read())
# Perform inference
predicted_label = inference(temp_file_path)
# Remove temporary file
os.remove(temp_file_path)
return JSONResponse(content={"prediction": predicted_label})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
@app.get("/")
async def root():
return {"message": "API is up and running!"}
|