nakas commited on
Commit
61440ca
·
1 Parent(s): fed2fbd

Update audiocraft/data/audio.py

Browse files
Files changed (1) hide show
  1. audiocraft/data/audio.py +16 -7
audiocraft/data/audio.py CHANGED
@@ -72,7 +72,6 @@ def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
72
  def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73
  """FFMPEG-based audio file reading using PyAV bindings.
74
  Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
75
-
76
  Args:
77
  filepath (str or Path): Path to audio file to read.
78
  seek_time (float): Time at which to start reading in the file.
@@ -116,7 +115,6 @@ def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: floa
116
  def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
117
  duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
118
  """Read audio by picking the most appropriate backend tool based on the audio format.
119
-
120
  Args:
121
  filepath (str or Path): Path to audio file to read.
122
  seek_time (float): Time at which to start reading in the file.
@@ -152,7 +150,7 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
152
 
153
  def audio_write(stem_name: tp.Union[str, Path],
154
  wav: torch.Tensor, sample_rate: int,
155
- normalize: bool = True,
156
  strategy: str = 'peak', peak_clip_headroom_db: float = 1,
157
  rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
158
  loudness_compressor: bool = False,
@@ -161,6 +159,8 @@ def audio_write(stem_name: tp.Union[str, Path],
161
  """Convenience function for saving audio to disk. Returns the filename the audio was written to.
162
  Args:
163
  stem_name (str or Path): Filename without extension which will be added automatically.
 
 
164
  normalize (bool): if `True` (default), normalizes according to the prescribed
165
  strategy (see after). If `False`, the strategy is only used in case clipping
166
  would happen.
@@ -172,7 +172,7 @@ def audio_write(stem_name: tp.Union[str, Path],
172
  than the `peak_clip` one to avoid further clipping.
173
  loudness_headroom_db (float): Target loudness for loudness normalization.
174
  loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
175
- log_clipping (bool): If True, basic logging on stderr when clipping still
176
  occurs despite strategy (only for 'rms').
177
  make_parent_dir (bool): Make parent directory if it doesn't exist.
178
  Returns:
@@ -187,17 +187,26 @@ def audio_write(stem_name: tp.Union[str, Path],
187
  wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
188
  rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
189
  sample_rate=sample_rate, stem_name=str(stem_name))
190
- suffix = '.wav'
 
 
 
 
 
 
 
 
 
191
  if not add_suffix:
192
  suffix = ''
193
  path = Path(str(stem_name) + suffix)
194
  if make_parent_dir:
195
  path.parent.mkdir(exist_ok=True, parents=True)
196
  try:
197
- ta.save(path, wav, sample_rate)
198
  except Exception:
199
  if path.exists():
200
  # we do not want to leave half written files around.
201
  path.unlink()
202
  raise
203
- return path
 
72
  def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
73
  """FFMPEG-based audio file reading using PyAV bindings.
74
  Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
 
75
  Args:
76
  filepath (str or Path): Path to audio file to read.
77
  seek_time (float): Time at which to start reading in the file.
 
115
  def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
116
  duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
117
  """Read audio by picking the most appropriate backend tool based on the audio format.
 
118
  Args:
119
  filepath (str or Path): Path to audio file to read.
120
  seek_time (float): Time at which to start reading in the file.
 
150
 
151
  def audio_write(stem_name: tp.Union[str, Path],
152
  wav: torch.Tensor, sample_rate: int,
153
+ format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
154
  strategy: str = 'peak', peak_clip_headroom_db: float = 1,
155
  rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
156
  loudness_compressor: bool = False,
 
159
  """Convenience function for saving audio to disk. Returns the filename the audio was written to.
160
  Args:
161
  stem_name (str or Path): Filename without extension which will be added automatically.
162
+ format (str): Either "wav" or "mp3".
163
+ mp3_rate (int): kbps when using mp3s.
164
  normalize (bool): if `True` (default), normalizes according to the prescribed
165
  strategy (see after). If `False`, the strategy is only used in case clipping
166
  would happen.
 
172
  than the `peak_clip` one to avoid further clipping.
173
  loudness_headroom_db (float): Target loudness for loudness normalization.
174
  loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
175
+ when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
176
  occurs despite strategy (only for 'rms').
177
  make_parent_dir (bool): Make parent directory if it doesn't exist.
178
  Returns:
 
187
  wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
188
  rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
189
  sample_rate=sample_rate, stem_name=str(stem_name))
190
+ kwargs: dict = {}
191
+ if format == 'mp3':
192
+ suffix = '.mp3'
193
+ kwargs.update({"compression": mp3_rate})
194
+ elif format == 'wav':
195
+ wav = i16_pcm(wav)
196
+ suffix = '.wav'
197
+ kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
198
+ else:
199
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
200
  if not add_suffix:
201
  suffix = ''
202
  path = Path(str(stem_name) + suffix)
203
  if make_parent_dir:
204
  path.parent.mkdir(exist_ok=True, parents=True)
205
  try:
206
+ ta.save(path, wav, sample_rate, **kwargs)
207
  except Exception:
208
  if path.exists():
209
  # we do not want to leave half written files around.
210
  path.unlink()
211
  raise
212
+ return path