File size: 6,574 Bytes
0e4f45d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import numpy as np
import torch
import torch.nn.functional as NF

from torch_scatter import scatter_mean


class FeatureBank:

    def __init__(self, obj_n, memory_budget, device, update_rate=0.1, thres_close=0.95):
        self.obj_n = obj_n
        self.update_rate = update_rate
        self.thres_close = thres_close
        self.device = device

        self.info = [None for _ in range(obj_n)]
        self.peak_n = np.zeros(obj_n)
        self.replace_n = np.zeros(obj_n)

        self.class_budget = memory_budget // obj_n
        if obj_n == 2:
            self.class_budget = 0.8 * self.class_budget

        self.keys = None
        self.values = None

    def init_bank(self, keys, values, frame_idx=0):

        self.keys = keys
        self.values = values

        for class_idx in range(self.obj_n):
            _, bank_n = keys[class_idx].shape
            self.info[class_idx] = torch.zeros((bank_n, 2), device=self.device)
            self.info[class_idx][:, 0] = frame_idx
            self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])

    def append(self, keys, values, frame_idx=0):

        if self.keys:
            for class_idx in range(self.obj_n):
                self.keys[class_idx] = torch.cat([self.keys[class_idx], keys[class_idx]], dim=1)
                self.values[class_idx] = torch.cat([self.values[class_idx], values[class_idx]], dim=1)

                _, bank_n = keys[class_idx].shape
                new_info = torch.ones((bank_n, 2), device=self.device) * 20 # zeros
                new_info[:, 0] = frame_idx
                self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)
                self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])
        else:
            self.init_bank(keys, values, frame_idx)

    def update(self, prev_key, prev_value, frame_idx, update_rate=-1):

        if update_rate == -1:
            update_rate = self.update_rate

        for class_idx in range(self.obj_n):

            d_key, bank_n = self.keys[class_idx].shape
            d_val, _ = self.values[class_idx].shape

            normed_keys = NF.normalize(self.keys[class_idx], dim=0)
            normed_prev_key = NF.normalize(prev_key[class_idx], dim=0)
            mag_keys = self.keys[class_idx].norm(p=2, dim=0)
            corr = torch.mm(normed_keys.transpose(0, 1), normed_prev_key)  # bank_n, prev_n
            related_bank_idx = corr.argmax(dim=0, keepdim=True)  # 1, HW
            related_bank_corr = torch.gather(corr, 0, related_bank_idx)  # 1, HW

            # greater than threshold, merge them
            selected_idx = (related_bank_corr[0] > self.thres_close).nonzero(as_tuple=False)
            class_related_bank_idx = related_bank_idx[0, selected_idx[:, 0]]  # selected_HW
            unique_related_bank_idx, cnt = class_related_bank_idx.unique(dim=0, return_counts=True)  # selected_HW

            # Update key
            key_bank_update = torch.zeros((d_key, bank_n), dtype=torch.float, device=self.device)  # d_key, THW
            key_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_key, -1)  # d_key, HW
            scatter_mean(normed_prev_key[:, selected_idx[:, 0]], key_bank_idx, dim=1, out=key_bank_update)
            # d_key, selected_HW

            self.keys[class_idx][:, unique_related_bank_idx] = \
                mag_keys[unique_related_bank_idx] * \
                ((1 - update_rate) * normed_keys[:, unique_related_bank_idx] + \
                 update_rate * key_bank_update[:, unique_related_bank_idx])

            # Update value
            normed_values = NF.normalize(self.values[class_idx], dim=0)
            normed_prev_value = NF.normalize(prev_value[class_idx], dim=0)
            mag_values = self.values[class_idx].norm(p=2, dim=0)
            val_bank_update = torch.zeros((d_val, bank_n), dtype=torch.float, device=self.device)
            val_bank_idx = class_related_bank_idx.unsqueeze(0).expand(d_val, -1)
            scatter_mean(normed_prev_value[:, selected_idx[:, 0]], val_bank_idx, dim=1, out=val_bank_update)

            self.values[class_idx][:, unique_related_bank_idx] = \
                mag_values[unique_related_bank_idx] * \
                ((1 - update_rate) * normed_values[:, unique_related_bank_idx] + \
                 update_rate * val_bank_update[:, unique_related_bank_idx])

            # less than the threshold, concat them
            selected_idx = (related_bank_corr[0] <= self.thres_close).nonzero(as_tuple=False)

            if self.class_budget < bank_n + selected_idx.shape[0]:
                self.remove(class_idx, selected_idx.shape[0], frame_idx)

            self.keys[class_idx] = torch.cat([self.keys[class_idx], prev_key[class_idx][:, selected_idx[:, 0]]], dim=1)
            self.values[class_idx] = \
                torch.cat([self.values[class_idx], prev_value[class_idx][:, selected_idx[:, 0]]], dim=1)

            new_info = torch.zeros((selected_idx.shape[0], 2), device=self.device)
            new_info[:, 0] = frame_idx
            self.info[class_idx] = torch.cat([self.info[class_idx], new_info], dim=0)

            self.peak_n[class_idx] = max(self.peak_n[class_idx], self.info[class_idx].shape[0])

            self.info[class_idx][:, 1] = torch.clamp(self.info[class_idx][:, 1], 0, 1e5)  # Prevent inf

    def remove(self, class_idx, request_n, frame_idx):

        old_size = self.keys[class_idx].shape[1]

        LFU = frame_idx - self.info[class_idx][:, 0]  # time length
        LFU = self.info[class_idx][:, 1] / LFU
        thres_dynamic = int(LFU.min()) + 1
        iter_cnt = 0

        while True:
            selected_idx = LFU > thres_dynamic
            self.keys[class_idx] = self.keys[class_idx][:, selected_idx]
            self.values[class_idx] = self.values[class_idx][:, selected_idx]
            self.info[class_idx] = self.info[class_idx][selected_idx]
            LFU = LFU[selected_idx]
            iter_cnt += 1

            balance = (self.class_budget - self.keys[class_idx].shape[1]) - request_n
            if balance < 0:
                thres_dynamic = int(LFU.min()) + 1
            else:
                break

        new_size = self.keys[class_idx].shape[1]
        self.replace_n[class_idx] += old_size - new_size

        return balance

    def print_peak_mem(self):

        ur = self.peak_n / self.class_budget
        rr = self.replace_n / self.class_budget
        print(f'Obj num: {self.obj_n}.', f'Budget / obj: {self.class_budget}.', f'UR: {ur}.', f'Replace: {rr}.')