Yjhhh commited on
Commit
243993a
1 Parent(s): 6de284f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -12,9 +12,14 @@ import re
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
  # Cargar los modelos una sola vez
 
 
 
 
 
15
  models = {
16
- 'Large': MusicGen.get_pretrained('nateraw/musicgen-songstarter-v0.2').to(device),
17
- 'Small': MusicGen.get_pretrained('facebook/musicgen-small').to(device)
18
  }
19
 
20
  def get_model(model_choice):
@@ -33,12 +38,12 @@ def generate_music(description, melody_audio, duration, model_choice):
33
  description = [description]
34
  if melody_audio:
35
  melody, sr = torchaudio.load(melody_audio, normalize=True)
36
- melody = melody.to(device)
37
  wav = model.generate_with_chroma(description, melody[None], sr)
38
  else:
39
- wav = model.generate(description, use_amp=torch.cuda.is_available())
40
  else:
41
- wav = model.generate_unconditional(1, use_amp=torch.cuda.is_available())
42
 
43
  filename = f'{str(uuid.uuid4())}.wav'
44
  path = audio_write(filename, wav[0].cpu().to(torch.float32), model.sample_rate, strategy="loudness", loudness_compressor=True)
 
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
  # Cargar los modelos una sola vez
15
+ def load_model(model_name):
16
+ """Cargar el modelo especificado."""
17
+ return MusicGen.get_pretrained(model_name)
18
+
19
+ # Inicializar los modelos
20
  models = {
21
+ 'Large': load_model('nateraw/musicgen-songstarter-v0.2'),
22
+ 'Small': load_model('facebook/musicgen-small')
23
  }
24
 
25
  def get_model(model_choice):
 
38
  description = [description]
39
  if melody_audio:
40
  melody, sr = torchaudio.load(melody_audio, normalize=True)
41
+ melody = melody.to(device) if torch.cuda.is_available() else melody
42
  wav = model.generate_with_chroma(description, melody[None], sr)
43
  else:
44
+ wav = model.generate(description)
45
  else:
46
+ wav = model.generate_unconditional(1)
47
 
48
  filename = f'{str(uuid.uuid4())}.wav'
49
  path = audio_write(filename, wav[0].cpu().to(torch.float32), model.sample_rate, strategy="loudness", loudness_compressor=True)