suric commited on
Commit
a99ae87
·
1 Parent(s): 6a24aec

include audio gen

Browse files
Files changed (1) hide show
  1. gradio_components/prediction.py +5 -2
gradio_components/prediction.py CHANGED
@@ -8,13 +8,16 @@ import gradio as gr
8
  import torch
9
  from audiocraft.data.audio import audio_write
10
  from audiocraft.data.audio_utils import convert_audio
11
- from audiocraft.models import MusicGen
12
  from basic_pitch import ICASSP_2022_MODEL_PATH
13
  from transformers import AutoModelForSeq2SeqLM
14
 
15
 
16
  def load_model(version="facebook/musicgen-melody"):
17
- return MusicGen.get_pretrained(version)
 
 
 
18
 
19
 
20
  def _do_predictions(
 
8
  import torch
9
  from audiocraft.data.audio import audio_write
10
  from audiocraft.data.audio_utils import convert_audio
11
+ from audiocraft.models import MusicGen, AudioGen
12
  from basic_pitch import ICASSP_2022_MODEL_PATH
13
  from transformers import AutoModelForSeq2SeqLM
14
 
15
 
16
  def load_model(version="facebook/musicgen-melody"):
17
+ if version in ["facebook/audiogen-medium"]:
18
+ return AudioGen.get_pretrained(version)
19
+ else:
20
+ return MusicGen.get_pretrained(version)
21
 
22
 
23
  def _do_predictions(