Gpagejr12 commited on
Commit
2cae516
·
verified ·
1 Parent(s): 7a5d65f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -36,7 +36,11 @@ def generate_music_tensors(descriptions, duration: int):
36
 
37
  def save_audio(samples: torch.Tensor):
38
  sample_rate = 30000
39
- save_path = "/audio_output"
 
 
 
 
40
  assert samples.dim() == 2 or samples.dim() == 3
41
 
42
  samples = samples.detach().cpu()
@@ -45,14 +49,14 @@ def save_audio(samples: torch.Tensor):
45
 
46
  for idx, audio in enumerate(samples):
47
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
48
- torchaudio.save(audio_path, audio, sample_rate)
49
-
50
- def get_binary_file_downloader_html(bin_file, file_label='File'):
51
- with open(bin_file, 'rb') as f:
52
- data = f.read()
53
- bin_str = base64.b64encode(data).decode()
54
- href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
55
- return href
56
 
57
  st.set_page_config(
58
  page_icon= "musical_note",
 
36
 
37
  def save_audio(samples: torch.Tensor):
38
  sample_rate = 30000
39
+ save_path = "/tmp/audio_output" # Use /tmp directory
40
+
41
+ if not os.path.exists(save_path):
42
+ os.makedirs(save_path)
43
+
44
  assert samples.dim() == 2 or samples.dim() == 3
45
 
46
  samples = samples.detach().cpu()
 
49
 
50
  for idx, audio in enumerate(samples):
51
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
52
+ try:
53
+ torchaudio.save(audio_path, audio, sample_rate)
54
+ except Exception as e:
55
+ st.error(f"Error saving audio file: {e}")
56
+ return None
57
+
58
+ return save_path
59
+
60
 
61
  st.set_page_config(
62
  page_icon= "musical_note",