File size: 1,124 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

default_attn = {
    'inputs': [True] * 10,
    'input_idxs': list(range(10)),
    'middle_0': True,
    'outputs': [True] * 12,
    'output_idxs': list(range(12))
}


class ApplyFluxRaveAttentionNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required":
                {
                    "model": ("MODEL",),
                    "grid_size": ("INT", {"default": 3, "min": 1, "max": 10}),
                    "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
                },
                "optional": {
                    "attn_override": ("ATTN_OVERRIDE",)
                }
                }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "apply"

    CATEGORY = "fluxtapoz"

    def apply(self, model, grid_size, seed, attn_override=default_attn):
        model = model.clone()

        transformer_options = {**model.model_options.get('transformer_options', {})}
        model.model_options['transformer_options'] = transformer_options

        transformer_options['RAVE'] = {
            "grid_size": grid_size,
            "seed": seed,
        }

        return (model, )