|
import torch |
|
import librosa |
|
import random |
|
import json |
|
from muq import MuQMuLan |
|
from mutagen.mp3 import MP3 |
|
import os |
|
import numpy as np |
|
from huggingface_hub import hf_hub_download |
|
from diffrhythm.model import DiT, CFM |
|
|
|
|
|
def prepare_model(device): |
|
|
|
dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-base", filename="cfm_model.pt") |
|
dit_config_path = "./diffrhythm/config/diffrhythm-1b.json" |
|
with open(dit_config_path) as f: |
|
model_config = json.load(f) |
|
dit_model_cls = DiT |
|
cfm = CFM( |
|
transformer=dit_model_cls(**model_config["model"], use_style_prompt=True), |
|
num_channels=model_config["model"]['mel_dim'], |
|
use_style_prompt=True |
|
) |
|
cfm = cfm.to(device) |
|
cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False) |
|
|
|
|
|
tokenizer = CNENTokenizer() |
|
|
|
|
|
muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large") |
|
muq = muq.to(device).eval() |
|
|
|
|
|
vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt") |
|
vae = torch.jit.load(vae_ckpt_path).to(device) |
|
|
|
return cfm, tokenizer, muq, vae |
|
|
|
|
|
|
|
def get_reference_latent(device, max_frames): |
|
return torch.zeros(1, max_frames, 64).to(device) |
|
|
|
def get_negative_style_prompt(device): |
|
file_path = "./prompt/negative_prompt.npy" |
|
vocal_stlye = np.load(file_path) |
|
|
|
vocal_stlye = torch.from_numpy(vocal_stlye).to(device) |
|
vocal_stlye = vocal_stlye.half() |
|
|
|
return vocal_stlye |
|
|
|
def get_style_prompt(model, wav_path): |
|
mulan = model |
|
|
|
ext = os.path.splitext(wav_path)[-1].lower() |
|
if ext == '.mp3': |
|
meta = MP3(wav_path) |
|
audio_len = meta.info.length |
|
src_sr = meta.info.sample_rate |
|
elif ext == '.wav': |
|
audio, sr = librosa.load(wav_path, sr=None) |
|
audio_len = librosa.get_duration(y=audio, sr=sr) |
|
src_sr = sr |
|
else: |
|
raise ValueError("Unsupported file format: {}".format(ext)) |
|
|
|
assert(audio_len >= 10) |
|
|
|
mid_time = audio_len // 2 |
|
start_time = mid_time - 5 |
|
wav, sr = librosa.load(wav_path, sr=None, offset=start_time, duration=10) |
|
|
|
resampled_wav = librosa.resample(wav, orig_sr=src_sr, target_sr=24000) |
|
resampled_wav = torch.tensor(resampled_wav).unsqueeze(0).to(model.device) |
|
|
|
with torch.no_grad(): |
|
audio_emb = mulan(wavs = resampled_wav) |
|
|
|
audio_emb = audio_emb |
|
audio_emb = audio_emb.half() |
|
|
|
return audio_emb |
|
|
|
def parse_lyrics(lyrics: str): |
|
lyrics_with_time = [] |
|
lyrics = lyrics.strip() |
|
for line in lyrics.split('\n'): |
|
try: |
|
time, lyric = line[1:9], line[10:] |
|
lyric = lyric.strip() |
|
mins, secs = time.split(':') |
|
secs = int(mins) * 60 + float(secs) |
|
lyrics_with_time.append((secs, lyric)) |
|
except: |
|
continue |
|
return lyrics_with_time |
|
|
|
class CNENTokenizer(): |
|
def __init__(self): |
|
with open('./diffrhythm/g2p/g2p/vocab.json', 'r') as file: |
|
self.phone2id:dict = json.load(file)['vocab'] |
|
self.id2phone = {v:k for (k, v) in self.phone2id.items()} |
|
|
|
from diffrhythm.g2p.g2p_generation import chn_eng_g2p |
|
self.tokenizer = chn_eng_g2p |
|
def encode(self, text): |
|
phone, token = self.tokenizer(text) |
|
token = [x+1 for x in token] |
|
return token |
|
def decode(self, token): |
|
return "|".join([self.id2phone[x-1] for x in token]) |
|
|
|
def get_lrc_token(text, tokenizer, device): |
|
|
|
max_frames = 2048 |
|
lyrics_shift = 0 |
|
sampling_rate = 44100 |
|
downsample_rate = 2048 |
|
max_secs = max_frames / (sampling_rate / downsample_rate) |
|
|
|
pad_token_id = 0 |
|
comma_token_id = 1 |
|
period_token_id = 2 |
|
|
|
lrc_with_time = parse_lyrics(text) |
|
|
|
modified_lrc_with_time = [] |
|
for i in range(len(lrc_with_time)): |
|
time, line = lrc_with_time[i] |
|
line_token = tokenizer.encode(line) |
|
modified_lrc_with_time.append((time, line_token)) |
|
lrc_with_time = modified_lrc_with_time |
|
|
|
lrc_with_time = [(time_start, line) for (time_start, line) in lrc_with_time if time_start < max_secs] |
|
lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time |
|
|
|
normalized_start_time = 0. |
|
|
|
lrc = torch.zeros((max_frames,), dtype=torch.long) |
|
|
|
tokens_count = 0 |
|
last_end_pos = 0 |
|
for time_start, line in lrc_with_time: |
|
tokens = [token if token != period_token_id else comma_token_id for token in line] + [period_token_id] |
|
tokens = torch.tensor(tokens, dtype=torch.long) |
|
num_tokens = tokens.shape[0] |
|
|
|
gt_frame_start = int(time_start * sampling_rate / downsample_rate) |
|
|
|
frame_shift = random.randint(int(lyrics_shift), int(lyrics_shift)) |
|
|
|
frame_start = max(gt_frame_start - frame_shift, last_end_pos) |
|
frame_len = min(num_tokens, max_frames - frame_start) |
|
|
|
|
|
|
|
lrc[frame_start:frame_start + frame_len] = tokens[:frame_len] |
|
|
|
tokens_count += num_tokens |
|
last_end_pos = frame_start + frame_len |
|
|
|
lrc_emb = lrc.unsqueeze(0).to(device) |
|
|
|
normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device) |
|
normalized_start_time = normalized_start_time.half() |
|
|
|
return lrc_emb, normalized_start_time |
|
|
|
def load_checkpoint(model, ckpt_path, device, use_ema=True): |
|
if device == "cuda": |
|
model = model.half() |
|
|
|
ckpt_type = ckpt_path.split(".")[-1] |
|
if ckpt_type == "safetensors": |
|
from safetensors.torch import load_file |
|
|
|
checkpoint = load_file(ckpt_path) |
|
else: |
|
checkpoint = torch.load(ckpt_path, weights_only=True) |
|
|
|
if use_ema: |
|
if ckpt_type == "safetensors": |
|
checkpoint = {"ema_model_state_dict": checkpoint} |
|
checkpoint["model_state_dict"] = { |
|
k.replace("ema_model.", ""): v |
|
for k, v in checkpoint["ema_model_state_dict"].items() |
|
if k not in ["initted", "step"] |
|
} |
|
model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
|
else: |
|
if ckpt_type == "safetensors": |
|
checkpoint = {"model_state_dict": checkpoint} |
|
model.load_state_dict(checkpoint["model_state_dict"], strict=False) |
|
|
|
return model.to(device) |