Spaces:
Running
Running
chong.zhang
commited on
Commit
·
2c50d95
1
Parent(s):
b6363bb
update
Browse files- inspiremusic/.DS_Store +0 -0
- inspiremusic/bin/inference.py +15 -6
- inspiremusic/cli/frontend.py +4 -5
- inspiremusic/cli/inference.py +59 -65
- inspiremusic/cli/inspiremusic.py +22 -11
- inspiremusic/cli/model.py +27 -15
- inspiremusic/flow/flow.py +1 -1
- inspiremusic/llm/llm.py +29 -71
- inspiremusic/transformer/qwen_encoder.py +26 -7
- inspiremusic/utils/common.py +47 -5
- inspiremusic/utils/executor.py +9 -3
- inspiremusic/utils/utils.py +23 -6
- inspiremusic/wavtokenizer/.DS_Store +0 -0
inspiremusic/.DS_Store
DELETED
Binary file (8.2 kB)
|
|
inspiremusic/bin/inference.py
CHANGED
@@ -28,7 +28,6 @@ from inspiremusic.cli.model import InspireMusicModel
|
|
28 |
from inspiremusic.dataset.dataset import Dataset
|
29 |
import time
|
30 |
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
|
31 |
-
from inspiremusic.utils.common import MUSIC_STRUCTURE_LABELS
|
32 |
|
33 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
34 |
|
@@ -42,6 +41,7 @@ def get_args():
|
|
42 |
parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
|
43 |
parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
|
44 |
parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.')
|
|
|
45 |
parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model')
|
46 |
parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio')
|
47 |
parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds')
|
@@ -53,7 +53,7 @@ def get_args():
|
|
53 |
help='sampling rate of input audio')
|
54 |
parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000],
|
55 |
help='sampling rate of generated output audio')
|
56 |
-
parser.add_argument('--min_generate_audio_seconds', type=float, default=
|
57 |
help='the minimum generated audio length in seconds')
|
58 |
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
|
59 |
help='the maximum generated audio length in seconds')
|
@@ -70,9 +70,9 @@ def get_args():
|
|
70 |
print(args)
|
71 |
return args
|
72 |
|
73 |
-
|
74 |
def main():
|
75 |
args = get_args()
|
|
|
76 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
|
77 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
78 |
|
@@ -85,11 +85,20 @@ def main():
|
|
85 |
|
86 |
# Init inspiremusic models from configs
|
87 |
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
with open(args.config, 'r') as f:
|
90 |
configs = load_hyperpyyaml(f)
|
91 |
|
92 |
-
model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.fast, args.fp16)
|
93 |
|
94 |
model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
|
95 |
|
@@ -153,7 +162,7 @@ def main():
|
|
153 |
time_end = batch["time_end"].to(device)
|
154 |
chorus = batch["chorus"].to(torch.int)
|
155 |
|
156 |
-
text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{
|
157 |
chorus = chorus.to(device)
|
158 |
|
159 |
if batch["acoustic_token"] is None:
|
|
|
28 |
from inspiremusic.dataset.dataset import Dataset
|
29 |
import time
|
30 |
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
|
|
|
31 |
|
32 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
33 |
|
|
|
41 |
parser.add_argument('--wavtokenizer', required=True, help='wavtokenizer model file')
|
42 |
parser.add_argument('--chorus', default="random",required=False, help='chorus tag generation mode, eg. random, verse, chorus, intro.')
|
43 |
parser.add_argument('--fast', action='store_true', required=False, help='True: fast inference mode, without flow matching for fast inference. False: normal inference mode, with flow matching for high quality.')
|
44 |
+
parser.add_argument('--dtype', type=str, default="fp16", required=False, choices=["fp16", "bf16", "fp32"], help='data type')
|
45 |
parser.add_argument('--fp16', default=True, type=bool, required=False, help='inference with fp16 model')
|
46 |
parser.add_argument('--fade_out', default=True, type=bool, required=False, help='add fade out effect to generated audio')
|
47 |
parser.add_argument('--fade_out_duration', default=1.0, type=float, required=False, help='fade out duration in seconds')
|
|
|
53 |
help='sampling rate of input audio')
|
54 |
parser.add_argument('--output_sample_rate', type=int, default=48000, required=False, choices=[24000, 48000],
|
55 |
help='sampling rate of generated output audio')
|
56 |
+
parser.add_argument('--min_generate_audio_seconds', type=float, default=0.0, required=False,
|
57 |
help='the minimum generated audio length in seconds')
|
58 |
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, required=False,
|
59 |
help='the maximum generated audio length in seconds')
|
|
|
70 |
print(args)
|
71 |
return args
|
72 |
|
|
|
73 |
def main():
|
74 |
args = get_args()
|
75 |
+
chorus_labels = ["intro", "verse1", "chorus", "verse2", "outro"]
|
76 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(message)s')
|
77 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
78 |
|
|
|
85 |
|
86 |
# Init inspiremusic models from configs
|
87 |
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
88 |
+
if args.gpu >=0:
|
89 |
+
if torch.cuda.is_available():
|
90 |
+
device = torch.device('cuda')
|
91 |
+
elif torch.backends.mps.is_available():
|
92 |
+
device = torch.device('mps')
|
93 |
+
elif torch.xpu.is_available():
|
94 |
+
device = torch.device('xpu')
|
95 |
+
else:
|
96 |
+
device = torch.device('cpu')
|
97 |
+
|
98 |
with open(args.config, 'r') as f:
|
99 |
configs = load_hyperpyyaml(f)
|
100 |
|
101 |
+
model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], args.dtype, args.fast, args.fp16)
|
102 |
|
103 |
model.load(args.llm_model, args.flow_model, args.music_tokenizer, args.wavtokenizer)
|
104 |
|
|
|
162 |
time_end = batch["time_end"].to(device)
|
163 |
chorus = batch["chorus"].to(torch.int)
|
164 |
|
165 |
+
text_prompt = f"<|{batch['time_start'].numpy()[0]}|><|{chorus_labels[chorus.numpy()[0]]}|><|{batch['text'][0]}|><|{batch['time_end'].numpy()[0]}|>"
|
166 |
chorus = chorus.to(device)
|
167 |
|
168 |
if batch["acoustic_token"] is None:
|
inspiremusic/cli/frontend.py
CHANGED
@@ -29,6 +29,7 @@ class InspireMusicFrontEnd:
|
|
29 |
music_tokenizer_dir: str,
|
30 |
audio_tokenizer_dir: str,
|
31 |
instruct: bool = False,
|
|
|
32 |
fast: bool = False,
|
33 |
fp16: bool = True,
|
34 |
allowed_special: str = 'all'):
|
@@ -39,7 +40,7 @@ class InspireMusicFrontEnd:
|
|
39 |
self.bandwidth_id = torch.tensor([0]).to(self.device)
|
40 |
self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device)
|
41 |
|
42 |
-
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16)
|
43 |
self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir)
|
44 |
|
45 |
self.instruct = instruct
|
@@ -69,12 +70,10 @@ class InspireMusicFrontEnd:
|
|
69 |
text = text.replace(" - ", ",")
|
70 |
text = remove_bracket(text)
|
71 |
text = re.sub(r'[,,]+$', '。', text)
|
72 |
-
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
73 |
-
token_min_n=60, merge_len=20, comma_split=False))
|
74 |
else:
|
75 |
text = spell_out_number(text, self.inflect_parser)
|
76 |
-
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
77 |
-
token_min_n=60, merge_len=20, comma_split=False))
|
78 |
if split is False:
|
79 |
return text
|
80 |
return texts
|
|
|
29 |
music_tokenizer_dir: str,
|
30 |
audio_tokenizer_dir: str,
|
31 |
instruct: bool = False,
|
32 |
+
dtype: str = "fp16",
|
33 |
fast: bool = False,
|
34 |
fp16: bool = True,
|
35 |
allowed_special: str = 'all'):
|
|
|
40 |
self.bandwidth_id = torch.tensor([0]).to(self.device)
|
41 |
self.wavtokenizer = WavTokenizer.from_pretrained_feat(f"{audio_tokenizer_dir}/config.yaml", f"{audio_tokenizer_dir}/model.pt").to(self.device)
|
42 |
|
43 |
+
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
|
44 |
self.model = self.model.load(llm_model, flow_model, music_tokenizer_dir, audio_tokenizer_dir)
|
45 |
|
46 |
self.instruct = instruct
|
|
|
70 |
text = text.replace(" - ", ",")
|
71 |
text = remove_bracket(text)
|
72 |
text = re.sub(r'[,,]+$', '。', text)
|
73 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
|
|
|
74 |
else:
|
75 |
text = spell_out_number(text, self.inflect_parser)
|
76 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False))
|
|
|
77 |
if split is False:
|
78 |
return text
|
79 |
return texts
|
inspiremusic/cli/inference.py
CHANGED
@@ -23,53 +23,60 @@ from inspiremusic.utils.file_utils import logging
|
|
23 |
import torch
|
24 |
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
|
25 |
|
26 |
-
def
|
27 |
os.environ['PYTHONIOENCODING'] = 'UTF-8'
|
28 |
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
|
29 |
-
|
|
|
30 |
bin_dir = os.path.join(main_root, 'inspiremusic')
|
31 |
third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
|
32 |
python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
|
33 |
-
os.environ['
|
34 |
sys.path.extend([main_root, third_party_matcha_tts_path])
|
35 |
|
36 |
-
class
|
37 |
def __init__(self,
|
38 |
-
model_name: str
|
39 |
model_dir: str = None,
|
40 |
-
min_generate_audio_seconds: float =
|
41 |
max_generate_audio_seconds: float = 30.0,
|
42 |
sample_rate: int = 24000,
|
43 |
output_sample_rate: int = 48000,
|
44 |
load_jit: bool = True,
|
45 |
load_onnx: bool = False,
|
|
|
46 |
fast: bool = False,
|
47 |
fp16: bool = True,
|
48 |
-
gpu: int =
|
49 |
result_dir: str = None,
|
50 |
-
hub="modelscope"
|
|
|
|
|
51 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
52 |
|
53 |
# Set model_dir or default to downloading if it doesn't exist
|
54 |
if model_dir is None:
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
|
59 |
-
if not os.path.isfile(
|
60 |
if hub == "modelscope":
|
61 |
from modelscope import snapshot_download
|
62 |
if model_name == "InspireMusic-Base":
|
63 |
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
|
64 |
else:
|
65 |
snapshot_download(f"iic/{model_name}", local_dir=model_dir)
|
|
|
|
|
|
|
66 |
|
67 |
self.model_dir = model_dir
|
68 |
-
print(self.model_dir)
|
69 |
|
70 |
self.sample_rate = sample_rate
|
71 |
self.output_sample_rate = 24000 if fast else output_sample_rate
|
72 |
-
self.result_dir = result_dir or
|
73 |
os.makedirs(self.result_dir, exist_ok=True)
|
74 |
|
75 |
self.min_generate_audio_seconds = min_generate_audio_seconds
|
@@ -79,9 +86,17 @@ class InspireMusicUnified:
|
|
79 |
assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
|
80 |
|
81 |
use_cuda = gpu >= 0 and torch.cuda.is_available()
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
87 |
|
@@ -90,6 +105,7 @@ class InspireMusicUnified:
|
|
90 |
task: str = 'text-to-music',
|
91 |
text: str = None,
|
92 |
audio_prompt: str = None, # audio prompt file path
|
|
|
93 |
chorus: str = "verse",
|
94 |
time_start: float = 0.0,
|
95 |
time_end: float = 30.0,
|
@@ -205,84 +221,61 @@ class InspireMusicUnified:
|
|
205 |
|
206 |
def get_args():
|
207 |
parser = argparse.ArgumentParser(description='Run inference with your model')
|
208 |
-
parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long",
|
209 |
-
help='Model name')
|
210 |
|
211 |
-
parser.add_argument('-d', '--model_dir',
|
212 |
-
help='Model folder path')
|
213 |
|
214 |
-
parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.",
|
215 |
-
help='Prompt text')
|
216 |
|
217 |
-
parser.add_argument('-a', '--audio_prompt', default=None,
|
218 |
-
help='Prompt audio')
|
219 |
|
220 |
-
parser.add_argument('-c', '--chorus', default="intro",
|
221 |
-
help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
|
222 |
|
223 |
-
parser.add_argument('-f', '--fast', type=bool, default=False,
|
224 |
-
help='Enable fast inference mode (without flow matching)')
|
225 |
|
226 |
-
parser.add_argument('-g', '--gpu', type=int, default=
|
227 |
-
help='GPU ID for this rank, -1 for CPU')
|
228 |
|
229 |
-
parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'],
|
230 |
-
help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
|
231 |
|
232 |
-
parser.add_argument('-r', '--result_dir', default="exp/inspiremusic",
|
233 |
-
help='Directory to save generated audio')
|
234 |
|
235 |
-
parser.add_argument('-o', '--output_fn', default="output_audio",
|
236 |
-
help='Output file name')
|
237 |
|
238 |
-
parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"],
|
239 |
-
help='Format of output audio')
|
240 |
|
241 |
-
parser.add_argument('--sample_rate', type=int, default=24000,
|
242 |
-
help='Sampling rate of input audio')
|
243 |
|
244 |
-
parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000],
|
245 |
-
help='Sampling rate of generated output audio')
|
246 |
|
247 |
-
parser.add_argument('-s', '--time_start', type=float, default=0.0,
|
248 |
-
help='Start time in seconds')
|
249 |
|
250 |
-
parser.add_argument('-e', '--time_end', type=float, default=30.0,
|
251 |
-
help='End time in seconds')
|
252 |
|
253 |
-
parser.add_argument('--max_audio_prompt_length', type=float, default=5.0,
|
254 |
-
help='Maximum audio prompt length in seconds')
|
255 |
|
256 |
-
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0,
|
257 |
-
help='Minimum generated audio length in seconds')
|
258 |
|
259 |
-
parser.add_argument('--max_generate_audio_seconds', type=float, default=
|
260 |
-
help='Maximum generated audio length in seconds')
|
261 |
|
262 |
-
parser.add_argument('--fp16', type=bool, default=True,
|
263 |
-
help='Inference with fp16 model')
|
264 |
|
265 |
-
parser.add_argument('--fade_out', type=bool, default=True,
|
266 |
-
help='Apply fade out effect to generated audio')
|
267 |
|
268 |
-
parser.add_argument('--fade_out_duration', type=float, default=1.0,
|
269 |
-
help='Fade out duration in seconds')
|
270 |
|
271 |
-
parser.add_argument('--trim', type=bool, default=False,
|
272 |
-
help='Trim the silence ending of generated audio')
|
273 |
|
274 |
args = parser.parse_args()
|
275 |
|
276 |
if not args.model_dir:
|
277 |
-
args.model_dir = os.path.join("pretrained_models", args.model_name)
|
278 |
|
279 |
print(args)
|
280 |
return args
|
281 |
-
|
282 |
def main():
|
283 |
-
|
284 |
args = get_args()
|
285 |
-
model =
|
286 |
model_dir = args.model_dir,
|
287 |
min_generate_audio_seconds = args.min_generate_audio_seconds,
|
288 |
max_generate_audio_seconds = args.max_generate_audio_seconds,
|
@@ -290,6 +283,7 @@ def main():
|
|
290 |
output_sample_rate = args.output_sample_rate,
|
291 |
load_jit = True,
|
292 |
load_onnx = False,
|
|
|
293 |
fast = args.fast,
|
294 |
fp16 = args.fp16,
|
295 |
gpu = args.gpu,
|
|
|
23 |
import torch
|
24 |
from inspiremusic.utils.audio_utils import trim_audio, fade_out, process_audio
|
25 |
|
26 |
+
def env_variables():
|
27 |
os.environ['PYTHONIOENCODING'] = 'UTF-8'
|
28 |
os.environ['TOKENIZERS_PARALLELISM'] = 'False'
|
29 |
+
current_working_dir = os.getcwd()
|
30 |
+
main_root = os.path.realpath(os.path.join(current_working_dir, '../../'))
|
31 |
bin_dir = os.path.join(main_root, 'inspiremusic')
|
32 |
third_party_matcha_tts_path = os.path.join(main_root, 'third_party', 'Matcha-TTS')
|
33 |
python_path = f"{main_root}:{bin_dir}:{third_party_matcha_tts_path}:{os.environ.get('PYTHONPATH', '')}"
|
34 |
+
os.environ['PYTHONPATH'] = python_path
|
35 |
sys.path.extend([main_root, third_party_matcha_tts_path])
|
36 |
|
37 |
+
class InspireMusicModel:
|
38 |
def __init__(self,
|
39 |
+
model_name: str,
|
40 |
model_dir: str = None,
|
41 |
+
min_generate_audio_seconds: float = 0.0,
|
42 |
max_generate_audio_seconds: float = 30.0,
|
43 |
sample_rate: int = 24000,
|
44 |
output_sample_rate: int = 48000,
|
45 |
load_jit: bool = True,
|
46 |
load_onnx: bool = False,
|
47 |
+
dtype: str = "fp16",
|
48 |
fast: bool = False,
|
49 |
fp16: bool = True,
|
50 |
+
gpu: int = 1,
|
51 |
result_dir: str = None,
|
52 |
+
hub="modelscope",
|
53 |
+
repo_url=None,
|
54 |
+
token=None):
|
55 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
|
56 |
|
57 |
# Set model_dir or default to downloading if it doesn't exist
|
58 |
if model_dir is None:
|
59 |
+
if sys.platform == "win32":
|
60 |
+
model_dir = f"..\..\pretrained_models\{model_name}"
|
61 |
+
else:
|
62 |
+
model_dir = f"../../pretrained_models/{model_name}"
|
63 |
|
64 |
+
if not os.path.isfile(os.path.join(model_dir, "llm.pt")):
|
65 |
if hub == "modelscope":
|
66 |
from modelscope import snapshot_download
|
67 |
if model_name == "InspireMusic-Base":
|
68 |
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
|
69 |
else:
|
70 |
snapshot_download(f"iic/{model_name}", local_dir=model_dir)
|
71 |
+
elif hub == "huggingface":
|
72 |
+
from huggingface_hub import snapshot_download
|
73 |
+
snapshot_download(repo_id=f"FunAudioLLM/{model_name}", local_dir=model_dir)
|
74 |
|
75 |
self.model_dir = model_dir
|
|
|
76 |
|
77 |
self.sample_rate = sample_rate
|
78 |
self.output_sample_rate = 24000 if fast else output_sample_rate
|
79 |
+
self.result_dir = result_dir or os.path.join("exp", model_name)
|
80 |
os.makedirs(self.result_dir, exist_ok=True)
|
81 |
|
82 |
self.min_generate_audio_seconds = min_generate_audio_seconds
|
|
|
86 |
assert self.min_generate_audio_seconds <= self.max_generate_audio_seconds, "Min audio seconds must be less than or equal to max audio seconds"
|
87 |
|
88 |
use_cuda = gpu >= 0 and torch.cuda.is_available()
|
89 |
+
if gpu >=0:
|
90 |
+
if torch.cuda.is_available():
|
91 |
+
self.device = torch.device('cuda')
|
92 |
+
elif torch.backends.mps.is_available():
|
93 |
+
self.device = torch.device('mps')
|
94 |
+
elif torch.xpu.is_available():
|
95 |
+
self.device = torch.device('xpu')
|
96 |
+
else:
|
97 |
+
self.device = torch.device('cpu')
|
98 |
+
|
99 |
+
self.model = InspireMusic(self.model_dir, load_jit=load_jit, load_onnx=load_onnx, dtype=dtype, fast=fast, fp16=fp16)
|
100 |
|
101 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
102 |
|
|
|
105 |
task: str = 'text-to-music',
|
106 |
text: str = None,
|
107 |
audio_prompt: str = None, # audio prompt file path
|
108 |
+
instruct: str = None,
|
109 |
chorus: str = "verse",
|
110 |
time_start: float = 0.0,
|
111 |
time_end: float = 30.0,
|
|
|
221 |
|
222 |
def get_args():
|
223 |
parser = argparse.ArgumentParser(description='Run inference with your model')
|
224 |
+
parser.add_argument('-m', '--model_name', default="InspireMusic-1.5B-Long", help='Model name')
|
|
|
225 |
|
226 |
+
parser.add_argument('-d', '--model_dir', help='Model folder path')
|
|
|
227 |
|
228 |
+
parser.add_argument('-t', '--text', default="Experience soothing and sensual instrumental jazz with a touch of Bossa Nova, perfect for a relaxing restaurant or spa ambiance.", help='Prompt text')
|
|
|
229 |
|
230 |
+
parser.add_argument('-a', '--audio_prompt', default=None, help='Prompt audio')
|
|
|
231 |
|
232 |
+
parser.add_argument('-c', '--chorus', default="intro", help='Chorus tag generation mode (e.g., random, verse, chorus, intro, outro)')
|
|
|
233 |
|
234 |
+
parser.add_argument('-f', '--fast', type=bool, default=False, help='Enable fast inference mode (without flow matching)')
|
|
|
235 |
|
236 |
+
parser.add_argument('-g', '--gpu', type=int, default=1, help='GPU ID for this rank, -1 for CPU')
|
|
|
237 |
|
238 |
+
parser.add_argument('--task', default='text-to-music', choices=['text-to-music', 'continuation', 'reconstruct', 'super_resolution'], help='Inference task type: text-to-music, continuation, reconstruct, super_resolution')
|
|
|
239 |
|
240 |
+
parser.add_argument('-r', '--result_dir', default="exp/inspiremusic", help='Directory to save generated audio')
|
|
|
241 |
|
242 |
+
parser.add_argument('-o', '--output_fn', default="output_audio", help='Output file name')
|
|
|
243 |
|
244 |
+
parser.add_argument('--format', type=str, default="wav", choices=["wav", "mp3", "m4a", "flac"], help='Format of output audio')
|
|
|
245 |
|
246 |
+
parser.add_argument('--sample_rate', type=int, default=24000, help='Sampling rate of input audio')
|
|
|
247 |
|
248 |
+
parser.add_argument('--output_sample_rate', type=int, default=48000, choices=[24000, 48000], help='Sampling rate of generated output audio')
|
|
|
249 |
|
250 |
+
parser.add_argument('-s', '--time_start', type=float, default=0.0, help='Start time in seconds')
|
|
|
251 |
|
252 |
+
parser.add_argument('-e', '--time_end', type=float, default=30.0, help='End time in seconds')
|
|
|
253 |
|
254 |
+
parser.add_argument('--max_audio_prompt_length', type=float, default=5.0, help='Maximum audio prompt length in seconds')
|
|
|
255 |
|
256 |
+
parser.add_argument('--min_generate_audio_seconds', type=float, default=10.0, help='Minimum generated audio length in seconds')
|
|
|
257 |
|
258 |
+
parser.add_argument('--max_generate_audio_seconds', type=float, default=30.0, help='Maximum generated audio length in seconds')
|
|
|
259 |
|
260 |
+
parser.add_argument('--fp16', type=bool, default=True, help='Inference with fp16 model')
|
|
|
261 |
|
262 |
+
parser.add_argument('--fade_out', type=bool, default=True, help='Apply fade out effect to generated audio')
|
|
|
263 |
|
264 |
+
parser.add_argument('--fade_out_duration', type=float, default=1.0, help='Fade out duration in seconds')
|
|
|
265 |
|
266 |
+
parser.add_argument('--trim', type=bool, default=False, help='Trim the silence ending of generated audio')
|
|
|
267 |
|
268 |
args = parser.parse_args()
|
269 |
|
270 |
if not args.model_dir:
|
271 |
+
args.model_dir = os.path.join("../../pretrained_models", args.model_name)
|
272 |
|
273 |
print(args)
|
274 |
return args
|
|
|
275 |
def main():
|
276 |
+
env_variables()
|
277 |
args = get_args()
|
278 |
+
model = InspireMusicModel(model_name = args.model_name,
|
279 |
model_dir = args.model_dir,
|
280 |
min_generate_audio_seconds = args.min_generate_audio_seconds,
|
281 |
max_generate_audio_seconds = args.max_generate_audio_seconds,
|
|
|
283 |
output_sample_rate = args.output_sample_rate,
|
284 |
load_jit = True,
|
285 |
load_onnx = False,
|
286 |
+
dtype="fp16",
|
287 |
fast = args.fast,
|
288 |
fp16 = args.fp16,
|
289 |
gpu = args.gpu,
|
inspiremusic/cli/inspiremusic.py
CHANGED
@@ -12,32 +12,41 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
import os
|
|
|
15 |
import time
|
16 |
from tqdm import tqdm
|
17 |
from hyperpyyaml import load_hyperpyyaml
|
18 |
from inspiremusic.cli.frontend import InspireMusicFrontEnd
|
19 |
from inspiremusic.cli.model import InspireMusicModel
|
20 |
from inspiremusic.utils.file_utils import logging
|
|
|
21 |
import torch
|
22 |
|
23 |
class InspireMusic:
|
24 |
-
def __init__(self, model_dir, load_jit=True, load_onnx=False, fast = False, fp16=True, hub="modelscope"):
|
25 |
instruct = True if '-Instruct' in model_dir else False
|
26 |
|
27 |
if model_dir is None:
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
-
if not os.path.isfile(
|
31 |
model_name = model_dir.split("/")[-1]
|
32 |
if hub == "modelscope":
|
33 |
from modelscope import snapshot_download
|
34 |
if model_name == "InspireMusic-Base":
|
35 |
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
|
36 |
else:
|
37 |
-
snapshot_download(f"iic/
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
|
40 |
-
with open('{}/inspiremusic.yaml'.format(model_dir), 'r') as f:
|
41 |
configs = load_hyperpyyaml(f)
|
42 |
|
43 |
self.frontend = InspireMusicFrontEnd(configs,
|
@@ -47,15 +56,17 @@ class InspireMusic:
|
|
47 |
'{}/music_tokenizer/'.format(model_dir),
|
48 |
'{}/wavtokenizer/'.format(model_dir),
|
49 |
instruct,
|
|
|
50 |
fast,
|
51 |
fp16,
|
52 |
configs['allowed_special'])
|
53 |
|
54 |
-
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], fast, fp16)
|
55 |
-
self.model.load('
|
56 |
-
'
|
57 |
-
'
|
58 |
-
'
|
|
|
59 |
del configs
|
60 |
|
61 |
@torch.inference_mode()
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
+
import sys
|
16 |
import time
|
17 |
from tqdm import tqdm
|
18 |
from hyperpyyaml import load_hyperpyyaml
|
19 |
from inspiremusic.cli.frontend import InspireMusicFrontEnd
|
20 |
from inspiremusic.cli.model import InspireMusicModel
|
21 |
from inspiremusic.utils.file_utils import logging
|
22 |
+
from inspiremusic.utils.utils import download_model
|
23 |
import torch
|
24 |
|
25 |
class InspireMusic:
|
26 |
+
def __init__(self, model_dir, load_jit=True, load_onnx=False, dtype = "fp16", fast = False, fp16=True, hub="modelscope", repo_url=None, token=None):
|
27 |
instruct = True if '-Instruct' in model_dir else False
|
28 |
|
29 |
if model_dir is None:
|
30 |
+
if sys.platform == "win32":
|
31 |
+
model_dir = f"..\..\pretrained_models\{model_name}"
|
32 |
+
else:
|
33 |
+
model_dir = f"../../pretrained_models/{model_name}"
|
34 |
|
35 |
+
if not os.path.isfile(os.path.join(model_dir, "llm.pt")):
|
36 |
model_name = model_dir.split("/")[-1]
|
37 |
if hub == "modelscope":
|
38 |
from modelscope import snapshot_download
|
39 |
if model_name == "InspireMusic-Base":
|
40 |
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
|
41 |
else:
|
42 |
+
snapshot_download(f"iic/InspireMusic", local_dir=model_dir)
|
43 |
+
elif hub == "huggingface":
|
44 |
+
from huggingface_hub import snapshot_download
|
45 |
+
snapshot_download(repo_id=f"FunAudioLLM/{model_name}", local_dir=model_dir)
|
46 |
+
else:
|
47 |
+
download_model(repo_url, model_dir, token)
|
48 |
|
49 |
+
with open(os.path.join(model_dir, 'inspiremusic.yaml'), 'r') as f:
|
|
|
50 |
configs = load_hyperpyyaml(f)
|
51 |
|
52 |
self.frontend = InspireMusicFrontEnd(configs,
|
|
|
56 |
'{}/music_tokenizer/'.format(model_dir),
|
57 |
'{}/wavtokenizer/'.format(model_dir),
|
58 |
instruct,
|
59 |
+
dtype,
|
60 |
fast,
|
61 |
fp16,
|
62 |
configs['allowed_special'])
|
63 |
|
64 |
+
self.model = InspireMusicModel(configs['llm'], configs['flow'], configs['hift'], configs['wavtokenizer'], dtype, fast, fp16)
|
65 |
+
self.model.load(os.path.join(model_dir, 'llm.pt'),
|
66 |
+
os.path.join(model_dir, 'flow.pt'),
|
67 |
+
os.path.join(model_dir, 'music_tokenizer'),
|
68 |
+
os.path.join(model_dir, 'wavtokenizer', "model.pt"),
|
69 |
+
)
|
70 |
del configs
|
71 |
|
72 |
@torch.inference_mode()
|
inspiremusic/cli/model.py
CHANGED
@@ -11,6 +11,8 @@
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
|
|
|
|
14 |
import numpy as np
|
15 |
import threading
|
16 |
import time
|
@@ -21,23 +23,37 @@ from inspiremusic.wavtokenizer.decoder.pretrained import WavTokenizer
|
|
21 |
from torch.cuda.amp import autocast
|
22 |
import logging
|
23 |
import torch
|
24 |
-
import os
|
25 |
-
|
26 |
|
27 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
28 |
|
29 |
class InspireMusicModel:
|
30 |
-
|
31 |
def __init__(self,
|
32 |
llm: torch.nn.Module,
|
33 |
flow: torch.nn.Module,
|
34 |
music_tokenizer: torch.nn.Module,
|
35 |
wavtokenizer: torch.nn.Module,
|
|
|
36 |
fast: bool = False,
|
37 |
fp16: bool = True,
|
38 |
):
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
self.flow = flow
|
42 |
self.music_tokenizer = music_tokenizer
|
43 |
self.wavtokenizer = wavtokenizer
|
@@ -66,7 +82,7 @@ class InspireMusicModel:
|
|
66 |
def load(self, llm_model, flow_model, hift_model, wavtokenizer_model):
|
67 |
if llm_model is not None:
|
68 |
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
69 |
-
self.llm.to(self.device).eval()
|
70 |
else:
|
71 |
self.llm = None
|
72 |
if flow_model is not None:
|
@@ -74,19 +90,15 @@ class InspireMusicModel:
|
|
74 |
self.flow.to(self.device).eval()
|
75 |
if hift_model is not None:
|
76 |
if ".pt" not in hift_model:
|
77 |
-
self.music_tokenizer = VQVAE(
|
78 |
-
hift_model + '/model.pt', with_encoder=True)
|
79 |
else:
|
80 |
-
self.music_tokenizer = VQVAE(os.path.dirname(hift_model)
|
81 |
-
hift_model, with_encoder=True)
|
82 |
self.music_tokenizer.to(self.device).eval()
|
83 |
if wavtokenizer_model is not None:
|
84 |
if ".pt" not in wavtokenizer_model:
|
85 |
-
self.wavtokenizer = WavTokenizer.from_pretrained_feat(
|
86 |
-
wavtokenizer_model + '/model.pt')
|
87 |
else:
|
88 |
-
self.wavtokenizer = WavTokenizer.from_pretrained_feat(
|
89 |
-
wavtokenizer_model )
|
90 |
self.wavtokenizer.to(self.device)
|
91 |
|
92 |
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
@@ -110,7 +122,7 @@ class InspireMusicModel:
|
|
110 |
def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task):
|
111 |
with self.llm_context:
|
112 |
local_res = []
|
113 |
-
with autocast(enabled=self.fp16):
|
114 |
inference_kwargs = {
|
115 |
'text': text.to(self.device),
|
116 |
'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
import numpy as np
|
17 |
import threading
|
18 |
import time
|
|
|
23 |
from torch.cuda.amp import autocast
|
24 |
import logging
|
25 |
import torch
|
|
|
|
|
26 |
|
27 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
28 |
|
29 |
class InspireMusicModel:
|
|
|
30 |
def __init__(self,
|
31 |
llm: torch.nn.Module,
|
32 |
flow: torch.nn.Module,
|
33 |
music_tokenizer: torch.nn.Module,
|
34 |
wavtokenizer: torch.nn.Module,
|
35 |
+
dtype: str = "fp16",
|
36 |
fast: bool = False,
|
37 |
fp16: bool = True,
|
38 |
):
|
39 |
+
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
self.device = torch.device('cuda')
|
42 |
+
elif torch.backends.mps.is_available():
|
43 |
+
self.device = torch.device('mps')
|
44 |
+
elif torch.xpu.is_available():
|
45 |
+
self.device = torch.device('xpu')
|
46 |
+
else:
|
47 |
+
self.device = torch.device('cpu')
|
48 |
+
|
49 |
+
if dtype == "fp16":
|
50 |
+
self.dtype = torch.float16
|
51 |
+
elif dtype == "bf16":
|
52 |
+
self.dtype = torch.bfloat16
|
53 |
+
else:
|
54 |
+
self.dtype = torch.float32
|
55 |
+
|
56 |
+
self.llm = llm.to(self.dtype)
|
57 |
self.flow = flow
|
58 |
self.music_tokenizer = music_tokenizer
|
59 |
self.wavtokenizer = wavtokenizer
|
|
|
82 |
def load(self, llm_model, flow_model, hift_model, wavtokenizer_model):
|
83 |
if llm_model is not None:
|
84 |
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
|
85 |
+
self.llm.to(self.device).to(self.dtype).eval()
|
86 |
else:
|
87 |
self.llm = None
|
88 |
if flow_model is not None:
|
|
|
90 |
self.flow.to(self.device).eval()
|
91 |
if hift_model is not None:
|
92 |
if ".pt" not in hift_model:
|
93 |
+
self.music_tokenizer = VQVAE(os.path.join(hift_model, 'config.json'), os.path.join(hift_model, 'model.pt'), with_encoder=True)
|
|
|
94 |
else:
|
95 |
+
self.music_tokenizer = VQVAE(os.path.join(os.path.dirname(hift_model), 'config.json'), hift_model, with_encoder=True)
|
|
|
96 |
self.music_tokenizer.to(self.device).eval()
|
97 |
if wavtokenizer_model is not None:
|
98 |
if ".pt" not in wavtokenizer_model:
|
99 |
+
self.wavtokenizer = WavTokenizer.from_pretrained_feat(os.path.join(wavtokenizer_model, 'config.yaml'), os.path.join(wavtokenizer_model, 'model.pt'))
|
|
|
100 |
else:
|
101 |
+
self.wavtokenizer = WavTokenizer.from_pretrained_feat(os.path.join(os.path.dirname(wavtokenizer_model), 'config.yaml'), wavtokenizer_model)
|
|
|
102 |
self.wavtokenizer.to(self.device)
|
103 |
|
104 |
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
|
|
122 |
def llm_job(self, text, audio_token, audio_token_len, prompt_text, llm_prompt_audio_token, embeddings, uuid, duration_to_gen, task):
|
123 |
with self.llm_context:
|
124 |
local_res = []
|
125 |
+
with autocast(enabled=self.fp16, dtype=self.dtype, cache_enabled=True):
|
126 |
inference_kwargs = {
|
127 |
'text': text.to(self.device),
|
128 |
'text_len': torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
inspiremusic/flow/flow.py
CHANGED
@@ -39,7 +39,7 @@ class MaskedDiff(torch.nn.Module):
|
|
39 |
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
40 |
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000,
|
41 |
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000},
|
42 |
-
generator_model_dir: str = "pretrained_models/InspireMusic-Base/music_tokenizer",
|
43 |
num_codebooks: int = 4
|
44 |
):
|
45 |
super().__init__()
|
|
|
39 |
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
40 |
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 128, 'sampling_rate': 48000,
|
41 |
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 48000},
|
42 |
+
generator_model_dir: str = "../../pretrained_models/InspireMusic-Base/music_tokenizer",
|
43 |
num_codebooks: int = 4
|
44 |
):
|
45 |
super().__init__()
|
inspiremusic/llm/llm.py
CHANGED
@@ -50,9 +50,19 @@ class LLM(torch.nn.Module):
|
|
50 |
length_normalized_loss: bool = True,
|
51 |
lsm_weight: float = 0.0,
|
52 |
frozen_input_embed: bool = False,
|
|
|
|
|
53 |
**kwargs,
|
54 |
):
|
55 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
self.llm_input_size = llm_input_size
|
57 |
self.audio_token_size = audio_token_size
|
58 |
# 1. build text token inputs related modules
|
@@ -115,34 +125,9 @@ class LLM(torch.nn.Module):
|
|
115 |
|
116 |
encoder_name = encoder_conf.pop("name", "transformer")
|
117 |
model = None
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
**encoder_conf,
|
122 |
-
input_size=self.input_size,
|
123 |
-
use_cnn_module=False,
|
124 |
-
macaron_style=False,
|
125 |
-
)
|
126 |
-
elif encoder_name == "conformer":
|
127 |
-
from inspiremusic.transformer.encoder.conformer_encoder import ConformerEncoder
|
128 |
-
model = ConformerEncoder(
|
129 |
-
**encoder_conf,
|
130 |
-
input_size=self.input_size,
|
131 |
-
)
|
132 |
-
elif encoder_name == "llama_encoder":
|
133 |
-
from inspiremusic.transformer.encoder.llama_encoder import LlamaEncoder
|
134 |
-
model = LlamaEncoder(
|
135 |
-
**encoder_conf,
|
136 |
-
input_size=self.input_size,
|
137 |
-
)
|
138 |
-
elif encoder_name == "qwen2":
|
139 |
-
from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
|
140 |
-
model = QwenEncoder(
|
141 |
-
**encoder_conf,
|
142 |
-
input_size=self.input_size,
|
143 |
-
)
|
144 |
-
elif encoder_name == "qwen2.5":
|
145 |
-
from inspiremusic.transformer.encoder.qwen_encoder import QwenEncoder
|
146 |
model = QwenEncoder(
|
147 |
**encoder_conf,
|
148 |
input_size=self.input_size,
|
@@ -237,8 +222,7 @@ class LLM(torch.nn.Module):
|
|
237 |
time_end_embed = self.time_embedding(time_end).to(text_token.dtype)
|
238 |
chorus_embed = self.chorus_embedding(chorus)
|
239 |
|
240 |
-
lm_target = [torch.tensor(
|
241 |
-
[IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
|
242 |
|
243 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
244 |
|
@@ -250,18 +234,9 @@ class LLM(torch.nn.Module):
|
|
250 |
audio_token = self.speech_embedding(audio_token)
|
251 |
|
252 |
# 5. unpad and pad
|
253 |
-
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb,
|
254 |
-
[time_start_embed,
|
255 |
-
time_end_embed,
|
256 |
-
chorus_embed],
|
257 |
-
text_token,
|
258 |
-
text_token_len,
|
259 |
-
task_id_emb,
|
260 |
-
audio_token,
|
261 |
-
audio_token_len,
|
262 |
-
seg_len)
|
263 |
# 6. run lm forward
|
264 |
-
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
265 |
logits = self.llm_decoder(lm_output)
|
266 |
loss = self.criterion_ce(logits, lm_target)
|
267 |
|
@@ -290,7 +265,7 @@ class LLM(torch.nn.Module):
|
|
290 |
prompt_audio_token: torch.Tensor,
|
291 |
prompt_audio_token_len: torch.Tensor,
|
292 |
embeddings: List,
|
293 |
-
duration_to_gen: float =
|
294 |
task: str = "continuation",
|
295 |
token_rate: int = 75,
|
296 |
limit_audio_prompt_len: int = 5,
|
@@ -317,8 +292,7 @@ class LLM(torch.nn.Module):
|
|
317 |
time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half()
|
318 |
chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half()
|
319 |
else:
|
320 |
-
time_start_embed = self.time_embedding(
|
321 |
-
time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half()
|
322 |
time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half()
|
323 |
chorus_embed = self.chorus_embedding(chorus) # .half()
|
324 |
|
@@ -332,10 +306,10 @@ class LLM(torch.nn.Module):
|
|
332 |
else:
|
333 |
audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
334 |
|
335 |
-
if prompt_audio_token_len:
|
336 |
-
|
337 |
-
else:
|
338 |
-
|
339 |
# Check if removing prompt audio token will fail decoding.
|
340 |
|
341 |
if task == "continuation":
|
@@ -344,31 +318,18 @@ class LLM(torch.nn.Module):
|
|
344 |
chorus_embed, text, task_id_emb, audio_token_emb], dim=1)
|
345 |
|
346 |
if infer_cfg:
|
347 |
-
audio_cfg = self.speech_embedding(
|
348 |
-
|
349 |
-
lm_cf_input = torch.concat(
|
350 |
-
[sos_eos_emb, torch.rand_like(time_start_embed),
|
351 |
-
torch.rand_like(time_end_embed),
|
352 |
-
torch.rand_like(chorus_embed), text_cfg, task_id_emb,
|
353 |
-
audio_cfg], dim=1)
|
354 |
lm_input = torch.cat([lm_input, lm_cf_input], 0)
|
355 |
else:
|
356 |
-
lm_input = torch.concat(
|
357 |
-
[sos_eos_emb, time_start_embed, time_end_embed,
|
358 |
-
chorus_embed, text, task_id_emb], dim=1)
|
359 |
if infer_cfg:
|
360 |
-
lm_cf_input = torch.concat(
|
361 |
-
[sos_eos_emb, torch.rand_like(time_start_embed),
|
362 |
-
torch.rand_like(time_end_embed),
|
363 |
-
torch.rand_like(chorus_embed), text_cfg, task_id_emb],
|
364 |
-
dim=1)
|
365 |
lm_input = torch.cat([lm_input, lm_cf_input], 0)
|
366 |
|
367 |
# 4. cal min/max_length
|
368 |
-
min_len = 0.9 * duration_to_gen * token_rate
|
369 |
max_len = duration_to_gen * token_rate
|
370 |
-
logging.info(
|
371 |
-
f"LLM generation sequence length: {max_len}, generate audio length {duration_to_gen}s.")
|
372 |
|
373 |
# 5. step by step decode
|
374 |
out_tokens = []
|
@@ -376,7 +337,7 @@ class LLM(torch.nn.Module):
|
|
376 |
state = None
|
377 |
|
378 |
for i in range(int(max_len)):
|
379 |
-
y_pred, _, state = self.llm.forward_one_step(lm_input, torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state)
|
380 |
logits = self.llm_decoder(y_pred[:, -1])
|
381 |
if infer_cfg:
|
382 |
# perform context free guidance
|
@@ -389,10 +350,7 @@ class LLM(torch.nn.Module):
|
|
389 |
logp = logp.squeeze(dim=0)
|
390 |
|
391 |
if i < int(min_len):
|
392 |
-
logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=
|
393 |
-
|
394 |
-
if i < int(min_len):
|
395 |
-
logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=torch.float16)
|
396 |
|
397 |
top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
|
398 |
|
|
|
50 |
length_normalized_loss: bool = True,
|
51 |
lsm_weight: float = 0.0,
|
52 |
frozen_input_embed: bool = False,
|
53 |
+
dtype: str = "fp16",
|
54 |
+
text_token_size: int = 151643,
|
55 |
**kwargs,
|
56 |
):
|
57 |
super().__init__()
|
58 |
+
|
59 |
+
if dtype == "fp16":
|
60 |
+
self.dtype = torch.float16
|
61 |
+
elif dtype == "bf16":
|
62 |
+
self.dtype = torch.bfloat16
|
63 |
+
else:
|
64 |
+
self.dtype = torch.float32
|
65 |
+
|
66 |
self.llm_input_size = llm_input_size
|
67 |
self.audio_token_size = audio_token_size
|
68 |
# 1. build text token inputs related modules
|
|
|
125 |
|
126 |
encoder_name = encoder_conf.pop("name", "transformer")
|
127 |
model = None
|
128 |
+
|
129 |
+
if "qwen" in encoder_name:
|
130 |
+
from inspiremusic.transformer.qwen_encoder import QwenEncoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
model = QwenEncoder(
|
132 |
**encoder_conf,
|
133 |
input_size=self.input_size,
|
|
|
222 |
time_end_embed = self.time_embedding(time_end).to(text_token.dtype)
|
223 |
chorus_embed = self.chorus_embedding(chorus)
|
224 |
|
225 |
+
lm_target = [torch.tensor([IGNORE_ID] * (4 + text_token_len[i]) + audio_token[i,:audio_token_len[i]].tolist() + [self.audio_token_size]) for i in range(text_token.size(0))]
|
|
|
226 |
|
227 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
228 |
|
|
|
234 |
audio_token = self.speech_embedding(audio_token)
|
235 |
|
236 |
# 5. unpad and pad
|
237 |
+
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, [time_start_embed, time_end_embed, chorus_embed], text_token, text_token_len, task_id_emb, audio_token, audio_token_len, seg_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
# 6. run lm forward
|
239 |
+
lm_output, lm_output_mask = self.llm(lm_input.to(self.dtype), lm_input_len.to(device))
|
240 |
logits = self.llm_decoder(lm_output)
|
241 |
loss = self.criterion_ce(logits, lm_target)
|
242 |
|
|
|
265 |
prompt_audio_token: torch.Tensor,
|
266 |
prompt_audio_token_len: torch.Tensor,
|
267 |
embeddings: List,
|
268 |
+
duration_to_gen: float = 30,
|
269 |
task: str = "continuation",
|
270 |
token_rate: int = 75,
|
271 |
limit_audio_prompt_len: int = 5,
|
|
|
292 |
time_end_embed = self.time_embedding(time_end).reshape(1, 1, -1) # .half()
|
293 |
chorus_embed = self.chorus_embedding(chorus).reshape(1, 1, -1) # .half()
|
294 |
else:
|
295 |
+
time_start_embed = self.time_embedding(time_start.view(-1)).reshape(1, chorus.size(1), -1) # .half()
|
|
|
296 |
time_end_embed = self.time_embedding(time_end.view(-1)).reshape(1, chorus.size(1), -1) # .half()
|
297 |
chorus_embed = self.chorus_embedding(chorus) # .half()
|
298 |
|
|
|
306 |
else:
|
307 |
audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
308 |
|
309 |
+
#if prompt_audio_token_len:
|
310 |
+
# prompt_audio_token_emb = self.speech_embedding(prompt_audio_token)
|
311 |
+
#else:
|
312 |
+
# prompt_audio_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
313 |
# Check if removing prompt audio token will fail decoding.
|
314 |
|
315 |
if task == "continuation":
|
|
|
318 |
chorus_embed, text, task_id_emb, audio_token_emb], dim=1)
|
319 |
|
320 |
if infer_cfg:
|
321 |
+
audio_cfg = self.speech_embedding(audio_token.new_zeros(audio_token.shape))
|
322 |
+
lm_cf_input = torch.concat([sos_eos_emb, torch.rand_like(time_start_embed), torch.rand_like(time_end_embed), torch.rand_like(chorus_embed), text_cfg, task_id_emb, audio_cfg], dim=1)
|
|
|
|
|
|
|
|
|
|
|
323 |
lm_input = torch.cat([lm_input, lm_cf_input], 0)
|
324 |
else:
|
325 |
+
lm_input = torch.concat([sos_eos_emb, time_start_embed, time_end_embed, chorus_embed, text, task_id_emb], dim=1)
|
|
|
|
|
326 |
if infer_cfg:
|
327 |
+
lm_cf_input = torch.concat([sos_eos_emb, torch.rand_like(time_start_embed), torch.rand_like(time_end_embed), torch.rand_like(chorus_embed), text_cfg, task_id_emb], dim=1)
|
|
|
|
|
|
|
|
|
328 |
lm_input = torch.cat([lm_input, lm_cf_input], 0)
|
329 |
|
330 |
# 4. cal min/max_length
|
331 |
+
min_len = int(0.9 * duration_to_gen * token_rate)
|
332 |
max_len = duration_to_gen * token_rate
|
|
|
|
|
333 |
|
334 |
# 5. step by step decode
|
335 |
out_tokens = []
|
|
|
337 |
state = None
|
338 |
|
339 |
for i in range(int(max_len)):
|
340 |
+
y_pred, _, state = self.llm.forward_one_step(lm_input.to(self.dtype), torch.ones(lm_input.shape[0], lm_input.shape[1], device=lm_input.device).to(torch.bool), cache=state)
|
341 |
logits = self.llm_decoder(y_pred[:, -1])
|
342 |
if infer_cfg:
|
343 |
# perform context free guidance
|
|
|
350 |
logp = logp.squeeze(dim=0)
|
351 |
|
352 |
if i < int(min_len):
|
353 |
+
logp[self.audio_token_size] = torch.tensor(float('-inf'), dtype=self.dtype)
|
|
|
|
|
|
|
354 |
|
355 |
top_ids = self.sampling_ids(logp, out_tokens, ignore_eos=i < min_len).item()
|
356 |
|
inspiremusic/transformer/qwen_encoder.py
CHANGED
@@ -22,6 +22,7 @@ class QwenEncoder(nn.Module):
|
|
22 |
def __init__(
|
23 |
self,
|
24 |
input_size: int,
|
|
|
25 |
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
|
26 |
trainable: bool = False,
|
27 |
do_fusion_emb: bool = False,
|
@@ -30,7 +31,15 @@ class QwenEncoder(nn.Module):
|
|
30 |
super(QwenEncoder, self).__init__()
|
31 |
self.input_size = input_size
|
32 |
self.trainable = trainable
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
self._output_size = self.model.config.hidden_size
|
35 |
self.do_fusion_emb = do_fusion_emb
|
36 |
self.hidden_norm = torch.nn.LayerNorm(self._output_size)
|
@@ -88,14 +97,19 @@ class QwenEmbeddingEncoder(nn.Module):
|
|
88 |
def __init__(
|
89 |
self,
|
90 |
input_size: int,
|
|
|
91 |
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
|
92 |
):
|
93 |
super(QwenEmbeddingEncoder, self).__init__()
|
94 |
self.input_size = input_size
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
from transformers import Qwen2ForCausalLM
|
96 |
-
|
97 |
-
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path,
|
98 |
-
device_map="cpu")
|
99 |
self._output_size = self.model.config.hidden_size
|
100 |
|
101 |
def output_size(self) -> int:
|
@@ -137,14 +151,19 @@ class QwenInputOnlyEncoder(nn.Module):
|
|
137 |
def __init__(
|
138 |
self,
|
139 |
input_size: int,
|
|
|
140 |
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
|
141 |
):
|
142 |
super(QwenInputOnlyEncoder, self).__init__()
|
143 |
self.input_size = input_size
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
from transformers import Qwen2ForCausalLM
|
145 |
-
|
146 |
-
model = Qwen2ForCausalLM.from_pretrained(pretrain_path,
|
147 |
-
device_map="cpu")
|
148 |
self.embed = model.model.embed_tokens
|
149 |
for p in self.embed.parameters():
|
150 |
p.requires_grad = False
|
|
|
22 |
def __init__(
|
23 |
self,
|
24 |
input_size: int,
|
25 |
+
dtype: str = "fp16",
|
26 |
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
|
27 |
trainable: bool = False,
|
28 |
do_fusion_emb: bool = False,
|
|
|
31 |
super(QwenEncoder, self).__init__()
|
32 |
self.input_size = input_size
|
33 |
self.trainable = trainable
|
34 |
+
|
35 |
+
if dtype == "fp16":
|
36 |
+
self.dtype = torch.float16
|
37 |
+
elif dtype == "bf16":
|
38 |
+
self.dtype = torch.bfloat16
|
39 |
+
else:
|
40 |
+
self.dtype = torch.float32
|
41 |
+
|
42 |
+
self.model = AutoModelForCausalLM.from_pretrained(pretrain_path, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=self.dtype)
|
43 |
self._output_size = self.model.config.hidden_size
|
44 |
self.do_fusion_emb = do_fusion_emb
|
45 |
self.hidden_norm = torch.nn.LayerNorm(self._output_size)
|
|
|
97 |
def __init__(
|
98 |
self,
|
99 |
input_size: int,
|
100 |
+
dtype: str = "fp16",
|
101 |
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
|
102 |
):
|
103 |
super(QwenEmbeddingEncoder, self).__init__()
|
104 |
self.input_size = input_size
|
105 |
+
if dtype == "fp16":
|
106 |
+
self.dtype = torch.float16
|
107 |
+
elif dtype == "bf16":
|
108 |
+
self.dtype = torch.bfloat16
|
109 |
+
else:
|
110 |
+
self.dtype = torch.float32
|
111 |
from transformers import Qwen2ForCausalLM
|
112 |
+
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=self.dtype)
|
|
|
|
|
113 |
self._output_size = self.model.config.hidden_size
|
114 |
|
115 |
def output_size(self) -> int:
|
|
|
151 |
def __init__(
|
152 |
self,
|
153 |
input_size: int,
|
154 |
+
dtype: str = "fp16",
|
155 |
pretrain_path: str = "Qwen/Qwen2.0-0.5B",
|
156 |
):
|
157 |
super(QwenInputOnlyEncoder, self).__init__()
|
158 |
self.input_size = input_size
|
159 |
+
if dtype == "fp16":
|
160 |
+
self.dtype = torch.float16
|
161 |
+
elif dtype == "bf16":
|
162 |
+
self.dtype = torch.bfloat16
|
163 |
+
else:
|
164 |
+
self.dtype = torch.float32
|
165 |
from transformers import Qwen2ForCausalLM
|
166 |
+
model = Qwen2ForCausalLM.from_pretrained(pretrain_path, device_map="auto", attn_implementation="flash_attention_2", torch_dtype=self.dtype)
|
|
|
|
|
167 |
self.embed = model.model.embed_tokens
|
168 |
for p in self.embed.parameters():
|
169 |
p.requires_grad = False
|
inspiremusic/utils/common.py
CHANGED
@@ -16,12 +16,9 @@
|
|
16 |
"""Unility functions for Transformer."""
|
17 |
|
18 |
from typing import List
|
19 |
-
|
20 |
import torch
|
21 |
IGNORE_ID = -1
|
22 |
|
23 |
-
MUSIC_STRUCTURE_LABELS = ["intro", "verse1", "chorus", "verse2", "outro"]
|
24 |
-
|
25 |
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
26 |
"""Perform padding for the list of tensors.
|
27 |
|
@@ -92,16 +89,61 @@ def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
|
|
92 |
denominator = torch.sum(mask)
|
93 |
return (numerator / denominator).detach()
|
94 |
|
95 |
-
|
96 |
def get_padding(kernel_size, dilation=1):
|
97 |
return int((kernel_size * dilation - dilation) / 2)
|
98 |
|
99 |
-
|
100 |
def init_weights(m, mean=0.0, std=0.01):
|
101 |
classname = m.__class__.__name__
|
102 |
if classname.find("Conv") != -1:
|
103 |
m.weight.data.normal_(mean, std)
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def topk_sampling(weighted_scores, decoded_tokens, top_k=25):
|
106 |
zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf')
|
107 |
values,indices = torch.topk(weighted_scores,top_k)
|
|
|
16 |
"""Unility functions for Transformer."""
|
17 |
|
18 |
from typing import List
|
|
|
19 |
import torch
|
20 |
IGNORE_ID = -1
|
21 |
|
|
|
|
|
22 |
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
23 |
"""Perform padding for the list of tensors.
|
24 |
|
|
|
89 |
denominator = torch.sum(mask)
|
90 |
return (numerator / denominator).detach()
|
91 |
|
|
|
92 |
def get_padding(kernel_size, dilation=1):
|
93 |
return int((kernel_size * dilation - dilation) / 2)
|
94 |
|
|
|
95 |
def init_weights(m, mean=0.0, std=0.01):
|
96 |
classname = m.__class__.__name__
|
97 |
if classname.find("Conv") != -1:
|
98 |
m.weight.data.normal_(mean, std)
|
99 |
|
100 |
+
def keep_rhythm(next_token, current_time_signature):
|
101 |
+
allowed_durations = get_allowed_durations(current_time_signature)
|
102 |
+
if next_token not in allowed_durations:
|
103 |
+
next_token = random.choice(allowed_durations)
|
104 |
+
return next_token
|
105 |
+
|
106 |
+
def keep_harmony(next_token, current_chord):
|
107 |
+
allowed_notes = get_allowed_notes(current_chord) # Define allowed notes for the chord
|
108 |
+
if next_token not in allowed_notes:
|
109 |
+
next_token = random.choice(allowed_notes) # Replace with a valid note
|
110 |
+
return next_token
|
111 |
+
|
112 |
+
def relieve_repetition(weighted_scores, recent_tokens, repetition_penalty=1.2):
|
113 |
+
for token in recent_tokens:
|
114 |
+
if weighted_scores[token] > 0:
|
115 |
+
weighted_scores[token] /= repetition_penalty
|
116 |
+
return weighted_scores
|
117 |
+
|
118 |
+
def top_p_sampling_with_constraints(weighted_scores, decoded_tokens, top_p=0.85, temperature=1.1, current_chord=None, current_time_signature=None, recent_tokens=None):
|
119 |
+
# Apply temperature scaling
|
120 |
+
weighted_scores = weighted_scores ** (1 / temperature)
|
121 |
+
weighted_scores /= weighted_scores.sum()
|
122 |
+
|
123 |
+
if recent_tokens:
|
124 |
+
weighted_scores = relieve_repetition(weighted_scores, recent_tokens)
|
125 |
+
|
126 |
+
# Sort weighted scores in descending order
|
127 |
+
sorted_weighted_scores, _ = torch.sort(weighted_scores, descending=True)
|
128 |
+
|
129 |
+
# Compute cumulative weighted scores
|
130 |
+
cumulative_weighted_scores = torch.cumsum(sorted_weighted_scores, dim=0)
|
131 |
+
|
132 |
+
# Find the threthold index of top-p
|
133 |
+
cutoff_index = torch.where(cumulative_weighted_scores >= top_p)[0][0]
|
134 |
+
selected_weighted_scores = sorted_weighted_scores[:cutoff_index + 1]
|
135 |
+
|
136 |
+
# Apply domain-specific constraints
|
137 |
+
if current_chord:
|
138 |
+
selected_weighted_scores = keep_harmony(selected_weighted_scores, current_chord)
|
139 |
+
if current_time_signature:
|
140 |
+
selected_weighted_scores = keep_rhythm(selected_weighted_scores, current_time_signature)
|
141 |
+
|
142 |
+
# Normalize selected probabilities
|
143 |
+
selected_weighted_scores /= selected_weighted_scores.sum()
|
144 |
+
|
145 |
+
# Sample top-p tokens from the distribution
|
146 |
+
return random_sampling(selected_weighted_scores, decoded_tokens)
|
147 |
def topk_sampling(weighted_scores, decoded_tokens, top_k=25):
|
148 |
zeros = weighted_scores.new_ones(weighted_scores.shape) * float('-inf')
|
149 |
values,indices = torch.topk(weighted_scores,top_k)
|
inspiremusic/utils/executor.py
CHANGED
@@ -24,13 +24,19 @@ from inspiremusic.utils.train_utils import update_parameter_and_lr, log_per_step
|
|
24 |
from torch.cuda.amp import GradScaler, autocast
|
25 |
|
26 |
class Executor:
|
27 |
-
|
28 |
def __init__(self):
|
29 |
self.step = 0
|
30 |
self.epoch = 0
|
31 |
self.rank = int(os.environ.get('RANK', 0))
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None):
|
35 |
''' Train one epoch
|
36 |
'''
|
|
|
24 |
from torch.cuda.amp import GradScaler, autocast
|
25 |
|
26 |
class Executor:
|
|
|
27 |
def __init__(self):
|
28 |
self.step = 0
|
29 |
self.epoch = 0
|
30 |
self.rank = int(os.environ.get('RANK', 0))
|
31 |
+
if torch.cuda.is_available():
|
32 |
+
if torch.cuda.is_available():
|
33 |
+
self.device = torch.device('cuda:{}'.format(self.rank))
|
34 |
+
elif torch.backends.mps.is_available():
|
35 |
+
self.device = torch.device('mps')
|
36 |
+
elif torch.xpu.is_available():
|
37 |
+
self.device = torch.device('xpu')
|
38 |
+
else:
|
39 |
+
self.device = torch.device('cpu')
|
40 |
def train_one_epoch(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join, scaler=None):
|
41 |
''' Train one epoch
|
42 |
'''
|
inspiremusic/utils/utils.py
CHANGED
@@ -1,5 +1,27 @@
|
|
1 |
import os
|
2 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def align_trans_scp_file(trans, scp):
|
5 |
trans_dict = {}
|
@@ -14,9 +36,4 @@ def align_trans_scp_file(trans, scp):
|
|
14 |
scp_dict[sec[0]] = sec[1]
|
15 |
with open("text", "w") as f:
|
16 |
for k, v in scp_dict.items():
|
17 |
-
f.write("%s\t%s\n"%(k,trans_dict[k]))
|
18 |
-
|
19 |
-
if __name__ == '__main__':
|
20 |
-
trans = sys.argv[1]
|
21 |
-
scp = sys.argv[2]
|
22 |
-
align_trans_scp_file(trans, scp)
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
+
import subprocess
|
4 |
+
|
5 |
+
def download_model(repo_url: str, output_dir: str = None, token: str = None):
|
6 |
+
try:
|
7 |
+
if token:
|
8 |
+
repo_url = repo_url.replace("https://", f"https://USER:{token}@")
|
9 |
+
else:
|
10 |
+
repo_url = f"https://www.modelscope.cn/models/iic/{repo_url}"
|
11 |
+
|
12 |
+
cmd = ["git", "clone", repo_url]
|
13 |
+
if output_dir:
|
14 |
+
cmd.append(output_dir)
|
15 |
+
|
16 |
+
result = subprocess.run(
|
17 |
+
cmd,
|
18 |
+
check=True,
|
19 |
+
capture_output=True,
|
20 |
+
text=True
|
21 |
+
)
|
22 |
+
print("Success:", result.stdout)
|
23 |
+
except subprocess.CalledProcessError as e:
|
24 |
+
print("Error:", e.stderr)
|
25 |
|
26 |
def align_trans_scp_file(trans, scp):
|
27 |
trans_dict = {}
|
|
|
36 |
scp_dict[sec[0]] = sec[1]
|
37 |
with open("text", "w") as f:
|
38 |
for k, v in scp_dict.items():
|
39 |
+
f.write("%s\t%s\n"%(k,trans_dict[k]))
|
|
|
|
|
|
|
|
|
|
inspiremusic/wavtokenizer/.DS_Store
DELETED
Binary file (6.15 kB)
|
|