File size: 2,095 Bytes
9965bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import torch as th
from einops import rearrange

from .generic_sampler import SimpleWork
from .w_img import split_wimg, avg_merge_wimg

class AvgCircle(SimpleWork):
    def __init__(self, shape, eps_scalar_t_fn, num_img, overlap_size=32):
        c, h, w = shape
        self.base_img_w = w
        self.overlap_size = overlap_size
        self.num_img = num_img
        final_img_w = w * num_img - self.overlap_size * num_img
        super().__init__((c, h, final_img_w), self.get_eps_t_fn(eps_scalar_t_fn))

    def get_eps_t_fn(self, eps_scalar_t_fn):
        def eps_t_fn(long_x, scalar_t, enable_grad=False):
            shift = np.random.randint(self.base_img_w)
            long_x = th.cat(
                [
                    long_x[:,:,:,shift:],
                    long_x[:,:,:,:shift]
                ],
                dim=-1
            )

            x = th.cat(
                [
                    long_x,
                    long_x[:,:,:,:self.overlap_size]
                ],
                dim=-1,
            )
            xs, _overlap = split_wimg(x, self.num_img, rtn_overlap=True)
            assert _overlap == self.overlap_size
            full_eps = eps_scalar_t_fn(xs, scalar_t, enable_grad) # #((b,n), c, h, w)

            eps = avg_merge_wimg(full_eps, self.overlap_size, n=self.num_img)
            eps = th.cat(
                [
                    (eps[:,:,:,:self.overlap_size] + eps[:,:,:,-self.overlap_size:])/2.0,
                    eps[:,:,:,self.overlap_size:-self.overlap_size]
                ],
                dim=-1
            )
            assert eps.shape == long_x.shape
            return th.cat(
                [
                    eps[:,:,:,-shift:],
                    eps[:,:,:,:-shift],
                ],
                dim=-1
            )
            # return eps

        return eps_t_fn

    def x0_fn(self, xt, scalar_t, enable_grad=False):
        cur_eps = self.eps_scalar_t_fn(xt, scalar_t, enable_grad)
        x0 = xt - scalar_t * cur_eps
        return x0, {}, {
            "x0": x0.cpu()
        }