asigalov61 commited on
Commit
c8071ec
·
verified ·
1 Parent(s): 2fdf5af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -4
app.py CHANGED
@@ -89,7 +89,7 @@ def GenerateGroove():
89
 
90
  print('Sample input events', drums_score[:5])
91
  print('=' * 70)
92
- print('Generating...')
93
 
94
  num_prime_chords = 7
95
 
@@ -98,15 +98,54 @@ def GenerateGroove():
98
  for d in drums_score[:num_prime_chords]:
99
 
100
  outy.extend(d)
 
 
 
 
 
 
 
101
 
102
- for i in tqdm.tqdm(range(num_prime_chords, len(drums_score))):
103
 
104
  outy.extend(drums_score[i])
105
 
106
  if i == num_prime_chords:
107
- outy.append(256+8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- out = generate_chords(outy)
110
 
111
  outy.extend(out)
112
 
 
89
 
90
  print('Sample input events', drums_score[:5])
91
  print('=' * 70)
92
+ print('Prepping drums track...')
93
 
94
  num_prime_chords = 7
95
 
 
98
  for d in drums_score[:num_prime_chords]:
99
 
100
  outy.extend(d)
101
+
102
+ print('Generating...')
103
+
104
+ max_notes_per_chord=8,
105
+ num_samples=4,
106
+ num_memory_tokens = 4096,
107
+ temperature=1.0):
108
 
109
+ for i in range(num_prime_chords, len(drums_score)):
110
 
111
  outy.extend(drums_score[i])
112
 
113
  if i == num_prime_chords:
114
+ outy.append(256+12)
115
+
116
+ input_seq = outy[-num_memory_tokens:]
117
+
118
+ seq = copy.deepcopy(input_seq)
119
+
120
+ batch_value = 256
121
+
122
+ nc = 0
123
+
124
+ while batch_value > 255 and nc < max_notes_per_chord:
125
+
126
+ x = torch.tensor([seq] * num_samples, dtype=torch.long, device='cuda')
127
+
128
+ with ctx:
129
+ out = model.generate(x,
130
+ 1,
131
+ temperature=temperature,
132
+ return_prime=False,
133
+ verbose=False)
134
+
135
+ out1 = [o[0] for o in out.tolist() if o[0] > 255]
136
+
137
+ if not out1:
138
+ out1 = [-1]
139
+
140
+ batch_value = random.choice(out1)
141
+
142
+ if batch_value > 255:
143
+ seq.append(batch_value)
144
+
145
+ if batch_value > 383:
146
+ nc += 1
147
 
148
+ out = seq[len(input_seq):]
149
 
150
  outy.extend(out)
151