File size: 1,041 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# Copyright (c) OpenMMLab. All rights reserved
import warnings

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmpretrain.models import BaseRetriever
from mmpretrain.registry import HOOKS


@HOOKS.register_module()
class PrepareProtoBeforeValLoopHook(Hook):
    """The hook to prepare the prototype in retrievers.

    Since the encoders of the retriever changes during training, the prototype
    changes accordingly. So the `prototype_vecs` needs to be regenerated before
    validation loop.
    """

    def before_val(self, runner) -> None:
        model = runner.model
        if is_model_wrapper(model):
            model = model.module

        if isinstance(model, BaseRetriever):
            if hasattr(model, 'prepare_prototype'):
                model.prepare_prototype()
        else:
            warnings.warn(
                'Only the `mmpretrain.models.retrievers.BaseRetriever` '
                'can execute `PrepareRetrieverPrototypeHook`, but got '
                f'`{type(model)}`')