# Copyright 2024 Flash-VStream Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import torch import torch.nn as nn import torch.nn.functional as F def drop_feature(img_feature, video_max_frames, img_similarity=None): T, P, D = img_feature.shape indices = [[i] for i in range(T)] T0 = video_max_frames if T <= T0: return img_feature, img_similarity, [indices] cur_feature = img_feature[:T0] # [T0, P, D] if img_similarity is not None: cur_sim = img_similarity[:T0 - 1] else: cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1] cur_indices = indices[:T0] step_indices = [cur_indices] for i in range(T0, T): new_feature = img_feature[i] new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0) all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) all_indices = cur_indices + [[i]] all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0) idx = torch.argmax(all_sim) if random.randint(0, 1) > 0: idx = idx + 1 cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) if idx + 1 == T0 + 1: cur_sim = all_sim[:T0 - 1] cur_indices = all_indices[:-1] elif idx == 0: cur_sim = all_sim[1:] cur_indices = all_indices[1:] else: cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]]) cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0) cur_indices = all_indices[:idx] + all_indices[idx + 1:] step_indices.append(cur_indices) # print(f'Note: perform drop feature {img_feature.shape} to {cur_feature.shape}') return cur_feature, cur_sim, step_indices def merge_feature(img_feature, video_max_frames, img_similarity=None): T, P, D = img_feature.shape indices = [[i] for i in range(T)] T0 = video_max_frames if T <= T0: return img_feature, img_similarity, [indices] cur_feature = img_feature[:T0] # [T0, P, D] cur_indices = indices[:T0] step_indices = [cur_indices] if img_similarity is not None: cur_sim = img_similarity[:T0 - 1] else: cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1] for i in range(T0, T): new_feature = img_feature[i] new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0) all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0) all_indices = cur_indices + [[i]] idx = torch.argmax(all_sim) all_feature[idx + 1] = (all_feature[idx] + all_feature[idx + 1]) / 2.0 all_indices[idx + 1] = all_indices[idx] + all_indices[idx + 1] cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]]) cur_indices = all_indices[:idx] + all_indices[idx + 1:] if idx > 0: cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0) if idx + 1 < T0: cur_sim[idx] = F.cosine_similarity(all_feature[idx + 1].view(-1), all_feature[idx + 2].view(-1), dim=0) step_indices.append(cur_indices) # print(f'Note: perform merge feature {img_feature.shape} to {cur_feature.shape}') return cur_feature, cur_sim, step_indices def kmeans_feature(img_feature, video_max_frames, img_similarity=None): def kmeans_torch(X, num_clusters, distance='euclidean', tol=1e-4, max_iter=10): indices = torch.randperm(X.size(0))[:num_clusters] centroids = X[indices] for i in range(max_iter): if distance == 'euclidean': dists = torch.cdist(X, centroids, p=2) else: raise NotImplementedError("Only Euclidean distance is supported yet") labels = torch.argmin(dists, dim=1) new_centroids = [] for j in range(num_clusters): cluster_points = X[labels == j] if len(cluster_points) > 0: new_centroid = cluster_points.mean(0) else: # fix nan centroids new_centroid = X[random.randint(0, X.size(0) - 1)] new_centroids.append(new_centroid) new_centroids = torch.stack(new_centroids) diff = torch.norm(centroids - new_centroids, dim=1).sum() if diff < tol: break centroids = new_centroids return centroids, labels, i T, P, D = img_feature.shape T0 = video_max_frames if T <= T0: return img_feature, img_similarity, [[[i] for i in range(T)]] X = img_feature.view(T, -1) # [T, P, D] centroids, labels, exit_step = kmeans_torch(X, T0) reduced_feature = centroids.view(T0, P, D) # print(f'Note: perform kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0 step_indices = [[] for _ in range(T0)] for i in range(T0): step_indices[i] = [j for j in range(T) if labels[j] == i] return reduced_feature, img_similarity, [step_indices] def weighted_kmeans_feature(img_feature, video_max_frames, weights=None): if weights is None: weights = torch.ones(img_feature.size(0), dtype=img_feature.dtype, device=img_feature.device) def weighted_kmeans_torch(X, num_clusters, weights=None, distance='euclidean', tol=1e-4, max_iter=10): indices = torch.randperm(X.size(0), device=X.device)[:num_clusters] centroids = X[indices] for i in range(max_iter): if distance == 'euclidean': dists = ((X.unsqueeze(1) - centroids.unsqueeze(0)) ** 2).sum(dim=2).sqrt() else: raise NotImplementedError("Only Euclidean distance is supported yet") labels = torch.argmin(dists, dim=1) weighted_sum = torch.zeros_like(centroids) weights_sum = torch.zeros(num_clusters, dtype=X.dtype, device=X.device) for j in range(num_clusters): cluster_mask = labels == j weighted_sum[j] = torch.sum(weights[cluster_mask, None] * X[cluster_mask], dim=0) weights_sum[j] = torch.sum(weights[cluster_mask]) mask = weights_sum > 0 new_centroids = torch.zeros_like(weighted_sum) new_centroids[mask] = weighted_sum[mask] / weights_sum[mask, None] if mask.sum() < num_clusters: # fix nan centroids new_centroids[~mask] = torch.stack([X[random.randint(0, X.size(0) - 1)] for _ in range(num_clusters - mask.sum())]) diff = torch.norm(centroids - new_centroids, dim=1).sum() if diff < tol: break centroids = new_centroids return centroids, labels, weights_sum, i T, P, D = img_feature.shape T0 = video_max_frames if T <= T0: return img_feature, weights, [[[i] for i in range(T)]] X = img_feature.view(T, -1) # [T, P, D] centroids, labels, weights, exit_step = weighted_kmeans_torch(X, T0, weights) reduced_feature = centroids.view(T0, P, D) # print(f'Note: perform weighted kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0 step_indices = [[] for _ in range(T0)] for i in range(T0): step_indices[i] = [j for j in range(T) if labels[j] == i] return reduced_feature, weights, [step_indices] def k_drop_feature(img_feature, video_max_frames, img_similarity=None): T, P, D = img_feature.shape indices = [[i] for i in range(T)] T0 = video_max_frames if T <= T0: return img_feature, img_similarity, [indices] cur_feature = img_feature[:T0] # [T0, P, D] normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1) cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0] cur_sim.fill_diagonal_(-100.0) cur_indices = indices[:T0] step_indices = [cur_indices] for i in range(T0, T): # get new feature new_feature = img_feature[i] normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1) new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1] all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0) all_indices = cur_indices + [[i]] # get new similarity all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1] all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1] all_sim[-1, :-1] = new_sim.T # choose compression position idx = torch.argmax(all_sim) left, right = idx // (T0 + 1), idx % (T0 + 1) if random.randint(0, 1) > 0: idx = left else: idx = right assert all_sim[left, right] == torch.max(all_sim) # get compressed feature and similarity cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]]) normed_cur_features = torch.cat([normed_all_features[:idx], normed_all_features[idx + 1:]]) cur_indices = all_indices[:idx] + all_indices[idx + 1:] cur_sim_1 = torch.cat([all_sim[:idx], all_sim[idx + 1:]], dim=0) # [T0, T0 + 1] cur_sim = torch.cat([cur_sim_1[:, :idx], cur_sim_1[:, idx + 1:]], dim=1) # [T0, T0] step_indices.append(cur_indices) # print(f'Note: perform k-drop feature {img_feature.shape} to {cur_feature.shape}') return cur_feature, None, step_indices def k_merge_feature(img_feature, video_max_frames, img_similarity=None): T, P, D = img_feature.shape indices = [[i] for i in range(T)] T0 = video_max_frames if T <= T0: return img_feature, img_similarity, [indices] cur_feature = img_feature[:T0] # [T0, P, D] normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1) cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0] cur_sim.fill_diagonal_(-100.0) cur_indices = indices[:T0] step_indices = [cur_indices] for i in range(T0, T): # get new feature new_feature = img_feature[i] normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1) new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1] all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0) normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0) all_indices = cur_indices + [[i]] # get new similarity all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1] all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1] all_sim[-1, :-1] = new_sim.T # choose compression position idx = torch.argmax(all_sim) left, right = idx // (T0 + 1), idx % (T0 + 1) assert all_sim[left, right] == torch.max(all_sim) # update feature all_feature[right] = (all_feature[left] + all_feature[right]) / 2.0 normed_all_features[right] = F.normalize(all_feature[right].view(1, P * D), p=2, dim=1) all_indices[right] = all_indices[left] + all_indices[right] # update similarity new_sim = torch.mm(normed_all_features, normed_all_features[right:right+1].T) # [T0 + 1, 1] all_sim[right, :] = new_sim.T all_sim[:, right:right+1] = new_sim all_sim[right, right] = -100.0 # get compressed feature and similarity cur_feature = torch.cat([all_feature[:left], all_feature[left + 1:]]) normed_cur_features = torch.cat([normed_all_features[:left], normed_all_features[left + 1:]]) cur_indices = all_indices[:left] + all_indices[left + 1:] cur_sim_1 = torch.cat([all_sim[:left], all_sim[left + 1:]], dim=0) # [T0, T0 + 1] cur_sim = torch.cat([cur_sim_1[:, :left], cur_sim_1[:, left + 1:]], dim=1) # [T0, T0] step_indices.append(cur_indices) # print(f'Note: perform k-merge feature {img_feature.shape} to {cur_feature.shape}') return cur_feature, cur_sim, step_indices def attention_feature(img_feature, video_max_frames, attention_fn=None, update_ratio=0.2): T, P, D = img_feature.shape T0 = video_max_frames if T <= T0: return img_feature, None cur_feature = img_feature[:T0] # [T0, P, D] turing_memory = cur_feature.reshape(T0*P, D) # [T0*P, D] for i in range(T0, T, T0): j = min(i + T0, T) new_feature = img_feature[i:j] # [P, D] new_feature = new_feature.reshape(-1, D) # [n*P, D] turing_memory = attention_fn(turing_memory, new_feature, update_ratio=update_ratio) # [T0*P, n*P] cur_feature = turing_memory.reshape(T0, P, D) # print(f'Note: perform {attention_fn.__name__} feature {img_feature.shape} to {cur_feature.shape}') return cur_feature, None