import os import requests import yaml import torch import librosa import numpy as np import soundfile as sf from pathlib import Path from transformers import T5Tokenizer, T5EncoderModel from tqdm import tqdm from .src.plugin_wrapper import DreamVG class DreamVoice_Plugin: def __init__(self, config='plugin.yaml', device='cuda'): # Initial setup script_dir = Path(__file__).resolve().parent config_path = script_dir / config # Load configuration file with open(config_path, 'r') as fp: self.config = yaml.safe_load(fp) self.script_dir = script_dir # Ensure all checkpoints are downloaded self._ensure_checkpoints_exist() # Initialize attributes self.device = device # Load tokenizer and text encoder lm_path = self.config['lm_path'] self.tokenizer = T5Tokenizer.from_pretrained(lm_path) self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval() self.dreamvg = DreamVG( config_path=self.script_dir / self.config['dreamvg']['config_path'], ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'], device=self.device ) def _ensure_checkpoints_exist(self): checkpoints = [ ('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url')) ] for path_key, url in checkpoints: local_path = self._get_local_path(path_key) if not local_path.exists() and url: print(f"Downloading {path_key} from {url}") self._download_file(url, local_path) def _get_local_path(self, path_key): keys = path_key.split('.') local_path = self.config for key in keys: local_path = local_path.get(key, {}) return self.script_dir / local_path def _download_file(self, url, local_path): try: # Attempt to send a GET request to the URL response = requests.get(url, stream=True) response.raise_for_status() # Ensure we raise an exception for HTTP errors except requests.exceptions.RequestException as e: # Log the error for debugging purposes print(f"Error encountered: {e}") # Development mode: prompt user for Hugging Face API key user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.") self.hf_key = user_input if user_input else None # Set headers if an API key is provided headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {} try: # Attempt to send a GET request with headers in development mode response = requests.get(url, stream=True, headers=headers) response.raise_for_status() # Ensure we raise an exception for HTTP errors except requests.exceptions.RequestException as e: # Log the error for debugging purposes print(f"Error encountered in dev mode: {e}") response = None # Handle response accordingly in your code local_path.parent.mkdir(parents=True, exist_ok=True) total_size = int(response.headers.get('content-length', 0)) block_size = 8192 t = tqdm(total=total_size, unit='iB', unit_scale=True) with open(local_path, 'wb') as f: for chunk in response.iter_content(chunk_size=block_size): t.update(len(chunk)) f.write(chunk) t.close() def _init_plugin_mode(self): # Initialize DreamVG self.dreamvg = DreamVG( config_path=self.script_dir / self.config['dreamvg']['config_path'], ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'], device=self.device ) # Load speaker encoder spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device) self.spk_encoder = spk_encoder self.spk_embed_cache = None @torch.no_grad() def gen_spk(self, prompt, prompt_guidance_scale=3, prompt_guidance_rescale=0.0, prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None,): text_batch = self.tokenizer(prompt, max_length=32, padding='max_length', truncation=True, return_tensors="pt") text, text_mask = text_batch.input_ids.to(self.device), \ text_batch.attention_mask.to(self.device) text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0] spk_embed = self.dreamvg.inference([text, text_mask], guidance_scale=prompt_guidance_scale, guidance_rescale=prompt_guidance_rescale, ddim_steps=prompt_ddim_steps, eta=prompt_eta, random_seed=prompt_random_seed) return spk_embed