File size: 2,959 Bytes
9965bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from collections import defaultdict

import math
import numpy as np
import torch as th
from tqdm import tqdm

__all__ = [
    "generic_sampler",
    "SimpleWork",
]


def batch_mul(a, b):  # pylint: disable=invalid-name
    return th.einsum("a...,a...->a...", a, b)

class SimpleWork:
    def __init__(self, shape, eps_scalar_t_fn):
        self.shape = shape
        self.eps_scalar_t_fn = eps_scalar_t_fn

    def generate_xT(self, n):
        return 80.0 * th.randn((n , *self.shape)).cuda()

    def x0_fn(self, xt, scalar_t, y=None):
        cur_eps = self.eps_scalar_t_fn(xt, scalar_t, y=y)
        x0 = xt - scalar_t * cur_eps
        x0 = th.clip(x0, -1,1)
        return x0, {}, {"x0": x0.cpu()}

    def noise(self, xt, scalar_t):
        del scalar_t
        return th.randn_like(xt)

    def  rev_ts(self, n_step, ts_order):
        _rev_ts = th.pow(
            th.linspace(
                np.power(80.0, 1.0 / ts_order),
                np.power(1e-3, 1.0 / ts_order),
                n_step + 1
            ),
            ts_order
        )
        return _rev_ts.cuda()

def generic_sampler(  # pylint: disable=too-many-locals
    x,
    rev_ts,
    noise_fn,
    x0_pred_fn,
    xt_lgv_fn=None,
    s_churn = 0.0,
    before_step_fn=None,
    end_fn=None, # to do???
    is_tqdm=True,
    is_traj=True,
):
    measure_loss = defaultdict(list)
    traj = defaultdict(list)
    if callable(x):
        x = x()
    if traj:
        traj["xt"].append(x.cpu())

    s_t_min = 0.05
    s_t_max = 50.0
    s_noise = 1.003
    eta = min(s_churn / len(rev_ts), math.sqrt(2.0) - 1)

    loop = zip(rev_ts[:-1], rev_ts[1:])
    if is_tqdm:
        loop = tqdm(loop)

    running_x = x
    for cur_t, next_t in loop:
        # cur_x = traj["xt"][-1].clone().to("cuda")
        cur_x = running_x
        if cur_t < s_t_max and cur_t > s_t_min:
            hat_cur_t = cur_t + eta * cur_t
            cur_noise = noise_fn(cur_x, cur_t)
            cur_x = cur_x + s_noise * cur_noise * th.sqrt(hat_cur_t ** 2 - cur_t ** 2)
            cur_t = hat_cur_t

        if before_step_fn is not None:
            # TODO: may change the callabck
            cur_x = before_step_fn(cur_x, cur_t)

        x0, loss_info, traj_info = x0_pred_fn(cur_x, cur_t)
        epsilon_1 = (cur_x - x0) / cur_t

        xt_next = x0 + next_t * epsilon_1

        x0, loss_info, traj_info = x0_pred_fn(xt_next, next_t)
        epsilon_2 = (xt_next - x0) / next_t

        xt_next = cur_x + (next_t - cur_t) * (epsilon_1 + epsilon_2) / 2

        running_x = xt_next

        if is_traj:
            for key, value in loss_info.items():
                measure_loss[key].append(value)

            for key, value in traj_info.items():
                traj[key].append(value)
            traj["xt"].append(running_x.to("cpu").detach())

        if xt_lgv_fn:
            raise RuntimeError("Not implemented")

    if is_traj:
        return traj, measure_loss
    return running_x