Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -94,9 +94,20 @@ def initialize_models():
|
|
94 |
if not model_initialized:
|
95 |
print("Initializing models...")
|
96 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# Load the NF4 quantized checkpoint
|
102 |
from huggingface_hub import hf_hub_download
|
@@ -106,6 +117,11 @@ def initialize_models():
|
|
106 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
107 |
model = Flux().to(dtype=torch.bfloat16, device=device)
|
108 |
result = model.load_state_dict(sd)
|
|
|
|
|
|
|
|
|
|
|
109 |
model_initialized = True
|
110 |
print("Models initialized successfully!")
|
111 |
|
@@ -214,6 +230,7 @@ if BNB_AVAILABLE:
|
|
214 |
original_linear = nn.Linear
|
215 |
nn.Linear = Linear
|
216 |
else:
|
|
|
217 |
print("Warning: BitsAndBytes not available, using standard Linear layers")
|
218 |
|
219 |
# ---------------- Model ----------------
|
|
|
94 |
if not model_initialized:
|
95 |
print("Initializing models...")
|
96 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
97 |
+
|
98 |
+
# Temporarily restore original Linear for loading standard models
|
99 |
+
original_linear = nn.Linear
|
100 |
+
if BNB_AVAILABLE:
|
101 |
+
nn.Linear = original_linear
|
102 |
+
|
103 |
+
# Load standard models without quantization
|
104 |
+
t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
|
105 |
+
clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
|
106 |
+
ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
|
107 |
+
|
108 |
+
# Re-apply quantized Linear for Flux model
|
109 |
+
if BNB_AVAILABLE:
|
110 |
+
nn.Linear = Linear
|
111 |
|
112 |
# Load the NF4 quantized checkpoint
|
113 |
from huggingface_hub import hf_hub_download
|
|
|
117 |
sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
|
118 |
model = Flux().to(dtype=torch.bfloat16, device=device)
|
119 |
result = model.load_state_dict(sd)
|
120 |
+
|
121 |
+
# Restore original Linear
|
122 |
+
if BNB_AVAILABLE:
|
123 |
+
nn.Linear = original_linear
|
124 |
+
|
125 |
model_initialized = True
|
126 |
print("Models initialized successfully!")
|
127 |
|
|
|
230 |
original_linear = nn.Linear
|
231 |
nn.Linear = Linear
|
232 |
else:
|
233 |
+
original_linear = nn.Linear
|
234 |
print("Warning: BitsAndBytes not available, using standard Linear layers")
|
235 |
|
236 |
# ---------------- Model ----------------
|