Spaces:
Build error
Build error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
#!/usr/bin/python3.6 | |
# simple command-line wrapper around the chunked_dataset_iterator | |
# Example: | |
# block_randomize my_chunked_data_folder/ | |
# block_randomize --azure-storage-key $MY_KEY https://myaccount.blob.core.windows.net/mycontainer/my_chunked_data_folder | |
import os, sys, inspect | |
sys.path.insert( | |
0, | |
os.path.dirname( | |
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
), | |
) # find our imports | |
from infinibatch.datasets import chunked_dataset_iterator | |
from typing import Union, Iterator, Callable, Any, Optional, Dict | |
import os, sys, re | |
import gzip | |
# helper functions to abstract access to Azure blobs | |
# @TODO: These will be abstracted into a helper library in a future version. | |
def _try_parse_azure_blob_uri(path: str): | |
try: | |
m = re.compile("https://([a-z0-9]*).blob.core.windows.net/([^/]*)/(.*)").match( | |
path | |
) | |
# print (m.group(1)) | |
# print (m.group(2)) | |
# print (m.group(3)) | |
return (m.group(1), m.group(2), m.group(3)) | |
except: | |
return None | |
def _get_azure_key( | |
storage_account: str, credentials: Optional[Union[str, Dict[str, str]]] | |
): | |
if not credentials: | |
return None | |
elif isinstance(credentials, str): | |
return credentials | |
else: | |
return credentials[storage_account] | |
def read_utf8_file( | |
path: str, credentials: Optional[Union[str, Dict[str, str]]] | |
) -> Iterator[str]: | |
blob_data = _try_parse_azure_blob_uri(path) | |
if blob_data is None: | |
with open(path, "rb") as f: | |
data = f.read() | |
else: | |
try: | |
# pip install azure-storage-blob | |
from azure.storage.blob import BlobClient | |
except: | |
print( | |
"Failed to import azure.storage.blob. Please pip install azure-storage-blob", | |
file=sys.stderr, | |
) | |
raise | |
data = ( | |
BlobClient.from_blob_url( | |
path, | |
credential=_get_azure_key( | |
storage_account=blob_data[0], credentials=credentials | |
), | |
) | |
.download_blob() | |
.readall() | |
) | |
if path.endswith(".gz"): | |
data = gzip.decompress(data) | |
# @TODO: auto-detect UCS-2 by BOM | |
return iter(data.decode(encoding="utf-8").splitlines()) | |
def enumerate_files( | |
dir: str, ext: str, credentials: Optional[Union[str, Dict[str, str]]] | |
): | |
blob_data = _try_parse_azure_blob_uri(dir) | |
if blob_data is None: | |
return [ | |
os.path.join(dir, path.name) | |
for path in os.scandir(dir) | |
if path.is_file() and (ext is None or path.name.endswith(ext)) | |
] | |
else: | |
try: | |
# pip install azure-storage-blob | |
from azure.storage.blob import ContainerClient | |
except: | |
print( | |
"Failed to import azure.storage.blob. Please pip install azure-storage-blob", | |
file=sys.stderr, | |
) | |
raise | |
account, container, blob_path = blob_data | |
print("enumerate_files: enumerating blobs in", dir, file=sys.stderr, flush=True) | |
# @BUGBUG: The prefix does not seem to have to start; seems it can also be a substring | |
container_uri = "https://" + account + ".blob.core.windows.net/" + container | |
container_client = ContainerClient.from_container_url( | |
container_uri, credential=_get_azure_key(account, credentials) | |
) | |
if not blob_path.endswith("/"): | |
blob_path += "/" | |
blob_uris = [ | |
container_uri + "/" + blob["name"] | |
for blob in container_client.walk_blobs(blob_path, delimiter="") | |
if (ext is None or blob["name"].endswith(ext)) | |
] | |
print( | |
"enumerate_files:", | |
len(blob_uris), | |
"blobs found", | |
file=sys.stderr, | |
flush=True, | |
) | |
for blob_name in blob_uris[:10]: | |
print(blob_name, file=sys.stderr, flush=True) | |
return blob_uris | |
if sys.argv[1] == "--azure-storage-key": | |
credential = sys.argv[2] | |
paths = sys.argv[3:] | |
else: | |
credential = None | |
paths = sys.argv[1:] | |
chunk_file_paths = [ # enumerate all .gz files in the given paths | |
subpath for path in paths for subpath in enumerate_files(path, ".gz", credential) | |
] | |
chunk_file_paths.sort() # make sure file order is always the same, independent of OS | |
print( | |
"block_randomize: reading from", | |
len(chunk_file_paths), | |
"chunk files", | |
file=sys.stderr, | |
) | |
ds = chunked_dataset_iterator( | |
chunk_refs=chunk_file_paths, | |
read_chunk_fn=lambda path: read_utf8_file(path, credential), | |
shuffle=True, | |
buffer_size=1000000, | |
seed=1, | |
use_windowed=True, | |
) | |
for line in ds: | |
print(line) | |