File size: 6,415 Bytes
0a97d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import yaml
import torch
from diffusers import DDIMScheduler
from .model.model import DiffVC
from .model.model_cross import DiffVC_Cross
from .utils import scale_shift, scale_shift_re, rescale_noise_cfg


class ReDiffVC(object):
    def __init__(self,
                 config_path='configs/diffvc_base.yaml',
                 ckpt_path='../ckpts/dreamvc_base.pt',
                 device='cpu'):

        with open(config_path, 'r') as fp:
            config = yaml.safe_load(fp)

        self.device = device
        self.model = DiffVC(config['model']).to(device)
        self.model.load_state_dict(torch.load(ckpt_path)['model'])
        self.model.eval()

        noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
                                        beta_start=config['scheduler']['beta_start'],
                                        beta_end=config['scheduler']['beta_end'],
                                        rescale_betas_zero_snr=True,
                                        timestep_spacing="trailing",
                                        clip_sample=False,
                                        prediction_type='v_prediction')
        self.noise_scheduler = noise_scheduler
        self.scale = config['scheduler']['scale']
        self.shift = config['scheduler']['shift']
        self.melshape = config['model']['unet']['sample_size'][0]

    @torch.no_grad()
    def inference(self,
                  spk_embed, content_clip, f0_clip=None,
                  guidance_scale=3, guidance_rescale=0.7,
                  ddim_steps=50, eta=1, random_seed=2023):

        self.model.eval()
        if random_seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(random_seed)
        else:
            generator = torch.Generator(device=self.device)
            generator.seed()

        self.noise_scheduler.set_timesteps(ddim_steps)

        # init noise
        gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
        noise = torch.randn(gen_shape, generator=generator, device=self.device)
        latents = noise

        for t in self.noise_scheduler.timesteps:
            latents = self.noise_scheduler.scale_model_input(latents, t)

            if guidance_scale:
                output_text = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)
                output_uncond = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=True,
                                           speaker_cfg=1.0, pitch_cfg=0.0)

                output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
                if guidance_rescale > 0.0:
                    output_pred = rescale_noise_cfg(output_pred, output_text,
                                                    guidance_rescale=guidance_rescale)
            else:
                output_pred = self.model(latents, t, content_clip, spk_embed, f0_clip, train_cfg=False)

            latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
                                                eta=eta, generator=generator).prev_sample

        pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
        return pred


class DreamVC(object):
    def __init__(self,
                 config_path='configs/diffvc_cross.yaml',
                 ckpt_path='../ckpts/dreamvc_cross.pt',
                 device='cpu'):

        with open(config_path, 'r') as fp:
            config = yaml.safe_load(fp)

        self.device = device
        self.model = DiffVC_Cross(config['model']).to(device)
        self.model.load_state_dict(torch.load(ckpt_path)['model'])
        self.model.eval()

        noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
                                        beta_start=config['scheduler']['beta_start'],
                                        beta_end=config['scheduler']['beta_end'],
                                        rescale_betas_zero_snr=True,
                                        timestep_spacing="trailing",
                                        clip_sample=False,
                                        prediction_type='v_prediction')
        self.noise_scheduler = noise_scheduler
        self.scale = config['scheduler']['scale']
        self.shift = config['scheduler']['shift']
        self.melshape = config['model']['unet']['sample_size'][0]

    @torch.no_grad()
    def inference(self,
                  text, content_clip, f0_clip=None,
                  guidance_scale=3, guidance_rescale=0.7,
                  ddim_steps=50, eta=1, random_seed=2023):

        text, text_mask = text
        self.model.eval()
        if random_seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(random_seed)
        else:
            generator = torch.Generator(device=self.device)
            generator.seed()

        self.noise_scheduler.set_timesteps(ddim_steps)

        # init noise
        gen_shape = (1, 1, self.melshape, content_clip.shape[-2])
        noise = torch.randn(gen_shape, generator=generator, device=self.device)
        latents = noise

        for t in self.noise_scheduler.timesteps:
            latents = self.noise_scheduler.scale_model_input(latents, t)

            if guidance_scale:
                output_text = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)
                output_uncond = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=True,
                                           speaker_cfg=1.0, pitch_cfg=0.0)

                output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
                if guidance_rescale > 0.0:
                    output_pred = rescale_noise_cfg(output_pred, output_text,
                                                    guidance_rescale=guidance_rescale)
            else:
                output_pred = self.model(latents, t, content_clip, text, text_mask, f0_clip, train_cfg=False)

            latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
                                                eta=eta, generator=generator).prev_sample

        pred = scale_shift_re(latents, scale=1/self.scale, shift=self.shift)
        return pred