breadlicker45 commited on
Commit
77ba504
·
verified ·
1 Parent(s): 5d8ea59

Upload app_onnx.py

Browse files
Files changed (1) hide show
  1. app_onnx.py +16 -26
app_onnx.py CHANGED
@@ -29,6 +29,7 @@ def softmax(x, axis):
29
  exp_x_shifted = np.exp(x - x_max)
30
  return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
31
 
 
32
  def sample_top_p_k(probs, p, k, generator=None):
33
  if generator is None:
34
  generator = np.random
@@ -48,9 +49,10 @@ def sample_top_p_k(probs, p, k, generator=None):
48
  next_token = next_token.reshape(*shape[:-1])
49
  return next_token
50
 
 
51
  def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
52
  io_binding = model.io_binding()
53
- for input_ in model.get_inputs():
54
  name = input_.name
55
  if name.startswith("past_key_values"):
56
  present_name = name.replace("past_key_values", "present")
@@ -80,8 +82,7 @@ def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, pa
80
  return io_binding
81
 
82
  def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
83
- disable_patch_change=False, disable_control_change=False, disable_channels=None,
84
- repetition_penalty=1.0, generator=None):
85
  tokenizer = model[2]
86
  if disable_channels is not None:
87
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
@@ -106,7 +107,7 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
106
  prompt = prompt[..., :max_token_seq]
107
  if prompt.shape[-1] < max_token_seq:
108
  prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
109
- mode="constant", constant_values=tokenizer.pad_id)
110
  input_tensor = prompt
111
  cur_len = input_tensor.shape[1]
112
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
@@ -161,6 +162,7 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
161
  mask = mask[:, None, :]
162
  x = next_token_seq
163
  if i != 0:
 
164
  if i == 1:
165
  hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
166
  model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
@@ -176,16 +178,6 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
176
  model[1].run_with_iobinding(io_binding)
177
  io_binding.synchronize_outputs()
178
  logits = model1_outputs["y"].numpy()
179
-
180
- # Apply repetition penalty
181
- if repetition_penalty != 1.0:
182
- for b in range(batch_size):
183
- if not end[b]:
184
- prev_tokens = input_tensor[b, :cur_len].tolist()
185
- used_tokens = set(prev_tokens)
186
- for token in used_tokens:
187
- logits[b, :, token] /= repetition_penalty
188
-
189
  scores = softmax(logits / temp, -1) * mask
190
  samples = sample_top_p_k(scores, top_p, top_k, generator)
191
  if i == 0:
@@ -204,8 +196,8 @@ def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98
204
  break
205
  if next_token_seq.shape[1] < max_token_seq:
206
  next_token_seq = np.pad(next_token_seq,
207
- ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
208
- mode="constant", constant_values=tokenizer.pad_id)
209
  next_token_seq = next_token_seq[:, None, :]
210
  input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
211
  past_len = cur_len
@@ -594,12 +586,10 @@ if __name__ == "__main__":
594
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
595
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
596
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
597
- input_rep_penalty = gr.Slider(label="repetition penalty", minimum=1.0, maximum=2.0,
598
- step=0.05, value=1.0)
599
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
600
  input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
601
  example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
602
- [input_temp, input_top_p, input_top_k])
603
  run_btn = gr.Button("generate", variant="primary")
604
  # stop_btn = gr.Button("stop and output")
605
  output_midi_seq = gr.State()
@@ -615,13 +605,13 @@ if __name__ == "__main__":
615
  midi_outputs.append(output_midi)
616
  audio_outputs.append(output_audio)
617
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
618
- input_continuation_select, input_instruments, input_drum_kit, input_bpm,
619
- input_time_sig, input_key_sig, input_midi, input_midi_events,
620
- input_reduce_cc_st, input_remap_track_channel,
621
- input_add_default_instr, input_remove_empty_channels,
622
- input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
623
- input_top_k, input_rep_penalty, input_allow_cc],
624
- [output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
625
  finish_run_event = run_event.then(fn=finish_run,
626
  inputs=[input_model, output_midi_seq],
627
  outputs=midi_outputs + [js_msg],
 
29
  exp_x_shifted = np.exp(x - x_max)
30
  return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
31
 
32
+
33
  def sample_top_p_k(probs, p, k, generator=None):
34
  if generator is None:
35
  generator = np.random
 
49
  next_token = next_token.reshape(*shape[:-1])
50
  return next_token
51
 
52
+
53
  def apply_io_binding(model: rt.InferenceSession, inputs, outputs, batch_size, past_len, cur_len):
54
  io_binding = model.io_binding()
55
+ for input_ in model.get_inputs():
56
  name = input_.name
57
  if name.startswith("past_key_values"):
58
  present_name = name.replace("past_key_values", "present")
 
82
  return io_binding
83
 
84
  def generate(model, prompt=None, batch_size=1, max_len=512, temp=1.0, top_p=0.98, top_k=20,
85
+ disable_patch_change=False, disable_control_change=False, disable_channels=None, generator=None):
 
86
  tokenizer = model[2]
87
  if disable_channels is not None:
88
  disable_channels = [tokenizer.parameter_ids["channel"][c] for c in disable_channels]
 
107
  prompt = prompt[..., :max_token_seq]
108
  if prompt.shape[-1] < max_token_seq:
109
  prompt = np.pad(prompt, ((0, 0), (0, 0), (0, max_token_seq - prompt.shape[-1])),
110
+ mode="constant", constant_values=tokenizer.pad_id)
111
  input_tensor = prompt
112
  cur_len = input_tensor.shape[1]
113
  bar = tqdm.tqdm(desc="generating", total=max_len - cur_len, disable=in_space)
 
162
  mask = mask[:, None, :]
163
  x = next_token_seq
164
  if i != 0:
165
+ # cached
166
  if i == 1:
167
  hidden = np.zeros((batch_size, 0, emb_size), dtype=np.float32)
168
  model1_inputs["hidden"] = rt.OrtValue.ortvalue_from_numpy(hidden, device_type=device)
 
178
  model[1].run_with_iobinding(io_binding)
179
  io_binding.synchronize_outputs()
180
  logits = model1_outputs["y"].numpy()
 
 
 
 
 
 
 
 
 
 
181
  scores = softmax(logits / temp, -1) * mask
182
  samples = sample_top_p_k(scores, top_p, top_k, generator)
183
  if i == 0:
 
196
  break
197
  if next_token_seq.shape[1] < max_token_seq:
198
  next_token_seq = np.pad(next_token_seq,
199
+ ((0, 0), (0, max_token_seq - next_token_seq.shape[-1])),
200
+ mode="constant", constant_values=tokenizer.pad_id)
201
  next_token_seq = next_token_seq[:, None, :]
202
  input_tensor = np.concatenate([input_tensor, next_token_seq], axis=1)
203
  past_len = cur_len
 
586
  input_temp = gr.Slider(label="temperature", minimum=0.1, maximum=1.2, step=0.01, value=1)
587
  input_top_p = gr.Slider(label="top p", minimum=0.1, maximum=1, step=0.01, value=0.95)
588
  input_top_k = gr.Slider(label="top k", minimum=1, maximum=128, step=1, value=20)
 
 
589
  input_allow_cc = gr.Checkbox(label="allow midi cc event", value=True)
590
  input_render_audio = gr.Checkbox(label="render audio after generation", value=True)
591
  example3 = gr.Examples([[1, 0.94, 128], [1, 0.98, 20], [1, 0.98, 12]],
592
+ [input_temp, input_top_p, input_top_k])
593
  run_btn = gr.Button("generate", variant="primary")
594
  # stop_btn = gr.Button("stop and output")
595
  output_midi_seq = gr.State()
 
605
  midi_outputs.append(output_midi)
606
  audio_outputs.append(output_audio)
607
  run_event = run_btn.click(run, [input_model, tab_select, output_midi_seq, output_continuation_state,
608
+ input_continuation_select, input_instruments, input_drum_kit, input_bpm,
609
+ input_time_sig, input_key_sig, input_midi, input_midi_events,
610
+ input_reduce_cc_st, input_remap_track_channel,
611
+ input_add_default_instr, input_remove_empty_channels,
612
+ input_seed, input_seed_rand, input_gen_events, input_temp, input_top_p,
613
+ input_top_k, input_allow_cc],
614
+ [output_midi_seq, output_continuation_state, input_seed, js_msg], queue=True)
615
  finish_run_event = run_event.then(fn=finish_run,
616
  inputs=[input_model, output_midi_seq],
617
  outputs=midi_outputs + [js_msg],