Hmjz100 commited on
Commit
1db8186
1 Parent(s): 37af952

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -35
app.py CHANGED
@@ -136,22 +136,21 @@ class InferenceModel(object):
136
  @property
137
  def input_shapes(self):
138
  return {
139
- 'encoder_input_tokens': (self.batch_size, self.inputs_length),
140
- 'decoder_input_tokens': (self.batch_size, self.outputs_length)
141
  }
142
 
143
  def _parse_gin(self, gin_files):
144
  """解析用于训练模型的 gin 文件。"""
145
  print(f"[{current_time()}] 日志:解析 gin 文件")
146
  gin_bindings = [
147
- 'from __gin__ import dynamic_registration',
148
- 'from mt3 import vocabularies',
149
150
- 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
151
  ]
152
  with gin.unlock_config():
153
- gin.parse_config_files_and_bindings(
154
- gin_files, gin_bindings, finalize_config=False)
155
 
156
  def _load_model(self):
157
  """在解析训练 gin 配置后加载 T5X `Model`。"""
@@ -159,11 +158,11 @@ class InferenceModel(object):
159
  model_config = gin.get_configurable(network.T5Config)()
160
  module = network.Transformer(config=model_config)
161
  return models.ContinuousInputsEncoderDecoderModel(
162
- module=module,
163
- input_vocabulary=self.output_features['inputs'].vocabulary,
164
- output_vocabulary=self.output_features['targets'].vocabulary,
165
- optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
166
- input_depth=spectrograms.input_depth(self.spectrogram_config))
167
 
168
 
169
  def restore_from_checkpoint(self, checkpoint_path):
@@ -176,33 +175,31 @@ class InferenceModel(object):
176
  partitioner=self.partitioner)
177
 
178
  restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
179
- path=checkpoint_path, mode='specific', dtype='float32')
180
 
181
  train_state_axes = train_state_initializer.train_state_axes
182
  self._predict_fn = self._get_predict_fn(train_state_axes)
183
  self._train_state = train_state_initializer.from_checkpoint_or_scratch(
184
- [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
185
 
186
  @functools.lru_cache()
187
  def _get_predict_fn(self, train_state_axes):
188
  """生成一个分区的预测函数用于解码。"""
189
  print(f"[{current_time()}] 日志:生成用于解码的预测函数")
190
  def partial_predict_fn(params, batch, decode_rng):
191
- return self.model.predict_batch_with_aux(
192
- params, batch, decoder_params={'decode_rng': None})
193
  return self.partitioner.partition(
194
- partial_predict_fn,
195
- in_axis_resources=(
196
- train_state_axes.params,
197
- t5x.partitioning.PartitionSpec('data',), None),
198
- out_axis_resources=t5x.partitioning.PartitionSpec('data',)
199
  )
200
 
201
  def predict_tokens(self, batch, seed=0):
202
  """从预处理的数据集批次中预测 tokens。"""
203
  print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
204
- prediction, _ = self._predict_fn(
205
- self._train_state.params, batch, jax.random.PRNGKey(seed))
206
  return self.vocabulary.decode_tf(prediction).numpy()
207
 
208
  def __call__(self, audio):
@@ -255,16 +252,16 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
255
  def preprocess(self, ds):
256
  pp_chain = [
257
  functools.partial(
258
- t5.data.preprocessors.split_tokens_to_inputs_length,
259
- sequence_length=self.sequence_length,
260
- output_features=self.output_features,
261
- feature_key='inputs',
262
- additional_feature_keys=['input_times']),
263
  # 在训练期间进行缓存。
264
  preprocessors.add_dummy_targets,
265
  functools.partial(
266
- preprocessors.compute_spectrograms,
267
- spectrogram_config=self.spectrogram_config)
268
  ]
269
  for pp in pp_chain:
270
  ds = pp(ds)
@@ -276,10 +273,10 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
276
  # 向下取整到最接近的符号化时间步。
277
  start_time -= start_time % (1 / self.codec.steps_per_second)
278
  return {
279
- 'est_tokens': tokens,
280
- 'start_time': start_time,
281
- # 内部 MT3 代码期望原始输入,这里不使用。
282
- 'raw_inputs': []
283
  }
284
 
285
  @staticmethod
 
136
  @property
137
  def input_shapes(self):
138
  return {
139
+ 'encoder_input_tokens': (self.batch_size, self.inputs_length),
140
+ 'decoder_input_tokens': (self.batch_size, self.outputs_length)
141
  }
142
 
143
  def _parse_gin(self, gin_files):
144
  """解析用于训练模型的 gin 文件。"""
145
  print(f"[{current_time()}] 日志:解析 gin 文件")
146
  gin_bindings = [
147
+ 'from __gin__ import dynamic_registration',
148
+ 'from mt3 import vocabularies',
149
150
+ 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
151
  ]
152
  with gin.unlock_config():
153
+ gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)
 
154
 
155
  def _load_model(self):
156
  """在解析训练 gin 配置后加载 T5X `Model`。"""
 
158
  model_config = gin.get_configurable(network.T5Config)()
159
  module = network.Transformer(config=model_config)
160
  return models.ContinuousInputsEncoderDecoderModel(
161
+ module=module,
162
+ input_vocabulary=self.output_features['inputs'].vocabulary,
163
+ output_vocabulary=self.output_features['targets'].vocabulary,
164
+ optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
165
+ input_depth=spectrograms.input_depth(self.spectrogram_config))
166
 
167
 
168
  def restore_from_checkpoint(self, checkpoint_path):
 
175
  partitioner=self.partitioner)
176
 
177
  restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
178
+ path=checkpoint_path, mode='specific', dtype='float32')
179
 
180
  train_state_axes = train_state_initializer.train_state_axes
181
  self._predict_fn = self._get_predict_fn(train_state_axes)
182
  self._train_state = train_state_initializer.from_checkpoint_or_scratch(
183
+ [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
184
 
185
  @functools.lru_cache()
186
  def _get_predict_fn(self, train_state_axes):
187
  """生成一个分区的预测函数用于解码。"""
188
  print(f"[{current_time()}] 日志:生成用于解码的预测函数")
189
  def partial_predict_fn(params, batch, decode_rng):
190
+ return self.model.predict_batch_with_aux(params, batch, decoder_params={'decode_rng': None})
 
191
  return self.partitioner.partition(
192
+ partial_predict_fn,
193
+ in_axis_resources=(
194
+ train_state_axes.params,
195
+ t5x.partitioning.PartitionSpec('data',), None),
196
+ out_axis_resources=t5x.partitioning.PartitionSpec('data',)
197
  )
198
 
199
  def predict_tokens(self, batch, seed=0):
200
  """从预处理的数据集批次中预测 tokens。"""
201
  print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
202
+ prediction, _ = self._predict_fn(self._train_state.params, batch, jax.random.PRNGKey(seed))
 
203
  return self.vocabulary.decode_tf(prediction).numpy()
204
 
205
  def __call__(self, audio):
 
252
  def preprocess(self, ds):
253
  pp_chain = [
254
  functools.partial(
255
+ t5.data.preprocessors.split_tokens_to_inputs_length,
256
+ sequence_length=self.sequence_length,
257
+ output_features=self.output_features,
258
+ feature_key='inputs',
259
+ additional_feature_keys=['input_times']),
260
  # 在训练期间进行缓存。
261
  preprocessors.add_dummy_targets,
262
  functools.partial(
263
+ preprocessors.compute_spectrograms,
264
+ spectrogram_config=self.spectrogram_config)
265
  ]
266
  for pp in pp_chain:
267
  ds = pp(ds)
 
273
  # 向下取整到最接近的符号化时间步。
274
  start_time -= start_time % (1 / self.codec.steps_per_second)
275
  return {
276
+ 'est_tokens': tokens,
277
+ 'start_time': start_time,
278
+ # 内部 MT3 代码期望原始输入,这里不使用。
279
+ 'raw_inputs': []
280
  }
281
 
282
  @staticmethod