|
import torch |
|
import random |
|
|
|
|
|
def shuffle_indices(size, seed=None): |
|
if seed is not None: |
|
random.seed(seed) |
|
indices = list(range(size)) |
|
random.shuffle(indices) |
|
return indices |
|
|
|
|
|
def shuffle_tensors2(tensor, current_indices, target_indices): |
|
tensor_dict = {current_idx: t for current_idx, |
|
t in zip(current_indices, tensor)} |
|
shuffled_tensors = [tensor_dict[current_idx] |
|
for current_idx in target_indices] |
|
return torch.stack(shuffled_tensors) |
|
|
|
|
|
def grid_to_list(tensor, grid_size): |
|
frame_count = len(tensor) * grid_size * grid_size |
|
flattened_list = [flatten_grid(grid.unsqueeze( |
|
0), [grid_size, grid_size]) for grid in tensor] |
|
list_tensor = torch.cat(flattened_list, dim=-2) |
|
return torch.cat(torch.chunk(list_tensor, frame_count, dim=-2), dim=0) |
|
|
|
|
|
def list_to_grid(tensor, grid_size): |
|
grid_frame_count = grid_size * grid_size |
|
grid_count = len(tensor) // grid_frame_count |
|
flat_grids = [torch.cat([a for a in tensor[i * grid_frame_count:(i + 1) |
|
* grid_frame_count]], dim=-2).unsqueeze(0) for i in range(grid_count)] |
|
unflattened_grids = [unflatten_grid( |
|
flat_grid, [grid_size, grid_size]) for flat_grid in flat_grids] |
|
return torch.cat(unflattened_grids, dim=0) |
|
|
|
|
|
def flatten_grid(x, grid_shape): |
|
B, H, W, C = x.size() |
|
hs, ws = grid_shape |
|
img_h = H // hs |
|
flattened = torch.cat(torch.split(x, img_h, dim=1), dim=2) |
|
return flattened |
|
|
|
|
|
def unflatten_grid(x, grid_shape): |
|
''' |
|
x: B x C x H x W |
|
''' |
|
B, H, W, C = x.size() |
|
hs, ws = grid_shape |
|
img_w = W // (ws) |
|
|
|
unflattened = torch.cat(torch.split(x, img_w, dim=2), dim=1) |
|
|
|
return unflattened |