File size: 3,751 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from tqdm import tqdm, trange

from ..utils.sampling_utils import generate_eta_values


@torch.no_grad()
def mochi_sample(model, z, sigmas, callback=None):
    total_steps = len(sigmas)-1
    latent_shape = z.shape
    for i in tqdm(range(0, total_steps), desc="Processing Samples", total=total_steps):
        pred = model(z=z, sigma=torch.full([latent_shape[0]], sigmas[i], device=z.device))
        z = z + pred * (sigmas[i] - sigmas[i + 1])
        
        if callback is not None:
            callback(i, z)
    
    return z


def get_rf_forward_sample_fn(gamma, seed, correction=True):
    # Controlled Forward ODE (Algorithm 1)
    generator = torch.Generator()
    generator.manual_seed(seed)

    @torch.no_grad()
    def sample_forward(model, y0, sigmas, extra_args={}, callback=None, disable=None):
        Y = y0.clone()
        y1 = torch.randn(Y.shape, generator=generator).to(y0.device)
        N = len(sigmas)-1
        s_in = y0.new_ones([y0.shape[0]])
        for i in trange(N, disable=disable):
            # t_i = i/N 
            t_i = sigmas[i]

            # 6. Unconditional Vector field uti(Yti) = u(Yti, ti, Φ(“”); φ)
            unconditional_vector_field = -model(Y, sigmas[i]*s_in, **extra_args)
            
            if correction:
                # 7.Conditional Vector field  uti(Yti|y1) = (y1−Yti)/1−ti
                conditional_vector_field = (y1-Y)/(1-t_i)
                
                # 8. Controlled Vector field ti(Yti) = uti(Yti) + γ (uti(Yti|y1) − uti(Yti))
                controlled_vector_field = unconditional_vector_field + gamma * (conditional_vector_field - unconditional_vector_field)
            else:
                controlled_vector_field = unconditional_vector_field
            
            # 9. Next state Yti+1 = Yti + ˆuti(Yti) (σ(ti+1) − σ(ti))
            Y = Y + controlled_vector_field * (sigmas[i+1] - sigmas[i])

            if callback is not None:
                callback({'x': Y, 'denoised': Y, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})

        return Y

    return sample_forward


def get_rf_reverse_sample_fn(latent_image, eta, start_time, end_time, eta_trend):
    # Controlled Reverse ODE (Algorithm 2)
    @torch.no_grad()
    def sample_reverse(model, y1, sigmas, extra_args={}, callback=None, disable=None):
        latent_shape = y1.shape
        X = y1.clone()
        N = len(sigmas)-1
        y0 = latent_image.clone().to(y1.device)
        eta_values = generate_eta_values(N, start_time, end_time, eta, eta_trend)
        s_in = y0.new_ones([y0.shape[0]])
        for i in trange(N, disable=disable):
            # t_i = i/N 
            t_i = 1 - sigmas[i]

            # 5. Unconditional Vector field uti(Xti) = -u(Xti, 1-ti, Φ(“prompt”); φ)
            # torch.full([latent_shape[0]], sigmas[i], device=X.device)
            unconditional_vector_field = model(X, sigmas[i]*s_in, **extra_args)
            
            # 6.Conditional Vector field  uti(Xti|y0) = (y0−Xti)/(1−ti)
            conditional_vector_field = (y0-X)/(1-t_i)
            
            # 7. Controlled Vector field ti(Yti) = uti(Yti) + γ (uti(Yti|y1) − uti(Yti))
            controlled_vector_field = unconditional_vector_field + eta_values[i] * (conditional_vector_field - unconditional_vector_field)
            
            # 8. Next state Yti+1 = Yti + ˆuti(Yti) (σ(ti+1) − σ(ti))
            X = X + controlled_vector_field * (sigmas[i] - sigmas[i+1])
            # X = X + -unconditional_vector_field * (sigmas[i] - sigmas[i+1])

            if callback is not None:
                callback({'x': X, 'denoised': X, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})

        return X
    
    return sample_reverse