Spaces:
Runtime error
Runtime error
File size: 1,041 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# Copyright (c) OpenMMLab. All rights reserved
import warnings
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmpretrain.models import BaseRetriever
from mmpretrain.registry import HOOKS
@HOOKS.register_module()
class PrepareProtoBeforeValLoopHook(Hook):
"""The hook to prepare the prototype in retrievers.
Since the encoders of the retriever changes during training, the prototype
changes accordingly. So the `prototype_vecs` needs to be regenerated before
validation loop.
"""
def before_val(self, runner) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
if isinstance(model, BaseRetriever):
if hasattr(model, 'prepare_prototype'):
model.prepare_prototype()
else:
warnings.warn(
'Only the `mmpretrain.models.retrievers.BaseRetriever` '
'can execute `PrepareRetrieverPrototypeHook`, but got '
f'`{type(model)}`')
|