Huhujingjing
commited on
Commit
·
1187aa4
1
Parent(s):
af2334e
Upload model
Browse files- modeling_transmxm.py +6 -6
modeling_transmxm.py
CHANGED
@@ -1209,7 +1209,7 @@ class TransmxmModel(PreTrainedModel):
|
|
1209 |
def __init__(self, config):
|
1210 |
super().__init__(config)
|
1211 |
|
1212 |
-
self.
|
1213 |
dim=config.dim,
|
1214 |
n_layer=config.n_layer,
|
1215 |
cutoff=config.cutoff,
|
@@ -1221,14 +1221,14 @@ class TransmxmModel(PreTrainedModel):
|
|
1221 |
smiles=config.smiles,
|
1222 |
)
|
1223 |
|
1224 |
-
self.
|
1225 |
self.dataset = None
|
1226 |
self.output = None
|
1227 |
self.data_loader = None
|
1228 |
self.pred_data = None
|
1229 |
|
1230 |
def forward(self, tensor):
|
1231 |
-
return self.
|
1232 |
|
1233 |
def SmilesProcessor(self, smiles):
|
1234 |
return self.process.get_data(smiles)
|
@@ -1242,8 +1242,8 @@ class TransmxmModel(PreTrainedModel):
|
|
1242 |
drop_last = kwargs.pop('drop_last', False)
|
1243 |
num_workers = kwargs.pop('num_workers', 0)
|
1244 |
|
1245 |
-
self.
|
1246 |
-
self.
|
1247 |
|
1248 |
self.dataset = self.process.get_data(smiles)
|
1249 |
self.output = ""
|
@@ -1264,7 +1264,7 @@ class TransmxmModel(PreTrainedModel):
|
|
1264 |
batch = batch.to(device)
|
1265 |
with torch.no_grad():
|
1266 |
self.pred_data['smiles'] += batch['smiles']
|
1267 |
-
self.pred_data['pred'] += self.
|
1268 |
|
1269 |
pred = torch.tensor(self.pred_data['pred']).reshape(-1)
|
1270 |
if device == 'cuda':
|
|
|
1209 |
def __init__(self, config):
|
1210 |
super().__init__(config)
|
1211 |
|
1212 |
+
self.backbone = TransMXMNet(
|
1213 |
dim=config.dim,
|
1214 |
n_layer=config.n_layer,
|
1215 |
cutoff=config.cutoff,
|
|
|
1221 |
smiles=config.smiles,
|
1222 |
)
|
1223 |
|
1224 |
+
self.model = None
|
1225 |
self.dataset = None
|
1226 |
self.output = None
|
1227 |
self.data_loader = None
|
1228 |
self.pred_data = None
|
1229 |
|
1230 |
def forward(self, tensor):
|
1231 |
+
return self.bhackbone.forward_features(tensor)
|
1232 |
|
1233 |
def SmilesProcessor(self, smiles):
|
1234 |
return self.process.get_data(smiles)
|
|
|
1242 |
drop_last = kwargs.pop('drop_last', False)
|
1243 |
num_workers = kwargs.pop('num_workers', 0)
|
1244 |
|
1245 |
+
self.model = AutoModel.from_pretrained("Huhujingjing/custom-transmxm", trust_remote_code=True).to(device)
|
1246 |
+
self.model.eval()
|
1247 |
|
1248 |
self.dataset = self.process.get_data(smiles)
|
1249 |
self.output = ""
|
|
|
1264 |
batch = batch.to(device)
|
1265 |
with torch.no_grad():
|
1266 |
self.pred_data['smiles'] += batch['smiles']
|
1267 |
+
self.pred_data['pred'] += self.model(batch).cpu().tolist()
|
1268 |
|
1269 |
pred = torch.tensor(self.pred_data['pred']).reshape(-1)
|
1270 |
if device == 'cuda':
|