nyanko7 commited on
Commit
e11ace5
·
verified ·
1 Parent(s): 33db744

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -29
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import os
2
  import spaces
3
- from dataclasses import dataclass
4
 
5
  import gradio as gr
6
  import torch
@@ -22,6 +21,7 @@ from torch import Tensor, nn
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
24
  from safetensors.torch import load_file
 
25
  # from optimum.quanto import freeze, qfloat8, quantize
26
 
27
 
@@ -216,18 +216,27 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
216
  q, k = apply_rope(q, k, pe)
217
 
218
  x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
219
- x = rearrange(x, "B H L D -> B L (H D)")
 
220
 
221
  return x
222
 
223
 
224
- def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
225
- assert dim % 2 == 0
226
  scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
227
- omega = 1.0 / (theta**scale)
228
- out = torch.einsum("...n,d->...nd", pos, omega)
229
- out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
230
- out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
 
 
 
 
 
 
 
 
 
231
  return out.float()
232
 
233
 
@@ -267,9 +276,12 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
267
  """
268
  t = time_factor * t
269
  half = dim // 2
270
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
271
- t.device
272
- )
 
 
 
273
 
274
  args = t[:, None].float() * freqs[None]
275
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -327,7 +339,10 @@ class SelfAttention(nn.Module):
327
 
328
  def forward(self, x: Tensor, pe: Tensor) -> Tensor:
329
  qkv = self.qkv(x)
330
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
 
 
 
331
  q, k = self.norm(q, k, v)
332
  x = attention(q, k, v, pe=pe)
333
  x = self.proj(x)
@@ -394,14 +409,20 @@ class DoubleStreamBlock(nn.Module):
394
  img_modulated = self.img_norm1(img)
395
  img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
396
  img_qkv = self.img_attn.qkv(img_modulated)
397
- img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
 
 
 
 
398
  img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
399
 
400
  # prepare txt for attention
401
  txt_modulated = self.txt_norm1(txt)
402
  txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
403
  txt_qkv = self.txt_attn.qkv(txt_modulated)
404
- txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
 
 
405
  txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
406
 
407
  # run actual attention
@@ -460,7 +481,9 @@ class SingleStreamBlock(nn.Module):
460
  x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
461
  qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
462
 
463
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
 
 
464
  q, k = self.norm(q, k, v)
465
 
466
  # compute attention
@@ -677,9 +700,7 @@ def denoise(
677
  timesteps=t_vec,
678
  guidance=guidance_vec,
679
  )
680
-
681
  img = img + (t_prev - t_curr) * pred
682
-
683
  return img
684
 
685
 
@@ -723,7 +744,7 @@ from safetensors.torch import load_file
723
 
724
  sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4.safetensors"))
725
  sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
726
- model = Flux().to(dtype=torch.float16, device="cuda")
727
  result = model.load_state_dict(sd)
728
  print(result)
729
 
@@ -731,7 +752,7 @@ print(result)
731
  # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
732
 
733
  @spaces.GPU
734
- @torch.inference_mode()
735
  def generate_image(
736
  prompt, width, height, guidance, seed,
737
  do_img2img, init_image, image2image_strength, resize_img,
@@ -742,7 +763,7 @@ def generate_image(
742
 
743
  device = "cuda" if torch.cuda.is_available() else "cpu"
744
  torch_device = torch.device(device)
745
-
746
  global model
747
  model = model.to(torch_device)
748
 
@@ -761,7 +782,7 @@ def generate_image(
761
  generator = torch.Generator(device=device).manual_seed(seed)
762
  x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
763
 
764
- num_steps = 25
765
  timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
766
 
767
  if do_img2img and init_image is not None:
@@ -770,13 +791,16 @@ def generate_image(
770
  timesteps = timesteps[t_idx:]
771
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
772
 
773
- with torch_device:
774
- inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
775
- x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
776
- x = unpack(x.float(), height, width)
777
- with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
778
- x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
779
- x = ae.decode(x).sample
 
 
 
780
 
781
  x = x.clamp(-1, 1)
782
  x = rearrange(x[0], "c h w -> h w c")
 
1
+ # import os
2
  import spaces
 
3
 
4
  import gradio as gr
5
  import torch
 
21
  from transformers import CLIPTextModel, CLIPTokenizer
22
  from transformers import T5EncoderModel, T5Tokenizer
23
  from safetensors.torch import load_file
24
+ # from torch.profiler import profile, record_function, ProfilerActivity
25
  # from optimum.quanto import freeze, qfloat8, quantize
26
 
27
 
 
216
  q, k = apply_rope(q, k, pe)
217
 
218
  x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
219
+ # x = rearrange(x, "B H L D -> B L (H D)")
220
+ x = x.permute(0, 2, 1, 3).contiguous().reshape(x.size(0), x.size(2), -1)
221
 
222
  return x
223
 
224
 
225
+ def rope(pos, dim, theta):
 
226
  scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
227
+ omega = 1.0 / (theta ** scale)
228
+
229
+ # out = torch.einsum("...n,d->...nd", pos, omega)
230
+ out = pos.unsqueeze(-1) * omega.unsqueeze(0)
231
+
232
+ cos_out = torch.cos(out)
233
+ sin_out = torch.sin(out)
234
+ out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
235
+
236
+ # out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
237
+ b, n, d, _ = out.shape
238
+ out = out.view(b, n, d, 2, 2)
239
+
240
  return out.float()
241
 
242
 
 
276
  """
277
  t = time_factor * t
278
  half = dim // 2
279
+
280
+ # Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
281
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
282
+
283
+ # Block CUDA steam, but consistent with official codes:
284
+ # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
285
 
286
  args = t[:, None].float() * freqs[None]
287
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
 
339
 
340
  def forward(self, x: Tensor, pe: Tensor) -> Tensor:
341
  qkv = self.qkv(x)
342
+ # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
343
+ B, L, _ = qkv.shape
344
+ qkv = qkv.view(B, L, 3, self.num_heads, -1)
345
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
346
  q, k = self.norm(q, k, v)
347
  x = attention(q, k, v, pe=pe)
348
  x = self.proj(x)
 
409
  img_modulated = self.img_norm1(img)
410
  img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
411
  img_qkv = self.img_attn.qkv(img_modulated)
412
+ # img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
413
+ B, L, _ = img_qkv.shape
414
+ H = self.num_heads
415
+ D = img_qkv.shape[-1] // (3 * H)
416
+ img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
417
  img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
418
 
419
  # prepare txt for attention
420
  txt_modulated = self.txt_norm1(txt)
421
  txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
422
  txt_qkv = self.txt_attn.qkv(txt_modulated)
423
+ # txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
424
+ B, L, _ = txt_qkv.shape
425
+ txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
426
  txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
427
 
428
  # run actual attention
 
481
  x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
482
  qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
483
 
484
+ # q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
485
+ qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
486
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
487
  q, k = self.norm(q, k, v)
488
 
489
  # compute attention
 
700
  timesteps=t_vec,
701
  guidance=guidance_vec,
702
  )
 
703
  img = img + (t_prev - t_curr) * pred
 
704
  return img
705
 
706
 
 
744
 
745
  sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4.safetensors"))
746
  sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
747
+ model = Flux().to(dtype=torch.bfloat16, device="cuda")
748
  result = model.load_state_dict(sd)
749
  print(result)
750
 
 
752
  # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
753
 
754
  @spaces.GPU
755
+ @torch.no_grad()
756
  def generate_image(
757
  prompt, width, height, guidance, seed,
758
  do_img2img, init_image, image2image_strength, resize_img,
 
763
 
764
  device = "cuda" if torch.cuda.is_available() else "cpu"
765
  torch_device = torch.device(device)
766
+
767
  global model
768
  model = model.to(torch_device)
769
 
 
782
  generator = torch.Generator(device=device).manual_seed(seed)
783
  x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
784
 
785
+ num_steps = 20
786
  timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
787
 
788
  if do_img2img and init_image is not None:
 
791
  timesteps = timesteps[t_idx:]
792
  x = t * x + (1.0 - t) * init_image.to(x.dtype)
793
 
794
+ inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
795
+ x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
796
+
797
+ # with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
798
+ # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
799
+
800
+ x = unpack(x.float(), height, width)
801
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
802
+ x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
803
+ x = ae.decode(x).sample
804
 
805
  x = x.clamp(-1, 1)
806
  x = rearrange(x[0], "c h w -> h w c")