File size: 3,717 Bytes
3424266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import re
import boto3
import webdataset as wds

from boto3.s3.transfer import TransferConfig
from webdataset.handlers import reraise_exception


def setup_s3_args(args):
    if not args.s3_data_endpoint:
        args.s3_data_endpoint = args.s3_endpoint
    

def save_on_s3(filename, s3_path, s3_endpoint):

    s3_client = boto3.client(
                service_name='s3',
                aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
                aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
                endpoint_url=s3_endpoint
            )

    _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path)

    s3_client.upload_file(filename, bucket, key)


def download_from_s3(s3_path, s3_endpoint, filename, multipart_threshold_mb=512, multipart_chunksize_mb=512):

    MB = 1024 ** 2
    transfer_config = TransferConfig(
        multipart_threshold=multipart_threshold_mb * MB, 
        multipart_chunksize=multipart_chunksize_mb * MB, 
        max_io_queue=1000)


    s3_client = boto3.client(
                service_name='s3',
                aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
                aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
                endpoint_url=s3_endpoint,
            )

    _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", s3_path)

    s3_client.download_file(bucket, key, filename, Config=transfer_config)



def override_wds_s3_tar_loading(s3_data_endpoint, s3_multipart_threshold_mb, s3_multipart_chunksize_mb, s3_max_io_queue):
    
    # When loading from S3 using boto3, hijack webdatasets tar loading
    MB = 1024 ** 2
    transfer_config = TransferConfig(
        multipart_threshold=s3_multipart_threshold_mb * MB, 
        multipart_chunksize=s3_multipart_chunksize_mb * MB, 
        max_io_queue=s3_max_io_queue)

    s3_client = boto3.client(
        service_name='s3',
        aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
        aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
        endpoint_url=s3_data_endpoint,
    )

    def get_bytes_io(path):
        byte_io = io.BytesIO()
        _, bucket, key, _ = re.split("s3://(.*?)/(.*)$", path)
        s3_client.download_fileobj(bucket, key, byte_io, Config=transfer_config)
        byte_io.seek(0)
        return byte_io
    
    def gopen_with_s3(url, mode="rb", bufsize=8192, **kw):
        """gopen from webdataset, but with s3 support"""
        if url.startswith("s3://"):
            return get_bytes_io(url)
        else:
            return wds.gopen.gopen(url, mode, bufsize, **kw)

    def url_opener(data, handler=reraise_exception, **kw):
        for sample in data:
            url = sample["url"]
            try:
                stream = gopen_with_s3(url, **kw)
                # stream = get_bytes_io(url)
                sample.update(stream=stream)
                yield sample
            except Exception as exn:
                exn.args = exn.args + (url,)
                if handler(exn):
                    continue
                else:
                    break

    wds.tariterators.url_opener = url_opener