chaojiemao commited on
Commit
097bf68
·
verified ·
1 Parent(s): ec43f9b

Create model/flux.py

Browse files
Files changed (1) hide show
  1. model/flux.py +407 -0
model/flux.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from functools import partial
4
+
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from scepter.modules.model.base_model import BaseModel
8
+ from scepter.modules.model.registry import BACKBONES
9
+ from scepter.modules.utils.config import dict_to_yaml
10
+ from scepter.modules.utils.distribute import we
11
+ from scepter.modules.utils.file_system import FS
12
+ from torch import Tensor, nn
13
+ from torch.utils.checkpoint import checkpoint_sequential
14
+
15
+ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
16
+ MLPEmbedder, SingleStreamBlock,
17
+ timestep_embedding)
18
+
19
+ @BACKBONES.register_class()
20
+ class Flux(BaseModel):
21
+ """
22
+ Transformer backbone Diffusion model with RoPE.
23
+ """
24
+ para_dict = {
25
+ "IN_CHANNELS": {
26
+ "value": 64,
27
+ "description": "model's input channels."
28
+ },
29
+ "OUT_CHANNELS": {
30
+ "value": 64,
31
+ "description": "model's output channels."
32
+ },
33
+ "HIDDEN_SIZE": {
34
+ "value": 1024,
35
+ "description": "model's hidden size."
36
+ },
37
+ "NUM_HEADS": {
38
+ "value": 16,
39
+ "description": "number of heads in the transformer."
40
+ },
41
+ "AXES_DIM": {
42
+ "value": [16, 56, 56],
43
+ "description": "dimensions of the axes of the positional encoding."
44
+ },
45
+ "THETA": {
46
+ "value": 10_000,
47
+ "description": "theta for positional encoding."
48
+ },
49
+ "VEC_IN_DIM": {
50
+ "value": 768,
51
+ "description": "dimension of the vector input."
52
+ },
53
+ "GUIDANCE_EMBED": {
54
+ "value": False,
55
+ "description": "whether to use guidance embedding."
56
+ },
57
+ "CONTEXT_IN_DIM": {
58
+ "value": 4096,
59
+ "description": "dimension of the context input."
60
+ },
61
+ "MLP_RATIO": {
62
+ "value": 4.0,
63
+ "description": "ratio of mlp hidden size to hidden size."
64
+ },
65
+ "QKV_BIAS": {
66
+ "value": True,
67
+ "description": "whether to use bias in qkv projection."
68
+ },
69
+ "DEPTH": {
70
+ "value": 19,
71
+ "description": "number of transformer blocks."
72
+ },
73
+ "DEPTH_SINGLE_BLOCKS": {
74
+ "value": 38,
75
+ "description": "number of transformer blocks in the single stream block."
76
+ },
77
+ "USE_GRAD_CHECKPOINT": {
78
+ "value": False,
79
+ "description": "whether to use gradient checkpointing."
80
+ },
81
+ "ATTN_BACKEND": {
82
+ "value": "pytorch",
83
+ "description": "backend for the transformer blocks, 'pytorch' or 'flash_attn'."
84
+ }
85
+ }
86
+ def __init__(
87
+ self,
88
+ cfg,
89
+ logger = None
90
+ ):
91
+ super().__init__(cfg, logger=logger)
92
+ self.in_channels = cfg.IN_CHANNELS
93
+ self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
94
+ hidden_size = cfg.get("HIDDEN_SIZE", 1024)
95
+ num_heads = cfg.get("NUM_HEADS", 16)
96
+ axes_dim = cfg.AXES_DIM
97
+ theta = cfg.THETA
98
+ vec_in_dim = cfg.VEC_IN_DIM
99
+ self.guidance_embed = cfg.GUIDANCE_EMBED
100
+ context_in_dim = cfg.CONTEXT_IN_DIM
101
+ mlp_ratio = cfg.MLP_RATIO
102
+ qkv_bias = cfg.QKV_BIAS
103
+ depth = cfg.DEPTH
104
+ depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
105
+ self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
106
+ self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
107
+ self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
108
+ self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
109
+ self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
110
+
111
+ if hidden_size % num_heads != 0:
112
+ raise ValueError(
113
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
114
+ )
115
+ pe_dim = hidden_size // num_heads
116
+ if sum(axes_dim) != pe_dim:
117
+ raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
118
+ self.hidden_size = hidden_size
119
+ self.num_heads = num_heads
120
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim= axes_dim)
121
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
122
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
123
+ self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
124
+ self.guidance_in = (
125
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
126
+ )
127
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
128
+
129
+ self.double_blocks = nn.ModuleList(
130
+ [
131
+ DoubleStreamBlock(
132
+ self.hidden_size,
133
+ self.num_heads,
134
+ mlp_ratio=mlp_ratio,
135
+ qkv_bias=qkv_bias,
136
+ backend=self.attn_backend
137
+ )
138
+ for _ in range(depth)
139
+ ]
140
+ )
141
+
142
+ self.single_blocks = nn.ModuleList(
143
+ [
144
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
145
+ for _ in range(depth_single_blocks)
146
+ ]
147
+ )
148
+
149
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
150
+
151
+ def prepare_input(self, x, context, y, x_shape=None):
152
+ # x.shape [6, 16, 16, 16] target is [6, 16, 768, 1360]
153
+ bs, c, h, w = x.shape
154
+ x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
155
+ x_id = torch.zeros(h // 2, w // 2, 3)
156
+ x_id[..., 1] = x_id[..., 1] + torch.arange(h // 2)[:, None]
157
+ x_id[..., 2] = x_id[..., 2] + torch.arange(w // 2)[None, :]
158
+ x_ids = repeat(x_id, "h w c -> b (h w) c", b=bs)
159
+ txt_ids = torch.zeros(bs, context.shape[1], 3)
160
+ return x, x_ids.to(x), context.to(x), txt_ids.to(x), y.to(x), h, w
161
+
162
+ def unpack(self, x: Tensor, height: int, width: int) -> Tensor:
163
+ return rearrange(
164
+ x,
165
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
166
+ h=math.ceil(height/2),
167
+ w=math.ceil(width/2),
168
+ ph=2,
169
+ pw=2,
170
+ )
171
+
172
+ def merge_diffuser_lora(self, ori_sd, lora_sd, scale = 1.0):
173
+ key_map = {
174
+ "single_blocks.{}.linear1.weight": {"key_list": [
175
+ ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
176
+ "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight"],
177
+ ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
178
+ "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight"],
179
+ ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
180
+ "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight"],
181
+ ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
182
+ "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight"]
183
+ ], "num": 38},
184
+ "single_blocks.{}.modulation.lin.weight": {"key_list": [
185
+ ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
186
+ "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight"],
187
+ ], "num": 38},
188
+ "single_blocks.{}.linear2.weight": {"key_list": [
189
+ ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
190
+ "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight"],
191
+ ], "num": 38},
192
+ "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
193
+ ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
194
+ "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight"],
195
+ ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
196
+ "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight"],
197
+ ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
198
+ "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight"],
199
+ ], "num": 19},
200
+ "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
201
+ ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
202
+ "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight"],
203
+ ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
204
+ "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight"],
205
+ ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
206
+ "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight"],
207
+ ], "num": 19},
208
+ "double_blocks.{}.img_attn.proj.weight": {"key_list": [
209
+ ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
210
+ "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight"]
211
+ ], "num": 19},
212
+ "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
213
+ ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
214
+ "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight"]
215
+ ], "num": 19},
216
+ "double_blocks.{}.img_mlp.0.weight": {"key_list": [
217
+ ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
218
+ "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight"]
219
+ ], "num": 19},
220
+ "double_blocks.{}.img_mlp.2.weight": {"key_list": [
221
+ ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
222
+ "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight"]
223
+ ], "num": 19},
224
+ "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
225
+ ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
226
+ "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight"]
227
+ ], "num": 19},
228
+ "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
229
+ ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
230
+ "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight"]
231
+ ], "num": 19},
232
+ "double_blocks.{}.img_mod.lin.weight": {"key_list": [
233
+ ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
234
+ "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight"]
235
+ ], "num": 19},
236
+ "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
237
+ ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
238
+ "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight"]
239
+ ], "num": 19}
240
+ }
241
+ for k, v in key_map.items():
242
+ key_list = v["key_list"]
243
+ block_num = v["num"]
244
+ for block_id in range(block_num):
245
+ current_weight_list = []
246
+ for k_list in key_list:
247
+ current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
248
+ lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
249
+ current_weight_list.append(current_weight)
250
+ current_weight = torch.cat(current_weight_list, dim=0)
251
+ ori_sd[k.format(block_id)] += scale*current_weight
252
+ return ori_sd
253
+
254
+ def merge_swift_lora(self, ori_sd, lora_sd, scale = 1.0):
255
+ have_lora_keys = {}
256
+ for k, v in lora_sd.items():
257
+ k = k[len("model."):] if k.startswith("model.") else k
258
+ ori_key = k.split("lora")[0] + "weight"
259
+ if ori_key not in ori_sd:
260
+ raise f"{ori_key} should in the original statedict"
261
+ if ori_key not in have_lora_keys:
262
+ have_lora_keys[ori_key] = {}
263
+ if "lora_A" in k:
264
+ have_lora_keys[ori_key]["lora_A"] = v
265
+ elif "lora_B" in k:
266
+ have_lora_keys[ori_key]["lora_B"] = v
267
+ else:
268
+ raise NotImplementedError
269
+ for key, v in have_lora_keys.items():
270
+ current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
271
+ ori_sd[key] += scale * current_weight
272
+ return ori_sd
273
+
274
+
275
+ def load_pretrained_model(self, pretrained_model):
276
+ if next(self.parameters()).device.type == 'meta':
277
+ map_location = we.device_id
278
+ else:
279
+ map_location = "cpu"
280
+ if self.lora_model is not None:
281
+ map_location = we.device_id
282
+ if pretrained_model is not None:
283
+ with FS.get_from(pretrained_model, wait_finish=True) as local_model:
284
+ if local_model.endswith('safetensors'):
285
+ from safetensors.torch import load_file as load_safetensors
286
+ sd = load_safetensors(local_model, device=map_location)
287
+ else:
288
+ sd = torch.load(local_model, map_location=map_location)
289
+ if "state_dict" in sd:
290
+ sd = sd["state_dict"]
291
+ if "model" in sd:
292
+ sd = sd["model"]["model"]
293
+
294
+ if self.lora_model is not None:
295
+ with FS.get_from(self.lora_model, wait_finish=True) as local_model:
296
+ if local_model.endswith('safetensors'):
297
+ from safetensors.torch import load_file as load_safetensors
298
+ lora_sd = load_safetensors(local_model, device=map_location)
299
+ else:
300
+ lora_sd = torch.load(local_model, map_location=map_location)
301
+ sd = self.merge_diffuser_lora(sd, lora_sd)
302
+ if self.swift_lora_model is not None:
303
+ with FS.get_from(self.swift_lora_model, wait_finish=True) as local_model:
304
+ if local_model.endswith('safetensors'):
305
+ from safetensors.torch import load_file as load_safetensors
306
+ lora_sd = load_safetensors(local_model, device=map_location)
307
+ else:
308
+ lora_sd = torch.load(local_model, map_location=map_location)
309
+ sd = self.merge_swift_lora(sd, lora_sd)
310
+
311
+ adapter_ckpt = {}
312
+ if self.pretrain_adapter is not None:
313
+ with FS.get_from(self.pretrain_adapter, wait_finish=True) as local_adapter:
314
+ if local_model.endswith('safetensors'):
315
+ from safetensors.torch import load_file as load_safetensors
316
+ adapter_ckpt = load_safetensors(local_adapter, device=map_location)
317
+ else:
318
+ adapter_ckpt = torch.load(local_adapter, map_location=map_location)
319
+ sd.update(adapter_ckpt)
320
+
321
+
322
+ new_ckpt = OrderedDict()
323
+ for k, v in sd.items():
324
+ if k in ("img_in.weight"):
325
+ model_p = self.state_dict()[k]
326
+ if v.shape != model_p.shape:
327
+ model_p.zero_()
328
+ model_p[:, :64].copy_(v[:, :64])
329
+ new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
330
+ else:
331
+ new_ckpt[k] = v
332
+ else:
333
+ new_ckpt[k] = v
334
+
335
+
336
+ missing, unexpected = self.load_state_dict(new_ckpt, strict=False, assign=True)
337
+ self.logger.info(
338
+ f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
339
+ )
340
+ if len(missing) > 0:
341
+ self.logger.info(f'Missing Keys:\n {missing}')
342
+ if len(unexpected) > 0:
343
+ self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
344
+
345
+ def forward(
346
+ self,
347
+ x: Tensor,
348
+ t: Tensor,
349
+ cond: dict = {},
350
+ guidance: Tensor | None = None,
351
+ gc_seg: int = 0
352
+ ) -> Tensor:
353
+ x, x_ids, txt, txt_ids, y, h, w = self.prepare_input(x, cond["context"], cond["y"])
354
+ # running on sequences img
355
+ x = self.img_in(x)
356
+ vec = self.time_in(timestep_embedding(t, 256))
357
+ if self.guidance_embed:
358
+ if guidance is None:
359
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
360
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
361
+ vec = vec + self.vector_in(y)
362
+ txt = self.txt_in(txt)
363
+ ids = torch.cat((txt_ids, x_ids), dim=1)
364
+ pe = self.pe_embedder(ids)
365
+ kwargs = dict(
366
+ vec=vec,
367
+ pe=pe,
368
+ txt_length=txt.shape[1],
369
+ )
370
+ x = torch.cat((txt, x), 1)
371
+ if self.use_grad_checkpoint and gc_seg >= 0:
372
+ x = checkpoint_sequential(
373
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
374
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
375
+ input=x,
376
+ use_reentrant=False
377
+ )
378
+ else:
379
+ for block in self.double_blocks:
380
+ x = block(x, **kwargs)
381
+
382
+ kwargs = dict(
383
+ vec=vec,
384
+ pe=pe,
385
+ )
386
+
387
+ if self.use_grad_checkpoint and gc_seg >= 0:
388
+ x = checkpoint_sequential(
389
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
390
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
391
+ input=x,
392
+ use_reentrant=False
393
+ )
394
+ else:
395
+ for block in self.single_blocks:
396
+ x = block(x, **kwargs)
397
+ x = x[:, txt.shape[1] :, ...]
398
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
399
+ x = self.unpack(x, h, w)
400
+ return x
401
+
402
+ @staticmethod
403
+ def get_config_template():
404
+ return dict_to_yaml('MODEL',
405
+ __class__.__name__,
406
+ Flux.para_dict,
407
+ set_name=True)