character-360 / sgm /data /video_dataset.py
aki-0421
F: add
a3a3ae4 unverified
raw
history blame
7.76 kB
import pytorch_lightning as pl
import numpy as np
import torch
import PIL
import os
import random
from skimage.io import imread
import webdataset as wds
import PIL.Image as Image
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
# from ldm.base_utils import read_pickle, pose_inverse
import torchvision.transforms as transforms
import torchvision
from einops import rearrange
def add_margin(pil_img, color=0, size=256):
width, height = pil_img.size
result = Image.new(pil_img.mode, (size, size), color)
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
return result
def prepare_inputs(image_path, elevation_input, crop_size=-1, image_size=256):
image_input = Image.open(image_path)
if crop_size!=-1:
alpha_np = np.asarray(image_input)[:, :, 3]
coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
min_x, min_y = np.min(coords, 0)
max_x, max_y = np.max(coords, 0)
ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
h, w = ref_img_.height, ref_img_.width
scale = crop_size / max(h, w)
h_, w_ = int(scale * h), int(scale * w)
ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
image_input = add_margin(ref_img_, size=image_size)
else:
image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
image_input = np.asarray(image_input)
image_input = image_input.astype(np.float32) / 255.0
ref_mask = image_input[:, :, 3:]
image_input[:, :, :3] = image_input[:, :, :3] * ref_mask + 1 - ref_mask # white background
image_input = image_input[:, :, :3] * 2.0 - 1.0
image_input = torch.from_numpy(image_input.astype(np.float32))
elevation_input = torch.from_numpy(np.asarray([np.deg2rad(elevation_input)], np.float32))
return {"input_image": image_input, "input_elevation": elevation_input}
class VideoTrainDataset(Dataset):
def __init__(self, base_folder='/data/yanghaibo/datas/OBJAVERSE-LVIS/images', width=1024, height=576, sample_frames=25):
"""
Args:
num_samples (int): Number of samples in the dataset.
channels (int): Number of channels, default is 3 for RGB.
"""
# Define the path to the folder containing video frames
self.base_folder = base_folder
self.folders = os.listdir(self.base_folder)
self.num_samples = len(self.folders)
self.channels = 3
self.width = width
self.height = height
self.sample_frames = sample_frames
self.elevations = [-10, 0, 10, 20, 30, 40]
def __len__(self):
return self.num_samples
def load_im(self, path):
img = imread(path)
img = img.astype(np.float32) / 255.0
mask = img[:,:,3:]
img[:,:,:3] = img[:,:,:3] * mask + 1 - mask # white background
img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
return img, mask
def __getitem__(self, idx):
"""
Args:
idx (int): Index of the sample to return.
Returns:
dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512).
"""
# Randomly select a folder (representing a video) from the base folder
chosen_folder = random.choice(self.folders)
folder_path = os.path.join(self.base_folder, chosen_folder)
frames = os.listdir(folder_path)
# Sort the frames by name
frames.sort()
# Ensure the selected folder has at least `sample_frames`` frames
if len(frames) < self.sample_frames:
raise ValueError(
f"The selected folder '{chosen_folder}' contains fewer than `{self.sample_frames}` frames.")
# Randomly select a start index for frame sequence. Fixed elevation
start_idx = random.randint(0, len(frames) - 1)
range_id = int(start_idx / 16) # 0, 1, 2, 3, 4, 5
elevation = self.elevations[range_id]
selected_frames = []
for frame_idx in range(start_idx, (range_id + 1) * 16):
selected_frames.append(frames[frame_idx])
for frame_idx in range((range_id) * 16, start_idx):
selected_frames.append(frames[frame_idx])
# Initialize a tensor to store the pixel values
pixel_values = torch.empty((self.sample_frames, self.channels, self.height, self.width))
# Load and process each frame
for i, frame_name in enumerate(selected_frames):
frame_path = os.path.join(folder_path, frame_name)
img, mask = self.load_im(frame_path)
# Resize the image and convert it to a tensor
img_resized = img.resize((self.width, self.height))
img_tensor = torch.from_numpy(np.array(img_resized)).float()
# Normalize the image by scaling pixel values to [-1, 1]
img_normalized = img_tensor / 127.5 - 1
# Rearrange channels if necessary
if self.channels == 3:
img_normalized = img_normalized.permute(
2, 0, 1) # For RGB images
elif self.channels == 1:
img_normalized = img_normalized.mean(
dim=2, keepdim=True) # For grayscale images
pixel_values[i] = img_normalized
pixel_values = rearrange(pixel_values, 't c h w -> c t h w')
caption = chosen_folder + "_" + str(start_idx)
return {'video': pixel_values, 'elevation': elevation, 'caption': caption, "fps_id": 7, "motion_bucket_id": 127}
class SyncDreamerEvalData(Dataset):
def __init__(self, image_dir):
self.image_size = 512
self.image_dir = Path(image_dir)
self.crop_size = 20
self.fns = []
for fn in Path(image_dir).iterdir():
if fn.suffix=='.png':
self.fns.append(fn)
print('============= length of dataset %d =============' % len(self.fns))
def __len__(self):
return len(self.fns)
def get_data_for_index(self, index):
input_img_fn = self.fns[index]
elevation = 0
return prepare_inputs(input_img_fn, elevation, 512)
def __getitem__(self, index):
return self.get_data_for_index(index)
class VideoDataset(pl.LightningDataModule):
def __init__(self, base_folder, eval_folder, width, height, sample_frames, batch_size, num_workers=4, seed=0, **kwargs):
super().__init__()
self.base_folder = base_folder
self.eval_folder = eval_folder
self.width = width
self.height = height
self.sample_frames = sample_frames
self.batch_size = batch_size
self.num_workers = num_workers
self.seed = seed
self.additional_args = kwargs
def setup(self):
self.train_dataset = VideoTrainDataset(self.base_folder, self.width, self.height, self.sample_frames)
self.val_dataset = SyncDreamerEvalData(image_dir=self.eval_folder)
def train_dataloader(self):
sampler = DistributedSampler(self.train_dataset, seed=self.seed)
return wds.WebLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
def val_dataloader(self):
loader = wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
return loader
def test_dataloader(self):
return wds.WebLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)