Spaces:
Running
on
Zero
Running
on
Zero
from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb | |
from torch.utils.data import Dataset | |
import numpy as np | |
import torch | |
import lmdb | |
import json | |
from pathlib import Path | |
from PIL import Image | |
import os | |
class TextDataset(Dataset): | |
def __init__(self, prompt_path, extended_prompt_path=None): | |
with open(prompt_path, encoding="utf-8") as f: | |
self.prompt_list = [line.rstrip() for line in f] | |
if extended_prompt_path is not None: | |
with open(extended_prompt_path, encoding="utf-8") as f: | |
self.extended_prompt_list = [line.rstrip() for line in f] | |
assert len(self.extended_prompt_list) == len(self.prompt_list) | |
else: | |
self.extended_prompt_list = None | |
def __len__(self): | |
return len(self.prompt_list) | |
def __getitem__(self, idx): | |
batch = { | |
"prompts": self.prompt_list[idx], | |
"idx": idx, | |
} | |
if self.extended_prompt_list is not None: | |
batch["extended_prompts"] = self.extended_prompt_list[idx] | |
return batch | |
class ODERegressionLMDBDataset(Dataset): | |
def __init__(self, data_path: str, max_pair: int = int(1e8)): | |
self.env = lmdb.open(data_path, readonly=True, | |
lock=False, readahead=False, meminit=False) | |
self.latents_shape = get_array_shape_from_lmdb(self.env, 'latents') | |
self.max_pair = max_pair | |
def __len__(self): | |
return min(self.latents_shape[0], self.max_pair) | |
def __getitem__(self, idx): | |
""" | |
Outputs: | |
- prompts: List of Strings | |
- latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image. | |
""" | |
latents = retrieve_row_from_lmdb( | |
self.env, | |
"latents", np.float16, idx, shape=self.latents_shape[1:] | |
) | |
if len(latents.shape) == 4: | |
latents = latents[None, ...] | |
prompts = retrieve_row_from_lmdb( | |
self.env, | |
"prompts", str, idx | |
) | |
return { | |
"prompts": prompts, | |
"ode_latent": torch.tensor(latents, dtype=torch.float32) | |
} | |
class ShardingLMDBDataset(Dataset): | |
def __init__(self, data_path: str, max_pair: int = int(1e8)): | |
self.envs = [] | |
self.index = [] | |
for fname in sorted(os.listdir(data_path)): | |
path = os.path.join(data_path, fname) | |
env = lmdb.open(path, | |
readonly=True, | |
lock=False, | |
readahead=False, | |
meminit=False) | |
self.envs.append(env) | |
self.latents_shape = [None] * len(self.envs) | |
for shard_id, env in enumerate(self.envs): | |
self.latents_shape[shard_id] = get_array_shape_from_lmdb(env, 'latents') | |
for local_i in range(self.latents_shape[shard_id][0]): | |
self.index.append((shard_id, local_i)) | |
# print("shard_id ", shard_id, " local_i ", local_i) | |
self.max_pair = max_pair | |
def __len__(self): | |
return len(self.index) | |
def __getitem__(self, idx): | |
""" | |
Outputs: | |
- prompts: List of Strings | |
- latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image. | |
""" | |
shard_id, local_idx = self.index[idx] | |
latents = retrieve_row_from_lmdb( | |
self.envs[shard_id], | |
"latents", np.float16, local_idx, | |
shape=self.latents_shape[shard_id][1:] | |
) | |
if len(latents.shape) == 4: | |
latents = latents[None, ...] | |
prompts = retrieve_row_from_lmdb( | |
self.envs[shard_id], | |
"prompts", str, local_idx | |
) | |
return { | |
"prompts": prompts, | |
"ode_latent": torch.tensor(latents, dtype=torch.float32) | |
} | |
class TextImagePairDataset(Dataset): | |
def __init__( | |
self, | |
data_dir, | |
transform=None, | |
eval_first_n=-1, | |
pad_to_multiple_of=None | |
): | |
""" | |
Args: | |
data_dir (str): Path to the directory containing: | |
- target_crop_info_*.json (metadata file) | |
- */ (subdirectory containing images with matching aspect ratio) | |
transform (callable, optional): Optional transform to be applied on the image | |
""" | |
self.transform = transform | |
data_dir = Path(data_dir) | |
# Find the metadata JSON file | |
metadata_files = list(data_dir.glob('target_crop_info_*.json')) | |
if not metadata_files: | |
raise FileNotFoundError(f"No metadata file found in {data_dir}") | |
if len(metadata_files) > 1: | |
raise ValueError(f"Multiple metadata files found in {data_dir}") | |
metadata_path = metadata_files[0] | |
# Extract aspect ratio from metadata filename (e.g. target_crop_info_26-15.json -> 26-15) | |
aspect_ratio = metadata_path.stem.split('_')[-1] | |
# Use aspect ratio subfolder for images | |
self.image_dir = data_dir / aspect_ratio | |
if not self.image_dir.exists(): | |
raise FileNotFoundError(f"Image directory not found: {self.image_dir}") | |
# Load metadata | |
with open(metadata_path, 'r') as f: | |
self.metadata = json.load(f) | |
eval_first_n = eval_first_n if eval_first_n != -1 else len(self.metadata) | |
self.metadata = self.metadata[:eval_first_n] | |
# Verify all images exist | |
for item in self.metadata: | |
image_path = self.image_dir / item['file_name'] | |
if not image_path.exists(): | |
raise FileNotFoundError(f"Image not found: {image_path}") | |
self.dummy_prompt = "DUMMY PROMPT" | |
self.pre_pad_len = len(self.metadata) | |
if pad_to_multiple_of is not None and len(self.metadata) % pad_to_multiple_of != 0: | |
# Duplicate the last entry | |
self.metadata += [self.metadata[-1]] * ( | |
pad_to_multiple_of - len(self.metadata) % pad_to_multiple_of | |
) | |
def __len__(self): | |
return len(self.metadata) | |
def __getitem__(self, idx): | |
""" | |
Returns: | |
dict: A dictionary containing: | |
- image: PIL Image | |
- caption: str | |
- target_bbox: list of int [x1, y1, x2, y2] | |
- target_ratio: str | |
- type: str | |
- origin_size: tuple of int (width, height) | |
""" | |
item = self.metadata[idx] | |
# Load image | |
image_path = self.image_dir / item['file_name'] | |
image = Image.open(image_path).convert('RGB') | |
# Apply transform if specified | |
if self.transform: | |
image = self.transform(image) | |
return { | |
'image': image, | |
'prompts': item['caption'], | |
'target_bbox': item['target_crop']['target_bbox'], | |
'target_ratio': item['target_crop']['target_ratio'], | |
'type': item['type'], | |
'origin_size': (item['origin_width'], item['origin_height']), | |
'idx': idx | |
} | |
def cycle(dl): | |
while True: | |
for data in dl: | |
yield data | |