jadechoghari
commited on
Update audioldm_train/modules/latent_diffusion/ddpm.py
Browse files
audioldm_train/modules/latent_diffusion/ddpm.py
CHANGED
@@ -1335,7 +1335,7 @@ class LatentDiffusion(DDPM):
|
|
1335 |
waveform = self.first_stage_model.vocoder(mel)
|
1336 |
waveform = waveform.cpu().detach().numpy()
|
1337 |
if save:
|
1338 |
-
self.save_waveform(waveform, savepath
|
1339 |
return waveform
|
1340 |
|
1341 |
def encode_first_stage(self, x):
|
@@ -1818,44 +1818,31 @@ class LatentDiffusion(DDPM):
|
|
1818 |
**kwargs,
|
1819 |
)
|
1820 |
|
1821 |
-
def save_waveform(self, waveform, savepath, name="
|
1822 |
-
|
1823 |
-
|
1824 |
-
|
1825 |
-
|
1826 |
-
|
1827 |
-
|
1828 |
-
|
1829 |
-
|
1830 |
-
|
1831 |
-
|
1832 |
-
|
1833 |
-
|
1834 |
-
|
1835 |
-
|
1836 |
-
|
1837 |
-
|
1838 |
-
|
1839 |
-
|
1840 |
-
|
1841 |
-
|
|
|
|
|
|
|
1842 |
|
1843 |
-
if (not ".wav" in name[i])
|
1844 |
-
else os.path.basename(name[i]).split(".")[0]
|
1845 |
-
),
|
1846 |
-
)
|
1847 |
-
else:
|
1848 |
-
# import pdb
|
1849 |
-
# pdb.set_trace()
|
1850 |
-
raise NotImplementedError
|
1851 |
-
todo_waveform = waveform[i, 0]
|
1852 |
-
todo_waveform = (
|
1853 |
-
todo_waveform / np.max(np.abs(todo_waveform))
|
1854 |
-
) * 0.8 # Normalize the energy of the generation output
|
1855 |
-
try:
|
1856 |
-
sf.write(path, todo_waveform, samplerate=self.sampling_rate)
|
1857 |
-
except:
|
1858 |
-
print('waveform name ERROR!!!!!!!!!!!!')
|
1859 |
|
1860 |
@torch.no_grad()
|
1861 |
def sample_log(
|
@@ -2054,7 +2041,7 @@ class LatentDiffusion(DDPM):
|
|
2054 |
print("Choose the following indexes:", best_index)
|
2055 |
except Exception as e:
|
2056 |
print("Warning: while calculating CLAP score (not fatal), ", e)
|
2057 |
-
self.save_waveform(waveform,
|
2058 |
return waveform_save_path
|
2059 |
|
2060 |
|
|
|
1335 |
waveform = self.first_stage_model.vocoder(mel)
|
1336 |
waveform = waveform.cpu().detach().numpy()
|
1337 |
if save:
|
1338 |
+
self.save_waveform(waveform, savepath="./")
|
1339 |
return waveform
|
1340 |
|
1341 |
def encode_first_stage(self, x):
|
|
|
1818 |
**kwargs,
|
1819 |
)
|
1820 |
|
1821 |
+
def save_waveform(self, waveform, savepath="./", name="awesome.wav", n_gen=1):
|
1822 |
+
print(f'debug_name : {name}')
|
1823 |
+
|
1824 |
+
# If `name` is a list, join the elements into a string or select the first element
|
1825 |
+
if isinstance(name, list):
|
1826 |
+
name = "_".join(name) # Joins the list elements with an underscore
|
1827 |
+
name += ".wav" # Ensures the file has a `.wav` extension
|
1828 |
+
elif not isinstance(name, str):
|
1829 |
+
raise TypeError("Name must be a string or list")
|
1830 |
+
|
1831 |
+
# Normalize the energy of the waveform
|
1832 |
+
todo_waveform = waveform[0, 0] # Assuming you are only saving the first waveform
|
1833 |
+
todo_waveform = (todo_waveform / np.max(np.abs(todo_waveform))) * 0.8
|
1834 |
+
|
1835 |
+
# Define the path where to save the file
|
1836 |
+
path = os.path.join(savepath, name)
|
1837 |
+
|
1838 |
+
try:
|
1839 |
+
# Save the waveform to the specified path
|
1840 |
+
sf.write(path, todo_waveform, samplerate=self.sampling_rate)
|
1841 |
+
print(f'Waveform saved at -> {path}')
|
1842 |
+
except Exception as e:
|
1843 |
+
print(f'Error saving waveform: {e}')
|
1844 |
+
|
1845 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1846 |
|
1847 |
@torch.no_grad()
|
1848 |
def sample_log(
|
|
|
2041 |
print("Choose the following indexes:", best_index)
|
2042 |
except Exception as e:
|
2043 |
print("Warning: while calculating CLAP score (not fatal), ", e)
|
2044 |
+
self.save_waveform(waveform, savepath="./")
|
2045 |
return waveform_save_path
|
2046 |
|
2047 |
|