|
import os |
|
import requests |
|
import yaml |
|
import torch |
|
import librosa |
|
import numpy as np |
|
import soundfile as sf |
|
from pathlib import Path |
|
from transformers import T5Tokenizer, T5EncoderModel |
|
from tqdm import tqdm |
|
from .src.plugin_wrapper import DreamVG |
|
|
|
|
|
class DreamVoice_Plugin: |
|
def __init__(self, config='plugin.yaml', device='cuda'): |
|
|
|
script_dir = Path(__file__).resolve().parent |
|
config_path = script_dir / config |
|
|
|
|
|
with open(config_path, 'r') as fp: |
|
self.config = yaml.safe_load(fp) |
|
|
|
self.script_dir = script_dir |
|
|
|
|
|
self._ensure_checkpoints_exist() |
|
|
|
|
|
self.device = device |
|
|
|
|
|
lm_path = self.config['lm_path'] |
|
self.tokenizer = T5Tokenizer.from_pretrained(lm_path) |
|
self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval() |
|
|
|
self.dreamvg = DreamVG( |
|
config_path=self.script_dir / self.config['dreamvg']['config_path'], |
|
ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'], |
|
device=self.device |
|
|
|
) |
|
def _ensure_checkpoints_exist(self): |
|
checkpoints = [ |
|
('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url')) |
|
] |
|
|
|
for path_key, url in checkpoints: |
|
local_path = self._get_local_path(path_key) |
|
if not local_path.exists() and url: |
|
print(f"Downloading {path_key} from {url}") |
|
self._download_file(url, local_path) |
|
|
|
def _get_local_path(self, path_key): |
|
keys = path_key.split('.') |
|
local_path = self.config |
|
for key in keys: |
|
local_path = local_path.get(key, {}) |
|
return self.script_dir / local_path |
|
|
|
def _download_file(self, url, local_path): |
|
try: |
|
|
|
response = requests.get(url, stream=True) |
|
response.raise_for_status() |
|
except requests.exceptions.RequestException as e: |
|
|
|
print(f"Error encountered: {e}") |
|
|
|
|
|
user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.") |
|
self.hf_key = user_input if user_input else None |
|
|
|
|
|
headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {} |
|
|
|
try: |
|
|
|
response = requests.get(url, stream=True, headers=headers) |
|
response.raise_for_status() |
|
except requests.exceptions.RequestException as e: |
|
|
|
print(f"Error encountered in dev mode: {e}") |
|
response = None |
|
|
|
local_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
block_size = 8192 |
|
t = tqdm(total=total_size, unit='iB', unit_scale=True) |
|
|
|
with open(local_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=block_size): |
|
t.update(len(chunk)) |
|
f.write(chunk) |
|
t.close() |
|
|
|
def _init_plugin_mode(self): |
|
|
|
self.dreamvg = DreamVG( |
|
config_path=self.script_dir / self.config['dreamvg']['config_path'], |
|
ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'], |
|
device=self.device |
|
) |
|
|
|
|
|
spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device) |
|
self.spk_encoder = spk_encoder |
|
self.spk_embed_cache = None |
|
|
|
@torch.no_grad() |
|
def gen_spk(self, prompt, |
|
prompt_guidance_scale=3, prompt_guidance_rescale=0.0, |
|
prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None,): |
|
|
|
text_batch = self.tokenizer(prompt, max_length=32, |
|
padding='max_length', truncation=True, return_tensors="pt") |
|
text, text_mask = text_batch.input_ids.to(self.device), \ |
|
text_batch.attention_mask.to(self.device) |
|
text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0] |
|
|
|
spk_embed = self.dreamvg.inference([text, text_mask], |
|
guidance_scale=prompt_guidance_scale, |
|
guidance_rescale=prompt_guidance_rescale, |
|
ddim_steps=prompt_ddim_steps, eta=prompt_eta, |
|
random_seed=prompt_random_seed) |
|
return spk_embed |