File size: 1,494 Bytes
9965bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch as th
from einops import rearrange

from .generic_sampler import SimpleWork
from .w_img import split_wimg, avg_merge_wimg

class AvgLong(SimpleWork):
    def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
        c, h, w = shape
        assert overlap_size == w // 2
        self.overlap_size = overlap_size
        self.num_img = num_img
        final_img_w = w * num_img - self.overlap_size * (num_img - 1)
        super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))

    def loss(self, x):
        x1, x2 = x[:-1], x[1:]
        return th.sum(
            (th.abs(x1[:, :, :, -self.overlap_size :] - x2[:, :, :, : self.overlap_size])) ** 2,
            dim=(1, 2, 3),
        )

    def get_eps_t_fn(self, eps_scalar_t_fn):
        def eps_t_fn(long_x, scalar_t, y=None):
            xs = split_wimg(long_x, self.num_img, rtn_overlap=False)
            if y is not None:
                y = y.repeat_interleave(self.num_img)
            scalar_t = scalar_t.repeat_interleave(self.num_img)
            full_eps = eps_scalar_t_fn(xs, scalar_t, y=y)  #((b,n), c, h, w)
            full_eps = rearrange(
                full_eps,
                "(b n) c h w -> n b c h w", n = self.num_img
            )

            whole_eps = rearrange(
                full_eps,
                "n b c h w -> (b n) c h w"
            )
            return avg_merge_wimg(whole_eps, self.overlap_size, n=self.num_img, is_avg=False)
        return eps_t_fn