chong.zhang commited on
Commit
2c50d95
·
1 Parent(s): b6363bb
inspiremusic/.DS_Store DELETED
Binary file (8.2 kB)
 
inspiremusic/bin/inference.py CHANGED
@@ -28,7 +28,6 @@ from inspiremusic.cli.model import InspireMusicModel
28
  from inspiremusic.dataset.dataset import Dataset
29
  import time
30
  from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
31
- from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
32
 
33
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
34
 
@@ -42,6 +41,7 @@ def get_args():
42
  parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
43
  parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
44
  parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.')
 
45
  parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model')
46
  parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio')
47
  parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds')
@@ -53,7 +53,7 @@ def get_args():
53
  help='sampling rate of input audio')
54
  parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000],
55
  help='sampling rate of generated output audio')
56
- parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, required=False,
57
  help='the minimum generated audio length in seconds')
58
  parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
59
  help='the maximum generated audio length in seconds')
@@ -70,9 +70,9 @@ def get_args():
70
  print(args)
71
  return args
72
 
73
-
74
  def main():
75
  args = get_args()
 
76
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
77
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
78
 
@@ -85,11 +85,20 @@ def main():
85
 
86
  # Init inspiremusic models from configs
87
  use_cuda = args.gpu >= 0 and torch.cuda.is_available()
88
- device = torch.device('cuda' if use_cuda else 'cpu')
 
 
 
 
 
 
 
 
 
89
  with open(args.config, 'r') as f:
90
  configs = load_hyperpyyaml(f)
91
 
92
- model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.fast, args.fp16)
93
 
94
  model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
95
 
@@ -153,7 +162,7 @@ def main():
153
  time_end = batch["time_end"].to(device)
154
  chorus = batch["chorus"].to(torch.int)
155
 
156
- text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{MUSIC_STRUCTURE_LABELS[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
157
  chorus = chorus.to(device)
158
 
159
  if batch["acoustic_token"] is None:
 
28
  from inspiremusic.dataset.dataset import Dataset
29
  import time
30
  from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
 
31
 
32
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
33
 
 
41
  parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
42
  parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
43
  parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.')
44
+ parser.add_argument('--dtype', type=str, default="fp16", required=False, choices=["fp16", "bf16", "fp32"], help='data type')
45
  parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model')
46
  parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio')
47
  parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds')
 
53
  help='sampling rate of input audio')
54
  parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000],
55
  help='sampling rate of generated output audio')
56
+ parser.add_argument('--min_generate_audio_seconds', type=float, default=0.0, required=False,
57
  help='the minimum generated audio length in seconds')
58
  parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
59
  help='the maximum generated audio length in seconds')
 
70
  print(args)
71
  return args
72
 
 
73
  def main():
74
  args = get_args()
75
+ chorus_labels = ["intro", "verse1", "chorus", "verse2", "outro"]
76
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
77
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
78
 
 
85
 
86
  # Init inspiremusic models from configs
87
  use_cuda = args.gpu >= 0 and torch.cuda.is_available()
88
+ if args.gpu >=0:
89
+ if torch.cuda.is_available():
90
+ device = torch.device('cuda')
91
+ elif torch.backends.mps.is_available():
92
+ device = torch.device('mps')
93
+ elif torch.xpu.is_available():
94
+ device = torch.device('xpu')
95
+ else:
96
+ device = torch.device('cpu')
97
+
98
  with open(args.config, 'r') as f:
99
  configs = load_hyperpyyaml(f)
100
 
101
+ model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.dtype, args.fast, args.fp16)
102
 
103
  model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
104
 
 
162
  time_end = batch["time_end"].to(device)
163
  chorus = batch["chorus"].to(torch.int)
164
 
165
+ text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{chorus_labels[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
166
  chorus = chorus.to(device)
167
 
168
  if batch["acoustic_token"] is None:
inspiremusic/cli/frontend.py CHANGED
@@ -29,6 +29,7 @@ class InspireMusicFrontEnd:
29
  music_tokenizer_dir: str,
30
  audio_tokenizer_dir: str,
31
  instruct: bool = False,
 
32
  fast: bool = False,
33
  fp16: bool = True,
34
  allowed_special: str = 'all'):
@@ -39,7 +40,7 @@ class InspireMusicFrontEnd:
39
  self.bandwidth_id = torch.tensor([0]).to(self.device)
40
  self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device)
41
 
42
- self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16)
43
  self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir)
44
 
45
  self.instruct = instruct
@@ -69,12 +70,10 @@ class InspireMusicFrontEnd:
69
  text = text.replace(" - ", ",")
70
  text = remove_bracket(text)
71
  text = re.sub(r'[,,]+$', '。', text)
72
- texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
73
- token_min_n=60, merge_len=20, comma_split=False))
74
  else:
75
  text = spell_out_number(text, self.inflect_parser)
76
- texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
77
- token_min_n=60, merge_len=20, comma_split=False))
78
  if split is False:
79
  return text
80
  return texts
 
29
  music_tokenizer_dir: str,
30
  audio_tokenizer_dir: str,
31
  instruct: bool = False,
32
+ dtype: str = "fp16",
33
  fast: bool = False,
34
  fp16: bool = True,
35
  allowed_special: str = 'all'):
 
40
  self.bandwidth_id = torch.tensor([0]).to(self.device)
41
  self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device)
42
 
43
+ self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
44
  self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir)
45
 
46
  self.instruct = instruct
 
70
  text = text.replace(" - ", ",")
71
  text = remove_bracket(text)
72
  text = re.sub(r'[,,]+$', '。', text)
73
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
 
74
  else:
75
  text = spell_out_number(text, self.inflect_parser)
76
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
 
77
  if split is False:
78
  return text
79
  return texts
inspiremusic/cli/inference.py CHANGED
@@ -23,53 +23,60 @@ from inspiremusic.utils.file_utils import logging
23
  import torch
24
  from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
25
 
26
- def set_env_variables():
27
  os.environ['PYTHONIOENCODING'] = 'UTF-8'
28
  os.environ['TOKENIZERS_PARALLELISM'] = 'False'
29
- main_root = os.getcwd()
 
30
  bin_dir = os.path.join(main_root, 'inspiremusic')
31
  third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
32
  python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
33
- os.environ['PATH'] = python_path
34
  sys.path.extend([main_root, third_party_matcha_tts_path])
35
 
36
- class InspireMusicUnified:
37
  def __init__(self,
38
- model_name: str = "InspireMusic-1.5B-Long",
39
  model_dir: str = None,
40
- min_generate_audio_seconds: float = 10.0,
41
  max_generate_audio_seconds: float = 30.0,
42
  sample_rate: int = 24000,
43
  output_sample_rate: int = 48000,
44
  load_jit: bool = True,
45
  load_onnx: bool = False,
 
46
  fast: bool = False,
47
  fp16: bool = True,
48
- gpu: int = 0,
49
  result_dir: str = None,
50
- hub="modelscope"):
 
 
51
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
52
 
53
  # Set model_dir or default to downloading if it doesn't exist
54
  if model_dir is None:
55
- model_dir = f"pretrained_models/{model_name}"
56
- else:
57
- model_dir = model_dir.replace("../../", "./")
 
58
 
59
- if not os.path.isfile(f"{model_dir}/llm.pt"):
60
  if hub == "modelscope":
61
  from modelscope import snapshot_download
62
  if model_name == "InspireMusic-Base":
63
  snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
64
  else:
65
  snapshot_download(f"iic/{model_name}", local_dir=model_dir)
 
 
 
66
 
67
  self.model_dir = model_dir
68
- print(self.model_dir)
69
 
70
  self.sample_rate = sample_rate
71
  self.output_sample_rate = 24000 if fast else output_sample_rate
72
- self.result_dir = result_dir or f"exp/{model_name}"
73
  os.makedirs(self.result_dir, exist_ok=True)
74
 
75
  self.min_generate_audio_seconds = min_generate_audio_seconds
@@ -79,9 +86,17 @@ class InspireMusicUnified:
79
  assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
80
 
81
  use_cuda = gpu >= 0 and torch.cuda.is_available()
82
- self.device = torch.device('cuda' if use_cuda else 'cpu')
83
- self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, fast=fast, fp16=fp16)
84
- self.model.model.llm = self.model.model.llm.to(torch.float16)
 
 
 
 
 
 
 
 
85
 
86
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
87
 
@@ -90,6 +105,7 @@ class InspireMusicUnified:
90
  task: str = 'text-to-music',
91
  text: str = None,
92
  audio_prompt: str = None, # audio prompt file path
 
93
  chorus: str = "verse",
94
  time_start: float = 0.0,
95
  time_end: float = 30.0,
@@ -205,84 +221,61 @@ class InspireMusicUnified:
205
 
206
  def get_args():
207
  parser = argparse.ArgumentParser(description='Run inference with your model')
208
- parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long",
209
- help='Model name')
210
 
211
- parser.add_argument('-d', '--model_dir',
212
- help='Model folder path')
213
 
214
- parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
215
- help='Prompt text')
216
 
217
- parser.add_argument('-a', '--audio_prompt', default=None,
218
- help='Prompt audio')
219
 
220
- parser.add_argument('-c', '--chorus', default="intro",
221
- help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
222
 
223
- parser.add_argument('-f', '--fast', type=bool, default=False,
224
- help='Enable fast inference mode (without flow matching)')
225
 
226
- parser.add_argument('-g', '--gpu', type=int, default=0,
227
- help='GPU ID for this rank, -1 for CPU')
228
 
229
- parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'],
230
- help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
231
 
232
- parser.add_argument('-r', '--result_dir', default="exp/inspiremusic",
233
- help='Directory to save generated audio')
234
 
235
- parser.add_argument('-o', '--output_fn', default="output_audio",
236
- help='Output file name')
237
 
238
- parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"],
239
- help='Format of output audio')
240
 
241
- parser.add_argument('--sample_rate', type=int, default=24000,
242
- help='Sampling rate of input audio')
243
 
244
- parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000],
245
- help='Sampling rate of generated output audio')
246
 
247
- parser.add_argument('-s', '--time_start', type=float, default=0.0,
248
- help='Start time in seconds')
249
 
250
- parser.add_argument('-e', '--time_end', type=float, default=30.0,
251
- help='End time in seconds')
252
 
253
- parser.add_argument('--max_audio_prompt_length', type=float, default=5.0,
254
- help='Maximum audio prompt length in seconds')
255
 
256
- parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
257
- help='Minimum generated audio length in seconds')
258
 
259
- parser.add_argument('--max_generate_audio_seconds', type=float, default=300.0,
260
- help='Maximum generated audio length in seconds')
261
 
262
- parser.add_argument('--fp16', type=bool, default=True,
263
- help='Inference with fp16 model')
264
 
265
- parser.add_argument('--fade_out', type=bool, default=True,
266
- help='Apply fade out effect to generated audio')
267
 
268
- parser.add_argument('--fade_out_duration', type=float, default=1.0,
269
- help='Fade out duration in seconds')
270
 
271
- parser.add_argument('--trim', type=bool, default=False,
272
- help='Trim the silence ending of generated audio')
273
 
274
  args = parser.parse_args()
275
 
276
  if not args.model_dir:
277
- args.model_dir = os.path.join("pretrained_models", args.model_name)
278
 
279
  print(args)
280
  return args
281
-
282
  def main():
283
- set_env_variables()
284
  args = get_args()
285
- model = InspireMusicUnified(model_name = args.model_name,
286
  model_dir = args.model_dir,
287
  min_generate_audio_seconds = args.min_generate_audio_seconds,
288
  max_generate_audio_seconds = args.max_generate_audio_seconds,
@@ -290,6 +283,7 @@ def main():
290
  output_sample_rate = args.output_sample_rate,
291
  load_jit = True,
292
  load_onnx = False,
 
293
  fast = args.fast,
294
  fp16 = args.fp16,
295
  gpu = args.gpu,
 
23
  import torch
24
  from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
25
 
26
+ def env_variables():
27
  os.environ['PYTHONIOENCODING'] = 'UTF-8'
28
  os.environ['TOKENIZERS_PARALLELISM'] = 'False'
29
+ current_working_dir = os.getcwd()
30
+ main_root = os.path.realpath(os.path.join(current_working_dir, '../../'))
31
  bin_dir = os.path.join(main_root, 'inspiremusic')
32
  third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
33
  python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
34
+ os.environ['PYTHONPATH'] = python_path
35
  sys.path.extend([main_root, third_party_matcha_tts_path])
36
 
37
+ class InspireMusicModel:
38
  def __init__(self,
39
+ model_name: str,
40
  model_dir: str = None,
41
+ min_generate_audio_seconds: float = 0.0,
42
  max_generate_audio_seconds: float = 30.0,
43
  sample_rate: int = 24000,
44
  output_sample_rate: int = 48000,
45
  load_jit: bool = True,
46
  load_onnx: bool = False,
47
+ dtype: str = "fp16",
48
  fast: bool = False,
49
  fp16: bool = True,
50
+ gpu: int = 1,
51
  result_dir: str = None,
52
+ hub="modelscope",
53
+ repo_url=None,
54
+ token=None):
55
  os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
56
 
57
  # Set model_dir or default to downloading if it doesn't exist
58
  if model_dir is None:
59
+ if sys.platform == "win32":
60
+ model_dir = f"..\..\pretrained_models\{model_name}"
61
+ else:
62
+ model_dir = f"../../pretrained_models/{model_name}"
63
 
64
+ if not os.path.isfile(os.path.join(model_dir, "llm.pt")):
65
  if hub == "modelscope":
66
  from modelscope import snapshot_download
67
  if model_name == "InspireMusic-Base":
68
  snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
69
  else:
70
  snapshot_download(f"iic/{model_name}", local_dir=model_dir)
71
+ elif hub == "huggingface":
72
+ from huggingface_hub import snapshot_download
73
+ snapshot_download(repo_id=f"FunAudioLLM/{model_name}", local_dir=model_dir)
74
 
75
  self.model_dir = model_dir
 
76
 
77
  self.sample_rate = sample_rate
78
  self.output_sample_rate = 24000 if fast else output_sample_rate
79
+ self.result_dir = result_dir or os.path.join("exp", model_name)
80
  os.makedirs(self.result_dir, exist_ok=True)
81
 
82
  self.min_generate_audio_seconds = min_generate_audio_seconds
 
86
  assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
87
 
88
  use_cuda = gpu >= 0 and torch.cuda.is_available()
89
+ if gpu >=0:
90
+ if torch.cuda.is_available():
91
+ self.device = torch.device('cuda')
92
+ elif torch.backends.mps.is_available():
93
+ self.device = torch.device('mps')
94
+ elif torch.xpu.is_available():
95
+ self.device = torch.device('xpu')
96
+ else:
97
+ self.device = torch.device('cpu')
98
+
99
+ self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, dtype=dtype, fast=fast, fp16=fp16)
100
 
101
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
102
 
 
105
  task: str = 'text-to-music',
106
  text: str = None,
107
  audio_prompt: str = None, # audio prompt file path
108
+ instruct: str = None,
109
  chorus: str = "verse",
110
  time_start: float = 0.0,
111
  time_end: float = 30.0,
 
221
 
222
  def get_args():
223
  parser = argparse.ArgumentParser(description='Run inference with your model')
224
+ parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long", help='Model name')
 
225
 
226
+ parser.add_argument('-d', '--model_dir', help='Model folder path')
 
227
 
228
+ parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.", help='Prompt text')
 
229
 
230
+ parser.add_argument('-a', '--audio_prompt', default=None, help='Prompt audio')
 
231
 
232
+ parser.add_argument('-c', '--chorus', default="intro", help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
 
233
 
234
+ parser.add_argument('-f', '--fast', type=bool, default=False, help='Enable fast inference mode (without flow matching)')
 
235
 
236
+ parser.add_argument('-g', '--gpu', type=int, default=1, help='GPU ID for this rank, -1 for CPU')
 
237
 
238
+ parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'], help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
 
239
 
240
+ parser.add_argument('-r', '--result_dir', default="exp/inspiremusic", help='Directory to save generated audio')
 
241
 
242
+ parser.add_argument('-o', '--output_fn', default="output_audio", help='Output file name')
 
243
 
244
+ parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"], help='Format of output audio')
 
245
 
246
+ parser.add_argument('--sample_rate', type=int, default=24000, help='Sampling rate of input audio')
 
247
 
248
+ parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000], help='Sampling rate of generated output audio')
 
249
 
250
+ parser.add_argument('-s', '--time_start', type=float, default=0.0, help='Start time in seconds')
 
251
 
252
+ parser.add_argument('-e', '--time_end', type=float, default=30.0, help='End time in seconds')
 
253
 
254
+ parser.add_argument('--max_audio_prompt_length', type=float, default=5.0, help='Maximum audio prompt length in seconds')
 
255
 
256
+ parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, help='Minimum generated audio length in seconds')
 
257
 
258
+ parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, help='Maximum generated audio length in seconds')
 
259
 
260
+ parser.add_argument('--fp16', type=bool, default=True, help='Inference with fp16 model')
 
261
 
262
+ parser.add_argument('--fade_out', type=bool, default=True, help='Apply fade out effect to generated audio')
 
263
 
264
+ parser.add_argument('--fade_out_duration', type=float, default=1.0, help='Fade out duration in seconds')
 
265
 
266
+ parser.add_argument('--trim', type=bool, default=False, help='Trim the silence ending of generated audio')
 
267
 
268
  args = parser.parse_args()
269
 
270
  if not args.model_dir:
271
+ args.model_dir = os.path.join("../../pretrained_models", args.model_name)
272
 
273
  print(args)
274
  return args
 
275
  def main():
276
+ env_variables()
277
  args = get_args()
278
+ model = InspireMusicModel(model_name = args.model_name,
279
  model_dir = args.model_dir,
280
  min_generate_audio_seconds = args.min_generate_audio_seconds,
281
  max_generate_audio_seconds = args.max_generate_audio_seconds,
 
283
  output_sample_rate = args.output_sample_rate,
284
  load_jit = True,
285
  load_onnx = False,
286
+ dtype="fp16",
287
  fast = args.fast,
288
  fp16 = args.fp16,
289
  gpu = args.gpu,
inspiremusic/cli/inspiremusic.py CHANGED
@@ -12,32 +12,41 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import os
 
15
  import time
16
  from tqdm import tqdm
17
  from hyperpyyaml import load_hyperpyyaml
18
  from inspiremusic.cli.frontend import InspireMusicFrontEnd
19
  from inspiremusic.cli.model import InspireMusicModel
20
  from inspiremusic.utils.file_utils import logging
 
21
  import torch
22
 
23
  class InspireMusic:
24
- def __init__(self, model_dir, load_jit=True, load_onnx=False, fast = False, fp16=True, hub="modelscope"):
25
  instruct = True if '-Instruct' in model_dir else False
26
 
27
  if model_dir is None:
28
- model_dir = f"pretrained_models/InspireMusic-1.5B-Long"
 
 
 
29
 
30
- if not os.path.isfile(f"{model_dir}/llm.pt"):
31
  model_name = model_dir.split("/")[-1]
32
  if hub == "modelscope":
33
  from modelscope import snapshot_download
34
  if model_name == "InspireMusic-Base":
35
  snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
36
  else:
37
- snapshot_download(f"iic/{model_name}", local_dir=model_dir)
 
 
 
 
 
38
 
39
- assert os.path.exists(f'{model_dir}/inspiremusic.yaml')
40
- with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f:
41
  configs = load_hyperpyyaml(f)
42
 
43
  self.frontend = InspireMusicFrontEnd(configs,
@@ -47,15 +56,17 @@ class InspireMusic:
47
  '{}/music_tokenizer/'.format(model_dir),
48
  '{}/wavtokenizer/'.format(model_dir),
49
  instruct,
 
50
  fast,
51
  fp16,
52
  configs['allowed_special'])
53
 
54
- self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16)
55
- self.model.load('{}/llm.pt'.format(model_dir),
56
- '{}/flow.pt'.format(model_dir),
57
- '{}/music_tokenizer/'.format(model_dir),
58
- '{}/wavtokenizer/model.pt'.format(model_dir))
 
59
  del configs
60
 
61
  @torch.inference_mode()
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import os
15
+ import sys
16
  import time
17
  from tqdm import tqdm
18
  from hyperpyyaml import load_hyperpyyaml
19
  from inspiremusic.cli.frontend import InspireMusicFrontEnd
20
  from inspiremusic.cli.model import InspireMusicModel
21
  from inspiremusic.utils.file_utils import logging
22
+ from inspiremusic.utils.utils import download_model
23
  import torch
24
 
25
  class InspireMusic:
26
+ def __init__(self, model_dir, load_jit=True, load_onnx=False, dtype = "fp16", fast = False, fp16=True, hub="modelscope", repo_url=None, token=None):
27
  instruct = True if '-Instruct' in model_dir else False
28
 
29
  if model_dir is None:
30
+ if sys.platform == "win32":
31
+ model_dir = f"..\..\pretrained_models\{model_name}"
32
+ else:
33
+ model_dir = f"../../pretrained_models/{model_name}"
34
 
35
+ if not os.path.isfile(os.path.join(model_dir, "llm.pt")):
36
  model_name = model_dir.split("/")[-1]
37
  if hub == "modelscope":
38
  from modelscope import snapshot_download
39
  if model_name == "InspireMusic-Base":
40
  snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
41
  else:
42
+ snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
43
+ elif hub == "huggingface":
44
+ from huggingface_hub import snapshot_download
45
+ snapshot_download(repo_id=f"FunAudioLLM/{model_name}", local_dir=model_dir)
46
+ else:
47
+ download_model(repo_url, model_dir, token)
48
 
49
+ with open(os.path.join(model_dir, 'inspiremusic.yaml'), 'r') as f:
 
50
  configs = load_hyperpyyaml(f)
51
 
52
  self.frontend = InspireMusicFrontEnd(configs,
 
56
  '{}/music_tokenizer/'.format(model_dir),
57
  '{}/wavtokenizer/'.format(model_dir),
58
  instruct,
59
+ dtype,
60
  fast,
61
  fp16,
62
  configs['allowed_special'])
63
 
64
+ self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
65
+ self.model.load(os.path.join(model_dir, 'llm.pt'),
66
+ os.path.join(model_dir, 'flow.pt'),
67
+ os.path.join(model_dir, 'music_tokenizer'),
68
+ os.path.join(model_dir, 'wavtokenizer', "model.pt"),
69
+ )
70
  del configs
71
 
72
  @torch.inference_mode()
inspiremusic/cli/model.py CHANGED
@@ -11,6 +11,8 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
 
14
  import numpy as np
15
  import threading
16
  import time
@@ -21,23 +23,37 @@ from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
21
  from torch.cuda.amp import autocast
22
  import logging
23
  import torch
24
- import os
25
-
26
 
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
 
29
  class InspireMusicModel:
30
-
31
  def __init__(self,
32
  llm: torch.nn.Module,
33
  flow: torch.nn.Module,
34
  music_tokenizer: torch.nn.Module,
35
  wavtokenizer: torch.nn.Module,
 
36
  fast: bool = False,
37
  fp16: bool = True,
38
  ):
39
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40
- self.llm = llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  self.flow = flow
42
  self.music_tokenizer = music_tokenizer
43
  self.wavtokenizer = wavtokenizer
@@ -66,7 +82,7 @@ class InspireMusicModel:
66
  def load(self, llm_model, flow_model, hift_model, wavtokenizer_model):
67
  if llm_model is not None:
68
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
69
- self.llm.to(self.device).eval()
70
  else:
71
  self.llm = None
72
  if flow_model is not None:
@@ -74,19 +90,15 @@ class InspireMusicModel:
74
  self.flow.to(self.device).eval()
75
  if hift_model is not None:
76
  if ".pt" not in hift_model:
77
- self.music_tokenizer = VQVAE( hift_model + '/config.json',
78
- hift_model + '/model.pt', with_encoder=True)
79
  else:
80
- self.music_tokenizer = VQVAE(os.path.dirname(hift_model) + '/config.json',
81
- hift_model, with_encoder=True)
82
  self.music_tokenizer.to(self.device).eval()
83
  if wavtokenizer_model is not None:
84
  if ".pt" not in wavtokenizer_model:
85
- self.wavtokenizer = WavTokenizer.from_pretrained_feat( wavtokenizer_model + '/config.yaml',
86
- wavtokenizer_model + '/model.pt')
87
  else:
88
- self.wavtokenizer = WavTokenizer.from_pretrained_feat( os.path.dirname(wavtokenizer_model) + '/config.yaml',
89
- wavtokenizer_model )
90
  self.wavtokenizer.to(self.device)
91
 
92
  def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
@@ -110,7 +122,7 @@ class InspireMusicModel:
110
  def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task):
111
  with self.llm_context:
112
  local_res = []
113
- with autocast(enabled=self.fp16):
114
  inference_kwargs = {
115
  'text': text.to(self.device),
116
  'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ import os
15
+ import sys
16
  import numpy as np
17
  import threading
18
  import time
 
23
  from torch.cuda.amp import autocast
24
  import logging
25
  import torch
 
 
26
 
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
28
 
29
  class InspireMusicModel:
 
30
  def __init__(self,
31
  llm: torch.nn.Module,
32
  flow: torch.nn.Module,
33
  music_tokenizer: torch.nn.Module,
34
  wavtokenizer: torch.nn.Module,
35
+ dtype: str = "fp16",
36
  fast: bool = False,
37
  fp16: bool = True,
38
  ):
39
+
40
+ if torch.cuda.is_available():
41
+ self.device = torch.device('cuda')
42
+ elif torch.backends.mps.is_available():
43
+ self.device = torch.device('mps')
44
+ elif torch.xpu.is_available():
45
+ self.device = torch.device('xpu')
46
+ else:
47
+ self.device = torch.device('cpu')
48
+
49
+ if dtype == "fp16":
50
+ self.dtype = torch.float16
51
+ elif dtype == "bf16":
52
+ self.dtype = torch.bfloat16
53
+ else:
54
+ self.dtype = torch.float32
55
+
56
+ self.llm = llm.to(self.dtype)
57
  self.flow = flow
58
  self.music_tokenizer = music_tokenizer
59
  self.wavtokenizer = wavtokenizer
 
82
  def load(self, llm_model, flow_model, hift_model, wavtokenizer_model):
83
  if llm_model is not None:
84
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
85
+ self.llm.to(self.device).to(self.dtype).eval()
86
  else:
87
  self.llm = None
88
  if flow_model is not None:
 
90
  self.flow.to(self.device).eval()
91
  if hift_model is not None:
92
  if ".pt" not in hift_model:
93
+ self.music_tokenizer = VQVAE(os.path.join(hift_model, 'config.json'), os.path.join(hift_model, 'model.pt'), with_encoder=True)
 
94
  else:
95
+ self.music_tokenizer = VQVAE(os.path.join(os.path.dirname(hift_model), 'config.json'), hift_model, with_encoder=True)
 
96
  self.music_tokenizer.to(self.device).eval()
97
  if wavtokenizer_model is not None:
98
  if ".pt" not in wavtokenizer_model:
99
+ self.wavtokenizer = WavTokenizer.from_pretrained_feat(os.path.join(wavtokenizer_model, 'config.yaml'), os.path.join(wavtokenizer_model, 'model.pt'))
 
100
  else:
101
+ self.wavtokenizer = WavTokenizer.from_pretrained_feat(os.path.join(os.path.dirname(wavtokenizer_model), 'config.yaml'), wavtokenizer_model)
 
102
  self.wavtokenizer.to(self.device)
103
 
104
  def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
 
122
  def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task):
123
  with self.llm_context:
124
  local_res = []
125
+ with autocast(enabled=self.fp16, dtype=self.dtype, cache_enabled=True):
126
  inference_kwargs = {
127
  'text': text.to(self.device),
128
  'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
inspiremusic/flow/flow.py CHANGED
@@ -39,7 +39,7 @@ class MaskedDiff(torch.nn.Module):
39
  'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
40
  mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000,
41
  'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000},
42
- generator_model_dir: str = "pretrained_models/InspireMusic-Base/music_tokenizer",
43
  num_codebooks: int = 4
44
  ):
45
  super().__init__()
 
39
  'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
40
  mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000,
41
  'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000},
42
+ generator_model_dir: str = "../../pretrained_models/InspireMusic-Base/music_tokenizer",
43
  num_codebooks: int = 4
44
  ):
45
  super().__init__()
inspiremusic/llm/llm.py CHANGED
@@ -50,9 +50,19 @@ class LLM(torch.nn.Module):
50
  length_normalized_loss: bool = True,
51
  lsm_weight: float = 0.0,
52
  frozen_input_embed: bool = False,
 
 
53
  **kwargs,
54
  ):
55
  super().__init__()
 
 
 
 
 
 
 
 
56
  self.llm_input_size = llm_input_size
57
  self.audio_token_size = audio_token_size
58
  # 1. build text token inputs related modules
@@ -115,34 +125,9 @@ class LLM(torch.nn.Module):
115
 
116
  encoder_name = encoder_conf.pop("name", "transformer")
117
  model = None
118
- if encoder_name == "transformer":
119
- from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder
120
- model = ConformerEncoder(
121
- **encoder_conf,
122
- input_size=self.input_size,
123
- use_cnn_module=False,
124
- macaron_style=False,
125
- )
126
- elif encoder_name == "conformer":
127
- from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder
128
- model = ConformerEncoder(
129
- **encoder_conf,
130
- input_size=self.input_size,
131
- )
132
- elif encoder_name == "llama_encoder":
133
- from inspiremusic.transformer.encoder.llama_encoder import LlamaEncoder
134
- model = LlamaEncoder(
135
- **encoder_conf,
136
- input_size=self.input_size,
137
- )
138
- elif encoder_name == "qwen2":
139
- from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
140
- model = QwenEncoder(
141
- **encoder_conf,
142
- input_size=self.input_size,
143
- )
144
- elif encoder_name == "qwen2.5":
145
- from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
146
  model = QwenEncoder(
147
  **encoder_conf,
148
  input_size=self.input_size,
@@ -237,8 +222,7 @@ class LLM(torch.nn.Module):
237
  time_end_embed = self.time_embedding(time_end).to(text_token.dtype)
238
  chorus_embed = self.chorus_embedding(chorus)
239
 
240
- lm_target = [torch.tensor(
241
- [IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
242
 
243
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
244
 
@@ -250,18 +234,9 @@ class LLM(torch.nn.Module):
250
  audio_token = self.speech_embedding(audio_token)
251
 
252
  # 5. unpad and pad
253
- lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb,
254
- [time_start_embed,
255
- time_end_embed,
256
- chorus_embed],
257
- text_token,
258
- text_token_len,
259
- task_id_emb,
260
- audio_token,
261
- audio_token_len,
262
- seg_len)
263
  # 6. run lm forward
264
- lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
265
  logits = self.llm_decoder(lm_output)
266
  loss = self.criterion_ce(logits, lm_target)
267
 
@@ -290,7 +265,7 @@ class LLM(torch.nn.Module):
290
  prompt_audio_token: torch.Tensor,
291
  prompt_audio_token_len: torch.Tensor,
292
  embeddings: List,
293
- duration_to_gen: float = 300,
294
  task: str = "continuation",
295
  token_rate: int = 75,
296
  limit_audio_prompt_len: int = 5,
@@ -317,8 +292,7 @@ class LLM(torch.nn.Module):
317
  time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half()
318
  chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half()
319
  else:
320
- time_start_embed = self.time_embedding(
321
- time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half()
322
  time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half()
323
  chorus_embed = self.chorus_embedding(chorus) # .half()
324
 
@@ -332,10 +306,10 @@ class LLM(torch.nn.Module):
332
  else:
333
  audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
334
 
335
- if prompt_audio_token_len:
336
- prompt_audio_token_emb = self.speech_embedding(prompt_audio_token)
337
- else:
338
- prompt_audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
339
  # Check if removing prompt audio token will fail decoding.
340
 
341
  if task == "continuation":
@@ -344,31 +318,18 @@ class LLM(torch.nn.Module):
344
  chorus_embed, text, task_id_emb, audio_token_emb], dim=1)
345
 
346
  if infer_cfg:
347
- audio_cfg = self.speech_embedding(
348
- audio_token.new_zeros(audio_token.shape))
349
- lm_cf_input = torch.concat(
350
- [sos_eos_emb, torch.rand_like(time_start_embed),
351
- torch.rand_like(time_end_embed),
352
- torch.rand_like(chorus_embed), text_cfg, task_id_emb,
353
- audio_cfg], dim=1)
354
  lm_input = torch.cat([lm_input, lm_cf_input], 0)
355
  else:
356
- lm_input = torch.concat(
357
- [sos_eos_emb, time_start_embed, time_end_embed,
358
- chorus_embed, text, task_id_emb], dim=1)
359
  if infer_cfg:
360
- lm_cf_input = torch.concat(
361
- [sos_eos_emb, torch.rand_like(time_start_embed),
362
- torch.rand_like(time_end_embed),
363
- torch.rand_like(chorus_embed), text_cfg, task_id_emb],
364
- dim=1)
365
  lm_input = torch.cat([lm_input, lm_cf_input], 0)
366
 
367
  # 4. cal min/max_length
368
- min_len = 0.9 * duration_to_gen * token_rate
369
  max_len = duration_to_gen * token_rate
370
- logging.info(
371
- f"LLM generation sequence length: {max_len}, generate audio length {duration_to_gen}s.")
372
 
373
  # 5. step by step decode
374
  out_tokens = []
@@ -376,7 +337,7 @@ class LLM(torch.nn.Module):
376
  state = None
377
 
378
  for i in range(int(max_len)):
379
- y_pred, _, state = self.llm.forward_one_step(lm_input, torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state)
380
  logits = self.llm_decoder(y_pred[:, -1])
381
  if infer_cfg:
382
  # perform context free guidance
@@ -389,10 +350,7 @@ class LLM(torch.nn.Module):
389
  logp = logp.squeeze(dim=0)
390
 
391
  if i < int(min_len):
392
- logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
393
-
394
- if i < int(min_len):
395
- logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
396
 
397
  top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
398
 
 
50
  length_normalized_loss: bool = True,
51
  lsm_weight: float = 0.0,
52
  frozen_input_embed: bool = False,
53
+ dtype: str = "fp16",
54
+ text_token_size: int = 151643,
55
  **kwargs,
56
  ):
57
  super().__init__()
58
+
59
+ if dtype == "fp16":
60
+ self.dtype = torch.float16
61
+ elif dtype == "bf16":
62
+ self.dtype = torch.bfloat16
63
+ else:
64
+ self.dtype = torch.float32
65
+
66
  self.llm_input_size = llm_input_size
67
  self.audio_token_size = audio_token_size
68
  # 1. build text token inputs related modules
 
125
 
126
  encoder_name = encoder_conf.pop("name", "transformer")
127
  model = None
128
+
129
+ if "qwen" in encoder_name:
130
+ from inspiremusic.transformer.qwen_encoder import QwenEncoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  model = QwenEncoder(
132
  **encoder_conf,
133
  input_size=self.input_size,
 
222
  time_end_embed = self.time_embedding(time_end).to(text_token.dtype)
223
  chorus_embed = self.chorus_embedding(chorus)
224
 
225
+ lm_target = [torch.tensor([IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
 
226
 
227
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
228
 
 
234
  audio_token = self.speech_embedding(audio_token)
235
 
236
  # 5. unpad and pad
237
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, [time_start_embed, time_end_embed, chorus_embed], text_token, text_token_len, task_id_emb, audio_token, audio_token_len, seg_len)
 
 
 
 
 
 
 
 
 
238
  # 6. run lm forward
239
+ lm_output, lm_output_mask = self.llm(lm_input.to(self.dtype), lm_input_len.to(device))
240
  logits = self.llm_decoder(lm_output)
241
  loss = self.criterion_ce(logits, lm_target)
242
 
 
265
  prompt_audio_token: torch.Tensor,
266
  prompt_audio_token_len: torch.Tensor,
267
  embeddings: List,
268
+ duration_to_gen: float = 30,
269
  task: str = "continuation",
270
  token_rate: int = 75,
271
  limit_audio_prompt_len: int = 5,
 
292
  time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half()
293
  chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half()
294
  else:
295
+ time_start_embed = self.time_embedding(time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half()
 
296
  time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half()
297
  chorus_embed = self.chorus_embedding(chorus) # .half()
298
 
 
306
  else:
307
  audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
308
 
309
+ #if prompt_audio_token_len:
310
+ # prompt_audio_token_emb = self.speech_embedding(prompt_audio_token)
311
+ #else:
312
+ # prompt_audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
313
  # Check if removing prompt audio token will fail decoding.
314
 
315
  if task == "continuation":
 
318
  chorus_embed, text, task_id_emb, audio_token_emb], dim=1)
319
 
320
  if infer_cfg:
321
+ audio_cfg = self.speech_embedding(audio_token.new_zeros(audio_token.shape))
322
+ lm_cf_input = torch.concat([sos_eos_emb, torch.rand_like(time_start_embed), torch.rand_like(time_end_embed), torch.rand_like(chorus_embed), text_cfg, task_id_emb, audio_cfg], dim=1)
 
 
 
 
 
323
  lm_input = torch.cat([lm_input, lm_cf_input], 0)
324
  else:
325
+ lm_input = torch.concat([sos_eos_emb, time_start_embed, time_end_embed, chorus_embed, text, task_id_emb], dim=1)
 
 
326
  if infer_cfg:
327
+ lm_cf_input = torch.concat([sos_eos_emb, torch.rand_like(time_start_embed), torch.rand_like(time_end_embed), torch.rand_like(chorus_embed), text_cfg, task_id_emb], dim=1)
 
 
 
 
328
  lm_input = torch.cat([lm_input, lm_cf_input], 0)
329
 
330
  # 4. cal min/max_length
331
+ min_len = int(0.9 * duration_to_gen * token_rate)
332
  max_len = duration_to_gen * token_rate
 
 
333
 
334
  # 5. step by step decode
335
  out_tokens = []
 
337
  state = None
338
 
339
  for i in range(int(max_len)):
340
+ y_pred, _, state = self.llm.forward_one_step(lm_input.to(self.dtype), torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state)
341
  logits = self.llm_decoder(y_pred[:, -1])
342
  if infer_cfg:
343
  # perform context free guidance
 
350
  logp = logp.squeeze(dim=0)
351
 
352
  if i < int(min_len):
353
+ logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=self.dtype)
 
 
 
354
 
355
  top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
356
 
inspiremusic/transformer/qwen_encoder.py CHANGED
@@ -22,6 +22,7 @@ class QwenEncoder(nn.Module):
22
  def __init__(
23
  self,
24
  input_size: int,
 
25
  pretrain_path: str = "Qwen/Qwen2.0-0.5B",
26
  trainable: bool = False,
27
  do_fusion_emb: bool = False,
@@ -30,7 +31,15 @@ class QwenEncoder(nn.Module):
30
  super(QwenEncoder, self).__init__()
31
  self.input_size = input_size
32
  self.trainable = trainable
33
- self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="cpu")
 
 
 
 
 
 
 
 
34
  self._output_size = self.model.config.hidden_size
35
  self.do_fusion_emb = do_fusion_emb
36
  self.hidden_norm = torch.nn.LayerNorm(self._output_size)
@@ -88,14 +97,19 @@ class QwenEmbeddingEncoder(nn.Module):
88
  def __init__(
89
  self,
90
  input_size: int,
 
91
  pretrain_path: str = "Qwen/Qwen2.0-0.5B",
92
  ):
93
  super(QwenEmbeddingEncoder, self).__init__()
94
  self.input_size = input_size
 
 
 
 
 
 
95
  from transformers import Qwen2ForCausalLM
96
- # self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2")
97
- self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path,
98
- device_map="cpu")
99
  self._output_size = self.model.config.hidden_size
100
 
101
  def output_size(self) -> int:
@@ -137,14 +151,19 @@ class QwenInputOnlyEncoder(nn.Module):
137
  def __init__(
138
  self,
139
  input_size: int,
 
140
  pretrain_path: str = "Qwen/Qwen2.0-0.5B",
141
  ):
142
  super(QwenInputOnlyEncoder, self).__init__()
143
  self.input_size = input_size
 
 
 
 
 
 
144
  from transformers import Qwen2ForCausalLM
145
- # model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="cpu", attn_implementation="flash_attention_2")
146
- model = Qwen2ForCausalLM.from_pretrained(pretrain_path,
147
- device_map="cpu")
148
  self.embed = model.model.embed_tokens
149
  for p in self.embed.parameters():
150
  p.requires_grad = False
 
22
  def __init__(
23
  self,
24
  input_size: int,
25
+ dtype: str = "fp16",
26
  pretrain_path: str = "Qwen/Qwen2.0-0.5B",
27
  trainable: bool = False,
28
  do_fusion_emb: bool = False,
 
31
  super(QwenEncoder, self).__init__()
32
  self.input_size = input_size
33
  self.trainable = trainable
34
+
35
+ if dtype == "fp16":
36
+ self.dtype = torch.float16
37
+ elif dtype == "bf16":
38
+ self.dtype = torch.bfloat16
39
+ else:
40
+ self.dtype = torch.float32
41
+
42
+ self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=self.dtype)
43
  self._output_size = self.model.config.hidden_size
44
  self.do_fusion_emb = do_fusion_emb
45
  self.hidden_norm = torch.nn.LayerNorm(self._output_size)
 
97
  def __init__(
98
  self,
99
  input_size: int,
100
+ dtype: str = "fp16",
101
  pretrain_path: str = "Qwen/Qwen2.0-0.5B",
102
  ):
103
  super(QwenEmbeddingEncoder, self).__init__()
104
  self.input_size = input_size
105
+ if dtype == "fp16":
106
+ self.dtype = torch.float16
107
+ elif dtype == "bf16":
108
+ self.dtype = torch.bfloat16
109
+ else:
110
+ self.dtype = torch.float32
111
  from transformers import Qwen2ForCausalLM
112
+ self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=self.dtype)
 
 
113
  self._output_size = self.model.config.hidden_size
114
 
115
  def output_size(self) -> int:
 
151
  def __init__(
152
  self,
153
  input_size: int,
154
+ dtype: str = "fp16",
155
  pretrain_path: str = "Qwen/Qwen2.0-0.5B",
156
  ):
157
  super(QwenInputOnlyEncoder, self).__init__()
158
  self.input_size = input_size
159
+ if dtype == "fp16":
160
+ self.dtype = torch.float16
161
+ elif dtype == "bf16":
162
+ self.dtype = torch.bfloat16
163
+ else:
164
+ self.dtype = torch.float32
165
  from transformers import Qwen2ForCausalLM
166
+ model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=self.dtype)
 
 
167
  self.embed = model.model.embed_tokens
168
  for p in self.embed.parameters():
169
  p.requires_grad = False
inspiremusic/utils/common.py CHANGED
@@ -16,12 +16,9 @@
16
  """Unility functions for Transformer."""
17
 
18
  from typing import List
19
-
20
  import torch
21
  IGNORE_ID = -1
22
 
23
- MUSIC_STRUCTURE_LABELS = ["intro", "verse1", "chorus", "verse2", "outro"]
24
-
25
  def pad_list(xs: List[torch.Tensor], pad_value: int):
26
  """Perform padding for the list of tensors.
27
 
@@ -92,16 +89,61 @@ def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
92
  denominator = torch.sum(mask)
93
  return (numerator / denominator).detach()
94
 
95
-
96
  def get_padding(kernel_size, dilation=1):
97
  return int((kernel_size * dilation - dilation) / 2)
98
 
99
-
100
  def init_weights(m, mean=0.0, std=0.01):
101
  classname = m.__class__.__name__
102
  if classname.find("Conv") != -1:
103
  m.weight.data.normal_(mean, std)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  def topk_sampling(weighted_scores, decoded_tokens, top_k=25):
106
  zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf')
107
  values,indices = torch.topk(weighted_scores,top_k)
 
16
  """Unility functions for Transformer."""
17
 
18
  from typing import List
 
19
  import torch
20
  IGNORE_ID = -1
21
 
 
 
22
  def pad_list(xs: List[torch.Tensor], pad_value: int):
23
  """Perform padding for the list of tensors.
24
 
 
89
  denominator = torch.sum(mask)
90
  return (numerator / denominator).detach()
91
 
 
92
  def get_padding(kernel_size, dilation=1):
93
  return int((kernel_size * dilation - dilation) / 2)
94
 
 
95
  def init_weights(m, mean=0.0, std=0.01):
96
  classname = m.__class__.__name__
97
  if classname.find("Conv") != -1:
98
  m.weight.data.normal_(mean, std)
99
 
100
+ def keep_rhythm(next_token, current_time_signature):
101
+ allowed_durations = get_allowed_durations(current_time_signature)
102
+ if next_token not in allowed_durations:
103
+ next_token = random.choice(allowed_durations)
104
+ return next_token
105
+
106
+ def keep_harmony(next_token, current_chord):
107
+ allowed_notes = get_allowed_notes(current_chord) # Define allowed notes for the chord
108
+ if next_token not in allowed_notes:
109
+ next_token = random.choice(allowed_notes) # Replace with a valid note
110
+ return next_token
111
+
112
+ def relieve_repetition(weighted_scores, recent_tokens, repetition_penalty=1.2):
113
+ for token in recent_tokens:
114
+ if weighted_scores[token] > 0:
115
+ weighted_scores[token] /= repetition_penalty
116
+ return weighted_scores
117
+
118
+ def top_p_sampling_with_constraints(weighted_scores, decoded_tokens, top_p=0.85, temperature=1.1, current_chord=None, current_time_signature=None, recent_tokens=None):
119
+ # Apply temperature scaling
120
+ weighted_scores = weighted_scores ** (1 / temperature)
121
+ weighted_scores /= weighted_scores.sum()
122
+
123
+ if recent_tokens:
124
+ weighted_scores = relieve_repetition(weighted_scores, recent_tokens)
125
+
126
+ # Sort weighted scores in descending order
127
+ sorted_weighted_scores, _ = torch.sort(weighted_scores, descending=True)
128
+
129
+ # Compute cumulative weighted scores
130
+ cumulative_weighted_scores = torch.cumsum(sorted_weighted_scores, dim=0)
131
+
132
+ # Find the threthold index of top-p
133
+ cutoff_index = torch.where(cumulative_weighted_scores >= top_p)[0][0]
134
+ selected_weighted_scores = sorted_weighted_scores[:cutoff_index + 1]
135
+
136
+ # Apply domain-specific constraints
137
+ if current_chord:
138
+ selected_weighted_scores = keep_harmony(selected_weighted_scores, current_chord)
139
+ if current_time_signature:
140
+ selected_weighted_scores = keep_rhythm(selected_weighted_scores, current_time_signature)
141
+
142
+ # Normalize selected probabilities
143
+ selected_weighted_scores /= selected_weighted_scores.sum()
144
+
145
+ # Sample top-p tokens from the distribution
146
+ return random_sampling(selected_weighted_scores, decoded_tokens)
147
  def topk_sampling(weighted_scores, decoded_tokens, top_k=25):
148
  zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf')
149
  values,indices = torch.topk(weighted_scores,top_k)
inspiremusic/utils/executor.py CHANGED
@@ -24,13 +24,19 @@ from inspiremusic.utils.train_utils import update_parameter_and_lr, log_per_step
24
  from torch.cuda.amp import GradScaler, autocast
25
 
26
  class Executor:
27
-
28
  def __init__(self):
29
  self.step = 0
30
  self.epoch = 0
31
  self.rank = int(os.environ.get('RANK', 0))
32
- self.device = torch.device('cuda:{}'.format(self.rank))
33
-
 
 
 
 
 
 
 
34
  def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None):
35
  ''' Train one epoch
36
  '''
 
24
  from torch.cuda.amp import GradScaler, autocast
25
 
26
  class Executor:
 
27
  def __init__(self):
28
  self.step = 0
29
  self.epoch = 0
30
  self.rank = int(os.environ.get('RANK', 0))
31
+ if torch.cuda.is_available():
32
+ if torch.cuda.is_available():
33
+ self.device = torch.device('cuda:{}'.format(self.rank))
34
+ elif torch.backends.mps.is_available():
35
+ self.device = torch.device('mps')
36
+ elif torch.xpu.is_available():
37
+ self.device = torch.device('xpu')
38
+ else:
39
+ self.device = torch.device('cpu')
40
  def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None):
41
  ''' Train one epoch
42
  '''
inspiremusic/utils/utils.py CHANGED
@@ -1,5 +1,27 @@
1
  import os
2
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def align_trans_scp_file(trans, scp):
5
  trans_dict = {}
@@ -14,9 +36,4 @@ def align_trans_scp_file(trans, scp):
14
  scp_dict[sec[0]] = sec[1]
15
  with open("text", "w") as f:
16
  for k, v in scp_dict.items():
17
- f.write("%s\t%s\n"%(k,trans_dict[k]))
18
-
19
- if __name__ == '__main__':
20
- trans = sys.argv[1]
21
- scp = sys.argv[2]
22
- align_trans_scp_file(trans, scp)
 
1
  import os
2
  import sys
3
+ import subprocess
4
+
5
+ def download_model(repo_url: str, output_dir: str = None, token: str = None):
6
+ try:
7
+ if token:
8
+ repo_url = repo_url.replace("https://", f"https://USER:{token}@")
9
+ else:
10
+ repo_url = f"https://www.modelscope.cn/models/iic/{repo_url}"
11
+
12
+ cmd = ["git", "clone", repo_url]
13
+ if output_dir:
14
+ cmd.append(output_dir)
15
+
16
+ result = subprocess.run(
17
+ cmd,
18
+ check=True,
19
+ capture_output=True,
20
+ text=True
21
+ )
22
+ print("Success:", result.stdout)
23
+ except subprocess.CalledProcessError as e:
24
+ print("Error:", e.stderr)
25
 
26
  def align_trans_scp_file(trans, scp):
27
  trans_dict = {}
 
36
  scp_dict[sec[0]] = sec[1]
37
  with open("text", "w") as f:
38
  for k, v in scp_dict.items():
39
+ f.write("%s\t%s\n"%(k,trans_dict[k]))
 
 
 
 
 
inspiremusic/wavtokenizer/.DS_Store DELETED
Binary file (6.15 kB)