Spaces:
Build error
Build error
""" | |
Utilities for working with the local dataset cache. | |
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp | |
Copyright by the AllenNLP authors. | |
""" | |
from __future__ import (absolute_import, division, print_function, unicode_literals) | |
import sys | |
import json | |
import logging | |
import os | |
import six | |
import shutil | |
import tempfile | |
import fnmatch | |
from functools import wraps | |
from hashlib import sha256 | |
from io import open | |
# import boto3 | |
# from botocore.config import Config | |
# from botocore.exceptions import ClientError | |
import requests | |
from tqdm import tqdm | |
try: | |
from torch.hub import _get_torch_home | |
torch_cache_home = _get_torch_home() | |
except ImportError: | |
torch_cache_home = os.path.expanduser( | |
os.getenv('TORCH_HOME', os.path.join( | |
os.getenv('XDG_CACHE_HOME', '~/.cache'), 'torch'))) | |
default_cache_path = os.path.join(torch_cache_home, 'pytorch_transformers') | |
try: | |
from urllib.parse import urlparse | |
except ImportError: | |
from urlparse import urlparse | |
try: | |
from pathlib import Path | |
PYTORCH_PRETRAINED_BERT_CACHE = Path( | |
os.getenv('PYTORCH_TRANSFORMERS_CACHE', os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', default_cache_path))) | |
except (AttributeError, ImportError): | |
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_TRANSFORMERS_CACHE', | |
os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', | |
default_cache_path)) | |
PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward compatibility | |
WEIGHTS_NAME = "pytorch_model.bin" | |
TF_WEIGHTS_NAME = 'model.ckpt' | |
CONFIG_NAME = "config.json" | |
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |
if not six.PY2: | |
def add_start_docstrings(*docstr): | |
def docstring_decorator(fn): | |
fn.__doc__ = ''.join(docstr) + fn.__doc__ | |
return fn | |
return docstring_decorator | |
def add_end_docstrings(*docstr): | |
def docstring_decorator(fn): | |
fn.__doc__ = fn.__doc__ + ''.join(docstr) | |
return fn | |
return docstring_decorator | |
else: | |
# Not possible to update class docstrings on python2 | |
def add_start_docstrings(*docstr): | |
def docstring_decorator(fn): | |
return fn | |
return docstring_decorator | |
def add_end_docstrings(*docstr): | |
def docstring_decorator(fn): | |
return fn | |
return docstring_decorator | |
def url_to_filename(url, etag=None): | |
""" | |
Convert `url` into a hashed filename in a repeatable way. | |
If `etag` is specified, append its hash to the url's, delimited | |
by a period. | |
""" | |
url_bytes = url.encode('utf-8') | |
url_hash = sha256(url_bytes) | |
filename = url_hash.hexdigest() | |
if etag: | |
etag_bytes = etag.encode('utf-8') | |
etag_hash = sha256(etag_bytes) | |
filename += '.' + etag_hash.hexdigest() | |
return filename | |
def filename_to_url(filename, cache_dir=None): | |
""" | |
Return the url and etag (which may be ``None``) stored for `filename`. | |
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. | |
""" | |
if cache_dir is None: | |
cache_dir = PYTORCH_TRANSFORMERS_CACHE | |
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
cache_path = os.path.join(cache_dir, filename) | |
if not os.path.exists(cache_path): | |
raise EnvironmentError("file {} not found".format(cache_path)) | |
meta_path = cache_path + '.json' | |
if not os.path.exists(meta_path): | |
raise EnvironmentError("file {} not found".format(meta_path)) | |
with open(meta_path, encoding="utf-8") as meta_file: | |
metadata = json.load(meta_file) | |
url = metadata['url'] | |
etag = metadata['etag'] | |
return url, etag | |
def cached_path(url_or_filename, cache_dir=None, force_download=False, proxies=None): | |
""" | |
Given something that might be a URL (or might be a local path), | |
determine which. If it's a URL, download the file and cache it, and | |
return the path to the cached file. If it's already a local path, | |
make sure the file exists and then return the path. | |
Args: | |
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir). | |
force_download: if True, re-dowload the file even if it's already cached in the cache dir. | |
""" | |
if cache_dir is None: | |
cache_dir = PYTORCH_TRANSFORMERS_CACHE | |
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): | |
url_or_filename = str(url_or_filename) | |
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
parsed = urlparse(url_or_filename) | |
if parsed.scheme in ('http', 'https', 's3'): | |
# URL, so get it from the cache (downloading if necessary) | |
return get_from_cache(url_or_filename, cache_dir=cache_dir, force_download=force_download, proxies=proxies) | |
elif os.path.exists(url_or_filename): | |
# File, and it exists. | |
return url_or_filename | |
elif parsed.scheme == '': | |
# File, but it doesn't exist. | |
raise EnvironmentError("file {} not found".format(url_or_filename)) | |
else: | |
# Something unknown | |
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) | |
def split_s3_path(url): | |
"""Split a full s3 path into the bucket name and path.""" | |
parsed = urlparse(url) | |
if not parsed.netloc or not parsed.path: | |
raise ValueError("bad s3 path {}".format(url)) | |
bucket_name = parsed.netloc | |
s3_path = parsed.path | |
# Remove '/' at beginning of path. | |
if s3_path.startswith("/"): | |
s3_path = s3_path[1:] | |
return bucket_name, s3_path | |
def s3_request(func): | |
""" | |
Wrapper function for s3 requests in order to create more helpful error | |
messages. | |
""" | |
def wrapper(url, *args, **kwargs): | |
try: | |
return func(url, *args, **kwargs) | |
except ClientError as exc: | |
if int(exc.response["Error"]["Code"]) == 404: | |
raise EnvironmentError("file {} not found".format(url)) | |
else: | |
raise | |
return wrapper | |
def s3_etag(url, proxies=None): | |
"""Check ETag on S3 object.""" | |
s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) | |
bucket_name, s3_path = split_s3_path(url) | |
s3_object = s3_resource.Object(bucket_name, s3_path) | |
return s3_object.e_tag | |
def s3_get(url, temp_file, proxies=None): | |
"""Pull a file directly from S3.""" | |
s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) | |
bucket_name, s3_path = split_s3_path(url) | |
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) | |
def http_get(url, temp_file, proxies=None): | |
req = requests.get(url, stream=True, proxies=proxies) | |
content_length = req.headers.get('Content-Length') | |
total = int(content_length) if content_length is not None else None | |
progress = tqdm(unit="B", total=total) | |
for chunk in req.iter_content(chunk_size=1024): | |
if chunk: # filter out keep-alive new chunks | |
progress.update(len(chunk)) | |
temp_file.write(chunk) | |
progress.close() | |
def get_from_cache(url, cache_dir=None, force_download=False, proxies=None): | |
""" | |
Given a URL, look for the corresponding dataset in the local cache. | |
If it's not there, download it. Then return the path to the cached file. | |
""" | |
if cache_dir is None: | |
cache_dir = PYTORCH_TRANSFORMERS_CACHE | |
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |
cache_dir = str(cache_dir) | |
if sys.version_info[0] == 2 and not isinstance(cache_dir, str): | |
cache_dir = str(cache_dir) | |
if not os.path.exists(cache_dir): | |
os.makedirs(cache_dir) | |
# Get eTag to add to filename, if it exists. | |
if url.startswith("s3://"): | |
etag = s3_etag(url, proxies=proxies) | |
else: | |
try: | |
response = requests.head(url, allow_redirects=True, proxies=proxies) | |
if response.status_code != 200: | |
etag = None | |
else: | |
etag = response.headers.get("ETag") | |
except EnvironmentError: | |
etag = None | |
if sys.version_info[0] == 2 and etag is not None: | |
etag = etag.decode('utf-8') | |
filename = url_to_filename(url, etag) | |
# get cache path to put the file | |
cache_path = os.path.join(cache_dir, filename) | |
# If we don't have a connection (etag is None) and can't identify the file | |
# try to get the last downloaded one | |
if not os.path.exists(cache_path) and etag is None: | |
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') | |
matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) | |
if matching_files: | |
cache_path = os.path.join(cache_dir, matching_files[-1]) | |
if not os.path.exists(cache_path) or force_download: | |
# Download to temporary file, then copy to cache dir once finished. | |
# Otherwise you get corrupt cache entries if the download gets interrupted. | |
with tempfile.NamedTemporaryFile() as temp_file: | |
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) | |
# GET file object | |
if url.startswith("s3://"): | |
s3_get(url, temp_file, proxies=proxies) | |
else: | |
http_get(url, temp_file, proxies=proxies) | |
# we are copying the file before closing it, so flush to avoid truncation | |
temp_file.flush() | |
# shutil.copyfileobj() starts at the current position, so go to the start | |
temp_file.seek(0) | |
logger.info("copying %s to cache at %s", temp_file.name, cache_path) | |
with open(cache_path, 'wb') as cache_file: | |
shutil.copyfileobj(temp_file, cache_file) | |
logger.info("creating metadata file for %s", cache_path) | |
meta = {'url': url, 'etag': etag} | |
meta_path = cache_path + '.json' | |
with open(meta_path, 'w') as meta_file: | |
output_string = json.dumps(meta) | |
if sys.version_info[0] == 2 and isinstance(output_string, str): | |
output_string = unicode(output_string, 'utf-8') # The beauty of python 2 | |
meta_file.write(output_string) | |
logger.info("removing temp file %s", temp_file.name) | |
return cache_path | |