File size: 4,575 Bytes
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import numpy as np
import torchvision.transforms as tvtf
from tools.StableDiffusion.Hack_SD_stepwise import Hack_SDPipe_Stepwise

'''
Input: Multiview images with added noise
denoise to x0
denoise from step t1 to step t2
'''    

class HackSD_MCS():
    '''
        transform images to self.latents
        add noise to self.latents
        predict step noise --> x0
        mv RGB-D warp as target image
        target image encode to latent and get target noise
        noise rectification
        step denoise
    '''
    def __init__(self,device='cpu',use_lcm=True,denoise_steps=20,
                 sd_ckpt=f'tools/StableDiffusion/ckpt',
                 lcm_ckpt=f'latent-consistency/lcm-lora-sdv1-5') -> None:
        '''
        ref_rgb should be -1~1 tensor B*3*H*W
        '''
        self.device = device
        self.target_type = np.float32
        self.use_lcm = use_lcm
        self.sd_ckpt = sd_ckpt
        self.lcm_ckpt = lcm_ckpt
        self._load_model()
        # define step to add noise and steps to denoise
        self.denoise_steps = denoise_steps
        self.timesteps = self.model.timesteps

    def _load_model(self):
        self.model = Hack_SDPipe_Stepwise.from_pretrained(self.sd_ckpt)
        self.model._use_lcm(self.use_lcm,self.lcm_ckpt)
        self.model.re_init(num_inference_steps=50)
        try:
            self.model.enable_xformers_memory_efficient_attention()
        except:
            pass  # run without xformers
        self.model = self.model.to(self.device)

    def to(self, device):
        self.device = device
        self.model.to(device)

    @ torch.no_grad()
    def _add_noise_to_latent(self,latents):
        bsz = latents.shape[0]
        # in the Stable Diffusion, the iterations numbers is 1000 for adding the noise and denosing.
        timestep = self.timesteps[-self.denoise_steps]
        timestep = timestep.repeat(bsz).to(self.device)
        # target noise
        noise = torch.randn_like(latents)
        # add noise
        noisy_latent = self.model.scheduler.add_noise(latents, noise, timestep)
        # -------------------- noise for supervision -----------------
        if self.model.scheduler.config.prediction_type == "epsilon":
            target = noise
        elif self.model.scheduler.config.prediction_type == "v_prediction":
            target = self.model.scheduler.get_velocity(latents, noise, timestep)
        return noisy_latent, timestep, target

    @ torch.no_grad()
    def _encode_mv_init_images(self, images):
        '''
        images should be B3HW
        '''
        images = images * 2 - 1
        self.latents = self.model._encode(images)
        self.latents,_,_ = self._add_noise_to_latent(self.latents)

    @ torch.no_grad()
    def _sd_forward(self, denoise_step, prompt_latent:torch.Tensor):
        # temp noise prediction
        t = self.timesteps[[-self.denoise_steps+denoise_step]].to(self.device)
        noise_pred = self.model._step_noise(self.latents, t, prompt_latent.repeat(len(self.latents),1,1))
        # solve image
        _,x0 = self.model._solve_x0(self.latents,noise_pred,t)   
        x0 = (x0 + 1) / 2 # in 0-1
        return t, noise_pred, x0   
    

    @ torch.no_grad()
    def _denoise_to_x0(self, timestep_in_1000, prompt_latent:torch.Tensor):
        # temp noise prediction
        noise_pred = self.model._step_noise(self.latents, timestep_in_1000, prompt_latent.repeat(len(self.latents),1,1))
        # solve image
        _,x0 = self.model._solve_x0(self.latents,noise_pred,timestep_in_1000)   
        x0 = (x0 + 1) / 2 # in 0-1
        return noise_pred, x0   

    @ torch.no_grad()
    def _step_denoise(self, t, pred_noise, rect_x0, rect_w = 0.7):
        '''
        pred_noise B4H//8W//8
        x0, rect_x0 B3HW
        '''
        # encoder rect_x0 to latent
        rect_x0 = rect_x0 * 2 - 1
        rect_latent = self.model._encode(rect_x0)
        # rectified noise
        rect_noise = self.model._solve_noise_given_x0_latent(self.latents,rect_latent,t)
        # noise rectification
        rect_noise = rect_noise / rect_noise.std(dim=list(range(1, rect_noise.ndim)),keepdim=True) \
                                * pred_noise.std(dim=list(range(1, pred_noise.ndim)),keepdim=True)
        pred_noise = pred_noise*(1.-rect_w) + rect_noise*rect_w
        # step forward
        self.latents = self.model._step_denoise(self.latents,pred_noise,t)

    @ torch.no_grad()
    def _decode_mv_imgs(self):
        imgs = self.model._decode(self.latents)
        imgs = (imgs + 1) / 2
        return imgs