jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
import random
def shuffle_indices(size, seed=None):
if seed is not None:
random.seed(seed)
indices = list(range(size))
random.shuffle(indices)
return indices
def shuffle_tensors2(tensor, current_indices, target_indices):
tensor_dict = {current_idx: t for current_idx,
t in zip(current_indices, tensor)}
shuffled_tensors = [tensor_dict[current_idx]
for current_idx in target_indices]
return torch.stack(shuffled_tensors)
def grid_to_list(tensor, grid_size):
frame_count = len(tensor) * grid_size * grid_size
flattened_list = [flatten_grid(grid.unsqueeze(
0), [grid_size, grid_size]) for grid in tensor]
list_tensor = torch.cat(flattened_list, dim=-2)
return torch.cat(torch.chunk(list_tensor, frame_count, dim=-2), dim=0)
def list_to_grid(tensor, grid_size):
grid_frame_count = grid_size * grid_size
grid_count = len(tensor) // grid_frame_count
flat_grids = [torch.cat([a for a in tensor[i * grid_frame_count:(i + 1)
* grid_frame_count]], dim=-2).unsqueeze(0) for i in range(grid_count)]
unflattened_grids = [unflatten_grid(
flat_grid, [grid_size, grid_size]) for flat_grid in flat_grids]
return torch.cat(unflattened_grids, dim=0)
def flatten_grid(x, grid_shape):
B, H, W, C = x.size()
hs, ws = grid_shape
img_h = H // hs
flattened = torch.cat(torch.split(x, img_h, dim=1), dim=2)
return flattened
def unflatten_grid(x, grid_shape):
'''
x: B x C x H x W
'''
B, H, W, C = x.size()
hs, ws = grid_shape
img_w = W // (ws)
unflattened = torch.cat(torch.split(x, img_w, dim=2), dim=1)
return unflattened