Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
class Connector(nn.Module): | |
def __init__(self, config=None): | |
super().__init__() | |
self._connector = None | |
def load_model(self, **kwargs): | |
pretrained_connector_path = kwargs.get('pretrained_connector_path', None) | |
if pretrained_connector_path is not None: | |
pretrained_connector_path = os.path.join(pretrained_connector_path, 'pytorch_model.bin') | |
connector_weights = torch.load(pretrained_connector_path, map_location='cpu') | |
def get_w(weights, keyword): | |
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} | |
self._connector.load_state_dict(get_w(connector_weights, '_connector')) | |
print(f'Loading connector from {pretrained_connector_path}...') | |
for p in self._connector.parameters(): | |
p.requires_grad = False | |
def forward(self, x): | |
return self._connector(x) | |