Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,761 Bytes
a3a3ae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import pytorch_lightning as pl
import numpy as np
import torch
import PIL
import os
import random
from skimage.io import imread
import webdataset as wds
import PIL.Image as Image
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
# from ldm.base_utils import read_pickle, pose_inverse
import torchvision.transforms as transforms
import torchvision
from einops import rearrange
def add_margin(pil_img, color=0, size=256):
width, height = pil_img.size
result = Image.new(pil_img.mode, (size, size), color)
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
return result
def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256):
image_input = Image.open(image_path)
if crop_size!=-1:
alpha_np = np.asarray(image_input)[:, :, 3]
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
min_x, min_y = np.min(coords, 0)
max_x, max_y = np.max(coords, 0)
ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
h, w = ref_img_.height, ref_img_.width
scale = crop_size / max(h, w)
h_, w_ = int(scale * h), int(scale * w)
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
image_input = add_margin(ref_img_, size=image_size)
else:
image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
image_input = np.asarray(image_input)
image_input = image_input.astype(np.float32) / 255.0
ref_mask = image_input[:, :, 3:]
image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
image_input = image_input[:, :, :3] * 2.0 - 1.0
image_input = torch.from_numpy(image_input.astype(np.float32))
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
return {"input_image": image_input, "input_elevation": elevation_input}
class VideoTrainDataset(Dataset):
def __init__(self, base_folder='/data/yanghaibo/datas/OBJAVERSE-LVIS/images', width=1024, height=576, sample_frames=25):
"""
Args:
num_samples (int): Number of samples in the dataset.
channels (int): Number of channels, default is 3 for RGB.
"""
# Define the path to the folder containing video frames
self.base_folder = base_folder
self.folders = os.listdir(self.base_folder)
self.num_samples = len(self.folders)
self.channels = 3
self.width = width
self.height = height
self.sample_frames = sample_frames
self.elevations = [-10, 0, 10, 20, 30, 40]
def __len__(self):
return self.num_samples
def load_im(self, path):
img = imread(path)
img = img.astype(np.float32) / 255.0
mask = img[:,:,3:]
img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
return img, mask
def __getitem__(self, idx):
"""
Args:
idx (int): Index of the sample to return.
Returns:
dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
"""
# Randomly select a folder (representing a video) from the base folder
chosen_folder = random.choice(self.folders)
folder_path = os.path.join(self.base_folder, chosen_folder)
frames = os.listdir(folder_path)
# Sort the frames by name
frames.sort()
# Ensure the selected folder has at least `sample_frames`` frames
if len(frames) < self.sample_frames:
raise ValueError(
f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.")
# Randomly select a start index for frame sequence. Fixed elevation
start_idx = random.randint(0, len(frames) - 1)
range_id = int(start_idx / 16) # 0, 1, 2, 3, 4, 5
elevation = self.elevations[range_id]
selected_frames = []
for frame_idx in range(start_idx, (range_id + 1) * 16):
selected_frames.append(frames[frame_idx])
for frame_idx in range((range_id) * 16, start_idx):
selected_frames.append(frames[frame_idx])
# Initialize a tensor to store the pixel values
pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width))
# Load and process each frame
for i, frame_name in enumerate(selected_frames):
frame_path = os.path.join(folder_path, frame_name)
img, mask = self.load_im(frame_path)
# Resize the image and convert it to a tensor
img_resized = img.resize((self.width, self.height))
img_tensor = torch.from_numpy(np.array(img_resized)).float()
# Normalize the image by scaling pixel values to [-1, 1]
img_normalized = img_tensor / 127.5 - 1
# Rearrange channels if necessary
if self.channels == 3:
img_normalized = img_normalized.permute(
2, 0, 1) # For RGB images
elif self.channels == 1:
img_normalized = img_normalized.mean(
dim=2, keepdim=True) # For grayscale images
pixel_values[i] = img_normalized
pixel_values = rearrange(pixel_values, 't c h w -> c t h w')
caption = chosen_folder + "_" + str(start_idx)
return {'video': pixel_values, 'elevation': elevation, 'caption': caption, "fps_id": 7, "motion_bucket_id": 127}
class SyncDreamerEvalData(Dataset):
def __init__(self, image_dir):
self.image_size = 512
self.image_dir = Path(image_dir)
self.crop_size = 20
self.fns = []
for fn in Path(image_dir).iterdir():
if fn.suffix=='.png':
self.fns.append(fn)
print('============= length of dataset %d =============' % len(self.fns))
def __len__(self):
return len(self.fns)
def get_data_for_index(self, index):
input_img_fn = self.fns[index]
elevation = 0
return prepare_inputs(input_img_fn, elevation, 512)
def __getitem__(self, index):
return self.get_data_for_index(index)
class VideoDataset(pl.LightningDataModule):
def __init__(self, base_folder, eval_folder, width, height, sample_frames, batch_size, num_workers=4, seed=0, **kwargs):
super().__init__()
self.base_folder = base_folder
self.eval_folder = eval_folder
self.width = width
self.height = height
self.sample_frames = sample_frames
self.batch_size = batch_size
self.num_workers = num_workers
self.seed = seed
self.additional_args = kwargs
def setup(self):
self.train_dataset = VideoTrainDataset(self.base_folder, self.width, self.height, self.sample_frames)
self.val_dataset = SyncDreamerEvalData(image_dir=self.eval_folder)
def train_dataloader(self):
sampler = DistributedSampler(self.train_dataset, seed=self.seed)
return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def val_dataloader(self):
loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
return loader
def test_dataloader(self):
return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) |