Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torchaudio | |
from einops import rearrange | |
import argparse | |
import os | |
import time | |
import random | |
import torch | |
import torchaudio | |
import numpy as np | |
from einops import rearrange | |
import io | |
import pydub | |
from diffrhythm.infer.infer_utils import ( | |
decode_audio, | |
get_lrc_token, | |
get_negative_style_prompt, | |
get_reference_latent, | |
get_style_prompt, | |
prepare_model, | |
eval_song, | |
) | |
def inference( | |
cfm_model, | |
vae_model, | |
eval_model, | |
eval_muq, | |
cond, | |
text, | |
duration, | |
style_prompt, | |
negative_style_prompt, | |
steps, | |
cfg_strength, | |
sway_sampling_coef, | |
start_time, | |
file_type, | |
vocal_flag, | |
odeint_method, | |
pred_frames, | |
batch_infer_num, | |
chunked=True, | |
): | |
with torch.inference_mode(): | |
latents, _ = cfm_model.sample( | |
cond=cond, | |
text=text, | |
duration=duration, | |
style_prompt=style_prompt, | |
negative_style_prompt=negative_style_prompt, | |
steps=steps, | |
cfg_strength=cfg_strength, | |
sway_sampling_coef=sway_sampling_coef, | |
start_time=start_time, | |
vocal_flag=vocal_flag, | |
odeint_method=odeint_method, | |
latent_pred_segments=pred_frames, | |
batch_infer_num=batch_infer_num | |
) | |
outputs = [] | |
for latent in latents: | |
latent = latent.to(torch.float32) | |
latent = latent.transpose(1, 2) # [b d t] | |
output = decode_audio(latent, vae_model, chunked=chunked) | |
# Rearrange audio batch to a single sequence | |
output = rearrange(output, "b d n -> d (b n)") | |
outputs.append(output) | |
if batch_infer_num > 1: | |
generated_song = eval_song(eval_model, eval_muq, outputs) | |
else: | |
generated_song = outputs[0] | |
output_tensor = generated_song.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu() | |
output_np = output_tensor.numpy().T.astype(np.float32) | |
if file_type == 'wav': | |
return (44100, output_np) | |
else: | |
buffer = io.BytesIO() | |
output_np = np.int16(output_np * 2**15) | |
song = pydub.AudioSegment(output_np.tobytes(), frame_rate=44100, sample_width=2, channels=2) | |
if file_type == 'mp3': | |
song.export(buffer, format="mp3", bitrate="320k") | |
else: | |
song.export(buffer, format="ogg", bitrate="320k") | |
return buffer.getvalue() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--lrc-path", | |
type=str, | |
help="lyrics of target song", | |
) # lyrics of target song | |
parser.add_argument( | |
"--ref-prompt", | |
type=str, | |
help="reference prompt as style prompt for target song", | |
required=False, | |
) # reference prompt as style prompt for target song | |
parser.add_argument( | |
"--ref-audio-path", | |
type=str, | |
help="reference audio as style prompt for target song", | |
required=False, | |
) # reference audio as style prompt for target song | |
parser.add_argument( | |
"--chunked", | |
action="store_true", | |
help="whether to use chunked decoding", | |
) # whether to use chunked decoding | |
parser.add_argument( | |
"--audio-length", | |
type=int, | |
default=95, | |
choices=[95, 285], | |
help="length of generated song", | |
) # length of target song | |
parser.add_argument( | |
"--repo-id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model" | |
) | |
parser.add_argument( | |
"--output-dir", | |
type=str, | |
default="infer/example/output", | |
help="output directory fo generated song", | |
) # output directory of target song | |
parser.add_argument( | |
"--edit", | |
action="store_true", | |
help="whether to open edit mode", | |
) # edit flag | |
parser.add_argument( | |
"--ref-song", | |
type=str, | |
required=False, | |
help="reference prompt as latent prompt for editing", | |
) # reference prompt as latent prompt for editing | |
parser.add_argument( | |
"--edit-segments", | |
type=str, | |
required=False, | |
help="edit segments o target song", | |
) # edit segments o target song | |
args = parser.parse_args() | |
assert ( | |
args.ref_prompt or args.ref_audio_path | |
), "either ref_prompt or ref_audio_path should be provided" | |
assert not ( | |
args.ref_prompt and args.ref_audio_path | |
), "only one of them should be provided" | |
if args.edit: | |
assert ( | |
args.ref_song and args.edit_segments | |
), "reference song and edit segments should be provided for editing" | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.mps.is_available(): | |
device = "mps" | |
audio_length = args.audio_length | |
if audio_length == 95: | |
max_frames = 2048 | |
elif audio_length == 285: | |
max_frames = 6144 | |
cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(max_frames, device, repo_id=args.repo_id) | |
if args.lrc_path: | |
with open(args.lrc_path, "r", encoding='utf-8') as f: | |
lrc = f.read() | |
else: | |
lrc = "" | |
lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device) | |
if args.ref_audio_path: | |
style_prompt = get_style_prompt(muq, args.ref_audio_path) | |
else: | |
style_prompt = get_style_prompt(muq, prompt=args.ref_prompt) | |
negative_style_prompt = get_negative_style_prompt(device) | |
latent_prompt, pred_frames = get_reference_latent(device, max_frames, args.edit, args.edit_segments, args.ref_song, vae) | |
s_t = time.time() | |
generated_songs = inference( | |
cfm_model=cfm, | |
vae_model=vae, | |
cond=latent_prompt, | |
text=lrc_prompt, | |
duration=max_frames, | |
style_prompt=style_prompt, | |
negative_style_prompt=negative_style_prompt, | |
start_time=start_time, | |
pred_frames=pred_frames, | |
chunked=args.chunked, | |
) | |
generated_song = eval_song(eval_model, eval_muq, generated_songs) | |
# Peak normalize, clip, convert to int16, and save to file | |
generated_song = ( | |
generated_song.to(torch.float32) | |
.div(torch.max(torch.abs(generated_song))) | |
.clamp(-1, 1) | |
.mul(32767) | |
.to(torch.int16) | |
.cpu() | |
) | |
e_t = time.time() - s_t | |
print(f"inference cost {e_t:.2f} seconds") | |
output_dir = args.output_dir | |
os.makedirs(output_dir, exist_ok=True) | |
output_path = os.path.join(output_dir, "output.wav") | |
torchaudio.save(output_path, generated_song, sample_rate=44100) | |