yinuozhang
commited on
Commit
•
071db43
1
Parent(s):
01a1f08
model
Browse files
model.py
CHANGED
@@ -14,7 +14,7 @@ import gc
|
|
14 |
from torch.optim.lr_scheduler import _LRScheduler
|
15 |
from transformers import EsmModel, PreTrainedModel
|
16 |
from configuration import MetaLATTEConfig
|
17 |
-
|
18 |
seed_everything(42)
|
19 |
|
20 |
class GELU(nn.Module):
|
@@ -226,9 +226,19 @@ class MultitaskProteinModel(PreTrainedModel):
|
|
226 |
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
|
227 |
|
228 |
model = cls(config)
|
229 |
-
state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
|
230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
return model
|
|
|
232 |
|
233 |
def forward(self, input_ids, attention_mask=None):
|
234 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
|
14 |
from torch.optim.lr_scheduler import _LRScheduler
|
15 |
from transformers import EsmModel, PreTrainedModel
|
16 |
from configuration import MetaLATTEConfig
|
17 |
+
from urllib.parse import urljoin
|
18 |
seed_everything(42)
|
19 |
|
20 |
class GELU(nn.Module):
|
|
|
226 |
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
|
227 |
|
228 |
model = cls(config)
|
229 |
+
#state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
|
230 |
+
try:
|
231 |
+
state_dict_url = urljoin(f"https://huggingface.co/{pretrained_model_name_or_path}/resolve/main/", "pytorch_model.bin")
|
232 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
233 |
+
state_dict_url,
|
234 |
+
map_location=torch.device('cpu')
|
235 |
+
)['state_dict']
|
236 |
+
model.load_state_dict(state_dict, strict=False)
|
237 |
+
except Exception as e:
|
238 |
+
raise RuntimeError(f"Error loading state_dict from {pretrained_model_name_or_path}/pytorch_model.bin: {e}")
|
239 |
+
|
240 |
return model
|
241 |
+
|
242 |
|
243 |
def forward(self, input_ids, attention_mask=None):
|
244 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|