Spaces:
Sleeping
Sleeping
from cvae import CVAE | |
import torch | |
from typing import Sequence | |
import streamlit as st | |
from lightning import LightningModule | |
def format_instruments(text: str) -> str: | |
stems = text.split(" ")[1:] | |
stems = [stem.replace(" ", "").lower() for stem in stems] | |
return "_".join(stems) | |
def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor: | |
choice = "_".join([format_instruments(i) for i in choice]) | |
return torch.tensor(instruments.index(choice)) | |
def load_model(device: str) -> LightningModule: | |
return CVAE.load_from_checkpoint( | |
"epoch=77-step=2819778.ckpt", | |
io_channels=1, | |
io_features=16000 * 4, | |
latent_features=5, | |
channels=[32, 64, 128, 256, 512], | |
num_classes=len(instruments), | |
learning_rate=1e-5, | |
).to(device) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
instruments = [ | |
"bass_acoustic", | |
"brass_acoustic", | |
"flute_acoustic", | |
"guitar_acoustic", | |
"keyboard_acoustic", | |
"mallet_acoustic", | |
"organ_acoustic", | |
"reed_acoustic", | |
"string_acoustic", | |
"synth_lead_acoustic", | |
"vocal_acoustic", | |
"bass_synthetic", | |
"brass_synthetic", | |
"flute_synthetic", | |
"guitar_synthetic", | |
"keyboard_synthetic", | |
"mallet_synthetic", | |
"organ_synthetic", | |
"reed_synthetic", | |
"string_synthetic", | |
"synth_lead_synthetic", | |
"vocal_synthetic", | |
"bass_electronic", | |
"brass_electronic", | |
"flute_electronic", | |
"guitar_electronic", | |
"keyboard_electronic", | |
"mallet_electronic", | |
"organ_electronic", | |
"reed_electronic", | |
"string_electronic", | |
"synth_lead_electronic", | |
"vocal_electronic", | |
] | |
model = load_model(device) | |
def generate(choice: Sequence[str], params: Sequence[int] = None): | |
noise = ( | |
torch.tensor(params).unsqueeze(0).to(device) | |
if params | |
else torch.randn(1, 5).to(device) | |
) | |
return ( | |
model.sample(eps=noise, c=choice_to_tensor(choice).to(device)).cpu().numpy()[0] | |
) | |