File size: 5,119 Bytes
1e95c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import requests
import yaml
import torch
import librosa
import numpy as np
import soundfile as sf
from pathlib import Path
from transformers import T5Tokenizer, T5EncoderModel
from tqdm import tqdm
from .src.plugin_wrapper import DreamVG


class DreamVoice_Plugin:
    def __init__(self, config='plugin.yaml', device='cuda'):
        # Initial setup
        script_dir = Path(__file__).resolve().parent
        config_path = script_dir / config

        # Load configuration file
        with open(config_path, 'r') as fp:
            self.config = yaml.safe_load(fp)

        self.script_dir = script_dir

        # Ensure all checkpoints are downloaded
        self._ensure_checkpoints_exist()

        # Initialize attributes
        self.device = device

        # Load tokenizer and text encoder
        lm_path = self.config['lm_path']
        self.tokenizer = T5Tokenizer.from_pretrained(lm_path)
        self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval()
        
        self.dreamvg = DreamVG(
            config_path=self.script_dir / self.config['dreamvg']['config_path'],
            ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'],
            device=self.device
        
        )
    def _ensure_checkpoints_exist(self):
        checkpoints = [
            ('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url'))
        ]

        for path_key, url in checkpoints:
            local_path = self._get_local_path(path_key)
            if not local_path.exists() and url:
                print(f"Downloading {path_key} from {url}")
                self._download_file(url, local_path)

    def _get_local_path(self, path_key):
        keys = path_key.split('.')
        local_path = self.config
        for key in keys:
            local_path = local_path.get(key, {})
        return self.script_dir / local_path

    def _download_file(self, url, local_path):
        try:
            # Attempt to send a GET request to the URL
            response = requests.get(url, stream=True)
            response.raise_for_status()  # Ensure we raise an exception for HTTP errors
        except requests.exceptions.RequestException as e:
            # Log the error for debugging purposes
            print(f"Error encountered: {e}")

            # Development mode: prompt user for Hugging Face API key
            user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.")
            self.hf_key = user_input if user_input else None

            # Set headers if an API key is provided
            headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {}

            try:
                # Attempt to send a GET request with headers in development mode
                response = requests.get(url, stream=True, headers=headers)
                response.raise_for_status()  # Ensure we raise an exception for HTTP errors
            except requests.exceptions.RequestException as e:
                # Log the error for debugging purposes
                print(f"Error encountered in dev mode: {e}")
                response = None  # Handle response accordingly in your code

        local_path.parent.mkdir(parents=True, exist_ok=True)

        total_size = int(response.headers.get('content-length', 0))
        block_size = 8192
        t = tqdm(total=total_size, unit='iB', unit_scale=True)

        with open(local_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=block_size):
                t.update(len(chunk))
                f.write(chunk)
        t.close()

    def _init_plugin_mode(self):
        # Initialize DreamVG
        self.dreamvg = DreamVG(
            config_path=self.script_dir / self.config['dreamvg']['config_path'],
            ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'],
            device=self.device
        )

        # Load speaker encoder
        spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device)
        self.spk_encoder = spk_encoder
        self.spk_embed_cache = None

    @torch.no_grad()
    def gen_spk(self, prompt,
                prompt_guidance_scale=3, prompt_guidance_rescale=0.0,
                prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None,):
        
        text_batch = self.tokenizer(prompt, max_length=32,
                                    padding='max_length', truncation=True, return_tensors="pt")
        text, text_mask = text_batch.input_ids.to(self.device), \
            text_batch.attention_mask.to(self.device)
        text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0]
        
        spk_embed = self.dreamvg.inference([text, text_mask],
                                           guidance_scale=prompt_guidance_scale,
                                           guidance_rescale=prompt_guidance_rescale,
                                           ddim_steps=prompt_ddim_steps, eta=prompt_eta,
                                           random_seed=prompt_random_seed)
        return spk_embed