Spaces:
Running
on
Zero
Running
on
Zero
Update model/flux.py
Browse files- model/flux.py +280 -0
model/flux.py
CHANGED
|
@@ -10,6 +10,7 @@ from scepter.modules.utils.config import dict_to_yaml
|
|
| 10 |
from scepter.modules.utils.distribute import we
|
| 11 |
from scepter.modules.utils.file_system import FS
|
| 12 |
from torch import Tensor, nn
|
|
|
|
| 13 |
from torch.utils.checkpoint import checkpoint_sequential
|
| 14 |
|
| 15 |
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
@@ -404,4 +405,283 @@ class Flux(BaseModel):
|
|
| 404 |
return dict_to_yaml('MODEL',
|
| 405 |
__class__.__name__,
|
| 406 |
Flux.para_dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
set_name=True)
|
|
|
|
| 10 |
from scepter.modules.utils.distribute import we
|
| 11 |
from scepter.modules.utils.file_system import FS
|
| 12 |
from torch import Tensor, nn
|
| 13 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 14 |
from torch.utils.checkpoint import checkpoint_sequential
|
| 15 |
|
| 16 |
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
|
|
|
| 405 |
return dict_to_yaml('MODEL',
|
| 406 |
__class__.__name__,
|
| 407 |
Flux.para_dict,
|
| 408 |
+
set_name=True)
|
| 409 |
+
|
| 410 |
+
@BACKBONES.register_class()
|
| 411 |
+
class FluxMR(Flux):
|
| 412 |
+
def prepare_input(self, x, cond):
|
| 413 |
+
if isinstance(cond['context'], list):
|
| 414 |
+
context, y = torch.cat(cond["context"], dim=0).to(x), torch.cat(cond["y"], dim=0).to(x)
|
| 415 |
+
else:
|
| 416 |
+
context, y = cond['context'].to(x), cond['y'].to(x)
|
| 417 |
+
batch_frames, batch_frames_ids = [], []
|
| 418 |
+
for ix, shape in zip(x, cond["x_shapes"]):
|
| 419 |
+
# unpack image from sequence
|
| 420 |
+
ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
|
| 421 |
+
c, h, w = ix.shape
|
| 422 |
+
ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
| 423 |
+
ix_id = torch.zeros(h // 2, w // 2, 3)
|
| 424 |
+
ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
|
| 425 |
+
ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
|
| 426 |
+
ix_id = rearrange(ix_id, "h w c -> (h w) c")
|
| 427 |
+
batch_frames.append([ix])
|
| 428 |
+
batch_frames_ids.append([ix_id])
|
| 429 |
+
|
| 430 |
+
x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
|
| 431 |
+
for frames, frame_ids in zip(batch_frames, batch_frames_ids):
|
| 432 |
+
proj_frames = []
|
| 433 |
+
for idx, one_frame in enumerate(frames):
|
| 434 |
+
one_frame = self.img_in(one_frame)
|
| 435 |
+
proj_frames.append(one_frame)
|
| 436 |
+
ix = torch.cat(proj_frames, dim=0)
|
| 437 |
+
if_id = torch.cat(frame_ids, dim=0)
|
| 438 |
+
x_list.append(ix)
|
| 439 |
+
x_id_list.append(if_id)
|
| 440 |
+
mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
|
| 441 |
+
x_seq_length.append(ix.shape[0])
|
| 442 |
+
x = pad_sequence(tuple(x_list), batch_first=True)
|
| 443 |
+
x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
|
| 444 |
+
mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
|
| 445 |
+
|
| 446 |
+
txt = self.txt_in(context)
|
| 447 |
+
txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
|
| 448 |
+
mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
|
| 449 |
+
|
| 450 |
+
return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
|
| 451 |
+
|
| 452 |
+
def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
|
| 453 |
+
x_list = []
|
| 454 |
+
image_shapes = cond["x_shapes"]
|
| 455 |
+
for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
|
| 456 |
+
height, width = shape
|
| 457 |
+
h, w = math.ceil(height / 2), math.ceil(width / 2)
|
| 458 |
+
u = rearrange(
|
| 459 |
+
u[seq_length-h*w:seq_length, ...],
|
| 460 |
+
"(h w) (c ph pw) -> (h ph w pw) c",
|
| 461 |
+
h=h,
|
| 462 |
+
w=w,
|
| 463 |
+
ph=2,
|
| 464 |
+
pw=2,
|
| 465 |
+
)
|
| 466 |
+
x_list.append(u)
|
| 467 |
+
x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
|
| 468 |
+
return x
|
| 469 |
+
|
| 470 |
+
def forward(
|
| 471 |
+
self,
|
| 472 |
+
x: Tensor,
|
| 473 |
+
t: Tensor,
|
| 474 |
+
cond: dict = {},
|
| 475 |
+
guidance: Tensor | None = None,
|
| 476 |
+
gc_seg: int = 0,
|
| 477 |
+
**kwargs
|
| 478 |
+
) -> Tensor:
|
| 479 |
+
x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
|
| 480 |
+
# running on sequences img
|
| 481 |
+
vec = self.time_in(timestep_embedding(t, 256))
|
| 482 |
+
if self.guidance_embed:
|
| 483 |
+
if guidance is None:
|
| 484 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
| 485 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 486 |
+
vec = vec + self.vector_in(y)
|
| 487 |
+
ids = torch.cat((txt_ids, x_ids), dim=1)
|
| 488 |
+
pe = self.pe_embedder(ids)
|
| 489 |
+
|
| 490 |
+
mask_aside = torch.cat((mask_txt, mask_x), dim=1)
|
| 491 |
+
mask = mask_aside[:, None, :] * mask_aside[:, :, None]
|
| 492 |
+
|
| 493 |
+
kwargs = dict(
|
| 494 |
+
vec=vec,
|
| 495 |
+
pe=pe,
|
| 496 |
+
mask=mask,
|
| 497 |
+
txt_length = txt.shape[1],
|
| 498 |
+
)
|
| 499 |
+
x = torch.cat((txt, x), 1)
|
| 500 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
| 501 |
+
x = checkpoint_sequential(
|
| 502 |
+
functions=[partial(block, **kwargs) for block in self.double_blocks],
|
| 503 |
+
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
|
| 504 |
+
input=x,
|
| 505 |
+
use_reentrant=False
|
| 506 |
+
)
|
| 507 |
+
else:
|
| 508 |
+
for block in self.double_blocks:
|
| 509 |
+
x = block(x, **kwargs)
|
| 510 |
+
|
| 511 |
+
kwargs = dict(
|
| 512 |
+
vec=vec,
|
| 513 |
+
pe=pe,
|
| 514 |
+
mask=mask,
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
| 518 |
+
x = checkpoint_sequential(
|
| 519 |
+
functions=[partial(block, **kwargs) for block in self.single_blocks],
|
| 520 |
+
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
|
| 521 |
+
input=x,
|
| 522 |
+
use_reentrant=False
|
| 523 |
+
)
|
| 524 |
+
else:
|
| 525 |
+
for block in self.single_blocks:
|
| 526 |
+
x = block(x, **kwargs)
|
| 527 |
+
x = x[:, txt.shape[1]:, ...]
|
| 528 |
+
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
|
| 529 |
+
x = self.unpack(x, cond, seq_length_list)
|
| 530 |
+
return x
|
| 531 |
+
|
| 532 |
+
@staticmethod
|
| 533 |
+
def get_config_template():
|
| 534 |
+
return dict_to_yaml('MODEL',
|
| 535 |
+
__class__.__name__,
|
| 536 |
+
FluxEdit.para_dict,
|
| 537 |
+
set_name=True)
|
| 538 |
+
@BACKBONES.register_class()
|
| 539 |
+
class FluxEdit(FluxMR):
|
| 540 |
+
def prepare_input(self, x, cond, *args, **kwargs):
|
| 541 |
+
context, y = cond["context"], cond["y"]
|
| 542 |
+
batch_frames, batch_frames_ids, batch_shift = [], [], []
|
| 543 |
+
|
| 544 |
+
for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
|
| 545 |
+
# unpack image from sequence
|
| 546 |
+
ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
|
| 547 |
+
c, h, w = ix.shape
|
| 548 |
+
ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
| 549 |
+
ix_id = torch.zeros(h // 2, w // 2, 3)
|
| 550 |
+
ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
|
| 551 |
+
ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
|
| 552 |
+
batch_shift.append(h // 2) #if is_align < 1 else batch_shift.append(0)
|
| 553 |
+
ix_id = rearrange(ix_id, "h w c -> (h w) c")
|
| 554 |
+
batch_frames.append([ix])
|
| 555 |
+
batch_frames_ids.append([ix_id])
|
| 556 |
+
if 'edit_x' in cond:
|
| 557 |
+
for i, edit in enumerate(cond['edit_x']):
|
| 558 |
+
if edit is None:
|
| 559 |
+
continue
|
| 560 |
+
for ie in edit:
|
| 561 |
+
ie = ie.squeeze(0)
|
| 562 |
+
c, h, w = ie.shape
|
| 563 |
+
ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
| 564 |
+
ie_id = torch.zeros(h // 2, w // 2, 3)
|
| 565 |
+
ie_id[..., 1] = ie_id[..., 1] + torch.arange(batch_shift[i], h // 2 + batch_shift[i])[:, None]
|
| 566 |
+
ie_id[..., 2] = ie_id[..., 2] + torch.arange(w // 2)[None, :]
|
| 567 |
+
ie_id = rearrange(ie_id, "h w c -> (h w) c")
|
| 568 |
+
batch_frames[i].append(ie)
|
| 569 |
+
batch_frames_ids[i].append(ie_id)
|
| 570 |
+
|
| 571 |
+
x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
|
| 572 |
+
for frames, frame_ids in zip(batch_frames, batch_frames_ids):
|
| 573 |
+
proj_frames = []
|
| 574 |
+
for idx, one_frame in enumerate(frames):
|
| 575 |
+
one_frame = self.img_in(one_frame)
|
| 576 |
+
proj_frames.append(one_frame)
|
| 577 |
+
ix = torch.cat(proj_frames, dim=0)
|
| 578 |
+
if_id = torch.cat(frame_ids, dim=0)
|
| 579 |
+
x_list.append(ix)
|
| 580 |
+
x_id_list.append(if_id)
|
| 581 |
+
mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
|
| 582 |
+
x_seq_length.append(ix.shape[0])
|
| 583 |
+
x = pad_sequence(tuple(x_list), batch_first=True)
|
| 584 |
+
x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
|
| 585 |
+
mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
|
| 586 |
+
|
| 587 |
+
txt_list, mask_txt_list, y_list = [], [], []
|
| 588 |
+
for sample_id, (ctx, yy) in enumerate(zip(context, y)):
|
| 589 |
+
ctx_batch = []
|
| 590 |
+
for frame_id, one_ctx in enumerate(ctx):
|
| 591 |
+
one_ctx = self.txt_in(one_ctx.to(x))
|
| 592 |
+
ctx_batch.append(one_ctx)
|
| 593 |
+
txt_list.append(torch.cat(ctx_batch, dim=0))
|
| 594 |
+
mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
|
| 595 |
+
y_list.append(yy.mean(dim = 0, keepdim=True))
|
| 596 |
+
txt = pad_sequence(tuple(txt_list), batch_first=True)
|
| 597 |
+
txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
|
| 598 |
+
mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
|
| 599 |
+
y = torch.cat(y_list, dim=0)
|
| 600 |
+
return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
|
| 601 |
+
|
| 602 |
+
def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
|
| 603 |
+
x_list = []
|
| 604 |
+
image_shapes = cond["x_shapes"]
|
| 605 |
+
for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
|
| 606 |
+
height, width = shape
|
| 607 |
+
h, w = math.ceil(height / 2), math.ceil(width / 2)
|
| 608 |
+
u = rearrange(
|
| 609 |
+
u[:h*w, ...],
|
| 610 |
+
"(h w) (c ph pw) -> (h ph w pw) c",
|
| 611 |
+
h=h,
|
| 612 |
+
w=w,
|
| 613 |
+
ph=2,
|
| 614 |
+
pw=2,
|
| 615 |
+
)
|
| 616 |
+
x_list.append(u)
|
| 617 |
+
x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
|
| 618 |
+
return x
|
| 619 |
+
|
| 620 |
+
def forward(
|
| 621 |
+
self,
|
| 622 |
+
x: Tensor,
|
| 623 |
+
t: Tensor,
|
| 624 |
+
cond: dict = {},
|
| 625 |
+
guidance: Tensor | None = None,
|
| 626 |
+
gc_seg: int = 0,
|
| 627 |
+
text_position_embeddings = None
|
| 628 |
+
) -> Tensor:
|
| 629 |
+
x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond, text_position_embeddings)
|
| 630 |
+
# running on sequences img
|
| 631 |
+
vec = self.time_in(timestep_embedding(t, 256))
|
| 632 |
+
if self.guidance_embed:
|
| 633 |
+
if guidance is None:
|
| 634 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
| 635 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 636 |
+
vec = vec + self.vector_in(y)
|
| 637 |
+
ids = torch.cat((txt_ids, x_ids), dim=1)
|
| 638 |
+
pe = self.pe_embedder(ids)
|
| 639 |
+
|
| 640 |
+
mask_aside = torch.cat((mask_txt, mask_x), dim=1)
|
| 641 |
+
mask = mask_aside[:, None, :] * mask_aside[:, :, None]
|
| 642 |
+
|
| 643 |
+
kwargs = dict(
|
| 644 |
+
vec=vec,
|
| 645 |
+
pe=pe,
|
| 646 |
+
mask=mask,
|
| 647 |
+
txt_length = txt.shape[1],
|
| 648 |
+
)
|
| 649 |
+
x = torch.cat((txt, x), 1)
|
| 650 |
+
|
| 651 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
| 652 |
+
x = checkpoint_sequential(
|
| 653 |
+
functions=[partial(block, **kwargs) for block in self.double_blocks],
|
| 654 |
+
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
|
| 655 |
+
input=x,
|
| 656 |
+
use_reentrant=False
|
| 657 |
+
)
|
| 658 |
+
else:
|
| 659 |
+
for block in self.double_blocks:
|
| 660 |
+
x = block(x, **kwargs)
|
| 661 |
+
|
| 662 |
+
kwargs = dict(
|
| 663 |
+
vec=vec,
|
| 664 |
+
pe=pe,
|
| 665 |
+
mask=mask,
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
if self.use_grad_checkpoint and gc_seg >= 0:
|
| 669 |
+
x = checkpoint_sequential(
|
| 670 |
+
functions=[partial(block, **kwargs) for block in self.single_blocks],
|
| 671 |
+
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
|
| 672 |
+
input=x,
|
| 673 |
+
use_reentrant=False
|
| 674 |
+
)
|
| 675 |
+
else:
|
| 676 |
+
for block in self.single_blocks:
|
| 677 |
+
x = block(x, **kwargs)
|
| 678 |
+
x = x[:, txt.shape[1]:, ...]
|
| 679 |
+
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
|
| 680 |
+
x = self.unpack(x, cond, seq_length_list)
|
| 681 |
+
return x
|
| 682 |
+
@staticmethod
|
| 683 |
+
def get_config_template():
|
| 684 |
+
return dict_to_yaml('MODEL',
|
| 685 |
+
__class__.__name__,
|
| 686 |
+
FluxEdit.para_dict,
|
| 687 |
set_name=True)
|