Hmjz100 commited on
Commit
804bc89
·
verified ·
1 Parent(s): bfd6354

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -242
app.py CHANGED
@@ -5,14 +5,11 @@ import pytz
5
  from pathlib import Path
6
 
7
  def current_time():
8
- current = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
9
- return current
10
 
11
  print(f"[{current_time()}] 开始部署空间...")
12
- """
13
- print(f"[{current_time()}] 日志:安装 - 必要包")
14
- os.system("pip install -r ./requirements.txt")
15
- """
16
  print(f"[{current_time()}] 日志:安装 - gsutil")
17
  os.system("pip install gsutil")
18
  print(f"[{current_time()}] 日志:Git - 克隆 Github 的 T5X 训练框架到当前目录")
@@ -30,7 +27,6 @@ os.system("pip install langchain")
30
  print(f"[{current_time()}] 日志:安装 - sentence-transformers")
31
  os.system("pip install sentence-transformers")
32
 
33
- # 安装 airio
34
  print(f"[{current_time()}] 日志:Git - 克隆 Github 的 airio 到当前目录")
35
  os.system("git clone --branch=main https://github.com/google/airio")
36
  print(f"[{current_time()}] 日志:文件 - 移动 airio 到当前目录并重命名为 airio_tmp 并删除")
@@ -38,7 +34,6 @@ os.system("mv airio airio_tmp; mv airio_tmp/* .; rm -r airio_tmp")
38
  print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
39
  os.system("python3 -m pip install -e .")
40
 
41
- # 安装 mt3
42
  print(f"[{current_time()}] 日志:Git - 克隆 Github 的 MT3 模型到当前目录")
43
  os.system("git clone --branch=main https://github.com/magenta/mt3")
44
  print(f"[{current_time()}] 日志:文件 - 移动 mt3 到当前目录并重命名为 mt3_tmp 并删除")
@@ -46,33 +41,28 @@ os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
46
  print(f"[{current_time()}] 日志:Python - 使用 pip 从 storage.googleapis.com 安装 jax[cuda11_local] nest-asyncio pyfluidsynth")
47
  os.system("python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
48
  print(f"[{current_time()}] 日志:安装 - 更新 jaxlib")
49
- os.system("pip install --upgrade jaxlib==0.4.27")
50
  print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
51
  os.system("python3 -m pip install -e .")
52
  print(f"[{current_time()}] 日志:安装 - TensorFlow CPU")
53
  os.system("pip install tensorflow_cpu")
54
 
55
- # 复制检查点
56
  print(f"[{current_time()}] 日志:gsutil - 复制 MT3 检查点到当前目录")
57
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
58
 
59
- # 复制 soundfont 文件(原始文件来自 https://sites.google.com/site/soundfonts4u)
60
  print(f"[{current_time()}] 日志:gsutil - 复制 SoundFont 文件到当前目录")
61
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
62
 
63
- #@title 导入和定义
64
  print(f"[{current_time()}] 日志:导入 - 必要工具")
65
  import functools
66
  import os
67
  import numpy as np
68
  import tensorflow.compat.v2 as tf
69
 
70
- import functools
71
  import gin
72
  import jax
73
  import librosa
74
  import note_seq
75
-
76
  import seqio
77
  import t5
78
  import t5x
@@ -92,230 +82,132 @@ SAMPLE_RATE = 16000
92
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
93
 
94
  def upload_audio(audio, sample_rate):
95
- return note_seq.audio_io.wav_data_to_samples_librosa(
96
- audio, sample_rate=sample_rate)
97
 
98
 
99
  print(f"[{current_time()}] 日志:开始包装模型...")
100
  class InferenceModel(object):
101
- """音乐转录的 T5X 模型包装器。"""
102
-
103
- def __init__(self, checkpoint_path, model_type='mt3'):
104
-
105
- # 模型常量。
106
- if model_type == 'ismir2021':
107
- num_velocity_bins = 127
108
- self.encoding_spec = note_sequences.NoteEncodingSpec
109
- self.inputs_length = 512
110
- elif model_type == 'mt3':
111
- num_velocity_bins = 1
112
- self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
113
- self.inputs_length = 256
114
- else:
115
- raise ValueError('unknown model_type: %s' % model_type)
116
-
117
- gin_files = ['/home/user/app/mt3/gin/model.gin',
118
- '/home/user/app/mt3/gin/mt3.gin']
119
-
120
- self.batch_size = 8
121
- self.outputs_length = 1024
122
- self.sequence_length = {'inputs': self.inputs_length,
123
- 'targets': self.outputs_length}
124
-
125
- self.partitioner = t5x.partitioning.PjitPartitioner(
126
- model_parallel_submesh=None, num_partitions=1)
127
-
128
- # 构建编解码器和词汇表。
129
- print(f"[{current_time()}] 日志:构建编解码器")
130
- self.spectrogram_config = spectrograms.SpectrogramConfig()
131
- self.codec = vocabularies.build_codec(
132
- vocab_config=vocabularies.VocabularyConfig(
133
- num_velocity_bins=num_velocity_bins)
134
- )
135
- self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
136
- self.output_features = {
137
- 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
138
- 'targets': seqio.Feature(vocabulary=self.vocabulary),
139
- }
140
-
141
- # 创建 T5X 模型。
142
- print(f"[{current_time()}] 日志:创建 T5X 模型")
143
- self._parse_gin(gin_files)
144
- self.model = self._load_model()
145
-
146
- # 从检查点中恢复。
147
- print(f"[{current_time()}] 日志:恢复模型检查点")
148
- self.restore_from_checkpoint(checkpoint_path)
149
-
150
- @property
151
- def input_shapes(self):
152
- return {
153
- 'encoder_input_tokens': (self.batch_size, self.inputs_length),
154
- 'decoder_input_tokens': (self.batch_size, self.outputs_length)
155
- }
156
-
157
- def _parse_gin(self, gin_files):
158
- """解析用于训练模型的 gin 文件。"""
159
- print(f"[{current_time()}] 日志:解析 gin 文件")
160
- gin_bindings = [
161
- 'from __gin__ import dynamic_registration',
162
- 'from mt3 import vocabularies',
163
164
- 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
165
- ]
166
- with gin.unlock_config():
167
- gin.parse_config_files_and_bindings(
168
- gin_files, gin_bindings, finalize_config=False)
169
-
170
- def _load_model(self):
171
- """在解析训练 gin 配置后加载 T5X `Model`。"""
172
- print(f"[{current_time()}] 日志:加载 T5X 模型")
173
- model_config = gin.get_configurable(network.T5Config)()
174
- module = network.Transformer(config=model_config)
175
- return models.ContinuousInputsEncoderDecoderModel(
176
- module=module,
177
- input_vocabulary=self.output_features['inputs'].vocabulary,
178
- output_vocabulary=self.output_features['targets'].vocabulary,
179
- optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
180
- input_depth=spectrograms.input_depth(self.spectrogram_config))
181
-
182
-
183
- def restore_from_checkpoint(self, checkpoint_path):
184
- """从检查点中恢复训练状态,重置 self._predict_fn()。"""
185
- print(f"[{current_time()}] 日志:从检查点恢复训练状态")
186
- train_state_initializer = t5x.utils.TrainStateInitializer(
187
- optimizer_def=self.model.optimizer_def,
188
- init_fn=self.model.get_initial_variables,
189
- input_shapes=self.input_shapes,
190
- partitioner=self.partitioner)
191
-
192
- restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
193
- path=checkpoint_path, mode='specific', dtype='float32')
194
-
195
- train_state_axes = train_state_initializer.train_state_axes
196
- self._predict_fn = self._get_predict_fn(train_state_axes)
197
- self._train_state = train_state_initializer.from_checkpoint_or_scratch(
198
- [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
199
-
200
- @functools.lru_cache()
201
- def _get_predict_fn(self, train_state_axes):
202
- """生成一个分区的预测函数用于解码。"""
203
- print(f"[{current_time()}] 日志:生成用于解码的预测函数")
204
- def partial_predict_fn(params, batch, decode_rng):
205
- return self.model.predict_batch_with_aux(
206
- params, batch, decoder_params={'decode_rng': None})
207
- return self.partitioner.partition(
208
- partial_predict_fn,
209
- in_axis_resources=(
210
- train_state_axes.params,
211
- t5x.partitioning.PartitionSpec('data',), None),
212
- out_axis_resources=t5x.partitioning.PartitionSpec('data',)
213
- )
214
-
215
- def predict_tokens(self, batch, seed=0):
216
- """从预处理的数据集批次中预测 tokens。"""
217
- print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
218
- prediction, _ = self._predict_fn(
219
- self._train_state.params, batch, jax.random.PRNGKey(seed))
220
- return self.vocabulary.decode_tf(prediction).numpy()
221
-
222
- def __call__(self, audio):
223
- """从音频样本推断出音符序列。
224
-
225
- 参数:
226
- audio:16kHz 的单个音频样本的 1 维 numpy 数组。
227
- 返回:
228
- 转录音频的音符序列。
229
- """
230
- print(f"[{current_time()}] 运行:从音频样本中推断音符序列")
231
- ds = self.audio_to_dataset(audio)
232
- ds = self.preprocess(ds)
233
-
234
- model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
235
- ds, task_feature_lengths=self.sequence_length)
236
- model_ds = model_ds.batch(self.batch_size)
237
-
238
- inferences = (tokens for batch in model_ds.as_numpy_iterator()
239
- for tokens in self.predict_tokens(batch))
240
-
241
- predictions = []
242
- for example, tokens in zip(ds.as_numpy_iterator(), inferences):
243
- predictions.append(self.postprocess(tokens, example))
244
-
245
- result = metrics_utils.event_predictions_to_ns(
246
- predictions, codec=self.codec, encoding_spec=self.encoding_spec)
247
- return result['est_ns']
248
-
249
- def audio_to_dataset(self, audio):
250
- """从输入音频创建一个包含频谱图的 TF Dataset。"""
251
- print(f"[{current_time()}] 运行:从音频创建包含频谱图的 TF Dataset")
252
- frames, frame_times = self._audio_to_frames(audio)
253
- return tf.data.Dataset.from_tensors({
254
- 'inputs': frames,
255
- 'input_times': frame_times,
256
- })
257
-
258
- def _audio_to_frames(self, audio):
259
- """从音频计算频谱图帧。"""
260
- print(f"[{current_time()}] 运行:从音频计算频谱图帧")
261
- frame_size = self.spectrogram_config.hop_width
262
- padding = [0, frame_size - len(audio) % frame_size]
263
- audio = np.pad(audio, padding, mode='constant')
264
- frames = spectrograms.split_audio(audio, self.spectrogram_config)
265
- num_frames = len(audio) // frame_size
266
- times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
267
- return frames, times
268
-
269
- def preprocess(self, ds):
270
- pp_chain = [
271
- functools.partial(
272
- t5.data.preprocessors.split_tokens_to_inputs_length,
273
- sequence_length=self.sequence_length,
274
- output_features=self.output_features,
275
- feature_key='inputs',
276
- additional_feature_keys=['input_times']),
277
- # 在训练期间进行缓存。
278
- preprocessors.add_dummy_targets,
279
- functools.partial(
280
- preprocessors.compute_spectrograms,
281
- spectrogram_config=self.spectrogram_config)
282
- ]
283
- for pp in pp_chain:
284
- ds = pp(ds)
285
- return ds
286
-
287
- def postprocess(self, tokens, example):
288
- tokens = self._trim_eos(tokens)
289
- start_time = example['input_times'][0]
290
- # 向下取整到最接近的符号化时间步。
291
- start_time -= start_time % (1 / self.codec.steps_per_second)
292
- return {
293
- 'est_tokens': tokens,
294
- 'start_time': start_time,
295
- # 内部 MT3 代码期望原始输入,这里不使用。
296
- 'raw_inputs': []
297
- }
298
-
299
- @staticmethod
300
- def _trim_eos(tokens):
301
- tokens = np.array(tokens, np.int32)
302
- if vocabularies.DECODED_EOS_ID in tokens:
303
- tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
304
- return tokens
305
-
306
-
307
- inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
308
-
309
-
310
- def inference(audio):
311
- filename = os.path.basename(audio) # 获取输入文件的文件名
312
- print(f"[{current_time()}] 运行:输入文件: {filename}")
313
- with open(audio, 'rb') as fd:
314
- contents = fd.read()
315
- audio = upload_audio(contents,sample_rate=16000)
316
- est_ns = inference_model(audio)
317
- note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
318
- return './transcribed.mid'
319
 
320
  title = "MT3"
321
  description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以查看效果。更多信息请参阅下面的链接。"
@@ -325,11 +217,11 @@ article = "<p style='text-align: center'>出错了?试试把文件转换为MP3
325
  examples=[['canon.flac'], ['download.wav']]
326
 
327
  gr.Interface(
328
- inference,
329
- gr.Audio(type="filepath", label="输入"),
330
- outputs = gr.File(label="输出"),
331
- title=title,
332
- description=description,
333
- article=article,
334
- examples=examples
335
- ).launch()
 
5
  from pathlib import Path
6
 
7
  def current_time():
8
+ current = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
9
+ return current
10
 
11
  print(f"[{current_time()}] 开始部署空间...")
12
+
 
 
 
13
  print(f"[{current_time()}] 日志:安装 - gsutil")
14
  os.system("pip install gsutil")
15
  print(f"[{current_time()}] 日志:Git - 克隆 Github 的 T5X 训练框架到当前目录")
 
27
  print(f"[{current_time()}] 日志:安装 - sentence-transformers")
28
  os.system("pip install sentence-transformers")
29
 
 
30
  print(f"[{current_time()}] 日志:Git - 克隆 Github 的 airio 到当前目录")
31
  os.system("git clone --branch=main https://github.com/google/airio")
32
  print(f"[{current_time()}] 日志:文件 - 移动 airio 到当前目录并重命名为 airio_tmp 并删除")
 
34
  print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
35
  os.system("python3 -m pip install -e .")
36
 
 
37
  print(f"[{current_time()}] 日志:Git - 克隆 Github 的 MT3 模型到当前目录")
38
  os.system("git clone --branch=main https://github.com/magenta/mt3")
39
  print(f"[{current_time()}] 日志:文件 - 移动 mt3 到当前目录并重命名为 mt3_tmp 并删除")
 
41
  print(f"[{current_time()}] 日志:Python - 使用 pip 从 storage.googleapis.com 安装 jax[cuda11_local] nest-asyncio pyfluidsynth")
42
  os.system("python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
43
  print(f"[{current_time()}] 日志:安装 - 更新 jaxlib")
44
+ os.system("pip install --upgrade jaxlib")
45
  print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
46
  os.system("python3 -m pip install -e .")
47
  print(f"[{current_time()}] 日志:安装 - TensorFlow CPU")
48
  os.system("pip install tensorflow_cpu")
49
 
 
50
  print(f"[{current_time()}] 日志:gsutil - 复制 MT3 检查点到当前目录")
51
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
52
 
 
53
  print(f"[{current_time()}] 日志:gsutil - 复制 SoundFont 文件到当前目录")
54
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
55
 
 
56
  print(f"[{current_time()}] 日志:导入 - 必要工具")
57
  import functools
58
  import os
59
  import numpy as np
60
  import tensorflow.compat.v2 as tf
61
 
 
62
  import gin
63
  import jax
64
  import librosa
65
  import note_seq
 
66
  import seqio
67
  import t5
68
  import t5x
 
82
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
83
 
84
  def upload_audio(audio, sample_rate):
85
+ return note_seq.audio_io.wav_data_to_samples_librosa(
86
+ audio, sample_rate=sample_rate)
87
 
88
 
89
  print(f"[{current_time()}] 日志:开始包装模型...")
90
  class InferenceModel(object):
91
+ """音乐转录的 T5X 模型包装器。"""
92
+
93
+ def __init__(self, checkpoint_path, model_type='mt3'):
94
+ if model_type == 'ismir2021':
95
+ num_velocity_bins = 127
96
+ self.encoding_spec = note_sequences.NoteEncodingSpec
97
+ self.inputs_length = 512
98
+ elif model_type == 'mt3':
99
+ num_velocity_bins = 1
100
+ self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
101
+ self.inputs_length = 256
102
+ else:
103
+ raise ValueError('unknown model_type: %s' % model_type)
104
+
105
+ gin_files = ['/home/user/app/mt3/gin/model.gin',
106
+ '/home/user/app/mt3/gin/mt3.gin']
107
+
108
+ self.batch_size = 8
109
+ self.outputs_length = 1024
110
+ self.sequence_length = {'inputs': self.inputs_length,
111
+ 'targets': self.outputs_length}
112
+
113
+ self.partitioner = t5x.partitioning.PjitPartitioner(
114
+ model_parallel_submesh=None, num_partitions=1)
115
+
116
+ print(f"[{current_time()}] 日志:构建编解码器")
117
+ self.spectrogram_config = spectrograms.SpectrogramConfig()
118
+ self.codec = vocabularies.build_codec(
119
+ vocab_config=vocabularies.VocabularyConfig(
120
+ num_velocity_bins=num_velocity_bins)
121
+ )
122
+ self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
123
+ self.output_features = {
124
+ 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
125
+ 'targets': seqio.Feature(vocabulary=self.vocabulary),
126
+ }
127
+
128
+ print(f"[{current_time()}] 日志:创建 T5X 模型")
129
+ self._parse_gin(gin_files)
130
+ self.model = self._load_model()
131
+
132
+ print(f"[{current_time()}] 日志:恢复模型检查点")
133
+ self.restore_from_checkpoint(checkpoint_path)
134
+
135
+ @property
136
+ def input_shapes(self):
137
+ return {
138
+ 'encoder_input_tokens': (self.batch_size, self.inputs_length),
139
+ 'decoder_input_tokens': (self.batch_size, self.outputs_length)
140
+ }
141
+
142
+ def _parse_gin(self, gin_files):
143
+ print(f"[{current_time()}] 日志:解析 gin 文件")
144
+ gin_bindings = [
145
+ 'from __gin__ import dynamic_registration',
146
+ 'from mt3 import vocabularies',
147
148
+ 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
149
+ ]
150
+ with gin.unlock_config():
151
+ gin.parse_config_files_and_bindings(
152
+ gin_files, gin_bindings, finalize_config=False)
153
+
154
+ def _load_model(self):
155
+ print(f"[{current_time()}] 日志:加载 T5X 模型")
156
+ model_config = gin.get_configurable(network.T5Config)()
157
+ module = network.Transformer(config=model_config)
158
+ return models.ContinuousInputsEncoderDecoderModel(
159
+ module=module,
160
+ input_vocabulary=self.output_features['inputs'].vocabulary,
161
+ output_vocabulary=self.output_features['targets'].vocabulary,
162
+ optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
163
+ input_depth=spectrograms.input_depth(self.spectrogram_config))
164
+
165
+
166
+ def restore_from_checkpoint(self, checkpoint_path):
167
+ print(f"[{current_time()}] 日志:从检查点恢复训练状态")
168
+ train_state_initializer = t5x.utils.TrainStateInitializer(
169
+ optimizer_def=self.model.optimizer_def,
170
+ init_fn=self.model.get_initial_variables,
171
+ input_shapes=self.input_shapes,
172
+ partitioner=self.partitioner)
173
+
174
+ restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
175
+ path=checkpoint_path, mode='specific', dtype='float32')
176
+
177
+ train_state_axes = train_state_initializer.train_state_axes
178
+ self._predict_fn = self._get_predict_fn(train_state_axes)
179
+ self._train_state = train_state_initializer.from_checkpoint_or_scratch(
180
+ [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
181
+
182
+ @functools.lru_cache()
183
+ def _get_predict_fn(self, train_state_axes):
184
+ print(f"[{current_time()}] 日志:生成用于解码的预测函数")
185
+ def partial_predict_fn(params, batch, decode_rng):
186
+ return self.model.predict_batch_with_aux(
187
+ params, batch, decoder_params={'decode_rng': None})
188
+ return self.partitioner.partition(
189
+ partial_predict_fn,
190
+ in_axis_resources=(
191
+ train_state_axes.params,
192
+ t5x.partitioning.PartitionSpec('data',), None),
193
+ out_axis_resources=t5x.partitioning.PartitionSpec('data',)
194
+ )
195
+
196
+ def predict_tokens(self, batch, seed=0):
197
+ print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
198
+ prediction, _ = self._predict_fn(
199
+ self._train_state.params, batch, jax.random.PRNGKey(seed))
200
+ return self.vocabulary.decode_tf(prediction).numpy()
201
+
202
+ def __call__(self, audio):
203
+ filename = os.path.basename(audio) # 获取输入文件的文件名
204
+ print(f"[{current_time()}] 运行:输入文件: {filename}")
205
+ with open(audio, 'rb') as fd:
206
+ contents = fd.read()
207
+ audio = upload_audio(contents,sample_rate=16000)
208
+ est_ns = inference_model(audio)
209
+ note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
210
+ return './transcribed.mid'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
  title = "MT3"
213
  description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以查看效果。更多信息请参阅下面的链接。"
 
217
  examples=[['canon.flac'], ['download.wav']]
218
 
219
  gr.Interface(
220
+ inference,
221
+ gr.Audio(type="filepath", label="输入"),
222
+ outputs=gr.File(label="输出"),
223
+ title=title,
224
+ description=description,
225
+ article=article,
226
+ examples=examples
227
+ ).launch(server_port=7861)