File size: 6,574 Bytes
0e4f45d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import numpy as np
import torch
import torch.nn.functional as NF
from torch_scatter import scatter_mean
class FeatureBank:
def __init__(self, obj_n, memory_budget, device, update_rate=0.1, thres_close=0.95):
self.obj_n = obj_n
self.update_rate = update_rate
self.thres_close = thres_close
self.device = device
self.info = [None for _ in range(obj_n)]
self.peak_n = np.zeros(obj_n)
self.replace_n = np.zeros(obj_n)
self.class_budget = memory_budget // obj_n
if obj_n == 2:
self.class_budget = 0.8 * self.class_budget
self.keys = None
self.values = None
def init_bank(self, keys, values, frame_idx=0):
self.keys = keys
self.values = values
for class_idx in range(self.obj_n):
_, bank_n = keys[class_idx].shape
self.info[class_idx] = torch.zeros((bank_n, 2), device=self.device)
self.info[class_idx][:, 0] = frame_idx
self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
def append(self, keys, values, frame_idx=0):
if self.keys:
for class_idx in range(self.obj_n):
self.keys[class_idx] = torch.cat([self.keys[class_idx], keys[class_idx]], dim=1)
self.values[class_idx] = torch.cat([self.values[class_idx], values[class_idx]], dim=1)
_, bank_n = keys[class_idx].shape
new_info = torch.ones((bank_n, 2), device=self.device) * 20 # zeros
new_info[:, 0] = frame_idx
self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)
self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
else:
self.init_bank(keys, values, frame_idx)
def update(self, prev_key, prev_value, frame_idx, update_rate=-1):
if update_rate == -1:
update_rate = self.update_rate
for class_idx in range(self.obj_n):
d_key, bank_n = self.keys[class_idx].shape
d_val, _ = self.values[class_idx].shape
normed_keys = NF.normalize(self.keys[class_idx], dim=0)
normed_prev_key = NF.normalize(prev_key[class_idx], dim=0)
mag_keys = self.keys[class_idx].norm(p=2, dim=0)
corr = torch.mm(normed_keys.transpose(0, 1), normed_prev_key) # bank_n, prev_n
related_bank_idx = corr.argmax(dim=0, keepdim=True) # 1, HW
related_bank_corr = torch.gather(corr, 0, related_bank_idx) # 1, HW
# greater than threshold, merge them
selected_idx = (related_bank_corr[0] > self.thres_close).nonzero(as_tuple=False)
class_related_bank_idx = related_bank_idx[0, selected_idx[:, 0]] # selected_HW
unique_related_bank_idx, cnt = class_related_bank_idx.unique(dim=0, return_counts=True) # selected_HW
# Update key
key_bank_update = torch.zeros((d_key, bank_n), dtype=torch.float, device=self.device) # d_key, THW
key_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_key, -1) # d_key, HW
scatter_mean(normed_prev_key[:, selected_idx[:, 0]], key_bank_idx, dim=1, out=key_bank_update)
# d_key, selected_HW
self.keys[class_idx][:, unique_related_bank_idx] = \
mag_keys[unique_related_bank_idx] * \
((1 - update_rate) * normed_keys[:, unique_related_bank_idx] + \
update_rate * key_bank_update[:, unique_related_bank_idx])
# Update value
normed_values = NF.normalize(self.values[class_idx], dim=0)
normed_prev_value = NF.normalize(prev_value[class_idx], dim=0)
mag_values = self.values[class_idx].norm(p=2, dim=0)
val_bank_update = torch.zeros((d_val, bank_n), dtype=torch.float, device=self.device)
val_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_val, -1)
scatter_mean(normed_prev_value[:, selected_idx[:, 0]], val_bank_idx, dim=1, out=val_bank_update)
self.values[class_idx][:, unique_related_bank_idx] = \
mag_values[unique_related_bank_idx] * \
((1 - update_rate) * normed_values[:, unique_related_bank_idx] + \
update_rate * val_bank_update[:, unique_related_bank_idx])
# less than the threshold, concat them
selected_idx = (related_bank_corr[0] <= self.thres_close).nonzero(as_tuple=False)
if self.class_budget < bank_n + selected_idx.shape[0]:
self.remove(class_idx, selected_idx.shape[0], frame_idx)
self.keys[class_idx] = torch.cat([self.keys[class_idx], prev_key[class_idx][:, selected_idx[:, 0]]], dim=1)
self.values[class_idx] = \
torch.cat([self.values[class_idx], prev_value[class_idx][:, selected_idx[:, 0]]], dim=1)
new_info = torch.zeros((selected_idx.shape[0], 2), device=self.device)
new_info[:, 0] = frame_idx
self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)
self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
self.info[class_idx][:, 1] = torch.clamp(self.info[class_idx][:, 1], 0, 1e5) # Prevent inf
def remove(self, class_idx, request_n, frame_idx):
old_size = self.keys[class_idx].shape[1]
LFU = frame_idx - self.info[class_idx][:, 0] # time length
LFU = self.info[class_idx][:, 1] / LFU
thres_dynamic = int(LFU.min()) + 1
iter_cnt = 0
while True:
selected_idx = LFU > thres_dynamic
self.keys[class_idx] = self.keys[class_idx][:, selected_idx]
self.values[class_idx] = self.values[class_idx][:, selected_idx]
self.info[class_idx] = self.info[class_idx][selected_idx]
LFU = LFU[selected_idx]
iter_cnt += 1
balance = (self.class_budget - self.keys[class_idx].shape[1]) - request_n
if balance < 0:
thres_dynamic = int(LFU.min()) + 1
else:
break
new_size = self.keys[class_idx].shape[1]
self.replace_n[class_idx] += old_size - new_size
return balance
def print_peak_mem(self):
ur = self.peak_n / self.class_budget
rr = self.replace_n / self.class_budget
print(f'Obj num: {self.obj_n}.', f'Budget / obj: {self.class_budget}.', f'UR: {ur}.', f'Replace: {rr}.')
|