asigalov61 commited on
Commit
0fd605a
·
verified ·
1 Parent(s): c8d891d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -43
app.py CHANGED
@@ -19,41 +19,7 @@ import TMIDIX
19
  import matplotlib.pyplot as plt
20
 
21
  in_space = os.getenv("SYSTEM") == "spaces"
22
-
23
- # =================================================================================================
24
-
25
- @spaces.GPU
26
- def generate_drums(notes_times,
27
- max_drums_limit = 8,
28
- num_memory_tokens = 4096,
29
- temperature=0.9):
30
-
31
- ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16)
32
-
33
- x = torch.tensor([notes_times] * 1, dtype=torch.long, device='cuda')
34
-
35
- o = 128
36
-
37
- ncount = 0
38
-
39
- while o > 127 and ncount < max_drums_limit:
40
- with ctx:
41
- out = model.generate(x[-num_memory_tokens:],
42
- 1,
43
- temperature=temperature,
44
- return_prime=False,
45
- verbose=False)
46
-
47
- o = out.tolist()[0][0]
48
-
49
- if 256 <= o < 384:
50
- ncount += 1
51
-
52
- if o > 127:
53
- x = torch.cat((x, out), 1)
54
-
55
- return x.tolist()[0][len(notes_times):]
56
-
57
  # =================================================================================================
58
 
59
  @spaces.GPU
@@ -150,15 +116,38 @@ def GenerateDrums(input_midi, input_num_tokens):
150
 
151
  output = []
152
 
 
 
 
 
153
  for c in comp_times[:input_num_tokens]:
154
- output.append(c)
155
-
156
- out = generate_drums(output,
157
- temperature=0.9,
158
- max_drums_limit=8,
159
- num_memory_tokens=4096
160
- )
161
- output.extend(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  print('=' * 70)
164
  print('Done!')
 
19
  import matplotlib.pyplot as plt
20
 
21
  in_space = os.getenv("SYSTEM") == "spaces"
22
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # =================================================================================================
24
 
25
  @spaces.GPU
 
116
 
117
  output = []
118
 
119
+ temperature=0.9,
120
+ max_drums_limit=8,
121
+ num_memory_tokens=4096
122
+
123
  for c in comp_times[:input_num_tokens]:
124
+ output.append(c)
125
+
126
+ x = torch.tensor([output] * 1, dtype=torch.long, device=DEVICE)
127
+
128
+ o = 128
129
+
130
+ ncount = 0
131
+
132
+ while o > 127 and ncount < max_drums_limit:
133
+ with ctx:
134
+ out = model.generate(x[-num_memory_tokens:],
135
+ 1,
136
+ temperature=temperature,
137
+ return_prime=False,
138
+ verbose=False)
139
+
140
+ o = out.tolist()[0][0]
141
+
142
+ if 256 <= o < 384:
143
+ ncount += 1
144
+
145
+ if o > 127:
146
+ x = torch.cat((x, out), 1)
147
+
148
+ x_output = x.tolist()[0][len(notes_times):]
149
+
150
+ output.extend(x_output1)
151
 
152
  print('=' * 70)
153
  print('Done!')