File size: 5,434 Bytes
9965bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch as th
from einops import rearrange

from .generic_sampler import SimpleWork
from .w_img import split_wimg, avg_merge_wimg

class CondIndSimple(SimpleWork):
    def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
        c, h, w = shape
        assert overlap_size == w // 2
        self.overlap_size = overlap_size
        self.num_img = num_img
        final_img_w = w * num_img - self.overlap_size * (num_img - 1)
        super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))

    def loss(self, x):
        x1, x2 = x[:-1], x[1:]
        return th.sum(
            (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
            dim=(1, 2, 3),
        )

    def get_eps_t_fn(self, eps_scalar_t_fn):
        def eps_t_fn(long_x, scalar_t, y=None):
            xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
            if y is not None:
                y = y.repeat_interleave(self.num_img)
            scalar_t = scalar_t.repeat_interleave(self.num_img)
            full_eps = eps_scalar_t_fn(xs, scalar_t, y=y)  #((b,n), c, h, w)
            full_eps = rearrange(
                full_eps,
                "(b n) c h w -> n b c h w", n = self.num_img
            )

            # calculate half eps
            half_eps = eps_scalar_t_fn(xs[:,:,:,-self.overlap_size:], scalar_t, y=y)  #((b,n), c, h, w//2)
            half_eps = rearrange(
                half_eps,
                "(b n) c h w -> n b c h w", n = self.num_img
            )

            half_eps[-1]=0

            full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
            whole_eps = rearrange(
                full_eps,
                "n b c h w -> (b n) c h w"
            )
            return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
        return eps_t_fn



class CondIndSR(SimpleWork):
    def __init__(self, shape, eps_scalar_t_fn, num_img, low_res, overlap_size=128):
        c, h, w = shape
        assert overlap_size == w // 2
        self.overlap_size = overlap_size
        self.low_overlap_size = low_res.shape[-2] // 2
        self.num_img = num_img
        final_img_w = w * num_img - self.overlap_size * (num_img - 1)
        assert low_res.shape[-1] == self.low_overlap_size * (num_img + 1)

        self.square_fn = self.get_square_sr_fn(eps_scalar_t_fn, low_res)
        self.half_fn = self.get_half_sr_fn(eps_scalar_t_fn, low_res)

        super().__init__((c, h, final_img_w), self.get_eps_t_fn())

    def get_square_sr_fn(self, eps_fn, low_res):
        low_res = split_wimg(low_res, self.num_img, False)
        def _fn(_x, _t, enable_grad):
            context = th.enable_grad if enable_grad else th.no_grad
            with context():
                vec_t = th.ones(_x.shape[0]).cuda() * _t
                rtn = eps_fn(_x, vec_t, low_res)
            rtn = rearrange(
                rtn,
                "(b n) c h w -> n b c h w", n = self.num_img
            )
            return rtn
        return _fn

    def get_half_sr_fn(self, eps_fn, low_res):
        low_res = split_wimg(low_res, self.num_img, False)
        def _fn(_x, _t, enable_grad):
            context = th.enable_grad if enable_grad else th.no_grad
            with context():
                vec_t = th.ones(_x.shape[0]).cuda() * _t
                half_eps = eps_fn(_x[:,:,:,-self.overlap_size:], vec_t, low_res[:,:,:,-self.low_overlap_size:])
            half_eps = rearrange(
                half_eps,
                "(b n) c h w -> n b c h w", n = self.num_img
            )

            half_eps[-1]=0
            return half_eps
        return _fn

    def get_eps_t_fn(self):
        def eps_t_fn(in_x, scalar_t, enable_grad=False):
            xs = split_wimg(in_x, self.num_img, rtn_overlap=False)

            # full eps
            full_eps = self.square_fn(xs, scalar_t, enable_grad)
            # calculate half eps
            half_eps = self.half_fn(xs, scalar_t, enable_grad)

            full_eps[:,:,:,:,-self.overlap_size:] = full_eps[:,:,:,:,-self.overlap_size:] - half_eps
            whole_eps = rearrange(
                full_eps,
                "n b c h w -> (b n) c h w"
            )
            out_eps = avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
            return out_eps
        return eps_t_fn




# class CondIndLong(SimpleWork):
#     def __init__(self, shape, eps_scalar_t_fn, overlap_size=32):
#         super().__init__(shape, eps_scalar_t_fn)
#         self.overlap_size = overlap_size

#     def loss(self, x):
#         x1, x2 = x[:-1], x[1:]
#         return th.sum(
#             (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
#             dim=(1, 2, 3),
#         )

#     def generate_xT(self, n):
#         white_noise = th.randn((n , *self.shape)).cuda()
#         return self.noise(white_noise, None) * 80.0

#     def noise(self, xt, scalar_t):
#         del scalar_t
#         noise = th.randn_like(xt)
#         b, _, _, w = xt.shape
#         final_img_w = w * b - self.overlap_size * (b - 1)
#         noise = rearrange(noise, "(t n) c h w -> t c h (n w)", t=1)[:, :, :, :final_img_w]
#         noise = split_wimg(noise, b, rtn_overlap=False)
#         return noise

#     def merge(self, xs):
#         return avg_merge_wimg(xs, self.overlap_size)