from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb from torch.utils.data import Dataset import numpy as np import torch import lmdb import json from pathlib import Path from PIL import Image import os class TextDataset(Dataset): def __init__(self, prompt_path, extended_prompt_path=None): with open(prompt_path, encoding="utf-8") as f: self.prompt_list = [line.rstrip() for line in f] if extended_prompt_path is not None: with open(extended_prompt_path, encoding="utf-8") as f: self.extended_prompt_list = [line.rstrip() for line in f] assert len(self.extended_prompt_list) == len(self.prompt_list) else: self.extended_prompt_list = None def __len__(self): return len(self.prompt_list) def __getitem__(self, idx): batch = { "prompts": self.prompt_list[idx], "idx": idx, } if self.extended_prompt_list is not None: batch["extended_prompts"] = self.extended_prompt_list[idx] return batch class ODERegressionLMDBDataset(Dataset): def __init__(self, data_path: str, max_pair: int = int(1e8)): self.env = lmdb.open(data_path, readonly=True, lock=False, readahead=False, meminit=False) self.latents_shape = get_array_shape_from_lmdb(self.env, 'latents') self.max_pair = max_pair def __len__(self): return min(self.latents_shape[0], self.max_pair) def __getitem__(self, idx): """ Outputs: - prompts: List of Strings - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image. """ latents = retrieve_row_from_lmdb( self.env, "latents", np.float16, idx, shape=self.latents_shape[1:] ) if len(latents.shape) == 4: latents = latents[None, ...] prompts = retrieve_row_from_lmdb( self.env, "prompts", str, idx ) return { "prompts": prompts, "ode_latent": torch.tensor(latents, dtype=torch.float32) } class ShardingLMDBDataset(Dataset): def __init__(self, data_path: str, max_pair: int = int(1e8)): self.envs = [] self.index = [] for fname in sorted(os.listdir(data_path)): path = os.path.join(data_path, fname) env = lmdb.open(path, readonly=True, lock=False, readahead=False, meminit=False) self.envs.append(env) self.latents_shape = [None] * len(self.envs) for shard_id, env in enumerate(self.envs): self.latents_shape[shard_id] = get_array_shape_from_lmdb(env, 'latents') for local_i in range(self.latents_shape[shard_id][0]): self.index.append((shard_id, local_i)) # print("shard_id ", shard_id, " local_i ", local_i) self.max_pair = max_pair def __len__(self): return len(self.index) def __getitem__(self, idx): """ Outputs: - prompts: List of Strings - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image. """ shard_id, local_idx = self.index[idx] latents = retrieve_row_from_lmdb( self.envs[shard_id], "latents", np.float16, local_idx, shape=self.latents_shape[shard_id][1:] ) if len(latents.shape) == 4: latents = latents[None, ...] prompts = retrieve_row_from_lmdb( self.envs[shard_id], "prompts", str, local_idx ) return { "prompts": prompts, "ode_latent": torch.tensor(latents, dtype=torch.float32) } class TextImagePairDataset(Dataset): def __init__( self, data_dir, transform=None, eval_first_n=-1, pad_to_multiple_of=None ): """ Args: data_dir (str): Path to the directory containing: - target_crop_info_*.json (metadata file) - */ (subdirectory containing images with matching aspect ratio) transform (callable, optional): Optional transform to be applied on the image """ self.transform = transform data_dir = Path(data_dir) # Find the metadata JSON file metadata_files = list(data_dir.glob('target_crop_info_*.json')) if not metadata_files: raise FileNotFoundError(f"No metadata file found in {data_dir}") if len(metadata_files) > 1: raise ValueError(f"Multiple metadata files found in {data_dir}") metadata_path = metadata_files[0] # Extract aspect ratio from metadata filename (e.g. target_crop_info_26-15.json -> 26-15) aspect_ratio = metadata_path.stem.split('_')[-1] # Use aspect ratio subfolder for images self.image_dir = data_dir / aspect_ratio if not self.image_dir.exists(): raise FileNotFoundError(f"Image directory not found: {self.image_dir}") # Load metadata with open(metadata_path, 'r') as f: self.metadata = json.load(f) eval_first_n = eval_first_n if eval_first_n != -1 else len(self.metadata) self.metadata = self.metadata[:eval_first_n] # Verify all images exist for item in self.metadata: image_path = self.image_dir / item['file_name'] if not image_path.exists(): raise FileNotFoundError(f"Image not found: {image_path}") self.dummy_prompt = "DUMMY PROMPT" self.pre_pad_len = len(self.metadata) if pad_to_multiple_of is not None and len(self.metadata) % pad_to_multiple_of != 0: # Duplicate the last entry self.metadata += [self.metadata[-1]] * ( pad_to_multiple_of - len(self.metadata) % pad_to_multiple_of ) def __len__(self): return len(self.metadata) def __getitem__(self, idx): """ Returns: dict: A dictionary containing: - image: PIL Image - caption: str - target_bbox: list of int [x1, y1, x2, y2] - target_ratio: str - type: str - origin_size: tuple of int (width, height) """ item = self.metadata[idx] # Load image image_path = self.image_dir / item['file_name'] image = Image.open(image_path).convert('RGB') # Apply transform if specified if self.transform: image = self.transform(image) return { 'image': image, 'prompts': item['caption'], 'target_bbox': item['target_crop']['target_bbox'], 'target_ratio': item['target_crop']['target_ratio'], 'type': item['type'], 'origin_size': (item['origin_width'], item['origin_height']), 'idx': idx } def cycle(dl): while True: for data in dl: yield data