nickfraser commited on
Commit
4024f9d
·
1 Parent(s): dca9b6e

Updated math model to target int8 x int8 kernels.

Browse files
Files changed (3) hide show
  1. math_model.py +8 -6
  2. test_quant_conv2d.py +13 -8
  3. test_quant_linear.py +11 -7
math_model.py CHANGED
@@ -47,11 +47,12 @@ class QuantLinear(nn.Module):
47
  # - multiply this sum with every weight zero-point (e.g., `torch.sum(quant_input, dim=-1) * self.weight_zp`
48
  # - Subtract from previous output (e.g., `quant_output -= torch.sum(quant_input, dim=-1) * self.weight_zp`)
49
  # - All other code is just to make sure the broadcasting semantics work correctly
50
- quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True).to(torch.uint8)
 
51
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
52
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
53
  quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
54
- correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * (-self.weight_zp).to(torch.uint8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
55
  quant_output = quant_output + correction
56
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
57
  output += self.linear.bias
@@ -103,15 +104,16 @@ class QuantConv2d(nn.Module):
103
  # - multiply this sum with every weight zero-point (e.g., `sum * self.weight_zp`
104
  # - Subtract from previous output (e.g., `quant_output -= sum * self.weight_zp`)
105
  # - All other code is just to make sure the broadcasting semantics work correctly
106
- quant_weight = quantize(self.conv2d.weight, self.weight_scale, self.weight_zp, is_asym=True).to(torch.uint8)
 
107
  b_shape = list(quant_weight.shape) # Used for weight zero-point correction
108
  b_shape[0] = 1 # Used for weight zero-point correction
109
- weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.uint8) # Used for weight zero-point correction
110
- quant_weight = torch.cat((quant_weight,weight_cat),dim=0).to(torch.uint8) # Create extra output channel, used for weight zero-point correction
111
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
112
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
113
  quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
114
- correction = quant_output[:,-1,:,:] * (-self.weight_zp).to(torch.uint8).view([1, self.weight_zp.nelement()] + [1]*(quant_output.ndim-2)) # Correct zero-point for weight
115
  quant_output = quant_output[:,:-1,:,:] + correction
116
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0)
117
  output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
 
47
  # - multiply this sum with every weight zero-point (e.g., `torch.sum(quant_input, dim=-1) * self.weight_zp`
48
  # - Subtract from previous output (e.g., `quant_output -= torch.sum(quant_input, dim=-1) * self.weight_zp`)
49
  # - All other code is just to make sure the broadcasting semantics work correctly
50
+ weight_zp_int8 = (self.weight_zp - 128).to(torch.int8).to(torch.float32)
51
+ quant_weight = quantize(self.linear.weight, self.weight_scale, weight_zp_int8, is_asym=False).to(torch.int8)
52
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
53
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
54
  quant_output = torch.nn.functional.linear(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.linear quantizing the output to int8
55
+ correction = torch.sum(quant_input, dim=-1, keepdim=True).to(torch.int32) * (-weight_zp_int8).to(torch.int8).view([1]*(quant_input.ndim-1) + [self.weight_zp.nelement()]) # Correct for weight zero-point
56
  quant_output = quant_output + correction
57
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1]*(quant_output.ndim-1) + [(self.weight_scale * self.input_scale).nelement()]), 0.0)
58
  output += self.linear.bias
 
104
  # - multiply this sum with every weight zero-point (e.g., `sum * self.weight_zp`
105
  # - Subtract from previous output (e.g., `quant_output -= sum * self.weight_zp`)
106
  # - All other code is just to make sure the broadcasting semantics work correctly
107
+ weight_zp_int8 = (self.weight_zp - 128).to(torch.int8).to(torch.float32)
108
+ quant_weight = quantize(self.conv2d.weight, self.weight_scale, weight_zp_int8, is_asym=False).to(torch.int8)
109
  b_shape = list(quant_weight.shape) # Used for weight zero-point correction
110
  b_shape[0] = 1 # Used for weight zero-point correction
111
+ weight_cat = torch.ones((1,1,1,1)).broadcast_to(b_shape).to(torch.int8) # Used for weight zero-point correction
112
+ quant_weight = torch.cat((quant_weight,weight_cat),dim=0).to(torch.int8) # Create extra output channel, used for weight zero-point correction
113
  fused_input_scale = self.input_scale / self.mul_factor # Fuse SmoothQuant and input scales, can be computed offline
114
  quant_input = quantize(x, fused_input_scale, self.input_zp, is_asym=False).to(torch.int8)
115
  quant_output = torch.nn.functional.conv2d(quant_input.to(torch.float32), quant_weight.to(torch.float32), None).to(torch.int32) # Convert inputs to FP32 to avoid F.conv2d quantizing the output to int8
116
+ correction = quant_output[:,-1,:,:] * (-weight_zp_int8).to(torch.int8).view([1, self.weight_zp.nelement()] + [1]*(quant_output.ndim-2)) # Correct zero-point for weight
117
  quant_output = quant_output[:,:-1,:,:] + correction
118
  output = dequantize(quant_output, (self.weight_scale * self.input_scale).view([1, (self.weight_scale * self.input_scale).nelement()] + [1]*(quant_output.ndim-2)), 0.0)
119
  output += self.conv2d.bias.view([1, self.conv2d.bias.nelement()] + [1]*(quant_output.ndim-2))
test_quant_conv2d.py CHANGED
@@ -1,23 +1,28 @@
1
  import torch
 
2
  from math_model import QuantConv2d
3
 
4
  torch.manual_seed(0)
5
 
6
  batch_size = 1
7
- out_ch = 8
8
- in_ch = 4
9
  k = 3
10
  h = 5
11
  w = 5
12
 
 
 
 
13
  quant_params = {
14
  'smoothquant_mul': torch.rand((in_ch,)),
15
  'smoothquant_mul_shape': (1,in_ch,1,1),
16
  'weight_scale': torch.rand((out_ch,)),
 
17
  'weight_scale_shape': (out_ch,1,1,1),
18
- 'weight_zp': torch.randint(-255, 0, (out_ch,)),
19
  'weight_zp_shape': (out_ch,1,1,1),
20
- 'input_scale': torch.rand((1,)),
21
  'input_scale_shape': tuple(),
22
  'input_zp': torch.zeros((1,)),
23
  'input_zp_shape': tuple(),
@@ -25,10 +30,10 @@ quant_params = {
25
 
26
  print(quant_params)
27
 
28
- l = QuantConv2d(in_ch, out_ch, k, quant_params)
29
- i = torch.rand((batch_size,in_ch,h,w))
30
- o_qdq = l(i)
31
- o_qop = l(i, qop=True)
32
  print(o_qdq.shape)
33
  print(o_qop.shape)
34
  print(o_qdq - o_qop)
 
1
  import torch
2
+ import torch.nn as nn
3
  from math_model import QuantConv2d
4
 
5
  torch.manual_seed(0)
6
 
7
  batch_size = 1
8
+ out_ch = 128
9
+ in_ch = 64
10
  k = 3
11
  h = 5
12
  w = 5
13
 
14
+ i = 2*torch.rand((batch_size,in_ch,h,w)) - 1.
15
+ l = nn.Conv2d(in_ch, out_ch, k, bias=True)
16
+
17
  quant_params = {
18
  'smoothquant_mul': torch.rand((in_ch,)),
19
  'smoothquant_mul_shape': (1,in_ch,1,1),
20
  'weight_scale': torch.rand((out_ch,)),
21
+ 'weight_scale': torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values / 128.,
22
  'weight_scale_shape': (out_ch,1,1,1),
23
+ 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=(1,2,3))) * (128 / torch.max(torch.abs(torch.flatten(l.weight, start_dim=1)), dim=1).values)) + 128, 0, 255),
24
  'weight_zp_shape': (out_ch,1,1,1),
25
+ 'input_scale': torch.max(torch.abs(i)) / 128.,
26
  'input_scale_shape': tuple(),
27
  'input_zp': torch.zeros((1,)),
28
  'input_zp_shape': tuple(),
 
30
 
31
  print(quant_params)
32
 
33
+ ql = QuantConv2d(in_ch, out_ch, k, quant_params)
34
+ ql.conv2d.load_state_dict(l.state_dict())
35
+ o_qdq = ql(i)
36
+ o_qop = ql(i, qop=True)
37
  print(o_qdq.shape)
38
  print(o_qop.shape)
39
  print(o_qdq - o_qop)
test_quant_linear.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  from math_model import QuantLinear
3
 
4
  torch.manual_seed(0)
@@ -7,14 +8,17 @@ batch_size = 1
7
  out_ch = 128
8
  in_ch = 64
9
 
 
 
 
10
  quant_params = {
11
  'smoothquant_mul': torch.rand((in_ch,)),
12
  'smoothquant_mul_shape': (1,in_ch),
13
- 'weight_scale': torch.rand((out_ch,)),
14
  'weight_scale_shape': (out_ch,1),
15
- 'weight_zp': torch.randint(-255, 0, (out_ch,)),
16
  'weight_zp_shape': (out_ch,1),
17
- 'input_scale': torch.rand((1,)),
18
  'input_scale_shape': tuple(),
19
  'input_zp': torch.zeros((1,)),
20
  'input_zp_shape': tuple(),
@@ -22,10 +26,10 @@ quant_params = {
22
 
23
  print(quant_params)
24
 
25
- l = QuantLinear(in_ch, out_ch, quant_params)
26
- i = torch.rand((batch_size,in_ch))
27
- o_qdq = l(i)
28
- o_qop = l(i, qop=True)
29
  print(o_qdq.shape)
30
  print(o_qop.shape)
31
  print(o_qdq - o_qop)
 
1
  import torch
2
+ import torch.nn as nn
3
  from math_model import QuantLinear
4
 
5
  torch.manual_seed(0)
 
8
  out_ch = 128
9
  in_ch = 64
10
 
11
+ i = 2*torch.rand((batch_size,in_ch)) - 1.
12
+ l = nn.Linear(in_ch, out_ch, bias=True)
13
+
14
  quant_params = {
15
  'smoothquant_mul': torch.rand((in_ch,)),
16
  'smoothquant_mul_shape': (1,in_ch),
17
+ 'weight_scale': torch.max(torch.abs(l.weight), dim=1).values / 128.,
18
  'weight_scale_shape': (out_ch,1),
19
+ 'weight_zp': torch.clamp(torch.round((torch.mean((l.weight), dim=1)) * (128 / torch.max(torch.abs(l.weight), dim=1).values)) + 128, 0, 255),
20
  'weight_zp_shape': (out_ch,1),
21
+ 'input_scale': torch.max(torch.abs(i)) / 128.,
22
  'input_scale_shape': tuple(),
23
  'input_zp': torch.zeros((1,)),
24
  'input_zp_shape': tuple(),
 
26
 
27
  print(quant_params)
28
 
29
+ ql = QuantLinear(in_ch, out_ch, quant_params)
30
+ ql.linear.load_state_dict(l.state_dict())
31
+ o_qdq = ql(i)
32
+ o_qop = ql(i, qop=True)
33
  print(o_qdq.shape)
34
  print(o_qop.shape)
35
  print(o_qdq - o_qop)