File size: 6,143 Bytes
baf9de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Quantization reduces a bit representation to less bits for efficient storage or computation.
# Most floating point data types have a mapping from a bit representation, e.g. 0010 = 2 to a floating
# point representation 2 -> 2 / max(0010) = 2/15 = 0.133333
# As such, we can represent a floating point quantization a mapping from integers to floating point values, e.g.
# [0, 1, 2, 3] -> [-1.0, -0.25, 0.25 , 1.0]
import numpy as np
from scipy.spatial.distance import cdist

index = np.array([0, 1, 2, 3, 4, 5, 6, 7])
values = np.linspace(-1.0, 1.0, 8) # 3-bit linear quantization
print('quantization values:', values)

# To quantize an input distribution we first need to normalize its range into the range of the quantization values, in this case [-1.0, 1.0]
# We can do this through division by the abolute maximum value if our distribution is roughly symmetric (most distribution in deep learning are noramlly distributed)

rand_inputs = np.random.randn(1024, 1024).astype(np.float32)

absmax = np.max(np.abs(rand_inputs))
normed = rand_inputs / absmax
print('normalized min and max range', np.min(normed), np.max(normed))

# The next step is to round the input value to the closest quantization value. 
# This can be done by performing a binary search of each element of the normalized input tensor with respect to the sorted values array:
# In this case, we simply compute the distance between all values and find the closest directly.

dist = cdist(normed.flatten().reshape(-1, 1), values.reshape(-1, 1))
closest_idx = np.argmin(dist, 1).reshape(rand_inputs.shape)

val, count = np.unique(closest_idx, return_counts=True)
print('Values:', val)
print('Count:', count)

# Closest index now represents the quantized 3 bit representation (4 different values). We can use this representation to store the data efficiently.


# ==================DEQUANTIZATION========================
# To dequantize the tensor we reverse the operations the we did
# 1. lookup the values corresponding to the 3-bit index
# 2. Denormalize by multipying by absmax

dequant = values[closest_idx]*absmax
# mean absolute error:
error = np.abs(dequant-rand_inputs).mean()
print(f'Absolute linear 3-bit quantization error: {error:.4f}')

# This yields an error of about 0.34 per value. We can do better with non-linear quantization.

# ==================NON-LINEAR QUANTIZATION========================
# In non-linear quantization the distance between quantization values is not always equal.
# This allows us to allocate more values to regions of high density. For example, the normal distribution has many values around 0.
# This can reduce the overall error in the distribution.
index = np.array([0, 1, 2, 3, 4, 5, 6, 7])
values = np.array([-1.0, -0.5, -0.25, -0.075, 0.075, 0.25, 0.5, 1.0])

dist = cdist(normed.flatten().reshape(-1, 1), values.reshape(-1, 1))
closest_idx = np.argmin(dist, 1).reshape(rand_inputs.shape)

val, count = np.unique(closest_idx, return_counts=True)
print('Values:', val)
print('Count:', count)

dequant = values[closest_idx]*absmax
error = np.abs(dequant-rand_inputs).mean()
print(f'Absolute non-linear 3-bit quantization error: {error:.4f}')

# dynamic quantization
# Adaptive from: https://github.com/facebookresearch/bitsandbytes/blob/main/bitsandbytes/functional.py
def create_dynamic_map(signed=True, n=7):
    '''
    Creates the dynamic quantiztion map.
    The dynamic data type is made up of a dynamic exponent and
    fraction. As the exponent increase from 0 to -7 the number
    of bits available for the fraction shrinks.
    This is a generalization of the dynamic type where a certain
    number of the bits and be reserved for the linear quantization
    region (the fraction). n determines the maximum number of
    exponent bits.
    For more details see
    (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
    '''

    data = []
    # these are additional items that come from the case
    # where all the exponent bits are zero and no
    # indicator bit is present
    additional_items = 2**(7-n)-1
    if not signed: additional_items = 2*additional_items
    for i in range(n):
        fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1
        boundaries = np.linspace(0.1, 1, fraction_items)
        means = (boundaries[:-1]+boundaries[1:])/2.0
        data += ((10**(-(n-1)+i))*means).tolist()
        if signed:
            data += (-(10**(-(n-1)+i))*means).tolist()

    if additional_items > 0:
        boundaries = np.linspace(0.1, 1, additional_items+1)
        means = (boundaries[:-1]+boundaries[1:])/2.0
        data += ((10**(-(n-1)+i))*means).tolist()
        if signed:
            data += (-(10**(-(n-1)+i))*means).tolist()

    data.append(0)
    data.append(1.0)
    data.sort()
    return np.array(data)

import time

values = create_dynamic_map(signed=True)

t0 = time.time()
dist = cdist(normed.flatten().reshape(-1, 1), values.reshape(-1, 1))
closest_idx = np.argmin(dist, 1).reshape(rand_inputs.shape)
quant_time = time.time()-t0

dequant = values[closest_idx]*absmax
error = np.abs(dequant-rand_inputs).mean()
print(f'Absolute dynamic 8-bit quantization error: {error:.4f}')
print(f'Total time taken: {quant_time:.4f} seconds.')

# This yields an error as low as 0.012. We could do even better when we use block-wise quantization.
# But performing block-wise quantization without optimized code is a bit slow. We can use the bitsandbytes library to do this quickly.

import torch
import bitsandbytes.functional as F

rand_inputs = torch.from_numpy(rand_inputs)
t0 = time.time()
quant_values, quant_state = F.quantize_blockwise(rand_inputs)
quant_time = time.time()-t0
dequant_values = F.dequantize_blockwise(quant_values, quant_state)

error = torch.abs(dequant_values-rand_inputs).mean().item()
print(f'Absolute dynamic block-wise 8-bit quantization error: {error:.4f}')
print(f'Total time taken (CPU): {quant_time:.4f} seconds.')

rand_inputs = rand_inputs.cuda()
t0 = time.time()
quant_values, quant_state = F.quantize_blockwise(rand_inputs)
quant_time = time.time()-t0
print(f'Total time taken (GPU): {quant_time:.4f} seconds.')