ginipick commited on
Commit
03a1ea8
·
verified ·
1 Parent(s): a5ab155

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -102
app.py CHANGED
@@ -1,9 +1,24 @@
1
  import os
2
- # Set environment variable before importing torch to avoid nested tensor issues
 
 
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
 
5
- # Import spaces FIRST before any torch imports
6
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import time
9
  import gradio as gr
@@ -15,9 +30,6 @@ import math
15
  from typing import Callable
16
 
17
  from tqdm import tqdm
18
- import bitsandbytes as bnb
19
- from bitsandbytes.nn.modules import Params4bit, QuantState
20
-
21
  import random
22
  from einops import rearrange, repeat
23
  from diffusers import AutoencoderKL
@@ -25,6 +37,15 @@ from torch import Tensor, nn
25
  from transformers import CLIPTextModel, CLIPTokenizer
26
  from transformers import T5EncoderModel, T5Tokenizer
27
 
 
 
 
 
 
 
 
 
 
28
  # ---------------- Encoders ----------------
29
 
30
  class HFEmbedder(nn.Module):
@@ -90,106 +111,110 @@ def initialize_models():
90
 
91
  # ---------------- NF4 ----------------
92
 
93
- def functional_linear_4bits(x, weight, bias):
94
- import bitsandbytes as bnb
95
- out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
96
- out = out.to(x)
97
- return out
98
-
99
- class ForgeParams4bit(Params4bit):
100
- """Subclass to force re-quantization to GPU if needed."""
101
- def to(self, *args, **kwargs):
102
- import torch
103
- device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
104
- if device is not None and device.type == "cuda" and not self.bnb_quantized:
105
- return self._quantize(device)
106
- else:
107
- n = ForgeParams4bit(
108
- torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
109
- requires_grad=self.requires_grad,
110
- quant_state=self.quant_state,
111
- compress_statistics=False,
112
- blocksize=64,
113
- quant_type=self.quant_type,
114
- quant_storage=self.quant_storage,
115
- bnb_quantized=self.bnb_quantized,
116
- module=self.module
117
- )
118
- self.module.quant_state = n.quant_state
119
- self.data = n.data
120
- self.quant_state = n.quant_state
121
- return n
122
-
123
- class ForgeLoader4Bit(nn.Module):
124
- def __init__(self, *, device, dtype, quant_type, **kwargs):
125
- super().__init__()
126
- self.dummy = nn.Parameter(torch.empty(1, device=device, dtype=dtype))
127
- self.weight = None
128
- self.quant_state = None
129
- self.bias = None
130
- self.quant_type = quant_type
131
-
132
- def _save_to_state_dict(self, destination, prefix, keep_vars):
133
- super()._save_to_state_dict(destination, prefix, keep_vars)
134
- from bitsandbytes.nn.modules import QuantState
135
- quant_state = getattr(self.weight, "quant_state", None)
136
- if quant_state is not None:
137
- for k, v in quant_state.as_dict(packed=True).items():
138
- destination[prefix + "weight." + k] = v if keep_vars else v.detach()
139
- return
140
-
141
- def _load_from_state_dict(
142
- self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
143
- ):
144
- from bitsandbytes.nn.modules import Params4bit
145
- import torch
146
-
147
- quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
148
- if any('bitsandbytes' in k for k in quant_state_keys):
149
- quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
150
- self.weight = ForgeParams4bit.from_prequantized(
151
- data=state_dict[prefix + 'weight'],
152
- quantized_stats=quant_state_dict,
153
- requires_grad=False,
154
- device=torch.device('cuda'),
155
- module=self
156
- )
157
- self.quant_state = self.weight.quant_state
158
-
159
- if prefix + 'bias' in state_dict:
160
- self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
161
- del self.dummy
162
- elif hasattr(self, 'dummy'):
163
- if prefix + 'weight' in state_dict:
164
- self.weight = ForgeParams4bit(
165
- state_dict[prefix + 'weight'].to(self.dummy),
166
- requires_grad=False,
167
- compress_statistics=True,
168
  quant_type=self.quant_type,
169
- quant_storage=torch.uint8,
170
- module=self,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  )
172
  self.quant_state = self.weight.quant_state
173
 
174
- if prefix + 'bias' in state_dict:
175
- self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
176
-
177
- del self.dummy
178
- else:
179
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
180
-
181
- class Linear(ForgeLoader4Bit):
182
- def __init__(self, *args, device=None, dtype=None, **kwargs):
183
- super().__init__(device=device, dtype=dtype, quant_type='nf4')
184
-
185
- def forward(self, x):
186
- self.weight.quant_state = self.quant_state
187
- if self.bias is not None and self.bias.dtype != x.dtype:
188
- self.bias.data = self.bias.data.to(x.dtype)
189
- return functional_linear_4bits(x, self.weight, self.bias)
190
-
191
- # Override Linear after all torch imports are done
192
- nn.Linear = Linear
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  # ---------------- Model ----------------
195
 
 
1
  import os
2
+ import sys
3
+
4
+ # Disable bitsandbytes triton integration to avoid conflicts
5
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
6
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
7
 
8
+ # Try to handle spaces import gracefully
9
+ try:
10
+ import spaces
11
+ SPACES_AVAILABLE = True
12
+ except Exception as e:
13
+ print(f"Warning: Could not import spaces: {e}")
14
+ SPACES_AVAILABLE = False
15
+ # Create a dummy decorator if spaces is not available
16
+ class spaces:
17
+ @staticmethod
18
+ def GPU(duration=None):
19
+ def decorator(func):
20
+ return func
21
+ return decorator
22
 
23
  import time
24
  import gradio as gr
 
30
  from typing import Callable
31
 
32
  from tqdm import tqdm
 
 
 
33
  import random
34
  from einops import rearrange, repeat
35
  from diffusers import AutoencoderKL
 
37
  from transformers import CLIPTextModel, CLIPTokenizer
38
  from transformers import T5EncoderModel, T5Tokenizer
39
 
40
+ # Import bitsandbytes after spaces to avoid conflicts
41
+ try:
42
+ import bitsandbytes as bnb
43
+ from bitsandbytes.nn.modules import Params4bit, QuantState
44
+ BNB_AVAILABLE = True
45
+ except Exception as e:
46
+ print(f"Warning: Could not import bitsandbytes: {e}")
47
+ BNB_AVAILABLE = False
48
+
49
  # ---------------- Encoders ----------------
50
 
51
  class HFEmbedder(nn.Module):
 
111
 
112
  # ---------------- NF4 ----------------
113
 
114
+ if BNB_AVAILABLE:
115
+ def functional_linear_4bits(x, weight, bias):
116
+ import bitsandbytes as bnb
117
+ out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
118
+ out = out.to(x)
119
+ return out
120
+
121
+ class ForgeParams4bit(Params4bit):
122
+ """Subclass to force re-quantization to GPU if needed."""
123
+ def to(self, *args, **kwargs):
124
+ import torch
125
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
126
+ if device is not None and device.type == "cuda" and not self.bnb_quantized:
127
+ return self._quantize(device)
128
+ else:
129
+ n = ForgeParams4bit(
130
+ torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
131
+ requires_grad=self.requires_grad,
132
+ quant_state=self.quant_state,
133
+ compress_statistics=False,
134
+ blocksize=64,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  quant_type=self.quant_type,
136
+ quant_storage=self.quant_storage,
137
+ bnb_quantized=self.bnb_quantized,
138
+ module=self.module
139
+ )
140
+ self.module.quant_state = n.quant_state
141
+ self.data = n.data
142
+ self.quant_state = n.quant_state
143
+ return n
144
+
145
+ class ForgeLoader4Bit(nn.Module):
146
+ def __init__(self, *, device, dtype, quant_type, **kwargs):
147
+ super().__init__()
148
+ self.dummy = nn.Parameter(torch.empty(1, device=device, dtype=dtype))
149
+ self.weight = None
150
+ self.quant_state = None
151
+ self.bias = None
152
+ self.quant_type = quant_type
153
+
154
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
155
+ super()._save_to_state_dict(destination, prefix, keep_vars)
156
+ from bitsandbytes.nn.modules import QuantState
157
+ quant_state = getattr(self.weight, "quant_state", None)
158
+ if quant_state is not None:
159
+ for k, v in quant_state.as_dict(packed=True).items():
160
+ destination[prefix + "weight." + k] = v if keep_vars else v.detach()
161
+ return
162
+
163
+ def _load_from_state_dict(
164
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
165
+ ):
166
+ from bitsandbytes.nn.modules import Params4bit
167
+ import torch
168
+
169
+ quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
170
+ if any('bitsandbytes' in k for k in quant_state_keys):
171
+ quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
172
+ self.weight = ForgeParams4bit.from_prequantized(
173
+ data=state_dict[prefix + 'weight'],
174
+ quantized_stats=quant_state_dict,
175
+ requires_grad=False,
176
+ device=torch.device('cuda'),
177
+ module=self
178
  )
179
  self.quant_state = self.weight.quant_state
180
 
181
+ if prefix + 'bias' in state_dict:
182
+ self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
183
+ del self.dummy
184
+ elif hasattr(self, 'dummy'):
185
+ if prefix + 'weight' in state_dict:
186
+ self.weight = ForgeParams4bit(
187
+ state_dict[prefix + 'weight'].to(self.dummy),
188
+ requires_grad=False,
189
+ compress_statistics=True,
190
+ quant_type=self.quant_type,
191
+ quant_storage=torch.uint8,
192
+ module=self,
193
+ )
194
+ self.quant_state = self.weight.quant_state
195
+
196
+ if prefix + 'bias' in state_dict:
197
+ self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
198
+
199
+ del self.dummy
200
+ else:
201
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
202
+
203
+ class Linear(ForgeLoader4Bit):
204
+ def __init__(self, *args, device=None, dtype=None, **kwargs):
205
+ super().__init__(device=device, dtype=dtype, quant_type='nf4')
206
+
207
+ def forward(self, x):
208
+ self.weight.quant_state = self.quant_state
209
+ if self.bias is not None and self.bias.dtype != x.dtype:
210
+ self.bias.data = self.bias.data.to(x.dtype)
211
+ return functional_linear_4bits(x, self.weight, self.bias)
212
+
213
+ # Override Linear after all torch imports are done
214
+ original_linear = nn.Linear
215
+ nn.Linear = Linear
216
+ else:
217
+ print("Warning: BitsAndBytes not available, using standard Linear layers")
218
 
219
  # ---------------- Model ----------------
220