Felguk commited on
Commit
4e611fb
·
verified ·
1 Parent(s): 1c3ad73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -9,7 +9,7 @@ import imageio_ffmpeg
9
  import gradio as gr
10
  import torch
11
  from PIL import Image
12
- from transformers import pipeline, AutoProcessor, MusicgenForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer
13
  import torchaudio
14
  import numpy as np
15
  from datetime import datetime, timedelta
@@ -29,7 +29,21 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  # Load MusicGen model for music generation
31
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
32
- musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # Chatbot models
35
  CHATBOT_MODELS = {
@@ -82,9 +96,9 @@ def generate_music_function(prompt, length, genre, custom_genre, lyrics):
82
  padding=True,
83
  return_tensors="pt",
84
  )
85
- audio_values = musicgen_model.generate(**inputs, max_new_tokens=int(length * 50))
86
  output_file = "generated_music.wav"
87
- sampling_rate = musicgen_model.config.audio_encoder.sampling_rate
88
  torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate)
89
  return output_file
90
 
 
9
  import gradio as gr
10
  import torch
11
  from PIL import Image
12
+ from transformers import pipeline, AutoProcessor, MusicgenForCausalLM, AutoModelForCausalLM, AutoTokenizer
13
  import torchaudio
14
  import numpy as np
15
  from datetime import datetime, timedelta
 
29
 
30
  # Load MusicGen model for music generation
31
  processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
32
+ model = MusicgenForCausalLM.from_pretrained("facebook/musicgen-small")
33
+
34
+ # Explicitly set configurations to avoid conflicts
35
+ model.config.audio_encoder = {
36
+ "audio_channels": 1,
37
+ "codebook_dim": 128,
38
+ "codebook_size": 2048,
39
+ "sampling_rate": 32000,
40
+ }
41
+
42
+ model.config.decoder = {
43
+ "activation_dropout": 0.0,
44
+ "activation_function": "gelu",
45
+ "attention_dropout": 0.0,
46
+ }
47
 
48
  # Chatbot models
49
  CHATBOT_MODELS = {
 
96
  padding=True,
97
  return_tensors="pt",
98
  )
99
+ audio_values = model.generate(**inputs, max_new_tokens=int(length * 50))
100
  output_file = "generated_music.wav"
101
+ sampling_rate = model.config.audio_encoder["sampling_rate"]
102
  torchaudio.save(output_file, audio_values[0].cpu(), sampling_rate)
103
  return output_file
104