Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
-
import os
|
2 |
import spaces
|
3 |
-
from dataclasses import dataclass
|
4 |
|
5 |
import gradio as gr
|
6 |
import torch
|
@@ -22,6 +21,7 @@ from torch import Tensor, nn
|
|
22 |
from transformers import CLIPTextModel, CLIPTokenizer
|
23 |
from transformers import T5EncoderModel, T5Tokenizer
|
24 |
from safetensors.torch import load_file
|
|
|
25 |
# from optimum.quanto import freeze, qfloat8, quantize
|
26 |
|
27 |
|
@@ -216,18 +216,27 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
|
216 |
q, k = apply_rope(q, k, pe)
|
217 |
|
218 |
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
219 |
-
x = rearrange(x, "B H L D -> B L (H D)")
|
|
|
220 |
|
221 |
return x
|
222 |
|
223 |
|
224 |
-
def rope(pos
|
225 |
-
assert dim % 2 == 0
|
226 |
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
227 |
-
omega = 1.0 / (theta**scale)
|
228 |
-
|
229 |
-
out = torch.
|
230 |
-
out =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
return out.float()
|
232 |
|
233 |
|
@@ -267,9 +276,12 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|
267 |
"""
|
268 |
t = time_factor * t
|
269 |
half = dim // 2
|
270 |
-
|
271 |
-
|
272 |
-
)
|
|
|
|
|
|
|
273 |
|
274 |
args = t[:, None].float() * freqs[None]
|
275 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
@@ -327,7 +339,10 @@ class SelfAttention(nn.Module):
|
|
327 |
|
328 |
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
329 |
qkv = self.qkv(x)
|
330 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
|
|
|
|
|
331 |
q, k = self.norm(q, k, v)
|
332 |
x = attention(q, k, v, pe=pe)
|
333 |
x = self.proj(x)
|
@@ -394,14 +409,20 @@ class DoubleStreamBlock(nn.Module):
|
|
394 |
img_modulated = self.img_norm1(img)
|
395 |
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
396 |
img_qkv = self.img_attn.qkv(img_modulated)
|
397 |
-
|
|
|
|
|
|
|
|
|
398 |
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
399 |
|
400 |
# prepare txt for attention
|
401 |
txt_modulated = self.txt_norm1(txt)
|
402 |
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
403 |
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
404 |
-
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
|
|
|
405 |
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
406 |
|
407 |
# run actual attention
|
@@ -460,7 +481,9 @@ class SingleStreamBlock(nn.Module):
|
|
460 |
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
461 |
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
462 |
|
463 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
|
|
|
464 |
q, k = self.norm(q, k, v)
|
465 |
|
466 |
# compute attention
|
@@ -677,9 +700,7 @@ def denoise(
|
|
677 |
timesteps=t_vec,
|
678 |
guidance=guidance_vec,
|
679 |
)
|
680 |
-
|
681 |
img = img + (t_prev - t_curr) * pred
|
682 |
-
|
683 |
return img
|
684 |
|
685 |
|
@@ -723,7 +744,7 @@ from safetensors.torch import load_file
|
|
723 |
|
724 |
sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4.safetensors"))
|
725 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
726 |
-
model = Flux().to(dtype=torch.
|
727 |
result = model.load_state_dict(sd)
|
728 |
print(result)
|
729 |
|
@@ -731,7 +752,7 @@ print(result)
|
|
731 |
# result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
|
732 |
|
733 |
@spaces.GPU
|
734 |
-
@torch.
|
735 |
def generate_image(
|
736 |
prompt, width, height, guidance, seed,
|
737 |
do_img2img, init_image, image2image_strength, resize_img,
|
@@ -742,7 +763,7 @@ def generate_image(
|
|
742 |
|
743 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
744 |
torch_device = torch.device(device)
|
745 |
-
|
746 |
global model
|
747 |
model = model.to(torch_device)
|
748 |
|
@@ -761,7 +782,7 @@ def generate_image(
|
|
761 |
generator = torch.Generator(device=device).manual_seed(seed)
|
762 |
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
|
763 |
|
764 |
-
num_steps =
|
765 |
timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
|
766 |
|
767 |
if do_img2img and init_image is not None:
|
@@ -770,13 +791,16 @@ def generate_image(
|
|
770 |
timesteps = timesteps[t_idx:]
|
771 |
x = t * x + (1.0 - t) * init_image.to(x.dtype)
|
772 |
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
|
|
|
|
|
|
780 |
|
781 |
x = x.clamp(-1, 1)
|
782 |
x = rearrange(x[0], "c h w -> h w c")
|
|
|
1 |
+
# import os
|
2 |
import spaces
|
|
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
|
|
21 |
from transformers import CLIPTextModel, CLIPTokenizer
|
22 |
from transformers import T5EncoderModel, T5Tokenizer
|
23 |
from safetensors.torch import load_file
|
24 |
+
# from torch.profiler import profile, record_function, ProfilerActivity
|
25 |
# from optimum.quanto import freeze, qfloat8, quantize
|
26 |
|
27 |
|
|
|
216 |
q, k = apply_rope(q, k, pe)
|
217 |
|
218 |
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
219 |
+
# x = rearrange(x, "B H L D -> B L (H D)")
|
220 |
+
x = x.permute(0, 2, 1, 3).contiguous().reshape(x.size(0), x.size(2), -1)
|
221 |
|
222 |
return x
|
223 |
|
224 |
|
225 |
+
def rope(pos, dim, theta):
|
|
|
226 |
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
227 |
+
omega = 1.0 / (theta ** scale)
|
228 |
+
|
229 |
+
# out = torch.einsum("...n,d->...nd", pos, omega)
|
230 |
+
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
231 |
+
|
232 |
+
cos_out = torch.cos(out)
|
233 |
+
sin_out = torch.sin(out)
|
234 |
+
out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
235 |
+
|
236 |
+
# out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
237 |
+
b, n, d, _ = out.shape
|
238 |
+
out = out.view(b, n, d, 2, 2)
|
239 |
+
|
240 |
return out.float()
|
241 |
|
242 |
|
|
|
276 |
"""
|
277 |
t = time_factor * t
|
278 |
half = dim // 2
|
279 |
+
|
280 |
+
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
|
281 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
282 |
+
|
283 |
+
# Block CUDA steam, but consistent with official codes:
|
284 |
+
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
285 |
|
286 |
args = t[:, None].float() * freqs[None]
|
287 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
|
339 |
|
340 |
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
341 |
qkv = self.qkv(x)
|
342 |
+
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
343 |
+
B, L, _ = qkv.shape
|
344 |
+
qkv = qkv.view(B, L, 3, self.num_heads, -1)
|
345 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
346 |
q, k = self.norm(q, k, v)
|
347 |
x = attention(q, k, v, pe=pe)
|
348 |
x = self.proj(x)
|
|
|
409 |
img_modulated = self.img_norm1(img)
|
410 |
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
411 |
img_qkv = self.img_attn.qkv(img_modulated)
|
412 |
+
# img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
413 |
+
B, L, _ = img_qkv.shape
|
414 |
+
H = self.num_heads
|
415 |
+
D = img_qkv.shape[-1] // (3 * H)
|
416 |
+
img_q, img_k, img_v = img_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
|
417 |
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
418 |
|
419 |
# prepare txt for attention
|
420 |
txt_modulated = self.txt_norm1(txt)
|
421 |
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
422 |
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
423 |
+
# txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
424 |
+
B, L, _ = txt_qkv.shape
|
425 |
+
txt_q, txt_k, txt_v = txt_qkv.view(B, L, 3, H, D).permute(2, 0, 3, 1, 4)
|
426 |
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
427 |
|
428 |
# run actual attention
|
|
|
481 |
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
482 |
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
483 |
|
484 |
+
# q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
485 |
+
qkv = qkv.view(qkv.size(0), qkv.size(1), 3, self.num_heads, self.hidden_size // self.num_heads)
|
486 |
+
q, k, v = qkv.permute(2, 0, 3, 1, 4)
|
487 |
q, k = self.norm(q, k, v)
|
488 |
|
489 |
# compute attention
|
|
|
700 |
timesteps=t_vec,
|
701 |
guidance=guidance_vec,
|
702 |
)
|
|
|
703 |
img = img + (t_prev - t_curr) * pred
|
|
|
704 |
return img
|
705 |
|
706 |
|
|
|
744 |
|
745 |
sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4.safetensors"))
|
746 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
747 |
+
model = Flux().to(dtype=torch.bfloat16, device="cuda")
|
748 |
result = model.load_state_dict(sd)
|
749 |
print(result)
|
750 |
|
|
|
752 |
# result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
|
753 |
|
754 |
@spaces.GPU
|
755 |
+
@torch.no_grad()
|
756 |
def generate_image(
|
757 |
prompt, width, height, guidance, seed,
|
758 |
do_img2img, init_image, image2image_strength, resize_img,
|
|
|
763 |
|
764 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
765 |
torch_device = torch.device(device)
|
766 |
+
|
767 |
global model
|
768 |
model = model.to(torch_device)
|
769 |
|
|
|
782 |
generator = torch.Generator(device=device).manual_seed(seed)
|
783 |
x = torch.randn(1, 16, 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), device=device, dtype=torch.bfloat16, generator=generator)
|
784 |
|
785 |
+
num_steps = 20
|
786 |
timesteps = get_schedule(num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=True)
|
787 |
|
788 |
if do_img2img and init_image is not None:
|
|
|
791 |
timesteps = timesteps[t_idx:]
|
792 |
x = t * x + (1.0 - t) * init_image.to(x.dtype)
|
793 |
|
794 |
+
inp = prepare(t5=t5, clip=clip, img=x, prompt=prompt)
|
795 |
+
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
|
796 |
+
|
797 |
+
# with profile(activities=[ProfilerActivity.CPU],record_shapes=True,profile_memory=True) as prof:
|
798 |
+
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
|
799 |
+
|
800 |
+
x = unpack(x.float(), height, width)
|
801 |
+
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
|
802 |
+
x = x = (x / ae.config.scaling_factor) + ae.config.shift_factor
|
803 |
+
x = ae.decode(x).sample
|
804 |
|
805 |
x = x.clamp(-1, 1)
|
806 |
x = rearrange(x[0], "c h w -> h w c")
|