File size: 3,321 Bytes
476ac07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

from mmengine import MessageHub
from mmengine.dist import get_rank
from mmengine.hooks import Hook

DATA_BATCH = Optional[Union[dict, tuple, list]]


class VarlenAttnArgsToMessageHubHook(Hook):

    def before_train_iter(self,
                          runner,
                          batch_idx: int,
                          data_batch: dict = None) -> None:
        rank = get_rank()
        message_hub = MessageHub.get_instance('varlen_attn_args')

        assert 'data' in data_batch.keys()
        data = data_batch['data']

        cumulative_len = data.pop('cumulative_len')
        assert len(cumulative_len) == 1
        cumulative_len = cumulative_len[0].cuda()
        message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len)

        max_seqlen = data.pop('max_seqlen')
        message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)

    def after_train_iter(self,
                         runner,
                         batch_idx: int,
                         data_batch: DATA_BATCH = None,
                         outputs: Optional[dict] = None) -> None:
        rank = get_rank()
        message_hub = MessageHub.get_instance('varlen_attn_args')
        message_hub.update_info(f'cumulative_len_rank_{rank}', None)
        message_hub.update_info(f'max_seqlen_rank_{rank}', None)

    def before_val_iter(self,
                        runner,
                        batch_idx: int,
                        data_batch: DATA_BATCH = None) -> None:
        """All subclasses should override this method, if they need any
        operations before each validation iteration.

        Args:
            runner (Runner): The runner of the validation process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (dict, optional): Data from dataloader.
                Defaults to None.
        """
        rank = get_rank()
        message_hub = MessageHub.get_instance('varlen_attn_args')

        assert 'data' in data_batch.keys()
        data = data_batch['data']

        cumulative_len = data.pop('cumulative_len')
        assert len(cumulative_len) == 1
        cumulative_len = cumulative_len[0].cuda()
        message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len)

        max_seqlen = data.pop('max_seqlen')
        message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)

    def after_val_iter(self,
                       runner,
                       batch_idx,
                       data_batch=None,
                       outputs=None) -> None:
        """All subclasses should override this method, if they need any
        operations after each validation iteration.

        Args:
            runner (Runner): The runner of the validation process.
            batch_idx (int): The index of the current batch in the val loop.
            data_batch (dict or tuple or list, optional): Data from dataloader.
            outputs (Sequence, optional): Outputs from model.
        """
        rank = get_rank()
        message_hub = MessageHub.get_instance('varlen_attn_args')
        message_hub.update_info(f'cumulative_len_rank_{rank}', None)
        message_hub.update_info(f'max_seqlen_rank_{rank}', None)