ginipick commited on
Commit
374614c
·
verified ·
1 Parent(s): 26b3498

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -116,25 +116,49 @@ def initialize_models():
116
  ae = ae.to(device)
117
 
118
  print("Loading Flux model...")
119
- # Use the standard Flux model instead of quantized version
120
- # This will use more memory but avoid compatibility issues
121
  from huggingface_hub import hf_hub_download
122
  from safetensors.torch import load_file
123
 
124
  try:
125
- # Try to load from the standard Flux checkpoint
126
- print("Loading standard Flux model (this may take a while)...")
 
127
  model = Flux()
128
- model = model.to(dtype=torch.bfloat16, device=device)
129
 
130
- # You would need to download the standard Flux weights
131
- # For now, let's create a randomly initialized model for testing
132
- print("Warning: Using randomly initialized Flux model for testing")
133
- print("To use a pretrained model, you need to load proper Flux weights")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  except Exception as e:
136
  print(f"Error initializing Flux model: {e}")
137
- raise
 
138
 
139
  model_initialized = True
140
  print("Models initialized successfully!")
 
116
  ae = ae.to(device)
117
 
118
  print("Loading Flux model...")
 
 
119
  from huggingface_hub import hf_hub_download
120
  from safetensors.torch import load_file
121
 
122
  try:
123
+ # Try to load from a standard Flux checkpoint
124
+ # First, let's try the schnell version which might be smaller
125
+ print("Attempting to load Flux model weights...")
126
  model = Flux()
 
127
 
128
+ # Try loading from black-forest-labs directly
129
+ try:
130
+ # Note: You might need to authenticate with HuggingFace for this
131
+ sd = load_file(hf_hub_download(repo_id="black-forest-labs/FLUX.1-schnell", filename="flux1-schnell.safetensors"))
132
+ # Adjust state dict keys if needed
133
+ model.load_state_dict(sd, strict=False)
134
+ print("Loaded Flux schnell model successfully!")
135
+ except Exception as e1:
136
+ print(f"Could not load Flux schnell: {e1}")
137
+
138
+ # Try the dev version
139
+ try:
140
+ sd = load_file(hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="flux1-dev.safetensors"))
141
+ model.load_state_dict(sd, strict=False)
142
+ print("Loaded Flux dev model successfully!")
143
+ except Exception as e2:
144
+ print(f"Could not load Flux dev: {e2}")
145
+
146
+ # If no pretrained weights are available, warn the user
147
+ print("\n" + "="*50)
148
+ print("WARNING: Could not load pretrained Flux weights!")
149
+ print("The model will use random initialization.")
150
+ print("For proper results, you need to:")
151
+ print("1. Authenticate with HuggingFace: huggingface-cli login")
152
+ print("2. Accept the Flux model license agreement")
153
+ print("3. Or use a publicly available Flux checkpoint")
154
+ print("="*50 + "\n")
155
+
156
+ model = model.to(dtype=torch.bfloat16, device=device)
157
 
158
  except Exception as e:
159
  print(f"Error initializing Flux model: {e}")
160
+ # Continue with random initialization for now
161
+ model = Flux().to(dtype=torch.bfloat16, device=device)
162
 
163
  model_initialized = True
164
  print("Models initialized successfully!")