Spaces:
Sleeping
Sleeping
include audio gen
Browse files
gradio_components/prediction.py
CHANGED
@@ -8,13 +8,16 @@ import gradio as gr
|
|
8 |
import torch
|
9 |
from audiocraft.data.audio import audio_write
|
10 |
from audiocraft.data.audio_utils import convert_audio
|
11 |
-
from audiocraft.models import MusicGen
|
12 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
13 |
from transformers import AutoModelForSeq2SeqLM
|
14 |
|
15 |
|
16 |
def load_model(version="facebook/musicgen-melody"):
|
17 |
-
|
|
|
|
|
|
|
18 |
|
19 |
|
20 |
def _do_predictions(
|
|
|
8 |
import torch
|
9 |
from audiocraft.data.audio import audio_write
|
10 |
from audiocraft.data.audio_utils import convert_audio
|
11 |
+
from audiocraft.models import MusicGen, AudioGen
|
12 |
from basic_pitch import ICASSP_2022_MODEL_PATH
|
13 |
from transformers import AutoModelForSeq2SeqLM
|
14 |
|
15 |
|
16 |
def load_model(version="facebook/musicgen-melody"):
|
17 |
+
if version in ["facebook/audiogen-medium"]:
|
18 |
+
return AudioGen.get_pretrained(version)
|
19 |
+
else:
|
20 |
+
return MusicGen.get_pretrained(version)
|
21 |
|
22 |
|
23 |
def _do_predictions(
|