File size: 9,668 Bytes
95d4bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca71d6b
95d4bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from dataclasses import dataclass

from einops import rearrange,repeat
import torch
import torch.nn.functional as F
from torch import Tensor
from typing import List

from flux.sampling import get_schedule, unpack,denoise_kv,denoise_kv_inf
from flux.util import load_flow_model
from flux.model import Flux_kv

@dataclass
class SamplingOptions:
    source_prompt: str = ''
    target_prompt: str = ''
    # prompt: str
    width: int = 1366
    height: int = 768
    inversion_num_steps: int = 0
    denoise_num_steps: int = 0
    skip_step: int = 0
    inversion_guidance: float = 1.0
    denoise_guidance: float = 1.0
    seed: int = 42
    re_init: bool = False
    attn_mask: bool = False

class only_Flux(torch.nn.Module): # 仅包括初始化函数
    def __init__(self, device,name='flux-dev'):
        self.device = device
        self.name = name
        super().__init__()
        self.model = load_flow_model(self.name, device=self.device,flux_cls=Flux_kv)
        
    def create_attention_mask(self,seq_len, mask_indices, text_len=512, device='cuda'):
        """
        创建自定义的注意力掩码。

        Args:
            seq_len (int): 序列长度。
            mask_indices (List[int]): 图像令牌中掩码区域的索引。
            text_len (int): 文本令牌的长度,默认 512。
            device (str): 设备类型,如 'cuda' 或 'cpu'。

        Returns:
            torch.Tensor: 形状为 (seq_len, seq_len) 的注意力掩码。
        """
        # 初始化掩码为全 False
        attention_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)

        # 文本令牌索引
        text_indices = torch.arange(0, text_len, device=device)

        # 掩码区域令牌索引
        mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device)

        # 背景区域令牌索引
        all_indices = torch.arange(text_len, seq_len, device=device)
        background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices])

        # 设置文本查询可以关注所有键
        attention_mask[text_indices.unsqueeze(1).expand(-1, seq_len)] = True
        attention_mask[text_indices.unsqueeze(1), text_indices] = True# 关注文本
        attention_mask[text_indices.unsqueeze(1), background_token_indices] = True # 关注背景

        
        # attention_mask[mask_token_indices.unsqueeze(1), background_token_indices] = True  # 关注背景
        attention_mask[mask_token_indices.unsqueeze(1), text_indices] = True  # 关注文本
        attention_mask[mask_token_indices.unsqueeze(1), mask_token_indices] = True  # 关注掩码区域

        
        # attention_mask[background_token_indices.unsqueeze(1).expand(-1, seq_len), :] = False
        # attention_mask[background_token_indices.unsqueeze(1), mask_token_indices] = True  # 关注掩码
        attention_mask[background_token_indices.unsqueeze(1), text_indices] = True  # 关注文本
        attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True  # 关注背景区域

        return attention_mask.unsqueeze(0)
     
class Flux_kv_edit_inf(only_Flux):
    def __init__(self, device,name):
        super().__init__(device,name)

    @torch.inference_mode()
    def forward(self,inp,inp_target,mask:Tensor,opts):
        #############根据mask生成token序列上的索引试试#######################
        info = {}
        info['feature'] = {}
        bs, L, d = inp["img"].shape
        h = opts.height // 8
        w = opts.width // 8
        mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
        mask[mask > 0] = 1
        
        mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
        # mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1)
        # mask = mask.flatten().to(self.device[1])
        mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
        info['mask'] = mask
        bool_mask = (mask.sum(dim=2) > 0.5)
        info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
         #单独分离inversion
        if opts.attn_mask and (~bool_mask).any(): # mask有一个false就进行attn mask 全true就none 
            attention_mask = self.create_attention_mask(L+512, info['mask_indices'], device=self.device)
        else:
            attention_mask = None   
        info['attention_mask'] = attention_mask
        
        denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
        # denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=False)
        denoise_timesteps = denoise_timesteps[opts.skip_step:]
    
        z0 = inp["img"]

        with torch.no_grad():
            info['inject'] = True
            z_fe, info = denoise_kv_inf(self.model, img=inp["img"], img_ids=inp['img_ids'], 
                                    source_txt=inp['txt'], source_txt_ids=inp['txt_ids'], source_vec=inp['vec'],
                                    target_txt=inp_target['txt'], target_txt_ids=inp_target['txt_ids'], target_vec=inp_target['vec'],
                                    timesteps=denoise_timesteps, source_guidance=opts.inversion_guidance, target_guidance=opts.denoise_guidance,
                                    info=info)
        mask_indices = info['mask_indices'] # 图片seq坐标下的
        # x是根据索引取出来的 再放回去
        z0[:, mask_indices,...] = z_fe

        # decode latents to pixel space
        z0 = unpack(z0.float(),  opts.height, opts.width)
        del info
        return z0

class Flux_kv_edit(only_Flux):
    def __init__(self, device,name):
        super().__init__(device,name)
    
    @torch.inference_mode()
    def forward(self,inp,inp_target,mask:Tensor,opts):
        z0,zt,info = self.inverse(inp,mask,opts)
        z0 = self.denoise(z0,zt,inp_target,mask,opts,info)
        return z0
    @torch.inference_mode()
    def inverse(self,inp,mask,opts):
        info = {}
        info['feature'] = {}
        bs, L, d = inp["img"].shape
        h = opts.height // 8
        w = opts.width // 8
        # mask = F.interpolate(mask, size=(h,w), mode='nearest')

        if opts.attn_mask:
            mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
            mask[mask > 0] = 1
            
            mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
            # mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1)
            # mask = mask.flatten().to(self.device[1])
            mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
            bool_mask = (mask.sum(dim=2) > 0.5)
            mask_indices = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
            
            #单独分离inversion
            assert not (~bool_mask).all(), "mask is all false"
            assert not (bool_mask).all(), "mask is all true"
            attention_mask = self.create_attention_mask(L+512, mask_indices, device=mask.device)
            info['attention_mask'] = attention_mask
    
        
        denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
        denoise_timesteps = denoise_timesteps[opts.skip_step:]
        
        # 加噪过程
        z0 = inp["img"].clone()        
        info['inverse'] = True
        zt, info = denoise_kv(self.model, **inp, timesteps=denoise_timesteps, guidance=opts.inversion_guidance, inverse=True, info=info)
        return z0,zt,info
    
    @torch.inference_mode()
    def denoise(self,z0,zt,inp_target,mask:Tensor,opts,info):
        
        h = opts.height // 8
        w = opts.width // 8
        
        mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
        mask[mask > 0] = 1
        
        mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
      
        mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
        info['mask'] = mask
        bool_mask = (mask.sum(dim=2) > 0.5)
        info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
        
        denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
        denoise_timesteps = denoise_timesteps[opts.skip_step:]
        # 重建的时候不需要全部token z这里需要根据indice拿出来
        mask_indices = info['mask_indices'] # 图片seq坐标下的
        if opts.re_init:
            noise = torch.randn_like(zt)
            t  = denoise_timesteps[0]
            zt_noise = z0 *(1 - t) + noise * t
            inp_target["img"] = zt_noise[:, mask_indices,...]
        else:
            inp_target["img"] = zt[:, mask_indices,...]

        info['inverse'] = False
        x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)
        # x是根据索引取出来的 再放回去
        z0[:, mask_indices,...] = z0[:, mask_indices,...] * (1 - info['mask'][:, mask_indices,...]) + x * info['mask'][:, mask_indices,...]
        # x = inp['img'].clone()

        # decode latents to pixel space
        z0 = unpack(z0.float(),  opts.height, opts.width)
        del info
        return z0