ginipick commited on
Commit
a7a4022
·
verified ·
1 Parent(s): fece974

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -26
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
- import spaces
 
3
 
4
  import time
5
  import gradio as gr
@@ -14,7 +15,6 @@ from tqdm import tqdm
14
  import bitsandbytes as bnb
15
  from bitsandbytes.nn.modules import Params4bit, QuantState
16
 
17
- import torch
18
  import random
19
  from einops import rearrange, repeat
20
  from diffusers import AutoencoderKL
@@ -22,6 +22,9 @@ from torch import Tensor, nn
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
24
 
 
 
 
25
  # ---------------- Encoders ----------------
26
 
27
  class HFEmbedder(nn.Module):
@@ -58,10 +61,32 @@ class HFEmbedder(nn.Module):
58
  )
59
  return outputs[self.output_key]
60
 
61
- device = "cuda"
62
- t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
63
- clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
64
- ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # ---------------- NF4 ----------------
67
 
@@ -163,7 +188,7 @@ class Linear(ForgeLoader4Bit):
163
  self.bias.data = self.bias.data.to(x.dtype)
164
  return functional_linear_4bits(x, self.weight, self.bias)
165
 
166
- import torch.nn as nn
167
  nn.Linear = Linear
168
 
169
  # ---------------- Model ----------------
@@ -608,34 +633,22 @@ def get_image(image) -> torch.Tensor | None:
608
  img: torch.Tensor = transform(image)
609
  return img[None, ...]
610
 
611
- # Load the NF4 quantized checkpoint
612
- from huggingface_hub import hf_hub_download
613
- from safetensors.torch import load_file
614
-
615
- sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors"))
616
- sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
617
- model = Flux().to(dtype=torch.bfloat16, device="cuda")
618
- result = model.load_state_dict(sd)
619
- model_zero_init = False
620
-
621
- @spaces.GPU
622
  @torch.no_grad()
623
  def generate_image(
624
  prompt, width, height, guidance, inference_steps, seed,
625
  do_img2img, init_image, image2image_strength, resize_img,
626
  progress=gr.Progress(track_tqdm=True),
627
  ):
 
 
 
628
  if seed == 0:
629
  seed = int(random.random() * 1_000_000)
630
 
631
  device = "cuda" if torch.cuda.is_available() else "cpu"
632
  torch_device = torch.device(device)
633
 
634
- global model, model_zero_init
635
- if not model_zero_init:
636
- model = model.to(torch_device)
637
- model_zero_init = True
638
-
639
  if do_img2img and init_image is not None:
640
  init_image = get_image(init_image)
641
  if resize_img:
@@ -759,6 +772,5 @@ if __name__ == "__main__":
759
  demo = create_demo()
760
  # Enable the queue to handle concurrency
761
  demo.queue()
762
- # Launch with show_api=False and share=True to avoid the "bool is not iterable" error
763
- # and the "ValueError: When localhost is not accessible..." error.
764
- demo.launch(show_api=False, share=True, server_name="0.0.0.0", mcp_server=True)
 
1
  import os
2
+ # Set environment variable before importing torch to avoid nested tensor issues
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
 
5
  import time
6
  import gradio as gr
 
15
  import bitsandbytes as bnb
16
  from bitsandbytes.nn.modules import Params4bit, QuantState
17
 
 
18
  import random
19
  from einops import rearrange, repeat
20
  from diffusers import AutoencoderKL
 
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
24
 
25
+ # Import spaces after other imports to minimize conflicts
26
+ import spaces
27
+
28
  # ---------------- Encoders ----------------
29
 
30
  class HFEmbedder(nn.Module):
 
61
  )
62
  return outputs[self.output_key]
63
 
64
+ # Initialize models without GPU decorator first
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ t5 = None
67
+ clip = None
68
+ ae = None
69
+ model = None
70
+ model_initialized = False
71
+
72
+ def initialize_models():
73
+ global t5, clip, ae, model, model_initialized
74
+ if not model_initialized:
75
+ print("Initializing models...")
76
+ t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
77
+ clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
78
+ ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
79
+
80
+ # Load the NF4 quantized checkpoint
81
+ from huggingface_hub import hf_hub_download
82
+ from safetensors.torch import load_file
83
+
84
+ sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors"))
85
+ sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
86
+ model = Flux().to(dtype=torch.bfloat16, device=device)
87
+ result = model.load_state_dict(sd)
88
+ model_initialized = True
89
+ print("Models initialized successfully!")
90
 
91
  # ---------------- NF4 ----------------
92
 
 
188
  self.bias.data = self.bias.data.to(x.dtype)
189
  return functional_linear_4bits(x, self.weight, self.bias)
190
 
191
+ # Override Linear after all torch imports are done
192
  nn.Linear = Linear
193
 
194
  # ---------------- Model ----------------
 
633
  img: torch.Tensor = transform(image)
634
  return img[None, ...]
635
 
636
+ @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
637
  @torch.no_grad()
638
  def generate_image(
639
  prompt, width, height, guidance, inference_steps, seed,
640
  do_img2img, init_image, image2image_strength, resize_img,
641
  progress=gr.Progress(track_tqdm=True),
642
  ):
643
+ # Initialize models on first run
644
+ initialize_models()
645
+
646
  if seed == 0:
647
  seed = int(random.random() * 1_000_000)
648
 
649
  device = "cuda" if torch.cuda.is_available() else "cpu"
650
  torch_device = torch.device(device)
651
 
 
 
 
 
 
652
  if do_img2img and init_image is not None:
653
  init_image = get_image(init_image)
654
  if resize_img:
 
772
  demo = create_demo()
773
  # Enable the queue to handle concurrency
774
  demo.queue()
775
+ # Launch with appropriate settings
776
+ demo.launch(show_api=False, share=True)