update device
Browse files- Dyna-1/model/model.py +1 -6
Dyna-1/model/model.py
CHANGED
@@ -99,8 +99,7 @@ class ESM_model(nn.Module):
|
|
99 |
self.method = method
|
100 |
self.layer = layer
|
101 |
if 'esm3' in self.method:
|
102 |
-
|
103 |
-
self.model = ESM3.from_pretrained("esm3_sm_open_v1").to(DEVICE,non_blocking=True).to(torch.float32)
|
104 |
'''except GatedRepoError as e:
|
105 |
print(f"No access to gated repository: {e}")
|
106 |
except OSError as e:
|
@@ -109,10 +108,6 @@ class ESM_model(nn.Module):
|
|
109 |
else:
|
110 |
print(f"Other error occurred: {e}")'''
|
111 |
|
112 |
-
self.n_layers = len(self.model.transformer.blocks)
|
113 |
-
self.hidden_size = self.model.transformer.blocks[0].attn.d_model
|
114 |
-
elif 'esmc' in self.method:
|
115 |
-
self.model = ESMC.from_pretrained("esmc_300m").to(DEVICE,non_blocking=True).to(torch.float32)
|
116 |
self.n_layers = len(self.model.transformer.blocks)
|
117 |
self.hidden_size = self.model.transformer.blocks[0].attn.d_model
|
118 |
elif self.method == 'esm2':
|
|
|
99 |
self.method = method
|
100 |
self.layer = layer
|
101 |
if 'esm3' in self.method:
|
102 |
+
self.model = ESM3.from_pretrained("esm3_sm_open_v1").to(torch.float32)
|
|
|
103 |
'''except GatedRepoError as e:
|
104 |
print(f"No access to gated repository: {e}")
|
105 |
except OSError as e:
|
|
|
108 |
else:
|
109 |
print(f"Other error occurred: {e}")'''
|
110 |
|
|
|
|
|
|
|
|
|
111 |
self.n_layers = len(self.model.transformer.blocks)
|
112 |
self.hidden_size = self.model.transformer.blocks[0].attn.d_model
|
113 |
elif self.method == 'esm2':
|