File size: 5,673 Bytes
ef198e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from einops import rearrange
from typing import Any, Dict, Optional
from diffusers.utils.import_utils import is_xformers_available
from canonicalize.models.transformer_mv2d import XFormersMVAttnProcessor, MVAttnProcessor


class ReferenceOnlyAttnProc(torch.nn.Module):
    def __init__(
        self,
        chained_proc,
        enabled=False,
        name=None
    ) -> None:
        super().__init__()
        self.enabled = enabled
        self.chained_proc = chained_proc
        self.name = name

    def __call__(
        self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None,
        mode="w", ref_dict: dict = None, is_cfg_guidance = False,num_views=4,
            multiview_attention=True,
            cross_domain_attention=False,
    ) -> Any:
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states

        if self.enabled:
            if mode == 'w':
                ref_dict[self.name] = encoder_hidden_states
                res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=1,
                    multiview_attention=False,
                    cross_domain_attention=False,)
            elif mode == 'r':
                encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
                if self.name in ref_dict:
                    encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
                res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
                    multiview_attention=False,
                    cross_domain_attention=False,)
            elif mode == 'm':
                encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
            elif mode == 'n':
                encoder_hidden_states = rearrange(encoder_hidden_states, '(b t) d c-> b (t d) c', t=num_views)
                encoder_hidden_states = torch.cat([encoder_hidden_states], dim=1).unsqueeze(1).repeat(1,num_views,1,1).flatten(0,1)
                res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask, num_views=num_views,
                    multiview_attention=False,
                    cross_domain_attention=False,)
            else:
                assert False, mode
        else:
            res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
        return res
        
class RefOnlyNoisedUNet(torch.nn.Module):
    def __init__(self, unet, train_sched, val_sched) -> None:
        super().__init__()
        self.unet = unet
        self.train_sched = train_sched
        self.val_sched = val_sched

        unet_lora_attn_procs = dict()
        for name, _ in unet.attn_processors.items():
            if is_xformers_available():
                default_attn_proc = XFormersMVAttnProcessor()
            else:
                default_attn_proc = MVAttnProcessor()
            unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
                default_attn_proc, enabled=name.endswith("attn1.processor"), name=name)
            
        self.unet.set_attn_processor(unet_lora_attn_procs)

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.unet, name)

    def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
        if is_cfg_guidance:
            encoder_hidden_states = encoder_hidden_states[1:]
            class_labels = class_labels[1:]
        self.unet(
            noisy_cond_lat, timestep,
            encoder_hidden_states=encoder_hidden_states,
            class_labels=class_labels,
            cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
            **kwargs
        )

    def forward(
        self, sample, timestep, encoder_hidden_states, class_labels=None,
        *args, cross_attention_kwargs,
        down_block_res_samples=None, mid_block_res_sample=None,
        **kwargs
    ):
        cond_lat = cross_attention_kwargs['cond_lat']
        is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
        noise = torch.randn_like(cond_lat)
        if self.training:
            noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
            noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
        else:
            noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
            noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
        ref_dict = {}
        self.forward_cond(
            noisy_cond_lat, timestep,
            encoder_hidden_states, class_labels,
            ref_dict, is_cfg_guidance, **kwargs
        )
        weight_dtype = self.unet.dtype
        return self.unet(
            sample, timestep,
            encoder_hidden_states, *args,
            class_labels=class_labels,
            cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
            down_block_additional_residuals=[
                sample.to(dtype=weight_dtype) for sample in down_block_res_samples
            ] if down_block_res_samples is not None else None,
            mid_block_additional_residual=(
                mid_block_res_sample.to(dtype=weight_dtype)
                if mid_block_res_sample is not None else None
            ),
            **kwargs
        )