Spaces:
Running
on
Zero
Running
on
Zero
Upload app_onnx.py
Browse files- app_onnx.py +16 -26
app_onnx.py
CHANGED
@@ -29,6 +29,7 @@ def softmax(x, axis):
|
|
29 |
exp_x_shifted = np.exp(x - x_max)
|
30 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
31 |
|
|
|
32 |
def sample_top_p_k(probs, p, k, generator=None):
|
33 |
if generator is None:
|
34 |
generator = np.random
|
@@ -48,9 +49,10 @@ def sample_top_p_k(probs, p, k, generator=None):
|
|
48 |
next_token = next_token.reshape(*shape[:-1])
|
49 |
return next_token
|
50 |
|
|
|
51 |
def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
|
52 |
io_binding = model.io_binding()
|
53 |
-
for input_ in
|
54 |
name = input_.name
|
55 |
if name.startswith("past_key_values"):
|
56 |
present_name = name.replace("past_key_values", "present")
|
@@ -80,8 +82,7 @@ def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, pa
|
|
80 |
return io_binding
|
81 |
|
82 |
def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
83 |
-
disable_patch_change=False, disable_control_change=False, disable_channels=None,
|
84 |
-
repetition_penalty=1.0, generator=None):
|
85 |
tokenizer = model[2]
|
86 |
if disable_channels is not None:
|
87 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
@@ -106,7 +107,7 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
106 |
prompt = prompt[..., :max_token_seq]
|
107 |
if prompt.shape[-1] < max_token_seq:
|
108 |
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
|
109 |
-
|
110 |
input_tensor = prompt
|
111 |
cur_len = input_tensor.shape[1]
|
112 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
|
@@ -161,6 +162,7 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
161 |
mask = mask[:, None, :]
|
162 |
x = next_token_seq
|
163 |
if i != 0:
|
|
|
164 |
if i == 1:
|
165 |
hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
|
166 |
model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
|
@@ -176,16 +178,6 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
176 |
model[1].run_with_iobinding(io_binding)
|
177 |
io_binding.synchronize_outputs()
|
178 |
logits = model1_outputs["y"].numpy()
|
179 |
-
|
180 |
-
# Apply repetition penalty
|
181 |
-
if repetition_penalty != 1.0:
|
182 |
-
for b in range(batch_size):
|
183 |
-
if not end[b]:
|
184 |
-
prev_tokens = input_tensor[b, :cur_len].tolist()
|
185 |
-
used_tokens = set(prev_tokens)
|
186 |
-
for token in used_tokens:
|
187 |
-
logits[b, :, token] /= repetition_penalty
|
188 |
-
|
189 |
scores = softmax(logits / temp, -1) * mask
|
190 |
samples = sample_top_p_k(scores, top_p, top_k, generator)
|
191 |
if i == 0:
|
@@ -204,8 +196,8 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
|
|
204 |
break
|
205 |
if next_token_seq.shape[1] < max_token_seq:
|
206 |
next_token_seq = np.pad(next_token_seq,
|
207 |
-
|
208 |
-
|
209 |
next_token_seq = next_token_seq[:, None, :]
|
210 |
input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
|
211 |
past_len = cur_len
|
@@ -594,12 +586,10 @@ if __name__ == "__main__":
|
|
594 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
595 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
|
596 |
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
597 |
-
input_rep_penalty = gr.Slider(label="repetition penalty", minimum=1.0, maximum=2.0,
|
598 |
-
step=0.05, value=1.0)
|
599 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
600 |
input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
|
601 |
example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
|
602 |
-
|
603 |
run_btn = gr.Button("generate", variant="primary")
|
604 |
# stop_btn = gr.Button("stop and output")
|
605 |
output_midi_seq = gr.State()
|
@@ -615,13 +605,13 @@ if __name__ == "__main__":
|
|
615 |
midi_outputs.append(output_midi)
|
616 |
audio_outputs.append(output_audio)
|
617 |
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
finish_run_event = run_event.then(fn=finish_run,
|
626 |
inputs=[input_model, output_midi_seq],
|
627 |
outputs=midi_outputs + [js_msg],
|
|
|
29 |
exp_x_shifted = np.exp(x - x_max)
|
30 |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
|
31 |
|
32 |
+
|
33 |
def sample_top_p_k(probs, p, k, generator=None):
|
34 |
if generator is None:
|
35 |
generator = np.random
|
|
|
49 |
next_token = next_token.reshape(*shape[:-1])
|
50 |
return next_token
|
51 |
|
52 |
+
|
53 |
def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
|
54 |
io_binding = model.io_binding()
|
55 |
+
for input_ in model.get_inputs():
|
56 |
name = input_.name
|
57 |
if name.startswith("past_key_values"):
|
58 |
present_name = name.replace("past_key_values", "present")
|
|
|
82 |
return io_binding
|
83 |
|
84 |
def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
|
85 |
+
disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
|
|
|
86 |
tokenizer = model[2]
|
87 |
if disable_channels is not None:
|
88 |
disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
|
|
|
107 |
prompt = prompt[..., :max_token_seq]
|
108 |
if prompt.shape[-1] < max_token_seq:
|
109 |
prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
|
110 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
111 |
input_tensor = prompt
|
112 |
cur_len = input_tensor.shape[1]
|
113 |
bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
|
|
|
162 |
mask = mask[:, None, :]
|
163 |
x = next_token_seq
|
164 |
if i != 0:
|
165 |
+
# cached
|
166 |
if i == 1:
|
167 |
hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
|
168 |
model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
|
|
|
178 |
model[1].run_with_iobinding(io_binding)
|
179 |
io_binding.synchronize_outputs()
|
180 |
logits = model1_outputs["y"].numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
scores = softmax(logits / temp, -1) * mask
|
182 |
samples = sample_top_p_k(scores, top_p, top_k, generator)
|
183 |
if i == 0:
|
|
|
196 |
break
|
197 |
if next_token_seq.shape[1] < max_token_seq:
|
198 |
next_token_seq = np.pad(next_token_seq,
|
199 |
+
((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
|
200 |
+
mode="constant", constant_values=tokenizer.pad_id)
|
201 |
next_token_seq = next_token_seq[:, None, :]
|
202 |
input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
|
203 |
past_len = cur_len
|
|
|
586 |
input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
|
587 |
input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
|
588 |
input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
|
|
|
|
|
589 |
input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
|
590 |
input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
|
591 |
example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
|
592 |
+
[input_temp, input_top_p, input_top_k])
|
593 |
run_btn = gr.Button("generate", variant="primary")
|
594 |
# stop_btn = gr.Button("stop and output")
|
595 |
output_midi_seq = gr.State()
|
|
|
605 |
midi_outputs.append(output_midi)
|
606 |
audio_outputs.append(output_audio)
|
607 |
run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
|
608 |
+
input_continuation_select, input_instruments, input_drum_kit, input_bpm,
|
609 |
+
input_time_sig, input_key_sig, input_midi, input_midi_events,
|
610 |
+
input_reduce_cc_st, input_remap_track_channel,
|
611 |
+
input_add_default_instr, input_remove_empty_channels,
|
612 |
+
input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
|
613 |
+
input_top_k, input_allow_cc],
|
614 |
+
[output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
|
615 |
finish_run_event = run_event.then(fn=finish_run,
|
616 |
inputs=[input_model, output_midi_seq],
|
617 |
outputs=midi_outputs + [js_msg],
|