Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import torch | |
from preprocessing.preprocess import AudioPipeline | |
from preprocessing.preprocess import AudioPipeline | |
from dancer_net.dancer_net import ShortChunkCNN | |
import os | |
import json | |
from functools import cache | |
import pandas as pd | |
def get_model(device) -> tuple[ShortChunkCNN, np.ndarray]: | |
model_path = "logs/20221226-230930" | |
weights = os.path.join(model_path, "dancer_net.pt") | |
config_path = os.path.join(model_path, "config.json") | |
with open(config_path) as f: | |
config = json.load(f) | |
labels = np.array(sorted(config["classes"])) | |
model = ShortChunkCNN(n_class=len(labels)) | |
model.load_state_dict(torch.load(weights)) | |
model = model.to(device).eval() | |
return model, labels | |
def get_pipeline(sample_rate:int) -> AudioPipeline: | |
return AudioPipeline(input_freq=sample_rate) | |
def get_dance_map() -> dict: | |
df = pd.read_csv("data/dance_mapping.csv") | |
return df.set_index("id").to_dict()["name"] | |
def predict(audio: tuple[int, np.ndarray]) -> list[str]: | |
sample_rate, waveform = audio | |
expected_duration = 6 | |
threshold = 0.5 | |
sample_len = sample_rate * expected_duration | |
device = "mps" | |
audio_pipeline = get_pipeline(sample_rate) | |
model, labels = get_model(device) | |
if sample_len > len(waveform): | |
raise gr.Error("You must record for at least 6 seconds") | |
if len(waveform.shape) > 1 and waveform.shape[1] > 1: | |
waveform = waveform.transpose(1,0) | |
waveform = waveform.mean(axis=0, keepdims=True) | |
else: | |
waveform = np.expand_dims(waveform, 0) | |
waveform = waveform[: ,:sample_len] | |
waveform = (waveform - waveform.min()) / (waveform.max() - waveform.min()) * 2 - 1 | |
waveform = waveform.astype("float32") | |
waveform = torch.from_numpy(waveform) | |
spectrogram = audio_pipeline(waveform) | |
spectrogram = spectrogram.unsqueeze(0).to(device) | |
with torch.no_grad(): | |
results = model(spectrogram) | |
dance_mapping = get_dance_map() | |
results = results.squeeze(0).detach().cpu().numpy() | |
result_mask = results > threshold | |
probs = results[result_mask] | |
dances = labels[result_mask] | |
return {dance_mapping[dance_id]:float(prob) for dance_id, prob in zip(dances, probs)} if len(dances) else "Couldn't find a dance." | |
def demo(): | |
title = "Dance Classifier" | |
description = "Record 6 seconds of a song and find out what dance fits the music." | |
with gr.Blocks() as app: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
with gr.Tab("Record Song"): | |
mic_audio = gr.Audio(source="microphone", label="Song Recording") | |
mic_submit = gr.Button("Predict") | |
with gr.Tab("Upload Song") as t: | |
audio_file = gr.Audio(label="Song Audio File") | |
audio_file_submit = gr.Button("Predict") | |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples") | |
example_audio = [str(song) for song in song_samples.iterdir() if song.name[0] != '.'] | |
labels = gr.Label(label="Dances") | |
gr.Markdown("## Examples") | |
gr.Examples( | |
examples=example_audio, | |
inputs=audio_file, | |
outputs=labels, | |
fn=predict, | |
) | |
audio_file_submit.click(fn=predict, inputs=audio_file, outputs=labels) | |
mic_submit.click(fn=predict, inputs=mic_audio, outputs=labels) | |
return app | |
if __name__ == "__main__": | |
demo().launch() |