oniati commited on
Commit
bdcc0c1
·
verified ·
1 Parent(s): 4d7ce0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -271
app.py CHANGED
@@ -1,333 +1,345 @@
1
- import gradio as gr
2
  import os
3
- import datetime
4
- import pytz
5
- from pathlib import Path
6
 
7
- def current_time():
8
- current = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
9
- return current
10
 
11
- print(f"[{current_time()}] 开始部署空间...")
12
 
13
- """
14
- print(f"[{current_time()}] 日志:安装 - 必要包")
15
- os.system("pip install -r ./requirements.txt")
16
- """
17
- print(f"[{current_time()}] 日志:安装 - gsutil")
18
  os.system("pip install gsutil")
19
- print(f"[{current_time()}] 日志:Git - 克隆 Github 的 T5X 训练框架到当前目录")
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  os.system("git clone --branch=main https://github.com/google-research/t5x")
21
- print(f"[{current_time()}] 日志:文件 - 移动 t5x 到当前目录并重命名为 t5x_tmp 并删除")
22
  os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
23
- print(f"[{current_time()}] 日志:编辑 - 替换 setup.py 内的文本“jax[tpu]”为“jax”")
24
  os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
25
- print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
26
  os.system("python3 -m pip install -e .")
27
- print(f"[{current_time()}] 日志:Python - 更新 Python 包管理器 pip")
28
  os.system("python3 -m pip install --upgrade pip")
29
- print(f"[{current_time()}] 日志:安装 - langchain")
30
  os.system("pip install langchain")
31
- print(f"[{current_time()}] 日志:安装 - sentence-transformers")
32
  os.system("pip install sentence-transformers")
33
 
34
- # 安装 airio
35
- print(f"[{current_time()}] 日志:Git - 克隆 Github 的 airio 到当前目录")
36
- os.system("git clone --branch=main https://github.com/google/airio")
37
- print(f"[{current_time()}] 日志:文件 - 移动 airio 到当前目录并重命名为 airio_tmp 并删除")
38
- os.system("mv airio airio_tmp; mv airio_tmp/* .; rm -r airio_tmp")
39
- print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
40
- os.system("python3 -m pip install -e .")
 
 
41
 
42
- # 安装 mt3
43
- print(f"[{current_time()}] 日志:Git - 克隆 Github 的 MT3 模型到当前目录")
44
  os.system("git clone --branch=main https://github.com/magenta/mt3")
45
- print(f"[{current_time()}] 日志:文件 - 移动 mt3 到当前目录并重命名为 mt3_tmp 并删除")
46
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
47
- print(f"[{current_time()}] 日志:Python - 使用 pip 从 storage.googleapis.com 安装 jax[cuda11_local] nest-asyncio pyfluidsynth")
48
  os.system("python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
49
- print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
50
  os.system("python3 -m pip install -e .")
51
- print(f"[{current_time()}] 日志:安装 - TensorFlow CPU")
52
  os.system("pip install tensorflow_cpu")
 
 
53
 
54
- # 复制检查点
55
- print(f"[{current_time()}] 日志:gsutil - 复制 MT3 检查点到当前目录")
56
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
57
 
58
- # 复制 soundfont 文件(原始文件来自 https://sites.google.com/site/soundfonts4u
59
- print(f"[{current_time()}] 日志:gsutil - 复制 SoundFont 文件到当前目录")
60
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
61
 
62
- #@title 导入和定义
63
- print(f"[{current_time()}] 日志:导入 - 必要工具")
 
 
 
 
64
  import functools
65
  import os
 
66
  import numpy as np
 
67
  import tensorflow.compat.v2 as tf
68
 
69
  import functools
70
  import gin
71
  import jax
 
72
  import librosa
73
  import note_seq
74
 
 
 
75
  import seqio
76
  import t5
77
  import t5x
78
-
79
- from mt3 import metrics_utils
80
- from mt3 import models
81
- from mt3 import network
82
- from mt3 import note_sequences
83
- from mt3 import preprocessors
84
  from mt3 import spectrograms
85
  from mt3 import vocabularies
86
 
 
87
  import nest_asyncio
88
  nest_asyncio.apply()
89
 
90
- SAMPLE_RATE = 16000
91
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
92
 
93
  def upload_audio(audio, sample_rate):
94
- return note_seq.audio_io.wav_data_to_samples_librosa(
95
- audio, sample_rate=sample_rate)
 
 
96
 
97
 
98
- print(f"[{current_time()}] 日志:开始包装模型...")
99
  class InferenceModel(object):
100
- """音乐转录的 T5X 模型包装器。"""
101
-
102
- def __init__(self, checkpoint_path, model_type='mt3'):
103
-
104
- # 模型常量。
105
- if model_type == 'ismir2021':
106
- num_velocity_bins = 127
107
- self.encoding_spec = note_sequences.NoteEncodingSpec
108
- self.inputs_length = 512
109
- elif model_type == 'mt3':
110
- num_velocity_bins = 1
111
- self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
112
- self.inputs_length = 256
113
- else:
114
- raise ValueError('unknown model_type: %s' % model_type)
115
-
116
- gin_files = ['/home/user/app/mt3/gin/model.gin',
117
- '/home/user/app/mt3/gin/mt3.gin']
118
-
119
- self.batch_size = 8
120
- self.outputs_length = 1024
121
- self.sequence_length = {'inputs': self.inputs_length,
122
- 'targets': self.outputs_length}
123
-
124
- self.partitioner = t5x.partitioning.PjitPartitioner(
125
- model_parallel_submesh=None, num_partitions=1)
126
-
127
- # 构建编解码器和词汇表。
128
- print(f"[{current_time()}] 日志:构建编解码器")
129
- self.spectrogram_config = spectrograms.SpectrogramConfig()
130
- self.codec = vocabularies.build_codec(
131
- vocab_config=vocabularies.VocabularyConfig(
132
- num_velocity_bins=num_velocity_bins)
133
- )
134
- self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
135
- self.output_features = {
136
- 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
137
- 'targets': seqio.Feature(vocabulary=self.vocabulary),
138
- }
139
-
140
- # 创建 T5X 模型。
141
- print(f"[{current_time()}] 日志:创建 T5X 模型")
142
- self._parse_gin(gin_files)
143
- self.model = self._load_model()
144
-
145
- # 从检查点中恢复。
146
- print(f"[{current_time()}] 日志:恢复模型检查点")
147
- self.restore_from_checkpoint(checkpoint_path)
148
-
149
- @property
150
- def input_shapes(self):
151
- return {
152
- 'encoder_input_tokens': (self.batch_size, self.inputs_length),
153
- 'decoder_input_tokens': (self.batch_size, self.outputs_length)
154
- }
155
-
156
- def _parse_gin(self, gin_files):
157
- """解析用于训练模型的 gin 文件。"""
158
- print(f"[{current_time()}] 日志:解析 gin 文件")
159
- gin_bindings = [
160
- 'from __gin__ import dynamic_registration',
161
- 'from mt3 import vocabularies',
162
163
- 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
164
- ]
165
- with gin.unlock_config():
166
- gin.parse_config_files_and_bindings(
167
- gin_files, gin_bindings, finalize_config=False)
168
-
169
- def _load_model(self):
170
- """在解析训练 gin 配置后加载 T5X `Model`。"""
171
- print(f"[{current_time()}] 日志:加载 T5X 模型")
172
- model_config = gin.get_configurable(network.T5Config)()
173
- module = network.Transformer(config=model_config)
174
- return models.ContinuousInputsEncoderDecoderModel(
175
- module=module,
176
- input_vocabulary=self.output_features['inputs'].vocabulary,
177
- output_vocabulary=self.output_features['targets'].vocabulary,
178
- optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
179
- input_depth=spectrograms.input_depth(self.spectrogram_config))
180
-
181
-
182
- def restore_from_checkpoint(self, checkpoint_path):
183
- """从检查点中恢复训练状态,重置 self._predict_fn()。"""
184
- print(f"[{current_time()}] 日志:从检查点恢复训练状态")
185
- train_state_initializer = t5x.utils.TrainStateInitializer(
186
- optimizer_def=self.model.optimizer_def,
187
- init_fn=self.model.get_initial_variables,
188
- input_shapes=self.input_shapes,
189
- partitioner=self.partitioner)
190
-
191
- restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
192
- path=checkpoint_path, mode='specific', dtype='float32')
193
-
194
- train_state_axes = train_state_initializer.train_state_axes
195
- self._predict_fn = self._get_predict_fn(train_state_axes)
196
- self._train_state = train_state_initializer.from_checkpoint_or_scratch(
197
- [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
198
-
199
- @functools.lru_cache()
200
- def _get_predict_fn(self, train_state_axes):
201
- """生成一个分区的预测函数用于解码。"""
202
- print(f"[{current_time()}] 日志:生成用于解码的预测函数")
203
- def partial_predict_fn(params, batch, decode_rng):
204
- return self.model.predict_batch_with_aux(
205
- params, batch, decoder_params={'decode_rng': None})
206
- return self.partitioner.partition(
207
- partial_predict_fn,
208
- in_axis_resources=(
209
- train_state_axes.params,
210
- t5x.partitioning.PartitionSpec('data',), None),
211
- out_axis_resources=t5x.partitioning.PartitionSpec('data',)
212
- )
213
-
214
- def predict_tokens(self, batch, seed=0):
215
- """从预处理的数据集批次中预测 tokens。"""
216
- print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
217
- prediction, _ = self._predict_fn(
218
- self._train_state.params, batch, jax.random.PRNGKey(seed))
219
- return self.vocabulary.decode_tf(prediction).numpy()
220
-
221
- def __call__(self, audio):
222
- """从音频样本推断出音符序列。
223
- 参数:
224
- audio:16kHz 的单个音频样本的 1 维 numpy 数组。
225
- 返回:
226
- 转录音频的音符序列。
227
- """
228
- print(f"[{current_time()}] 运行:从音频样本中推断音符序列")
229
- ds = self.audio_to_dataset(audio)
230
- ds = self.preprocess(ds)
231
-
232
- model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
233
- ds, task_feature_lengths=self.sequence_length)
234
- model_ds = model_ds.batch(self.batch_size)
235
-
236
- inferences = (tokens for batch in model_ds.as_numpy_iterator()
237
- for tokens in self.predict_tokens(batch))
238
-
239
- predictions = []
240
- for example, tokens in zip(ds.as_numpy_iterator(), inferences):
241
- predictions.append(self.postprocess(tokens, example))
242
-
243
- result = metrics_utils.event_predictions_to_ns(
244
- predictions, codec=self.codec, encoding_spec=self.encoding_spec)
245
- return result['est_ns']
246
-
247
- def audio_to_dataset(self, audio):
248
- """从输入音频创建一个包含频谱图的 TF Dataset。"""
249
- print(f"[{current_time()}] 运行:从音频创建包含频谱图的 TF Dataset")
250
- frames, frame_times = self._audio_to_frames(audio)
251
- return tf.data.Dataset.from_tensors({
252
- 'inputs': frames,
253
- 'input_times': frame_times,
254
- })
255
-
256
- def _audio_to_frames(self, audio):
257
- """从音频计算频谱图帧。"""
258
- print(f"[{current_time()}] 运行:从音频计算频谱图帧")
259
- frame_size = self.spectrogram_config.hop_width
260
- padding = [0, frame_size - len(audio) % frame_size]
261
- audio = np.pad(audio, padding, mode='constant')
262
- frames = spectrograms.split_audio(audio, self.spectrogram_config)
263
- num_frames = len(audio) // frame_size
264
- times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
265
- return frames, times
266
-
267
- def preprocess(self, ds):
268
- pp_chain = [
269
- functools.partial(
270
- t5.data.preprocessors.split_tokens_to_inputs_length,
271
- sequence_length=self.sequence_length,
272
- output_features=self.output_features,
273
- feature_key='inputs',
274
- additional_feature_keys=['input_times']),
275
- # 在训练期间进行缓存。
276
- preprocessors.add_dummy_targets,
277
- functools.partial(
278
- preprocessors.compute_spectrograms,
279
- spectrogram_config=self.spectrogram_config)
280
- ]
281
- for pp in pp_chain:
282
- ds = pp(ds)
283
- return ds
284
-
285
- def postprocess(self, tokens, example):
286
- tokens = self._trim_eos(tokens)
287
- start_time = example['input_times'][0]
288
- # 向下取整到最接近的符号化时间步。
289
- start_time -= start_time % (1 / self.codec.steps_per_second)
290
- return {
291
- 'est_tokens': tokens,
292
- 'start_time': start_time,
293
- # 内部 MT3 代码期望原始输入,这里不使用。
294
- 'raw_inputs': []
295
- }
296
-
297
- @staticmethod
298
- def _trim_eos(tokens):
299
- tokens = np.array(tokens, np.int32)
300
- if vocabularies.DECODED_EOS_ID in tokens:
301
- tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
302
- return tokens
303
 
304
 
305
  inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
306
 
307
 
308
  def inference(audio):
309
- filename = os.path.basename(audio) # 获取输入文件的文件名
310
- print(f"[{current_time()}] 运行:输入文件: {filename}")
311
- with open(audio, 'rb') as fd:
312
- contents = fd.read()
313
- audio = upload_audio(contents,sample_rate=16000)
314
- est_ns = inference_model(audio)
315
- note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
316
- return './transcribed.mid'
317
-
 
318
  title = "MT3"
319
- description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以查看效果。更多信息请参阅下面的链接。"
320
 
321
- article = "<p style='text-align: center'>出错了?试试把文件转换为MP3后再上传吧~</p><p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: 多任务多音轨音乐转录</a> | <a href='https://github.com/hmjz100/mt3' target='_blank'>Github 仓库</a></p>"
322
 
323
  examples=[['download.wav']]
324
 
325
  gr.Interface(
326
- inference,
327
- gr.Audio(type="filepath", label="输入"),
328
- outputs = gr.File(label="输出"),
329
- title=title,
330
- description=description,
331
- article=article,
332
- examples=examples
333
- ).launch()
 
 
 
 
 
1
  import os
2
+ os.system("pip install gradio==3.50")
 
 
3
 
4
+ import gradio as gr
 
 
5
 
 
6
 
7
+
8
+ from pathlib import Path
 
 
 
9
  os.system("pip install gsutil")
10
+
11
+
12
+
13
+
14
+
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
  os.system("git clone --branch=main https://github.com/google-research/t5x")
25
+
26
  os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
27
+
28
  os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
29
+
30
  os.system("python3 -m pip install -e .")
31
+
32
  os.system("python3 -m pip install --upgrade pip")
33
+
34
  os.system("pip install langchain")
35
+
36
  os.system("pip install sentence-transformers")
37
 
38
+ # install mt3
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
 
 
 
48
  os.system("git clone --branch=main https://github.com/magenta/mt3")
49
+
50
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
51
+
52
  os.system("python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
53
+
54
  os.system("python3 -m pip install -e .")
55
+
56
  os.system("pip install tensorflow_cpu")
57
+ # copy checkpoints
58
+
59
 
 
 
60
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
61
 
62
+ # copy soundfont (originally from https://sites.google.com/site/soundfonts4u)
63
+
64
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
65
 
66
+ #@title Imports and Definitions
67
+
68
+
69
+
70
+
71
+
72
  import functools
73
  import os
74
+
75
  import numpy as np
76
+
77
  import tensorflow.compat.v2 as tf
78
 
79
  import functools
80
  import gin
81
  import jax
82
+
83
  import librosa
84
  import note_seq
85
 
86
+
87
+
88
  import seqio
89
  import t5
90
  import t5x
 
 
 
 
 
 
91
  from mt3 import spectrograms
92
  from mt3 import vocabularies
93
 
94
+
95
  import nest_asyncio
96
  nest_asyncio.apply()
97
 
 
98
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
99
 
100
  def upload_audio(audio, sample_rate):
101
+ return note_seq.audio_io.wav_data_to_samples_librosa(
102
+ audio, sample_rate=sample_rate)
103
+
104
+
105
 
106
 
 
107
  class InferenceModel(object):
108
+ """Wrapper of T5X model for music transcription."""
109
+
110
+ def __init__(self, checkpoint_path, model_type='mt3'):
111
+
112
+ # Model Constants.
113
+ if model_type == 'ismir2021':
114
+ num_velocity_bins = 127
115
+ self.encoding_spec = note_sequences.NoteEncodingSpec
116
+ self.inputs_length = 512
117
+ elif model_type == 'mt3':
118
+ num_velocity_bins = 1
119
+ self.encoding_spec = note_sequences.NoteEncodingWithTiesSpec
120
+ self.inputs_length = 256
121
+ else:
122
+ raise ValueError('unknown model_type: %s' % model_type)
123
+
124
+ gin_files = ['/home/user/app/mt3/gin/model.gin',
125
+ '/home/user/app/mt3/gin/mt3.gin']
126
+
127
+ self.batch_size = 8
128
+ self.outputs_length = 1024
129
+ self.sequence_length = {'inputs': self.inputs_length,
130
+ 'targets': self.outputs_length}
131
+
132
+ self.partitioner = t5x.partitioning.PjitPartitioner(
133
+ model_parallel_submesh=None, num_partitions=1)
134
+
135
+ # Build Codecs and Vocabularies.
136
+ self.spectrogram_config = spectrograms.SpectrogramConfig()
137
+ self.codec = vocabularies.build_codec(
138
+ vocab_config=vocabularies.VocabularyConfig(
139
+ num_velocity_bins=num_velocity_bins))
140
+ self.vocabulary = vocabularies.vocabulary_from_codec(self.codec)
141
+ self.output_features = {
142
+ 'inputs': seqio.ContinuousFeature(dtype=tf.float32, rank=2),
143
+ 'targets': seqio.Feature(vocabulary=self.vocabulary),
144
+ }
145
+
146
+ # Create a T5X model.
147
+ self._parse_gin(gin_files)
148
+ self.model = self._load_model()
149
+
150
+ # Restore from checkpoint.
151
+ self.restore_from_checkpoint(checkpoint_path)
152
+
153
+ @property
154
+ def input_shapes(self):
155
+ return {
156
+ 'encoder_input_tokens': (self.batch_size, self.inputs_length),
157
+ 'decoder_input_tokens': (self.batch_size, self.outputs_length)
158
+ }
159
+
160
+ def _parse_gin(self, gin_files):
161
+ """Parse gin files used to train the model."""
162
+ gin_bindings = [
163
+ 'from __gin__ import dynamic_registration',
164
+ 'from mt3 import vocabularies',
165
166
+ 'vocabularies.VocabularyConfig.num_velocity_bins=%NUM_VELOCITY_BINS'
167
+ ]
168
+ with gin.unlock_config():
169
+ gin.parse_config_files_and_bindings(
170
+ gin_files, gin_bindings, finalize_config=False)
171
+
172
+ def _load_model(self):
173
+ """Load up a T5X `Model` after parsing training gin config."""
174
+ model_config = gin.get_configurable(network.T5Config)()
175
+ module = network.Transformer(config=model_config)
176
+ return models.ContinuousInputsEncoderDecoderModel(
177
+ module=module,
178
+ input_vocabulary=self.output_features['inputs'].vocabulary,
179
+ output_vocabulary=self.output_features['targets'].vocabulary,
180
+ optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
181
+ input_depth=spectrograms.input_depth(self.spectrogram_config))
182
+
183
+
184
+ def restore_from_checkpoint(self, checkpoint_path):
185
+ """Restore training state from checkpoint, resets self._predict_fn()."""
186
+ train_state_initializer = t5x.utils.TrainStateInitializer(
187
+ optimizer_def=self.model.optimizer_def,
188
+ init_fn=self.model.get_initial_variables,
189
+ input_shapes=self.input_shapes,
190
+ partitioner=self.partitioner)
191
+
192
+ restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
193
+ path=checkpoint_path, mode='specific', dtype='float32')
194
+
195
+ train_state_axes = train_state_initializer.train_state_axes
196
+ self._predict_fn = self._get_predict_fn(train_state_axes)
197
+ self._train_state = train_state_initializer.from_checkpoint_or_scratch(
198
+ [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))
199
+
200
+ @functools.lru_cache()
201
+ def _get_predict_fn(self, train_state_axes):
202
+ """Generate a partitioned prediction function for decoding."""
203
+ def partial_predict_fn(params, batch, decode_rng):
204
+ return self.model.predict_batch_with_aux(
205
+ params, batch, decoder_params={'decode_rng': None})
206
+ return self.partitioner.partition(
207
+ partial_predict_fn,
208
+ in_axis_resources=(
209
+ train_state_axes.params,
210
+ t5x.partitioning.PartitionSpec('data',), None),
211
+ out_axis_resources=t5x.partitioning.PartitionSpec('data',)
212
+ )
213
+
214
+ def predict_tokens(self, batch, seed=0):
215
+ """Predict tokens from preprocessed dataset batch."""
216
+ prediction, _ = self._predict_fn(
217
+ self._train_state.params, batch, jax.random.PRNGKey(seed))
218
+ return self.vocabulary.decode_tf(prediction).numpy()
219
+
220
+ def __call__(self, audio):
221
+ """Infer note sequence from audio samples.
222
+
223
+ Args:
224
+ audio: 1-d numpy array of audio samples (16kHz) for a single example.
225
+ Returns:
226
+ A note_sequence of the transcribed audio.
227
+ """
228
+ ds = self.audio_to_dataset(audio)
229
+ ds = self.preprocess(ds)
230
+
231
+ model_ds = self.model.FEATURE_CONVERTER_CLS(pack=False)(
232
+ ds, task_feature_lengths=self.sequence_length)
233
+ model_ds = model_ds.batch(self.batch_size)
234
+
235
+ inferences = (tokens for batch in model_ds.as_numpy_iterator()
236
+ for tokens in self.predict_tokens(batch))
237
+
238
+ predictions = []
239
+ for example, tokens in zip(ds.as_numpy_iterator(), inferences):
240
+ predictions.append(self.postprocess(tokens, example))
241
+
242
+ result = metrics_utils.event_predictions_to_ns(
243
+ predictions, codec=self.codec, encoding_spec=self.encoding_spec)
244
+ return result['est_ns']
245
+
246
+ def audio_to_dataset(self, audio):
247
+ """Create a TF Dataset of spectrograms from input audio."""
248
+ frames, frame_times = self._audio_to_frames(audio)
249
+ return tf.data.Dataset.from_tensors({
250
+ 'inputs': frames,
251
+ 'input_times': frame_times,
252
+ })
253
+
254
+ def _audio_to_frames(self, audio):
255
+ """Compute spectrogram frames from audio."""
256
+ frame_size = self.spectrogram_config.hop_width
257
+ padding = [0, frame_size - len(audio) % frame_size]
258
+ audio = np.pad(audio, padding, mode='constant')
259
+ frames = spectrograms.split_audio(audio, self.spectrogram_config)
260
+ num_frames = len(audio) // frame_size
261
+ times = np.arange(num_frames) / self.spectrogram_config.frames_per_second
262
+ return frames, times
263
+
264
+ def preprocess(self, ds):
265
+ pp_chain = [
266
+ functools.partial(
267
+ t5.data.preprocessors.split_tokens_to_inputs_length,
268
+ sequence_length=self.sequence_length,
269
+ output_features=self.output_features,
270
+ feature_key='inputs',
271
+ additional_feature_keys=['input_times']),
272
+ # Cache occurs here during training.
273
+ preprocessors.add_dummy_targets,
274
+ functools.partial(
275
+ preprocessors.compute_spectrograms,
276
+ spectrogram_config=self.spectrogram_config)
277
+ ]
278
+ for pp in pp_chain:
279
+ ds = pp(ds)
280
+ return ds
281
+
282
+ def postprocess(self, tokens, example):
283
+ tokens = self._trim_eos(tokens)
284
+ start_time = example['input_times'][0]
285
+ # Round down to nearest symbolic token step.
286
+ start_time -= start_time % (1 / self.codec.steps_per_second)
287
+ return {
288
+ 'est_tokens': tokens,
289
+ 'start_time': start_time,
290
+ # Internal MT3 code expects raw inputs, not used here.
291
+ 'raw_inputs': []
292
+ }
293
+
294
+ @staticmethod
295
+ def _trim_eos(tokens):
296
+ tokens = np.array(tokens, np.int32)
297
+ if vocabularies.DECODED_EOS_ID in tokens:
298
+ tokens = tokens[:np.argmax(tokens == vocabularies.DECODED_EOS_ID)]
299
+ return tokens
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+
308
+
309
+
310
+
311
 
312
 
313
  inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
314
 
315
 
316
  def inference(audio):
317
+ with open(audio, 'rb') as fd:
318
+ contents = fd.read()
319
+ audio = upload_audio(contents,sample_rate=16000)
320
+
321
+ est_ns = inference_model(audio)
322
+
323
+ note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
324
+
325
+ return './transcribed.mid'
326
+
327
  title = "MT3"
328
+ description = "Gradio demo for MT3: Multi-Task Multitrack Music Transcription. To use it, simply upload your audio file, or click one of the examples to load them. Read more at the links below."
329
 
330
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: Multi-Task Multitrack Music Transcription</a> | <a href='https://github.com/magenta/mt3' target='_blank'>Github Repo</a></p>"
331
 
332
  examples=[['download.wav']]
333
 
334
  gr.Interface(
335
+ inference,
336
+ gr.inputs.Audio(type="filepath", label="Input"),
337
+ [gr.outputs.File(label="Output")],
338
+ title=title,
339
+ description=description,
340
+ article=article,
341
+ examples=examples,
342
+ allow_flagging=False,
343
+ allow_screenshot=False,
344
+ enable_queue=True
345
+ ).launch()