Spaces:
Runtime error
Runtime error
from pathlib import Path | |
import gradio as gr | |
import numpy as np | |
import os | |
from functools import cache | |
from pathlib import Path | |
from models.audio_spectrogram_transformer import AST, ASTExtractorWrapper | |
from models.training_environment import TrainingEnvironment | |
import torch | |
from torch import nn | |
import yaml | |
import torchaudio | |
CONFIG_FILE = Path("models/config/train_local.yaml") | |
MODEL_CLS = AST | |
EXTRACTOR = ASTExtractorWrapper | |
class DancePredictor: | |
def __init__( | |
self, | |
weight_path: str, | |
labels: list[str], | |
expected_duration=6, | |
threshold=0.5, | |
resample_frequency=16000, | |
device="cpu", | |
): | |
super().__init__() | |
self.expected_duration = expected_duration | |
self.threshold = threshold | |
self.resample_frequency = resample_frequency | |
self.labels = np.array(labels) | |
self.device = device | |
self.model = self.get_model(weight_path) | |
self.extractor = ASTExtractorWrapper() | |
def get_model(self, weight_path: str) -> nn.Module: | |
weights = torch.load(weight_path, map_location=self.device)["state_dict"] | |
model = AST(self.labels).to(self.device) | |
for key in list(weights): | |
weights[ | |
key.replace( | |
"model.", | |
"", | |
) | |
] = weights.pop(key) | |
model.load_state_dict(weights, strict=False) | |
return model.to(self.device).eval() | |
def from_config(cls, config_path: str) -> "DancePredictor": | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f) | |
weight_path = config["checkpoint"] | |
labels = sorted(config["dance_ids"]) | |
expected_duration = 6 | |
threshold = 0.5 | |
resample_frequency = 16000 | |
device = "mps" | |
return DancePredictor( | |
weight_path, | |
labels, | |
expected_duration, | |
threshold, | |
resample_frequency, | |
device, | |
) | |
def __call__(self, waveform: np.ndarray, sample_rate: int) -> dict[str, float]: | |
if waveform.ndim == 1: | |
waveform = np.stack([waveform, waveform]).T | |
waveform = torch.from_numpy(waveform.T) | |
waveform = torchaudio.functional.apply_codec( | |
waveform, sample_rate, "wav", channels_first=True | |
) | |
waveform = torchaudio.functional.resample( | |
waveform, sample_rate, self.resample_frequency | |
) | |
waveform = waveform[ | |
:, : self.resample_frequency * self.expected_duration | |
] # TODO PAD | |
features = self.extractor(waveform) | |
features = features.unsqueeze(0).to(self.device) | |
results = self.model(features) | |
results = nn.functional.softmax(results.squeeze(0), dim=0) | |
results = results.detach().cpu().numpy() | |
result_mask = results > self.threshold | |
probs = results[result_mask] | |
dances = self.labels[result_mask] | |
return {dance: float(prob) for dance, prob in zip(dances, probs)} | |
def get_model(config_path: str) -> DancePredictor: | |
model = DancePredictor.from_config(config_path) | |
return model | |
def predict(audio: tuple[int, np.ndarray]) -> list[str]: | |
sample_rate, waveform = audio | |
model = get_model(CONFIG_FILE) | |
results = model(waveform, sample_rate) | |
return results if len(results) else "Dance Not Found" | |
def demo(): | |
title = "Dance Classifier" | |
description = "What should I dance to this song? Pass some audio to the Dance Classifier find out!" | |
song_samples = Path(os.path.dirname(__file__), "assets", "song-samples") | |
example_audio = [ | |
str(song) for song in song_samples.iterdir() if song.name[0] != "." | |
] | |
all_dances = get_model(CONFIG_FILE).labels | |
recording_interface = gr.Interface( | |
fn=predict, | |
description="Record at least **6 seconds** of the song.", | |
inputs=gr.Audio(source="microphone", label="Song Recording"), | |
outputs=gr.Label(label="Dances"), | |
examples=example_audio, | |
) | |
uploading_interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Audio(label="Song Audio File"), | |
outputs=gr.Label(label="Dances"), | |
examples=example_audio, | |
) | |
with gr.Blocks() as app: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
gr.TabbedInterface( | |
[uploading_interface, recording_interface], ["Upload Song", "Record Song"] | |
) | |
with gr.Accordion("See all dances", open=False): | |
gr.Markdown("\n".join(f"- {dance}" for dance in all_dances)) | |
return app | |
if __name__ == "__main__": | |
demo().launch() | |