Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import io | |
import os | |
import re | |
import boto3 | |
import webdataset as wds | |
from boto3.s3.transfer import TransferConfig | |
from webdataset.handlers import reraise_exception | |
def setup_s3_args(args): | |
if not args.s3_data_endpoint: | |
args.s3_data_endpoint = args.s3_endpoint | |
def save_on_s3(filename, s3_path, s3_endpoint): | |
s3_client = boto3.client( | |
service_name='s3', | |
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], | |
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], | |
endpoint_url=s3_endpoint | |
) | |
_, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path) | |
s3_client.upload_file(filename, bucket, key) | |
def download_from_s3(s3_path, s3_endpoint, filename, multipart_threshold_mb=512, multipart_chunksize_mb=512): | |
MB = 1024 ** 2 | |
transfer_config = TransferConfig( | |
multipart_threshold=multipart_threshold_mb * MB, | |
multipart_chunksize=multipart_chunksize_mb * MB, | |
max_io_queue=1000) | |
s3_client = boto3.client( | |
service_name='s3', | |
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], | |
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], | |
endpoint_url=s3_endpoint, | |
) | |
_, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path) | |
s3_client.download_file(bucket, key, filename, Config=transfer_config) | |
def override_wds_s3_tar_loading(s3_data_endpoint, s3_multipart_threshold_mb, s3_multipart_chunksize_mb, s3_max_io_queue): | |
# When loading from S3 using boto3, hijack webdatasets tar loading | |
MB = 1024 ** 2 | |
transfer_config = TransferConfig( | |
multipart_threshold=s3_multipart_threshold_mb * MB, | |
multipart_chunksize=s3_multipart_chunksize_mb * MB, | |
max_io_queue=s3_max_io_queue) | |
s3_client = boto3.client( | |
service_name='s3', | |
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'], | |
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'], | |
endpoint_url=s3_data_endpoint, | |
) | |
def get_bytes_io(path): | |
byte_io = io.BytesIO() | |
_, bucket, key, _ = re.split("s3://(.*?)/(.*)$", path) | |
s3_client.download_fileobj(bucket, key, byte_io, Config=transfer_config) | |
byte_io.seek(0) | |
return byte_io | |
def gopen_with_s3(url, mode="rb", bufsize=8192, **kw): | |
"""gopen from webdataset, but with s3 support""" | |
if url.startswith("s3://"): | |
return get_bytes_io(url) | |
else: | |
return wds.gopen.gopen(url, mode, bufsize, **kw) | |
def url_opener(data, handler=reraise_exception, **kw): | |
for sample in data: | |
url = sample["url"] | |
try: | |
stream = gopen_with_s3(url, **kw) | |
# stream = get_bytes_io(url) | |
sample.update(stream=stream) | |
yield sample | |
except Exception as exn: | |
exn.args = exn.args + (url,) | |
if handler(exn): | |
continue | |
else: | |
break | |
wds.tariterators.url_opener = url_opener | |