oniati commited on
Commit
2c1adaa
·
1 Parent(s): ed59c7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -6,18 +6,15 @@ from pathlib import Path
6
  os.system("pip install gsutil")
7
 
8
 
9
- os.system("git clone --branch=main https://github.com/google-research/t5x")
10
- os.system("mv t5x t5x_tmp; mv t5x_tmp/* .; rm -r t5x_tmp")
11
- os.system("sed -i 's:jax\[tpu\]:jax:' setup.py")
12
- os.system("python3 -m pip install -e .")
13
- os.system("python3 -m pip install --upgrade pip")
14
 
15
 
16
 
17
  # install mt3
18
  os.system("git clone --branch=main https://github.com/magenta/mt3")
19
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
20
- os.system("python3 -m pip install -e .")
 
21
 
22
  # copy checkpoints
23
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
@@ -35,18 +32,13 @@ import functools
35
  import os
36
 
37
  import numpy as np
38
-
39
  import tensorflow.compat.v2 as tf
40
 
41
  import functools
42
  import gin
43
- import jax.linear_util
44
- jax.extend.linear_util = jax.linear_util
45
  import librosa
46
  import note_seq
47
-
48
-
49
-
50
  import seqio
51
  import t5
52
  import t5x
@@ -59,6 +51,7 @@ from mt3 import preprocessors
59
  from mt3 import spectrograms
60
  from mt3 import vocabularies
61
 
 
62
 
63
  import nest_asyncio
64
  nest_asyncio.apply()
@@ -66,9 +59,12 @@ nest_asyncio.apply()
66
  SAMPLE_RATE = 16000
67
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
68
 
69
- def upload_audio(audio, sample_rate):
 
 
 
70
  return note_seq.audio_io.wav_data_to_samples_librosa(
71
- audio, sample_rate=sample_rate)
72
 
73
 
74
 
@@ -89,16 +85,16 @@ class InferenceModel(object):
89
  else:
90
  raise ValueError('unknown model_type: %s' % model_type)
91
 
92
- gin_files = ['/home/user/app/mt3/gin/model.gin',
93
- '/home/user/app/mt3/gin/mt3.gin']
94
 
95
  self.batch_size = 8
96
  self.outputs_length = 1024
97
- self.sequence_length = {'inputs': self.inputs_length,
98
  'targets': self.outputs_length}
99
 
100
  self.partitioner = t5x.partitioning.PjitPartitioner(
101
- model_parallel_submesh=(1, 1, 1, 1), num_partitions=1)
102
 
103
  # Build Codecs and Vocabularies.
104
  self.spectrogram_config = spectrograms.SpectrogramConfig()
@@ -187,9 +183,10 @@ class InferenceModel(object):
187
 
188
  def __call__(self, audio):
189
  """Infer note sequence from audio samples.
190
-
191
  Args:
192
  audio: 1-d numpy array of audio samples (16kHz) for a single example.
 
193
  Returns:
194
  A note_sequence of the transcribed audio.
195
  """
 
6
  os.system("pip install gsutil")
7
 
8
 
9
+
 
 
 
 
10
 
11
 
12
 
13
  # install mt3
14
  os.system("git clone --branch=main https://github.com/magenta/mt3")
15
  os.system("mv mt3 mt3_tmp; mv mt3_tmp/* .; rm -r mt3_tmp")
16
+ os.system("python3 -m pip install jax[cuda11_local] nest-asyncio pyfluidsynth==1.3.0 -e . -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
17
+ ")
18
 
19
  # copy checkpoints
20
  os.system("gsutil -q -m cp -r gs://mt3/checkpoints .")
 
32
  import os
33
 
34
  import numpy as np
 
35
  import tensorflow.compat.v2 as tf
36
 
37
  import functools
38
  import gin
39
+ import jax
 
40
  import librosa
41
  import note_seq
 
 
 
42
  import seqio
43
  import t5
44
  import t5x
 
51
  from mt3 import spectrograms
52
  from mt3 import vocabularies
53
 
54
+ from google.colab import files
55
 
56
  import nest_asyncio
57
  nest_asyncio.apply()
 
59
  SAMPLE_RATE = 16000
60
  SF2_PATH = 'SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'
61
 
62
+ def upload_audio(sample_rate):
63
+ data = list(files.upload().values())
64
+ if len(data) > 1:
65
+ print('Multiple files uploaded; using only one.')
66
  return note_seq.audio_io.wav_data_to_samples_librosa(
67
+ data[0], sample_rate=sample_rate)
68
 
69
 
70
 
 
85
  else:
86
  raise ValueError('unknown model_type: %s' % model_type)
87
 
88
+ gin_files = ['/content/mt3/gin/model.gin',
89
+ f'/content/mt3/gin/{model_type}.gin']
90
 
91
  self.batch_size = 8
92
  self.outputs_length = 1024
93
+ self.sequence_length = {'inputs': self.inputs_length,
94
  'targets': self.outputs_length}
95
 
96
  self.partitioner = t5x.partitioning.PjitPartitioner(
97
+ num_partitions=1)
98
 
99
  # Build Codecs and Vocabularies.
100
  self.spectrogram_config = spectrograms.SpectrogramConfig()
 
183
 
184
  def __call__(self, audio):
185
  """Infer note sequence from audio samples.
186
+
187
  Args:
188
  audio: 1-d numpy array of audio samples (16kHz) for a single example.
189
+
190
  Returns:
191
  A note_sequence of the transcribed audio.
192
  """