Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -136,22 +136,21 @@ class InferenceModel(object):
|
|
136 |
@property
|
137 |
def input_shapes(self):
|
138 |
return {
|
139 |
-
|
140 |
-
|
141 |
}
|
142 |
|
143 |
def _parse_gin(self, gin_files):
|
144 |
"""解析用于训练模型的 gin 文件。"""
|
145 |
print(f"[{current_time()}] 日志:解析 gin 文件")
|
146 |
gin_bindings = [
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
]
|
152 |
with gin.unlock_config():
|
153 |
-
gin.parse_config_files_and_bindings(
|
154 |
-
gin_files, gin_bindings, finalize_config=False)
|
155 |
|
156 |
def _load_model(self):
|
157 |
"""在解析训练 gin 配置后加载 T5X `Model`。"""
|
@@ -159,11 +158,11 @@ class InferenceModel(object):
|
|
159 |
model_config = gin.get_configurable(network.T5Config)()
|
160 |
module = network.Transformer(config=model_config)
|
161 |
return models.ContinuousInputsEncoderDecoderModel(
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
|
168 |
|
169 |
def restore_from_checkpoint(self, checkpoint_path):
|
@@ -176,33 +175,31 @@ class InferenceModel(object):
|
|
176 |
partitioner=self.partitioner)
|
177 |
|
178 |
restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
|
179 |
-
|
180 |
|
181 |
train_state_axes = train_state_initializer.train_state_axes
|
182 |
self._predict_fn = self._get_predict_fn(train_state_axes)
|
183 |
self._train_state = train_state_initializer.from_checkpoint_or_scratch(
|
184 |
-
|
185 |
|
186 |
@functools.lru_cache()
|
187 |
def _get_predict_fn(self, train_state_axes):
|
188 |
"""生成一个分区的预测函数用于解码。"""
|
189 |
print(f"[{current_time()}] 日志:生成用于解码的预测函数")
|
190 |
def partial_predict_fn(params, batch, decode_rng):
|
191 |
-
return self.model.predict_batch_with_aux(
|
192 |
-
params, batch, decoder_params={'decode_rng': None})
|
193 |
return self.partitioner.partition(
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
)
|
200 |
|
201 |
def predict_tokens(self, batch, seed=0):
|
202 |
"""从预处理的数据集批次中预测 tokens。"""
|
203 |
print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
|
204 |
-
prediction, _ = self._predict_fn(
|
205 |
-
self._train_state.params, batch, jax.random.PRNGKey(seed))
|
206 |
return self.vocabulary.decode_tf(prediction).numpy()
|
207 |
|
208 |
def __call__(self, audio):
|
@@ -255,16 +252,16 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
|
|
255 |
def preprocess(self, ds):
|
256 |
pp_chain = [
|
257 |
functools.partial(
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
# 在训练期间进行缓存。
|
264 |
preprocessors.add_dummy_targets,
|
265 |
functools.partial(
|
266 |
-
|
267 |
-
|
268 |
]
|
269 |
for pp in pp_chain:
|
270 |
ds = pp(ds)
|
@@ -276,10 +273,10 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
|
|
276 |
# 向下取整到最接近的符号化时间步。
|
277 |
start_time -= start_time % (1 / self.codec.steps_per_second)
|
278 |
return {
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
}
|
284 |
|
285 |
@staticmethod
|
|
|
136 |
@property
|
137 |
def input_shapes(self):
|
138 |
return {
|
139 |
+
'encoder_input_tokens': (self.batch_size, self.inputs_length),
|
140 |
+
'decoder_input_tokens': (self.batch_size, self.outputs_length)
|
141 |
}
|
142 |
|
143 |
def _parse_gin(self, gin_files):
|
144 |
"""解析用于训练模型的 gin 文件。"""
|
145 |
print(f"[{current_time()}] 日志:解析 gin 文件")
|
146 |
gin_bindings = [
|
147 |
+
'from __gin__ import dynamic_registration',
|
148 |
+
'from mt3 import vocabularies',
|
149 |
+
'[email protected]()',
|
150 |
+
'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
|
151 |
]
|
152 |
with gin.unlock_config():
|
153 |
+
gin.parse_config_files_and_bindings(gin_files, gin_bindings, finalize_config=False)
|
|
|
154 |
|
155 |
def _load_model(self):
|
156 |
"""在解析训练 gin 配置后加载 T5X `Model`。"""
|
|
|
158 |
model_config = gin.get_configurable(network.T5Config)()
|
159 |
module = network.Transformer(config=model_config)
|
160 |
return models.ContinuousInputsEncoderDecoderModel(
|
161 |
+
module=module,
|
162 |
+
input_vocabulary=self.output_features['inputs'].vocabulary,
|
163 |
+
output_vocabulary=self.output_features['targets'].vocabulary,
|
164 |
+
optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
|
165 |
+
input_depth=spectrograms.input_depth(self.spectrogram_config))
|
166 |
|
167 |
|
168 |
def restore_from_checkpoint(self, checkpoint_path):
|
|
|
175 |
partitioner=self.partitioner)
|
176 |
|
177 |
restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
|
178 |
+
path=checkpoint_path, mode='specific', dtype='float32')
|
179 |
|
180 |
train_state_axes = train_state_initializer.train_state_axes
|
181 |
self._predict_fn = self._get_predict_fn(train_state_axes)
|
182 |
self._train_state = train_state_initializer.from_checkpoint_or_scratch(
|
183 |
+
[restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
|
184 |
|
185 |
@functools.lru_cache()
|
186 |
def _get_predict_fn(self, train_state_axes):
|
187 |
"""生成一个分区的预测函数用于解码。"""
|
188 |
print(f"[{current_time()}] 日志:生成用于解码的预测函数")
|
189 |
def partial_predict_fn(params, batch, decode_rng):
|
190 |
+
return self.model.predict_batch_with_aux(params, batch, decoder_params={'decode_rng': None})
|
|
|
191 |
return self.partitioner.partition(
|
192 |
+
partial_predict_fn,
|
193 |
+
in_axis_resources=(
|
194 |
+
train_state_axes.params,
|
195 |
+
t5x.partitioning.PartitionSpec('data',), None),
|
196 |
+
out_axis_resources=t5x.partitioning.PartitionSpec('data',)
|
197 |
)
|
198 |
|
199 |
def predict_tokens(self, batch, seed=0):
|
200 |
"""从预处理的数据集批次中预测 tokens。"""
|
201 |
print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
|
202 |
+
prediction, _ = self._predict_fn(self._train_state.params, batch, jax.random.PRNGKey(seed))
|
|
|
203 |
return self.vocabulary.decode_tf(prediction).numpy()
|
204 |
|
205 |
def __call__(self, audio):
|
|
|
252 |
def preprocess(self, ds):
|
253 |
pp_chain = [
|
254 |
functools.partial(
|
255 |
+
t5.data.preprocessors.split_tokens_to_inputs_length,
|
256 |
+
sequence_length=self.sequence_length,
|
257 |
+
output_features=self.output_features,
|
258 |
+
feature_key='inputs',
|
259 |
+
additional_feature_keys=['input_times']),
|
260 |
# 在训练期间进行缓存。
|
261 |
preprocessors.add_dummy_targets,
|
262 |
functools.partial(
|
263 |
+
preprocessors.compute_spectrograms,
|
264 |
+
spectrogram_config=self.spectrogram_config)
|
265 |
]
|
266 |
for pp in pp_chain:
|
267 |
ds = pp(ds)
|
|
|
273 |
# 向下取整到最接近的符号化时间步。
|
274 |
start_time -= start_time % (1 / self.codec.steps_per_second)
|
275 |
return {
|
276 |
+
'est_tokens': tokens,
|
277 |
+
'start_time': start_time,
|
278 |
+
# 内部 MT3 代码期望原始输入,这里不使用。
|
279 |
+
'raw_inputs': []
|
280 |
}
|
281 |
|
282 |
@staticmethod
|