Omarrran commited on
Commit
f4c7b32
·
verified ·
1 Parent(s): 7598a90

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +405 -0
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import gradio as gr
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.colors import TwoSlopeNorm
8
+ import io
9
+ from PIL import Image
10
+
11
+ # Implementation of the W8A16LinearLayer
12
+ class W8A16LinearLayer(nn.Module):
13
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
14
+ super().__init__()
15
+
16
+ self.register_buffer(
17
+ "int8_weights",
18
+ torch.randint(
19
+ -128, 127, (out_features, in_features), dtype=torch.int8
20
+ )
21
+ )
22
+
23
+ self.register_buffer("scales",
24
+ torch.randn((out_features), dtype=dtype))
25
+
26
+ if bias:
27
+ self.register_buffer("bias",
28
+ torch.randn((1, out_features),
29
+ dtype=dtype))
30
+ else:
31
+ self.bias = None
32
+
33
+ def quantize(self, weights):
34
+ """
35
+ Quantize floating point weights to int8 precision
36
+
37
+ Args:
38
+ weights: Tensor of weights to quantize (shape: out_features x in_features)
39
+
40
+ Returns:
41
+ None (updates the int8_weights and scales directly)
42
+ """
43
+ w_fp32 = weights.clone().to(torch.float32)
44
+
45
+ # Calculate scales as the max absolute value for each output row
46
+ # divided by 127 (max value for int8)
47
+ scales = w_fp32.abs().max(dim=-1).values / 127
48
+ scales = scales.to(weights.dtype)
49
+
50
+ # Quantize by dividing by scales and rounding to nearest integer
51
+ int8_weights = torch.round(weights / scales.unsqueeze(1)).to(torch.int8)
52
+
53
+ # Update the model parameters
54
+ self.int8_weights = int8_weights
55
+ self.scales = scales
56
+
57
+ return int8_weights, scales
58
+
59
+ def forward(self, input):
60
+ """
61
+ Forward pass through the quantized linear layer
62
+
63
+ Args:
64
+ input: Input tensor (shape: batch_size x seq_len x in_features)
65
+
66
+ Returns:
67
+ output: Output tensor after the linear transformation
68
+ """
69
+ # Cast int8 weights to input dtype while preserving the values
70
+ casted_weights = self.int8_weights.to(input.dtype)
71
+
72
+ # Perform the linear multiplication and apply the scaling factor
73
+ output = F.linear(input, casted_weights) * self.scales
74
+
75
+ # Add bias if present
76
+ if self.bias is not None:
77
+ output = output + self.bias
78
+
79
+ return output
80
+
81
+ # Helper functions for visualization
82
+
83
+ def plot_weight_matrix(weights, title="Weight Matrix"):
84
+ """Create a heatmap visualization of weight matrices"""
85
+ plt.figure(figsize=(10, 8))
86
+
87
+ # Create a centered colormap
88
+ vmax = max(abs(weights.min().item()), abs(weights.max().item()))
89
+ vmin = -vmax
90
+ norm = TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
91
+
92
+ plt.imshow(weights.detach().numpy(), cmap='RdBu_r', norm=norm)
93
+ plt.colorbar(label='Weight Value')
94
+ plt.title(title)
95
+
96
+ # Save the plot to a bytes buffer
97
+ buf = io.BytesIO()
98
+ plt.savefig(buf, format='png')
99
+ plt.close()
100
+ buf.seek(0)
101
+
102
+ return Image.open(buf)
103
+
104
+ def plot_weight_distribution(weights, title="Weight Distribution"):
105
+ """Create a histogram visualization of weight distributions"""
106
+ plt.figure(figsize=(10, 6))
107
+
108
+ # Flatten the weights to 1D for histogram
109
+ flat_weights = weights.flatten().detach().numpy()
110
+
111
+ plt.hist(flat_weights, bins=50, alpha=0.7, color='blue')
112
+ plt.xlabel('Weight Value')
113
+ plt.ylabel('Frequency')
114
+ plt.title(title)
115
+ plt.grid(alpha=0.3)
116
+
117
+ # Save the plot to a bytes buffer
118
+ buf = io.BytesIO()
119
+ plt.savefig(buf, format='png')
120
+ plt.close()
121
+ buf.seek(0)
122
+
123
+ return Image.open(buf)
124
+
125
+ def calculate_quantization_error(original, quantized, scales):
126
+ """Calculate error metrics between original and dequantized weights"""
127
+ # Dequantize the weights
128
+ dequantized = quantized.float() * scales.unsqueeze(1)
129
+
130
+ # Calculate error metrics
131
+ abs_error = (original - dequantized).abs()
132
+ max_error = abs_error.max().item()
133
+ mean_error = abs_error.mean().item()
134
+
135
+ return max_error, mean_error, dequantized
136
+
137
+ # Gradio UI components
138
+
139
+ def initialize_model(in_features, out_features, with_bias, dtype_str):
140
+ """Initialize a new quantized linear layer model"""
141
+ # Map dtype string to torch dtype
142
+ dtype_map = {
143
+ "float32": torch.float32,
144
+ "float16": torch.float16,
145
+ "bfloat16": torch.bfloat16
146
+ }
147
+ dtype = dtype_map[dtype_str]
148
+
149
+ # Create the model
150
+ model = W8A16LinearLayer(in_features, out_features, bias=with_bias, dtype=dtype)
151
+
152
+ # Generate random weights for visualization
153
+ random_weights = torch.randn((out_features, in_features), dtype=dtype)
154
+
155
+ # Original weights visualization
156
+ weights_vis = plot_weight_matrix(random_weights, "Original Weights")
157
+ dist_vis = plot_weight_distribution(random_weights, "Original Weight Distribution")
158
+
159
+ # Quantize the weights
160
+ int8_weights, scales = model.quantize(random_weights)
161
+
162
+ # Quantized weights visualization
163
+ q_weights_vis = plot_weight_matrix(int8_weights, "Quantized Weights (INT8)")
164
+ q_dist_vis = plot_weight_distribution(int8_weights, "Quantized Weight Distribution")
165
+
166
+ # Calculate quantization error
167
+ max_error, mean_error, dequantized = calculate_quantization_error(
168
+ random_weights, int8_weights, scales
169
+ )
170
+
171
+ # Dequantized weights visualization
172
+ deq_weights_vis = plot_weight_matrix(dequantized, "Dequantized Weights")
173
+
174
+ # Error visualization
175
+ error = (random_weights - dequantized).abs()
176
+ error_vis = plot_weight_matrix(error, "Quantization Error (Absolute)")
177
+
178
+ # Create model summary
179
+ model_info = f"""
180
+ ## Model Configuration
181
+ - Input Features: {in_features}
182
+ - Output Features: {out_features}
183
+ - Bias: {"Yes" if with_bias else "No"}
184
+ - Data Type: {dtype_str}
185
+
186
+ ## Quantization Stats
187
+ - Original Weights Shape: {random_weights.shape}
188
+ - Quantized Weights Shape: {int8_weights.shape}
189
+ - Scales Shape: {scales.shape}
190
+ - Maximum Quantization Error: {max_error:.6f}
191
+ - Mean Quantization Error: {mean_error:.6f}
192
+ - Memory Savings: {100 * (1 - (int8_weights.element_size() + scales.element_size() * scales.numel()/int8_weights.numel()) / random_weights.element_size()):.2f}%
193
+ """
194
+
195
+ # Create sample input/output for the model
196
+ sample_input = torch.randn(1, in_features, dtype=dtype)
197
+ sample_output = model(sample_input)
198
+
199
+ io_info = f"""
200
+ ## Sample Input/Output
201
+ - Input Shape: {sample_input.shape}
202
+ - Output Shape: {sample_output.shape}
203
+ - Output Range: [{sample_output.min().item():.4f}, {sample_output.max().item():.4f}]
204
+ """
205
+
206
+ return model_info, io_info, weights_vis, q_weights_vis, deq_weights_vis, dist_vis, q_dist_vis, error_vis
207
+
208
+ def quantize_custom_weights(in_features, out_features, with_bias, dtype_str, weight_pattern):
209
+ """Quantize custom weights based on the selected pattern"""
210
+ # Map dtype string to torch dtype
211
+ dtype_map = {
212
+ "float32": torch.float32,
213
+ "float16": torch.float16,
214
+ "bfloat16": torch.bfloat16
215
+ }
216
+ dtype = dtype_map[dtype_str]
217
+
218
+ # Create the model
219
+ model = W8A16LinearLayer(in_features, out_features, bias=with_bias, dtype=dtype)
220
+
221
+ # Generate weights based on pattern
222
+ if weight_pattern == "random":
223
+ custom_weights = torch.randn((out_features, in_features), dtype=dtype)
224
+ elif weight_pattern == "eye":
225
+ # Identity matrix (or closest approximation if dimensions don't match)
226
+ custom_weights = torch.zeros((out_features, in_features), dtype=dtype)
227
+ min_dim = min(out_features, in_features)
228
+ custom_weights[:min_dim, :min_dim] = torch.eye(min_dim, dtype=dtype)
229
+ elif weight_pattern == "ones":
230
+ custom_weights = torch.ones((out_features, in_features), dtype=dtype)
231
+ elif weight_pattern == "alternating":
232
+ custom_weights = torch.ones((out_features, in_features), dtype=dtype)
233
+ # Create a checkerboard pattern
234
+ for i in range(out_features):
235
+ for j in range(in_features):
236
+ if (i + j) % 2 == 1:
237
+ custom_weights[i, j] = -1.0
238
+ elif weight_pattern == "gradient":
239
+ # Linear gradient from -1 to 1
240
+ x = torch.linspace(-1, 1, in_features)
241
+ y = torch.linspace(-1, 1, out_features)
242
+ xx, yy = torch.meshgrid(x, y, indexing='ij')
243
+ custom_weights = (xx + yy).t().to(dtype)
244
+
245
+ # Original weights visualization
246
+ weights_vis = plot_weight_matrix(custom_weights, f"Original Weights ({weight_pattern})")
247
+ dist_vis = plot_weight_distribution(custom_weights, "Original Weight Distribution")
248
+
249
+ # Quantize the weights
250
+ int8_weights, scales = model.quantize(custom_weights)
251
+
252
+ # Quantized weights visualization
253
+ q_weights_vis = plot_weight_matrix(int8_weights, "Quantized Weights (INT8)")
254
+ q_dist_vis = plot_weight_distribution(int8_weights, "Quantized Weight Distribution")
255
+
256
+ # Calculate quantization error
257
+ max_error, mean_error, dequantized = calculate_quantization_error(
258
+ custom_weights, int8_weights, scales
259
+ )
260
+
261
+ # Dequantized weights visualization
262
+ deq_weights_vis = plot_weight_matrix(dequantized, "Dequantized Weights")
263
+
264
+ # Error visualization
265
+ error = (custom_weights - dequantized).abs()
266
+ error_vis = plot_weight_matrix(error, "Quantization Error (Absolute)")
267
+
268
+ # Quantization details
269
+ quant_info = f"""
270
+ ## Quantization Details
271
+ - Original Data Type: {dtype_str}
272
+ - Quantized Data Type: int8 (8-bit)
273
+ - Weight Pattern: {weight_pattern}
274
+
275
+ ## Error Analysis
276
+ - Maximum Quantization Error: {max_error:.6f}
277
+ - Mean Quantization Error: {mean_error:.6f}
278
+ - Memory Savings: {100 * (1 - (int8_weights.element_size() + scales.element_size() * scales.numel()/int8_weights.numel()) / custom_weights.element_size()):.2f}%
279
+
280
+ ## Tensor Shapes
281
+ - Original Weights: {custom_weights.shape}
282
+ - Quantized Weights: {int8_weights.shape}
283
+ - Quantization Scales: {scales.shape}
284
+ """
285
+
286
+ # Create row histograms for quantization scales
287
+ plt.figure(figsize=(10, 6))
288
+ plt.hist(scales.detach().cpu().numpy(), bins=30, alpha=0.7, color='green')
289
+ plt.xlabel('Scale Value')
290
+ plt.ylabel('Frequency')
291
+ plt.title('Distribution of Quantization Scales')
292
+ plt.grid(alpha=0.3)
293
+
294
+ # Save the plot to a bytes buffer
295
+ buf = io.BytesIO()
296
+ plt.savefig(buf, format='png')
297
+ plt.close()
298
+ buf.seek(0)
299
+ scales_vis = Image.open(buf)
300
+
301
+ return quant_info, weights_vis, q_weights_vis, deq_weights_vis, dist_vis, q_dist_vis, error_vis, scales_vis
302
+
303
+ # Create Gradio interface
304
+ with gr.Blocks(title="8-Bit Weight Quantizer") as demo:
305
+ gr.Markdown("# PyTorch 8-Bit Weight Quantizer")
306
+ gr.Markdown("""
307
+ This tool demonstrates quantization of neural network weights to INT8 precision.
308
+ It implements a custom `W8A16LinearLayer` that uses 8-bit weights with 16-bit activations.
309
+ """)
310
+
311
+ with gr.Tabs():
312
+ with gr.TabItem("Initialize Model"):
313
+ with gr.Row():
314
+ with gr.Column():
315
+ in_feat = gr.Slider(minimum=1, maximum=512, value=16, step=1, label="Input Features")
316
+ out_feat = gr.Slider(minimum=1, maximum=512, value=32, step=1, label="Output Features")
317
+ with_bias = gr.Checkbox(value=True, label="Include Bias")
318
+ dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], value="float32", label="Data Type")
319
+ init_btn = gr.Button("Initialize Model")
320
+
321
+ with gr.Column():
322
+ model_info = gr.Markdown()
323
+ io_info = gr.Markdown()
324
+
325
+ with gr.Row():
326
+ orig_weights = gr.Image(label="Original Weights")
327
+ quant_weights = gr.Image(label="Quantized Weights (INT8)")
328
+ dequant_weights = gr.Image(label="Dequantized Weights")
329
+
330
+ with gr.Row():
331
+ orig_dist = gr.Image(label="Original Weight Distribution")
332
+ quant_dist = gr.Image(label="Quantized Weight Distribution")
333
+ error_vis = gr.Image(label="Quantization Error")
334
+
335
+ with gr.TabItem("Custom Quantization"):
336
+ with gr.Row():
337
+ with gr.Column():
338
+ c_in_feat = gr.Slider(minimum=1, maximum=512, value=16, step=1, label="Input Features")
339
+ c_out_feat = gr.Slider(minimum=1, maximum=512, value=32, step=1, label="Output Features")
340
+ c_with_bias = gr.Checkbox(value=True, label="Include Bias")
341
+ c_dtype = gr.Dropdown(choices=["float32", "float16", "bfloat16"], value="float32", label="Data Type")
342
+ weight_pattern = gr.Dropdown(
343
+ choices=["random", "eye", "ones", "alternating", "gradient"],
344
+ value="random",
345
+ label="Weight Pattern"
346
+ )
347
+ quantize_btn = gr.Button("Quantize Weights")
348
+
349
+ with gr.Column():
350
+ quant_details = gr.Markdown()
351
+
352
+ with gr.Row():
353
+ c_orig_weights = gr.Image(label="Original Weights")
354
+ c_quant_weights = gr.Image(label="Quantized Weights (INT8)")
355
+ c_dequant_weights = gr.Image(label="Dequantized Weights")
356
+
357
+ with gr.Row():
358
+ c_orig_dist = gr.Image(label="Original Weight Distribution")
359
+ c_quant_dist = gr.Image(label="Quantized Weight Distribution")
360
+ c_error_vis = gr.Image(label="Quantization Error")
361
+
362
+ with gr.Row():
363
+ scales_dist = gr.Image(label="Quantization Scales Distribution")
364
+
365
+ with gr.TabItem("About"):
366
+ gr.Markdown("""
367
+ ## 8-bit Quantizer Implementation
368
+
369
+ This implementation includes:
370
+
371
+ 1. **W8A16LinearLayer** - A PyTorch module that uses INT8 weights and FP16/BF16/FP32 activations
372
+ 2. **Quantization** - Converts FP32/FP16/BF16 weights to INT8 using per-output-channel scaling
373
+ 3. **Visualization** - Shows the impact of quantization on weight distributions and errors
374
+
375
+ ### How It Works:
376
+
377
+ 1. For each output channel, find the maximum absolute weight value
378
+ 2. Scale all weights in that channel so the maximum fits in INT8 range (-128 to 127)
379
+ 3. Round scaled weights to integers and store as INT8
380
+ 4. During inference, multiply INT8 weights by scaling factors to recover approximate FP values
381
+
382
+ The quantization process reduces memory usage by up to 75% compared to FP32 weights.
383
+
384
+ ### References:
385
+
386
+ - This implementation is based on modern techniques used in LLM quantization
387
+ - Similar methods are used in libraries like bitsandbytes, AutoGPTQ, and GPTQ-for-LLaMa
388
+ """)
389
+
390
+ # Connect buttons to functions
391
+ init_btn.click(
392
+ initialize_model,
393
+ inputs=[in_feat, out_feat, with_bias, dtype],
394
+ outputs=[model_info, io_info, orig_weights, quant_weights, dequant_weights, orig_dist, quant_dist, error_vis]
395
+ )
396
+
397
+ quantize_btn.click(
398
+ quantize_custom_weights,
399
+ inputs=[c_in_feat, c_out_feat, c_with_bias, c_dtype, weight_pattern],
400
+ outputs=[quant_details, c_orig_weights, c_quant_weights, c_dequant_weights, c_orig_dist, c_quant_dist, c_error_vis, scales_dist]
401
+ )
402
+
403
+ # Launch the app
404
+ if __name__ == "__main__":
405
+ demo.launch()