import os import torch import torch.nn as nn class Connector(nn.Module): def __init__(self, config=None): super().__init__() self._connector = None def load_model(self, **kwargs): pretrained_connector_path = kwargs.get('pretrained_connector_path', None) if pretrained_connector_path is not None: pretrained_connector_path = os.path.join(pretrained_connector_path, 'pytorch_model.bin') connector_weights = torch.load(pretrained_connector_path, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} self._connector.load_state_dict(get_w(connector_weights, '_connector')) print(f'Loading connector from {pretrained_connector_path}...') for p in self._connector.parameters(): p.requires_grad = False def forward(self, x): return self._connector(x)