aroraaman's picture
Add all of `fourm`
3424266
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import re
import boto3
import webdataset as wds
from boto3.s3.transfer import TransferConfig
from webdataset.handlers import reraise_exception
def setup_s3_args(args):
if not args.s3_data_endpoint:
args.s3_data_endpoint = args.s3_endpoint
def save_on_s3(filename, s3_path, s3_endpoint):
s3_client = boto3.client(
service_name='s3',
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
endpoint_url=s3_endpoint
)
_, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path)
s3_client.upload_file(filename, bucket, key)
def download_from_s3(s3_path, s3_endpoint, filename, multipart_threshold_mb=512, multipart_chunksize_mb=512):
MB = 1024 ** 2
transfer_config = TransferConfig(
multipart_threshold=multipart_threshold_mb * MB,
multipart_chunksize=multipart_chunksize_mb * MB,
max_io_queue=1000)
s3_client = boto3.client(
service_name='s3',
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
endpoint_url=s3_endpoint,
)
_, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path)
s3_client.download_file(bucket, key, filename, Config=transfer_config)
def override_wds_s3_tar_loading(s3_data_endpoint, s3_multipart_threshold_mb, s3_multipart_chunksize_mb, s3_max_io_queue):
# When loading from S3 using boto3, hijack webdatasets tar loading
MB = 1024 ** 2
transfer_config = TransferConfig(
multipart_threshold=s3_multipart_threshold_mb * MB,
multipart_chunksize=s3_multipart_chunksize_mb * MB,
max_io_queue=s3_max_io_queue)
s3_client = boto3.client(
service_name='s3',
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
endpoint_url=s3_data_endpoint,
)
def get_bytes_io(path):
byte_io = io.BytesIO()
_, bucket, key, _ = re.split("s3://(.*?)/(.*)$", path)
s3_client.download_fileobj(bucket, key, byte_io, Config=transfer_config)
byte_io.seek(0)
return byte_io
def gopen_with_s3(url, mode="rb", bufsize=8192, **kw):
"""gopen from webdataset, but with s3 support"""
if url.startswith("s3://"):
return get_bytes_io(url)
else:
return wds.gopen.gopen(url, mode, bufsize, **kw)
def url_opener(data, handler=reraise_exception, **kw):
for sample in data:
url = sample["url"]
try:
stream = gopen_with_s3(url, **kw)
# stream = get_bytes_io(url)
sample.update(stream=stream)
yield sample
except Exception as exn:
exn.args = exn.args + (url,)
if handler(exn):
continue
else:
break
wds.tariterators.url_opener = url_opener