Flash-VStream-demo / flash_vstream /model /compress_functions.py
zhanghaoji
init
eb0678a
raw
history blame
13.9 kB
# Copyright 2024 Flash-VStream Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_feature(img_feature, video_max_frames, img_similarity=None):
T, P, D = img_feature.shape
indices = [[i] for i in range(T)]
T0 = video_max_frames
if T <= T0:
return img_feature, img_similarity, [indices]
cur_feature = img_feature[:T0] # [T0, P, D]
if img_similarity is not None:
cur_sim = img_similarity[:T0 - 1]
else:
cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
cur_indices = indices[:T0]
step_indices = [cur_indices]
for i in range(T0, T):
new_feature = img_feature[i]
new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
all_indices = cur_indices + [[i]]
all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
idx = torch.argmax(all_sim)
if random.randint(0, 1) > 0:
idx = idx + 1
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
if idx + 1 == T0 + 1:
cur_sim = all_sim[:T0 - 1]
cur_indices = all_indices[:-1]
elif idx == 0:
cur_sim = all_sim[1:]
cur_indices = all_indices[1:]
else:
cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
cur_indices = all_indices[:idx] + all_indices[idx + 1:]
step_indices.append(cur_indices)
# print(f'Note: perform drop feature {img_feature.shape} to {cur_feature.shape}')
return cur_feature, cur_sim, step_indices
def merge_feature(img_feature, video_max_frames, img_similarity=None):
T, P, D = img_feature.shape
indices = [[i] for i in range(T)]
T0 = video_max_frames
if T <= T0:
return img_feature, img_similarity, [indices]
cur_feature = img_feature[:T0] # [T0, P, D]
cur_indices = indices[:T0]
step_indices = [cur_indices]
if img_similarity is not None:
cur_sim = img_similarity[:T0 - 1]
else:
cur_sim = F.cosine_similarity(cur_feature[:-1].view(T0 - 1, P * D), cur_feature[1:].view(T0 - 1, P * D)) # [T0 - 1]
for i in range(T0, T):
new_feature = img_feature[i]
new_sim = F.cosine_similarity(cur_feature[-1].view(-1), new_feature.view(-1), dim=0)
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
all_sim = torch.cat([cur_sim, new_sim.unsqueeze(0)], dim=0)
all_indices = cur_indices + [[i]]
idx = torch.argmax(all_sim)
all_feature[idx + 1] = (all_feature[idx] + all_feature[idx + 1]) / 2.0
all_indices[idx + 1] = all_indices[idx] + all_indices[idx + 1]
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
cur_sim = torch.cat([all_sim[:idx], all_sim[idx + 1:]])
cur_indices = all_indices[:idx] + all_indices[idx + 1:]
if idx > 0:
cur_sim[idx - 1] = F.cosine_similarity(all_feature[idx - 1].view(-1), all_feature[idx + 1].view(-1), dim=0)
if idx + 1 < T0:
cur_sim[idx] = F.cosine_similarity(all_feature[idx + 1].view(-1), all_feature[idx + 2].view(-1), dim=0)
step_indices.append(cur_indices)
# print(f'Note: perform merge feature {img_feature.shape} to {cur_feature.shape}')
return cur_feature, cur_sim, step_indices
def kmeans_feature(img_feature, video_max_frames, img_similarity=None):
def kmeans_torch(X, num_clusters, distance='euclidean', tol=1e-4, max_iter=10):
indices = torch.randperm(X.size(0))[:num_clusters]
centroids = X[indices]
for i in range(max_iter):
if distance == 'euclidean':
dists = torch.cdist(X, centroids, p=2)
else:
raise NotImplementedError("Only Euclidean distance is supported yet")
labels = torch.argmin(dists, dim=1)
new_centroids = []
for j in range(num_clusters):
cluster_points = X[labels == j]
if len(cluster_points) > 0:
new_centroid = cluster_points.mean(0)
else: # fix nan centroids
new_centroid = X[random.randint(0, X.size(0) - 1)]
new_centroids.append(new_centroid)
new_centroids = torch.stack(new_centroids)
diff = torch.norm(centroids - new_centroids, dim=1).sum()
if diff < tol:
break
centroids = new_centroids
return centroids, labels, i
T, P, D = img_feature.shape
T0 = video_max_frames
if T <= T0:
return img_feature, img_similarity, [[[i] for i in range(T)]]
X = img_feature.view(T, -1) # [T, P, D]
centroids, labels, exit_step = kmeans_torch(X, T0)
reduced_feature = centroids.view(T0, P, D)
# print(f'Note: perform kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
step_indices = [[] for _ in range(T0)]
for i in range(T0):
step_indices[i] = [j for j in range(T) if labels[j] == i]
return reduced_feature, img_similarity, [step_indices]
def weighted_kmeans_feature(img_feature, video_max_frames, weights=None):
if weights is None:
weights = torch.ones(img_feature.size(0), dtype=img_feature.dtype, device=img_feature.device)
def weighted_kmeans_torch(X, num_clusters, weights=None, distance='euclidean', tol=1e-4, max_iter=10):
indices = torch.randperm(X.size(0), device=X.device)[:num_clusters]
centroids = X[indices]
for i in range(max_iter):
if distance == 'euclidean':
dists = ((X.unsqueeze(1) - centroids.unsqueeze(0)) ** 2).sum(dim=2).sqrt()
else:
raise NotImplementedError("Only Euclidean distance is supported yet")
labels = torch.argmin(dists, dim=1)
weighted_sum = torch.zeros_like(centroids)
weights_sum = torch.zeros(num_clusters, dtype=X.dtype, device=X.device)
for j in range(num_clusters):
cluster_mask = labels == j
weighted_sum[j] = torch.sum(weights[cluster_mask, None] * X[cluster_mask], dim=0)
weights_sum[j] = torch.sum(weights[cluster_mask])
mask = weights_sum > 0
new_centroids = torch.zeros_like(weighted_sum)
new_centroids[mask] = weighted_sum[mask] / weights_sum[mask, None]
if mask.sum() < num_clusters: # fix nan centroids
new_centroids[~mask] = torch.stack([X[random.randint(0, X.size(0) - 1)] for _ in range(num_clusters - mask.sum())])
diff = torch.norm(centroids - new_centroids, dim=1).sum()
if diff < tol:
break
centroids = new_centroids
return centroids, labels, weights_sum, i
T, P, D = img_feature.shape
T0 = video_max_frames
if T <= T0:
return img_feature, weights, [[[i] for i in range(T)]]
X = img_feature.view(T, -1) # [T, P, D]
centroids, labels, weights, exit_step = weighted_kmeans_torch(X, T0, weights)
reduced_feature = centroids.view(T0, P, D)
# print(f'Note: perform weighted kmeans feature {img_feature.shape} to {reduced_feature.shape}, exit at step={exit_step}') # actually, K=T0
step_indices = [[] for _ in range(T0)]
for i in range(T0):
step_indices[i] = [j for j in range(T) if labels[j] == i]
return reduced_feature, weights, [step_indices]
def k_drop_feature(img_feature, video_max_frames, img_similarity=None):
T, P, D = img_feature.shape
indices = [[i] for i in range(T)]
T0 = video_max_frames
if T <= T0:
return img_feature, img_similarity, [indices]
cur_feature = img_feature[:T0] # [T0, P, D]
normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
cur_sim.fill_diagonal_(-100.0)
cur_indices = indices[:T0]
step_indices = [cur_indices]
for i in range(T0, T):
# get new feature
new_feature = img_feature[i]
normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
all_indices = cur_indices + [[i]]
# get new similarity
all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
all_sim[-1, :-1] = new_sim.T
# choose compression position
idx = torch.argmax(all_sim)
left, right = idx // (T0 + 1), idx % (T0 + 1)
if random.randint(0, 1) > 0:
idx = left
else:
idx = right
assert all_sim[left, right] == torch.max(all_sim)
# get compressed feature and similarity
cur_feature = torch.cat([all_feature[:idx], all_feature[idx + 1:]])
normed_cur_features = torch.cat([normed_all_features[:idx], normed_all_features[idx + 1:]])
cur_indices = all_indices[:idx] + all_indices[idx + 1:]
cur_sim_1 = torch.cat([all_sim[:idx], all_sim[idx + 1:]], dim=0) # [T0, T0 + 1]
cur_sim = torch.cat([cur_sim_1[:, :idx], cur_sim_1[:, idx + 1:]], dim=1) # [T0, T0]
step_indices.append(cur_indices)
# print(f'Note: perform k-drop feature {img_feature.shape} to {cur_feature.shape}')
return cur_feature, None, step_indices
def k_merge_feature(img_feature, video_max_frames, img_similarity=None):
T, P, D = img_feature.shape
indices = [[i] for i in range(T)]
T0 = video_max_frames
if T <= T0:
return img_feature, img_similarity, [indices]
cur_feature = img_feature[:T0] # [T0, P, D]
normed_cur_features = F.normalize(cur_feature.view(T0, P * D), p=2, dim=1)
cur_sim = torch.mm(normed_cur_features, normed_cur_features.T) # [T0, T0]
cur_sim.fill_diagonal_(-100.0)
cur_indices = indices[:T0]
step_indices = [cur_indices]
for i in range(T0, T):
# get new feature
new_feature = img_feature[i]
normed_new_feature = F.normalize(new_feature.view(1, P * D), p=2, dim=1)
new_sim = torch.mm(normed_cur_features, normed_new_feature.T) # [T0, 1]
all_feature = torch.cat([cur_feature, new_feature.unsqueeze(0)], dim=0)
normed_all_features = torch.cat([normed_cur_features, normed_new_feature], dim=0)
all_indices = cur_indices + [[i]]
# get new similarity
all_sim_1 = torch.cat([cur_sim, new_sim], dim=1) # [T0, T0 + 1]
all_sim = torch.cat([all_sim_1, torch.ones_like(all_sim_1[-1:]) * -100.0], dim=0) # [T0 + 1, T0 + 1]
all_sim[-1, :-1] = new_sim.T
# choose compression position
idx = torch.argmax(all_sim)
left, right = idx // (T0 + 1), idx % (T0 + 1)
assert all_sim[left, right] == torch.max(all_sim)
# update feature
all_feature[right] = (all_feature[left] + all_feature[right]) / 2.0
normed_all_features[right] = F.normalize(all_feature[right].view(1, P * D), p=2, dim=1)
all_indices[right] = all_indices[left] + all_indices[right]
# update similarity
new_sim = torch.mm(normed_all_features, normed_all_features[right:right+1].T) # [T0 + 1, 1]
all_sim[right, :] = new_sim.T
all_sim[:, right:right+1] = new_sim
all_sim[right, right] = -100.0
# get compressed feature and similarity
cur_feature = torch.cat([all_feature[:left], all_feature[left + 1:]])
normed_cur_features = torch.cat([normed_all_features[:left], normed_all_features[left + 1:]])
cur_indices = all_indices[:left] + all_indices[left + 1:]
cur_sim_1 = torch.cat([all_sim[:left], all_sim[left + 1:]], dim=0) # [T0, T0 + 1]
cur_sim = torch.cat([cur_sim_1[:, :left], cur_sim_1[:, left + 1:]], dim=1) # [T0, T0]
step_indices.append(cur_indices)
# print(f'Note: perform k-merge feature {img_feature.shape} to {cur_feature.shape}')
return cur_feature, cur_sim, step_indices
def attention_feature(img_feature, video_max_frames, attention_fn=None, update_ratio=0.2):
T, P, D = img_feature.shape
T0 = video_max_frames
if T <= T0:
return img_feature, None
cur_feature = img_feature[:T0] # [T0, P, D]
turing_memory = cur_feature.reshape(T0*P, D) # [T0*P, D]
for i in range(T0, T, T0):
j = min(i + T0, T)
new_feature = img_feature[i:j] # [P, D]
new_feature = new_feature.reshape(-1, D) # [n*P, D]
turing_memory = attention_fn(turing_memory, new_feature, update_ratio=update_ratio) # [T0*P, n*P]
cur_feature = turing_memory.reshape(T0, P, D)
# print(f'Note: perform {attention_fn.__name__} feature {img_feature.shape} to {cur_feature.shape}')
return cur_feature, None