gelnesr commited on
Commit
e640230
·
verified ·
1 Parent(s): 514b274

update device

Browse files
Files changed (1) hide show
  1. Dyna-1/model/model.py +1 -6
Dyna-1/model/model.py CHANGED
@@ -99,8 +99,7 @@ class ESM_model(nn.Module):
99
  self.method = method
100
  self.layer = layer
101
  if 'esm3' in self.method:
102
- #self.model = ESM3.from_pretrained("esm3_sm_open_v1").to(DEVICE,non_blocking=True).to(torch.float32)
103
- self.model = ESM3.from_pretrained("esm3_sm_open_v1").to(DEVICE,non_blocking=True).to(torch.float32)
104
  '''except GatedRepoError as e:
105
  print(f"No access to gated repository: {e}")
106
  except OSError as e:
@@ -109,10 +108,6 @@ class ESM_model(nn.Module):
109
  else:
110
  print(f"Other error occurred: {e}")'''
111
 
112
- self.n_layers = len(self.model.transformer.blocks)
113
- self.hidden_size = self.model.transformer.blocks[0].attn.d_model
114
- elif 'esmc' in self.method:
115
- self.model = ESMC.from_pretrained("esmc_300m").to(DEVICE,non_blocking=True).to(torch.float32)
116
  self.n_layers = len(self.model.transformer.blocks)
117
  self.hidden_size = self.model.transformer.blocks[0].attn.d_model
118
  elif self.method == 'esm2':
 
99
  self.method = method
100
  self.layer = layer
101
  if 'esm3' in self.method:
102
+ self.model = ESM3.from_pretrained("esm3_sm_open_v1").to(torch.float32)
 
103
  '''except GatedRepoError as e:
104
  print(f"No access to gated repository: {e}")
105
  except OSError as e:
 
108
  else:
109
  print(f"Other error occurred: {e}")'''
110
 
 
 
 
 
111
  self.n_layers = len(self.model.transformer.blocks)
112
  self.hidden_size = self.model.transformer.blocks[0].attn.d_model
113
  elif self.method == 'esm2':