teticio commited on
Commit
fdc373f
1 Parent(s): 62617b3

different seed for latent and denoisng

Browse files
Files changed (1) hide show
  1. audiodiffusion/__init__.py +50 -24
audiodiffusion/__init__.py CHANGED
@@ -59,37 +59,44 @@ class AudioDiffusion:
59
  top_db=top_db)
60
 
61
  def generate_spectrogram_and_audio(
62
- self,
63
- steps: int = 1000,
64
- generator: torch.Generator = None
65
- ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
 
66
  """Generate random mel spectrogram and convert to audio.
67
 
68
  Args:
69
  steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
70
  generator (torch.Generator): random number generator or None
 
 
71
 
72
  Returns:
73
  PIL Image: mel spectrogram
74
  (float, np.ndarray): sample rate and raw audio
75
  """
76
- images, (sample_rate, audios) = self.pipe(mel=self.mel,
77
- batch_size=1,
78
- steps=steps,
79
- generator=generator)
 
 
 
80
  return images[0], (sample_rate, audios[0])
81
 
82
  def generate_spectrogram_and_audio_from_audio(
83
- self,
84
- audio_file: str = None,
85
- raw_audio: np.ndarray = None,
86
- slice: int = 0,
87
- start_step: int = 0,
88
- steps: int = 1000,
89
- generator: torch.Generator = None,
90
- mask_start_secs: float = 0,
91
- mask_end_secs: float = 0
92
- ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
 
93
  """Generate random mel spectrogram from audio input and convert to audio.
94
 
95
  Args:
@@ -101,6 +108,8 @@ class AudioDiffusion:
101
  generator (torch.Generator): random number generator or None
102
  mask_start_secs (float): number of seconds of audio to mask (not generate) at start
103
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
 
 
104
 
105
  Returns:
106
  PIL Image: mel spectrogram
@@ -117,7 +126,9 @@ class AudioDiffusion:
117
  steps=steps,
118
  generator=generator,
119
  mask_start_secs=mask_start_secs,
120
- mask_end_secs=mask_end_secs)
 
 
121
  return images[0], (sample_rate, audios[0])
122
 
123
  @staticmethod
@@ -160,7 +171,9 @@ class AudioDiffusionPipeline(DiffusionPipeline):
160
  steps: int = 1000,
161
  generator: torch.Generator = None,
162
  mask_start_secs: float = 0,
163
- mask_end_secs: float = 0
 
 
164
  ) -> Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]:
165
  """Generate random mel spectrogram from audio input and convert to audio.
166
 
@@ -175,6 +188,8 @@ class AudioDiffusionPipeline(DiffusionPipeline):
175
  generator (torch.Generator): random number generator or None
176
  mask_start_secs (float): number of seconds of audio to mask (not generate) at start
177
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
 
 
178
 
179
  Returns:
180
  List[PIL Image]: mel spectrograms
@@ -182,6 +197,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
182
  """
183
 
184
  self.scheduler.set_timesteps(steps)
 
185
  mask = None
186
  images = noise = torch.randn(
187
  (batch_size, self.unet.in_channels, mel.y_res, mel.x_res),
@@ -218,10 +234,20 @@ class AudioDiffusionPipeline(DiffusionPipeline):
218
  for step, t in enumerate(
219
  self.progress_bar(self.scheduler.timesteps[start_step:])):
220
  model_output = self.unet(images, t)['sample']
221
- images = self.scheduler.step(model_output,
222
- t,
223
- images,
224
- generator=generator)['prev_sample']
 
 
 
 
 
 
 
 
 
 
225
 
226
  if mask is not None:
227
  if mask_start > 0:
 
59
  top_db=top_db)
60
 
61
  def generate_spectrogram_and_audio(
62
+ self,
63
+ steps: int = 1000,
64
+ generator: torch.Generator = None,
65
+ step_generator: torch.Generator = None,
66
+ eta: float = 0) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
67
  """Generate random mel spectrogram and convert to audio.
68
 
69
  Args:
70
  steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
71
  generator (torch.Generator): random number generator or None
72
+ step_generator (torch.Generator): random number generator used to denoise or None
73
+ eta (float): parameter between 0 and 1 used with DDIM scheduler
74
 
75
  Returns:
76
  PIL Image: mel spectrogram
77
  (float, np.ndarray): sample rate and raw audio
78
  """
79
+ images, (sample_rate,
80
+ audios) = self.pipe(mel=self.mel,
81
+ batch_size=1,
82
+ steps=steps,
83
+ generator=generator,
84
+ step_generator=step_generator,
85
+ eta=eta)
86
  return images[0], (sample_rate, audios[0])
87
 
88
  def generate_spectrogram_and_audio_from_audio(
89
+ self,
90
+ audio_file: str = None,
91
+ raw_audio: np.ndarray = None,
92
+ slice: int = 0,
93
+ start_step: int = 0,
94
+ steps: int = 1000,
95
+ generator: torch.Generator = None,
96
+ mask_start_secs: float = 0,
97
+ mask_end_secs: float = 0,
98
+ step_generator: torch.Generator = None,
99
+ eta: float = 0) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
100
  """Generate random mel spectrogram from audio input and convert to audio.
101
 
102
  Args:
 
108
  generator (torch.Generator): random number generator or None
109
  mask_start_secs (float): number of seconds of audio to mask (not generate) at start
110
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
111
+ step_generator (torch.Generator): random number generator used to denoise or None
112
+ eta (float): parameter between 0 and 1 used with DDIM scheduler
113
 
114
  Returns:
115
  PIL Image: mel spectrogram
 
126
  steps=steps,
127
  generator=generator,
128
  mask_start_secs=mask_start_secs,
129
+ mask_end_secs=mask_end_secs,
130
+ step_generator=step_generator,
131
+ eta=eta)
132
  return images[0], (sample_rate, audios[0])
133
 
134
  @staticmethod
 
171
  steps: int = 1000,
172
  generator: torch.Generator = None,
173
  mask_start_secs: float = 0,
174
+ mask_end_secs: float = 0,
175
+ step_generator: torch.Generator = None,
176
+ eta: float = 0
177
  ) -> Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]:
178
  """Generate random mel spectrogram from audio input and convert to audio.
179
 
 
188
  generator (torch.Generator): random number generator or None
189
  mask_start_secs (float): number of seconds of audio to mask (not generate) at start
190
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
191
+ step_generator (torch.Generator): random number generator used to denoise or None
192
+ eta (float): parameter between 0 and 1 used with DDIM scheduler
193
 
194
  Returns:
195
  List[PIL Image]: mel spectrograms
 
197
  """
198
 
199
  self.scheduler.set_timesteps(steps)
200
+ step_generator = step_generator or generator
201
  mask = None
202
  images = noise = torch.randn(
203
  (batch_size, self.unet.in_channels, mel.y_res, mel.x_res),
 
234
  for step, t in enumerate(
235
  self.progress_bar(self.scheduler.timesteps[start_step:])):
236
  model_output = self.unet(images, t)['sample']
237
+
238
+ if isinstance(self.scheduler, DDIMScheduler):
239
+ images = self.scheduler.step(
240
+ model_output=model_output,
241
+ timestep=t,
242
+ sample=images,
243
+ eta=eta,
244
+ generator=step_generator)['prev_sample']
245
+ else:
246
+ images = self.scheduler.step(
247
+ model_output=model_output,
248
+ timestep=t,
249
+ sample=images,
250
+ generator=step_generator)['prev_sample']
251
 
252
  if mask is not None:
253
  if mask_start > 0: