Update app.py
Browse files
app.py
CHANGED
@@ -9,7 +9,7 @@ import imageio_ffmpeg
|
|
9 |
import gradio as gr
|
10 |
import torch
|
11 |
from PIL import Image
|
12 |
-
from transformers import pipeline, AutoProcessor,
|
13 |
import torchaudio
|
14 |
import numpy as np
|
15 |
from datetime import datetime, timedelta
|
@@ -29,7 +29,21 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
29 |
|
30 |
# Load MusicGen model for music generation
|
31 |
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
# Chatbot models
|
35 |
CHATBOT_MODELS = {
|
@@ -82,9 +96,9 @@ def generate_music_function(prompt, length, genre, custom_genre, lyrics):
|
|
82 |
padding=True,
|
83 |
return_tensors="pt",
|
84 |
)
|
85 |
-
audio_values =
|
86 |
output_file = "generated_music.wav"
|
87 |
-
sampling_rate =
|
88 |
torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate)
|
89 |
return output_file
|
90 |
|
|
|
9 |
import gradio as gr
|
10 |
import torch
|
11 |
from PIL import Image
|
12 |
+
from transformers import pipeline, AutoProcessor, MusicgenForCausalLM, AutoModelForCausalLM, AutoTokenizer
|
13 |
import torchaudio
|
14 |
import numpy as np
|
15 |
from datetime import datetime, timedelta
|
|
|
29 |
|
30 |
# Load MusicGen model for music generation
|
31 |
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
32 |
+
model = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small")
|
33 |
+
|
34 |
+
# Explicitly set configurations to avoid conflicts
|
35 |
+
model.config.audio_encoder = {
|
36 |
+
"audio_channels": 1,
|
37 |
+
"codebook_dim": 128,
|
38 |
+
"codebook_size": 2048,
|
39 |
+
"sampling_rate": 32000,
|
40 |
+
}
|
41 |
+
|
42 |
+
model.config.decoder = {
|
43 |
+
"activation_dropout": 0.0,
|
44 |
+
"activation_function": "gelu",
|
45 |
+
"attention_dropout": 0.0,
|
46 |
+
}
|
47 |
|
48 |
# Chatbot models
|
49 |
CHATBOT_MODELS = {
|
|
|
96 |
padding=True,
|
97 |
return_tensors="pt",
|
98 |
)
|
99 |
+
audio_values = model.generate(**inputs, max_new_tokens=int(length * 50))
|
100 |
output_file = "generated_music.wav"
|
101 |
+
sampling_rate = model.config.audio_encoder["sampling_rate"]
|
102 |
torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate)
|
103 |
return output_file
|
104 |
|