File size: 3,253 Bytes
0a97d6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import yaml
import torch
from diffusers import DDIMScheduler
from .model.p2e_cross import P2E_Cross
from .utils import scale_shift, scale_shift_re, rescale_noise_cfg


class DreamVG(object):
    def __init__(self,
                 config_path='configs/plugin_cross.yaml',
                 ckpt_path='../ckpts/dreamvc_plugin.pt',
                 device='cpu'):

        with open(config_path, 'r') as fp:
            config = yaml.safe_load(fp)

        self.device = device
        self.model = P2E_Cross(config['model']).to(device)
        self.model.load_state_dict(torch.load(ckpt_path)['model'])
        self.model.eval()

        noise_scheduler = DDIMScheduler(num_train_timesteps=config['scheduler']['num_train_steps'],
                                        beta_start=config['scheduler']['beta_start'],
                                        beta_end=config['scheduler']['beta_end'],
                                        rescale_betas_zero_snr=True,
                                        timestep_spacing="trailing",
                                        clip_sample=False,
                                        prediction_type='v_prediction')
        self.noise_scheduler = noise_scheduler
        self.scale = config['scheduler']['scale']
        self.shift = config['scheduler']['shift']
        self.spk_shape = config['model']['unet']['in_channels']

    @torch.no_grad()
    def inference(self, text,
                  guidance_scale=5, guidance_rescale=0.7,
                  ddim_steps=50, eta=1, random_seed=2023,
                 ):
        text, text_mask = text
        self.model.eval()

        gen_shape = (1, self.spk_shape)
        
        if random_seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(random_seed)
        else:
            generator = torch.Generator(device=self.device)
            generator.seed()

        self.noise_scheduler.set_timesteps(ddim_steps)
    
        # init noise
        noise = torch.randn(gen_shape, generator=generator, device=self.device)
        latents = noise
    
        for t in self.noise_scheduler.timesteps:
            latents = self.noise_scheduler.scale_model_input(latents, t)
    
            if guidance_scale:
                output_text = self.model(latents, t, text, text_mask, train_cfg=False)
                output_uncond = self.model(latents, t, text, text_mask, train_cfg=True, cfg_prob=1.0)
    
                output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
                if guidance_rescale > 0.0:
                    output_pred = rescale_noise_cfg(output_pred, output_text,
                                                    guidance_rescale=guidance_rescale)
            else:
                output_pred = self.model(latents, t, text, text_mask, train_cfg=False)
    
            latents = self.noise_scheduler.step(model_output=output_pred, timestep=t, sample=latents,
                                                eta=eta, generator=generator).prev_sample
    
        # pred = reverse_minmax_norm_diff(latents, vmin=0.0, vmax=0.5)
        pred = scale_shift_re(latents, 1/self.scale, self.shift)
        # pred = torch.clip(pred, min=0.0, max=0.5)
        return pred