Gpagejr12 commited on
Commit
58fc0c2
·
verified ·
1 Parent(s): bdb5f86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -16,10 +16,8 @@ def load_model():
16
  model = MusicGen.get_pretrained('facebook/musicgen-small')
17
  return model
18
 
19
- def generate_music_tensors(descriptions, duration: int):
20
- # Create a new instance of the model with the desired device
21
- device = torch.device('cpu')
22
- model = MusicGen.get_pretrained('facebook/musicgen-small').to(device)
23
 
24
  model.set_generation_params(
25
  use_sampling=True,
@@ -27,11 +25,13 @@ def generate_music_tensors(descriptions, duration: int):
27
  duration=duration
28
  )
29
 
30
- with st.spinner("Generating Music..."):
 
31
  output = model.generate(
32
  descriptions=descriptions,
33
  progress=True,
34
- return_tokens=True
 
35
  )
36
 
37
  st.success("Music Generation Complete!")
 
16
  model = MusicGen.get_pretrained('facebook/musicgen-small')
17
  return model
18
 
19
+ def generate_music_tensors(descriptions, duration: int, device):
20
+ model = load_model()
 
 
21
 
22
  model.set_generation_params(
23
  use_sampling=True,
 
25
  duration=duration
26
  )
27
 
28
+ # Set the device during the forward pass
29
+ with torch.no_grad():
30
  output = model.generate(
31
  descriptions=descriptions,
32
  progress=True,
33
+ return_tokens=True,
34
+ device=device
35
  )
36
 
37
  st.success("Music Generation Complete!")