File size: 11,757 Bytes
6dfcb0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 |
import torch
import numpy as np
import random
import math
def create_weighted_mask_batched(h, w):
y_mask = np.linspace(0, 1, h)
y_mask = np.minimum(y_mask, 1 - y_mask)
x_mask = np.linspace(0, 1, w)
x_mask = np.minimum(x_mask, 1 - x_mask)
weighted_mask = np.outer(y_mask, x_mask)
return torch.from_numpy(weighted_mask).float()
def reconstruct_video_new_2_batched(cropped_tensors, crop_positions, original_shape):
B, T, C, H, W = original_shape
# Initialize an empty tensor to store the reconstructed video
reconstructed_video = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
# Create a tensor to store the sum of weighted masks
weighted_masks_sum = torch.zeros((B, T, C, H, W)).to(cropped_tensors[0].device)
# Create a weighted mask for the crops
weighted_mask = create_weighted_mask_batched(224, 224).to(cropped_tensors[0].device)
weighted_mask = weighted_mask[None, None, None, :, :] # Extend dimensions to match the cropped tensor.
for idx, crop in enumerate(cropped_tensors):
start_h, start_w = crop_positions[idx]
# Multiply the crop with the weighted mask
weighted_crop = crop * weighted_mask
# Add the weighted crop to the corresponding location in the reconstructed_video tensor
reconstructed_video[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_crop
# Update the weighted_masks_sum tensor
weighted_masks_sum[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)] += weighted_mask
# Add a small epsilon value to avoid division by zero
epsilon = 1e-8
# Normalize the reconstructed video by dividing each pixel by its corresponding weighted_masks_sum value plus epsilon
reconstructed_video /= (weighted_masks_sum + epsilon)
return reconstructed_video
import torch.nn.functional as F
resize = lambda x,a: F.interpolate(x, [int(a*x.shape[-2]), int(a*x.shape[-1])], mode='bilinear', align_corners=False)
upsample = lambda x,H,W: F.interpolate(x, [int(H), int(W)], mode='bilinear', align_corners=False)
#
def compute_optical_flow(embedding_tensor, mask_tensor, frame_size):
# Unroll the mask tensor and find the indices of the masked and unmasked values in the second frame
mask_unrolled = mask_tensor.view(-1)
second_frame_unmask_indices = torch.where(mask_unrolled[frame_size**2:] == False)[0]
# Divide the embedding tensor into two parts: corresponding to the first and the second frame
first_frame_embeddings = embedding_tensor[0, :frame_size**2, :]
second_frame_embeddings = embedding_tensor[0, frame_size**2:, :]
# Compute the cosine similarity between the unmasked embeddings from the second frame and the embeddings from the first frame
dot_product = torch.matmul(second_frame_embeddings, first_frame_embeddings.T)
norms = torch.norm(second_frame_embeddings, dim=1)[:, None] * torch.norm(first_frame_embeddings, dim=1)[None, :]
cos_sim_matrix = dot_product / norms
# Find the indices of pixels in the first frame that are most similar to each unmasked pixel in the second frame
first_frame_most_similar_indices = cos_sim_matrix.argmax(dim=-1)
# Convert the 1D pixel indices into 2D coordinates
second_frame_y = second_frame_unmask_indices // frame_size
second_frame_x = second_frame_unmask_indices % frame_size
first_frame_y = first_frame_most_similar_indices // frame_size
first_frame_x = first_frame_most_similar_indices % frame_size
# Compute the x and y displacements and convert them to float
displacements_x = (second_frame_x - first_frame_x).float()
displacements_y = (second_frame_y - first_frame_y).float()
# Initialize optical flow tensor
optical_flow = torch.zeros((2, frame_size, frame_size), device=embedding_tensor.device)
# Assign the computed displacements to the corresponding pixels in the optical flow tensor
optical_flow[0, second_frame_y, second_frame_x] = displacements_x
optical_flow[1, second_frame_y, second_frame_x] = displacements_y
return optical_flow
def get_minimal_224_crops_new_batched(video_tensor, N):
B, T, C, H, W = video_tensor.shape
# Calculate the number of crops needed in both the height and width dimensions
num_crops_h = math.ceil(H / 224) if H > 224 else 1
num_crops_w = math.ceil(W / 224) if W > 224 else 1
# Calculate the step size for the height and width dimensions
step_size_h = 0 if H <= 224 else max(0, (H - 224) // (num_crops_h - 1))
step_size_w = 0 if W <= 224 else max(0, (W - 224) // (num_crops_w - 1))
# Create a list to store the cropped tensors and their start positions
cropped_tensors = []
crop_positions = []
# Iterate over the height and width dimensions, extract the 224x224 crops, and append to the cropped_tensors list
for i in range(num_crops_h):
for j in range(num_crops_w):
start_h = i * step_size_h
start_w = j * step_size_w
end_h = min(start_h + 224, H)
end_w = min(start_w + 224, W)
crop = video_tensor[:, :, :, start_h:end_h, start_w:end_w]
cropped_tensors.append(crop)
crop_positions.append((start_h, start_w))
D = len(cropped_tensors)
# If N is greater than D, generate additional random crops
if N > D and H > 224 and W > 224: # check if H and W are greater than 224
for _ in range(N - D):
start_h = random.randint(0, H - 224)
start_w = random.randint(0, W - 224)
crop = video_tensor[:, :, :, start_h:(start_h + 224), start_w:(start_w + 224)]
cropped_tensors.append(crop)
crop_positions.append((start_h, start_w))
# Reshape the cropped tensors to fit the required output shape (B, T, C, 224, 224)
cropped_tensors = [crop.reshape(B, T, C, 224, 224) for crop in cropped_tensors]
return cropped_tensors, crop_positions
def get_honglin_3frame_vmae_optical_flow_crop_batched(generator,
mask_generator,
img1,
img2,
img3,
neg_back_flow=True,
num_scales=1,
min_scale=400,
N_mask_samples=100,
mask_ratio=0.8,
flow_frames='23'):
B = img1.shape[0]
assert len(img1.shape) == 4
assert num_scales >= 1
# For scaling
h1 = img2.shape[-2]
w1 = img2.shape[-1]
assert min_scale < h1
if neg_back_flow is False:
print('WARNING: Not calculating negative backward flow')
alpha = (min_scale / img1.shape[-2]) ** (1 / 4)
frame_size = 224 // generator.patch_size[-1]
patch_size = generator.patch_size[-1]
all_fwd_flows_e2d = []
for aidx in range(num_scales):
# print('aidx: ', aidx)
img1_scaled = resize(img1.clone(), alpha ** aidx)
img2_scaled = resize(img2.clone(), alpha ** aidx)
img3_scaled = resize(img3.clone(), alpha ** aidx)
h2 = img2_scaled.shape[-2]
w2 = img2_scaled.shape[-1]
s_h = h1 / h2
s_w = w1 / w2
# Because technically the compute_optical_flow function returns neg back flow
if neg_back_flow is True:
video = torch.cat([img3_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img1_scaled.unsqueeze(1)], 1)
else:
video = torch.cat([img1_scaled.unsqueeze(1), img2_scaled.unsqueeze(1), img3_scaled.unsqueeze(1)], 1)
# Should work, even if the incoming video is already 224x224
crops1, c_pos1 = get_minimal_224_crops_new_batched(video, 1)
# print(len(crops1), crops1[0].shape)
num_crops = len(crops1)
crop_flows_enc = []
crop_flows_enc2dec = []
N_samples = N_mask_samples
crop = torch.cat(crops1, 0).cuda()
# print(crop.shape)
optical_flows_enc2dec = torch.zeros(B * num_crops, 2, frame_size, frame_size).cuda()
mask_counts = torch.zeros(frame_size, frame_size).cuda()
i = 0
while i < N_samples or (mask_counts == 0).any().item():
if i % 100 == 0:
pass # print(i)
mask_generator.mask_ratio = mask_ratio
# breakpoint()
# This would be that every sample has the same mask. For now that's okay I think
mask = mask_generator(num_frames=3)[None]
mask_2f = ~mask[0, frame_size * frame_size * 2:]
mask_counts += mask_2f.reshape(frame_size, frame_size)
with torch.cuda.amp.autocast(enabled=True):
processed_x = crop.transpose(1, 2)
# print("crop", processed_x.max())
encoder_out = generator.encoder(processed_x.to(torch.float16), mask.repeat(B * num_crops, 1))
encoder_to_decoder = generator.encoder_to_decoder(encoder_out)
# print(encoder_to_decoder.shape)
if flow_frames == '23':
encoder_to_decoder = encoder_to_decoder[:, frame_size * frame_size:, :]
flow_mask = mask[:, frame_size * frame_size:]
# print(encoder_to_decoder.shape)
elif flow_frames == '12':
encoder_to_decoder = encoder_to_decoder[:, :frame_size * frame_size * 2, :]
# print(encoder_to_decoder.shape)
flow_mask = mask[:, :frame_size * frame_size * 2]
# print(mask.shape)
# print(flow_mask.shape)
# print()
optical_flow_e2d = []
# one per batch element for now
for b in range(B * num_crops):
batch_flow = compute_optical_flow(encoder_to_decoder[b].unsqueeze(0), flow_mask, frame_size)
optical_flow_e2d.append(batch_flow.unsqueeze(0))
optical_flow_e2d = torch.cat(optical_flow_e2d, 0)
optical_flows_enc2dec += optical_flow_e2d
i += 1
optical_flows_enc2dec = optical_flows_enc2dec / mask_counts
scale_factor_y = video.shape[-2] / 224
scale_factor_x = video.shape[-1] / 224
scaled_optical_flow = torch.zeros_like(optical_flows_enc2dec)
scaled_optical_flow[:, 0, :, :] = optical_flows_enc2dec[:, 0, :, :] * scale_factor_x * s_w
scaled_optical_flow[:, 1, :, :] = optical_flows_enc2dec[:, 1, :, :] * scale_factor_y * s_h
# split the crops back up
crop_flows_enc2dec = scaled_optical_flow.split(B, 0)
# print(len(crop_flows_enc2dec))
optical_flows_enc2dec_joined = reconstruct_video_new_2_batched(
[_.unsqueeze(1).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2).cpu() for _ in
crop_flows_enc2dec], c_pos1, (B, 1, 2, video.shape[-2], video.shape[-1])).squeeze(1)
all_fwd_flows_e2d.append(optical_flows_enc2dec_joined)
all_fwd_flows_e2d_new = []
for r in all_fwd_flows_e2d:
new_r = upsample(r, all_fwd_flows_e2d[0].shape[-2], all_fwd_flows_e2d[0].shape[-1])
all_fwd_flows_e2d_new.append(new_r.unsqueeze(-1))
return_flow = torch.cat(all_fwd_flows_e2d_new, -1).mean(-1)
if neg_back_flow is True:
return_flow = -return_flow
all_fwd_flows_e2d_new = [-_ for _ in all_fwd_flows_e2d_new]
return return_flow, all_fwd_flows_e2d_new
|