File size: 4,892 Bytes
546a9ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

#!/usr/bin/python3.6

# simple command-line wrapper around the chunked_dataset_iterator
# Example:
#   block_randomize my_chunked_data_folder/
#   block_randomize --azure-storage-key $MY_KEY https://myaccount.blob.core.windows.net/mycontainer/my_chunked_data_folder

import os, sys, inspect

sys.path.insert(
    0,
    os.path.dirname(
        os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
    ),
)  # find our imports

from infinibatch.datasets import chunked_dataset_iterator

from typing import Union, Iterator, Callable, Any, Optional, Dict
import os, sys, re
import gzip


# helper functions to abstract access to Azure blobs
# @TODO: These will be abstracted into a helper library in a future version.
def _try_parse_azure_blob_uri(path: str):
    try:
        m = re.compile("https://([a-z0-9]*).blob.core.windows.net/([^/]*)/(.*)").match(
            path
        )
        # print (m.group(1))
        # print (m.group(2))
        # print (m.group(3))
        return (m.group(1), m.group(2), m.group(3))
    except:
        return None


def _get_azure_key(
    storage_account: str, credentials: Optional[Union[str, Dict[str, str]]]
):
    if not credentials:
        return None
    elif isinstance(credentials, str):
        return credentials
    else:
        return credentials[storage_account]


def read_utf8_file(
    path: str, credentials: Optional[Union[str, Dict[str, str]]]
) -> Iterator[str]:
    blob_data = _try_parse_azure_blob_uri(path)
    if blob_data is None:
        with open(path, "rb") as f:
            data = f.read()
    else:
        try:
            # pip install azure-storage-blob
            from azure.storage.blob import BlobClient
        except:
            print(
                "Failed to import azure.storage.blob. Please pip install azure-storage-blob",
                file=sys.stderr,
            )
            raise
        data = (
            BlobClient.from_blob_url(
                path,
                credential=_get_azure_key(
                    storage_account=blob_data[0], credentials=credentials
                ),
            )
            .download_blob()
            .readall()
        )
    if path.endswith(".gz"):
        data = gzip.decompress(data)
    # @TODO: auto-detect UCS-2 by BOM
    return iter(data.decode(encoding="utf-8").splitlines())


def enumerate_files(
    dir: str, ext: str, credentials: Optional[Union[str, Dict[str, str]]]
):
    blob_data = _try_parse_azure_blob_uri(dir)
    if blob_data is None:
        return [
            os.path.join(dir, path.name)
            for path in os.scandir(dir)
            if path.is_file() and (ext is None or path.name.endswith(ext))
        ]
    else:
        try:
            # pip install azure-storage-blob
            from azure.storage.blob import ContainerClient
        except:
            print(
                "Failed to import azure.storage.blob. Please pip install azure-storage-blob",
                file=sys.stderr,
            )
            raise
        account, container, blob_path = blob_data

        print("enumerate_files: enumerating blobs in", dir, file=sys.stderr, flush=True)
        # @BUGBUG: The prefix does not seem to have to start; seems it can also be a substring
        container_uri = "https://" + account + ".blob.core.windows.net/" + container
        container_client = ContainerClient.from_container_url(
            container_uri, credential=_get_azure_key(account, credentials)
        )
        if not blob_path.endswith("/"):
            blob_path += "/"
        blob_uris = [
            container_uri + "/" + blob["name"]
            for blob in container_client.walk_blobs(blob_path, delimiter="")
            if (ext is None or blob["name"].endswith(ext))
        ]
        print(
            "enumerate_files:",
            len(blob_uris),
            "blobs found",
            file=sys.stderr,
            flush=True,
        )
        for blob_name in blob_uris[:10]:
            print(blob_name, file=sys.stderr, flush=True)
        return blob_uris


if sys.argv[1] == "--azure-storage-key":
    credential = sys.argv[2]
    paths = sys.argv[3:]
else:
    credential = None
    paths = sys.argv[1:]

chunk_file_paths = [  # enumerate all .gz files in the given paths
    subpath for path in paths for subpath in enumerate_files(path, ".gz", credential)
]
chunk_file_paths.sort()  # make sure file order is always the same, independent of OS
print(
    "block_randomize: reading from",
    len(chunk_file_paths),
    "chunk files",
    file=sys.stderr,
)

ds = chunked_dataset_iterator(
    chunk_refs=chunk_file_paths,
    read_chunk_fn=lambda path: read_utf8_file(path, credential),
    shuffle=True,
    buffer_size=1000000,
    seed=1,
    use_windowed=True,
)
for line in ds:
    print(line)