File size: 2,301 Bytes
476ac07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dist import init_dist

from .attention import (post_process_for_sequence_parallel_attn,
                        pre_process_for_sequence_parallel_attn,
                        sequence_parallel_wrapper)
from .comm import (all_to_all, gather_for_sequence_parallel,
                   gather_forward_split_backward, split_for_sequence_parallel,
                   split_forward_gather_backward)
from .data_collate import (pad_cumulative_len_for_sequence_parallel,
                           pad_for_sequence_parallel)
from .reduce_loss import reduce_sequence_parallel_loss
from .sampler import SequenceParallelSampler
from .setup_distributed import (get_data_parallel_group,
                                get_data_parallel_rank,
                                get_data_parallel_world_size,
                                get_inner_sequence_parallel_group,
                                get_inner_sequence_parallel_rank,
                                get_inner_sequence_parallel_world_size,
                                get_sequence_parallel_group,
                                get_sequence_parallel_rank,
                                get_sequence_parallel_world_size,
                                init_inner_sequence_parallel,
                                init_sequence_parallel,
                                is_inner_sequence_parallel_initialized)

__all__ = [
    'sequence_parallel_wrapper', 'pre_process_for_sequence_parallel_attn',
    'post_process_for_sequence_parallel_attn', 'pad_for_sequence_parallel',
    'split_for_sequence_parallel', 'SequenceParallelSampler',
    'init_sequence_parallel', 'get_sequence_parallel_group',
    'get_sequence_parallel_world_size', 'get_sequence_parallel_rank',
    'get_data_parallel_group', 'get_data_parallel_world_size',
    'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist',
    'all_to_all', 'gather_for_sequence_parallel',
    'split_forward_gather_backward', 'gather_forward_split_backward',
    'get_inner_sequence_parallel_group', 'get_inner_sequence_parallel_rank',
    'get_inner_sequence_parallel_world_size', 'init_inner_sequence_parallel',
    'is_inner_sequence_parallel_initialized',
    'pad_cumulative_len_for_sequence_parallel'
]