Camil Ziane
init space
74b17e0
raw
history blame
1 kB
import os
import torch
import torch.nn as nn
class Connector(nn.Module):
def __init__(self, config=None):
super().__init__()
self._connector = None
def load_model(self, **kwargs):
pretrained_connector_path = kwargs.get('pretrained_connector_path', None)
if pretrained_connector_path is not None:
pretrained_connector_path = os.path.join(pretrained_connector_path, 'pytorch_model.bin')
connector_weights = torch.load(pretrained_connector_path, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self._connector.load_state_dict(get_w(connector_weights, '_connector'))
print(f'Loading connector from {pretrained_connector_path}...')
for p in self._connector.parameters():
p.requires_grad = False
def forward(self, x):
return self._connector(x)