Spaces:
Runtime error
Runtime error
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# SPDX-License-Identifier: Apache-2.0 | |
# This file is modified from https://github.com/NVlabs/VILA/tree/main/llava/wids | |
import base64 | |
import gzip | |
import hashlib | |
import io | |
import json | |
import math | |
import os | |
import os.path as osp | |
import random | |
import re | |
import sqlite3 | |
import sys | |
import tempfile | |
import uuid | |
import warnings | |
from functools import lru_cache, partial | |
from typing import Any, BinaryIO, Dict, Optional, TypeVar, Union | |
from urllib.parse import quote, urlparse | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
from torch.utils.data.distributed import DistributedSampler | |
from .wids_dl import download_and_open | |
from .wids_lru import LRUCache | |
from .wids_mmtar import MMIndexedTar | |
from .wids_specs import load_dsdesc_and_resolve, urldir | |
from .wids_tar import TarFileReader, find_index_file | |
try: | |
from torch.utils.data import Dataset, Sampler | |
except ImportError: | |
class Dataset: | |
pass | |
class Sampler: | |
pass | |
T = TypeVar("T") | |
T_co = TypeVar("T_co", covariant=True) | |
def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str: | |
"""Compute the md5sum of a file in chunks. | |
Parameters | |
---------- | |
fname : Union[str, BinaryIO] | |
Filename or file object | |
chunksize : int, optional | |
Chunk size in bytes, by default 1000000 | |
Returns | |
------- | |
str | |
MD5 sum of the file | |
Examples | |
-------- | |
>>> compute_file_md5sum("test.txt") | |
'd41d8cd98f00b204e9800998ecf8427e' | |
""" | |
md5 = hashlib.md5() | |
if isinstance(fname, str): | |
with open(fname, "rb") as f: | |
for chunk in iter(lambda: f.read(chunksize), b""): | |
md5.update(chunk) | |
else: | |
fname.seek(0) | |
for chunk in iter(lambda: fname.read(chunksize), b""): | |
md5.update(chunk) | |
return md5.hexdigest() | |
def compute_file_md5sum(fname: Union[str, BinaryIO], chunksize: int = 1000000) -> str: | |
"""Compute the md5sum of a file in chunks.""" | |
md5 = hashlib.md5() | |
if isinstance(fname, str): | |
with open(fname, "rb") as f: | |
for chunk in iter(lambda: f.read(chunksize), b""): | |
md5.update(chunk) | |
else: | |
fname.seek(0) | |
for chunk in iter(lambda: fname.read(chunksize), b""): | |
md5.update(chunk) | |
return md5.hexdigest() | |
def compute_num_samples(fname): | |
ds = IndexedTarSamples(fname) | |
return len(ds) | |
def splitname(fname): | |
"""Returns the basename and extension of a filename""" | |
assert "." in fname, "Filename must have an extension" | |
# basename, extension = re.match(r"^((?:.*/)?.*?)(\..*)$", fname).groups() | |
basename, extension = os.path.splitext(fname) | |
return basename, extension | |
# NOTE(ligeng): change to ordered mapping to more flexbile dict | |
# TODO(ligeng): submit a PR to fix the mapping issue. | |
def group_by_key(names): | |
"""Group the file names by key. | |
Args: | |
names: A list of file names. | |
Returns: | |
A list of lists of indices, where each sublist contains indices of files | |
with the same key. | |
""" | |
groups = [] | |
kmaps = {} | |
for i, fname in enumerate(names): | |
# Ignore files that are not in a subdirectory. | |
if "." not in fname: | |
print(f"Warning: Ignoring file {fname} (no '.')") | |
continue | |
if fname == ".": | |
print(f"Warning: Ignoring the '.' file.") | |
continue | |
key, ext = splitname(fname) | |
if key not in kmaps: | |
kmaps[key] = [] | |
kmaps[key].append(i) | |
for k, v in kmaps.items(): | |
groups.append(v) | |
return groups | |
def default_decoder(sample: Dict[str, Any], format: Optional[Union[bool, str]] = True): | |
"""A default decoder for webdataset. | |
This handles common file extensions: .txt, .cls, .cls2, | |
.jpg, .png, .json, .npy, .mp, .pt, .pth, .pickle, .pkl. | |
These are the most common extensions used in webdataset. | |
For other extensions, users can provide their own decoder. | |
Args: | |
sample: sample, modified in place | |
""" | |
sample = dict(sample) | |
for key, stream in sample.items(): | |
extensions = key.split(".") | |
if len(extensions) < 1: | |
continue | |
extension = extensions[-1] | |
if extension in ["gz"]: | |
decompressed = gzip.decompress(stream.read()) | |
stream = io.BytesIO(decompressed) | |
if len(extensions) < 2: | |
sample[key] = stream | |
continue | |
extension = extensions[-2] | |
if key.startswith("__"): | |
continue | |
elif extension in ["txt", "text"]: | |
value = stream.read() | |
sample[key] = value.decode("utf-8") | |
elif extension in ["cls", "cls2"]: | |
value = stream.read() | |
sample[key] = int(value.decode("utf-8")) | |
elif extension in ["jpg", "png", "ppm", "pgm", "pbm", "pnm"]: | |
if format == "PIL": | |
import PIL.Image | |
sample[key] = PIL.Image.open(stream) | |
elif format == "numpy": | |
import numpy as np | |
sample[key] = np.asarray(PIL.Image.open(stream)) | |
else: | |
raise ValueError(f"Unknown format: {format}") | |
elif extension == "json": | |
import json | |
value = stream.read() | |
sample[key] = json.loads(value) | |
elif extension == "npy": | |
import numpy as np | |
sample[key] = np.load(stream) | |
elif extension == "mp": | |
import msgpack | |
value = stream.read() | |
sample[key] = msgpack.unpackb(value, raw=False) | |
elif extension in ["pt", "pth"]: | |
import torch | |
sample[key] = torch.load(stream) | |
elif extension in ["pickle", "pkl"]: | |
import pickle | |
sample[key] = pickle.load(stream) | |
elif extension == "mp4": | |
# Write stream to a temporary file | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as tmpfile: | |
# tmpfile.write(stream.read()) | |
# tmpfile_path = tmpfile.name | |
# sample[key] = tmpfile_path | |
sample[key] = io.BytesIO(stream.read()) | |
return sample | |
def update_dict_with_extend(original_dict, update_dict): | |
for key, value in update_dict.items(): | |
if key in original_dict and isinstance(original_dict[key], list) and isinstance(value, list): | |
original_dict[key].extend(value) | |
else: | |
original_dict[key] = value | |
open_itfs = {} | |
class IndexedTarSamples: | |
"""A class that accesses samples in a tar file. The tar file must follow | |
WebDataset conventions. The tar file is indexed when the IndexedTarSamples | |
object is created. The samples are accessed by index using the __getitem__ | |
method. The __getitem__ method returns a dictionary containing the files | |
for the sample. The key for each file is the extension of the file name. | |
The key "__key__" is reserved for the key of the sample (the basename of | |
each file without the extension). For example, if the tar file contains | |
the files "sample1.jpg" and "sample1.txt", then the sample with key | |
"sample1" will be returned as the dictionary {"jpg": ..., "txt": ...}. | |
""" | |
def __init__( | |
self, | |
*, | |
path=None, | |
stream=None, | |
md5sum=None, | |
expected_size=None, | |
use_mmap=True, | |
index_file=find_index_file, | |
): | |
assert path is not None or stream is not None | |
# Create TarFileReader object to read from tar_file | |
self.path = path | |
stream = self.stream = stream or open(path, "rb") | |
# verify the MD5 sum | |
if md5sum is not None: | |
stream.seek(0) | |
got = compute_file_md5sum(stream) | |
assert got == md5sum, f"MD5 sum mismatch: expected {md5sum}, got {got}" | |
stream.seek(0) | |
# use either the mmap or the stream based implementation | |
# NOTE(ligeng): https://stackoverflow.com/questions/11072705/twitter-trends-api-unicodedecodeerror-utf8-codec-cant-decode-byte-0x8b-in-po | |
# import gzip | |
# print("convert to gzip IO stream") | |
# stream = gzip.GzipFile(fileobj=stream) | |
if use_mmap: | |
self.reader = MMIndexedTar(stream) | |
else: | |
self.reader = TarFileReader(stream, index_file=index_file) | |
# Get list of all files in stream | |
all_files = self.reader.names() | |
# Group files by key into samples | |
self.samples = group_by_key(all_files) | |
# print("DEBUG:", list(all_files)[:20]) | |
# print("DEBUG:", self.samples[:20]) | |
# check that the number of samples is correct | |
if expected_size is not None: | |
assert len(self) == expected_size, f"Expected {expected_size} samples, got {len(self)}" | |
self.uuid = str(uuid.uuid4()) | |
def close(self): | |
self.reader.close() | |
if not self.stream.closed: | |
self.stream.close() | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
# Get indexes of files for the sample at index idx | |
try: | |
indexes = self.samples[idx] | |
except IndexError as e: | |
print(f"[wids-debug] curr idx: {idx}, total sample length: {len(self.samples)} {e}") | |
raise e | |
sample = {} | |
key = None | |
for i in indexes: | |
# Get filename and data for the file at index i | |
fname, data = self.reader.get_file(i) | |
# Split filename into key and extension | |
k, ext = splitname(fname) | |
# Make sure all files in sample have same key | |
key = key or k | |
assert key == k | |
sample[ext] = data | |
# Add key to sample | |
sample["__key__"] = key | |
return sample | |
def __str__(self): | |
return f"<IndexedTarSamples-{id(self)} {self.path}>" | |
def __repr__(self): | |
return str(self) | |
def hash_localname(dldir="/tmp/_wids_cache"): | |
os.makedirs(dldir, exist_ok=True) | |
connection = sqlite3.connect(os.path.join(dldir, "cache.db")) | |
cursor = connection.cursor() | |
cursor.execute("CREATE TABLE IF NOT EXISTS cache (url TEXT PRIMARY KEY, path TEXT, checksum TEXT)") | |
connection.commit() | |
def f(shard): | |
"""Given a URL, return a local name for the shard.""" | |
if shard.startswith("pipe:"): | |
# uuencode the entire URL string | |
hex32 = base64.urlsafe_b64encode(hashlib.sha256(shard.encode()).digest())[:32].decode() | |
return os.path.join(dldir, "pipe__" + hex32) | |
else: | |
# we hash the host and directory components into a 16 character string | |
dirname = urldir(shard) | |
hex16 = base64.urlsafe_b64encode(hashlib.sha256(dirname.encode()).digest())[:16].decode() | |
# the cache name is the concatenation of the hex16 string and the file name component of the URL | |
cachename = "data__" + hex16 + "__" + os.path.basename(urlparse(shard).path) | |
checksum = None | |
cursor.execute( | |
"INSERT OR REPLACE INTO cache VALUES (?, ?, ?)", | |
(shard, cachename, checksum), | |
) | |
connection.commit() | |
return os.path.join(dldir, cachename) | |
return f | |
def cache_localname(cachedir): | |
os.makedirs(cachedir, exist_ok=True) | |
def f(shard): | |
"""Given a URL, return a local name for the shard.""" | |
path = urlparse(shard).path | |
fname = os.path.basename(path) | |
return os.path.join(cachedir, fname) | |
return f | |
def default_localname(dldir="/tmp/_wids_cache"): | |
os.makedirs(dldir, exist_ok=True) | |
def f(shard): | |
"""Given a URL, return a local name for the shard.""" | |
cachename = quote(shard, safe="+-") | |
return os.path.join(dldir, cachename) | |
return f | |
class LRUShards: | |
"""A class that manages a cache of shards. The cache is a LRU cache that | |
stores the local names of the shards as keys and the downloaded paths as | |
values. The shards are downloaded to a directory specified by dldir. | |
The local name of a shard is computed by the localname function, which | |
takes the shard URL as an argument. If keep is True, the downloaded files | |
are not deleted when they are no longer needed. | |
""" | |
def __init__(self, lru_size, keep=False, localname=default_localname()): | |
self.localname = localname | |
# the cache contains the local name as the key and the downloaded path as the value | |
self.lru = LRUCache(lru_size, release_handler=self.release_handler) | |
# keep statistics | |
self.reset_stats() | |
def reset_stats(self): | |
self.accesses = 0 | |
self.misses = 0 | |
def __len__(self): | |
return len(self.lru) | |
def release_handler(self, key, value): | |
value.close() | |
def clear(self): | |
self.lru.clear() | |
def get_shard(self, url): | |
assert isinstance(url, str) | |
self.accesses += 1 | |
if url not in self.lru: | |
local = self.localname(url) | |
with download_and_open(url, local) as stream: | |
itf = IndexedTarSamples(path=local, stream=stream) | |
self.lru[url] = itf | |
self.misses += 1 | |
self.last_missed = True | |
else: | |
self.last_missed = False | |
return self.lru[url] | |
def interpret_transformations(transformations): | |
"""Interpret the transformations argument. | |
This takes care of transformations specified as string shortcuts | |
and returns a list of callables. | |
""" | |
if not isinstance(transformations, list): | |
transformations = [transformations] | |
result = [] | |
for transformation in transformations: | |
if transformation == "PIL": | |
transformation = partial(default_decoder, format="PIL") | |
elif transformation == "numpy": | |
transformation = partial(default_decoder, format="numpy") | |
else: | |
assert callable(transformation) | |
result.append(transformation) | |
return result | |
def hash_dataset_name(input_string): | |
"""Compute a hash of the input string and return the first 16 characters of the hash.""" | |
# Compute SHA256 hash of the input string | |
hash_object = hashlib.sha256(input_string.encode()) | |
hash_digest = hash_object.digest() | |
# Encode the hash in base64 | |
base64_encoded_hash = base64.urlsafe_b64encode(hash_digest) | |
# Return the first 16 characters of the base64-encoded hash | |
return base64_encoded_hash[:16].decode("ascii") | |
def lru_json_load(fpath): | |
with open(fpath) as fp: | |
return json.load(fp) | |
class ShardListDataset(Dataset[T]): | |
"""An indexable dataset based on a list of shards. | |
The dataset is either given as a list of shards with optional options and name, | |
or as a URL pointing to a JSON descriptor file. | |
Datasets can reference other datasets via `source_url`. | |
Shard references within a dataset are resolve relative to an explicitly | |
given `base` property, or relative to the URL from which the dataset | |
descriptor was loaded. | |
""" | |
def __init__( | |
self, | |
shards, | |
*, | |
cache_size=int(1e12), | |
cache_dir=None, | |
lru_size=10, | |
dataset_name=None, | |
localname=None, | |
transformations="PIL", | |
keep=False, | |
base=None, | |
options=None, | |
): | |
"""Create a ShardListDataset. | |
Args: | |
shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file | |
cache_size: the number of shards to keep in the cache | |
lru_size: the number of shards to keep in the LRU cache | |
localname: a function that maps URLs to local filenames | |
Note that there are two caches: an on-disk directory, and an in-memory LRU cache. | |
""" | |
if options is None: | |
options = {} | |
super().__init__() | |
# shards is a list of (filename, length) pairs. We'll need to | |
# keep track of the lengths and cumulative lengths to know how | |
# to map indices to shards and indices within shards. | |
if isinstance(shards, (str, io.IOBase)): | |
if base is None and isinstance(shards, str): | |
shards = osp.expanduser(shards) | |
base = urldir(shards) | |
self.base = base | |
self.spec = load_dsdesc_and_resolve(shards, options=options, base=base) | |
self.shards = self.spec.get("shardlist", []) | |
self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards)) | |
else: | |
raise NotImplementedError("Only support taking path/url to JSON descriptor file.") | |
self.base = None | |
self.spec = options | |
self.shards = shards | |
self.dataset_name = dataset_name or hash_dataset_name(str(shards)) | |
self.lengths = [shard["nsamples"] for shard in self.shards] | |
self.cum_lengths = np.cumsum(self.lengths) | |
self.total_length = self.cum_lengths[-1] | |
if cache_dir is not None: | |
# when a cache dir is explicitly given, we download files into | |
# that directory without any changes | |
self.cache_dir = cache_dir | |
self.localname = cache_localname(cache_dir) | |
elif localname is not None: | |
# when a localname function is given, we use that | |
self.cache_dir = None | |
self.localname = localname | |
else: | |
import getpass | |
# when no cache dir or localname are given, use the cache from the environment | |
self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache") | |
self.cache_dir = osp.expanduser(self.cache_dir) | |
self.localname = default_localname(self.cache_dir) | |
self.data_info = ( | |
f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, " | |
f"nfiles: {str(len(self.shards))}" | |
) | |
if True or int(os.environ.get("WIDS_VERBOSE", 0)): | |
nbytes = sum(shard.get("filesize", 0) for shard in self.shards) | |
nsamples = sum(shard["nsamples"] for shard in self.shards) | |
self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} " | |
# print( | |
# "[WebShardedList]", | |
# str(shards), | |
# "base:", | |
# self.base, | |
# "name:", | |
# self.spec.get("name"), | |
# "nfiles:", | |
# len(self.shards), | |
# "nbytes:", | |
# nbytes, | |
# "samples:", | |
# nsamples, | |
# "cache:", | |
# self.cache_dir, | |
# file=sys.stderr, | |
# ) | |
self.transformations = interpret_transformations(transformations) | |
if lru_size > 200: | |
warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors") | |
self.cache = LRUShards(lru_size, localname=self.localname, keep=keep) | |
def add_transform(self, transform): | |
"""Add a transformation to the dataset.""" | |
self.transformations.append(transform) | |
return self | |
def __len__(self): | |
"""Return the total number of samples in the dataset.""" | |
return self.total_length | |
def get_stats(self): | |
"""Return the number of cache accesses and misses.""" | |
return self.cache.accesses, self.cache.misses | |
def check_cache_misses(self): | |
"""Check if the cache miss rate is too high.""" | |
accesses, misses = self.get_stats() | |
if accesses > 100 and misses / accesses > 0.3: | |
# output a warning only once | |
self.check_cache_misses = lambda: None | |
print(f"Warning: ShardListDataset has a cache miss rate of {misses * 100.0 / accesses:.1%}%") | |
def get_shard(self, index): | |
"""Get the shard and index within the shard corresponding to the given index.""" | |
# Find the shard corresponding to the given index. | |
shard_idx = np.searchsorted(self.cum_lengths, index, side="right") | |
# Figure out which index within the shard corresponds to the | |
# given index. | |
if shard_idx == 0: | |
inner_idx = index | |
else: | |
inner_idx = index - self.cum_lengths[shard_idx - 1] | |
# Get the shard and return the corresponding element. | |
desc = self.shards[shard_idx] | |
url = desc["url"] | |
if url.startswith(("https://", "http://", "gs://", "/", "~")): | |
# absolute path or url path | |
url = url | |
else: | |
# concat relative path | |
if self.base is None and "base_path" not in self.spec: | |
raise FileNotFoundError("passing a relative path in shardlist but no base found.") | |
base_path = self.spec["base_path"] if "base_path" in self.spec else self.base | |
url = osp.abspath(osp.join(osp.expanduser(base_path), url)) | |
desc["url"] = url | |
try: | |
shard = self.cache.get_shard(url) | |
except UnicodeDecodeError as e: | |
print("UnicodeDecodeError:", desc) | |
raise e | |
return shard, inner_idx, desc | |
def __getitem__(self, index): | |
"""Return the sample corresponding to the given index.""" | |
shard, inner_idx, desc = self.get_shard(index) | |
sample = shard[inner_idx] | |
# Check if we're missing the cache too often. | |
self.check_cache_misses() | |
sample["__dataset__"] = desc.get("dataset") | |
sample["__index__"] = index | |
sample["__shard__"] = desc["url"] | |
sample["__shardindex__"] = inner_idx | |
# Apply transformations | |
for transform in self.transformations: | |
sample = transform(sample) | |
return sample | |
def close(self): | |
"""Close the dataset.""" | |
self.cache.clear() | |
class ShardListDatasetMulti(ShardListDataset): | |
"""An indexable dataset based on a list of shards. | |
The dataset is either given as a list of shards with optional options and name, | |
or as a URL pointing to a JSON descriptor file. | |
Datasets can reference other datasets via `source_url`. | |
Shard references within a dataset are resolve relative to an explicitly | |
given `base` property, or relative to the URL from which the dataset | |
descriptor was loaded. | |
""" | |
def __init__( | |
self, | |
shards, | |
*, | |
cache_size=int(1e12), | |
cache_dir=None, | |
lru_size=10, | |
dataset_name=None, | |
localname=None, | |
transformations="PIL", | |
keep=False, | |
base=None, | |
options=None, | |
sort_data_inseq=False, | |
num_replicas=None, | |
): | |
"""Create a ShardListDataset. | |
Args: | |
shards: a list of (filename, length) pairs or a URL pointing to a JSON descriptor file | |
cache_size: the number of shards to keep in the cache | |
lru_size: the number of shards to keep in the LRU cache | |
localname: a function that maps URLs to local filenames | |
Note that there are two caches: an on-disk directory, and an in-memory LRU cache. | |
""" | |
if options is None: | |
options = {} | |
# shards is a list of (filename, length) pairs. We'll need to | |
# keep track of the lengths and cumulative lengths to know how | |
# to map indices to shards and indices within shards. | |
shards_lists = shards if isinstance(shards, list) else [shards] | |
bases = base if isinstance(base, list) else [base] * len(shards_lists) | |
self.spec = {} | |
self.shards = [] | |
self.num_per_dir = {} | |
for base, shards in zip(bases, shards_lists): | |
if isinstance(shards, (str, io.IOBase)): | |
if base is None and isinstance(shards, str): | |
shards = osp.expanduser(shards) | |
base = urldir(shards) | |
self.base = base | |
_spec = load_dsdesc_and_resolve(shards, options=options, base=base) | |
update_dict_with_extend(self.spec, _spec) | |
self.num_per_dir[os.path.basename(os.path.dirname(shards))] = sum( | |
[shard["nsamples"] for shard in _spec.get("shardlist", [])] | |
) | |
else: | |
raise NotImplementedError("Only support taking path/url to JSON descriptor file.") | |
self.base = None | |
self.spec = options | |
self.shards = shards | |
self.dataset_name = dataset_name or hash_dataset_name(str(shards)) | |
if sort_data_inseq and len(self.spec.get("shardlist", [])) > 0: | |
num_replicas = num_replicas or dist.get_world_size() | |
self.spec["shardlist"] = split_and_recombine(self.spec["shardlist"], num_replicas) | |
self.shards.extend(self.spec.get("shardlist", [])) | |
self.dataset_name = self.spec.get("name") or hash_dataset_name(str(shards)) | |
self.lengths = [shard["nsamples"] for shard in self.shards] | |
self.cum_lengths = np.cumsum(self.lengths) | |
self.total_length = self.cum_lengths[-1] | |
if cache_dir is not None: | |
# when a cache dir is explicitly given, we download files into | |
# that directory without any changes | |
self.cache_dir = cache_dir | |
self.localname = cache_localname(cache_dir) | |
elif localname is not None: | |
# when a localname function is given, we use that | |
self.cache_dir = None | |
self.localname = localname | |
else: | |
import getpass | |
# when no cache dir or localname are given, use the cache from the environment | |
self.cache_dir = os.environ.get("WIDS_CACHE", f"~/.cache/_wids_cache") | |
self.cache_dir = osp.expanduser(self.cache_dir) | |
self.localname = default_localname(self.cache_dir) | |
self.data_info = ( | |
f"[WebShardedList] {str(shards)}, base: {self.base,}, name: {self.spec.get('name')}, " | |
f"nfiles: {str(len(self.shards))}" | |
) | |
if True or int(os.environ.get("WIDS_VERBOSE", 0)): | |
nbytes = sum(shard.get("filesize", 0) for shard in self.shards) | |
nsamples = sum(shard["nsamples"] for shard in self.shards) | |
self.data_info += f"nbytes: {str(nbytes)}, samples: {str(nsamples),}, cache: {self.cache_dir} " | |
self.transformations = interpret_transformations(transformations) | |
if lru_size > 200: | |
warnings.warn("LRU size is very large; consider reducing it to avoid running out of file descriptors") | |
self.cache = LRUShards(lru_size, localname=self.localname, keep=keep) | |
def split_and_recombine(lst, n): | |
from collections import OrderedDict | |
def extract_prefix(i): | |
return i["url"].split("/")[-2] | |
unique_parts = list(OrderedDict((extract_prefix(item), None) for item in lst).keys()) | |
split_dict = {part: [] for part in unique_parts} | |
for part in unique_parts: | |
part_list = [item for item in lst if extract_prefix(item) == part] | |
chunk_size = max(1, len(part_list) // n) # 确保 chunk_size 至少为 1 | |
chunks = [part_list[i * chunk_size : (i + 1) * chunk_size] for i in range(n)] | |
# 处理最后一个 chunk,如果数量不均匀,将剩余的元素添加到最后一个 chunk | |
if len(part_list) % n != 0: | |
chunks[-1].extend(part_list[n * chunk_size :]) | |
split_dict[part] = chunks | |
recombined_list = [] | |
for i in range(n): | |
for part in unique_parts: | |
recombined_list.extend(split_dict[part][i]) | |
return recombined_list | |
def lengths_to_ranges(lengths): | |
"""Convert a list of lengths to a list of ranges.""" | |
ranges = [] | |
start = 0 | |
for length in lengths: | |
ranges.append((start, start + length)) | |
start += length | |
return ranges | |
def intersect_range(a, b): | |
"""Return the intersection of the two half-open integer intervals.""" | |
result = max(a[0], b[0]), min(a[1], b[1]) | |
if result[0] >= result[1]: | |
return None | |
return result | |
def intersect_ranges(rangelist, r): | |
"""Return the intersection of the half-open integer interval r with the list of half-open integer intervals.""" | |
result = [] | |
for a in rangelist: | |
x = intersect_range(a, r) | |
if x is not None: | |
result.append(x) | |
return result | |
def iterate_ranges(ranges, rng, indexshuffle=True, shardshuffle=True): | |
"""Iterate over the ranges in a random order.""" | |
shard_indexes = list(range(len(ranges))) | |
if shardshuffle: | |
rng.shuffle(shard_indexes) | |
for i in shard_indexes: | |
lo, hi = ranges[i] | |
sample_indexes = list(range(lo, hi)) | |
if indexshuffle: | |
rng.shuffle(sample_indexes) | |
yield from sample_indexes | |
class ShardListSampler(Sampler): | |
"""A sampler that samples consistent with a ShardListDataset. | |
This sampler is used to sample from a ShardListDataset in a way that | |
preserves locality. | |
This returns a permutation of the indexes by shard, then a permutation of | |
indexes within each shard. This ensures that the data is accessed in a | |
way that preserves locality. | |
Note that how this ends up splitting data between multiple workers ends up | |
on the details of the DataLoader. Generally, it will likely load samples from the | |
same shard in each worker. | |
Other more sophisticated shard-aware samplers are possible and will likely | |
be added. | |
""" | |
def __init__(self, dataset, *, lengths=None, seed=0, shufflefirst=False): | |
if lengths is None: | |
lengths = list(dataset.lengths) | |
self.ranges = lengths_to_ranges(lengths) | |
self.seed = seed | |
self.shufflefirst = shufflefirst | |
self.epoch = 0 | |
def __iter__(self): | |
self.rng = random.Random(self.seed + 1289738273 * self.epoch) | |
shardshuffle = self.shufflefirst or self.epoch > 0 | |
yield from iterate_ranges(self.ranges, self.rng, shardshuffle=shardshuffle) | |
self.epoch += 1 | |
ShardedSampler = ShardListSampler | |
class ChunkedSampler(Sampler): | |
"""A sampler that samples in chunks and then shuffles the samples within each chunk. | |
This preserves locality of reference while still shuffling the data. | |
""" | |
def __init__( | |
self, | |
dataset, | |
*, | |
num_samples=None, | |
chunksize=2000, | |
seed=0, | |
shuffle=False, | |
shufflefirst=False, | |
): | |
if isinstance(num_samples, int): | |
lo, hi = 0, num_samples | |
elif num_samples is None: | |
lo, hi = 0, len(dataset) | |
else: | |
lo, hi = num_samples | |
self.ranges = [(i, min(i + chunksize, hi)) for i in range(lo, hi, chunksize)] | |
self.seed = seed | |
self.shuffle = shuffle | |
self.shufflefirst = shufflefirst | |
self.epoch = 0 | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |
def __iter__(self): | |
self.rng = random.Random(self.seed + 1289738273 * self.epoch) | |
shardshuffle = self.shufflefirst or self.epoch > 0 | |
yield from iterate_ranges( | |
self.ranges, | |
self.rng, | |
indexshuffle=self.shuffle, | |
shardshuffle=(self.shuffle and shardshuffle), | |
) | |
self.epoch += 1 | |
def __len__(self): | |
return len(self.ranges) | |
def DistributedChunkedSampler( | |
dataset: Dataset, | |
*, | |
num_replicas: Optional[int] = None, | |
num_samples: Optional[int] = None, | |
rank: Optional[int] = None, | |
shuffle: bool = True, | |
shufflefirst: bool = False, | |
seed: int = 0, | |
drop_last: bool = None, | |
chunksize: int = 1000000, | |
) -> ChunkedSampler: | |
"""Return a ChunkedSampler for the current worker in distributed training. | |
Reverts to a simple ChunkedSampler if not running in distributed mode. | |
Since the split among workers takes place before the chunk shuffle, | |
workers end up with a fixed set of shards they need to download. The | |
more workers, the fewer shards are used by each worker. | |
""" | |
if drop_last is not None: | |
warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored") | |
if not dist.is_initialized(): | |
warnings.warn("DistributedChunkedSampler is called without distributed initialized; assuming single process") | |
num_replicas = 1 | |
rank = 0 | |
else: | |
num_replicas = num_replicas or dist.get_world_size() | |
rank = rank or dist.get_rank() | |
assert rank >= 0 and rank < num_replicas | |
num_samples = num_samples or len(dataset) | |
worker_chunk = (num_samples + num_replicas - 1) // num_replicas | |
worker_start = rank * worker_chunk | |
worker_end = min(worker_start + worker_chunk, num_samples) | |
return ChunkedSampler( | |
dataset, | |
num_samples=(worker_start, worker_end), | |
chunksize=chunksize, | |
seed=seed, | |
shuffle=shuffle, | |
shufflefirst=shufflefirst, | |
) | |
class DistributedRangedSampler(Sampler): | |
"""A sampler that samples in chunks and then shuffles the samples within each chunk. | |
This preserves locality of reference while still shuffling the data. | |
""" | |
def __init__( | |
self, | |
dataset: Dataset, | |
num_replicas: Optional[int] = None, | |
num_samples: Optional[int] = None, | |
rank: Optional[int] = None, | |
drop_last: bool = None, | |
): | |
if drop_last is not None: | |
warnings.warn("DistributedChunkedSampler does not support drop_last, thus it will be ignored") | |
if not dist.is_initialized(): | |
warnings.warn( | |
"DistributedChunkedSampler is called without distributed initialized; assuming single process" | |
) | |
num_replicas = 1 | |
rank = 0 | |
else: | |
num_replicas = num_replicas or dist.get_world_size() | |
rank = rank or dist.get_rank() | |
assert rank >= 0 and rank < num_replicas | |
num_samples = num_samples or len(dataset) | |
self.worker_chunk = num_samples // num_replicas | |
self.worker_start = rank * self.worker_chunk | |
self.worker_end = min((rank + 1) * self.worker_chunk, num_samples) | |
self.ranges = range(self.worker_start, self.worker_end) | |
self.epoch = 0 | |
self.step_start = 0 | |
def set_epoch(self, epoch): | |
self.epoch = epoch | |
def __len__(self): | |
return len(self.ranges) | |
def set_start(self, start): | |
self.step_start = start | |
def __iter__(self): | |
yield from self.ranges[self.step_start :] | |
self.epoch += 1 | |
class DistributedLocalSampler(DistributedSampler): | |
def __iter__(self): | |
if self.shuffle: | |
# deterministically shuffle based on epoch and seed | |
g = torch.Generator() | |
g.manual_seed(self.seed + self.epoch) | |
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] | |
else: | |
indices = list(range(len(self.dataset))) # type: ignore[arg-type] | |
if not self.drop_last: | |
# add extra samples to make it evenly divisible | |
padding_size = self.total_size - len(indices) | |
if padding_size <= len(indices): | |
indices += indices[:padding_size] | |
else: | |
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |
else: | |
# remove tail of data to make it evenly divisible. | |
indices = indices[: self.total_size] | |
assert len(indices) == self.total_size | |
# subsample | |
# indices = indices[self.rank:self.total_size:self.num_replicas] | |
chunk_size = self.total_size // self.num_replicas | |
begin_idx = chunk_size * self.rank | |
stop_idx = chunk_size * (self.rank + 1) | |
indices = indices[begin_idx:stop_idx] | |
# print("[SamplerIndices: ]", indices) | |
assert len(indices) == self.num_samples | |
return iter(indices) | |