ginipick commited on
Commit
423b272
·
verified ·
1 Parent(s): 7a61cf3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -94,9 +94,20 @@ def initialize_models():
94
  if not model_initialized:
95
  print("Initializing models...")
96
  device = "cuda" if torch.cuda.is_available() else "cpu"
97
- t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16).to(device)
98
- clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
99
- ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(device)
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # Load the NF4 quantized checkpoint
102
  from huggingface_hub import hf_hub_download
@@ -106,6 +117,11 @@ def initialize_models():
106
  sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
107
  model = Flux().to(dtype=torch.bfloat16, device=device)
108
  result = model.load_state_dict(sd)
 
 
 
 
 
109
  model_initialized = True
110
  print("Models initialized successfully!")
111
 
@@ -214,6 +230,7 @@ if BNB_AVAILABLE:
214
  original_linear = nn.Linear
215
  nn.Linear = Linear
216
  else:
 
217
  print("Warning: BitsAndBytes not available, using standard Linear layers")
218
 
219
  # ---------------- Model ----------------
 
94
  if not model_initialized:
95
  print("Initializing models...")
96
  device = "cuda" if torch.cuda.is_available() else "cpu"
97
+
98
+ # Temporarily restore original Linear for loading standard models
99
+ original_linear = nn.Linear
100
+ if BNB_AVAILABLE:
101
+ nn.Linear = original_linear
102
+
103
+ # Load standard models without quantization
104
+ t5 = HFEmbedder("DeepFloyd/t5-v1_1-xxl", max_length=512, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
105
+ clip = HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
106
+ ae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to(device)
107
+
108
+ # Re-apply quantized Linear for Flux model
109
+ if BNB_AVAILABLE:
110
+ nn.Linear = Linear
111
 
112
  # Load the NF4 quantized checkpoint
113
  from huggingface_hub import hf_hub_download
 
117
  sd = {k.replace("model.diffusion_model.", ""): v for k, v in sd.items() if "model.diffusion_model" in k}
118
  model = Flux().to(dtype=torch.bfloat16, device=device)
119
  result = model.load_state_dict(sd)
120
+
121
+ # Restore original Linear
122
+ if BNB_AVAILABLE:
123
+ nn.Linear = original_linear
124
+
125
  model_initialized = True
126
  print("Models initialized successfully!")
127
 
 
230
  original_linear = nn.Linear
231
  nn.Linear = Linear
232
  else:
233
+ original_linear = nn.Linear
234
  print("Warning: BitsAndBytes not available, using standard Linear layers")
235
 
236
  # ---------------- Model ----------------