Uhhy commited on
Commit
ed0dbc1
1 Parent(s): ac307eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -64
app.py CHANGED
@@ -1,83 +1,66 @@
1
- import os
2
- import uuid
3
- import torch
4
- import re
5
  import gradio as gr
6
  import torchaudio
7
  from audiocraft.models import MusicGen
8
  from audiocraft.data.audio import audio_write
9
- import spaces # Importar spaces
10
-
11
- # Decorador para gestionar el uso de GPU
12
- def gpu_decorator(duration):
13
- def decorator(func):
14
- def wrapper(*args, **kwargs):
15
- with spaces.GPU(duration=duration): # Solicitar GPU por el tiempo especificado
16
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
- return func(*args, device=device, **kwargs) # Pasar el dispositivo a la función
18
- return wrapper
19
- return decorator
20
 
21
- # Cargar el modelo `musicgen-small` una única vez
22
- model = MusicGen.get_pretrained("facebook/musicgen-small")
23
- model.to(torch.device('cpu')) # Inicialmente configurar el modelo para CPU
24
 
25
- @gpu_decorator(duration=120) # Decorar la función con el uso de GPU
26
- def generate_music(description, melody_audio, duration, device):
27
- # Limpiar el texto de la descripción
28
- description = clean_text(description)
29
- model.set_generation_params(duration=int(duration * 1000)) # Convertir segundos a milisegundos
30
 
31
- try:
32
- # Cambiar el modelo a GPU si está disponible
33
- model.to(device)
 
 
34
 
35
- with torch.no_grad():
36
- if description:
37
- description = [description]
38
- if melody_audio:
39
- # Cargar el archivo de audio para remixar
40
- melody, sr = torchaudio.load(melody_audio, normalize=True)
41
- melody = melody.to(device)
42
- wav = model.generate_with_chroma(description, melody[None], sr)
43
- else:
44
- wav = model.generate(description)
45
  else:
46
- wav = model.generate_unconditional(1)
47
-
48
- # Guardar el archivo de música generado
49
- filename = f'{str(uuid.uuid4())}.wav'
50
- path = audio_write(filename, wav[0].cpu().to(torch.float32), model.sample_rate, strategy="loudness", loudness_compressor=True)
51
-
52
- if not os.path.exists(path):
53
- raise ValueError(f'Failed to save audio to {path}')
54
-
55
- return path
56
-
57
- except Exception as e:
58
- return str(e)
59
-
60
- def clean_text(text):
61
- text = re.sub(r'http\S+', '', text)
62
- text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
63
- return text
64
 
65
- # Definir la interfaz de Gradio
66
- description = gr.Textbox(label="Description", placeholder="Acoustic, guitar, melody, trap, D minor, 90 bpm")
 
 
67
  melody_audio = gr.Audio(label="Melody Audio (optional)", type="filepath")
68
- duration = gr.Number(label="Duration (seconds)", value=10, precision=0)
69
- output_path = gr.File(label="Generated Music")
70
 
71
  gr.Interface(
72
  fn=generate_music,
73
  inputs=[description, melody_audio, duration],
74
  outputs=output_path,
75
- title="MusicGen Melody Demo",
76
- description="Generate music using the MusicGen melody model. Optionally remix with an audio file. Download the generated audio file.",
77
  examples=[
78
- ["happy rock", None, 8],
79
- ["energetic EDM", None, 8],
80
- ["chillwave", "./assets/example_melody.mp3", 10]
81
  ]
82
  ).launch()
83
-
 
1
+ import spaces
 
 
 
2
  import gradio as gr
3
  import torchaudio
4
  from audiocraft.models import MusicGen
5
  from audiocraft.data.audio import audio_write
6
+ import logging
7
+ import os
8
+ import uuid
9
+ from torch.cuda.amp import autocast
10
+ import torch
 
 
 
 
 
 
11
 
12
+ # Configura o logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
14
 
15
+ logging.info("Carregando o modelo pré-treinado.")
16
+ model = MusicGen.get_pretrained('nateraw/musicgen-songstarter-v0.2')
 
 
 
17
 
18
+ @spaces.GPU(duration=120)
19
+ def generate_music(description, melody_audio, duration):
20
+ with autocast():
21
+ logging.info("Iniciando a geração de música.")
22
+ model.set_generation_params(duration=duration)
23
 
24
+ if description:
25
+ description = [description]
26
+ if melody_audio:
27
+ logging.info(f"Carregando a melodia de áudio de: {melody_audio}")
28
+ melody, sr = torchaudio.load(melody_audio)
29
+ logging.info("Gerando música com descrição e melodia.")
30
+ wav = model.generate_with_chroma(description, melody[None], sr)
 
 
 
31
  else:
32
+ logging.info("Gerando música apenas com descrição.")
33
+ wav = model.generate(description)
34
+ else:
35
+ logging.info("Gerando música de forma incondicional.")
36
+ wav = model.generate_unconditional(1)
37
+
38
+ filename = f'{str(uuid.uuid4())}.wav'
39
+ logging.info(f"Salvando a música gerada com o nome: {filename}")
40
+ path = audio_write(filename, wav[0].cpu().to(torch.float32), model.sample_rate, strategy="loudness", loudness_compressor=True)
41
+ print("Música salva em", path, ".")
42
+ # Verifica a forma do tensor de áudio e se foi salvo corretamente
43
+ logging.info(f"A forma do tensor de áudio gerado: {wav[0].shape}")
44
+ logging.info("Música gerada e salva com sucesso.")
45
+ if not os.path.exists(path):
46
+ raise ValueError(f'Failed to save audio to {path}')
 
 
 
47
 
48
+ return path
49
+
50
+ # Define a interface Gradio
51
+ description = gr.Textbox(label="Description", placeholder="acoustic, guitar, melody, trap, d minor, 90 bpm")
52
  melody_audio = gr.Audio(label="Melody Audio (optional)", type="filepath")
53
+ duration = gr.Slider(label="Duration (seconds)", minimum=10, maximum=600, step=10, value=30) # Máximo 10 minutos (600 segundos)
54
+ output_path = gr.Audio(label="Generated Music", type="filepath")
55
 
56
  gr.Interface(
57
  fn=generate_music,
58
  inputs=[description, melody_audio, duration],
59
  outputs=output_path,
60
+ title="MusicGen Demo",
61
+ description="Generate music using the MusicGen model.",
62
  examples=[
63
+ ["trap, synthesizer, songstarters, dark, G# minor, 140 bpm", "./assets/kalhonaho.mp3", 30],
64
+ ["upbeat, electronic, synth, dance, 120 bpm", None, 60]
 
65
  ]
66
  ).launch()