Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -89,7 +89,7 @@ def GenerateGroove():
|
|
89 |
|
90 |
print('Sample input events', drums_score[:5])
|
91 |
print('=' * 70)
|
92 |
-
print('
|
93 |
|
94 |
num_prime_chords = 7
|
95 |
|
@@ -98,15 +98,54 @@ def GenerateGroove():
|
|
98 |
for d in drums_score[:num_prime_chords]:
|
99 |
|
100 |
outy.extend(d)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
for i in
|
103 |
|
104 |
outy.extend(drums_score[i])
|
105 |
|
106 |
if i == num_prime_chords:
|
107 |
-
outy.append(256+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
out =
|
110 |
|
111 |
outy.extend(out)
|
112 |
|
|
|
89 |
|
90 |
print('Sample input events', drums_score[:5])
|
91 |
print('=' * 70)
|
92 |
+
print('Prepping drums track...')
|
93 |
|
94 |
num_prime_chords = 7
|
95 |
|
|
|
98 |
for d in drums_score[:num_prime_chords]:
|
99 |
|
100 |
outy.extend(d)
|
101 |
+
|
102 |
+
print('Generating...')
|
103 |
+
|
104 |
+
max_notes_per_chord=8,
|
105 |
+
num_samples=4,
|
106 |
+
num_memory_tokens = 4096,
|
107 |
+
temperature=1.0):
|
108 |
|
109 |
+
for i in range(num_prime_chords, len(drums_score)):
|
110 |
|
111 |
outy.extend(drums_score[i])
|
112 |
|
113 |
if i == num_prime_chords:
|
114 |
+
outy.append(256+12)
|
115 |
+
|
116 |
+
input_seq = outy[-num_memory_tokens:]
|
117 |
+
|
118 |
+
seq = copy.deepcopy(input_seq)
|
119 |
+
|
120 |
+
batch_value = 256
|
121 |
+
|
122 |
+
nc = 0
|
123 |
+
|
124 |
+
while batch_value > 255 and nc < max_notes_per_chord:
|
125 |
+
|
126 |
+
x = torch.tensor([seq] * num_samples, dtype=torch.long, device='cuda')
|
127 |
+
|
128 |
+
with ctx:
|
129 |
+
out = model.generate(x,
|
130 |
+
1,
|
131 |
+
temperature=temperature,
|
132 |
+
return_prime=False,
|
133 |
+
verbose=False)
|
134 |
+
|
135 |
+
out1 = [o[0] for o in out.tolist() if o[0] > 255]
|
136 |
+
|
137 |
+
if not out1:
|
138 |
+
out1 = [-1]
|
139 |
+
|
140 |
+
batch_value = random.choice(out1)
|
141 |
+
|
142 |
+
if batch_value > 255:
|
143 |
+
seq.append(batch_value)
|
144 |
+
|
145 |
+
if batch_value > 383:
|
146 |
+
nc += 1
|
147 |
|
148 |
+
out = seq[len(input_seq):]
|
149 |
|
150 |
outy.extend(out)
|
151 |
|