Spaces:
Sleeping
Sleeping
import io | |
from ditk import logging | |
import os | |
import pickle | |
import time | |
from functools import lru_cache | |
from typing import Union | |
import torch | |
from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc | |
from .lock_helper import get_file_lock | |
_memcached = None | |
_redis_cluster = None | |
if os.environ.get('DI_STORE', 'off').lower() == 'on': | |
print('Enable DI-store') | |
from di_store import Client | |
di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml') | |
di_store_client = Client(di_store_config_path) | |
def save_to_di_store(data): | |
return di_store_client.put(data) | |
def read_from_di_store(object_ref): | |
data = di_store_client.get(object_ref) | |
di_store_client.delete(object_ref) | |
return data | |
else: | |
save_to_di_store = read_from_di_store = None | |
def get_ceph_package(): | |
return try_import_ceph() | |
def get_redis_package(): | |
return try_import_redis() | |
def get_rediscluster_package(): | |
return try_import_rediscluster() | |
def get_mc_package(): | |
return try_import_mc() | |
def read_from_ceph(path: str) -> object: | |
""" | |
Overview: | |
Read file from ceph | |
Arguments: | |
- path (:obj:`str`): File path in ceph, start with ``"s3://"`` | |
Returns: | |
- (:obj:`data`): Deserialized data | |
""" | |
value = get_ceph_package().Get(path) | |
if not value: | |
raise FileNotFoundError("File({}) doesn't exist in ceph".format(path)) | |
return pickle.loads(value) | |
def _get_redis(host='localhost', port=6379): | |
""" | |
Overview: | |
Ensures redis usage | |
Arguments: | |
- host (:obj:`str`): Host string | |
- port (:obj:`int`): Port number | |
Returns: | |
- (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0`` | |
""" | |
return get_redis_package().StrictRedis(host=host, port=port, db=0) | |
def read_from_redis(path: str) -> object: | |
""" | |
Overview: | |
Read file from redis | |
Arguments: | |
- path (:obj:`str`): Dile path in redis, could be a string key | |
Returns: | |
- (:obj:`data`): Deserialized data | |
""" | |
return pickle.loads(_get_redis().get(path)) | |
def _ensure_rediscluster(startup_nodes=[{"host": "127.0.0.1", "port": "7000"}]): | |
""" | |
Overview: | |
Ensures redis usage | |
Arguments: | |
- List of startup nodes (:obj:`dict`) of | |
- host (:obj:`str`): Host string | |
- port (:obj:`int`): Port number | |
Returns: | |
- (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \ | |
and ``False`` for ``decode_responses`` in default. | |
""" | |
global _redis_cluster | |
if _redis_cluster is None: | |
_redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False) | |
return | |
def read_from_rediscluster(path: str) -> object: | |
""" | |
Overview: | |
Read file from rediscluster | |
Arguments: | |
- path (:obj:`str`): Dile path in rediscluster, could be a string key | |
Returns: | |
- (:obj:`data`): Deserialized data | |
""" | |
_ensure_rediscluster() | |
value_bytes = _redis_cluster.get(path) | |
value = pickle.loads(value_bytes) | |
return value | |
def read_from_file(path: str) -> object: | |
""" | |
Overview: | |
Read file from local file system | |
Arguments: | |
- path (:obj:`str`): File path in local file system | |
Returns: | |
- (:obj:`data`): Deserialized data | |
""" | |
with open(path, "rb") as f: | |
value = pickle.load(f) | |
return value | |
def _ensure_memcached(): | |
""" | |
Overview: | |
Ensures memcache usage | |
Returns: | |
- (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \ | |
memcached_client's ``server_list.conf`` and ``client.conf`` files | |
""" | |
global _memcached | |
if _memcached is None: | |
server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf" | |
client_config_file = "/mnt/lustre/share/memcached_client/client.conf" | |
_memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file) | |
return | |
def read_from_mc(path: str, flush=False) -> object: | |
""" | |
Overview: | |
Read file from memcache, file must be saved by `torch.save()` | |
Arguments: | |
- path (:obj:`str`): File path in local system | |
Returns: | |
- (:obj:`data`): Deserialized data | |
""" | |
_ensure_memcached() | |
while True: | |
try: | |
value = get_mc_package().pyvector() | |
if flush: | |
_memcached.Get(path, value, get_mc_package().MC_READ_THROUGH) | |
return | |
else: | |
_memcached.Get(path, value) | |
value_buf = get_mc_package().ConvertBuffer(value) | |
value_str = io.BytesIO(value_buf) | |
value_str = torch.load(value_str, map_location='cpu') | |
return value_str | |
except Exception: | |
print('read mc failed, retry...') | |
time.sleep(0.01) | |
def read_from_path(path: str): | |
""" | |
Overview: | |
Read file from ceph | |
Arguments: | |
- path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system | |
Returns: | |
- (:obj:`data`): Deserialized data | |
""" | |
if get_ceph_package() is None: | |
logging.info( | |
"You do not have ceph installed! Loading local file!" | |
" If you are not testing locally, something is wrong!" | |
) | |
return read_from_file(path) | |
else: | |
return read_from_ceph(path) | |
def save_file_ceph(path, data): | |
""" | |
Overview: | |
Save pickle dumped data file to ceph | |
Arguments: | |
- path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not | |
- data (:obj:`Any`): Could be dict, list or tensor etc. | |
""" | |
data = pickle.dumps(data) | |
save_path = os.path.dirname(path) | |
file_name = os.path.basename(path) | |
ceph = get_ceph_package() | |
if ceph is not None: | |
if hasattr(ceph, 'save_from_string'): | |
ceph.save_from_string(save_path, file_name, data) | |
elif hasattr(ceph, 'put'): | |
ceph.put(os.path.join(save_path, file_name), data) | |
else: | |
raise RuntimeError('ceph can not save file, check your ceph installation') | |
else: | |
size = len(data) | |
if save_path == 'do_not_save': | |
logging.info( | |
"You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) + | |
" If you are not testing locally, something is wrong!" | |
) | |
return | |
p = os.path.join(save_path, file_name) | |
with open(p, 'wb') as f: | |
logging.info( | |
"You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) + | |
" If you are not testing locally, something is wrong!" | |
) | |
f.write(data) | |
def save_file_redis(path, data): | |
""" | |
Overview: | |
Save pickle dumped data file to redis | |
Arguments: | |
- path (:obj:`str`): File path (could be a string key) in redis | |
- data (:obj:`Any`): Could be dict, list or tensor etc. | |
""" | |
_get_redis().set(path, pickle.dumps(data)) | |
def save_file_rediscluster(path, data): | |
""" | |
Overview: | |
Save pickle dumped data file to rediscluster | |
Arguments: | |
- path (:obj:`str`): File path (could be a string key) in redis | |
- data (:obj:`Any`): Could be dict, list or tensor etc. | |
""" | |
_ensure_rediscluster() | |
data = pickle.dumps(data) | |
_redis_cluster.set(path, data) | |
return | |
def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object: | |
""" | |
Overview: | |
Read file from path | |
Arguments: | |
- path (:obj:`str`): The path of file to read | |
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` | |
- use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system | |
""" | |
if fs_type is None: | |
if path.lower().startswith('s3'): | |
fs_type = 'ceph' | |
elif get_mc_package() is not None: | |
fs_type = 'mc' | |
else: | |
fs_type = 'normal' | |
assert fs_type in ['normal', 'ceph', 'mc'] | |
if fs_type == 'ceph': | |
data = read_from_path(path) | |
elif fs_type == 'normal': | |
if use_lock: | |
with get_file_lock(path, 'read'): | |
data = torch.load(path, map_location='cpu') | |
else: | |
data = torch.load(path, map_location='cpu') | |
elif fs_type == 'mc': | |
data = read_from_mc(path) | |
return data | |
def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None: | |
""" | |
Overview: | |
Save data to file of path | |
Arguments: | |
- path (:obj:`str`): The path of file to save to | |
- data (:obj:`object`): The data to save | |
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` | |
- use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system | |
""" | |
if fs_type is None: | |
if path.lower().startswith('s3'): | |
fs_type = 'ceph' | |
elif get_mc_package() is not None: | |
fs_type = 'mc' | |
else: | |
fs_type = 'normal' | |
assert fs_type in ['normal', 'ceph', 'mc'] | |
if fs_type == 'ceph': | |
save_file_ceph(path, data) | |
elif fs_type == 'normal': | |
if use_lock: | |
with get_file_lock(path, 'write'): | |
torch.save(data, path) | |
else: | |
torch.save(data, path) | |
elif fs_type == 'mc': | |
torch.save(data, path) | |
read_from_mc(path, flush=True) | |
def remove_file(path: str, fs_type: Union[None, str] = None) -> None: | |
""" | |
Overview: | |
Remove file | |
Arguments: | |
- path (:obj:`str`): The path of file you want to remove | |
- fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}`` | |
""" | |
if fs_type is None: | |
fs_type = 'ceph' if path.lower().startswith('s3') else 'normal' | |
assert fs_type in ['normal', 'ceph'] | |
if fs_type == 'ceph': | |
os.popen("aws s3 rm --recursive {}".format(path)) | |
elif fs_type == 'normal': | |
os.popen("rm -rf {}".format(path)) | |