File size: 1,002 Bytes
74b17e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import os
import torch
import torch.nn as nn
class Connector(nn.Module):
def __init__(self, config=None):
super().__init__()
self._connector = None
def load_model(self, **kwargs):
pretrained_connector_path = kwargs.get('pretrained_connector_path', None)
if pretrained_connector_path is not None:
pretrained_connector_path = os.path.join(pretrained_connector_path, 'pytorch_model.bin')
connector_weights = torch.load(pretrained_connector_path, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self._connector.load_state_dict(get_w(connector_weights, '_connector'))
print(f'Loading connector from {pretrained_connector_path}...')
for p in self._connector.parameters():
p.requires_grad = False
def forward(self, x):
return self._connector(x)
|