File size: 2,018 Bytes
bdac835 6ed901c e72f4c2 f22a587 bdac835 e72f4c2 bdac835 6ed901c e72f4c2 6ed901c e72f4c2 6ed901c e72f4c2 6ed901c e72f4c2 6ed901c bdac835 e72f4c2 bdac835 e72f4c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from cvae import CVAE
import torch
from typing import Sequence
import streamlit as st
from lightning import LightningModule
def format_instruments(text: str) -> str:
stems = text.split(" ")[1:]
stems = [stem.replace(" ", "").lower() for stem in stems]
return "_".join(stems)
def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
choice = "_".join([format_instruments(i) for i in choice])
return torch.tensor(instruments.index(choice))
@st.cache_resource
def load_model(device: str) -> LightningModule:
return CVAE.load_from_checkpoint(
"epoch=77-step=2819778.ckpt",
io_channels=1,
io_features=16000 * 4,
latent_features=5,
channels=[32, 64, 128, 256, 512],
num_classes=len(instruments),
learning_rate=1e-5,
).to(device)
device = "cuda" if torch.cuda.is_available() else "cpu"
instruments = [
"bass_acoustic",
"brass_acoustic",
"flute_acoustic",
"guitar_acoustic",
"keyboard_acoustic",
"mallet_acoustic",
"organ_acoustic",
"reed_acoustic",
"string_acoustic",
"synth_lead_acoustic",
"vocal_acoustic",
"bass_synthetic",
"brass_synthetic",
"flute_synthetic",
"guitar_synthetic",
"keyboard_synthetic",
"mallet_synthetic",
"organ_synthetic",
"reed_synthetic",
"string_synthetic",
"synth_lead_synthetic",
"vocal_synthetic",
"bass_electronic",
"brass_electronic",
"flute_electronic",
"guitar_electronic",
"keyboard_electronic",
"mallet_electronic",
"organ_electronic",
"reed_electronic",
"string_electronic",
"synth_lead_electronic",
"vocal_electronic",
]
model = load_model(device)
def generate(choice: Sequence[str], params: Sequence[int] = None):
noise = (
torch.tensor(params).unsqueeze(0).to(device)
if params
else torch.randn(1, 5).to(device)
)
return (
model.sample(eps=noise, c=choice_to_tensor(choice).to(device)).cpu().numpy()[0]
)
|