Spaces:
Runtime error
Runtime error
File size: 2,301 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dist import init_dist
from .attention import (post_process_for_sequence_parallel_attn,
pre_process_for_sequence_parallel_attn,
sequence_parallel_wrapper)
from .comm import (all_to_all, gather_for_sequence_parallel,
gather_forward_split_backward, split_for_sequence_parallel,
split_forward_gather_backward)
from .data_collate import (pad_cumulative_len_for_sequence_parallel,
pad_for_sequence_parallel)
from .reduce_loss import reduce_sequence_parallel_loss
from .sampler import SequenceParallelSampler
from .setup_distributed import (get_data_parallel_group,
get_data_parallel_rank,
get_data_parallel_world_size,
get_inner_sequence_parallel_group,
get_inner_sequence_parallel_rank,
get_inner_sequence_parallel_world_size,
get_sequence_parallel_group,
get_sequence_parallel_rank,
get_sequence_parallel_world_size,
init_inner_sequence_parallel,
init_sequence_parallel,
is_inner_sequence_parallel_initialized)
__all__ = [
'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn',
'post_process_for_sequence_parallel_attn', 'pad_for_sequence_parallel',
'split_for_sequence_parallel', 'SequenceParallelSampler',
'init_sequence_parallel', 'get_sequence_parallel_group',
'get_sequence_parallel_world_size', 'get_sequence_parallel_rank',
'get_data_parallel_group', 'get_data_parallel_world_size',
'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist',
'all_to_all', 'gather_for_sequence_parallel',
'split_forward_gather_backward', 'gather_forward_split_backward',
'get_inner_sequence_parallel_group', 'get_inner_sequence_parallel_rank',
'get_inner_sequence_parallel_world_size', 'init_inner_sequence_parallel',
'is_inner_sequence_parallel_initialized',
'pad_cumulative_len_for_sequence_parallel'
]
|