File size: 9,668 Bytes
95d4bb7 ca71d6b 95d4bb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 |
from dataclasses import dataclass
from einops import rearrange,repeat
import torch
import torch.nn.functional as F
from torch import Tensor
from typing import List
from flux.sampling import get_schedule, unpack,denoise_kv,denoise_kv_inf
from flux.util import load_flow_model
from flux.model import Flux_kv
@dataclass
class SamplingOptions:
source_prompt: str = ''
target_prompt: str = ''
# prompt: str
width: int = 1366
height: int = 768
inversion_num_steps: int = 0
denoise_num_steps: int = 0
skip_step: int = 0
inversion_guidance: float = 1.0
denoise_guidance: float = 1.0
seed: int = 42
re_init: bool = False
attn_mask: bool = False
class only_Flux(torch.nn.Module): # 仅包括初始化函数
def __init__(self, device,name='flux-dev'):
self.device = device
self.name = name
super().__init__()
self.model = load_flow_model(self.name, device=self.device,flux_cls=Flux_kv)
def create_attention_mask(self,seq_len, mask_indices, text_len=512, device='cuda'):
"""
创建自定义的注意力掩码。
Args:
seq_len (int): 序列长度。
mask_indices (List[int]): 图像令牌中掩码区域的索引。
text_len (int): 文本令牌的长度,默认 512。
device (str): 设备类型,如 'cuda' 或 'cpu'。
Returns:
torch.Tensor: 形状为 (seq_len, seq_len) 的注意力掩码。
"""
# 初始化掩码为全 False
attention_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=device)
# 文本令牌索引
text_indices = torch.arange(0, text_len, device=device)
# 掩码区域令牌索引
mask_token_indices = torch.tensor([idx + text_len for idx in mask_indices], device=device)
# 背景区域令牌索引
all_indices = torch.arange(text_len, seq_len, device=device)
background_token_indices = torch.tensor([idx for idx in all_indices if idx not in mask_token_indices])
# 设置文本查询可以关注所有键
attention_mask[text_indices.unsqueeze(1).expand(-1, seq_len)] = True
attention_mask[text_indices.unsqueeze(1), text_indices] = True# 关注文本
attention_mask[text_indices.unsqueeze(1), background_token_indices] = True # 关注背景
# attention_mask[mask_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景
attention_mask[mask_token_indices.unsqueeze(1), text_indices] = True # 关注文本
attention_mask[mask_token_indices.unsqueeze(1), mask_token_indices] = True # 关注掩码区域
# attention_mask[background_token_indices.unsqueeze(1).expand(-1, seq_len), :] = False
# attention_mask[background_token_indices.unsqueeze(1), mask_token_indices] = True # 关注掩码
attention_mask[background_token_indices.unsqueeze(1), text_indices] = True # 关注文本
attention_mask[background_token_indices.unsqueeze(1), background_token_indices] = True # 关注背景区域
return attention_mask.unsqueeze(0)
class Flux_kv_edit_inf(only_Flux):
def __init__(self, device,name):
super().__init__(device,name)
@torch.inference_mode()
def forward(self,inp,inp_target,mask:Tensor,opts):
#############根据mask生成token序列上的索引试试#######################
info = {}
info['feature'] = {}
bs, L, d = inp["img"].shape
h = opts.height // 8
w = opts.width // 8
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
mask[mask > 0] = 1
mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
# mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1)
# mask = mask.flatten().to(self.device[1])
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
info['mask'] = mask
bool_mask = (mask.sum(dim=2) > 0.5)
info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
#单独分离inversion
if opts.attn_mask and (~bool_mask).any(): # mask有一个false就进行attn mask 全true就none
attention_mask = self.create_attention_mask(L+512, info['mask_indices'], device=self.device)
else:
attention_mask = None
info['attention_mask'] = attention_mask
denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
# denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=False)
denoise_timesteps = denoise_timesteps[opts.skip_step:]
z0 = inp["img"]
with torch.no_grad():
info['inject'] = True
z_fe, info = denoise_kv_inf(self.model, img=inp["img"], img_ids=inp['img_ids'],
source_txt=inp['txt'], source_txt_ids=inp['txt_ids'], source_vec=inp['vec'],
target_txt=inp_target['txt'], target_txt_ids=inp_target['txt_ids'], target_vec=inp_target['vec'],
timesteps=denoise_timesteps, source_guidance=opts.inversion_guidance, target_guidance=opts.denoise_guidance,
info=info)
mask_indices = info['mask_indices'] # 图片seq坐标下的
# x是根据索引取出来的 再放回去
z0[:, mask_indices,...] = z_fe
# decode latents to pixel space
z0 = unpack(z0.float(), opts.height, opts.width)
del info
return z0
class Flux_kv_edit(only_Flux):
def __init__(self, device,name):
super().__init__(device,name)
@torch.inference_mode()
def forward(self,inp,inp_target,mask:Tensor,opts):
z0,zt,info = self.inverse(inp,mask,opts)
z0 = self.denoise(z0,zt,inp_target,mask,opts,info)
return z0
@torch.inference_mode()
def inverse(self,inp,mask,opts):
info = {}
info['feature'] = {}
bs, L, d = inp["img"].shape
h = opts.height // 8
w = opts.width // 8
# mask = F.interpolate(mask, size=(h,w), mode='nearest')
if opts.attn_mask:
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
mask[mask > 0] = 1
mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
# mask = F.max_pool2d(mask, kernel_size=3, stride=1, padding=1)
# mask = mask.flatten().to(self.device[1])
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
bool_mask = (mask.sum(dim=2) > 0.5)
mask_indices = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
#单独分离inversion
assert not (~bool_mask).all(), "mask is all false"
assert not (bool_mask).all(), "mask is all true"
attention_mask = self.create_attention_mask(L+512, mask_indices, device=mask.device)
info['attention_mask'] = attention_mask
denoise_timesteps = get_schedule(opts.denoise_num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
denoise_timesteps = denoise_timesteps[opts.skip_step:]
# 加噪过程
z0 = inp["img"].clone()
info['inverse'] = True
zt, info = denoise_kv(self.model, **inp, timesteps=denoise_timesteps, guidance=opts.inversion_guidance, inverse=True, info=info)
return z0,zt,info
@torch.inference_mode()
def denoise(self,z0,zt,inp_target,mask:Tensor,opts,info):
h = opts.height // 8
w = opts.width // 8
mask = F.interpolate(mask, size=(h,w), mode='bilinear', align_corners=False)
mask[mask > 0] = 1
mask = repeat(mask, 'b c h w -> b (repeat c) h w', repeat=16)
mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
info['mask'] = mask
bool_mask = (mask.sum(dim=2) > 0.5)
info['mask_indices'] = torch.nonzero(bool_mask)[:,1] # 使用花式索引 即 数字tensor索引tensor 这个是基于图像的 在seq中需要加512
denoise_timesteps = get_schedule(opts.denoise_num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
denoise_timesteps = denoise_timesteps[opts.skip_step:]
# 重建的时候不需要全部token z这里需要根据indice拿出来
mask_indices = info['mask_indices'] # 图片seq坐标下的
if opts.re_init:
noise = torch.randn_like(zt)
t = denoise_timesteps[0]
zt_noise = z0 *(1 - t) + noise * t
inp_target["img"] = zt_noise[:, mask_indices,...]
else:
inp_target["img"] = zt[:, mask_indices,...]
info['inverse'] = False
x, _ = denoise_kv(self.model, **inp_target, timesteps=denoise_timesteps, guidance=opts.denoise_guidance, inverse=False, info=info)
# x是根据索引取出来的 再放回去
z0[:, mask_indices,...] = z0[:, mask_indices,...] * (1 - info['mask'][:, mask_indices,...]) + x * info['mask'][:, mask_indices,...]
# x = inp['img'].clone()
# decode latents to pixel space
z0 = unpack(z0.float(), opts.height, opts.width)
del info
return z0 |