auto cuda
Browse files
model.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
from cvae import CVAE
|
2 |
import torch
|
3 |
from typing import Sequence
|
4 |
-
|
|
|
5 |
|
6 |
instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
|
7 |
|
@@ -13,7 +14,7 @@ model = CVAE.load_from_checkpoint(
|
|
13 |
channels=[32, 64, 128, 256, 512],
|
14 |
num_classes=len(instruments),
|
15 |
learning_rate=1e-5
|
16 |
-
)
|
17 |
|
18 |
def format(text):
|
19 |
text = text.split(' ')[-1]
|
@@ -24,5 +25,5 @@ def choice_to_tensor(choice: Sequence[str]) -> torch.Tensor:
|
|
24 |
return torch.tensor(instruments.index(choice))
|
25 |
|
26 |
def generate(choice: Sequence[str], params: Sequence[int]=None):
|
27 |
-
noise = torch.tensor(params).unsqueeze(0).to(
|
28 |
-
return model.sample(eps=noise, c = choice_to_tensor(choice).to(
|
|
|
1 |
from cvae import CVAE
|
2 |
import torch
|
3 |
from typing import Sequence
|
4 |
+
|
5 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
6 |
|
7 |
instruments = ['bass_acoustic', 'brass_acoustic', 'flute_acoustic', 'guitar_acoustic', 'keyboard_acoustic', 'mallet_acoustic', 'organ_acoustic', 'reed_acoustic', 'string_acoustic', 'synth_lead_acoustic', 'vocal_acoustic', 'bass_synthetic', 'brass_synthetic', 'flute_synthetic', 'guitar_synthetic', 'keyboard_synthetic', 'mallet_synthetic', 'organ_synthetic', 'reed_synthetic', 'string_synthetic', 'synth_lead_synthetic', 'vocal_synthetic', 'bass_electronic', 'brass_electronic', 'flute_electronic', 'guitar_electronic', 'keyboard_electronic', 'mallet_electronic', 'organ_electronic', 'reed_electronic', 'string_electronic', 'synth_lead_electronic', 'vocal_electronic']
|
8 |
|
|
|
14 |
channels=[32, 64, 128, 256, 512],
|
15 |
num_classes=len(instruments),
|
16 |
learning_rate=1e-5
|
17 |
+
).to(device)
|
18 |
|
19 |
def format(text):
|
20 |
text = text.split(' ')[-1]
|
|
|
25 |
return torch.tensor(instruments.index(choice))
|
26 |
|
27 |
def generate(choice: Sequence[str], params: Sequence[int]=None):
|
28 |
+
noise = torch.tensor(params).unsqueeze(0).to(device) if params else torch.randn(1, 5).to('cuda')
|
29 |
+
return model.sample(eps=noise, c = choice_to_tensor(choice).to(device)).cpu().numpy()[0]
|