Huhujingjing commited on
Commit
1187aa4
·
1 Parent(s): af2334e

Upload model

Browse files
Files changed (1) hide show
  1. modeling_transmxm.py +6 -6
modeling_transmxm.py CHANGED
@@ -1209,7 +1209,7 @@ class TransmxmModel(PreTrainedModel):
1209
  def __init__(self, config):
1210
  super().__init__(config)
1211
 
1212
- self.model = TransMXMNet(
1213
  dim=config.dim,
1214
  n_layer=config.n_layer,
1215
  cutoff=config.cutoff,
@@ -1221,14 +1221,14 @@ class TransmxmModel(PreTrainedModel):
1221
  smiles=config.smiles,
1222
  )
1223
 
1224
- self.mxm_model = None
1225
  self.dataset = None
1226
  self.output = None
1227
  self.data_loader = None
1228
  self.pred_data = None
1229
 
1230
  def forward(self, tensor):
1231
- return self.model.forward_features(tensor)
1232
 
1233
  def SmilesProcessor(self, smiles):
1234
  return self.process.get_data(smiles)
@@ -1242,8 +1242,8 @@ class TransmxmModel(PreTrainedModel):
1242
  drop_last = kwargs.pop('drop_last', False)
1243
  num_workers = kwargs.pop('num_workers', 0)
1244
 
1245
- self.mxm_model = AutoModel.from_pretrained("Huhujingjing/custom-transmxm", trust_remote_code=True).to(device)
1246
- self.mxm_model.eval()
1247
 
1248
  self.dataset = self.process.get_data(smiles)
1249
  self.output = ""
@@ -1264,7 +1264,7 @@ class TransmxmModel(PreTrainedModel):
1264
  batch = batch.to(device)
1265
  with torch.no_grad():
1266
  self.pred_data['smiles'] += batch['smiles']
1267
- self.pred_data['pred'] += self.gcn_model(batch).cpu().tolist()
1268
 
1269
  pred = torch.tensor(self.pred_data['pred']).reshape(-1)
1270
  if device == 'cuda':
 
1209
  def __init__(self, config):
1210
  super().__init__(config)
1211
 
1212
+ self.backbone = TransMXMNet(
1213
  dim=config.dim,
1214
  n_layer=config.n_layer,
1215
  cutoff=config.cutoff,
 
1221
  smiles=config.smiles,
1222
  )
1223
 
1224
+ self.model = None
1225
  self.dataset = None
1226
  self.output = None
1227
  self.data_loader = None
1228
  self.pred_data = None
1229
 
1230
  def forward(self, tensor):
1231
+ return self.bhackbone.forward_features(tensor)
1232
 
1233
  def SmilesProcessor(self, smiles):
1234
  return self.process.get_data(smiles)
 
1242
  drop_last = kwargs.pop('drop_last', False)
1243
  num_workers = kwargs.pop('num_workers', 0)
1244
 
1245
+ self.model = AutoModel.from_pretrained("Huhujingjing/custom-transmxm", trust_remote_code=True).to(device)
1246
+ self.model.eval()
1247
 
1248
  self.dataset = self.process.get_data(smiles)
1249
  self.output = ""
 
1264
  batch = batch.to(device)
1265
  with torch.no_grad():
1266
  self.pred_data['smiles'] += batch['smiles']
1267
+ self.pred_data['pred'] += self.model(batch).cpu().tolist()
1268
 
1269
  pred = torch.tensor(self.pred_data['pred']).reshape(-1)
1270
  if device == 'cuda':