File size: 2,336 Bytes
9965bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch as th
from einops import rearrange

__all__ = [
    "split_wimg",
]

def split_wimg(wimg, n_img, rtn_overlap=True):
    if wimg.ndim == 3:
        wimg = wimg[None]
    _, _, h, w = wimg.shape
    base_len = 128   # todo: hard code 128 here (the length of the latents)
    overlap_size = (n_img * base_len - w) // (n_img - 1)
    assert n_img * base_len - overlap_size * (n_img - 1) == w
    
    img = th.nn.functional.unfold(wimg, kernel_size=(h, base_len), stride=base_len - overlap_size) #(B, block, n_img)
    img = rearrange(
        img,
        "b (c h w) n -> (b n) c h w", h=h, w=base_len
    )
    
    if rtn_overlap:
        return img , overlap_size
    return img

def avg_merge_wimg(imgs, overlap_size, n=None, is_avg=True):
    b, _, h, w = imgs.shape
    if n == None:
        n = b
    unfold_img = rearrange(
        imgs,
        "(b n) c h w -> b (c h w) n", n = n
    )
    img = th.nn.functional.fold(
        unfold_img,
        (h, n * w - (n-1) * overlap_size),
        kernel_size = (h, w),
        stride = w - overlap_size
    ) 
    if is_avg:
        counter = th.nn.functional.fold(
            th.ones_like(unfold_img), 
            (h, n * w - (n-1) * overlap_size),
            kernel_size = (h, w),
            stride = w - overlap_size
        )
        return img / counter
    return img

# legacy code use naive implementation

def split_wimg_legacy(himg, n_img, rtn_overlap=True):
    if himg.ndim == 3:
        himg = himg[None]
    _, _, h, w = himg.shape
    overlap_size = (n_img * h - w) // (n_img - 1)
    assert n_img * h - overlap_size * (n_img - 1) == w
    himg = himg[0]
    rtn_img = [himg[:, :, :h]]
    for i in range(n_img - 1):
        rtn_img.append(himg[:, :, (h - overlap_size) * (i + 1) : h + (h - overlap_size) * (i + 1)])
    if rtn_overlap:
        return th.stack(rtn_img), overlap_size
    return th.stack(rtn_img)

def avg_merge_wimg_legacy(imgs, overlap_size):
    _, _, _, w = imgs.shape
    rtn_img = [imgs[0]]
    for cur_img in imgs[1:]:
        rtn_img.append(cur_img[:, :, overlap_size:])
    first_img = th.cat(rtn_img, dim=-1)

    rtn_img = []
    for cur_img in imgs[:-1]:
        rtn_img.append(cur_img[:, :, : w - overlap_size])
    rtn_img.append(imgs[-1])
    second_img = th.cat(rtn_img, dim=-1)

    return (first_img + second_img) / 2.0