Hmjz100 commited on
Commit
f9a5f59
·
verified ·
1 Parent(s): 804bc89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -76
app.py CHANGED
@@ -133,81 +133,174 @@ class InferenceModel(object):
133
  self.restore_from_checkpoint(checkpoint_path)
134
 
135
  @property
136
- def input_shapes(self):
137
- return {
138
- 'encoder_input_tokens': (self.batch_size, self.inputs_length),
139
- 'decoder_input_tokens': (self.batch_size, self.outputs_length)
140
- }
141
-
142
- def _parse_gin(self, gin_files):
143
- print(f"[{current_time()}] 日志:解析 gin 文件")
144
- gin_bindings = [
145
- 'from __gin__ import dynamic_registration',
146
- 'from mt3 import vocabularies',
147
- 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
148
- 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
149
- ]
150
- with gin.unlock_config():
151
- gin.parse_config_files_and_bindings(
152
- gin_files, gin_bindings, finalize_config=False)
153
-
154
- def _load_model(self):
155
- print(f"[{current_time()}] 日志:加载 T5X 模型")
156
- model_config = gin.get_configurable(network.T5Config)()
157
- module = network.Transformer(config=model_config)
158
- return models.ContinuousInputsEncoderDecoderModel(
159
- module=module,
160
- input_vocabulary=self.output_features['inputs'].vocabulary,
161
- output_vocabulary=self.output_features['targets'].vocabulary,
162
- optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
163
- input_depth=spectrograms.input_depth(self.spectrogram_config))
164
-
165
-
166
- def restore_from_checkpoint(self, checkpoint_path):
167
- print(f"[{current_time()}] 日志:从检查点恢复训练状态")
168
- train_state_initializer = t5x.utils.TrainStateInitializer(
169
- optimizer_def=self.model.optimizer_def,
170
- init_fn=self.model.get_initial_variables,
171
- input_shapes=self.input_shapes,
172
- partitioner=self.partitioner)
173
-
174
- restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
175
- path=checkpoint_path, mode='specific', dtype='float32')
176
-
177
- train_state_axes = train_state_initializer.train_state_axes
178
- self._predict_fn = self._get_predict_fn(train_state_axes)
179
- self._train_state = train_state_initializer.from_checkpoint_or_scratch(
180
- [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
181
-
182
- @functools.lru_cache()
183
- def _get_predict_fn(self, train_state_axes):
184
- print(f"[{current_time()}] 日志:生成用于解码的预测函数")
185
- def partial_predict_fn(params, batch, decode_rng):
186
- return self.model.predict_batch_with_aux(
187
- params, batch, decoder_params={'decode_rng': None})
188
- return self.partitioner.partition(
189
- partial_predict_fn,
190
- in_axis_resources=(
191
- train_state_axes.params,
192
- t5x.partitioning.PartitionSpec('data',), None),
193
- out_axis_resources=t5x.partitioning.PartitionSpec('data',)
194
- )
195
-
196
- def predict_tokens(self, batch, seed=0):
197
- print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
198
- prediction, _ = self._predict_fn(
199
- self._train_state.params, batch, jax.random.PRNGKey(seed))
200
- return self.vocabulary.decode_tf(prediction).numpy()
201
-
202
- def __call__(self, audio):
203
- filename = os.path.basename(audio) # 获取输入文件的文件名
204
- print(f"[{current_time()}] 运行:输入文件: {filename}")
205
- with open(audio, 'rb') as fd:
206
- contents = fd.read()
207
- audio = upload_audio(contents,sample_rate=16000)
208
- est_ns = inference_model(audio)
209
- note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
210
- return './transcribed.mid'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  title = "MT3"
213
  description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以查看效果。更多信息请参阅下面的链接。"
@@ -224,4 +317,4 @@ gr.Interface(
224
  description=description,
225
  article=article,
226
  examples=examples
227
- ).launch(server_port=7861)
 
133
  self.restore_from_checkpoint(checkpoint_path)
134
 
135
  @property
136
+ def input_shapes(self):
137
+ return {
138
+ 'encoder_input_tokens': (self.batch_size, self.inputs_length),
139
+ 'decoder_input_tokens': (self.batch_size, self.outputs_length)
140
+ }
141
+
142
+ def _parse_gin(self, gin_files):
143
+ """解析用于训练模型的 gin 文件。"""
144
+ print(f"[{current_time()}] 日志:解析 gin 文件")
145
+ gin_bindings = [
146
+ 'from __gin__ import dynamic_registration',
147
+ 'from mt3 import vocabularies',
148
+ 'VOCAB_CONFIG=@vocabularies.VocabularyConfig()',
149
+ 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
150
+ ]
151
+ with gin.unlock_config():
152
+ gin.parse_config_files_and_bindings(
153
+ gin_files, gin_bindings, finalize_config=False)
154
+
155
+ def _load_model(self):
156
+ """在解析训练 gin 配置后加载 T5X `Model`。"""
157
+ print(f"[{current_time()}] 日志:加载 T5X 模型")
158
+ model_config = gin.get_configurable(network.T5Config)()
159
+ module = network.Transformer(config=model_config)
160
+ return models.ContinuousInputsEncoderDecoderModel(
161
+ module=module,
162
+ input_vocabulary=self.output_features['inputs'].vocabulary,
163
+ output_vocabulary=self.output_features['targets'].vocabulary,
164
+ optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
165
+ input_depth=spectrograms.input_depth(self.spectrogram_config))
166
+
167
+
168
+ def restore_from_checkpoint(self, checkpoint_path):
169
+ """从检查点中恢复训练状态,重置 self._predict_fn()。"""
170
+ print(f"[{current_time()}] 日志:从检查点恢复训练状态")
171
+ train_state_initializer = t5x.utils.TrainStateInitializer(
172
+ optimizer_def=self.model.optimizer_def,
173
+ init_fn=self.model.get_initial_variables,
174
+ input_shapes=self.input_shapes,
175
+ partitioner=self.partitioner)
176
+
177
+ restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
178
+ path=checkpoint_path, mode='specific', dtype='float32')
179
+
180
+ train_state_axes = train_state_initializer.train_state_axes
181
+ self._predict_fn = self._get_predict_fn(train_state_axes)
182
+ self._train_state = train_state_initializer.from_checkpoint_or_scratch(
183
+ [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
184
+
185
+ @functools.lru_cache()
186
+ def _get_predict_fn(self, train_state_axes):
187
+ """生成一个分区的预测函数用于解码。"""
188
+ print(f"[{current_time()}] 日志:生成用于解码的预测函数")
189
+ def partial_predict_fn(params, batch, decode_rng):
190
+ return self.model.predict_batch_with_aux(
191
+ params, batch, decoder_params={'decode_rng': None})
192
+ return self.partitioner.partition(
193
+ partial_predict_fn,
194
+ in_axis_resources=(
195
+ train_state_axes.params,
196
+ t5x.partitioning.PartitionSpec('data',), None),
197
+ out_axis_resources=t5x.partitioning.PartitionSpec('data',)
198
+ )
199
+
200
+ def predict_tokens(self, batch, seed=0):
201
+ """从预处理的数据集批次中预测 tokens。"""
202
+ print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
203
+ prediction, _ = self._predict_fn(
204
+ self._train_state.params, batch, jax.random.PRNGKey(seed))
205
+ return self.vocabulary.decode_tf(prediction).numpy()
206
+
207
+ def __call__(self, audio):
208
+ """从音频样本推断出音符序列。
209
+
210
+ 参数:
211
+ audio:16kHz 的单个音频样本的 1 维 numpy 数组。
212
+ 返回:
213
+ 转录音频的音符序列。
214
+ """
215
+ print(f"[{current_time()}] 运行:从音频样本中推断音符序列")
216
+ ds = self.audio_to_dataset(audio)
217
+ ds = self.preprocess(ds)
218
+
219
+ model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
220
+ ds, task_feature_lengths=self.sequence_length)
221
+ model_ds = model_ds.batch(self.batch_size)
222
+
223
+ inferences = (tokens for batch in model_ds.as_numpy_iterator()
224
+ for tokens in self.predict_tokens(batch))
225
+
226
+ predictions = []
227
+ for example, tokens in zip(ds.as_numpy_iterator(), inferences):
228
+ predictions.append(self.postprocess(tokens, example))
229
+
230
+ result = metrics_utils.event_predictions_to_ns(
231
+ predictions, codec=self.codec, encoding_spec=self.encoding_spec)
232
+ return result['est_ns']
233
+
234
+ def audio_to_dataset(self, audio):
235
+ """从输入音频创建一个包含频谱图的 TF Dataset。"""
236
+ print(f"[{current_time()}] 运行:从音频创建包含频谱图的 TF Dataset")
237
+ frames, frame_times = self._audio_to_frames(audio)
238
+ return tf.data.Dataset.from_tensors({
239
+ 'inputs': frames,
240
+ 'input_times': frame_times,
241
+ })
242
+
243
+ def _audio_to_frames(self, audio):
244
+ """从音频计算频谱图帧。"""
245
+ print(f"[{current_time()}] 运行:从音频计算频谱图帧")
246
+ frame_size = self.spectrogram_config.hop_width
247
+ padding = [0, frame_size - len(audio) % frame_size]
248
+ audio = np.pad(audio, padding, mode='constant')
249
+ frames = spectrograms.split_audio(audio, self.spectrogram_config)
250
+ num_frames = len(audio) // frame_size
251
+ times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
252
+ return frames, times
253
+
254
+ def preprocess(self, ds):
255
+ pp_chain = [
256
+ functools.partial(
257
+ t5.data.preprocessors.split_tokens_to_inputs_length,
258
+ sequence_length=self.sequence_length,
259
+ output_features=self.output_features,
260
+ feature_key='inputs',
261
+ additional_feature_keys=['input_times']),
262
+ # 在训练期间进行缓存。
263
+ preprocessors.add_dummy_targets,
264
+ functools.partial(
265
+ preprocessors.compute_spectrograms,
266
+ spectrogram_config=self.spectrogram_config)
267
+ ]
268
+ for pp in pp_chain:
269
+ ds = pp(ds)
270
+ return ds
271
+
272
+ def postprocess(self, tokens, example):
273
+ tokens = self._trim_eos(tokens)
274
+ start_time = example['input_times'][0]
275
+ # 向下取整到最接近的符号化时间步。
276
+ start_time -= start_time % (1 / self.codec.steps_per_second)
277
+ return {
278
+ 'est_tokens': tokens,
279
+ 'start_time': start_time,
280
+ # 内部 MT3 代码期望原始输入,这里不使用。
281
+ 'raw_inputs': []
282
+ }
283
+
284
+ @staticmethod
285
+ def _trim_eos(tokens):
286
+ tokens = np.array(tokens, np.int32)
287
+ if vocabularies.DECODED_EOS_ID in tokens:
288
+ tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
289
+ return tokens
290
+
291
+
292
+ inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
293
+
294
+
295
+ def inference(audio):
296
+ filename = os.path.basename(audio) # 获取输入文件的文件名
297
+ print(f"[{current_time()}] 运行:输入文件: {filename}")
298
+ with open(audio, 'rb') as fd:
299
+ contents = fd.read()
300
+ audio = upload_audio(contents,sample_rate=16000)
301
+ est_ns = inference_model(audio)
302
+ note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
303
+ return './transcribed.mid'
304
 
305
  title = "MT3"
306
  description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以查看效果。更多信息请参阅下面的链接。"
 
317
  description=description,
318
  article=article,
319
  examples=examples
320
+ ).launch(server_port=7861)