audio / main.py
Hammad712's picture
Update main.py
895f8b8 verified
raw
history blame
2.7 kB
from fastapi import FastAPI, UploadFile, HTTPException
from fastapi.responses import JSONResponse
import tensorflow as tf
import librosa
import numpy as np
import uvicorn
import os
# Load the pre-trained model
loaded_model = tf.keras.models.load_model('depression_audio_model1.keras')
print("Model loaded successfully.")
# Constants
N_MELS = 128
N_FFT = 2048
HOP_LENGTH = 512
DURATION = 10
SAMPLE_RATE = 22050
FIXED_SHAPE = (N_MELS, int(DURATION * SAMPLE_RATE / HOP_LENGTH))
# Create the FastAPI app
app = FastAPI()
def extract_mel_spectrogram(file_path, n_mels=N_MELS, n_fft=N_FFT, hop_length=HOP_LENGTH, duration=DURATION, sample_rate=SAMPLE_RATE):
signal, _ = librosa.load(file_path, sr=sample_rate, duration=duration)
mel_spectrogram = librosa.feature.melspectrogram(y=signal, sr=sample_rate, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length)
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
mean = mel_spectrogram_db.mean()
std = mel_spectrogram_db.std()
if std > 0:
mel_spectrogram_db = (mel_spectrogram_db - mean) / std
else:
mel_spectrogram_db = mel_spectrogram_db - mean
if mel_spectrogram_db.shape[1] < FIXED_SHAPE[1]:
pad_width = FIXED_SHAPE[1] - mel_spectrogram_db.shape[1]
mel_spectrogram_db = np.pad(mel_spectrogram_db, ((0, 0), (0, pad_width)), mode='constant')
else:
mel_spectrogram_db = mel_spectrogram_db[:, :FIXED_SHAPE[1]]
return mel_spectrogram_db
def inference(file_path):
mel_spectrogram_db = extract_mel_spectrogram(file_path)
mel_spectrogram_db = mel_spectrogram_db.reshape(1, *mel_spectrogram_db.shape) # Add batch dimension
prediction = loaded_model.predict(mel_spectrogram_db)
predicted_label = np.argmax(prediction, axis=-1)
return int(predicted_label[0])
@app.post("/predict")
async def predict(file: UploadFile):
try:
# Check file type
if not file.filename.endswith(('.wav', '.mp3')):
raise HTTPException(status_code=400, detail="Invalid file type. Please upload an audio file.")
# Save uploaded file to a temporary location
temp_file_path = f"temp_{file.filename}"
with open(temp_file_path, "wb") as temp_file:
temp_file.write(await file.read())
# Perform inference
predicted_label = inference(temp_file_path)
# Remove temporary file
os.remove(temp_file_path)
return JSONResponse(content={"prediction": predicted_label})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during prediction: {str(e)}")
@app.get("/")
async def root():
return {"message": "API is up and running!"}