Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -116,25 +116,49 @@ def initialize_models():
|
|
116 |
ae = ae.to(device)
|
117 |
|
118 |
print("Loading Flux model...")
|
119 |
-
# Use the standard Flux model instead of quantized version
|
120 |
-
# This will use more memory but avoid compatibility issues
|
121 |
from huggingface_hub import hf_hub_download
|
122 |
from safetensors.torch import load_file
|
123 |
|
124 |
try:
|
125 |
-
# Try to load from
|
126 |
-
|
|
|
127 |
model = Flux()
|
128 |
-
model = model.to(dtype=torch.bfloat16, device=device)
|
129 |
|
130 |
-
#
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
except Exception as e:
|
136 |
print(f"Error initializing Flux model: {e}")
|
137 |
-
|
|
|
138 |
|
139 |
model_initialized = True
|
140 |
print("Models initialized successfully!")
|
|
|
116 |
ae = ae.to(device)
|
117 |
|
118 |
print("Loading Flux model...")
|
|
|
|
|
119 |
from huggingface_hub import hf_hub_download
|
120 |
from safetensors.torch import load_file
|
121 |
|
122 |
try:
|
123 |
+
# Try to load from a standard Flux checkpoint
|
124 |
+
# First, let's try the schnell version which might be smaller
|
125 |
+
print("Attempting to load Flux model weights...")
|
126 |
model = Flux()
|
|
|
127 |
|
128 |
+
# Try loading from black-forest-labs directly
|
129 |
+
try:
|
130 |
+
# Note: You might need to authenticate with HuggingFace for this
|
131 |
+
sd = load_file(hf_hub_download(repo_id="black-forest-labs/FLUX.1-schnell", filename="flux1-schnell.safetensors"))
|
132 |
+
# Adjust state dict keys if needed
|
133 |
+
model.load_state_dict(sd, strict=False)
|
134 |
+
print("Loaded Flux schnell model successfully!")
|
135 |
+
except Exception as e1:
|
136 |
+
print(f"Could not load Flux schnell: {e1}")
|
137 |
+
|
138 |
+
# Try the dev version
|
139 |
+
try:
|
140 |
+
sd = load_file(hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="flux1-dev.safetensors"))
|
141 |
+
model.load_state_dict(sd, strict=False)
|
142 |
+
print("Loaded Flux dev model successfully!")
|
143 |
+
except Exception as e2:
|
144 |
+
print(f"Could not load Flux dev: {e2}")
|
145 |
+
|
146 |
+
# If no pretrained weights are available, warn the user
|
147 |
+
print("\n" + "="*50)
|
148 |
+
print("WARNING: Could not load pretrained Flux weights!")
|
149 |
+
print("The model will use random initialization.")
|
150 |
+
print("For proper results, you need to:")
|
151 |
+
print("1. Authenticate with HuggingFace: huggingface-cli login")
|
152 |
+
print("2. Accept the Flux model license agreement")
|
153 |
+
print("3. Or use a publicly available Flux checkpoint")
|
154 |
+
print("="*50 + "\n")
|
155 |
+
|
156 |
+
model = model.to(dtype=torch.bfloat16, device=device)
|
157 |
|
158 |
except Exception as e:
|
159 |
print(f"Error initializing Flux model: {e}")
|
160 |
+
# Continue with random initialization for now
|
161 |
+
model = Flux().to(dtype=torch.bfloat16, device=device)
|
162 |
|
163 |
model_initialized = True
|
164 |
print("Models initialized successfully!")
|