Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,398 Bytes
51ce47d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
"""This file contains the definition of data loader using webdataset.
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
Reference:
https://github.com/mlfoundations/open_clip/blob/main/src/training/data.py
https://github.com/huggingface/open-muse/blob/main/training/data.py
"""
import math
from typing import List, Union, Text
import webdataset as wds
import torch
from torch.utils.data import default_collate
from torchvision import transforms
from torch.utils.data import Dataset
import linecache
import json
def filter_keys(key_set):
def _f(dictionary):
return {k: v for k, v in dictionary.items() if k in key_set}
return _f
class ImageTransform:
def __init__(self,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop: bool = True,
random_flip: bool = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.]):
"""Initializes the WebDatasetReader with specified augmentation parameters.
Args:
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
crop_size: An integer, the size to crop the input image to.
random_crop: A boolean, whether to use random crop augmentation during training.
random_flip: A boolean, whether to use random flipping augmentation during training.
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
normalize_std: A list of float, the normalization std used to normalize the image tensor.
Raises:
NotImplementedError: If the interpolation mode is not one of ["bicubic", "bilinear"].
"""
train_transform = []
interpolation = transforms.InterpolationMode.BICUBIC
train_transform.append(
transforms.Resize(resize_shorter_edge, interpolation=interpolation, antialias=True))
if random_crop:
train_transform.append(transforms.RandomCrop(crop_size))
else:
train_transform.append(transforms.CenterCrop(crop_size))
if random_flip:
train_transform.append(transforms.RandomHorizontalFlip())
train_transform.append(transforms.ToTensor())
# normalize_mean = [0, 0, 0] and normalize_std = [1, 1, 1] will normalize images into [0, 1],
# normalize_mean = [0.5, 0.5, 0.5] and normalize_std = [0.5, 0.5, 0.5] will normalize images into [-1, 1].
train_transform.append(transforms.Normalize(normalize_mean, normalize_std))
self.train_transform = transforms.Compose(train_transform)
self.eval_transform = transforms.Compose(
[
# Note that we always resize to crop_size during eval to ensure the results
# can be compared against reference numbers on ImageNet etc.
transforms.Resize(crop_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(normalize_mean, normalize_std)
]
)
print(f"self.train_transform: {self.train_transform}")
print(f"self.eval_transform: {self.eval_transform}")
class SimpleImageDataset:
def __init__(
self,
train_shards_path: Union[Text, List[Text]],
eval_shards_path: Union[Text, List[Text]],
num_train_examples: int,
per_gpu_batch_size: int,
global_batch_size: int,
num_workers_per_gpu: int = 12,
resize_shorter_edge: int = 256,
crop_size: int = 256,
random_crop = True,
random_flip = True,
normalize_mean: List[float] = [0., 0., 0.],
normalize_std: List[float] = [1., 1., 1.],
):
"""Initializes the WebDatasetReader class.
Args:
train_shards_path: A string or list of string, path to the training data shards in webdataset format.
eval_shards_path: A string or list of string, path to the evaluation data shards in webdataset format.
num_train_examples: An integer, total number of training examples.
per_gpu_batch_size: An integer, number of examples per GPU batch.
global_batch_size: An integer, total number of examples in a batch across all GPUs.
num_workers_per_gpu: An integer, number of workers per GPU.
resize_shorter_edge: An integer, the shorter edge size to resize the input image to.
crop_size: An integer, the size to crop the input image to.
random_crop: A boolean, whether to use random crop augmentation during training.
random_flip: A boolean, whether to use random flipping augmentation during training.
normalize_mean: A list of float, the normalization mean used to normalize the image tensor.
normalize_std: A list of float, the normalization std used to normalize the image tensor.
"""
transform = ImageTransform(
resize_shorter_edge, crop_size, random_crop, random_flip,
normalize_mean, normalize_std)
train_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
wds.rename(
image="jpg;png;jpeg;webp",
class_id="cls",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
wds.map_dict(
image=transform.train_transform,
class_id=lambda x: int(x),
handler=wds.warn_and_continue,
),
]
test_processing_pipeline = [
wds.decode(wds.autodecode.ImageHandler("pil", extensions=["webp", "png", "jpg", "jpeg"])),
wds.rename(
image="jpg;png;jpeg;webp",
class_id="cls",
handler=wds.warn_and_continue,
),
wds.map(filter_keys(set(["image", "class_id", "filename"]))),
wds.map_dict(
image=transform.eval_transform,
class_id=lambda x: int(x),
handler=wds.warn_and_continue,
),
]
# Create train dataset and loader.
pipeline = [
wds.ResampledShards(train_shards_path),
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.shuffle(bufsize=5000,
initial=1000),
*train_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
]
num_batches = math.ceil(num_train_examples / global_batch_size)
num_worker_batches = math.ceil(num_train_examples /
(global_batch_size * num_workers_per_gpu))
num_batches = num_worker_batches * num_workers_per_gpu
num_samples = num_batches * global_batch_size
# Each worker is iterating over the complete dataset.
self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
self._train_dataloader = wds.WebLoader(
self._train_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
# Add meta-data to dataloader instance for convenience.
self._train_dataloader.num_batches = num_batches
self._train_dataloader.num_samples = num_samples
# Create eval dataset and loader.
pipeline = [
wds.SimpleShardList(eval_shards_path),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.ignore_and_continue),
*test_processing_pipeline,
wds.batched(per_gpu_batch_size, partial=True, collation_fn=default_collate),
]
self._eval_dataset = wds.DataPipeline(*pipeline)
self._eval_dataloader = wds.WebLoader(
self._eval_dataset,
batch_size=None,
shuffle=False,
num_workers=num_workers_per_gpu,
pin_memory=True,
persistent_workers=True,
)
@property
def train_dataset(self):
return self._train_dataset
@property
def train_dataloader(self):
return self._train_dataloader
@property
def eval_dataset(self):
return self._eval_dataset
@property
def eval_dataloader(self):
return self._eval_dataloader
class PretoeknizedDataSetJSONL(Dataset):
def __init__(self, data_path):
super().__init__()
self.jsonl_file = data_path
self.num_lines = sum(1 for _ in open(self.jsonl_file))
# Ensure the file is cached
linecache.checkcache(self.jsonl_file)
print("Number of data:", self.num_lines)
def __len__(self):
return self.num_lines
def __getitem__(self, idx):
line = linecache.getline(self.jsonl_file, idx + 1).strip()
data = json.loads(line)
return torch.tensor(data["class_id"]), torch.tensor(data["tokens"]) |