DreamVoice / dreamvoice /plugin.py
Higobeatz's picture
freevc plugin
0dabde8
raw
history blame
5.12 kB
import os
import requests
import yaml
import torch
import librosa
import numpy as np
import soundfile as sf
from pathlib import Path
from transformers import T5Tokenizer, T5EncoderModel
from tqdm import tqdm
from .src.plugin_wrapper import DreamVG
class DreamVoice_Plugin:
def __init__(self, config='plugin.yaml', device='cuda'):
# Initial setup
script_dir = Path(__file__).resolve().parent
config_path = script_dir / config
# Load configuration file
with open(config_path, 'r') as fp:
self.config = yaml.safe_load(fp)
self.script_dir = script_dir
# Ensure all checkpoints are downloaded
self._ensure_checkpoints_exist()
# Initialize attributes
self.device = device
# Load tokenizer and text encoder
lm_path = self.config['lm_path']
self.tokenizer = T5Tokenizer.from_pretrained(lm_path)
self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval()
self.dreamvg = DreamVG(
config_path=self.script_dir / self.config['dreamvg']['config_path'],
ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'],
device=self.device
)
def _ensure_checkpoints_exist(self):
checkpoints = [
('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url'))
]
for path_key, url in checkpoints:
local_path = self._get_local_path(path_key)
if not local_path.exists() and url:
print(f"Downloading {path_key} from {url}")
self._download_file(url, local_path)
def _get_local_path(self, path_key):
keys = path_key.split('.')
local_path = self.config
for key in keys:
local_path = local_path.get(key, {})
return self.script_dir / local_path
def _download_file(self, url, local_path):
try:
# Attempt to send a GET request to the URL
response = requests.get(url, stream=True)
response.raise_for_status() # Ensure we raise an exception for HTTP errors
except requests.exceptions.RequestException as e:
# Log the error for debugging purposes
print(f"Error encountered: {e}")
# Development mode: prompt user for Hugging Face API key
user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.")
self.hf_key = user_input if user_input else None
# Set headers if an API key is provided
headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {}
try:
# Attempt to send a GET request with headers in development mode
response = requests.get(url, stream=True, headers=headers)
response.raise_for_status() # Ensure we raise an exception for HTTP errors
except requests.exceptions.RequestException as e:
# Log the error for debugging purposes
print(f"Error encountered in dev mode: {e}")
response = None # Handle response accordingly in your code
local_path.parent.mkdir(parents=True, exist_ok=True)
total_size = int(response.headers.get('content-length', 0))
block_size = 8192
t = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(local_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=block_size):
t.update(len(chunk))
f.write(chunk)
t.close()
def _init_plugin_mode(self):
# Initialize DreamVG
self.dreamvg = DreamVG(
config_path=self.script_dir / self.config['dreamvg']['config_path'],
ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'],
device=self.device
)
# Load speaker encoder
spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device)
self.spk_encoder = spk_encoder
self.spk_embed_cache = None
@torch.no_grad()
def gen_spk(self, prompt,
prompt_guidance_scale=3, prompt_guidance_rescale=0.0,
prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None,):
text_batch = self.tokenizer(prompt, max_length=32,
padding='max_length', truncation=True, return_tensors="pt")
text, text_mask = text_batch.input_ids.to(self.device), \
text_batch.attention_mask.to(self.device)
text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0]
spk_embed = self.dreamvg.inference([text, text_mask],
guidance_scale=prompt_guidance_scale,
guidance_rescale=prompt_guidance_rescale,
ddim_steps=prompt_ddim_steps, eta=prompt_eta,
random_seed=prompt_random_seed)
return spk_embed