Spaces:
Running
Running
import json | |
import os | |
import random | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Tuple | |
import numpy as np | |
import pandas as pd | |
import torch | |
import torchvision.transforms as TT | |
import torchvision.transforms.functional as TTF | |
from accelerate.logging import get_logger | |
from torch.utils.data import Dataset, Sampler | |
from torchvision import transforms | |
from torchvision.transforms import InterpolationMode | |
from torchvision.transforms.functional import resize | |
import gc | |
import time | |
import resource | |
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error | |
# Very few bug reports but it happens. Look in decord Github issues for more relevant information. | |
import decord # isort:skip | |
decord.bridge.set_bridge("torch") | |
from .constants import ( # noqa | |
COMMON_LLM_START_PHRASES, | |
PRECOMPUTED_CONDITIONS_DIR_NAME, | |
PRECOMPUTED_DIR_NAME, | |
PRECOMPUTED_LATENTS_DIR_NAME, | |
) | |
# Decord is causing us some issues! | |
# Let's try to increase file descriptor limits to avoid this error: | |
# | |
# decord._ffi.base.DECORDError: Resource temporarily unavailable | |
try: | |
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) | |
print(f"Current file descriptor limits: soft={soft}, hard={hard}") | |
# Try to increase to hard limit if possible | |
if soft < hard: | |
resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) | |
new_soft, new_hard = resource.getrlimit(resource.RLIMIT_NOFILE) | |
print(f"Updated file descriptor limits: soft={new_soft}, hard={new_hard}") | |
except Exception as e: | |
print(f"Could not check or update file descriptor limits: {e}") | |
logger = get_logger(__name__) | |
# TODO(aryan): This needs a refactor with separation of concerns. | |
# Images should be handled separately. Videos should be handled separately. | |
# Loading should be handled separately. | |
# Preprocessing (aspect ratio, resizing) should be handled separately. | |
# URL loading should be handled. | |
# Parquet format should be handled. | |
# Loading from ZIP should be handled. | |
class ImageOrVideoDataset(Dataset): | |
def __init__( | |
self, | |
data_root: str, | |
caption_column: str, | |
video_column: str, | |
resolution_buckets: List[Tuple[int, int, int]], | |
dataset_file: Optional[str] = None, | |
id_token: Optional[str] = None, | |
remove_llm_prefixes: bool = False, | |
) -> None: | |
super().__init__() | |
self.data_root = Path(data_root) | |
self.dataset_file = dataset_file | |
self.caption_column = caption_column | |
self.video_column = video_column | |
self.id_token = f"{id_token.strip()} " if id_token else "" | |
self.resolution_buckets = resolution_buckets | |
# Four methods of loading data are supported. | |
# - Using a CSV: caption_column and video_column must be some column in the CSV. One could | |
# make use of other columns too, such as a motion score or aesthetic score, by modifying the | |
# logic in CSV processing. | |
# - Using two files containing line-separate captions and relative paths to videos. | |
# - Using a JSON file containing a list of dictionaries, where each dictionary has a `caption_column` and `video_column` key. | |
# - Using a JSONL file containing a list of line-separated dictionaries, where each dictionary has a `caption_column` and `video_column` key. | |
# For a more detailed explanation about preparing dataset format, checkout the README. | |
if dataset_file is None: | |
( | |
self.prompts, | |
self.video_paths, | |
) = self._load_dataset_from_local_path() | |
elif dataset_file.endswith(".csv"): | |
( | |
self.prompts, | |
self.video_paths, | |
) = self._load_dataset_from_csv() | |
elif dataset_file.endswith(".json"): | |
( | |
self.prompts, | |
self.video_paths, | |
) = self._load_dataset_from_json() | |
elif dataset_file.endswith(".jsonl"): | |
( | |
self.prompts, | |
self.video_paths, | |
) = self._load_dataset_from_jsonl() | |
else: | |
raise ValueError( | |
"Expected `--dataset_file` to be a path to a CSV file or a directory containing line-separated text prompts and video paths." | |
) | |
if len(self.video_paths) != len(self.prompts): | |
raise ValueError( | |
f"Expected length of prompts and videos to be the same but found {len(self.prompts)=} and {len(self.video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." | |
) | |
# Clean LLM start phrases | |
if remove_llm_prefixes: | |
for i in range(len(self.prompts)): | |
self.prompts[i] = self.prompts[i].strip() | |
for phrase in COMMON_LLM_START_PHRASES: | |
if self.prompts[i].startswith(phrase): | |
self.prompts[i] = self.prompts[i].removeprefix(phrase).strip() | |
self.video_transforms = transforms.Compose( | |
[ | |
transforms.Lambda(self.scale_transform), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), | |
] | |
) | |
def scale_transform(x): | |
return x / 255.0 | |
def __len__(self) -> int: | |
return len(self.video_paths) | |
def __getitem__(self, index: int) -> Dict[str, Any]: | |
if isinstance(index, list): | |
# Here, index is actually a list of data objects that we need to return. | |
# The BucketSampler should ideally return indices. But, in the sampler, we'd like | |
# to have information about num_frames, height and width. Since this is not stored | |
# as metadata, we need to read the video to get this information. You could read this | |
# information without loading the full video in memory, but we do it anyway. In order | |
# to not load the video twice (once to get the metadata, and once to return the loaded video | |
# based on sampled indices), we cache it in the BucketSampler. When the sampler is | |
# to yield, we yield the cache data instead of indices. So, this special check ensures | |
# that data is not loaded a second time. PRs are welcome for improvements. | |
return index | |
prompt = self.id_token + self.prompts[index] | |
video_path: Path = self.video_paths[index] | |
if video_path.suffix.lower() in [".png", ".jpg", ".jpeg"]: | |
video = self._preprocess_image(video_path) | |
else: | |
video = self._preprocess_video(video_path) | |
return { | |
"prompt": prompt, | |
"video": video, | |
"video_metadata": { | |
"num_frames": video.shape[0], | |
"height": video.shape[2], | |
"width": video.shape[3], | |
}, | |
} | |
def _load_dataset_from_local_path(self) -> Tuple[List[str], List[str]]: | |
if not self.data_root.exists(): | |
raise ValueError("Root folder for videos does not exist") | |
prompt_path = self.data_root.joinpath(self.caption_column) | |
video_path = self.data_root.joinpath(self.video_column) | |
if not prompt_path.exists() or not prompt_path.is_file(): | |
raise ValueError( | |
"Expected `--caption_column` to be path to a file in `--data_root` containing line-separated text prompts." | |
) | |
if not video_path.exists() or not video_path.is_file(): | |
raise ValueError( | |
"Expected `--video_column` to be path to a file in `--data_root` containing line-separated paths to video data in the same directory." | |
) | |
with open(prompt_path, "r", encoding="utf-8") as file: | |
prompts = [line.strip() for line in file.readlines() if len(line.strip()) > 0] | |
with open(video_path, "r", encoding="utf-8") as file: | |
video_paths = [self.data_root.joinpath(line.strip()) for line in file.readlines() if len(line.strip()) > 0] | |
if any(not path.is_file() for path in video_paths): | |
raise ValueError( | |
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." | |
) | |
return prompts, video_paths | |
def _load_dataset_from_csv(self) -> Tuple[List[str], List[str]]: | |
df = pd.read_csv(self.dataset_file) | |
prompts = df[self.caption_column].tolist() | |
video_paths = df[self.video_column].tolist() | |
video_paths = [self.data_root.joinpath(line.strip()) for line in video_paths] | |
if any(not path.is_file() for path in video_paths): | |
raise ValueError( | |
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." | |
) | |
return prompts, video_paths | |
def _load_dataset_from_json(self) -> Tuple[List[str], List[str]]: | |
with open(self.dataset_file, "r", encoding="utf-8") as file: | |
data = json.load(file) | |
prompts = [entry[self.caption_column] for entry in data] | |
video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data] | |
if any(not path.is_file() for path in video_paths): | |
raise ValueError( | |
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." | |
) | |
return prompts, video_paths | |
def _load_dataset_from_jsonl(self) -> Tuple[List[str], List[str]]: | |
with open(self.dataset_file, "r", encoding="utf-8") as file: | |
data = [json.loads(line) for line in file] | |
prompts = [entry[self.caption_column] for entry in data] | |
video_paths = [self.data_root.joinpath(entry[self.video_column].strip()) for entry in data] | |
if any(not path.is_file() for path in video_paths): | |
raise ValueError( | |
f"Expected `{self.video_column=}` to be a path to a file in `{self.data_root=}` containing line-separated paths to video data but found atleast one path that is not a valid file." | |
) | |
return prompts, video_paths | |
def _preprocess_image(self, path: Path) -> torch.Tensor: | |
# TODO(aryan): Support alpha channel in future by whitening background | |
image = TTF.Image.open(path.as_posix()).convert("RGB") | |
image = TTF.to_tensor(image) | |
image = image * 2.0 - 1.0 | |
image = image.unsqueeze(0).contiguous() # [C, H, W] -> [1, C, H, W] (1-frame video) | |
return image | |
def _preprocess_video(self, path: Path) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
""" | |
Loads a single video, or latent and prompt embedding, based on initialization parameters. | |
Returns a [F, C, H, W] video tensor. | |
""" | |
max_retries = 3 | |
retry_delay = 1.0 # seconds | |
for attempt in range(max_retries): | |
try: | |
# Create video reader | |
video_reader = decord.VideoReader(uri=path.as_posix()) | |
video_num_frames = len(video_reader) | |
# Process frames | |
indices = list(range(0, video_num_frames, video_num_frames // self.max_num_frames)) | |
frames = video_reader.get_batch(indices) | |
frames = frames[: self.max_num_frames].float() | |
frames = frames.permute(0, 3, 1, 2).contiguous() | |
frames = torch.stack([self.video_transforms(frame) for frame in frames], dim=0) | |
# Explicitly clean up resources | |
del video_reader | |
# Force garbage collection occasionally | |
if random.random() < 0.05: # 5% chance | |
gc.collect() | |
return frames | |
except decord._ffi.base.DECORDError as e: | |
# Log the error | |
error_msg = str(e) | |
if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1: | |
logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}") | |
# Clean up and wait before retrying | |
gc.collect() | |
time.sleep(retry_delay * (attempt + 1)) # Increasing backoff | |
else: | |
# Either not a resource error or we've run out of retries | |
logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}") | |
raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}") | |
class ImageOrVideoDatasetWithResizing(ImageOrVideoDataset): | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] | |
def _preprocess_image(self, path: Path) -> torch.Tensor: | |
# TODO(aryan): Support alpha channel in future by whitening background | |
image = TTF.Image.open(path.as_posix()).convert("RGB") | |
image = TTF.to_tensor(image) | |
nearest_res = self._find_nearest_resolution(image.shape[1], image.shape[2]) | |
image = resize(image, nearest_res) | |
image = image * 2.0 - 1.0 | |
image = image.unsqueeze(0).contiguous() | |
return image | |
def _preprocess_video(self, path: Path) -> torch.Tensor: | |
max_retries = 3 | |
retry_delay = 1.0 # seconds | |
for attempt in range(max_retries): | |
try: | |
# Create video reader | |
video_reader = decord.VideoReader(uri=path.as_posix()) | |
video_num_frames = len(video_reader) | |
# Find appropriate bucket for the video | |
video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames] | |
if not video_buckets: | |
_, h, w = self.resolution_buckets[0] | |
video_buckets = [(1, h, w)] | |
nearest_frame_bucket = min( | |
video_buckets, | |
key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)), | |
default=video_buckets[0], | |
)[0] | |
# Extract and process frames | |
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) | |
frames = video_reader.get_batch(frame_indices) | |
frames = frames[:nearest_frame_bucket].float() | |
frames = frames.permute(0, 3, 1, 2).contiguous() | |
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) | |
frames_resized = torch.stack([resize(frame, nearest_res) for frame in frames], dim=0) | |
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) | |
# Explicitly clean up resources | |
del video_reader | |
# Force garbage collection occasionally | |
if random.random() < 0.05: # 5% chance | |
gc.collect() | |
return frames | |
except decord._ffi.base.DECORDError as e: | |
# Log the error | |
error_msg = str(e) | |
if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1: | |
logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}") | |
# Clean up and wait before retrying | |
gc.collect() | |
time.sleep(retry_delay * (attempt + 1)) # Increasing backoff | |
else: | |
# Either not a resource error or we've run out of retries | |
logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}") | |
raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}") | |
def _find_nearest_resolution(self, height, width): | |
nearest_res = min(self.resolution_buckets, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) | |
return nearest_res[1], nearest_res[2] | |
class ImageOrVideoDatasetWithResizeAndRectangleCrop(ImageOrVideoDataset): | |
def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
self.video_reshape_mode = video_reshape_mode | |
self.max_num_frames = max(self.resolution_buckets, key=lambda x: x[0])[0] | |
def _resize_for_rectangle_crop(self, arr, image_size): | |
reshape_mode = self.video_reshape_mode | |
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: | |
arr = resize( | |
arr, | |
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], | |
interpolation=InterpolationMode.BICUBIC, | |
) | |
else: | |
arr = resize( | |
arr, | |
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], | |
interpolation=InterpolationMode.BICUBIC, | |
) | |
h, w = arr.shape[2], arr.shape[3] | |
arr = arr.squeeze(0) | |
delta_h = h - image_size[0] | |
delta_w = w - image_size[1] | |
if reshape_mode == "random" or reshape_mode == "none": | |
top = np.random.randint(0, delta_h + 1) | |
left = np.random.randint(0, delta_w + 1) | |
elif reshape_mode == "center": | |
top, left = delta_h // 2, delta_w // 2 | |
else: | |
raise NotImplementedError | |
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) | |
return arr | |
def _preprocess_video(self, path: Path) -> torch.Tensor: | |
max_retries = 3 | |
retry_delay = 1.0 # seconds | |
for attempt in range(max_retries): | |
try: | |
# Create video reader | |
video_reader = decord.VideoReader(uri=path.as_posix()) | |
video_num_frames = len(video_reader) | |
# Find appropriate bucket for the video | |
video_buckets = [bucket for bucket in self.resolution_buckets if bucket[0] <= video_num_frames] | |
if not video_buckets: | |
_, h, w = self.resolution_buckets[0] | |
video_buckets = [(1, h, w)] | |
nearest_frame_bucket = min( | |
video_buckets, | |
key=lambda x: abs(x[0] - min(video_num_frames, self.max_num_frames)), | |
default=video_buckets[0], | |
)[0] | |
# Extract and process frames | |
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) | |
frames = video_reader.get_batch(frame_indices) | |
frames = frames[:nearest_frame_bucket].float() | |
frames = frames.permute(0, 3, 1, 2).contiguous() | |
# Fix: Change self.resolutions to self.resolution_buckets to match the class attribute | |
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3]) | |
frames_resized = self._resize_for_rectangle_crop(frames, nearest_res) | |
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0) | |
# Explicitly clean up resources | |
del video_reader | |
# Force garbage collection occasionally | |
if random.random() < 0.05: # 5% chance | |
gc.collect() | |
return frames | |
except decord._ffi.base.DECORDError as e: | |
# Log the error | |
error_msg = str(e) | |
if "Resource temporarily unavailable" in error_msg and attempt < max_retries - 1: | |
logger.warning(f"Retry {attempt+1}/{max_retries} loading video {path}: {error_msg}") | |
# Clean up and wait before retrying | |
gc.collect() | |
time.sleep(retry_delay * (attempt + 1)) # Increasing backoff | |
else: | |
# Either not a resource error or we've run out of retries | |
logger.error(f"Failed to load video {path} after {attempt+1} attempts: {error_msg}") | |
raise RuntimeError(f"Failed to load video after {max_retries} attempts: {error_msg}") | |
def _find_nearest_resolution(self, height, width): | |
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width)) | |
return nearest_res[1], nearest_res[2] | |
class PrecomputedDataset(Dataset): | |
def __init__(self, data_root: str, model_name: str = None, cleaned_model_id: str = None) -> None: | |
super().__init__() | |
self.data_root = Path(data_root) | |
if model_name and cleaned_model_id: | |
precomputation_dir = self.data_root / f"{model_name}_{cleaned_model_id}_{PRECOMPUTED_DIR_NAME}" | |
self.latents_path = precomputation_dir / PRECOMPUTED_LATENTS_DIR_NAME | |
self.conditions_path = precomputation_dir / PRECOMPUTED_CONDITIONS_DIR_NAME | |
else: | |
self.latents_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_LATENTS_DIR_NAME | |
self.conditions_path = self.data_root / PRECOMPUTED_DIR_NAME / PRECOMPUTED_CONDITIONS_DIR_NAME | |
self.latent_conditions = sorted(os.listdir(self.latents_path)) | |
self.text_conditions = sorted(os.listdir(self.conditions_path)) | |
assert len(self.latent_conditions) == len(self.text_conditions), "Number of captions and videos do not match" | |
def __len__(self) -> int: | |
return len(self.latent_conditions) | |
def __getitem__(self, index: int) -> Dict[str, Any]: | |
conditions = {} | |
latent_path = self.latents_path / self.latent_conditions[index] | |
condition_path = self.conditions_path / self.text_conditions[index] | |
conditions["latent_conditions"] = torch.load(latent_path, map_location="cpu", weights_only=True) | |
conditions["text_conditions"] = torch.load(condition_path, map_location="cpu", weights_only=True) | |
return conditions | |
class BucketSampler(Sampler): | |
r""" | |
PyTorch Sampler that groups 3D data by height, width and frames. | |
Args: | |
data_source (`ImageOrVideoDataset`): | |
A PyTorch dataset object that is an instance of `ImageOrVideoDataset`. | |
batch_size (`int`, defaults to `8`): | |
The batch size to use for training. | |
shuffle (`bool`, defaults to `True`): | |
Whether or not to shuffle the data in each batch before dispatching to dataloader. | |
drop_last (`bool`, defaults to `False`): | |
Whether or not to drop incomplete buckets of data after completely iterating over all data | |
in the dataset. If set to True, only batches that have `batch_size` number of entries will | |
be yielded. If set to False, it is guaranteed that all data in the dataset will be processed | |
and batches that do not have `batch_size` number of entries will also be yielded. | |
""" | |
def __init__( | |
self, data_source: ImageOrVideoDataset, batch_size: int = 8, shuffle: bool = True, drop_last: bool = False | |
) -> None: | |
self.data_source = data_source | |
self.batch_size = batch_size | |
self.shuffle = shuffle | |
self.drop_last = drop_last | |
self.buckets = {resolution: [] for resolution in data_source.resolution_buckets} | |
self._raised_warning_for_drop_last = False | |
def __len__(self): | |
if self.drop_last and not self._raised_warning_for_drop_last: | |
self._raised_warning_for_drop_last = True | |
logger.warning( | |
"Calculating the length for bucket sampler is not possible when `drop_last` is set to True. This may cause problems when setting the number of epochs used for training." | |
) | |
return (len(self.data_source) + self.batch_size - 1) // self.batch_size | |
def __iter__(self): | |
for index, data in enumerate(self.data_source): | |
video_metadata = data["video_metadata"] | |
f, h, w = video_metadata["num_frames"], video_metadata["height"], video_metadata["width"] | |
self.buckets[(f, h, w)].append(data) | |
if len(self.buckets[(f, h, w)]) == self.batch_size: | |
if self.shuffle: | |
random.shuffle(self.buckets[(f, h, w)]) | |
yield self.buckets[(f, h, w)] | |
del self.buckets[(f, h, w)] | |
self.buckets[(f, h, w)] = [] | |
if self.drop_last: | |
return | |
for fhw, bucket in list(self.buckets.items()): | |
if len(bucket) == 0: | |
continue | |
if self.shuffle: | |
random.shuffle(bucket) | |
yield bucket | |
del self.buckets[fhw] | |
self.buckets[fhw] = [] | |