Update app.py
Browse files
app.py
CHANGED
@@ -163,7 +163,7 @@ def ClassifyMIDI(input_midi, input_sampling_resolution):
|
|
163 |
|
164 |
SEQ_LEN = 1026
|
165 |
PAD_IDX = 940
|
166 |
-
DEVICE = '
|
167 |
|
168 |
# instantiate the model
|
169 |
|
@@ -216,7 +216,7 @@ def ClassifyMIDI(input_midi, input_sampling_resolution):
|
|
216 |
|
217 |
for input in input_data:
|
218 |
|
219 |
-
x = torch.tensor(input[:1022], dtype=torch.long, device=
|
220 |
|
221 |
with ctx:
|
222 |
out = model.module.generate(x,
|
|
|
163 |
|
164 |
SEQ_LEN = 1026
|
165 |
PAD_IDX = 940
|
166 |
+
DEVICE = 'cpu' # 'cuda'
|
167 |
|
168 |
# instantiate the model
|
169 |
|
|
|
216 |
|
217 |
for input in input_data:
|
218 |
|
219 |
+
x = torch.tensor(input[:1022], dtype=torch.long, device=DEVICE)
|
220 |
|
221 |
with ctx:
|
222 |
out = model.module.generate(x,
|