Hmjz100 commited on
Commit
6314551
1 Parent(s): bc4a291

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -34
app.py CHANGED
@@ -3,45 +3,47 @@ import os
3
  import datetime
4
  import pytz
5
 
6
- current_time = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
 
 
7
 
8
- print(f"[{current_time}] 日志: - 部署空间")
9
 
10
  from pathlib import Path
11
- print(f"[{current_time}] 日志: - 安装 gsutil")
12
  os.system("pip install gsutil")
13
- print(f"[{current_time}] 日志: - Github 克隆 T5X 训练框架")
14
  os.system("git clone --branch=main https://github.com/google-research/t5x")
15
- print(f"[{current_time}] 日志: - T5X 训练框架转变为临时文件")
16
  os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
17
- print(f"[{current_time}] 日志: - 修改当前目录下的 setup.py 内的 jax[tpu]jax")
18
  os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
19
- print(f"[{current_time}] 日志: - 安装当前目录中的 Python 包")
20
  os.system("python3 -m pip install -e .")
21
- print(f"[{current_time}] 日志: - 更新 Python 包管理器 pip 到最新版")
22
  os.system("python3 -m pip install --upgrade pip")
23
 
24
 
25
  # 安装 mt3
26
- print(f"[{current_time}] 日志: - Github 克隆 MT3 模型")
27
  os.system("git clone --branch=main https://github.com/magenta/mt3")
28
- print(f"[{current_time}] 日志: - MT3 模型转变为临时文件")
29
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
30
- print(f"[{current_time}] 日志: - 安装当前目录中的 Python 包")
31
  os.system("python3 -m pip install -e .")
32
- print(f"[{current_time}] 日志: - 安装 TensorFlow CPU")
33
  os.system("pip install tensorflow_cpu")
34
 
35
  # 复制检查点
36
- print(f"[{current_time}] 日志: - 复制 MT3 内的检查点到当前目录")
37
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
38
 
39
  # 复制 soundfont 文件(原始文件来自 https://sites.google.com/site/soundfonts4u)
40
- print(f"[{current_time}] 日志: - 复制 SoundFont 文件到当前目录")
41
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
42
 
43
  #@title 导入和定义
44
- print(f"[{current_time}] 日志: - 导入实用命令")
45
  import functools
46
 
47
  import numpy as np
@@ -77,14 +79,13 @@ def upload_audio(audio, sample_rate):
77
  audio, sample_rate=sample_rate)
78
 
79
 
80
- print(f"[{current_time}] 日志: - 包装模型")
81
  class InferenceModel(object):
82
  """音乐转录的 T5X 模型包装器。"""
83
 
84
  def __init__(self, checkpoint_path, model_type='mt3'):
85
 
86
  # 模型常量。
87
- print(f"[{current_time}] 日志: - 设置模型常量")
88
  if model_type == 'ismir2021':
89
  num_velocity_bins = 127
90
  self.encoding_spec = note_sequences.NoteEncodingSpec
@@ -108,7 +109,7 @@ class InferenceModel(object):
108
  model_parallel_submesh=None, num_partitions=1)
109
 
110
  # 构建编解码器和词汇表。
111
- print(f"[{current_time}] 日志: - 构建编解码器")
112
  self.spectrogram_config = spectrograms.SpectrogramConfig()
113
  self.codec = vocabularies.build_codec(
114
  vocab_config=vocabularies.VocabularyConfig(
@@ -120,12 +121,12 @@ class InferenceModel(object):
120
  }
121
 
122
  # 创建 T5X 模型。
123
- print(f"[{current_time}] 日志: - 创建 T5X 模型")
124
  self._parse_gin(gin_files)
125
  self.model = self._load_model()
126
 
127
  # 从检查点中恢复。
128
- print(f"[{current_time}] 日志: - 恢复检查点")
129
  self.restore_from_checkpoint(checkpoint_path)
130
 
131
  @property
@@ -137,7 +138,7 @@ class InferenceModel(object):
137
 
138
  def _parse_gin(self, gin_files):
139
  """解析用于训练模型的 gin 文件。"""
140
- print(f"[{current_time}] 日志: - 解析 gin 文件")
141
  gin_bindings = [
142
  'from __gin__ import dynamic_registration',
143
  'from mt3 import vocabularies',
@@ -150,7 +151,7 @@ class InferenceModel(object):
150
 
151
  def _load_model(self):
152
  """在解析训练 gin 配置后加载 T5X `Model`。"""
153
- print(f"[{current_time}] 日志: - 加载 T5X 模型")
154
  model_config = gin.get_configurable(network.T5Config)()
155
  module = network.Transformer(config=model_config)
156
  return models.ContinuousInputsEncoderDecoderModel(
@@ -163,7 +164,7 @@ class InferenceModel(object):
163
 
164
  def restore_from_checkpoint(self, checkpoint_path):
165
  """从检查点中恢复训练状态,重置 self._predict_fn()。"""
166
- print(f"[{current_time}] 日志: - 从检查点恢复训练状态")
167
  train_state_initializer = t5x.utils.TrainStateInitializer(
168
  optimizer_def=self.model.optimizer_def,
169
  init_fn=self.model.get_initial_variables,
@@ -181,7 +182,7 @@ class InferenceModel(object):
181
  @functools.lru_cache()
182
  def _get_predict_fn(self, train_state_axes):
183
  """生成一个分区的预测函数用于解码。"""
184
- print(f"[{current_time}] 日志: - 生成用于解码的预测函数")
185
  def partial_predict_fn(params, batch, decode_rng):
186
  return self.model.predict_batch_with_aux(
187
  params, batch, decoder_params={'decode_rng': None})
@@ -195,7 +196,7 @@ class InferenceModel(object):
195
 
196
  def predict_tokens(self, batch, seed=0):
197
  """从预处理的数据集批次中预测 tokens。"""
198
- print(f"[{current_time}] 日志: - 从数据集中预测 tokens")
199
  prediction, _ = self._predict_fn(
200
  self._train_state.params, batch, jax.random.PRNGKey(seed))
201
  return self.vocabulary.decode_tf(prediction).numpy()
@@ -208,7 +209,7 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
208
  返回:
209
  转录音频的音符序列。
210
  """
211
- print(f"[{current_time}] 日志: - 推断音符序列")
212
  ds = self.audio_to_dataset(audio)
213
  ds = self.preprocess(ds)
214
 
@@ -229,7 +230,7 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
229
 
230
  def audio_to_dataset(self, audio):
231
  """从输入音频创建一个包含频谱图的 TF Dataset。"""
232
- print(f"[{current_time}] 日志: - 创建 TF Dataset")
233
  frames, frame_times = self._audio_to_frames(audio)
234
  return tf.data.Dataset.from_tensors({
235
  'inputs': frames,
@@ -238,9 +239,9 @@ self._train_state.params, batch, jax.random.PRNGKey(seed))
238
 
239
  def _audio_to_frames(self, audio):
240
  """从音频计算频谱图帧。"""
241
- print(f"[{current_time}] 日志: - 计算频谱图帧")
242
  frame_size = self.spectrogram_config.hop_width
243
- padding = [0, frame_size - len(audio) % frame_size]
244
  audio = np.pad(audio, padding, mode='constant')
245
  frames = spectrograms.split_audio(audio, self.spectrogram_config)
246
  num_frames = len(audio) // frame_size
@@ -289,20 +290,19 @@ inference_model = InferenceModel('/home/user/app/checkpoints/mt3/', 'mt3')
289
 
290
 
291
  def inference(audio):
 
 
292
  with open(audio, 'rb') as fd:
293
  contents = fd.read()
294
  audio = upload_audio(contents,sample_rate=16000)
295
-
296
  est_ns = inference_model(audio)
297
-
298
  note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
299
-
300
  return './transcribed.mid'
301
 
302
  title = "MT3"
303
- description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以加载它们。更多信息请参阅下面的链接。"
304
 
305
- article = "<p style='text-align: center'>出错了?试试把文件转换为MP3后再上传吧~</p><p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: 多任务多音轨音乐转录</a> | <a href='https://github.com/magenta/mt3' target='_blank'>Github 仓库</a></p>"
306
 
307
  examples=[['canon.flac'], ['download.wav']]
308
 
 
3
  import datetime
4
  import pytz
5
 
6
+ def current_time():
7
+ current = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y年-%m月-%d日 %H时:%M分:%S秒")
8
+ return current
9
 
10
+ print(f"[{current_time()}] 开始部署空间...")
11
 
12
  from pathlib import Path
13
+ print(f"[{current_time()}] 日志:安装 - gsutil")
14
  os.system("pip install gsutil")
15
+ print(f"[{current_time()}] 日志:Git - 克隆 Github T5X 训练框架到当前目录")
16
  os.system("git clone --branch=main https://github.com/google-research/t5x")
17
+ print(f"[{current_time()}] 日志:文件 - 移动 t5x 到当前目录并重命名为 t5x_tmp 并删除")
18
  os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
19
+ print(f"[{current_time()}] 日志:编辑 - 替换 setup.py 内的文本“jax[tpu]”为“jax")
20
  os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
21
+ print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
22
  os.system("python3 -m pip install -e .")
23
+ print(f"[{current_time()}] 日志:Python - 更新 Python 包管理器 pip")
24
  os.system("python3 -m pip install --upgrade pip")
25
 
26
 
27
  # 安装 mt3
28
+ print(f"[{current_time()}] 日志:Git - 克隆 Github MT3 模型到当前目录")
29
  os.system("git clone --branch=main https://github.com/magenta/mt3")
30
+ print(f"[{current_time()}] 日志:文件 - 移动 mt3 到当前目录并重命名为 mt3_tmp 并删除")
31
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
32
+ print(f"[{current_time()}] 日志:Python - 使用 pip 安装 当前目录内的 Python 包")
33
  os.system("python3 -m pip install -e .")
34
+ print(f"[{current_time()}] 日志:安装 - TensorFlow CPU")
35
  os.system("pip install tensorflow_cpu")
36
 
37
  # 复制检查点
38
+ print(f"[{current_time()}] 日志:gsutil - 复制 MT3 检查点到当前目录")
39
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
40
 
41
  # 复制 soundfont 文件(原始文件来自 https://sites.google.com/site/soundfonts4u)
42
+ print(f"[{current_time()}] 日志:gsutil - 复制 SoundFont 文件到当前目录")
43
  os.system("gsutil -q -m cp gs://magentadata/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 .")
44
 
45
  #@title 导入和定义
46
+ print(f"[{current_time()}] 日志:导入 - 必要工具")
47
  import functools
48
 
49
  import numpy as np
 
79
  audio, sample_rate=sample_rate)
80
 
81
 
82
+ print(f"[{current_time()}] 日志:开始包装模型...")
83
  class InferenceModel(object):
84
  """音乐转录的 T5X 模型包装器。"""
85
 
86
  def __init__(self, checkpoint_path, model_type='mt3'):
87
 
88
  # 模型常量。
 
89
  if model_type == 'ismir2021':
90
  num_velocity_bins = 127
91
  self.encoding_spec = note_sequences.NoteEncodingSpec
 
109
  model_parallel_submesh=None, num_partitions=1)
110
 
111
  # 构建编解码器和词汇表。
112
+ print(f"[{current_time()}] 日志:构建编解码器")
113
  self.spectrogram_config = spectrograms.SpectrogramConfig()
114
  self.codec = vocabularies.build_codec(
115
  vocab_config=vocabularies.VocabularyConfig(
 
121
  }
122
 
123
  # 创建 T5X 模型。
124
+ print(f"[{current_time()}] 日志:创建 T5X 模型")
125
  self._parse_gin(gin_files)
126
  self.model = self._load_model()
127
 
128
  # 从检查点中恢复。
129
+ print(f"[{current_time()}] 日志:恢复模型检查点")
130
  self.restore_from_checkpoint(checkpoint_path)
131
 
132
  @property
 
138
 
139
  def _parse_gin(self, gin_files):
140
  """解析用于训练模型的 gin 文件。"""
141
+ print(f"[{current_time()}] 日志:解析 gin 文件")
142
  gin_bindings = [
143
  'from __gin__ import dynamic_registration',
144
  'from mt3 import vocabularies',
 
151
 
152
  def _load_model(self):
153
  """在解析训练 gin 配置后加载 T5X `Model`。"""
154
+ print(f"[{current_time()}] 日志:加载 T5X 模型")
155
  model_config = gin.get_configurable(network.T5Config)()
156
  module = network.Transformer(config=model_config)
157
  return models.ContinuousInputsEncoderDecoderModel(
 
164
 
165
  def restore_from_checkpoint(self, checkpoint_path):
166
  """从检查点中恢复训练状态,重置 self._predict_fn()。"""
167
+ print(f"[{current_time()}] 日志:从检查点恢复训练状态")
168
  train_state_initializer = t5x.utils.TrainStateInitializer(
169
  optimizer_def=self.model.optimizer_def,
170
  init_fn=self.model.get_initial_variables,
 
182
  @functools.lru_cache()
183
  def _get_predict_fn(self, train_state_axes):
184
  """生成一个分区的预测函数用于解码。"""
185
+ print(f"[{current_time()}] 日志:生成用于解码的预测函数")
186
  def partial_predict_fn(params, batch, decode_rng):
187
  return self.model.predict_batch_with_aux(
188
  params, batch, decoder_params={'decode_rng': None})
 
196
 
197
  def predict_tokens(self, batch, seed=0):
198
  """从预处理的数据集批次中预测 tokens。"""
199
+ print(f"[{current_time()}] 运行:从预处理数据集中预测音符序列")
200
  prediction, _ = self._predict_fn(
201
  self._train_state.params, batch, jax.random.PRNGKey(seed))
202
  return self.vocabulary.decode_tf(prediction).numpy()
 
209
  返回:
210
  转录音频的音符序列。
211
  """
212
+ print(f"[{current_time()}] 运行:从音频样本中推断音符序列")
213
  ds = self.audio_to_dataset(audio)
214
  ds = self.preprocess(ds)
215
 
 
230
 
231
  def audio_to_dataset(self, audio):
232
  """从输入音频创建一个包含频谱图的 TF Dataset。"""
233
+ print(f"[{current_time()}] 运行:从音频创建包含频谱图的 TF Dataset")
234
  frames, frame_times = self._audio_to_frames(audio)
235
  return tf.data.Dataset.from_tensors({
236
  'inputs': frames,
 
239
 
240
  def _audio_to_frames(self, audio):
241
  """从音频计算频谱图帧。"""
242
+ print(f"[{current_time()}] 运行:从音频计算频谱图帧")
243
  frame_size = self.spectrogram_config.hop_width
244
+ padding = [0, frame_size提示 - len(audio) % frame_size]
245
  audio = np.pad(audio, padding, mode='constant')
246
  frames = spectrograms.split_audio(audio, self.spectrogram_config)
247
  num_frames = len(audio) // frame_size
 
290
 
291
 
292
  def inference(audio):
293
+ filename = os.path.basename(audio) # 获取输入文件的文件名
294
+ print(f"[{current_time()}] 运行:输入文件: {filename}")
295
  with open(audio, 'rb') as fd:
296
  contents = fd.read()
297
  audio = upload_audio(contents,sample_rate=16000)
 
298
  est_ns = inference_model(audio)
 
299
  note_seq.sequence_proto_to_midi_file(est_ns, './transcribed.mid')
 
300
  return './transcribed.mid'
301
 
302
  title = "MT3"
303
+ description = "MT3:多任务多音轨音乐转录的 Gradio 演示。要使用它,只需上传音频文件,或点击示例以查看效果。更多信息请参阅下面的链接。"
304
 
305
+ article = "<p style='text-align: center'>出错了?试试把文件转换为MP3后再上传吧~</p><p style='text-align: center'><a href='https://arxiv.org/abs/2111.03017' target='_blank'>MT3: 多任务多音轨音乐转录</a> | <a href='https://github.com/hmjz100/mt3' target='_blank'>Github 仓库</a></p>"
306
 
307
  examples=[['canon.flac'], ['download.wav']]
308