|
import torch |
|
import torchaudio |
|
from einops import rearrange |
|
import argparse |
|
import json |
|
import os |
|
from tqdm import tqdm |
|
import random |
|
import numpy as np |
|
import time |
|
|
|
from diffrhythm.infer.infer_utils import ( |
|
get_reference_latent, |
|
get_lrc_token, |
|
get_style_prompt, |
|
prepare_model, |
|
get_negative_style_prompt |
|
) |
|
|
|
def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128): |
|
downsampling_ratio = 2048 |
|
io_channels = 2 |
|
if not chunked: |
|
|
|
return vae_model.decode_export(latents) |
|
else: |
|
|
|
hop_size = chunk_size - overlap |
|
total_size = latents.shape[2] |
|
batch_size = latents.shape[0] |
|
chunks = [] |
|
i = 0 |
|
for i in range(0, total_size - chunk_size + 1, hop_size): |
|
chunk = latents[:,:,i:i+chunk_size] |
|
chunks.append(chunk) |
|
if i+chunk_size != total_size: |
|
|
|
chunk = latents[:,:,-chunk_size:] |
|
chunks.append(chunk) |
|
chunks = torch.stack(chunks) |
|
num_chunks = chunks.shape[0] |
|
|
|
samples_per_latent = downsampling_ratio |
|
|
|
y_size = total_size * samples_per_latent |
|
y_final = torch.zeros((batch_size,io_channels,y_size)).to(latents.device) |
|
for i in range(num_chunks): |
|
x_chunk = chunks[i,:] |
|
|
|
y_chunk = vae_model.decode_export(x_chunk) |
|
|
|
if i == num_chunks-1: |
|
|
|
t_end = y_size |
|
t_start = t_end - y_chunk.shape[2] |
|
else: |
|
t_start = i * hop_size * samples_per_latent |
|
t_end = t_start + chunk_size * samples_per_latent |
|
|
|
ol = (overlap//2) * samples_per_latent |
|
chunk_start = 0 |
|
chunk_end = y_chunk.shape[2] |
|
if i > 0: |
|
|
|
t_start += ol |
|
chunk_start += ol |
|
if i < num_chunks-1: |
|
|
|
t_end -= ol |
|
chunk_end -= ol |
|
|
|
y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] |
|
return y_final |
|
|
|
def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, sway_sampling_coef, start_time): |
|
|
|
with torch.inference_mode(): |
|
generated, _ = cfm_model.sample( |
|
cond=cond, |
|
text=text, |
|
duration=duration, |
|
style_prompt=style_prompt, |
|
negative_style_prompt=negative_style_prompt, |
|
steps=steps, |
|
cfg_strength=4.0, |
|
sway_sampling_coef=sway_sampling_coef, |
|
start_time=start_time |
|
) |
|
|
|
generated = generated.to(torch.float32) |
|
latent = generated.transpose(1, 2) |
|
|
|
output = decode_audio(latent, vae_model) |
|
|
|
|
|
output = rearrange(output, "b d n -> d (b n)") |
|
|
|
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() |
|
|
|
return output |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--lrc-path', type=str, default="example/eg.lrc") |
|
parser.add_argument('--ref-audio-path', type=str, default="example/eg.mp3") |
|
parser.add_argument('--audio-length', type=int, default=95) |
|
parser.add_argument('--output-dir', type=str, default="example/output") |
|
args = parser.parse_args() |
|
|
|
device = 'cuda' |
|
|
|
audio_length = args.audio_length |
|
if audio_length == 95: |
|
max_frames = 2048 |
|
elif audio_length == 285: |
|
max_frames = 6144 |
|
|
|
cfm, tokenizer, muq, vae = prepare_model(device) |
|
|
|
with open(args.lrc_path, 'r') as f: |
|
lrc = f.read() |
|
lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device) |
|
|
|
style_prompt = get_style_prompt(muq, args.ref_audio_path) |
|
|
|
negative_style_prompt = get_negative_style_prompt(device) |
|
|
|
latent_prompt = get_reference_latent(device, max_frames) |
|
|
|
s_t = time.time() |
|
generated_song = inference(cfm_model=cfm, |
|
vae_model=vae, |
|
cond=latent_prompt, |
|
text=lrc_prompt, |
|
duration=max_frames, |
|
style_prompt=style_prompt, |
|
negative_style_prompt=negative_style_prompt, |
|
start_time=start_time |
|
) |
|
e_t = time.time() - s_t |
|
print(f"inference cost {e_t} seconds") |
|
|
|
output_dir = args.output_dir |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
output_path = os.path.join(output_dir, "output.wav") |
|
torchaudio.save(output_path, generated_song, sample_rate=44100) |
|
|