# Copyright (c) OpenMMLab. All rights reserved. from typing import Any, Type from mmpretrain.registry import MODELS class ExtendModule: """Combine the base language model with adapter. This module will create a instance from base with extended functions in adapter. Args: base (object): Base module could be any object that represent a instance of language model or a dict that can build the base module. adapter: (dict): Dict to build the adapter. """ def __new__(cls, base: object, adapter: dict): if isinstance(base, dict): base = MODELS.build(base) adapter_module = MODELS.get(adapter.pop('type')) cls.extend_instance(base, adapter_module) return adapter_module.extend_init(base, **adapter) @classmethod def extend_instance(cls, base: object, mixin: Type[Any]): """Apply mixins to a class instance after creation. Args: base (object): Base module instance. mixin: (Type[Any]): Adapter class type to mixin. """ base_cls = base.__class__ base_cls_name = base.__class__.__name__ base.__class__ = type( base_cls_name, (mixin, base_cls), {}) # mixin needs to go first for our forward() logic to work def getattr_recursive(obj, att): """ Return nested attribute of obj Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c """ if att == '': return obj i = att.find('.') if i < 0: return getattr(obj, att) else: return getattr_recursive(getattr(obj, att[:i]), att[i + 1:]) def setattr_recursive(obj, att, val): """ Set nested attribute of obj Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val """ if '.' in att: obj = getattr_recursive(obj, '.'.join(att.split('.')[:-1])) setattr(obj, att.split('.')[-1], val)