zzc0208's picture
Upload 265 files
f1f9265 verified
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids
import fcntl
import os
import shutil
import sys
import time
from collections import deque
from datetime import datetime
from urllib.parse import urlparse
recent_downloads = deque(maxlen=1000)
open_objects = {}
max_open_objects = 100
class ULockFile:
"""A simple locking class. We don't need any of the third
party libraries since we rely on POSIX semantics for linking
below anyway."""
def __init__(self, path):
self.lockfile_path = path
self.lockfile = None
def __enter__(self):
self.lockfile = open(self.lockfile_path, "w")
fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
self.lockfile.close()
self.lockfile = None
try:
os.unlink(self.lockfile_path)
except FileNotFoundError:
pass
def pipe_download(remote, local):
"""Perform a download for a pipe: url."""
assert remote.startswith("pipe:")
cmd = remote[5:]
cmd = cmd.format(local=local)
assert os.system(cmd) == 0, "Command failed: %s" % cmd
def copy_file(remote, local):
remote = urlparse(remote)
assert remote.scheme in ["file", ""]
# use absolute path
remote = os.path.abspath(remote.path)
local = urlparse(local)
assert local.scheme in ["file", ""]
local = os.path.abspath(local.path)
if remote == local:
return
# check if the local file exists
shutil.copyfile(remote, local)
verbose_cmd = int(os.environ.get("WIDS_VERBOSE_CMD", "0"))
def vcmd(flag, verbose_flag=""):
return verbose_flag if verbose_cmd else flag
default_cmds = {
"posixpath": copy_file,
"file": copy_file,
"pipe": pipe_download,
"http": "curl " + vcmd("-s") + " -L {url} -o {local}",
"https": "curl " + vcmd("-s") + " -L {url} -o {local}",
"ftp": "curl " + vcmd("-s") + " -L {url} -o {local}",
"ftps": "curl " + vcmd("-s") + " -L {url} -o {local}",
"gs": "gsutil " + vcmd("-q") + " cp {url} {local}",
"s3": "aws s3 cp {url} {local}",
}
# TODO(ligeng): change HTTPS download to python requests library
def download_file_no_log(remote, local, handlers=default_cmds):
"""Download a file from a remote url to a local path.
The remote url can be a pipe: url, in which case the remainder of
the url is treated as a command template that is executed to perform the download.
"""
if remote.startswith("pipe:"):
schema = "pipe"
else:
schema = urlparse(remote).scheme
if schema is None or schema == "":
schema = "posixpath"
# get the handler
handler = handlers.get(schema)
if handler is None:
raise ValueError("Unknown schema: %s" % schema)
# call the handler
if callable(handler):
handler(remote, local)
else:
assert isinstance(handler, str)
cmd = handler.format(url=remote, local=local)
assert os.system(cmd) == 0, "Command failed: %s" % cmd
return local
def download_file(remote, local, handlers=default_cmds, verbose=False):
start = time.time()
try:
return download_file_no_log(remote, local, handlers=handlers)
finally:
recent_downloads.append((remote, local, time.time(), time.time() - start))
if verbose:
print(
"downloaded",
remote,
"to",
local,
"in",
time.time() - start,
"seconds",
file=sys.stderr,
)
def download_and_open(remote, local, mode="rb", handlers=default_cmds, verbose=False):
with ULockFile(local + ".lock"):
if os.path.exists(remote):
# print("enter1", remote, local, mode)
result = open(remote, mode)
else:
# print("enter2", remote, local, mode)
if not os.path.exists(local):
if verbose:
print("downloading", remote, "to", local, file=sys.stderr)
download_file(remote, local, handlers=handlers)
else:
if verbose:
print("using cached", local, file=sys.stderr)
result = open(local, mode)
# input()
if open_objects is not None:
for k, v in list(open_objects.items()):
if v.closed:
del open_objects[k]
if len(open_objects) > max_open_objects:
raise RuntimeError("Too many open objects")
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
key = tuple(str(x) for x in [remote, local, mode, current_time])
open_objects[key] = result
return result