ginipick commited on
Commit
26b3498
·
verified ·
1 Parent(s): 423b272

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -46,6 +46,13 @@ except Exception as e:
46
  print(f"Warning: Could not import bitsandbytes: {e}")
47
  BNB_AVAILABLE = False
48
 
 
 
 
 
 
 
 
49
  # ---------------- Encoders ----------------
50
 
51
  class HFEmbedder(nn.Module):
@@ -95,32 +102,39 @@ def initialize_models():
95
  print("Initializing models...")
96
  device = "cuda" if torch.cuda.is_available() else "cpu"
97
 
98
- # Temporarily restore original Linear for loading standard models
99
- original_linear = nn.Linear
100
- if BNB_AVAILABLE:
101
- nn.Linear = original_linear
102
 
103
- # Load standard models without quantization
104
- t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
105
- clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
106
- ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
107
 
108
- # Re-apply quantized Linear for Flux model
109
- if BNB_AVAILABLE:
110
- nn.Linear = Linear
111
 
112
- # Load the NF4 quantized checkpoint
 
 
113
  from huggingface_hub import hf_hub_download
114
  from safetensors.torch import load_file
115
 
116
- sd = load_file(hf_hub_download(repo_id="lllyasviel/flux1-dev-bnb-nf4", filename="flux1-dev-bnb-nf4-v2.safetensors"))
117
- sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
118
- model = Flux().to(dtype=torch.bfloat16, device=device)
119
- result = model.load_state_dict(sd)
120
-
121
- # Restore original Linear
122
- if BNB_AVAILABLE:
123
- nn.Linear = original_linear
 
 
 
 
 
 
124
 
125
  model_initialized = True
126
  print("Models initialized successfully!")
@@ -226,11 +240,9 @@ if BNB_AVAILABLE:
226
  self.bias.data = self.bias.data.to(x.dtype)
227
  return functional_linear_4bits(x, self.weight, self.bias)
228
 
229
- # Override Linear after all torch imports are done
230
- original_linear = nn.Linear
231
- nn.Linear = Linear
232
  else:
233
- original_linear = nn.Linear
234
  print("Warning: BitsAndBytes not available, using standard Linear layers")
235
 
236
  # ---------------- Model ----------------
 
46
  print(f"Warning: Could not import bitsandbytes: {e}")
47
  BNB_AVAILABLE = False
48
 
49
+ # Store original Linear class before any modifications
50
+ original_linear = nn.Linear
51
+
52
+ # Disable BNB for now due to compatibility issues
53
+ BNB_AVAILABLE = False
54
+ print("Note: BitsAndBytes quantization disabled for compatibility")
55
+
56
  # ---------------- Encoders ----------------
57
 
58
  class HFEmbedder(nn.Module):
 
102
  print("Initializing models...")
103
  device = "cuda" if torch.cuda.is_available() else "cpu"
104
 
105
+ # Load standard models
106
+ print("Loading T5 encoder...")
107
+ t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
108
+ t5 = t5.to(device)
109
 
110
+ print("Loading CLIP encoder...")
111
+ clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
112
+ clip = clip.to(device)
 
113
 
114
+ print("Loading VAE...")
115
+ ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True)
116
+ ae = ae.to(device)
117
 
118
+ print("Loading Flux model...")
119
+ # Use the standard Flux model instead of quantized version
120
+ # This will use more memory but avoid compatibility issues
121
  from huggingface_hub import hf_hub_download
122
  from safetensors.torch import load_file
123
 
124
+ try:
125
+ # Try to load from the standard Flux checkpoint
126
+ print("Loading standard Flux model (this may take a while)...")
127
+ model = Flux()
128
+ model = model.to(dtype=torch.bfloat16, device=device)
129
+
130
+ # You would need to download the standard Flux weights
131
+ # For now, let's create a randomly initialized model for testing
132
+ print("Warning: Using randomly initialized Flux model for testing")
133
+ print("To use a pretrained model, you need to load proper Flux weights")
134
+
135
+ except Exception as e:
136
+ print(f"Error initializing Flux model: {e}")
137
+ raise
138
 
139
  model_initialized = True
140
  print("Models initialized successfully!")
 
240
  self.bias.data = self.bias.data.to(x.dtype)
241
  return functional_linear_4bits(x, self.weight, self.bias)
242
 
243
+ # Don't override Linear globally - we'll only use it for Flux model
244
+ pass
 
245
  else:
 
246
  print("Warning: BitsAndBytes not available, using standard Linear layers")
247
 
248
  # ---------------- Model ----------------