Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
from mmengine._strategy import DeepSpeedStrategy as MMEngineDeepSpeedStrategy | |
from xtuner import DS_CEPH_DIR | |
from xtuner.parallel.sequence import init_sequence_parallel | |
from xtuner.utils.fileio import patch_fileio | |
class DeepSpeedStrategy(MMEngineDeepSpeedStrategy): | |
def __init__(self, *args, **kwargs): | |
sequence_parallel_size = kwargs.pop('sequence_parallel_size', 1) | |
self.sequence_parallel_size = sequence_parallel_size | |
super().__init__(*args, **kwargs) | |
from transformers.integrations.deepspeed import HfDeepSpeedConfig | |
# hf_deepspeed_config has to be saved as an attribute. | |
self.hf_deepspeed_config = HfDeepSpeedConfig(self.config) | |
def _wrap_model(self, model): | |
wrapper = super()._wrap_model(model) | |
# hard code for deepspeed zero3 | |
# When utilizing Zero3, the model isn't allocated to CUDA within the | |
# `deepspeed.initialize` process. | |
assert hasattr(wrapper.model, 'data_preprocessor') | |
wrapper.model.data_preprocessor.cuda() | |
return wrapper | |
def save_checkpoint(self, *args, **kwargs) -> None: | |
if DS_CEPH_DIR: | |
from os import path as osp | |
work_dir_prefix = osp.split(self.work_dir)[0] | |
filename = kwargs['filename'].replace(work_dir_prefix, DS_CEPH_DIR) | |
kwargs['filename'] = filename | |
with patch_fileio(): | |
super().save_checkpoint(*args, **kwargs) | |
else: | |
super().save_checkpoint(*args, **kwargs) | |
def load_checkpoint(self, *args, **kwargs) -> None: | |
if DS_CEPH_DIR: | |
with patch_fileio(): | |
checkpoint = super().load_checkpoint(*args, **kwargs) | |
else: | |
checkpoint = super().load_checkpoint(*args, **kwargs) | |
return checkpoint | |
def resume(self, *args, **kwargs) -> None: | |
if DS_CEPH_DIR: | |
with patch_fileio(): | |
checkpoint = super().resume(*args, **kwargs) | |
else: | |
checkpoint = super().resume(*args, **kwargs) | |
return checkpoint | |
def _setup_distributed( # type: ignore | |
self, | |
launcher: Optional[str] = None, | |
backend: str = 'nccl', | |
**kwargs, | |
): | |
super()._setup_distributed(launcher, backend, **kwargs) | |
init_sequence_parallel(self.sequence_parallel_size) | |