File size: 3,799 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import random
import torch
from einops import rearrange

from comfy.ldm.modules.attention import optimized_attention

from .rave_utils import grid_to_list, list_to_grid, shuffle_indices, shuffle_tensors2


def padding_count(n_frames, grid_frame_count):
    remainder = n_frames % grid_frame_count
    if remainder == 0:
        return 0
    else:
        difference = grid_frame_count - remainder
        return difference


def unpatchify(x, h, w, p=2):
    x = rearrange(x, 'b (h w) (p q d) -> b (h p) (w q) d', h=h, p=p, q=p)
    return rearrange(x, 'b h w d -> b (h w) d')


def patchify(x, h, w, p=2):
    return rearrange(x, 'b (p h q w) d -> b (h w) (p q d)', p=p, q=p, h=h, w=w)


def rave_attention(q, k, v, extra_options, n_heads):
    # get h,w
    batch_size, sequence_length, dim = q.shape
    shape = extra_options['original_shape']
    oh, ow = shape[-2:]
    ratio = oh/ow
    d = sequence_length
    w = int((d/ratio)**(0.5))
    h = int(d/w)

    rave_opts = extra_options.get('RAVE', {})
    grid_size = rave_opts.get('grid_size', 2)
    seed = rave_opts.get('seed', 1)
    len_conds = len(extra_options['cond_or_uncond'])
    n_frames = batch_size // len_conds
    original_n_frames = n_frames

    grid_frame_count = grid_size * grid_size
    n_padding_frames = padding_count(n_frames, grid_frame_count)
    if n_padding_frames > 0:
        random.seed(seed)
        cond_qs = []
        cond_ks = []
        cond_vs = []
        padding_frames = [random.randint(
            0, n_frames-1) for _ in range(n_padding_frames)]
        for cond_idx in range(len_conds):
            start, end = cond_idx*n_frames, (cond_idx+1)*n_frames
            cond_q = q[start:end]
            cond_q = torch.cat([cond_q, cond_q[padding_frames]])
            cond_qs.append(cond_q)
            cond_k = k[start:end]
            cond_k = torch.cat([cond_k, cond_k[padding_frames]])
            cond_ks.append(cond_k)
            cond_v = v[start:end]
            cond_v = torch.cat([cond_v, cond_v[padding_frames]])
            cond_vs.append(cond_v)

        q = torch.cat(cond_qs)
        k = torch.cat(cond_ks)
        v = torch.cat(cond_vs)

    n_frames = n_frames + n_padding_frames

    q = rearrange(q, 'b (h w) c -> b h w c', h=h, w=w)
    k = rearrange(k, 'b (h w) c -> b h w c', h=h, w=w)
    v = rearrange(v, 'b (h w) c -> b h w c', h=h, w=w)

    target_indexes = shuffle_indices(n_frames, seed=seed)

    original_indexes = list(range(n_frames))
    qs = []
    ks = []
    vs = []

    for i in range(len_conds):
        start, end = i*n_frames, (i+1)*n_frames
        q[start:end] = shuffle_tensors2(
            q[start:end], original_indexes, target_indexes)
        qs.append(list_to_grid(q[start:end], grid_size))
        k[start:end] = shuffle_tensors2(
            k[start:end], original_indexes, target_indexes)
        ks.append(list_to_grid(k[start:end], grid_size))
        v[start:end] = shuffle_tensors2(
            v[start:end], original_indexes, target_indexes)
        vs.append(list_to_grid(v[start:end], grid_size))

    q = torch.cat(qs)
    k = torch.cat(ks)
    v = torch.cat(vs)

    q = rearrange(q, 'b h w c -> b (h w) c')
    k = rearrange(k, 'b h w c -> b (h w) c')
    v = rearrange(v, 'b h w c -> b (h w) c')
    out = optimized_attention(q, k, v, n_heads, None)

    gh, gw = grid_size*h, grid_size*w
    out = rearrange(out, 'b (h w) c -> b h w c', h=gh, w=gw)
    out = grid_to_list(out, grid_size)
    out = rearrange(out, 'b h w c -> b (h w) c')

    outs = []
    for i in range(len_conds):
        start, end = i*n_frames, (i+1)*n_frames
        cond_out = shuffle_tensors2(
            out[start:end], target_indexes, original_indexes)
        cond_out = cond_out[:original_n_frames]
        outs.append(cond_out)

    return torch.cat(outs)