Spaces:
Runtime error
Runtime error
import io | |
from contextlib import contextmanager | |
import mmengine.fileio as fileio | |
from mmengine.fileio import LocalBackend, PetrelBackend, get_file_backend | |
def patch_func(module, fn_name_to_wrap): | |
backup = getattr(patch_func, '_backup', []) | |
fn_to_wrap = getattr(module, fn_name_to_wrap) | |
def wrap(fn_new): | |
setattr(module, fn_name_to_wrap, fn_new) | |
backup.append((module, fn_name_to_wrap, fn_to_wrap)) | |
setattr(fn_new, '_fallback', fn_to_wrap) | |
setattr(patch_func, '_backup', backup) | |
return fn_new | |
return wrap | |
def patch_fileio(global_vars=None): | |
if getattr(patch_fileio, '_patched', False): | |
# Only patch once, avoid error caused by patch nestly. | |
yield | |
return | |
import builtins | |
def open(file, mode='r', *args, **kwargs): | |
backend = get_file_backend(file) | |
if isinstance(backend, LocalBackend): | |
return open._fallback(file, mode, *args, **kwargs) | |
if 'b' in mode: | |
return io.BytesIO(backend.get(file, *args, **kwargs)) | |
else: | |
return io.StringIO(backend.get_text(file, *args, **kwargs)) | |
if global_vars is not None and 'open' in global_vars: | |
bak_open = global_vars['open'] | |
global_vars['open'] = builtins.open | |
import os | |
def join(a, *paths): | |
backend = get_file_backend( | |
a.decode('utf-8') if isinstance(a, bytes) else a) | |
if isinstance(backend, LocalBackend): | |
return join._fallback(a, *paths) | |
paths = [item.lstrip('./') for item in paths if len(item) > 0] | |
return backend.join_path(a, *paths) | |
def isdir(path): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return isdir._fallback(path) | |
return backend.isdir(path) | |
def isfile(path): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return isfile._fallback(path) | |
return backend.isfile(path) | |
def exists(path): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return exists._fallback(path) | |
return backend.exists(path) | |
def mkdir(path, *args, **kwargs): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return mkdir._fallback(path, *args, **kwargs) | |
def makedirs(path, *args, **kwargs): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return makedirs._fallback(path, *args, **kwargs) | |
def listdir(path): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return listdir._fallback(path) | |
return backend.list_dir_or_file(path) | |
def chmod(path, *args, **kwargs): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return chmod._fallback(path, *args, **kwargs) | |
def stat(path, *args, **kwargs): | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return stat._fallback(path, *args, **kwargs) | |
import glob as glob_pkg | |
def glob(pathname, *, recursive=False): | |
backend = get_file_backend(pathname) | |
if isinstance(backend, LocalBackend): | |
return glob._fallback(pathname, recursive=recursive) | |
if pathname.endswith('*_optim_states.pt'): | |
import os | |
pathname = os.path.split(pathname)[0] | |
files = backend.list_dir_or_file(pathname, recursive=recursive) | |
files = [ | |
os.path.join(pathname, f) for f in files | |
if f.endswith('_optim_states.pt') | |
] | |
elif pathname.endswith('*_model_states.pt'): | |
import os | |
pathname = os.path.split(pathname)[0] | |
files = backend.list_dir_or_file(pathname, recursive=recursive) | |
files = [ | |
os.path.join(pathname, f) for f in files | |
if f.endswith('_model_states.pt') | |
] | |
elif '*' in pathname: | |
raise NotImplementedError | |
else: | |
files = backend.list_dir_or_file(pathname, recursive=recursive) | |
return files | |
import filecmp | |
def cmp(f1, f2, *args, **kwargs): | |
with fileio.get_local_path(f1) as f1, fileio.get_local_path(f2) as f2: | |
return cmp._fallback(f1, f2, *args, **kwargs) | |
import shutil | |
def copy(src, dst, **kwargs): | |
from pathlib import Path | |
if isinstance(src, Path): | |
src = str(src).replace(':/', '://') | |
if isinstance(dst, Path): | |
dst = str(dst).replace(':/', '://') | |
src_backend = get_file_backend(src) | |
dst_backend = get_file_backend(dst) | |
if isinstance(src_backend, LocalBackend) and isinstance( | |
dst_backend, LocalBackend): | |
return copy._fallback(src, dst, **kwargs) | |
elif isinstance(src_backend, LocalBackend) and isinstance( | |
dst_backend, PetrelBackend): | |
return dst_backend.copyfile_from_local(str(src), str(dst)) | |
elif isinstance(src_backend, PetrelBackend) and isinstance( | |
dst_backend, LocalBackend): | |
return src_backend.copyfile_to_local(str(src), str(dst)) | |
import torch | |
def load(f, *args, **kwargs): | |
if isinstance(f, str): | |
f = io.BytesIO(fileio.get(f)) | |
return load._fallback(f, *args, **kwargs) | |
def save(obj, f, *args, **kwargs): | |
backend = get_file_backend(f) | |
if isinstance(backend, LocalBackend): | |
return save._fallback(obj, f, *args, **kwargs) | |
with io.BytesIO() as buffer: | |
save._fallback(obj, buffer, *args, **kwargs) | |
buffer.seek(0) | |
backend.put(buffer, f) | |
# from tempfile import TemporaryDirectory | |
# import os | |
# with TemporaryDirectory(dir='/dev/shm') as tmpdir: | |
# suffix = os.path.split(f)[-1] | |
# tmppath = os.path.join._fallback(tmpdir, suffix) | |
# from mmengine import print_log | |
# print_log('write to tmp dir', logger='current') | |
# save._fallback(obj, tmppath, *args, **kwargs) | |
# print_log('write to ceph', logger='current') | |
# with open(tmppath, 'rb') as buffer: | |
# backend.put(buffer, f) | |
from sentencepiece import SentencePieceProcessor | |
def LoadFromFile(cls, path): | |
if path: | |
backend = get_file_backend(path) | |
if isinstance(backend, LocalBackend): | |
return LoadFromFile._fallback(cls, path) | |
from tempfile import TemporaryDirectory | |
with TemporaryDirectory() as tmpdir: | |
local_path = backend.copyfile_to_local(path, tmpdir) | |
loaded_file = LoadFromFile._fallback(cls, local_path) | |
return loaded_file | |
else: | |
return LoadFromFile._fallback(cls, path) | |
try: | |
setattr(patch_fileio, '_patched', True) | |
yield | |
finally: | |
for patched_fn in patch_func._backup: | |
(module, fn_name_to_wrap, fn_to_wrap) = patched_fn | |
setattr(module, fn_name_to_wrap, fn_to_wrap) | |
if global_vars is not None and 'open' in global_vars: | |
global_vars['open'] = bak_open | |
setattr(patch_fileio, '_patched', False) | |
def patch_hf_auto_from_pretrained(petrel_hub): | |
if hasattr(patch_hf_auto_from_pretrained, '_patched'): | |
return | |
from peft import PeftModel | |
from transformers import (AutoConfig, AutoFeatureExtractor, | |
AutoImageProcessor, AutoModelForCausalLM, | |
AutoProcessor, AutoTokenizer, | |
ImageProcessingMixin, PreTrainedModel, | |
PreTrainedTokenizerBase, ProcessorMixin) | |
from transformers.models.auto.auto_factory import _BaseAutoModelClass | |
target_cls = list(_BaseAutoModelClass.__subclasses__()) | |
target_cls.extend([AutoModelForCausalLM] + | |
AutoModelForCausalLM.__subclasses__()) | |
target_cls.extend([AutoConfig] + AutoConfig.__subclasses__()) | |
target_cls.extend([AutoTokenizer] + AutoTokenizer.__subclasses__()) | |
target_cls.extend([AutoImageProcessor] + | |
AutoImageProcessor.__subclasses__()) | |
target_cls.extend([AutoFeatureExtractor] + | |
AutoFeatureExtractor.__subclasses__()) | |
target_cls.extend([AutoProcessor] + AutoProcessor.__subclasses__()) | |
target_cls.extend([PreTrainedTokenizerBase] + | |
PreTrainedTokenizerBase.__subclasses__()) | |
target_cls.extend([ImageProcessingMixin] + | |
ImageProcessingMixin.__subclasses__()) | |
target_cls.extend([PreTrainedModel] + PreTrainedModel.__subclasses__()) | |
target_cls.extend([ProcessorMixin] + ProcessorMixin.__subclasses__()) | |
target_cls.extend([PeftModel] + PeftModel.__subclasses__()) | |
import os | |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |
with patch_fileio(): | |
model_path = pretrained_model_name_or_path | |
model_path = os.path.join(petrel_hub, model_path) | |
obj = cls._from_pretrained(model_path, *args, **kwargs) | |
return obj | |
for cls in set(target_cls): | |
if not hasattr(cls, '_from_pretrained'): | |
cls._from_pretrained = cls.from_pretrained | |
cls.from_pretrained = from_pretrained | |
patch_hf_auto_from_pretrained._patched = True | |
def patch_hf_save_pretrained(): | |
if hasattr(patch_hf_save_pretrained, '_patched'): | |
return | |
import torch | |
from peft import PeftModel | |
from transformers import (AutoConfig, AutoTokenizer, PreTrainedModel, | |
PreTrainedTokenizerBase) | |
from transformers.models.auto.auto_factory import _BaseAutoModelClass | |
target_cls = [] | |
target_cls.extend([AutoConfig] + AutoConfig.__subclasses__()) | |
target_cls.extend([AutoTokenizer] + AutoTokenizer.__subclasses__()) | |
target_cls.extend([PreTrainedTokenizerBase] + | |
PreTrainedTokenizerBase.__subclasses__()) | |
target_cls.extend([PreTrainedModel] + PreTrainedModel.__subclasses__()) | |
target_cls.extend([_BaseAutoModelClass] + | |
_BaseAutoModelClass.__subclasses__()) | |
target_cls.extend([PeftModel] + PeftModel.__subclasses__()) | |
def _patch_wrap(method): | |
def wrapped_method(self, *args, **kwargs): | |
with patch_fileio(): | |
kwargs['save_function'] = torch.save | |
kwargs['safe_serialization'] = False | |
obj = method(self, *args, **kwargs) | |
return obj | |
return wrapped_method | |
for cls in set(target_cls): | |
if hasattr(cls, 'save_pretrained'): | |
cls.save_pretrained = _patch_wrap(cls.save_pretrained) | |
patch_hf_save_pretrained._patched = True | |
def patch_deepspeed_engine(): | |
if hasattr(patch_deepspeed_engine, '_patched'): | |
return | |
def _copy_recovery_script(self, save_path): | |
import os | |
from shutil import copyfile | |
from deepspeed.utils import zero_to_fp32 | |
from mmengine import PetrelBackend, get_file_backend | |
script = 'zero_to_fp32.py' | |
src = zero_to_fp32.__file__ | |
dst = os.path.join(save_path, script) | |
backend = get_file_backend(save_path) | |
if isinstance(backend, PetrelBackend): | |
backend.copyfile_from_local(src, dst) | |
else: | |
copyfile(src, dst) | |
self._change_recovery_script_permissions(dst) | |
from deepspeed.runtime.engine import DeepSpeedEngine | |
DeepSpeedEngine._copy_recovery_script = _copy_recovery_script | |
patch_deepspeed_engine._patched = True | |