DurreSudoku's picture
Update app.py
223715d verified
raw
history blame
3.68 kB
import gradio as gr
import keras
import librosa
import hopsworks
import os
import numpy as np
import shutil
from functions import log_mel_spectrogram, split_spectrogram, load_audio_file, image_transformer, save_spectrogram_as_png
from datasets import load_dataset
def empty_string():
return ""
def create_image_folder(folder):
try:
os.mkdir(folder)
except:
FileExistsError()
return
def delete_folder(folder):
try:
shutil.rmtree(folder)
except:
FileNotFoundError()
return
def create_dataset(image_folder):
image_dataset = load_dataset(image_folder, split=None)["train"]
print(image_dataset)
image_dataset = image_dataset.map(image_transformer, batched=True, fn_kwargs={"mode": "L"})
image_dataset_tf = image_dataset.to_tf_dataset(batch_size=1, columns="image")
return image_dataset_tf
def majority_vote(raw_predictions):
label_predictions = np.argmax(raw_predictions, axis=1)
labels, count = np.unique(label_predictions, return_counts=True)
winner = labels[np.argmax(count)]
return label_decoding[winner]
def predict(audio):
create_image_folder(folder)
try:
audio_array = load_audio_file(audio, sample_rate, res_type, duration)
except:
return "Error when loading audio. Did you submit a file?"
spectrogram = log_mel_spectrogram(audio_array, sample_rate, nfft, hop_length, window)
spec_splits = split_spectrogram(spectrogram, output_shape)
for idx, split in enumerate(spec_splits):
save_path = os.path.join(folder, f"{idx+1}_spec.png")
save_spectrogram_as_png(split, save_path, sample_rate, nfft, hop_length)
image_dataset = create_dataset(folder)
raw_preds = model.predict(image_dataset, verbose=0)
genre_pred = majority_vote(raw_preds)
return f"The submitted audio belongs to the {genre_pred} genre!"
sample_rate = 22050
res_type = "kaiser_fast"
nfft = 2048
hop_length = 512
window = "hann"
output_shape = (128, 256)
duration = 0
folder = "images"
label_decoding = {0: "Electronic",
1: "Experimental",
2: "Folk",
3: "Hip-Hop",
4: "Instrumental",
5: "International",
6: "Pop",
7: "Rock"}
model_path = "best_model.keras"
model = keras.models.load_model(model_path)
"""
model_version = 1
project = hopsworks.login()
mr = project.get_model_registry()
model = mr.get_model("cnn_genre_classifier", version=model_version)
model_dir = model.download()
model = keras.models.load_model(model_dir)
"""
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(
"""
# Music Genre Classifier
Hello!
This is a prototype for a genre classification service, where you can upload an audio file,
and the model will predict which genre it belongs to!
The model has been trained to predict 8 top-level genres, that each encompasses a multitude of sub-genres.
Upload your favorite song and give it a try!
"""
)
with gr.Row():
with gr.Column():
audio = gr.Audio(sources="upload", type="filepath", label="Upload your song here", format="wav")
with gr.Column():
answer_box = gr.Text(label="Answer appears here", interactive=False)
with gr.Row():
submit_audio = gr.Button(value="Submit audio for prediction")
submit_audio.click(fn=empty_string, outputs=answer_box)
submit_audio.click(fn=predict, inputs=audio, outputs=answer_box, trigger_mode="once")
submit_audio.click(fn=delete_folder)
demo.launch(share=True)