SongGeneration / levo_inference.py
hainazhu
Add infer code
2644f3e
raw
history blame
3.71 kB
import os
import sys
sys.path.append('./codeclm/tokenizer')
sys.path.append('./codeclm/tokenizer/Flow1dVAE')
sys.path.append('.')
import torch
import json
from omegaconf import OmegaConf
from codeclm.trainer.codec_song_pl import CodecLM_PL
from codeclm.models import CodecLM
from separator import Separator
class LeVoInference(torch.nn.Module):
def __init__(self, cfg_path):
super().__init__()
torch.backends.cudnn.enabled = False
OmegaConf.register_new_resolver("eval", lambda x: eval(x))
OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx])
OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x)))
self.cfg = OmegaConf.load(cfg_path)
self.max_duration = self.cfg.max_dur
# Define model or load pretrained model
model_light = CodecLM_PL(self.cfg)
model_light = model_light.eval().cuda()
model_light.audiolm.cfg = self.cfg
self.model_lm = model_light.audiolm
self.model_audio_tokenizer = model_light.audio_tokenizer
self.model_seperate_tokenizer = model_light.seperate_tokenizer
self.model = CodecLM(name = "tmp",
lm = self.model_lm,
audiotokenizer = self.model_audio_tokenizer,
max_duration = self.max_duration,
seperate_tokenizer = self.model_seperate_tokenizer,
)
self.separator = Separator()
self.default_params = dict(
cfg_coef = 1.5,
temperature = 1.0,
top_k = 50,
top_p = 0.0,
record_tokens = True,
record_window = 50,
extend_stride = 5,
duration = self.max_duration,
)
self.model.set_generation_params(**self.default_params)
def forward(self, lyric: str, description: str, prompt_audio_path: os.PathLike = None, params = dict()):
params = {**self.default_params, **params}
self.model.set_generation_params(**params)
if prompt_audio_path is None:
pmt_wav, vocal_wav, bgm_wav = None, None, None
else:
pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path)
generate_inp = {
'lyrics': [lyric.replace(" ", " ")],
'descriptions': [description],
'melody_wavs': pmt_wav,
'vocal_wavs': vocal_wav,
'bgm_wavs': bgm_wav,
}
with torch.autocast(device_type="cuda", dtype=torch.float16):
tokens = self.model.generate(**generate_inp, return_tokens=True)
if tokens.shape[-1] > 3000:
tokens = tokens[..., :3000]
with torch.no_grad():
wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav)
return wav_seperate[0]
def build_levo_inference():
cfg_path = './conf/infer.yaml'
return LeVoInference(cfg_path)
if __name__ == '__main__':
import sys
import os
import time
import json
import torchaudio
cfg_path = sys.argv[1]
save_dir = sys.argv[2]
input_jsonl = sys.argv[3]
model = LeVoInference(cfg_path)
os.makedirs(save_dir + "/audios", exist_ok=True)
with open(input_jsonl, "r") as fp:
lines = fp.readlines()
for line in lines:
item = json.loads(line)
target_wav_name = f"{save_dir}/audios/{item['idx']}.flac"
descriptions = item["descriptions"]
lyric = item["gt_lyric"]
prompt_audio_path = item['prompt_audio_path']
wav = model(lyric, descriptions, prompt_audio_path)
torchaudio.save(target_wav_name, wav.cpu().float(), model.cfg.sample_rate)