File size: 12,401 Bytes
0a97d6c 43c1510 0a97d6c 43c1510 0a97d6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import os
import requests
import yaml
import torch
import librosa
import numpy as np
import soundfile as sf
from pathlib import Path
from transformers import T5Tokenizer, T5EncoderModel
from tqdm import tqdm
from .src.vc_wrapper import ReDiffVC, DreamVC
from .src.plugin_wrapper import DreamVG
from .src.modules.speaker_encoder.encoder import inference as spk_encoder
from .src.modules.BigVGAN.inference import load_model as load_vocoder
from .src.feats.contentvec_hf import get_content_model, get_content
class DreamVoice:
def __init__(self, config='dreamvc.yaml', mode='plugin', device='cuda', chunk_size=16):
# Initial setup
script_dir = Path(__file__).resolve().parent
config_path = script_dir / config
# Load configuration file
with open(config_path, 'r') as fp:
self.config = yaml.safe_load(fp)
self.script_dir = script_dir
# Ensure all checkpoints are downloaded
self._ensure_checkpoints_exist()
# Initialize attributes
self.device = device
self.sr = self.config['sample_rate']
# Load vocoder
vocoder_path = script_dir / self.config['vocoder_path']
self.hifigan, _ = load_vocoder(vocoder_path, device)
self.hifigan.eval()
# Load content model
self.content_model = get_content_model().to(device)
# Load tokenizer and text encoder
lm_path = self.config['lm_path']
self.tokenizer = T5Tokenizer.from_pretrained(lm_path)
self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval()
# Set mode
self.mode = mode
if mode == 'plugin':
self._init_plugin_mode()
elif mode == 'end2end':
self._init_end2end_mode()
else:
raise NotImplementedError("Select mode from 'plugin' and 'end2end'")
# chunk inputs to 10s clips
self.chunk_size = chunk_size * 50
def _ensure_checkpoints_exist(self):
checkpoints = [
('vocoder_path', self.config.get('vocoder_url')),
('vocoder_config_path', self.config.get('vocoder_config_url')),
('speaker_path', self.config.get('speaker_url')),
('dreamvc.ckpt_path', self.config.get('dreamvc', {}).get('ckpt_url')),
('rediffvc.ckpt_path', self.config.get('rediffvc', {}).get('ckpt_url')),
('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url'))
]
for path_key, url in checkpoints:
local_path = self._get_local_path(path_key)
if not local_path.exists() and url:
print(f"Downloading {path_key} from {url}")
self._download_file(url, local_path)
def _get_local_path(self, path_key):
keys = path_key.split('.')
local_path = self.config
for key in keys:
local_path = local_path.get(key, {})
return self.script_dir / local_path
def _download_file(self, url, local_path):
try:
# Attempt to send a GET request to the URL
response = requests.get(url, stream=True)
response.raise_for_status() # Ensure we raise an exception for HTTP errors
except requests.exceptions.RequestException as e:
# Log the error for debugging purposes
print(f"Error encountered: {e}")
# Development mode: prompt user for Hugging Face API key
user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.")
self.hf_key = user_input if user_input else None
# Set headers if an API key is provided
headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {}
try:
# Attempt to send a GET request with headers in development mode
response = requests.get(url, stream=True, headers=headers)
response.raise_for_status() # Ensure we raise an exception for HTTP errors
except requests.exceptions.RequestException as e:
# Log the error for debugging purposes
print(f"Error encountered in dev mode: {e}")
response = None # Handle response accordingly in your code
local_path.parent.mkdir(parents=True, exist_ok=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 8192
t = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=block_size):
t.update(len(chunk))
f.write(chunk)
t.close()
def _init_plugin_mode(self):
# Initialize ReDiffVC
self.dreamvc = ReDiffVC(
config_path=self.script_dir / self.config['rediffvc']['config_path'],
ckpt_path=self.script_dir / self.config['rediffvc']['ckpt_path'],
device=self.device
)
# Initialize DreamVG
self.dreamvg = DreamVG(
config_path=self.script_dir / self.config['dreamvg']['config_path'],
ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'],
device=self.device
)
# Load speaker encoder
spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device)
self.spk_encoder = spk_encoder
self.spk_embed_cache = None
def _init_end2end_mode(self):
# Initialize DreamVC
self.dreamvc = DreamVC(
config_path=self.script_dir / self.config['dreamvc']['config_path'],
ckpt_path=self.script_dir / self.config['dreamvc']['ckpt_path'],
device=self.device
)
def _load_content(self, audio_path):
content_audio, _ = librosa.load(audio_path, sr=16000)
# Calculate the required length to make it a multiple of 16*160
target_length = ((len(content_audio) + 16*160 - 1) // (16*160)) * (16*160)
# Pad with zeros if necessary
if len(content_audio) < target_length:
content_audio = np.pad(content_audio, (0, target_length - len(content_audio)), mode='constant')
content_audio = torch.tensor(content_audio).unsqueeze(0).to(self.device)
content_clip = get_content(self.content_model, content_audio)
return content_clip
def load_spk_embed(self, emb_path):
self.spk_embed_cache = torch.load(emb_path, map_location=self.device)
def save_spk_embed(self, emb_path):
assert self.spk_embed_cache is not None
torch.save(self.spk_embed_cache.cpu(), emb_path)
def save_audio(self, output_path, audio, sr):
sf.write(output_path, audio, samplerate=sr)
@torch.no_grad()
def genvc(self, content_audio, prompt,
prompt_guidance_scale=3, prompt_guidance_rescale=0.0,
prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None,
vc_guidance_scale=3, vc_guidance_rescale=0.0,
vc_ddim_steps=50, vc_eta=1, vc_random_seed=None,
):
content_clip = self._load_content(content_audio)
text_batch = self.tokenizer(prompt, max_length=32,
padding='max_length', truncation=True, return_tensors="pt")
text, text_mask = text_batch.input_ids.to(self.device), \
text_batch.attention_mask.to(self.device)
text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0]
if self.mode == 'plugin':
spk_embed = self.dreamvg.inference([text, text_mask],
guidance_scale=prompt_guidance_scale,
guidance_rescale=prompt_guidance_rescale,
ddim_steps=prompt_ddim_steps, eta=prompt_eta,
random_seed=prompt_random_seed)
B, L, D = content_clip.shape
gen_audio_chunks = []
num_chunks = (L + self.chunk_size - 1) // self.chunk_size
for i in range(num_chunks):
start_idx = i * self.chunk_size
end_idx = min((i + 1) * self.chunk_size, L)
content_clip_chunk = content_clip[:, start_idx:end_idx, :]
gen_audio_chunk = self.dreamvc.inference(
spk_embed, content_clip_chunk, None,
guidance_scale=vc_guidance_scale,
guidance_rescale=vc_guidance_rescale,
ddim_steps=vc_ddim_steps,
eta=vc_eta,
random_seed=vc_random_seed)
gen_audio_chunks.append(gen_audio_chunk)
gen_audio = torch.cat(gen_audio_chunks, dim=-1)
self.spk_embed_cache = spk_embed
elif self.mode == 'end2end':
B, L, D = content_clip.shape
gen_audio_chunks = []
num_chunks = (L + self.chunk_size - 1) // self.chunk_size
for i in range(num_chunks):
start_idx = i * self.chunk_size
end_idx = min((i + 1) * self.chunk_size, L)
content_clip_chunk = content_clip[:, start_idx:end_idx, :]
gen_audio_chunk = self.dreamvc.inference([text, text_mask], content_clip,
guidance_scale=prompt_guidance_scale,
guidance_rescale=prompt_guidance_rescale,
ddim_steps=prompt_ddim_steps,
eta=prompt_eta, random_seed=prompt_random_seed)
gen_audio_chunks.append(gen_audio_chunk)
gen_audio = torch.cat(gen_audio_chunks, dim=-1)
else:
raise NotImplementedError("Select mode from 'plugin' and 'end2end'")
gen_audio = self.hifigan(gen_audio.squeeze(1))
gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0)
return gen_audio, self.sr
@torch.no_grad()
def simplevc(self, content_audio, speaker_audio=None, use_spk_cache=False,
vc_guidance_scale=3, vc_guidance_rescale=0.0,
vc_ddim_steps=50, vc_eta=1, vc_random_seed=None,
):
assert self.mode == 'plugin'
if speaker_audio is not None:
speaker_audio, _ = librosa.load(speaker_audio, sr=16000)
speaker_audio = torch.tensor(speaker_audio).unsqueeze(0).to(self.device)
spk_embed = spk_encoder.embed_utterance_batch(speaker_audio)
self.spk_embed_cache = spk_embed
elif use_spk_cache:
assert self.spk_embed_cache is not None
spk_embed = self.spk_embed_cache
else:
raise NotImplementedError
content_clip = self._load_content(content_audio)
B, L, D = content_clip.shape
gen_audio_chunks = []
num_chunks = (L + self.chunk_size - 1) // self.chunk_size
for i in range(num_chunks):
start_idx = i * self.chunk_size
end_idx = min((i + 1) * self.chunk_size, L)
content_clip_chunk = content_clip[:, start_idx:end_idx, :]
gen_audio_chunk = self.dreamvc.inference(
spk_embed, content_clip_chunk, None,
guidance_scale=vc_guidance_scale,
guidance_rescale=vc_guidance_rescale,
ddim_steps=vc_ddim_steps,
eta=vc_eta,
random_seed=vc_random_seed)
gen_audio_chunks.append(gen_audio_chunk)
gen_audio = torch.cat(gen_audio_chunks, dim=-1)
gen_audio = self.hifigan(gen_audio.squeeze(1))
gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0)
return gen_audio, self.sr
if __name__ == '__main__':
dreamvoice = DreamVoice(config='dreamvc.yaml', mode='plugin', device='cuda')
content_audio = 'test.wav'
speaker_audio = 'speaker.wav'
prompt = 'young female voice, sounds young and cute'
gen_audio, sr = dreamvoice.genvc('test.wav', prompt)
dreamvoice.save_audio('debug.wav', gen_audio, sr) |