gyigit commited on
Commit
4c9e6d9
·
1 Parent(s): 4d9e86e
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import Inference
3
+ import PIL
4
+ from PIL import Image
5
+ import pandas as pd
6
+ import random
7
+ from rdkit import Chem
8
+ from rdkit.Chem import Draw
9
+ from rdkit.Chem.Draw import IPythonConsole
10
+ import shutil
11
+ import os
12
+ import time
13
+
14
+ class DrugGENConfig:
15
+ # Inference configuration
16
+ submodel='DrugGEN'
17
+ inference_model="experiments/models/DrugGEN/"
18
+ sample_num=100
19
+ disable_correction=False # corresponds to correct=True in old config
20
+
21
+ # Data configuration
22
+ inf_smiles='data/chembl_test.smi' # corresponds to inf_raw_file in old config
23
+ train_smiles='data/chembl_train.smi'
24
+ train_drug_smiles='data/akt1_train.smi'
25
+ inf_batch_size=1
26
+ mol_data_dir='data'
27
+ features=False
28
+
29
+ # Model configuration
30
+ act='relu'
31
+ max_atom=45
32
+ dim=128
33
+ depth=1
34
+ heads=8
35
+ mlp_ratio=3
36
+ dropout=0.
37
+
38
+ # Seed configuration
39
+ set_seed=True
40
+ seed=10
41
+
42
+
43
+ class DrugGENAKT1Config(DrugGENConfig):
44
+ submodel='DrugGEN'
45
+ inference_model="experiments/models/DrugGEN-AKT1/"
46
+ train_drug_smiles='data/akt1_train.smi'
47
+ max_atom=45
48
+
49
+
50
+ class DrugGENCDK2Config(DrugGENConfig):
51
+ submodel='DrugGEN'
52
+ inference_model="experiments/models/DrugGEN-CDK2/"
53
+ train_drug_smiles='data/cdk2_train.smi'
54
+ max_atom=38
55
+
56
+
57
+ class NoTargetConfig(DrugGENConfig):
58
+ submodel="NoTarget"
59
+ inference_model="experiments/models/NoTarget/"
60
+ train_drug_smiles='data/chembl_train.smi' # No specific target, use general ChEMBL data
61
+
62
+
63
+ model_configs = {
64
+ "DrugGEN-AKT1": DrugGENAKT1Config(),
65
+ "DrugGEN-CDK2": DrugGENCDK2Config(),
66
+ "DrugGEN-NoTarget": NoTargetConfig(),
67
+ }
68
+
69
+
70
+
71
+ def function(model_name: str, num_molecules: int, seed_num: int) -> tuple[PIL.Image, pd.DataFrame, str]:
72
+ '''
73
+ Returns:
74
+ image, score_df, file path
75
+ '''
76
+ if model_name == "DrugGEN-NoTarget":
77
+ model_name = "NoTarget"
78
+
79
+ config = model_configs[model_name]
80
+ config.sample_num = num_molecules
81
+
82
+ if config.sample_num > 250:
83
+ raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
84
+
85
+ if seed_num is None or seed_num.strip() == "":
86
+ config.seed = random.randint(0, 10000)
87
+ else:
88
+ try:
89
+ config.seed = int(seed_num)
90
+ except ValueError:
91
+ raise gr.Error("The seed must be an integer value!")
92
+
93
+
94
+ inferer = Inference(config)
95
+ start_time = time.time()
96
+ scores = inferer.inference() # create scores_df out of this
97
+ et = time.time() - start_time
98
+
99
+ score_df = pd.DataFrame({
100
+ "Runtime (seconds)": [et],
101
+ "Validity": [scores["validity"].iloc[0]],
102
+ "Uniqueness": [scores["uniqueness"].iloc[0]],
103
+ "Novelty (Train)": [scores["novelty"].iloc[0]],
104
+ "Novelty (Test)": [scores["novelty_test"].iloc[0]],
105
+ "Drug Novelty": [scores["drug_novelty"].iloc[0]],
106
+ "Max Length": [scores["max_len"].iloc[0]],
107
+ "Mean Atom Type": [scores["mean_atom_type"].iloc[0]],
108
+ "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
109
+ "SNN Drug": [scores["snn_drug"].iloc[0]],
110
+ "Internal Diversity": [scores["IntDiv"].iloc[0]],
111
+ "QED": [scores["qed"].iloc[0]],
112
+ "SA Score": [scores["sa"].iloc[0]]
113
+ })
114
+
115
+ output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt'
116
+
117
+ new_path = f'{model_name}_denovo_mols.smi'
118
+ os.rename(output_file_path, new_path)
119
+
120
+ with open(new_path) as f:
121
+ inference_drugs = f.read()
122
+
123
+ generated_molecule_list = inference_drugs.split("\n")[:-1]
124
+
125
+ rng = random.Random(config.seed)
126
+ if num_molecules > 12:
127
+ selected_molecules = rng.choices(generated_molecule_list, k=12)
128
+ else:
129
+ selected_molecules = generated_molecule_list
130
+
131
+ selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None]
132
+
133
+ drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
134
+ drawOptions.prepareMolsBeforeDrawing = False
135
+ drawOptions.bondLineWidth = 0.5
136
+
137
+ molecule_image = Draw.MolsToGridImage(
138
+ selected_molecules,
139
+ molsPerRow=3,
140
+ subImgSize=(400, 400),
141
+ maxMols=len(selected_molecules),
142
+ # legends=None,
143
+ returnPNG=False,
144
+ drawOptions=drawOptions,
145
+ highlightAtomLists=None,
146
+ highlightBondLists=None,
147
+ )
148
+
149
+ return molecule_image, score_df, new_path
150
+
151
+
152
+
153
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
154
+ with gr.Row():
155
+ with gr.Column(scale=1):
156
+ gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
157
+ with gr.Row():
158
+ gr.Markdown("[![arXiv](https://img.shields.io/badge/arXiv-2302.07868-b31b1b.svg)](https://arxiv.org/abs/2302.07868)")
159
+ gr.Markdown("[![github-repository](https://img.shields.io/badge/GitHub-black?logo=github)](https://github.com/HUBioDataLab/DrugGEN)")
160
+
161
+ with gr.Accordion("About DrugGEN Models", open=False):
162
+ gr.Markdown("""
163
+ ## Model Variations
164
+
165
+ ### DrugGEN-AKT1
166
+ This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749), a serine/threonine-protein kinase that plays a key role in regulating cell survival, metabolism, and growth. AKT1 is a significant target in cancer therapy, particularly for breast, colorectal, and ovarian cancers.
167
+
168
+ The model learns from:
169
+ - General drug-like molecules from ChEMBL database
170
+ - Known AKT1 inhibitors
171
+ - Maximum atom count: 45
172
+
173
+ ### DrugGEN-CDK2
174
+ This model targets the human CDK2 protein (UniProt ID: P24941), a cyclin-dependent kinase involved in cell cycle regulation. CDK2 inhibitors are being investigated for treating various cancers, particularly those with dysregulated cell cycle control.
175
+
176
+ The model learns from:
177
+ - General drug-like molecules from ChEMBL database
178
+ - Known CDK2 inhibitors
179
+ - Maximum atom count: 38
180
+
181
+ ### DrugGEN-NoTarget
182
+ This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for:
183
+ - Exploring chemical space
184
+ - Generating diverse scaffolds
185
+ - Creating molecules with drug-like properties
186
+
187
+ ## How It Works
188
+ DrugGEN uses a graph-based generative adversarial network (GAN) architecture where:
189
+ 1. The generator creates molecular graphs
190
+ 2. The discriminator evaluates them against real molecules
191
+ 3. The model learns to generate increasingly realistic and target-specific molecules
192
+
193
+ For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
194
+ """)
195
+
196
+ with gr.Accordion("Understanding the Metrics", open=False):
197
+ gr.Markdown("""
198
+ ## Evaluation Metrics
199
+
200
+ ### Basic Metrics
201
+ - **Validity**: Percentage of generated molecules that are chemically valid
202
+ - **Uniqueness**: Percentage of unique molecules among valid ones
203
+ - **Runtime**: Time taken to generate the requested molecules
204
+
205
+ ### Novelty Metrics
206
+ - **Novelty (Train)**: Percentage of molecules not found in the training set
207
+ - **Novelty (Test)**: Percentage of molecules not found in the test set
208
+ - **Drug Novelty**: Percentage of molecules not found in known drugs
209
+
210
+ ### Structural Metrics
211
+ - **Max Length**: Maximum component length in the generated molecules
212
+ - **Mean Atom Type**: Average distribution of atom types
213
+ - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
214
+
215
+ ### Drug-likeness Metrics
216
+ - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
217
+ - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier)
218
+
219
+ ### Similarity Metrics
220
+ - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
221
+ - **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs)
222
+ """)
223
+
224
+ model_name = gr.Radio(
225
+ choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
226
+ value="DrugGEN-AKT1",
227
+ label="Select Target Model",
228
+ info="Choose which protein target or general model to use for molecule generation"
229
+ )
230
+
231
+ num_molecules = gr.Slider(
232
+ minimum=10,
233
+ maximum=250,
234
+ value=100,
235
+ step=10,
236
+ label="Number of Molecules to Generate",
237
+ info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU.""
238
+ )
239
+
240
+ seed_num = gr.Textbox(
241
+ label="Random Seed (Optional)",
242
+ value="",
243
+ info="Set a specific seed for reproducible results, or leave empty for random generation"
244
+ )
245
+
246
+ submit_button = gr.Button(
247
+ value="Generate Molecules",
248
+ variant="primary",
249
+ size="lg"
250
+ )
251
+
252
+ with gr.Column(scale=2):
253
+ with gr.Tabs():
254
+ with gr.TabItem("Generated Molecules"):
255
+ image_output = gr.Image(
256
+ label="Sample of Generated Molecules",
257
+ elem_id="molecule_display"
258
+ )
259
+ file_download = gr.File(
260
+ label="Download All Generated Molecules (SMILES format)",
261
+ )
262
+
263
+ with gr.TabItem("Performance Metrics"):
264
+ scores_df = gr.Dataframe(
265
+ label="Model Performance Metrics",
266
+ headers=["Runtime (seconds)", "Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)",
267
+ "Drug Novelty", "Max Length", "Mean Atom Type", "SNN ChEMBL", "SNN Drug",
268
+ "Internal Diversity", "QED", "SA Score"]
269
+ )
270
+
271
+ with gr.Accordion("Generation Settings", open=False):
272
+ gr.Markdown("""
273
+ ## Technical Details
274
+
275
+ - This demo runs on CPU which limits generation speed
276
+ - Generating 200 molecules takes approximately 6 minutes
277
+ - For faster generation or larger batches, run the model on GPU using our GitHub repository
278
+ - The model uses a graph-based representation of molecules
279
+ - Maximum atom count varies by model (AKT1: 45, CDK2: 38)
280
+ """)
281
+
282
+ gr.Markdown("### Created by the HU BioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
283
+
284
+ submit_button.click(function, inputs=[model_name, num_molecules, seed_num], outputs=[image_output, scores_df, file_download], api_name="inference")
285
+ #demo.queue(concurrency_count=1)
286
+ demo.queue()
287
+ demo.launch()
288
+
inference.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+ import random
5
+ import pickle
6
+ import argparse
7
+ import os.path as osp
8
+
9
+ import torch
10
+ import torch.utils.data
11
+ from torch_geometric.loader import DataLoader
12
+
13
+ import pandas as pd
14
+ from tqdm import tqdm
15
+
16
+ from rdkit import RDLogger, Chem
17
+ from rdkit.Chem import QED, RDConfig
18
+
19
+ sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
20
+ import sascorer
21
+
22
+ from src.util.utils import *
23
+ from src.model.models import Generator
24
+ from src.data.dataset import DruggenDataset
25
+ from src.data.utils import get_encoders_decoders, load_molecules
26
+ from src.model.loss import generator_loss
27
+ from src.util.smiles_cor import smi_correct
28
+
29
+
30
+ class Inference(object):
31
+ """Inference class for DrugGEN."""
32
+
33
+ def __init__(self, config):
34
+ if config.set_seed:
35
+ np.random.seed(config.seed)
36
+ random.seed(config.seed)
37
+ torch.manual_seed(config.seed)
38
+ torch.cuda.manual_seed_all(config.seed)
39
+
40
+ torch.backends.cudnn.deterministic = True
41
+ torch.backends.cudnn.benchmark = False
42
+
43
+ os.environ["PYTHONHASHSEED"] = str(config.seed)
44
+
45
+ print(f'Using seed {config.seed}')
46
+
47
+ self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
48
+
49
+ # Initialize configurations
50
+ self.submodel = config.submodel
51
+ self.inference_model = config.inference_model
52
+ self.sample_num = config.sample_num
53
+ self.disable_correction = config.disable_correction
54
+
55
+ # Data loader.
56
+ self.inf_smiles = config.inf_smiles # SMILES containing text file for first dataset.
57
+ # Write the full path to file.
58
+
59
+ inf_smiles_basename = osp.basename(self.inf_smiles)
60
+
61
+ # Get the base name without extension and add max_atom to it
62
+ self.max_atom = config.max_atom # Model is based on one-shot generation.
63
+ inf_smiles_base = os.path.splitext(inf_smiles_basename)[0]
64
+
65
+ # Change extension from .smi to .pt and add max_atom to the filename
66
+ self.inf_dataset_file = f"{inf_smiles_base}{self.max_atom}.pt"
67
+
68
+ self.inf_batch_size = config.inf_batch_size
69
+ self.train_smiles = config.train_smiles
70
+ self.train_drug_smiles = config.train_drug_smiles
71
+ self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
72
+ self.dataset_name = self.inf_dataset_file.split(".")[0]
73
+ self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
74
+ # Additional node features can be added. Please check new_dataloarder.py Line 102.
75
+
76
+ # Get atom and bond encoders/decoders
77
+ self.atom_encoder, self.atom_decoder, self.bond_encoder, self.bond_decoder = get_encoders_decoders(
78
+ self.train_smiles,
79
+ self.train_drug_smiles,
80
+ self.max_atom
81
+ )
82
+
83
+ self.inf_dataset = DruggenDataset(self.mol_data_dir,
84
+ self.inf_dataset_file,
85
+ self.inf_smiles,
86
+ self.max_atom,
87
+ self.features,
88
+ atom_encoder=self.atom_encoder,
89
+ atom_decoder=self.atom_decoder,
90
+ bond_encoder=self.bond_encoder,
91
+ bond_decoder=self.bond_decoder)
92
+
93
+ self.inf_loader = DataLoader(self.inf_dataset,
94
+ shuffle=True,
95
+ batch_size=self.inf_batch_size,
96
+ drop_last=True) # PyG dataloader for the first GAN.
97
+
98
+ self.m_dim = len(self.atom_decoder) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension.
99
+ self.b_dim = len(self.bond_decoder) # Bond type dimension.
100
+ self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
101
+
102
+ # Model configurations.
103
+ self.act = config.act
104
+ self.dim = config.dim
105
+ self.depth = config.depth
106
+ self.heads = config.heads
107
+ self.mlp_ratio = config.mlp_ratio
108
+ self.dropout = config.dropout
109
+
110
+ self.build_model()
111
+
112
+ def build_model(self):
113
+ """Create generators and discriminators."""
114
+ self.G = Generator(self.act,
115
+ self.vertexes,
116
+ self.b_dim,
117
+ self.m_dim,
118
+ self.dropout,
119
+ dim=self.dim,
120
+ depth=self.depth,
121
+ heads=self.heads,
122
+ mlp_ratio=self.mlp_ratio)
123
+ self.G.to(self.device)
124
+ self.print_network(self.G, 'G')
125
+
126
+ def print_network(self, model, name):
127
+ """Print out the network information."""
128
+ num_params = 0
129
+ for p in model.parameters():
130
+ num_params += p.numel()
131
+ print(model)
132
+ print(name)
133
+ print("The number of parameters: {}".format(num_params))
134
+
135
+ def restore_model(self, submodel, model_directory):
136
+ """Restore the trained generator and discriminator."""
137
+ print('Loading the model...')
138
+ G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
139
+ self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
140
+
141
+ def inference(self):
142
+ # Load the trained generator.
143
+ self.restore_model(self.submodel, self.inference_model)
144
+
145
+ # smiles data for metrics calculation.
146
+ chembl_smiles = [line for line in open(self.train_smiles, 'r').read().splitlines()]
147
+ chembl_test = [line for line in open(self.inf_smiles, 'r').read().splitlines()]
148
+ drug_smiles = [line for line in open(self.train_drug_smiles, 'r').read().splitlines()]
149
+ drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
150
+ drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
151
+
152
+
153
+ # Make directories if not exist.
154
+ if not os.path.exists("experiments/inference/{}".format(self.submodel)):
155
+ os.makedirs("experiments/inference/{}".format(self.submodel))
156
+
157
+ if not self.disable_correction:
158
+ correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
159
+
160
+ search_res = pd.DataFrame(columns=["submodel", "validity",
161
+ "uniqueness", "novelty",
162
+ "novelty_test", "drug_novelty",
163
+ "max_len", "mean_atom_type",
164
+ "snn_chembl", "snn_drug", "IntDiv", "qed", "sa"])
165
+
166
+ self.G.eval()
167
+
168
+ start_time = time.time()
169
+ metric_calc_dr = []
170
+ uniqueness_calc = []
171
+ real_smiles_snn = []
172
+ nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
173
+ generated_smiles = []
174
+ val_counter = 0
175
+ none_counter = 0
176
+
177
+ # Inference mode
178
+ with torch.inference_mode():
179
+ pbar = tqdm(range(self.sample_num))
180
+ pbar.set_description('Inference mode for {} model started'.format(self.submodel))
181
+ for i, data in enumerate(self.inf_loader):
182
+ val_counter += 1
183
+ # Preprocess dataset
184
+ _, a_tensor, x_tensor = load_molecules(
185
+ data=data,
186
+ batch_size=self.inf_batch_size,
187
+ device=self.device,
188
+ b_dim=self.b_dim,
189
+ m_dim=self.m_dim,
190
+ )
191
+
192
+ _, _, node_sample, edge_sample = self.G(a_tensor, x_tensor)
193
+
194
+ g_edges_hat_sample = torch.max(edge_sample, -1)[1]
195
+ g_nodes_hat_sample = torch.max(node_sample, -1)[1]
196
+
197
+ fake_mol_g = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=False, file_name=self.dataset_name)
198
+ for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
199
+
200
+ a_tensor_sample = torch.max(a_tensor, -1)[1]
201
+ x_tensor_sample = torch.max(x_tensor, -1)[1]
202
+ real_mols = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
203
+ for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
204
+
205
+ inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g]
206
+ inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
207
+
208
+ for molecules in inference_drugs:
209
+ if molecules is None:
210
+ none_counter += 1
211
+
212
+ for molecules in inference_drugs:
213
+ if molecules is not None:
214
+ molecules = molecules.replace("*", "C")
215
+ generated_smiles.append(molecules)
216
+ uniqueness_calc.append(molecules)
217
+ nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
218
+ pbar.update(1)
219
+ metric_calc_dr.append(molecules)
220
+
221
+ real_smiles_snn.append(real_mols[0])
222
+ generation_number = len([x for x in metric_calc_dr if x is not None])
223
+ if generation_number == self.sample_num or none_counter == self.sample_num:
224
+ break
225
+
226
+ if not self.disable_correction:
227
+ correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
228
+ gen_smi = correct.correct_smiles_list(generated_smiles)
229
+ else:
230
+ gen_smi = generated_smiles
231
+
232
+ et = time.time() - start_time
233
+
234
+ gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
235
+ real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
236
+
237
+ if not self.disable_correction:
238
+ val = round(len(gen_smi)/self.sample_num, 3)
239
+ else:
240
+ val = round(fraction_valid(gen_smi), 3)
241
+
242
+ uniq = round(fraction_unique(gen_smi), 3)
243
+ nov = round(novelty(gen_smi, chembl_smiles), 3)
244
+ nov_test = round(novelty(gen_smi, chembl_test), 3)
245
+ drug_nov = round(novelty(gen_smi, drug_smiles), 3)
246
+ max_len = round(Metrics.max_component(gen_smi, self.vertexes), 3)
247
+ mean_atom = round(Metrics.mean_atom_type(nodes_sample), 3)
248
+ snn_chembl = round(average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)), 3)
249
+ snn_drug = round(average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)), 3)
250
+ int_div = round((internal_diversity(np.array(gen_vecs)))[0], 3)
251
+ qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
252
+ sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
253
+
254
+ model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
255
+ "uniqueness": [uniq], "novelty": [nov],
256
+ "novelty_test": [nov_test], "drug_novelty": [drug_nov],
257
+ "max_len": [max_len], "mean_atom_type": [mean_atom],
258
+ "snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
259
+ "IntDiv": [int_div], "qed": [qed], "sa": [sa]})
260
+
261
+ # Write generated SMILES to a temporary file for app.py to use
262
+ temp_file = f'{self.submodel}_denovo_mols.smi'
263
+ with open(temp_file, 'w') as f:
264
+ f.write("SMILES\n")
265
+ for smiles in gen_smi:
266
+ f.write(f"{smiles}\n")
267
+
268
+ return model_res
269
+
270
+
271
+ if __name__=="__main__":
272
+ parser = argparse.ArgumentParser()
273
+
274
+ # Inference configuration.
275
+ parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
276
+ parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
277
+ parser.add_argument('--sample_num', type=int, default=100, help='inference samples')
278
+ parser.add_argument('--disable_correction', action='store_true', help='Disable SMILES correction')
279
+
280
+ # Data configuration.
281
+ parser.add_argument('--inf_smiles', type=str, required=True)
282
+ parser.add_argument('--train_smiles', type=str, required=True)
283
+ parser.add_argument('--train_drug_smiles', type=str, required=True)
284
+ parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
285
+ parser.add_argument('--mol_data_dir', type=str, default='data')
286
+ parser.add_argument('--features', action='store_true', help='features dimension for nodes')
287
+
288
+ # Model configuration.
289
+ parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
290
+ parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
291
+ parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
292
+ parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
293
+ parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
294
+ parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
295
+ parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
296
+
297
+ # Seed configuration.
298
+ parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
299
+ parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
300
+
301
+ config = parser.parse_args()
302
+ inference = Inference(config)
303
+ inference.inference()
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
src/data/__init__.py ADDED
File without changes
src/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
src/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
src/data/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.75 kB). View file
 
src/data/dataset.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import re
4
+ import pickle
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+
10
+ import torch
11
+ from torch_geometric.data import Data, InMemoryDataset
12
+
13
+ from rdkit import Chem, RDLogger
14
+
15
+ from src.data.utils import label2onehot
16
+
17
+ RDLogger.DisableLog('rdApp.*')
18
+
19
+
20
+ class DruggenDataset(InMemoryDataset):
21
+ def __init__(self, root, dataset_file, raw_files, max_atom, features,
22
+ atom_encoder, atom_decoder, bond_encoder, bond_decoder,
23
+ transform=None, pre_transform=None, pre_filter=None):
24
+ """
25
+ Initialize the DruggenDataset with pre-loaded encoder/decoder dictionaries.
26
+
27
+ Parameters:
28
+ root (str): Root directory.
29
+ dataset_file (str): Name of the processed dataset file.
30
+ raw_files (str): Path to the raw SMILES file.
31
+ max_atom (int): Maximum number of atoms allowed in a molecule.
32
+ features (bool): Whether to include additional node features.
33
+ atom_encoder (dict): Pre-loaded atom encoder dictionary.
34
+ atom_decoder (dict): Pre-loaded atom decoder dictionary.
35
+ bond_encoder (dict): Pre-loaded bond encoder dictionary.
36
+ bond_decoder (dict): Pre-loaded bond decoder dictionary.
37
+ transform, pre_transform, pre_filter: See PyG InMemoryDataset.
38
+ """
39
+ self.dataset_name = dataset_file.split(".")[0]
40
+ self.dataset_file = dataset_file
41
+ self.raw_files = raw_files
42
+ self.max_atom = max_atom
43
+ self.features = features
44
+
45
+ # Use the provided encoder/decoder mappings.
46
+ self.atom_encoder_m = atom_encoder
47
+ self.atom_decoder_m = atom_decoder
48
+ self.bond_encoder_m = bond_encoder
49
+ self.bond_decoder_m = bond_decoder
50
+
51
+ self.atom_num_types = len(atom_encoder)
52
+ self.bond_num_types = len(bond_encoder)
53
+
54
+ super().__init__(root, transform, pre_transform, pre_filter)
55
+ path = osp.join(self.processed_dir, dataset_file)
56
+ self.data, self.slices = torch.load(path)
57
+ self.root = root
58
+
59
+ @property
60
+ def processed_dir(self):
61
+ """
62
+ Returns the directory where processed dataset files are stored.
63
+ """
64
+ return self.root
65
+
66
+ @property
67
+ def raw_file_names(self):
68
+ """
69
+ Returns the raw SMILES file name.
70
+ """
71
+ return self.raw_files
72
+
73
+ @property
74
+ def processed_file_names(self):
75
+ """
76
+ Returns the name of the processed dataset file.
77
+ """
78
+ return self.dataset_file
79
+
80
+ def _filter_smiles(self, smiles_list):
81
+ """
82
+ Filters the input list of SMILES strings to keep only valid molecules that:
83
+ - Can be successfully parsed,
84
+ - Have a number of atoms less than or equal to the maximum allowed (max_atom),
85
+ - Contain only atoms present in the atom_encoder,
86
+ - Contain only bonds present in the bond_encoder.
87
+
88
+ Parameters:
89
+ smiles_list (list): List of SMILES strings.
90
+
91
+ Returns:
92
+ max_length (int): Maximum number of atoms found in the filtered molecules.
93
+ filtered_smiles (list): List of valid SMILES strings.
94
+ """
95
+ max_length = 0
96
+ filtered_smiles = []
97
+ for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
98
+ mol = Chem.MolFromSmiles(smiles)
99
+ if mol is None:
100
+ continue
101
+
102
+ # Check molecule size
103
+ molecule_size = mol.GetNumAtoms()
104
+ if molecule_size > self.max_atom:
105
+ continue
106
+
107
+ # Filter out molecules with atoms not in the atom_encoder
108
+ if not all(atom.GetAtomicNum() in self.atom_encoder_m for atom in mol.GetAtoms()):
109
+ continue
110
+
111
+ # Filter out molecules with bonds not in the bond_encoder
112
+ if not all(bond.GetBondType() in self.bond_encoder_m for bond in mol.GetBonds()):
113
+ continue
114
+
115
+ filtered_smiles.append(smiles)
116
+ max_length = max(max_length, molecule_size)
117
+ return max_length, filtered_smiles
118
+
119
+ def _genA(self, mol, connected=True, max_length=None):
120
+ """
121
+ Generates the adjacency matrix for a molecule based on its bond structure.
122
+
123
+ Parameters:
124
+ mol (rdkit.Chem.Mol): The molecule.
125
+ connected (bool): If True, ensures all atoms are connected.
126
+ max_length (int, optional): The size of the matrix; if None, uses number of atoms in mol.
127
+
128
+ Returns:
129
+ np.array: Adjacency matrix with bond types as entries, or None if disconnected.
130
+ """
131
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
132
+ A = np.zeros((max_length, max_length))
133
+ begin = [b.GetBeginAtomIdx() for b in mol.GetBonds()]
134
+ end = [b.GetEndAtomIdx() for b in mol.GetBonds()]
135
+ bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
136
+ A[begin, end] = bond_type
137
+ A[end, begin] = bond_type
138
+ degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
139
+ return A if connected and (degree > 0).all() else None
140
+
141
+ def _genX(self, mol, max_length=None):
142
+ """
143
+ Generates the feature vector for each atom in a molecule by encoding their atomic numbers.
144
+
145
+ Parameters:
146
+ mol (rdkit.Chem.Mol): The molecule.
147
+ max_length (int, optional): Length of the feature vector; if None, uses number of atoms in mol.
148
+
149
+ Returns:
150
+ np.array: Array of atom feature indices, padded with zeros if necessary, or None on error.
151
+ """
152
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
153
+ try:
154
+ return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] +
155
+ [0] * (max_length - mol.GetNumAtoms()))
156
+ except KeyError as e:
157
+ print(f"Skipping molecule with unsupported atom: {e}")
158
+ print(f"Skipped SMILES: {Chem.MolToSmiles(mol)}")
159
+ return None
160
+
161
+ def _genF(self, mol, max_length=None):
162
+ """
163
+ Generates additional node features for a molecule using various atomic properties.
164
+
165
+ Parameters:
166
+ mol (rdkit.Chem.Mol): The molecule.
167
+ max_length (int, optional): Number of rows in the features matrix; if None, uses number of atoms.
168
+
169
+ Returns:
170
+ np.array: Array of additional features for each atom, padded with zeros if necessary.
171
+ """
172
+ max_length = max_length if max_length is not None else mol.GetNumAtoms()
173
+ features = np.array([[*[a.GetDegree() == i for i in range(5)],
174
+ *[a.GetExplicitValence() == i for i in range(9)],
175
+ *[int(a.GetHybridization()) == i for i in range(1, 7)],
176
+ *[a.GetImplicitValence() == i for i in range(9)],
177
+ a.GetIsAromatic(),
178
+ a.GetNoImplicit(),
179
+ *[a.GetNumExplicitHs() == i for i in range(5)],
180
+ *[a.GetNumImplicitHs() == i for i in range(5)],
181
+ *[a.GetNumRadicalElectrons() == i for i in range(5)],
182
+ a.IsInRing(),
183
+ *[a.IsInRingSize(i) for i in range(2, 9)]]
184
+ for a in mol.GetAtoms()], dtype=np.int32)
185
+ return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
186
+
187
+ def decoder_load(self, dictionary_name, file):
188
+ """
189
+ Returns the pre-loaded decoder dictionary based on the dictionary name.
190
+
191
+ Parameters:
192
+ dictionary_name (str): Name of the dictionary ("atom" or "bond").
193
+ file: Placeholder parameter for compatibility.
194
+
195
+ Returns:
196
+ dict: The corresponding decoder dictionary.
197
+ """
198
+ if dictionary_name == "atom":
199
+ return self.atom_decoder_m
200
+ elif dictionary_name == "bond":
201
+ return self.bond_decoder_m
202
+ else:
203
+ raise ValueError("Unknown dictionary name.")
204
+
205
+ def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
206
+ """
207
+ Converts graph representations (node labels and edge labels) back to an RDKit molecule.
208
+
209
+ Parameters:
210
+ node_labels (iterable): Encoded atom labels.
211
+ edge_labels (np.array): Adjacency matrix with encoded bond types.
212
+ strict (bool): If True, sanitizes the molecule and returns None on failure.
213
+ file_name: Placeholder parameter for compatibility.
214
+
215
+ Returns:
216
+ rdkit.Chem.Mol: The resulting molecule, or None if sanitization fails.
217
+ """
218
+ mol = Chem.RWMol()
219
+ for node_label in node_labels:
220
+ mol.AddAtom(Chem.Atom(self.atom_decoder_m[node_label]))
221
+ for start, end in zip(*np.nonzero(edge_labels)):
222
+ if start > end:
223
+ mol.AddBond(int(start), int(end), self.bond_decoder_m[edge_labels[start, end]])
224
+ if strict:
225
+ try:
226
+ Chem.SanitizeMol(mol)
227
+ except Exception:
228
+ mol = None
229
+ return mol
230
+
231
+ def check_valency(self, mol):
232
+ """
233
+ Checks that no atom in the molecule has exceeded its allowed valency.
234
+
235
+ Parameters:
236
+ mol (rdkit.Chem.Mol): The molecule.
237
+
238
+ Returns:
239
+ tuple: (True, None) if valid; (False, atomid_valence) if there is a valency issue.
240
+ """
241
+ try:
242
+ Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
243
+ return True, None
244
+ except ValueError as e:
245
+ e = str(e)
246
+ p = e.find('#')
247
+ e_sub = e[p:]
248
+ atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
249
+ return False, atomid_valence
250
+
251
+ def correct_mol(self, mol):
252
+ """
253
+ Corrects a molecule by removing bonds until all atoms satisfy their valency limits.
254
+
255
+ Parameters:
256
+ mol (rdkit.Chem.Mol): The molecule.
257
+
258
+ Returns:
259
+ rdkit.Chem.Mol: The corrected molecule.
260
+ """
261
+ while True:
262
+ flag, atomid_valence = self.check_valency(mol)
263
+ if flag:
264
+ break
265
+ else:
266
+ # Expecting two numbers: atom index and its valence.
267
+ assert len(atomid_valence) == 2
268
+ idx = atomid_valence[0]
269
+ queue = []
270
+ for b in mol.GetAtomWithIdx(idx).GetBonds():
271
+ queue.append((b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
272
+ queue.sort(key=lambda tup: tup[1], reverse=True)
273
+ if queue:
274
+ start = queue[0][2]
275
+ end = queue[0][3]
276
+ mol.RemoveBond(start, end)
277
+ return mol
278
+
279
+
280
+ def process(self, size=None):
281
+ """
282
+ Processes the raw SMILES file by filtering and converting each valid SMILES into a PyTorch Geometric Data object.
283
+ The resulting dataset is saved to disk.
284
+
285
+ Parameters:
286
+ size (optional): Placeholder parameter for compatibility.
287
+
288
+ Side Effects:
289
+ Saves the processed dataset as a file in the processed directory.
290
+ """
291
+ # Read raw SMILES from file (assuming CSV with no header)
292
+ smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
293
+ max_length, filtered_smiles = self._filter_smiles(smiles_list)
294
+ data_list = []
295
+ self.m_dim = len(self.atom_decoder_m)
296
+ for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
297
+ mol = Chem.MolFromSmiles(smiles)
298
+ A = self._genA(mol, connected=True, max_length=max_length)
299
+ if A is not None:
300
+ x_array = self._genX(mol, max_length=max_length)
301
+ if x_array is None:
302
+ continue
303
+ x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
304
+ x = label2onehot(x, self.m_dim).squeeze()
305
+ if self.features:
306
+ f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
307
+ x = torch.concat((x, f), dim=-1)
308
+ adjacency = torch.from_numpy(A)
309
+ edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
310
+ edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
311
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
312
+ if self.pre_filter is not None and not self.pre_filter(data):
313
+ continue
314
+ if self.pre_transform is not None:
315
+ data = self.pre_transform(data)
316
+ data_list.append(data)
317
+ torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
src/data/utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+
7
+ import torch
8
+ from torch_geometric.data import Data, InMemoryDataset
9
+ import torch_geometric.utils as geoutils
10
+
11
+ from rdkit import Chem, RDLogger
12
+
13
+
14
+
15
+ def label2onehot(labels, dim, device=None):
16
+ """Convert label indices to one-hot vectors."""
17
+ out = torch.zeros(list(labels.size())+[dim])
18
+ if device:
19
+ out = out.to(device)
20
+
21
+ out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
22
+
23
+ return out.float()
24
+
25
+
26
+ def get_encoders_decoders(raw_file1, raw_file2, max_atom):
27
+ """
28
+ Given two raw SMILES files, either load the atom and bond encoders/decoders
29
+ if they exist (naming them based on the file names) or create and save them.
30
+
31
+ Parameters:
32
+ raw_file1 (str): Path to the first SMILES file.
33
+ raw_file2 (str): Path to the second SMILES file.
34
+ max_atom (int): Maximum allowed number of atoms in a molecule.
35
+
36
+ Returns:
37
+ atom_encoder (dict): Mapping from atomic numbers to indices.
38
+ atom_decoder (dict): Mapping from indices to atomic numbers.
39
+ bond_encoder (dict): Mapping from bond types to indices.
40
+ bond_decoder (dict): Mapping from indices to bond types.
41
+ """
42
+ # Determine unique suffix based on the two file names (alphabetically sorted for consistency)
43
+ name1 = os.path.splitext(os.path.basename(raw_file1))[0]
44
+ name2 = os.path.splitext(os.path.basename(raw_file2))[0]
45
+ sorted_names = sorted([name1, name2])
46
+ suffix = f"{sorted_names[0]}_{sorted_names[1]}"
47
+
48
+ # Define encoder/decoder directories and file paths
49
+ enc_dir = os.path.join("data", "encoders")
50
+ dec_dir = os.path.join("data", "decoders")
51
+ atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl")
52
+ atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl")
53
+ bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl")
54
+ bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl")
55
+
56
+ # If all files exist, load and return them
57
+ if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and
58
+ os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)):
59
+ with open(atom_encoder_path, "rb") as f:
60
+ atom_encoder = pickle.load(f)
61
+ with open(atom_decoder_path, "rb") as f:
62
+ atom_decoder = pickle.load(f)
63
+ with open(bond_encoder_path, "rb") as f:
64
+ bond_encoder = pickle.load(f)
65
+ with open(bond_decoder_path, "rb") as f:
66
+ bond_decoder = pickle.load(f)
67
+ print("Loaded existing encoders/decoders!")
68
+ return atom_encoder, atom_decoder, bond_encoder, bond_decoder
69
+
70
+ # Otherwise, create the encoders/decoders
71
+ print("Creating new encoders/decoders...")
72
+ # Read SMILES from both files (assuming one SMILES per row, no header)
73
+ smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist()
74
+ smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist()
75
+ smiles_combined = smiles1 + smiles2
76
+
77
+ atom_labels = set()
78
+ bond_labels = set()
79
+ max_length = 0
80
+ filtered_smiles = []
81
+
82
+ # Process each SMILES: keep only valid molecules with <= max_atom atoms
83
+ for smiles in tqdm(smiles_combined, desc="Processing SMILES"):
84
+ mol = Chem.MolFromSmiles(smiles)
85
+ if mol is None:
86
+ continue
87
+ molecule_size = mol.GetNumAtoms()
88
+ if molecule_size > max_atom:
89
+ continue
90
+ filtered_smiles.append(smiles)
91
+ # Collect atomic numbers
92
+ atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
93
+ max_length = max(max_length, molecule_size)
94
+ # Collect bond types
95
+ bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
96
+
97
+ # Add a PAD symbol (here using 0 for atoms)
98
+ atom_labels.add(0)
99
+ atom_labels = sorted(atom_labels)
100
+
101
+ # For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO)
102
+ bond_labels = sorted(bond_labels)
103
+ bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
104
+
105
+ # Create encoder and decoder dictionaries
106
+ atom_encoder = {l: i for i, l in enumerate(atom_labels)}
107
+ atom_decoder = {i: l for i, l in enumerate(atom_labels)}
108
+ bond_encoder = {l: i for i, l in enumerate(bond_labels)}
109
+ bond_decoder = {i: l for i, l in enumerate(bond_labels)}
110
+
111
+ # Ensure directories exist
112
+ os.makedirs(enc_dir, exist_ok=True)
113
+ os.makedirs(dec_dir, exist_ok=True)
114
+
115
+ # Save the encoders/decoders to disk
116
+ with open(atom_encoder_path, "wb") as f:
117
+ pickle.dump(atom_encoder, f)
118
+ with open(atom_decoder_path, "wb") as f:
119
+ pickle.dump(atom_decoder, f)
120
+ with open(bond_encoder_path, "wb") as f:
121
+ pickle.dump(bond_encoder, f)
122
+ with open(bond_decoder_path, "wb") as f:
123
+ pickle.dump(bond_decoder, f)
124
+
125
+ print("Encoders/decoders created and saved.")
126
+ return atom_encoder, atom_decoder, bond_encoder, bond_decoder
127
+
128
+ def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
129
+ data = data.to(device)
130
+ a = geoutils.to_dense_adj(
131
+ edge_index = data.edge_index,
132
+ batch=data.batch,
133
+ edge_attr=data.edge_attr,
134
+ max_num_nodes=int(data.batch.shape[0]/batch_size)
135
+ )
136
+ x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
137
+ a_tensor = label2onehot(a, b_dim, device)
138
+
139
+ a_tensor_vec = a_tensor.reshape(batch_size,-1)
140
+ x_tensor_vec = x_tensor.reshape(batch_size,-1)
141
+ real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
142
+
143
+ return real_graphs, a_tensor, x_tensor
src/model/__init__.py ADDED
File without changes
src/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (156 Bytes). View file
 
src/model/__pycache__/layers.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
src/model/__pycache__/loss.cpython-310.pyc ADDED
Binary file (2.04 kB). View file
 
src/model/__pycache__/models.cpython-310.pyc ADDED
Binary file (7.35 kB). View file
 
src/model/layers.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ class MLP(nn.Module):
8
+ """
9
+ A simple Multi-Layer Perceptron (MLP) module consisting of two linear layers with a ReLU activation in between,
10
+ followed by a dropout on the output.
11
+
12
+ Attributes:
13
+ fc1 (nn.Linear): The first fully-connected layer.
14
+ act (nn.ReLU): ReLU activation function.
15
+ fc2 (nn.Linear): The second fully-connected layer.
16
+ droprateout (nn.Dropout): Dropout layer applied to the output.
17
+ """
18
+ def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.):
19
+ """
20
+ Initializes the MLP module.
21
+
22
+ Args:
23
+ in_feat (int): Number of input features.
24
+ hid_feat (int, optional): Number of hidden features. Defaults to in_feat if not provided.
25
+ out_feat (int, optional): Number of output features. Defaults to in_feat if not provided.
26
+ dropout (float, optional): Dropout rate. Defaults to 0.
27
+ """
28
+ super().__init__()
29
+
30
+ # Set hidden and output dimensions to input dimension if not specified
31
+ if not hid_feat:
32
+ hid_feat = in_feat
33
+ if not out_feat:
34
+ out_feat = in_feat
35
+
36
+ self.fc1 = nn.Linear(in_feat, hid_feat)
37
+ self.act = nn.ReLU()
38
+ self.fc2 = nn.Linear(hid_feat, out_feat)
39
+ self.droprateout = nn.Dropout(dropout)
40
+
41
+ def forward(self, x):
42
+ """
43
+ Forward pass for the MLP.
44
+
45
+ Args:
46
+ x (torch.Tensor): Input tensor.
47
+
48
+ Returns:
49
+ torch.Tensor: Output tensor after applying the linear layers, activation, and dropout.
50
+ """
51
+ x = self.fc1(x)
52
+ x = self.act(x)
53
+ x = self.fc2(x)
54
+ return self.droprateout(x)
55
+
56
+ class MHA(nn.Module):
57
+ """
58
+ Multi-Head Attention (MHA) module of the graph transformer with edge features incorporated into the attention computation.
59
+
60
+ Attributes:
61
+ heads (int): Number of attention heads.
62
+ scale (float): Scaling factor for the attention scores.
63
+ q, k, v (nn.Linear): Linear layers to project the node features into query, key, and value embeddings.
64
+ e (nn.Linear): Linear layer to project the edge features.
65
+ d_k (int): Dimension of each attention head.
66
+ out_e (nn.Linear): Linear layer applied to the computed edge features.
67
+ out_n (nn.Linear): Linear layer applied to the aggregated node features.
68
+ """
69
+ def __init__(self, dim, heads, attention_dropout=0.):
70
+ """
71
+ Initializes the Multi-Head Attention module.
72
+
73
+ Args:
74
+ dim (int): Dimensionality of the input features.
75
+ heads (int): Number of attention heads.
76
+ attention_dropout (float, optional): Dropout rate for attention (not used explicitly in this implementation).
77
+ """
78
+ super().__init__()
79
+
80
+ # Ensure that dimension is divisible by the number of heads
81
+ assert dim % heads == 0
82
+
83
+ self.heads = heads
84
+ self.scale = 1. / math.sqrt(dim) # Scaling factor for attention
85
+ # Linear layers for projecting node features
86
+ self.q = nn.Linear(dim, dim)
87
+ self.k = nn.Linear(dim, dim)
88
+ self.v = nn.Linear(dim, dim)
89
+ # Linear layer for projecting edge features
90
+ self.e = nn.Linear(dim, dim)
91
+ self.d_k = dim // heads # Dimension per head
92
+
93
+ # Linear layers for output transformations
94
+ self.out_e = nn.Linear(dim, dim)
95
+ self.out_n = nn.Linear(dim, dim)
96
+
97
+ def forward(self, node, edge):
98
+ """
99
+ Forward pass for the Multi-Head Attention.
100
+
101
+ Args:
102
+ node (torch.Tensor): Node feature tensor of shape (batch, num_nodes, dim).
103
+ edge (torch.Tensor): Edge feature tensor of shape (batch, num_nodes, num_nodes, dim).
104
+
105
+ Returns:
106
+ tuple: (updated node features, updated edge features)
107
+ """
108
+ b, n, c = node.shape
109
+
110
+ # Compute query, key, and value embeddings and reshape for multi-head attention
111
+ q_embed = self.q(node).view(b, n, self.heads, c // self.heads)
112
+ k_embed = self.k(node).view(b, n, self.heads, c // self.heads)
113
+ v_embed = self.v(node).view(b, n, self.heads, c // self.heads)
114
+
115
+ # Compute edge embeddings
116
+ e_embed = self.e(edge).view(b, n, n, self.heads, c // self.heads)
117
+
118
+ # Adjust dimensions for broadcasting: add singleton dimensions to queries and keys
119
+ q_embed = q_embed.unsqueeze(2) # Shape: (b, n, 1, heads, c//heads)
120
+ k_embed = k_embed.unsqueeze(1) # Shape: (b, 1, n, heads, c//heads)
121
+
122
+ # Compute attention scores
123
+ attn = q_embed * k_embed
124
+ attn = attn / math.sqrt(self.d_k)
125
+ attn = attn * (e_embed + 1) * e_embed # Modulated attention incorporating edge features
126
+
127
+ edge_out = self.out_e(attn.flatten(3)) # Flatten last dimension for linear layer
128
+
129
+ # Apply softmax over the node dimension to obtain normalized attention weights
130
+ attn = F.softmax(attn, dim=2)
131
+
132
+ v_embed = v_embed.unsqueeze(1) # Adjust dimensions to broadcast: (b, 1, n, heads, c//heads)
133
+ v_embed = attn * v_embed
134
+ v_embed = v_embed.sum(dim=2).flatten(2)
135
+ node_out = self.out_n(v_embed)
136
+
137
+ return node_out, edge_out
138
+
139
+ class Encoder_Block(nn.Module):
140
+ """
141
+ Transformer encoder block that integrates node and edge features.
142
+
143
+ Consists of:
144
+ - A multi-head attention layer with edge modulation.
145
+ - Two MLP layers, each with residual connections and layer normalization.
146
+
147
+ Attributes:
148
+ ln1, ln3, ln4, ln5, ln6 (nn.LayerNorm): Layer normalization modules.
149
+ attn (MHA): Multi-head attention module.
150
+ mlp, mlp2 (MLP): MLP modules for further transformation of node and edge features.
151
+ """
152
+ def __init__(self, dim, heads, act, mlp_ratio=4, drop_rate=0.):
153
+ """
154
+ Initializes the encoder block.
155
+
156
+ Args:
157
+ dim (int): Dimensionality of the input features.
158
+ heads (int): Number of attention heads.
159
+ act (callable): Activation function (not explicitly used in this block, but provided for potential extensions).
160
+ mlp_ratio (int, optional): Ratio to determine the hidden layer size in the MLP. Defaults to 4.
161
+ drop_rate (float, optional): Dropout rate applied in the MLPs. Defaults to 0.
162
+ """
163
+ super().__init__()
164
+
165
+ self.ln1 = nn.LayerNorm(dim)
166
+ self.attn = MHA(dim, heads, drop_rate)
167
+ self.ln3 = nn.LayerNorm(dim)
168
+ self.ln4 = nn.LayerNorm(dim)
169
+ self.mlp = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
170
+ self.mlp2 = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
171
+ self.ln5 = nn.LayerNorm(dim)
172
+ self.ln6 = nn.LayerNorm(dim)
173
+
174
+ def forward(self, x, y):
175
+ """
176
+ Forward pass of the encoder block.
177
+
178
+ Args:
179
+ x (torch.Tensor): Node feature tensor.
180
+ y (torch.Tensor): Edge feature tensor.
181
+
182
+ Returns:
183
+ tuple: (updated node features, updated edge features)
184
+ """
185
+ x1 = self.ln1(x)
186
+ x2, y1 = self.attn(x1, y)
187
+ x2 = x1 + x2
188
+ y2 = y + y1
189
+ x2 = self.ln3(x2)
190
+ y2 = self.ln4(y2)
191
+ x = self.ln5(x2 + self.mlp(x2))
192
+ y = self.ln6(y2 + self.mlp2(y2))
193
+ return x, y
194
+
195
+ class TransformerEncoder(nn.Module):
196
+ """
197
+ Transformer Encoder composed of a sequence of encoder blocks.
198
+
199
+ Attributes:
200
+ Encoder_Blocks (nn.ModuleList): A list of Encoder_Block modules stacked sequentially.
201
+ """
202
+ def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1):
203
+ """
204
+ Initializes the Transformer Encoder.
205
+
206
+ Args:
207
+ dim (int): Dimensionality of the input features.
208
+ depth (int): Number of encoder blocks to stack.
209
+ heads (int): Number of attention heads in each block.
210
+ act (callable): Activation function (passed to encoder blocks for potential use).
211
+ mlp_ratio (int, optional): Ratio for determining the hidden layer size in MLP modules. Defaults to 4.
212
+ drop_rate (float, optional): Dropout rate for the MLPs within each block. Defaults to 0.1.
213
+ """
214
+ super().__init__()
215
+
216
+ self.Encoder_Blocks = nn.ModuleList([
217
+ Encoder_Block(dim, heads, act, mlp_ratio, drop_rate)
218
+ for _ in range(depth)
219
+ ])
220
+
221
+ def forward(self, x, y):
222
+ """
223
+ Forward pass of the Transformer Encoder.
224
+
225
+ Args:
226
+ x (torch.Tensor): Node feature tensor.
227
+ y (torch.Tensor): Edge feature tensor.
228
+
229
+ Returns:
230
+ tuple: (final node features, final edge features) after processing through all encoder blocks.
231
+ """
232
+ for block in self.Encoder_Blocks:
233
+ x, y = block(x, y)
234
+ return x, y
src/model/loss.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device):
5
+ """
6
+ Calculate gradient penalty for WGAN-GP.
7
+
8
+ Args:
9
+ discriminator: The discriminator model
10
+ real_node: Real node features
11
+ real_edge: Real edge features
12
+ fake_node: Generated node features
13
+ fake_edge: Generated edge features
14
+ batch_size: Batch size
15
+ device: Device to compute on
16
+
17
+ Returns:
18
+ Gradient penalty term
19
+ """
20
+ # Generate random interpolation factors
21
+ eps_edge = torch.rand(batch_size, 1, 1, 1, device=device)
22
+ eps_node = torch.rand(batch_size, 1, 1, device=device)
23
+
24
+ # Create interpolated samples
25
+ int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True)
26
+ int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True)
27
+
28
+ logits_interpolated = discriminator(int_edge, int_node)
29
+
30
+ # Calculate gradients for both node and edge inputs
31
+ weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device)
32
+ gradients = torch.autograd.grad(
33
+ outputs=logits_interpolated,
34
+ inputs=[int_node, int_edge],
35
+ grad_outputs=weight,
36
+ create_graph=True,
37
+ retain_graph=True,
38
+ only_inputs=True
39
+ )
40
+
41
+ # Combine gradients from both inputs
42
+ gradients_node = gradients[0].view(batch_size, -1)
43
+ gradients_edge = gradients[1].view(batch_size, -1)
44
+ gradients = torch.cat([gradients_node, gradients_edge], dim=1)
45
+
46
+ # Calculate gradient penalty
47
+ gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
48
+
49
+ return gradient_penalty
50
+
51
+
52
+ def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp):
53
+ # Compute loss for drugs
54
+ logits_real_disc = discriminator(drug_adj, drug_annot)
55
+
56
+ # Use mean reduction for more stable training
57
+ prediction_real = -torch.mean(logits_real_disc)
58
+
59
+ # Compute loss for generated molecules
60
+ node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
61
+
62
+ logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach())
63
+
64
+ prediction_fake = torch.mean(logits_fake_disc)
65
+
66
+ # Compute gradient penalty using the new function
67
+ gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device)
68
+
69
+ # Calculate total discriminator loss
70
+ d_loss = prediction_fake + prediction_real + lambda_gp * gp
71
+
72
+ return node, edge, d_loss
73
+
74
+
75
+ def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size):
76
+ # Generate fake molecules
77
+ node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
78
+
79
+ # Compute logits for fake molecules
80
+ logits_fake_disc = discriminator(edge_sample, node_sample)
81
+
82
+ prediction_fake = -torch.mean(logits_fake_disc)
83
+ g_loss = prediction_fake
84
+
85
+ return g_loss, node, edge, node_sample, edge_sample
src/model/models.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.model.layers import TransformerEncoder
4
+
5
+ class Generator(nn.Module):
6
+ """
7
+ Generator network that uses a Transformer Encoder to process node and edge features.
8
+
9
+ The network first processes input node and edge features with separate linear layers,
10
+ then applies a Transformer Encoder to model interactions, and finally outputs both transformed
11
+ features and readout samples.
12
+ """
13
+ def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
14
+ """
15
+ Initializes the Generator.
16
+
17
+ Args:
18
+ act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh").
19
+ vertexes (int): Number of vertexes in the graph.
20
+ edges (int): Number of edge features.
21
+ nodes (int): Number of node features.
22
+ dropout (float): Dropout rate.
23
+ dim (int): Dimensionality used for intermediate features.
24
+ depth (int): Number of Transformer encoder blocks.
25
+ heads (int): Number of attention heads in the Transformer.
26
+ mlp_ratio (int): Ratio for determining hidden layer size in MLP modules.
27
+ """
28
+ super(Generator, self).__init__()
29
+ self.vertexes = vertexes
30
+ self.edges = edges
31
+ self.nodes = nodes
32
+ self.depth = depth
33
+ self.dim = dim
34
+ self.heads = heads
35
+ self.mlp_ratio = mlp_ratio
36
+ self.dropout = dropout
37
+
38
+ # Set the activation function based on the provided string
39
+ if act == "relu":
40
+ act = nn.ReLU()
41
+ elif act == "leaky":
42
+ act = nn.LeakyReLU()
43
+ elif act == "sigmoid":
44
+ act = nn.Sigmoid()
45
+ elif act == "tanh":
46
+ act = nn.Tanh()
47
+
48
+ # Calculate the total number of features and dimensions for transformer
49
+ self.features = vertexes * vertexes * edges + vertexes * nodes
50
+ self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
51
+
52
+ self.node_layers = nn.Sequential(
53
+ nn.Linear(nodes, 64), act,
54
+ nn.Linear(64, dim), act,
55
+ nn.Dropout(self.dropout)
56
+ )
57
+ self.edge_layers = nn.Sequential(
58
+ nn.Linear(edges, 64), act,
59
+ nn.Linear(64, dim), act,
60
+ nn.Dropout(self.dropout)
61
+ )
62
+ self.TransformerEncoder = TransformerEncoder(
63
+ dim=self.dim, depth=self.depth, heads=self.heads, act=act,
64
+ mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
65
+ )
66
+
67
+ self.readout_e = nn.Linear(self.dim, edges)
68
+ self.readout_n = nn.Linear(self.dim, nodes)
69
+ self.softmax = nn.Softmax(dim=-1)
70
+
71
+ def forward(self, z_e, z_n):
72
+ """
73
+ Forward pass of the Generator.
74
+
75
+ Args:
76
+ z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
77
+ z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
78
+
79
+ Returns:
80
+ tuple: A tuple containing:
81
+ - node: Updated node features after the transformer.
82
+ - edge: Updated edge features after the transformer.
83
+ - node_sample: Readout sample from node features.
84
+ - edge_sample: Readout sample from edge features.
85
+ """
86
+ b, n, c = z_n.shape
87
+ # The fourth dimension of edge features
88
+ _, _, _, d = z_e.shape
89
+
90
+ # Process node and edge features through their respective layers
91
+ node = self.node_layers(z_n)
92
+ edge = self.edge_layers(z_e)
93
+ # Symmetrize the edge features by averaging with its transpose along vertex dimensions
94
+ edge = (edge + edge.permute(0, 2, 1, 3)) / 2
95
+
96
+ # Pass the features through the Transformer Encoder
97
+ node, edge = self.TransformerEncoder(node, edge)
98
+
99
+ # Readout layers to generate final outputs
100
+ node_sample = self.readout_n(node)
101
+ edge_sample = self.readout_e(edge)
102
+
103
+ return node, edge, node_sample, edge_sample
104
+
105
+
106
+ class Discriminator(nn.Module):
107
+ """
108
+ Discriminator network that evaluates node and edge features.
109
+
110
+ It processes features with linear layers, applies a Transformer Encoder to capture dependencies,
111
+ and finally predicts a scalar value using an MLP on aggregated node features.
112
+
113
+ This class is used in DrugGEN model.
114
+ """
115
+ def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
116
+ """
117
+ Initializes the Discriminator.
118
+
119
+ Args:
120
+ act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
121
+ vertexes (int): Number of vertexes.
122
+ edges (int): Number of edge features.
123
+ nodes (int): Number of node features.
124
+ dropout (float): Dropout rate.
125
+ dim (int): Dimensionality for intermediate representations.
126
+ depth (int): Number of Transformer encoder blocks.
127
+ heads (int): Number of attention heads.
128
+ mlp_ratio (int): MLP ratio for hidden layer dimensions.
129
+ """
130
+ super(Discriminator, self).__init__()
131
+ self.vertexes = vertexes
132
+ self.edges = edges
133
+ self.nodes = nodes
134
+ self.depth = depth
135
+ self.dim = dim
136
+ self.heads = heads
137
+ self.mlp_ratio = mlp_ratio
138
+ self.dropout = dropout
139
+
140
+ # Set the activation function
141
+ if act == "relu":
142
+ act = nn.ReLU()
143
+ elif act == "leaky":
144
+ act = nn.LeakyReLU()
145
+ elif act == "sigmoid":
146
+ act = nn.Sigmoid()
147
+ elif act == "tanh":
148
+ act = nn.Tanh()
149
+
150
+ self.features = vertexes * vertexes * edges + vertexes * nodes
151
+ self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
152
+
153
+ # Define layers for processing node and edge features
154
+ self.node_layers = nn.Sequential(
155
+ nn.Linear(nodes, 64), act,
156
+ nn.Linear(64, dim), act,
157
+ nn.Dropout(self.dropout)
158
+ )
159
+ self.edge_layers = nn.Sequential(
160
+ nn.Linear(edges, 64), act,
161
+ nn.Linear(64, dim), act,
162
+ nn.Dropout(self.dropout)
163
+ )
164
+ # Transformer Encoder for modeling node and edge interactions
165
+ self.TransformerEncoder = TransformerEncoder(
166
+ dim=self.dim, depth=self.depth, heads=self.heads, act=act,
167
+ mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
168
+ )
169
+ # Calculate dimensions for node features aggregation
170
+ self.node_features = vertexes * dim
171
+ self.edge_features = vertexes * vertexes * dim
172
+ # MLP to predict a scalar value from aggregated node features
173
+ self.node_mlp = nn.Sequential(
174
+ nn.Linear(self.node_features, 64), act,
175
+ nn.Linear(64, 32), act,
176
+ nn.Linear(32, 16), act,
177
+ nn.Linear(16, 1)
178
+ )
179
+
180
+ def forward(self, z_e, z_n):
181
+ """
182
+ Forward pass of the Discriminator.
183
+
184
+ Args:
185
+ z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
186
+ z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
187
+
188
+ Returns:
189
+ torch.Tensor: Prediction scores (typically a scalar per sample).
190
+ """
191
+ b, n, c = z_n.shape
192
+ # Unpack the shape of edge features (not used further directly)
193
+ _, _, _, d = z_e.shape
194
+
195
+ # Process node and edge features separately
196
+ node = self.node_layers(z_n)
197
+ edge = self.edge_layers(z_e)
198
+ # Symmetrize edge features by averaging with its transpose
199
+ edge = (edge + edge.permute(0, 2, 1, 3)) / 2
200
+
201
+ # Process features through the Transformer Encoder
202
+ node, edge = self.TransformerEncoder(node, edge)
203
+
204
+ # Flatten node features for MLP
205
+ node = node.view(b, -1)
206
+ # Predict a scalar score using the node MLP
207
+ prediction = self.node_mlp(node)
208
+
209
+ return prediction
210
+
211
+
212
+ class simple_disc(nn.Module):
213
+ """
214
+ A simplified discriminator that processes flattened features through an MLP
215
+ to predict a scalar score.
216
+
217
+ This class is used in NoTarget model.
218
+ """
219
+ def __init__(self, act, m_dim, vertexes, b_dim):
220
+ """
221
+ Initializes the simple discriminator.
222
+
223
+ Args:
224
+ act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
225
+ m_dim (int): Dimensionality for atom type features.
226
+ vertexes (int): Number of vertexes.
227
+ b_dim (int): Dimensionality for bond type features.
228
+ """
229
+ super().__init__()
230
+
231
+ # Set the activation function and check if it's supported
232
+ if act == "relu":
233
+ act = nn.ReLU()
234
+ elif act == "leaky":
235
+ act = nn.LeakyReLU()
236
+ elif act == "sigmoid":
237
+ act = nn.Sigmoid()
238
+ elif act == "tanh":
239
+ act = nn.Tanh()
240
+ else:
241
+ raise ValueError("Unsupported activation function: {}".format(act))
242
+
243
+ # Compute total number of features combining both dimensions
244
+ features = vertexes * m_dim + vertexes * vertexes * b_dim
245
+ print(vertexes)
246
+ print(m_dim)
247
+ print(b_dim)
248
+ print(features)
249
+ self.predictor = nn.Sequential(
250
+ nn.Linear(features, 256), act,
251
+ nn.Linear(256, 128), act,
252
+ nn.Linear(128, 64), act,
253
+ nn.Linear(64, 32), act,
254
+ nn.Linear(32, 16), act,
255
+ nn.Linear(16, 1)
256
+ )
257
+
258
+ def forward(self, x):
259
+ """
260
+ Forward pass of the simple discriminator.
261
+
262
+ Args:
263
+ x (torch.Tensor): Input features tensor.
264
+
265
+ Returns:
266
+ torch.Tensor: Prediction scores.
267
+ """
268
+ prediction = self.predictor(x)
269
+ return prediction
src/util/__init__.py ADDED
File without changes
src/util/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (155 Bytes). View file
 
src/util/__pycache__/smiles_cor.cpython-310.pyc ADDED
Binary file (30.2 kB). View file
 
src/util/__pycache__/utils.cpython-310.pyc ADDED
Binary file (30 kB). View file
 
src/util/smiles_cor.py ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import re
5
+ import itertools
6
+ import statistics
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.optim as optim
13
+ from torch.utils.data import DataLoader
14
+ from torchtext.data import TabularDataset, Field, BucketIterator, Iterator
15
+
16
+ from rdkit import Chem, rdBase, RDLogger
17
+ from rdkit.Chem import (
18
+ MolStandardize,
19
+ GraphDescriptors,
20
+ Lipinski,
21
+ AllChem,
22
+ )
23
+ from rdkit.Chem.rdSLNParse import MolFromSLN
24
+ from rdkit.Chem.rdmolfiles import MolFromSmiles
25
+ from chembl_structure_pipeline import standardizer
26
+
27
+ RDLogger.DisableLog('rdApp.*')
28
+
29
+ SEED = 42
30
+ random.seed(SEED)
31
+ torch.manual_seed(SEED)
32
+ torch.backends.cudnn.deterministic = True
33
+
34
+ ##################################################################################################
35
+ ##################################################################################################
36
+ # #
37
+ #  THIS SCRIPT IS DIRECTLY ADAPTED FROM https://github.com/LindeSchoenmaker/SMILES-corrector #
38
+ # #
39
+ ##################################################################################################
40
+ ##################################################################################################
41
+ def is_smiles(array,
42
+ TRG,
43
+ reverse: bool,
44
+ return_output=False,
45
+ src=None,
46
+ src_field=None):
47
+ """Turns predicted tokens within batch into smiles and evaluates their validity
48
+ Arguments:
49
+ array: Tensor with most probable token for each location for each sequence in batch
50
+ [trg len, batch size]
51
+ TRG: target field for getting tokens from vocab
52
+ reverse (bool): True if the target sequence is reversed
53
+ return_output (bool): True if output sequences and their validity should be saved
54
+ Returns:
55
+ df: dataframe with correct and incorrect sequences
56
+ valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
57
+ smiless: list of the predicted smiles
58
+ """
59
+ trg_field = TRG
60
+ valids = []
61
+ smiless = []
62
+ if return_output:
63
+ df = pd.DataFrame()
64
+ else:
65
+ df = None
66
+ batch_size = array.size(1)
67
+ # check if the first token should be removed, first token is zero because
68
+ # outputs initaliazed to all be zeros
69
+ if int((array[0, 0]).tolist()) == 0:
70
+ start = 1
71
+ else:
72
+ start = 0
73
+ # for each sequence in the batch
74
+ for i in range(0, batch_size):
75
+ # turns sequence from tensor to list skipps first row as this is not
76
+ # filled in in forward
77
+ sequence = (array[start:, i]).tolist()
78
+ # goes from embedded to tokens
79
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
80
+ # print(trg_tokens)
81
+ # takes all tokens untill eos token, model would be faster if did this
82
+ # one step earlier, but then changes in vocab order would disrupt.
83
+ rev_tokens = list(
84
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
85
+ if reverse:
86
+ rev_tokens = rev_tokens[::-1]
87
+ smiles = "".join(rev_tokens)
88
+ # determine how many valid smiles are made
89
+ valid = True if MolFromSmiles(smiles) else False
90
+ valids.append(valid)
91
+ smiless.append(smiles)
92
+ if return_output:
93
+ if valid:
94
+ df.loc[i, "CORRECT"] = smiles
95
+ else:
96
+ df.loc[i, "INCORRECT"] = smiles
97
+
98
+ # add the original drugex outputs to the _de dataframe
99
+ if return_output and src is not None:
100
+ for i in range(0, batch_size):
101
+ # turns sequence from tensor to list skipps first row as this is
102
+ # <sos> for src
103
+ sequence = (src[1:, i]).tolist()
104
+ # goes from embedded to tokens
105
+ src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
106
+ # takes all tokens untill eos token, model would be faster if did
107
+ # this one step earlier, but then changes in vocab order would
108
+ # disrupt.
109
+ rev_tokens = list(
110
+ itertools.takewhile(lambda x: x != "<eos>", src_tokens))
111
+ smiles = "".join(rev_tokens)
112
+ df.loc[i, "ORIGINAL"] = smiles
113
+
114
+ return df, valids, smiless
115
+
116
+
117
+ def is_unchanged(array,
118
+ TRG,
119
+ reverse: bool,
120
+ return_output=False,
121
+ src=None,
122
+ src_field=None):
123
+ """Checks is output is different from input
124
+ Arguments:
125
+ array: Tensor with most probable token for each location for each sequence in batch
126
+ [trg len, batch size]
127
+ TRG: target field for getting tokens from vocab
128
+ reverse (bool): True if the target sequence is reversed
129
+ return_output (bool): True if output sequences and their validity should be saved
130
+ Returns:
131
+ df: dataframe with correct and incorrect sequences
132
+ valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
133
+ smiless: list of the predicted smiles
134
+ """
135
+ trg_field = TRG
136
+ sources = []
137
+ batch_size = array.size(1)
138
+ unchanged = 0
139
+
140
+ # check if the first token should be removed, first token is zero because
141
+ # outputs initaliazed to all be zeros
142
+ if int((array[0, 0]).tolist()) == 0:
143
+ start = 1
144
+ else:
145
+ start = 0
146
+
147
+ for i in range(0, batch_size):
148
+ # turns sequence from tensor to list skipps first row as this is <sos>
149
+ # for src
150
+ sequence = (src[1:, i]).tolist()
151
+ # goes from embedded to tokens
152
+ src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
153
+ # takes all tokens untill eos token, model would be faster if did this
154
+ # one step earlier, but then changes in vocab order would disrupt.
155
+ rev_tokens = list(
156
+ itertools.takewhile(lambda x: x != "<eos>", src_tokens))
157
+ smiles = "".join(rev_tokens)
158
+ sources.append(smiles)
159
+
160
+ # for each sequence in the batch
161
+ for i in range(0, batch_size):
162
+ # turns sequence from tensor to list skipps first row as this is not
163
+ # filled in in forward
164
+ sequence = (array[start:, i]).tolist()
165
+ # goes from embedded to tokens
166
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
167
+ # print(trg_tokens)
168
+ # takes all tokens untill eos token, model would be faster if did this
169
+ # one step earlier, but then changes in vocab order would disrupt.
170
+ rev_tokens = list(
171
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
172
+ if reverse:
173
+ rev_tokens = rev_tokens[::-1]
174
+ smiles = "".join(rev_tokens)
175
+ # determine how many valid smiles are made
176
+ valid = True if MolFromSmiles(smiles) else False
177
+ if not valid:
178
+ if smiles == sources[i]:
179
+ unchanged += 1
180
+
181
+ return unchanged
182
+
183
+
184
+ def molecule_reconstruction(array, TRG, reverse: bool, outputs):
185
+ """Turns target tokens within batch into smiles and compares them to predicted output smiles
186
+ Arguments:
187
+ array: Tensor with target's token for each location for each sequence in batch
188
+ [trg len, batch size]
189
+ TRG: target field for getting tokens from vocab
190
+ reverse (bool): True if the target sequence is reversed
191
+ outputs: list of predicted SMILES sequences
192
+ Returns:
193
+ matches(int): number of total right molecules
194
+ """
195
+ trg_field = TRG
196
+ matches = 0
197
+ targets = []
198
+ batch_size = array.size(1)
199
+ # for each sequence in the batch
200
+ for i in range(0, batch_size):
201
+ # turns sequence from tensor to list skipps first row as this is not
202
+ # filled in in forward
203
+ sequence = (array[1:, i]).tolist()
204
+ # goes from embedded to tokens
205
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
206
+ # takes all tokens untill eos token, model would be faster if did this
207
+ # one step earlier, but then changes in vocab order would disrupt.
208
+ rev_tokens = list(
209
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
210
+ if reverse:
211
+ rev_tokens = rev_tokens[::-1]
212
+ smiles = "".join(rev_tokens)
213
+ targets.append(smiles)
214
+ for i in range(0, batch_size):
215
+ m = MolFromSmiles(targets[i])
216
+ p = MolFromSmiles(outputs[i])
217
+ if p is not None:
218
+ if m.HasSubstructMatch(p) and p.HasSubstructMatch(m):
219
+ matches += 1
220
+ return matches
221
+
222
+
223
+ def complexity_whitlock(mol: Chem.Mol, includeAllDescs=False):
224
+ """
225
+ Complexity as defined in DOI:10.1021/jo9814546
226
+ S: complexity = 4*#rings + 2*#unsat + #hetatm + 2*#chiral
227
+ Other descriptors:
228
+ H: size = #bonds (Hydrogen atoms included)
229
+ G: S + H
230
+ Ratio: S / H
231
+ """
232
+ mol_ = Chem.Mol(mol)
233
+ nrings = Lipinski.RingCount(mol_) - Lipinski.NumAromaticRings(mol_)
234
+ Chem.rdmolops.SetAromaticity(mol_)
235
+ unsat = sum(1 for bond in mol_.GetBonds()
236
+ if bond.GetBondTypeAsDouble() == 2)
237
+ hetatm = len(mol_.GetSubstructMatches(Chem.MolFromSmarts("[!#6]")))
238
+ AllChem.EmbedMolecule(mol_)
239
+ Chem.AssignAtomChiralTagsFromStructure(mol_)
240
+ chiral = len(Chem.FindMolChiralCenters(mol_))
241
+ S = 4 * nrings + 2 * unsat + hetatm + 2 * chiral
242
+ if not includeAllDescs:
243
+ return S
244
+ Chem.rdmolops.Kekulize(mol_)
245
+ mol_ = Chem.AddHs(mol_)
246
+ H = sum(bond.GetBondTypeAsDouble() for bond in mol_.GetBonds())
247
+ G = S + H
248
+ R = S / H
249
+ return {"WhitlockS": S, "WhitlockH": H, "WhitlockG": G, "WhitlockRatio": R}
250
+
251
+
252
+ def complexity_baronechanon(mol: Chem.Mol):
253
+ """
254
+ Complexity as defined in DOI:10.1021/ci000145p
255
+ """
256
+ mol_ = Chem.Mol(mol)
257
+ Chem.Kekulize(mol_)
258
+ Chem.RemoveStereochemistry(mol_)
259
+ mol_ = Chem.RemoveHs(mol_, updateExplicitCount=True)
260
+ degree, counts = 0, 0
261
+ for atom in mol_.GetAtoms():
262
+ degree += 3 * 2**(atom.GetExplicitValence() - atom.GetNumExplicitHs() -
263
+ 1)
264
+ counts += 3 if atom.GetSymbol() == "C" else 6
265
+ ringterm = sum(map(lambda x: 6 * len(x), mol_.GetRingInfo().AtomRings()))
266
+ return degree + counts + ringterm
267
+
268
+
269
+ def calc_complexity(array,
270
+ TRG,
271
+ reverse,
272
+ valids,
273
+ complexity_function=GraphDescriptors.BertzCT):
274
+ """Calculates the complexity of inputs that are not correct.
275
+ Arguments:
276
+ array: Tensor with target's token for each location for each sequence in batch
277
+ [trg len, batch size]
278
+ TRG: target field for getting tokens from vocab
279
+ reverse (bool): True if the target sequence is reversed
280
+ valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
281
+ complexity_function: the type of complexity measure that will be used
282
+ GraphDescriptors.BertzCT
283
+ complexity_whitlock
284
+ complexity_baronechanon
285
+ Returns:
286
+ matches(int): mean of complexity values
287
+ """
288
+ trg_field = TRG
289
+ sources = []
290
+ complexities = []
291
+ loc = torch.BoolTensor(valids)
292
+ # only keeps rows in batch size dimension where valid is false
293
+ array = array[:, loc == False]
294
+ # should check if this still works
295
+ # array = torch.transpose(array, 0, 1)
296
+ array_size = array.size(1)
297
+ for i in range(0, array_size):
298
+ # turns sequence from tensor to list skipps first row as this is not
299
+ # filled in in forward
300
+ sequence = (array[1:, i]).tolist()
301
+ # goes from embedded to tokens
302
+ trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
303
+ # takes all tokens untill eos token, model would be faster if did this
304
+ # one step earlier, but then changes in vocab order would disrupt.
305
+ rev_tokens = list(
306
+ itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
307
+ if reverse:
308
+ rev_tokens = rev_tokens[::-1]
309
+ smiles = "".join(rev_tokens)
310
+ sources.append(smiles)
311
+ for source in sources:
312
+ try:
313
+ m = MolFromSmiles(source)
314
+ except BaseException:
315
+ m = MolFromSLN(source)
316
+ complexities.append(complexity_function(m))
317
+ if len(complexities) > 0:
318
+ mean = statistics.mean(complexities)
319
+ else:
320
+ mean = 0
321
+ return mean
322
+
323
+
324
+ def epoch_time(start_time, end_time):
325
+ elapsed_time = end_time - start_time
326
+ elapsed_mins = int(elapsed_time / 60)
327
+ elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
328
+ return elapsed_mins, elapsed_secs
329
+
330
+
331
+ class Convo:
332
+ """Class for training and evaluating transformer and convolutional neural network
333
+
334
+ Methods
335
+ -------
336
+ train_model()
337
+ train model for initialized number of epochs
338
+ evaluate(return_output)
339
+ use model with validation loader (& optionally drugex loader) to get test loss & other metrics
340
+ translate(loader)
341
+ translate inputs from loader (different from evaluate in that no target sequence is used)
342
+ """
343
+
344
+ def train_model(self):
345
+ optimizer = optim.Adam(self.parameters(), lr=self.lr)
346
+ log = open(f"{self.out}.log", "a")
347
+ best_error = np.inf
348
+ for epoch in range(self.epochs):
349
+ self.train()
350
+ start_time = time.time()
351
+ loss_train = 0
352
+ for i, batch in enumerate(self.loader_train):
353
+ optimizer.zero_grad()
354
+ # changed src,trg call to match with bentrevett
355
+ # src, trg = batch['src'], batch['trg']
356
+ trg = batch.trg
357
+ src = batch.src
358
+ output, attention = self(src, trg[:, :-1])
359
+ # feed the source and target into def forward to get the output
360
+ # Xuhan uses forward for this, with istrain = true
361
+ output_dim = output.shape[-1]
362
+ # changed
363
+ output = output.contiguous().view(-1, output_dim)
364
+ trg = trg[:, 1:].contiguous().view(-1)
365
+ # output = output[:,:,0]#.view(-1)
366
+ # output = output[1:].view(-1, output.shape[-1])
367
+ # trg = trg[1:].view(-1)
368
+ loss = nn.CrossEntropyLoss(
369
+ ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
370
+ a, b = output.view(-1), trg.to(self.device).view(-1)
371
+ # changed
372
+ # loss = loss(output.view(0), trg.view(0).to(device))
373
+ loss = loss(output, trg)
374
+ loss.backward()
375
+ torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)
376
+ optimizer.step()
377
+ loss_train += loss.item()
378
+ # turned off for now, as not using voc so won't work, output is a tensor
379
+ # output = [(trg len - 1) * batch size, output dim]
380
+ # smiles, valid = is_valid_smiles(output, reversed)
381
+ # if valid:
382
+ # valids += 1
383
+ # smiless.append(smiles)
384
+ # added .dataset becaue len(iterator) gives len(self.dataset) /
385
+ # self.batch_size)
386
+ loss_train /= len(self.loader_train)
387
+ info = f"Epoch: {epoch+1:02} step: {i} loss_train: {loss_train:.4g}"
388
+ # model is used to generate trg based on src from the validation set to assess performance
389
+ # similar to Xuhan, although he doesn't use the if loop
390
+ if self.loader_valid is not None:
391
+ return_output = False
392
+ if epoch + 1 == self.epochs:
393
+ return_output = True
394
+ (
395
+ valids,
396
+ loss_valid,
397
+ valids_de,
398
+ df_output,
399
+ df_output_de,
400
+ right_molecules,
401
+ complexity,
402
+ unchanged,
403
+ unchanged_de,
404
+ ) = self.evaluate(return_output)
405
+ reconstruction_error = 1 - right_molecules / len(
406
+ self.loader_valid.dataset)
407
+ error = 1 - valids / len(self.loader_valid.dataset)
408
+ complexity = complexity / len(self.loader_valid)
409
+ unchan = unchanged / (len(self.loader_valid.dataset) - valids)
410
+ info += f" loss_valid: {loss_valid:.4g} error_rate: {error:.4g} molecule_reconstruction_error_rate: {reconstruction_error:.4g} unchanged: {unchan:.4g} invalid_target_complexity: {complexity:.4g}"
411
+ if self.loader_drugex is not None:
412
+ error_de = 1 - valids_de / len(self.loader_drugex.dataset)
413
+ unchan_de = unchanged_de / (
414
+ len(self.loader_drugex.dataset) - valids_de)
415
+ info += f" error_rate_drugex: {error_de:.4g} unchanged_drugex: {unchan_de:.4g}"
416
+
417
+ if reconstruction_error < best_error:
418
+ torch.save(self.state_dict(), f"{self.out}.pkg")
419
+ best_error = reconstruction_error
420
+ last_save = epoch
421
+ else:
422
+ if epoch - last_save >= 10 and best_error != 1:
423
+ torch.save(self.state_dict(), f"{self.out}_last.pkg")
424
+ (
425
+ valids,
426
+ loss_valid,
427
+ valids_de,
428
+ df_output,
429
+ df_output_de,
430
+ right_molecules,
431
+ complexity,
432
+ unchanged,
433
+ unchanged_de,
434
+ ) = self.evaluate(True)
435
+ end_time = time.time()
436
+ epoch_mins, epoch_secs = epoch_time(
437
+ start_time, end_time)
438
+ info += f" Time: {epoch_mins}m {epoch_secs}s"
439
+
440
+ break
441
+ elif error < best_error:
442
+ torch.save(self.state_dict(), f"{self.out}.pkg")
443
+ best_error = error
444
+ end_time = time.time()
445
+ epoch_mins, epoch_secs = epoch_time(start_time, end_time)
446
+ info += f" Time: {epoch_mins}m {epoch_secs}s"
447
+
448
+
449
+ torch.save(self.state_dict(), f"{self.out}_last.pkg")
450
+ log.close()
451
+ self.load_state_dict(torch.load(f"{self.out}.pkg"))
452
+ df_output.to_csv(f"{self.out}.csv", index=False)
453
+ df_output_de.to_csv(f"{self.out}_de.csv", index=False)
454
+
455
+ def evaluate(self, return_output):
456
+ self.eval()
457
+ test_loss = 0
458
+ df_output = pd.DataFrame()
459
+ df_output_de = pd.DataFrame()
460
+ valids = 0
461
+ valids_de = 0
462
+ unchanged = 0
463
+ unchanged_de = 0
464
+ right_molecules = 0
465
+ complexity = 0
466
+ with torch.no_grad():
467
+ for _, batch in enumerate(self.loader_valid):
468
+ trg = batch.trg
469
+ src = batch.src
470
+ output, attention = self.forward(src, trg[:, :-1])
471
+ pred_token = output.argmax(2)
472
+ array = torch.transpose(pred_token, 0, 1)
473
+ trg_trans = torch.transpose(trg, 0, 1)
474
+ output_dim = output.shape[-1]
475
+ output = output.contiguous().view(-1, output_dim)
476
+ trg = trg[:, 1:].contiguous().view(-1)
477
+ src_trans = torch.transpose(src, 0, 1)
478
+ df_batch, valid, smiless = is_smiles(
479
+ array, self.TRG, reverse=True, return_output=return_output)
480
+ unchanged += is_unchanged(
481
+ array,
482
+ self.TRG,
483
+ reverse=True,
484
+ return_output=return_output,
485
+ src=src_trans,
486
+ src_field=self.SRC,
487
+ )
488
+ matches = molecule_reconstruction(trg_trans,
489
+ self.TRG,
490
+ reverse=True,
491
+ outputs=smiless)
492
+ complexity += calc_complexity(trg_trans,
493
+ self.TRG,
494
+ reverse=True,
495
+ valids=valid)
496
+ if df_batch is not None:
497
+ df_output = pd.concat([df_output, df_batch],
498
+ ignore_index=True)
499
+ right_molecules += matches
500
+ valids += sum(valid)
501
+ # trg = trg[1:].view(-1)
502
+ # output, trg = output[1:].view(-1, output.shape[-1]), trg[1:].view(-1)
503
+ loss = nn.CrossEntropyLoss(
504
+ ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
505
+ loss = loss(output, trg)
506
+ test_loss += loss.item()
507
+ if self.loader_drugex is not None:
508
+ for _, batch in enumerate(self.loader_drugex):
509
+ src = batch.src
510
+ output = self.translate_sentence(src, self.TRG,
511
+ self.device)
512
+ # checks the number of valid smiles
513
+ pred_token = output.argmax(2)
514
+ array = torch.transpose(pred_token, 0, 1)
515
+ src_trans = torch.transpose(src, 0, 1)
516
+ df_batch, valid, smiless = is_smiles(
517
+ array,
518
+ self.TRG,
519
+ reverse=True,
520
+ return_output=return_output,
521
+ src=src_trans,
522
+ src_field=self.SRC,
523
+ )
524
+ unchanged_de += is_unchanged(
525
+ array,
526
+ self.TRG,
527
+ reverse=True,
528
+ return_output=return_output,
529
+ src=src_trans,
530
+ src_field=self.SRC,
531
+ )
532
+ if df_batch is not None:
533
+ df_output_de = pd.concat([df_output_de, df_batch],
534
+ ignore_index=True)
535
+ valids_de += sum(valid)
536
+ return (
537
+ valids,
538
+ test_loss / len(self.loader_valid),
539
+ valids_de,
540
+ df_output,
541
+ df_output_de,
542
+ right_molecules,
543
+ complexity,
544
+ unchanged,
545
+ unchanged_de,
546
+ )
547
+
548
+ def translate(self, loader):
549
+ self.eval()
550
+ df_output_de = pd.DataFrame()
551
+ valids_de = 0
552
+ with torch.no_grad():
553
+ for _, batch in enumerate(loader):
554
+ src = batch.src
555
+ output = self.translate_sentence(src, self.TRG, self.device)
556
+ # checks the number of valid smiles
557
+ pred_token = output.argmax(2)
558
+ array = torch.transpose(pred_token, 0, 1)
559
+ src_trans = torch.transpose(src, 0, 1)
560
+ df_batch, valid, smiless = is_smiles(
561
+ array,
562
+ self.TRG,
563
+ reverse=True,
564
+ return_output=True,
565
+ src=src_trans,
566
+ src_field=self.SRC,
567
+ )
568
+ if df_batch is not None:
569
+ df_output_de = pd.concat([df_output_de, df_batch],
570
+ ignore_index=True)
571
+ valids_de += sum(valid)
572
+ return valids_de, df_output_de
573
+
574
+
575
+ class Encoder(nn.Module):
576
+
577
+ def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout,
578
+ max_length, device):
579
+ super().__init__()
580
+ self.device = device
581
+ self.tok_embedding = nn.Embedding(input_dim, hid_dim)
582
+ self.pos_embedding = nn.Embedding(max_length, hid_dim)
583
+ self.layers = nn.ModuleList([
584
+ EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
585
+ for _ in range(n_layers)
586
+ ])
587
+
588
+ self.dropout = nn.Dropout(dropout)
589
+ self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
590
+
591
+ def forward(self, src, src_mask):
592
+ # src = [batch size, src len]
593
+ # src_mask = [batch size, src len]
594
+ batch_size = src.shape[0]
595
+ src_len = src.shape[1]
596
+ pos = (torch.arange(0, src_len).unsqueeze(0).repeat(batch_size,
597
+ 1).to(self.device))
598
+ # pos = [batch size, src len]
599
+ src = self.dropout((self.tok_embedding(src) * self.scale) +
600
+ self.pos_embedding(pos))
601
+ # src = [batch size, src len, hid dim]
602
+ for layer in self.layers:
603
+ src = layer(src, src_mask)
604
+ # src = [batch size, src len, hid dim]
605
+ return src
606
+
607
+
608
+ class EncoderLayer(nn.Module):
609
+
610
+ def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
611
+ super().__init__()
612
+
613
+ self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
614
+ self.ff_layer_norm = nn.LayerNorm(hid_dim)
615
+ self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
616
+ dropout, device)
617
+ self.positionwise_feedforward = PositionwiseFeedforwardLayer(
618
+ hid_dim, pf_dim, dropout)
619
+ self.dropout = nn.Dropout(dropout)
620
+
621
+ def forward(self, src, src_mask):
622
+ # src = [batch size, src len, hid dim]
623
+ # src_mask = [batch size, src len]
624
+ # self attention
625
+ _src, _ = self.self_attention(src, src, src, src_mask)
626
+ # dropout, residual connection and layer norm
627
+ src = self.self_attn_layer_norm(src + self.dropout(_src))
628
+ # src = [batch size, src len, hid dim]
629
+ # positionwise feedforward
630
+ _src = self.positionwise_feedforward(src)
631
+ # dropout, residual and layer norm
632
+ src = self.ff_layer_norm(src + self.dropout(_src))
633
+ # src = [batch size, src len, hid dim]
634
+
635
+ return src
636
+
637
+
638
+ class MultiHeadAttentionLayer(nn.Module):
639
+
640
+ def __init__(self, hid_dim, n_heads, dropout, device):
641
+ super().__init__()
642
+ assert hid_dim % n_heads == 0
643
+ self.hid_dim = hid_dim
644
+ self.n_heads = n_heads
645
+ self.head_dim = hid_dim // n_heads
646
+ self.fc_q = nn.Linear(hid_dim, hid_dim)
647
+ self.fc_k = nn.Linear(hid_dim, hid_dim)
648
+ self.fc_v = nn.Linear(hid_dim, hid_dim)
649
+ self.fc_o = nn.Linear(hid_dim, hid_dim)
650
+ self.dropout = nn.Dropout(dropout)
651
+ self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
652
+
653
+ def forward(self, query, key, value, mask=None):
654
+ batch_size = query.shape[0]
655
+ # query = [batch size, query len, hid dim]
656
+ # key = [batch size, key len, hid dim]
657
+ # value = [batch size, value len, hid dim]
658
+ Q = self.fc_q(query)
659
+ K = self.fc_k(key)
660
+ V = self.fc_v(value)
661
+ # Q = [batch size, query len, hid dim]
662
+ # K = [batch size, key len, hid dim]
663
+ # V = [batch size, value len, hid dim]
664
+ Q = Q.view(batch_size, -1, self.n_heads,
665
+ self.head_dim).permute(0, 2, 1, 3)
666
+ K = K.view(batch_size, -1, self.n_heads,
667
+ self.head_dim).permute(0, 2, 1, 3)
668
+ V = V.view(batch_size, -1, self.n_heads,
669
+ self.head_dim).permute(0, 2, 1, 3)
670
+ # Q = [batch size, n heads, query len, head dim]
671
+ # K = [batch size, n heads, key len, head dim]
672
+ # V = [batch size, n heads, value len, head dim]
673
+ energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
674
+ # energy = [batch size, n heads, query len, key len]
675
+ if mask is not None:
676
+ energy = energy.masked_fill(mask == 0, -1e10)
677
+ attention = torch.softmax(energy, dim=-1)
678
+ # attention = [batch size, n heads, query len, key len]
679
+ x = torch.matmul(self.dropout(attention), V)
680
+ # x = [batch size, n heads, query len, head dim]
681
+ x = x.permute(0, 2, 1, 3).contiguous()
682
+ # x = [batch size, query len, n heads, head dim]
683
+ x = x.view(batch_size, -1, self.hid_dim)
684
+ # x = [batch size, query len, hid dim]
685
+ x = self.fc_o(x)
686
+ # x = [batch size, query len, hid dim]
687
+ return x, attention
688
+
689
+
690
+ class PositionwiseFeedforwardLayer(nn.Module):
691
+
692
+ def __init__(self, hid_dim, pf_dim, dropout):
693
+ super().__init__()
694
+ self.fc_1 = nn.Linear(hid_dim, pf_dim)
695
+ self.fc_2 = nn.Linear(pf_dim, hid_dim)
696
+ self.dropout = nn.Dropout(dropout)
697
+
698
+ def forward(self, x):
699
+ # x = [batch size, seq len, hid dim]
700
+ x = self.dropout(torch.relu(self.fc_1(x)))
701
+ # x = [batch size, seq len, pf dim]
702
+ x = self.fc_2(x)
703
+ # x = [batch size, seq len, hid dim]
704
+
705
+ return x
706
+
707
+
708
+ class Decoder(nn.Module):
709
+
710
+ def __init__(
711
+ self,
712
+ output_dim,
713
+ hid_dim,
714
+ n_layers,
715
+ n_heads,
716
+ pf_dim,
717
+ dropout,
718
+ max_length,
719
+ device,
720
+ ):
721
+ super().__init__()
722
+ self.device = device
723
+ self.tok_embedding = nn.Embedding(output_dim, hid_dim)
724
+ self.pos_embedding = nn.Embedding(max_length, hid_dim)
725
+ self.layers = nn.ModuleList([
726
+ DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
727
+ for _ in range(n_layers)
728
+ ])
729
+ self.fc_out = nn.Linear(hid_dim, output_dim)
730
+ self.dropout = nn.Dropout(dropout)
731
+ self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
732
+
733
+ def forward(self, trg, enc_src, trg_mask, src_mask):
734
+ # trg = [batch size, trg len]
735
+ # enc_src = [batch size, src len, hid dim]
736
+ # trg_mask = [batch size, trg len]
737
+ # src_mask = [batch size, src len]
738
+ batch_size = trg.shape[0]
739
+ trg_len = trg.shape[1]
740
+ pos = (torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size,
741
+ 1).to(self.device))
742
+ # pos = [batch size, trg len]
743
+ trg = self.dropout((self.tok_embedding(trg) * self.scale) +
744
+ self.pos_embedding(pos))
745
+ # trg = [batch size, trg len, hid dim]
746
+ for layer in self.layers:
747
+ trg, attention = layer(trg, enc_src, trg_mask, src_mask)
748
+ # trg = [batch size, trg len, hid dim]
749
+ # attention = [batch size, n heads, trg len, src len]
750
+ output = self.fc_out(trg)
751
+ # output = [batch size, trg len, output dim]
752
+ return output, attention
753
+
754
+
755
+ class DecoderLayer(nn.Module):
756
+
757
+ def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
758
+ super().__init__()
759
+ self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
760
+ self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
761
+ self.ff_layer_norm = nn.LayerNorm(hid_dim)
762
+ self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
763
+ dropout, device)
764
+ self.encoder_attention = MultiHeadAttentionLayer(
765
+ hid_dim, n_heads, dropout, device)
766
+ self.positionwise_feedforward = PositionwiseFeedforwardLayer(
767
+ hid_dim, pf_dim, dropout)
768
+ self.dropout = nn.Dropout(dropout)
769
+
770
+ def forward(self, trg, enc_src, trg_mask, src_mask):
771
+ # trg = [batch size, trg len, hid dim]
772
+ # enc_src = [batch size, src len, hid dim]
773
+ # trg_mask = [batch size, trg len]
774
+ # src_mask = [batch size, src len]
775
+ # self attention
776
+ _trg, _ = self.self_attention(trg, trg, trg, trg_mask)
777
+ # dropout, residual connection and layer norm
778
+ trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
779
+ # trg = [batch size, trg len, hid dim]
780
+ # encoder attention
781
+ _trg, attention = self.encoder_attention(trg, enc_src, enc_src,
782
+ src_mask)
783
+ # dropout, residual connection and layer norm
784
+ trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
785
+ # trg = [batch size, trg len, hid dim]
786
+ # positionwise feedforward
787
+ _trg = self.positionwise_feedforward(trg)
788
+ # dropout, residual and layer norm
789
+ trg = self.ff_layer_norm(trg + self.dropout(_trg))
790
+ # trg = [batch size, trg len, hid dim]
791
+ # attention = [batch size, n heads, trg len, src len]
792
+ return trg, attention
793
+
794
+
795
+ class Seq2Seq(nn.Module, Convo):
796
+
797
+ def __init__(
798
+ self,
799
+ encoder,
800
+ decoder,
801
+ src_pad_idx,
802
+ trg_pad_idx,
803
+ device,
804
+ loader_train: DataLoader,
805
+ out: str,
806
+ loader_valid=None,
807
+ loader_drugex=None,
808
+ epochs=100,
809
+ lr=0.0005,
810
+ clip=0.1,
811
+ reverse=True,
812
+ TRG=None,
813
+ SRC=None,
814
+ ):
815
+ super().__init__()
816
+ self.encoder = encoder
817
+ self.decoder = decoder
818
+ self.src_pad_idx = src_pad_idx
819
+ self.trg_pad_idx = trg_pad_idx
820
+ self.device = device
821
+ self.loader_train = loader_train
822
+ self.out = out
823
+ self.loader_valid = loader_valid
824
+ self.loader_drugex = loader_drugex
825
+ self.epochs = epochs
826
+ self.lr = lr
827
+ self.clip = clip
828
+ self.reverse = reverse
829
+ self.TRG = TRG
830
+ self.SRC = SRC
831
+
832
+ def make_src_mask(self, src):
833
+ # src = [batch size, src len]
834
+ src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
835
+ # src_mask = [batch size, 1, 1, src len]
836
+ return src_mask
837
+
838
+ def make_trg_mask(self, trg):
839
+ # trg = [batch size, trg len]
840
+ trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
841
+ # trg_pad_mask = [batch size, 1, 1, trg len]
842
+ trg_len = trg.shape[1]
843
+ trg_sub_mask = torch.tril(
844
+ torch.ones((trg_len, trg_len), device=self.device)).bool()
845
+ # trg_sub_mask = [trg len, trg len]
846
+ trg_mask = trg_pad_mask & trg_sub_mask
847
+ # trg_mask = [batch size, 1, trg len, trg len]
848
+ return trg_mask
849
+
850
+ def forward(self, src, trg):
851
+ # src = [batch size, src len]
852
+ # trg = [batch size, trg len]
853
+ src_mask = self.make_src_mask(src)
854
+ trg_mask = self.make_trg_mask(trg)
855
+ # src_mask = [batch size, 1, 1, src len]
856
+ # trg_mask = [batch size, 1, trg len, trg len]
857
+ enc_src = self.encoder(src, src_mask)
858
+ # enc_src = [batch size, src len, hid dim]
859
+ output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
860
+ # output = [batch size, trg len, output dim]
861
+ # attention = [batch size, n heads, trg len, src len]
862
+ return output, attention
863
+
864
+ def translate_sentence(self, src, trg_field, device, max_len=202):
865
+ self.eval()
866
+ src_mask = self.make_src_mask(src)
867
+ with torch.no_grad():
868
+ enc_src = self.encoder(src, src_mask)
869
+ trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
870
+ batch_size = src.shape[0]
871
+ trg = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
872
+ trg = trg.repeat(batch_size, 1)
873
+ for i in range(max_len):
874
+ # turned model into self.
875
+ trg_mask = self.make_trg_mask(trg)
876
+ with torch.no_grad():
877
+ output, attention = self.decoder(trg, enc_src, trg_mask,
878
+ src_mask)
879
+ pred_tokens = output.argmax(2)[:, -1].unsqueeze(1)
880
+ trg = torch.cat((trg, pred_tokens), 1)
881
+
882
+ return output
883
+
884
+
885
+ def remove_floats(df: pd.DataFrame, subset: str):
886
+ """Preprocessing step to remove any entries that are not strings"""
887
+ df_subset = df[subset]
888
+ df[subset] = df[subset].astype(str)
889
+ # only keep entries that stayed the same after applying astype str
890
+ df = df[df[subset] == df_subset].copy()
891
+
892
+ return df
893
+
894
+
895
+ def smi_tokenizer(smi: str, reverse=False) -> list:
896
+ """
897
+ Tokenize a SMILES molecule
898
+ """
899
+ pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
900
+ regex = re.compile(pattern)
901
+ # tokens = ['<sos>'] + [token for token in regex.findall(smi)] + ['<eos>']
902
+ tokens = [token for token in regex.findall(smi)]
903
+ # assert smi == ''.join(tokens[1:-1])
904
+ assert smi == "".join(tokens[:])
905
+ # try:
906
+ # assert smi == "".join(tokens[:])
907
+ # except:
908
+ # print(smi)
909
+ # print("".join(tokens[:]))
910
+ if reverse:
911
+ return tokens[::-1]
912
+ return tokens
913
+
914
+
915
+ def init_weights(m: nn.Module):
916
+ if hasattr(m, "weight") and m.weight.dim() > 1:
917
+ nn.init.xavier_uniform_(m.weight.data)
918
+
919
+
920
+ def count_parameters(model: nn.Module):
921
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
922
+
923
+
924
+ def epoch_time(start_time, end_time):
925
+ elapsed_time = end_time - start_time
926
+ elapsed_mins = int(elapsed_time / 60)
927
+ elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
928
+ return elapsed_mins, elapsed_secs
929
+
930
+
931
+ def initialize_model(folder_out: str,
932
+ data_source: str,
933
+ error_source: str,
934
+ device: torch.device,
935
+ threshold: int,
936
+ epochs: int,
937
+ layers: int = 3,
938
+ batch_size: int = 16,
939
+ invalid_type: str = "all",
940
+ num_errors: int = 1,
941
+ validation_step=False):
942
+ """Create encoder decoder models for specified model (currently only translator) & type of invalid SMILES
943
+
944
+ param data: collection of invalid, valid SMILES pairs
945
+ param invalid_smiles_path: path to previously generated invalid SMILES
946
+ param invalid_type: type of errors introduced into invalid SMILES
947
+
948
+ return:
949
+
950
+ """
951
+
952
+ # set fields
953
+ SRC = Field(
954
+ tokenize=lambda x: smi_tokenizer(x),
955
+ init_token="<sos>",
956
+ eos_token="<eos>",
957
+ batch_first=True,
958
+ )
959
+ TRG = Field(
960
+ tokenize=lambda x: smi_tokenizer(x, reverse=True),
961
+ init_token="<sos>",
962
+ eos_token="<eos>",
963
+ batch_first=True,
964
+ )
965
+
966
+ if validation_step:
967
+ train, val = TabularDataset.splits(
968
+ path=f'{folder_out}errors/split/',
969
+ train=f"{data_source}_{invalid_type}_{num_errors}_errors_train.csv",
970
+ validation=
971
+ f"{data_source}_{invalid_type}_{num_errors}_errors_dev.csv",
972
+ format="CSV",
973
+ skip_header=False,
974
+ fields={
975
+ "ERROR": ("src", SRC),
976
+ "STD_SMILES": ("trg", TRG)
977
+ },
978
+ )
979
+ SRC.build_vocab(train, val, max_size=1000)
980
+ TRG.build_vocab(train, val, max_size=1000)
981
+ else:
982
+ train = TabularDataset(
983
+ path=
984
+ f'{folder_out}{data_source}_{invalid_type}_{num_errors}_errors.csv',
985
+ format="CSV",
986
+ skip_header=False,
987
+ fields={
988
+ "ERROR": ("src", SRC),
989
+ "STD_SMILES": ("trg", TRG)
990
+ },
991
+ )
992
+ SRC.build_vocab(train, max_size=1000)
993
+ TRG.build_vocab(train, max_size=1000)
994
+
995
+ drugex = TabularDataset(
996
+ path=error_source,
997
+ format="csv",
998
+ skip_header=False,
999
+ fields={
1000
+ "SMILES": ("src", SRC),
1001
+ "SMILES_TARGET": ("trg", TRG)
1002
+ },
1003
+ )
1004
+
1005
+
1006
+ #SRC.vocab = torch.load('vocab_src.pth')
1007
+ #TRG.vocab = torch.load('vocab_trg.pth')
1008
+
1009
+ # model parameters
1010
+ EPOCHS = epochs
1011
+ BATCH_SIZE = batch_size
1012
+ INPUT_DIM = len(SRC.vocab)
1013
+ OUTPUT_DIM = len(TRG.vocab)
1014
+ HID_DIM = 256
1015
+ ENC_LAYERS = layers
1016
+ DEC_LAYERS = layers
1017
+ ENC_HEADS = 8
1018
+ DEC_HEADS = 8
1019
+ ENC_PF_DIM = 512
1020
+ DEC_PF_DIM = 512
1021
+ ENC_DROPOUT = 0.1
1022
+ DEC_DROPOUT = 0.1
1023
+ SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
1024
+ TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
1025
+ # add 2 to length for start and stop tokens
1026
+ MAX_LENGTH = threshold + 2
1027
+
1028
+ # model name
1029
+ MODEL_OUT_FOLDER = f"{folder_out}"
1030
+
1031
+ MODEL_NAME = "transformer_%s_%s_%s_%s_%s" % (
1032
+ invalid_type, num_errors, data_source, BATCH_SIZE, layers)
1033
+ if not os.path.exists(MODEL_OUT_FOLDER):
1034
+ os.mkdir(MODEL_OUT_FOLDER)
1035
+
1036
+ out = os.path.join(MODEL_OUT_FOLDER, MODEL_NAME)
1037
+
1038
+ torch.save(SRC.vocab, f'{out}_vocab_src.pth')
1039
+ torch.save(TRG.vocab, f'{out}_vocab_trg.pth')
1040
+
1041
+ # iterator is a dataloader
1042
+ # iterator to pass to the same length and create batches in which the
1043
+ # amount of padding is minimized
1044
+ if validation_step:
1045
+ train_iter, val_iter = BucketIterator.splits(
1046
+ (train, val),
1047
+ batch_sizes=(BATCH_SIZE, 256),
1048
+ sort_within_batch=True,
1049
+ shuffle=True,
1050
+ # the BucketIterator needs to be told what function it should use to
1051
+ # group the data.
1052
+ sort_key=lambda x: len(x.src),
1053
+ device=device,
1054
+ )
1055
+ else:
1056
+ train_iter = BucketIterator(
1057
+ train,
1058
+ batch_size=BATCH_SIZE,
1059
+ sort_within_batch=True,
1060
+ shuffle=True,
1061
+ # the BucketIterator needs to be told what function it should use to
1062
+ # group the data.
1063
+ sort_key=lambda x: len(x.src),
1064
+ device=device,
1065
+ )
1066
+ val_iter = None
1067
+
1068
+ drugex_iter = Iterator(
1069
+ drugex,
1070
+ batch_size=64,
1071
+ device=device,
1072
+ sort=False,
1073
+ sort_within_batch=True,
1074
+ sort_key=lambda x: len(x.src),
1075
+ repeat=False,
1076
+ )
1077
+
1078
+
1079
+ # model initialization
1080
+
1081
+ enc = Encoder(
1082
+ INPUT_DIM,
1083
+ HID_DIM,
1084
+ ENC_LAYERS,
1085
+ ENC_HEADS,
1086
+ ENC_PF_DIM,
1087
+ ENC_DROPOUT,
1088
+ MAX_LENGTH,
1089
+ device,
1090
+ )
1091
+ dec = Decoder(
1092
+ OUTPUT_DIM,
1093
+ HID_DIM,
1094
+ DEC_LAYERS,
1095
+ DEC_HEADS,
1096
+ DEC_PF_DIM,
1097
+ DEC_DROPOUT,
1098
+ MAX_LENGTH,
1099
+ device,
1100
+ )
1101
+
1102
+ model = Seq2Seq(
1103
+ enc,
1104
+ dec,
1105
+ SRC_PAD_IDX,
1106
+ TRG_PAD_IDX,
1107
+ device,
1108
+ train_iter,
1109
+ out=out,
1110
+ loader_valid=val_iter,
1111
+ loader_drugex=drugex_iter,
1112
+ epochs=EPOCHS,
1113
+ TRG=TRG,
1114
+ SRC=SRC,
1115
+ ).to(device)
1116
+
1117
+
1118
+
1119
+
1120
+ return model, out, SRC
1121
+
1122
+
1123
+ def train_model(model, out, assess):
1124
+ """Apply given weights (& assess performance or train further) or start training new model
1125
+
1126
+ Args:
1127
+ model: initialized model
1128
+ out: .pkg file with model parameters
1129
+ asses: bool
1130
+
1131
+ Returns:
1132
+ model with (new) weights
1133
+ """
1134
+
1135
+ if os.path.exists(f"{out}.pkg") and assess:
1136
+
1137
+
1138
+ model.load_state_dict(torch.load(f=out + ".pkg"))
1139
+ (
1140
+ valids,
1141
+ loss_valid,
1142
+ valids_de,
1143
+ df_output,
1144
+ df_output_de,
1145
+ right_molecules,
1146
+ complexity,
1147
+ unchanged,
1148
+ unchanged_de,
1149
+ ) = model.evaluate(True)
1150
+
1151
+
1152
+ # log = open('unchanged.log', 'a')
1153
+ # info = f'type: comb unchanged: {unchan:.4g} unchanged_drugex: {unchan_de:.4g}'
1154
+ # print(info, file=log, flush = True)
1155
+ # print(valids_de)
1156
+ # print(unchanged_de)
1157
+
1158
+ # print(unchan)
1159
+ # print(unchan_de)
1160
+ # df_output_de.to_csv(f'{out}_de_new.csv', index = False)
1161
+
1162
+ # error_de = 1 - valids_de / len(drugex_iter.dataset)
1163
+ # print(error_de)
1164
+ # df_output.to_csv(f'{out}_par.csv', index = False)
1165
+
1166
+ elif os.path.exists(f"{out}.pkg"):
1167
+
1168
+ # starts from the model after the last epoch, not the best epoch
1169
+ model.load_state_dict(torch.load(f=out + "_last.pkg"))
1170
+ # need to change how log file names epochs
1171
+ model.train_model()
1172
+ else:
1173
+
1174
+ model = model.apply(init_weights)
1175
+ model.train_model()
1176
+
1177
+ return model
1178
+
1179
+
1180
+ def correct_SMILES(model, out, error_source, device, SRC):
1181
+ """Model that is given corrects SMILES and return number of correct ouputs and dataframe containing all outputs
1182
+ Args:
1183
+ model: initialized model
1184
+ out: .pkg file with model parameters
1185
+ asses: bool
1186
+
1187
+ Returns:
1188
+ valids: number of fixed outputs
1189
+ df_output: dataframe containing output (either correct or incorrect) & original input
1190
+ """
1191
+ ## account for tokens that are not yet in SRC without changing existing SRC token embeddings
1192
+ errors = TabularDataset(
1193
+ path=error_source,
1194
+ format="csv",
1195
+ skip_header=False,
1196
+ fields={"SMILES": ("src", SRC)},
1197
+ )
1198
+
1199
+ errors_loader = Iterator(
1200
+ errors,
1201
+ batch_size=64,
1202
+ device=device,
1203
+ sort=False,
1204
+ sort_within_batch=True,
1205
+ sort_key=lambda x: len(x.src),
1206
+ repeat=False,
1207
+ )
1208
+ model.load_state_dict(torch.load(f=out + ".pkg",map_location=torch.device('cpu')))
1209
+ # add option to use different iterator maybe?
1210
+
1211
+ valids, df_output = model.translate(errors_loader)
1212
+ #df_output.to_csv(f"{error_source}_fixed.csv", index=False)
1213
+
1214
+
1215
+ return valids, df_output
1216
+
1217
+
1218
+
1219
+ class smi_correct(object):
1220
+ def __init__(self, model_name, trans_file_path):
1221
+ # set random seed, used for error generation & initiation transformer
1222
+
1223
+ self.SEED = 42
1224
+ random.seed(self.SEED)
1225
+ self.model_name = model_name
1226
+ self.folder_out = "data/"
1227
+
1228
+ self.trans_file_path = trans_file_path
1229
+
1230
+ if not os.path.exists(self.folder_out):
1231
+ os.makedirs(self.folder_out)
1232
+
1233
+ self.invalid_type = 'multiple'
1234
+ self.num_errors = 12
1235
+ self.threshold = 200
1236
+ self.data_source = f"PAPYRUS_{self.threshold}"
1237
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
1238
+ self.initialize_source = 'data/papyrus_rnn_S.csv' # change this path
1239
+
1240
+ def standardization_pipeline(self, smile):
1241
+ desalter = MolStandardize.rdMolStandardize.LargestFragmentChooser()
1242
+ std_smile = None
1243
+ if not isinstance(smile, str): return None
1244
+ m = Chem.MolFromSmiles(smile)
1245
+ # skips smiles for which no mol file could be generated
1246
+ if m is not None:
1247
+ # standardizes
1248
+ std_m = standardizer.standardize_mol(m)
1249
+ # strips salts
1250
+ std_m_p, exclude = standardizer.get_parent_mol(std_m)
1251
+ if not exclude:
1252
+ # choose largest fragment for rare cases where chembl structure
1253
+ # pipeline leaves 2 fragments
1254
+ std_m_p_d = desalter.choose(std_m_p)
1255
+ std_smile = Chem.MolToSmiles(std_m_p_d)
1256
+ return std_smile
1257
+
1258
+ def remove_smiles_duplicates(self, dataframe: pd.DataFrame,
1259
+ subset: str) -> pd.DataFrame:
1260
+ return dataframe.drop_duplicates(subset=subset)
1261
+
1262
+ def correct(self, smi):
1263
+
1264
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1265
+
1266
+ model, out, SRC = initialize_model(self.folder_out,
1267
+ self.data_source,
1268
+ error_source=self.initialize_source,
1269
+ device=device,
1270
+ threshold=self.threshold,
1271
+ epochs=30,
1272
+ layers=3,
1273
+ batch_size=16,
1274
+ invalid_type=self.invalid_type,
1275
+ num_errors=self.num_errors)
1276
+
1277
+ valids, df_output = correct_SMILES(model, out, smi, device,
1278
+ SRC)
1279
+
1280
+ df_output["SMILES"] = df_output.apply(lambda row: self.standardization_pipeline(row["CORRECT"]), axis=1)
1281
+
1282
+ df_output = self.remove_smiles_duplicates(df_output, subset="SMILES").drop(columns=["CORRECT", "INCORRECT", "ORIGINAL"]).dropna()
1283
+
1284
+ return df_output
src/util/utils.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ import datetime
5
+ import warnings
6
+ import itertools
7
+ from copy import deepcopy
8
+ from functools import partial
9
+ from collections import Counter
10
+ from multiprocessing import Pool
11
+ from statistics import mean
12
+
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ from matplotlib.lines import Line2D
16
+ from scipy.spatial.distance import cosine as cos_distance
17
+
18
+ import torch
19
+ import wandb
20
+
21
+ from rdkit import Chem, DataStructs, RDLogger
22
+ from rdkit.Chem import (
23
+ AllChem,
24
+ Draw,
25
+ Descriptors,
26
+ Lipinski,
27
+ Crippen,
28
+ rdMolDescriptors,
29
+ FilterCatalog,
30
+ )
31
+ from rdkit.Chem.Scaffolds import MurckoScaffold
32
+
33
+ # Disable RDKit warnings
34
+ RDLogger.DisableLog("rdApp.*")
35
+
36
+
37
+ class Metrics(object):
38
+ """
39
+ Collection of static methods to compute various metrics for molecules.
40
+ """
41
+
42
+ @staticmethod
43
+ def valid(x):
44
+ """
45
+ Checks whether the molecule is valid.
46
+
47
+ Args:
48
+ x: RDKit molecule object.
49
+
50
+ Returns:
51
+ bool: True if molecule is valid and has a non-empty SMILES representation.
52
+ """
53
+ return x is not None and Chem.MolToSmiles(x) != ''
54
+
55
+ @staticmethod
56
+ def tanimoto_sim_1v2(data1, data2):
57
+ """
58
+ Computes the average Tanimoto similarity for paired fingerprints.
59
+
60
+ Args:
61
+ data1: Fingerprint data for first set.
62
+ data2: Fingerprint data for second set.
63
+
64
+ Returns:
65
+ float: The average Tanimoto similarity between corresponding fingerprints.
66
+ """
67
+ # Determine the minimum size between two arrays for pairing
68
+ min_len = data1.size if data1.size > data2.size else data2
69
+ sims = []
70
+ for i in range(min_len):
71
+ sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
72
+ sims.append(sim)
73
+ # Use 'mean' from statistics; note that variable 'sim' was used, corrected to use sims list.
74
+ mean_sim = mean(sims)
75
+ return mean_sim
76
+
77
+ @staticmethod
78
+ def mol_length(x):
79
+ """
80
+ Computes the length of the largest fragment (by character count) in a SMILES string.
81
+
82
+ Args:
83
+ x (str): SMILES string.
84
+
85
+ Returns:
86
+ int: Number of alphabetic characters in the longest fragment of the SMILES.
87
+ """
88
+ if x is not None:
89
+ # Split at dots (.) and take the fragment with maximum length, then count alphabetic characters.
90
+ return len([char for char in max(x.split(sep="."), key=len).upper() if char.isalpha()])
91
+ else:
92
+ return 0
93
+
94
+ @staticmethod
95
+ def max_component(data, max_len):
96
+ """
97
+ Returns the average normalized length of molecules in the dataset.
98
+
99
+ Each molecule's length is computed and divided by max_len, then averaged.
100
+
101
+ Args:
102
+ data (iterable): Collection of SMILES strings.
103
+ max_len (int): Maximum possible length for normalization.
104
+
105
+ Returns:
106
+ float: Normalized average length.
107
+ """
108
+ lengths = np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)
109
+ return (lengths / max_len).mean()
110
+
111
+ @staticmethod
112
+ def mean_atom_type(data):
113
+ """
114
+ Computes the average number of unique atom types in the provided node data.
115
+
116
+ Args:
117
+ data (iterable): Iterable containing node data with unique atom types.
118
+
119
+ Returns:
120
+ float: The average count of unique atom types, subtracting one.
121
+ """
122
+ atom_types_used = []
123
+ for i in data:
124
+ # Assuming each element i has a .unique() method that returns unique atom types.
125
+ atom_types_used.append(len(i.unique().tolist()))
126
+ av_type = np.mean(atom_types_used) - 1
127
+ return av_type
128
+
129
+
130
+ def mols2grid_image(mols, path):
131
+ """
132
+ Saves grid images for a list of molecules.
133
+
134
+ For each molecule in the list, computes 2D coordinates and saves an image file.
135
+
136
+ Args:
137
+ mols (list): List of RDKit molecule objects.
138
+ path (str): Directory where images will be saved.
139
+ """
140
+ # Replace None molecules with an empty molecule
141
+ mols = [e if e is not None else Chem.RWMol() for e in mols]
142
+
143
+ for i in range(len(mols)):
144
+ if Metrics.valid(mols[i]):
145
+ AllChem.Compute2DCoords(mols[i])
146
+ file_path = os.path.join(path, "{}.png".format(i + 1))
147
+ Draw.MolToFile(mols[i], file_path, size=(1200, 1200))
148
+ # wandb.save(file_path) # Optionally save to Weights & Biases
149
+ else:
150
+ continue
151
+
152
+
153
+ def save_smiles_matrices(mols, edges_hard, nodes_hard, path, data_source=None):
154
+ """
155
+ Saves the edge and node matrices along with SMILES strings to text files.
156
+
157
+ Each file contains the edge matrix, node matrix, and SMILES representation for a molecule.
158
+
159
+ Args:
160
+ mols (list): List of RDKit molecule objects.
161
+ edges_hard (torch.Tensor): Tensor of edge features.
162
+ nodes_hard (torch.Tensor): Tensor of node features.
163
+ path (str): Directory where files will be saved.
164
+ data_source: Optional data source information (not used in function).
165
+ """
166
+ mols = [e if e is not None else Chem.RWMol() for e in mols]
167
+
168
+ for i in range(len(mols)):
169
+ if Metrics.valid(mols[i]):
170
+ save_path = os.path.join(path, "{}.txt".format(i + 1))
171
+ with open(save_path, "a") as f:
172
+ np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n", fmt='%1.2f')
173
+ f.write("\n")
174
+ np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:", fmt='%1.2f')
175
+ f.write("\n")
176
+ # Append the SMILES representation to the file
177
+ with open(save_path, "a") as f:
178
+ print(Chem.MolToSmiles(mols[i]), file=f)
179
+ # wandb.save(save_path) # Optionally save to Weights & Biases
180
+ else:
181
+ continue
182
+
183
+ def dense_to_sparse_with_attr(adj):
184
+ """
185
+ Converts a dense adjacency matrix to a sparse representation.
186
+
187
+ Args:
188
+ adj (torch.Tensor): Adjacency matrix tensor (2D or 3D) with square last two dimensions.
189
+
190
+ Returns:
191
+ tuple: A tuple containing indices and corresponding edge attributes.
192
+ """
193
+ assert adj.dim() >= 2 and adj.dim() <= 3
194
+ assert adj.size(-1) == adj.size(-2)
195
+
196
+ index = adj.nonzero(as_tuple=True)
197
+ edge_attr = adj[index]
198
+
199
+ if len(index) == 3:
200
+ batch = index[0] * adj.size(-1)
201
+ index = (batch + index[1], batch + index[2])
202
+ return index, edge_attr
203
+
204
+
205
+ def mol_sample(sample_directory, edges, nodes, idx, i, matrices2mol, dataset_name):
206
+ """
207
+ Samples molecules from edge and node predictions, then saves grid images and text files.
208
+
209
+ Args:
210
+ sample_directory (str): Directory to save the samples.
211
+ edges (torch.Tensor): Edge predictions tensor.
212
+ nodes (torch.Tensor): Node predictions tensor.
213
+ idx (int): Current index for naming the sample.
214
+ i (int): Epoch/iteration index.
215
+ matrices2mol (callable): Function to convert matrices to RDKit molecule.
216
+ dataset_name (str): Name of the dataset for file naming.
217
+ """
218
+ sample_path = os.path.join(sample_directory, "{}_{}-epoch_iteration".format(idx + 1, i + 1))
219
+ # Get the index of the maximum predicted feature along the last dimension
220
+ g_edges_hat_sample = torch.max(edges, -1)[1]
221
+ g_nodes_hat_sample = torch.max(nodes, -1)[1]
222
+ # Convert matrices to molecule objects
223
+ mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
224
+ strict=True, file_name=dataset_name)
225
+ for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
226
+
227
+ if not os.path.exists(sample_path):
228
+ os.makedirs(sample_path)
229
+
230
+ mols2grid_image(mol, sample_path)
231
+ save_smiles_matrices(mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
232
+
233
+ # Remove the directory if no files were saved
234
+ if len(os.listdir(sample_path)) == 0:
235
+ os.rmdir(sample_path)
236
+
237
+ print("Valid molecules are saved.")
238
+ print("Valid matrices and smiles are saved")
239
+
240
+
241
+ def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node,
242
+ matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
243
+ """
244
+ Logs training statistics and evaluation metrics.
245
+
246
+ The function generates molecules from predictions, computes various metrics such as
247
+ validity, uniqueness, novelty, and similarity scores, and logs them using wandb and a file.
248
+
249
+ Args:
250
+ log_path (str): Path to save the log file.
251
+ start_time (float): Start time to compute elapsed time.
252
+ i (int): Current iteration index.
253
+ idx (int): Current epoch index.
254
+ loss (dict): Dictionary to update with loss and metric values.
255
+ save_path (str): Directory path to save sample outputs.
256
+ drug_smiles (list): List of reference drug SMILES.
257
+ edge (torch.Tensor): Edge prediction tensor.
258
+ node (torch.Tensor): Node prediction tensor.
259
+ matrices2mol (callable): Function to convert matrices to molecules.
260
+ dataset_name (str): Dataset name.
261
+ real_adj (torch.Tensor): Ground truth adjacency matrix tensor.
262
+ real_annot (torch.Tensor): Ground truth annotation tensor.
263
+ drug_vecs (list): List of drug vectors for similarity calculation.
264
+ """
265
+ g_edges_hat_sample = torch.max(edge, -1)[1]
266
+ g_nodes_hat_sample = torch.max(node, -1)[1]
267
+
268
+ a_tensor_sample = torch.max(real_adj, -1)[1].float()
269
+ x_tensor_sample = torch.max(real_annot, -1)[1].float()
270
+
271
+ # Generate molecules from predictions and real data
272
+ mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
273
+ strict=True, file_name=dataset_name)
274
+ for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
275
+ real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
276
+ strict=True, file_name=dataset_name)
277
+ for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
278
+
279
+ # Compute average number of atom types
280
+ atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
281
+ real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
282
+ gen_smiles = []
283
+ uniq_smiles = []
284
+ for line in mols:
285
+ if line is not None:
286
+ gen_smiles.append(Chem.MolToSmiles(line))
287
+ uniq_smiles.append(Chem.MolToSmiles(line))
288
+ elif line is None:
289
+ gen_smiles.append(None)
290
+
291
+ # Process SMILES to take the longest fragment if multiple are present
292
+ gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
293
+ uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
294
+
295
+ # Save the generated SMILES to a text file
296
+ sample_save_dir = os.path.join(save_path, "samples.txt")
297
+ with open(sample_save_dir, "a") as f:
298
+ for s in gen_smiles_saves:
299
+ if s is not None:
300
+ f.write(s + "\n")
301
+
302
+ k = len(set(uniq_smiles_saves) - {None})
303
+ et = time.time() - start_time
304
+ et = str(datetime.timedelta(seconds=et))[:-7]
305
+ log_str = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i + 1)
306
+
307
+ # Generate molecular fingerprints for similarity computations
308
+ gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
309
+ chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
310
+
311
+ # Compute evaluation metrics: validity, uniqueness, novelty, similarity scores, and average maximum molecule length.
312
+ valid = fraction_valid(gen_smiles_saves)
313
+ unique = fraction_unique(uniq_smiles_saves, k)
314
+ novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
315
+ novel_akt = novelty(gen_smiles_saves, drug_smiles)
316
+ if len(uniq_smiles_saves) == 0:
317
+ snn_chembl = 0
318
+ snn_akt = 0
319
+ maxlen = 0
320
+ else:
321
+ snn_chembl = average_agg_tanimoto(np.array(chembl_vecs), np.array(gen_vecs))
322
+ snn_akt = average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs))
323
+ maxlen = Metrics.max_component(uniq_smiles_saves, 45)
324
+
325
+ # Update loss dictionary with computed metrics
326
+ loss.update({
327
+ 'Validity': valid,
328
+ 'Uniqueness': unique,
329
+ 'Novelty': novel_starting_mol,
330
+ 'Novelty_akt': novel_akt,
331
+ 'SNN_chembl': snn_chembl,
332
+ 'SNN_akt': snn_akt,
333
+ 'MaxLen': maxlen,
334
+ 'Atom_types': atom_types_average
335
+ })
336
+
337
+ # Log metrics using wandb
338
+ wandb.log({
339
+ "Validity": valid,
340
+ "Uniqueness": unique,
341
+ "Novelty": novel_starting_mol,
342
+ "Novelty_akt": novel_akt,
343
+ "SNN_chembl": snn_chembl,
344
+ "SNN_akt": snn_akt,
345
+ "MaxLen": maxlen,
346
+ "Atom_types": atom_types_average
347
+ })
348
+
349
+ # Append each metric to the log string and write to the log file
350
+ for tag, value in loss.items():
351
+ log_str += ", {}: {:.4f}".format(tag, value)
352
+ with open(log_path, "a") as f:
353
+ f.write(log_str + "\n")
354
+ print(log_str)
355
+ print("\n")
356
+
357
+
358
+ def plot_grad_flow(named_parameters, model, itera, epoch, grad_flow_directory):
359
+ """
360
+ Plots the gradients flowing through different layers during training.
361
+
362
+ This is useful to check for possible gradient vanishing or exploding problems.
363
+
364
+ Args:
365
+ named_parameters (iterable): Iterable of (name, parameter) tuples from the model.
366
+ model (str): Name of the model (used for saving the plot).
367
+ itera (int): Iteration index.
368
+ epoch (int): Current epoch.
369
+ grad_flow_directory (str): Directory to save the gradient flow plot.
370
+ """
371
+ ave_grads = []
372
+ max_grads = []
373
+ layers = []
374
+ for n, p in named_parameters:
375
+ if p.requires_grad and ("bias" not in n):
376
+ layers.append(n)
377
+ ave_grads.append(p.grad.abs().mean().cpu())
378
+ max_grads.append(p.grad.abs().max().cpu())
379
+ # Plot maximum gradients and average gradients for each layer
380
+ plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
381
+ plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
382
+ plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
383
+ plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
384
+ plt.xlim(left=0, right=len(ave_grads))
385
+ plt.ylim(bottom=-0.001, top=1) # Zoom in on lower gradient regions
386
+ plt.xlabel("Layers")
387
+ plt.ylabel("Average Gradient")
388
+ plt.title("Gradient Flow")
389
+ plt.grid(True)
390
+ plt.legend([
391
+ Line2D([0], [0], color="c", lw=4),
392
+ Line2D([0], [0], color="b", lw=4),
393
+ Line2D([0], [0], color="k", lw=4)
394
+ ], ['max-gradient', 'mean-gradient', 'zero-gradient'])
395
+ # Save the plot to the specified directory
396
+ plt.savefig(os.path.join(grad_flow_directory, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi=500, bbox_inches='tight')
397
+
398
+
399
+ def get_mol(smiles_or_mol):
400
+ """
401
+ Loads a SMILES string or molecule into an RDKit molecule object.
402
+
403
+ Args:
404
+ smiles_or_mol (str or RDKit Mol): SMILES string or RDKit molecule.
405
+
406
+ Returns:
407
+ RDKit Mol or None: Sanitized molecule object, or None if invalid.
408
+ """
409
+ if isinstance(smiles_or_mol, str):
410
+ if len(smiles_or_mol) == 0:
411
+ return None
412
+ mol = Chem.MolFromSmiles(smiles_or_mol)
413
+ if mol is None:
414
+ return None
415
+ try:
416
+ Chem.SanitizeMol(mol)
417
+ except ValueError:
418
+ return None
419
+ return mol
420
+ return smiles_or_mol
421
+
422
+
423
+ def mapper(n_jobs):
424
+ """
425
+ Returns a mapping function for parallel or serial processing.
426
+
427
+ If n_jobs == 1, returns the built-in map function.
428
+ If n_jobs > 1, returns a function that uses a multiprocessing pool.
429
+
430
+ Args:
431
+ n_jobs (int or pool object): Number of jobs or a Pool instance.
432
+
433
+ Returns:
434
+ callable: A function that acts like map.
435
+ """
436
+ if n_jobs == 1:
437
+ def _mapper(*args, **kwargs):
438
+ return list(map(*args, **kwargs))
439
+ return _mapper
440
+ if isinstance(n_jobs, int):
441
+ pool = Pool(n_jobs)
442
+ def _mapper(*args, **kwargs):
443
+ try:
444
+ result = pool.map(*args, **kwargs)
445
+ finally:
446
+ pool.terminate()
447
+ return result
448
+ return _mapper
449
+ return n_jobs.map
450
+
451
+
452
+ def remove_invalid(gen, canonize=True, n_jobs=1):
453
+ """
454
+ Removes invalid molecules from the provided dataset.
455
+
456
+ Optionally canonizes the SMILES strings.
457
+
458
+ Args:
459
+ gen (list): List of SMILES strings.
460
+ canonize (bool): Whether to convert to canonical SMILES.
461
+ n_jobs (int): Number of parallel jobs.
462
+
463
+ Returns:
464
+ list: Filtered list of valid molecules.
465
+ """
466
+ if not canonize:
467
+ mols = mapper(n_jobs)(get_mol, gen)
468
+ return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
469
+ return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None]
470
+
471
+
472
+ def fraction_valid(gen, n_jobs=1):
473
+ """
474
+ Computes the fraction of valid molecules in the dataset.
475
+
476
+ Args:
477
+ gen (list): List of SMILES strings.
478
+ n_jobs (int): Number of parallel jobs.
479
+
480
+ Returns:
481
+ float: Fraction of molecules that are valid.
482
+ """
483
+ gen = mapper(n_jobs)(get_mol, gen)
484
+ return 1 - gen.count(None) / len(gen)
485
+
486
+
487
+ def canonic_smiles(smiles_or_mol):
488
+ """
489
+ Converts a SMILES string or molecule to its canonical SMILES.
490
+
491
+ Args:
492
+ smiles_or_mol (str or RDKit Mol): Input molecule.
493
+
494
+ Returns:
495
+ str or None: Canonical SMILES string or None if invalid.
496
+ """
497
+ mol = get_mol(smiles_or_mol)
498
+ if mol is None:
499
+ return None
500
+ return Chem.MolToSmiles(mol)
501
+
502
+
503
+ def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
504
+ """
505
+ Computes the fraction of unique molecules.
506
+
507
+ Optionally computes unique@k, where only the first k molecules are considered.
508
+
509
+ Args:
510
+ gen (list): List of SMILES strings.
511
+ k (int): Optional cutoff for unique@k computation.
512
+ n_jobs (int): Number of parallel jobs.
513
+ check_validity (bool): Whether to check for validity of molecules.
514
+
515
+ Returns:
516
+ float: Fraction of unique molecules.
517
+ """
518
+ if k is not None:
519
+ if len(gen) < k:
520
+ warnings.warn("Can't compute unique@{}.".format(k) +
521
+ " gen contains only {} molecules".format(len(gen)))
522
+ gen = gen[:k]
523
+ if check_validity:
524
+ canonic = list(mapper(n_jobs)(canonic_smiles, gen))
525
+ canonic = [i for i in canonic if i is not None]
526
+ set_cannonic = set(canonic)
527
+ return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic)
528
+
529
+
530
+ def novelty(gen, train, n_jobs=1):
531
+ """
532
+ Computes the novelty score of generated molecules.
533
+
534
+ Novelty is defined as the fraction of generated molecules that do not appear in the training set.
535
+
536
+ Args:
537
+ gen (list): List of generated SMILES strings.
538
+ train (list): List of training SMILES strings.
539
+ n_jobs (int): Number of parallel jobs.
540
+
541
+ Returns:
542
+ float: Novelty score.
543
+ """
544
+ gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
545
+ gen_smiles_set = set(gen_smiles) - {None}
546
+ train_set = set(train)
547
+ return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
548
+
549
+
550
+ def internal_diversity(gen):
551
+ """
552
+ Computes the internal diversity of a set of molecules.
553
+
554
+ Internal diversity is defined as one minus the average Tanimoto similarity between all pairs.
555
+
556
+ Args:
557
+ gen: Array-like representation of molecules.
558
+
559
+ Returns:
560
+ tuple: Mean and standard deviation of internal diversity.
561
+ """
562
+ diversity = [1 - x for x in average_agg_tanimoto(gen, gen, agg="mean", intdiv=True)]
563
+ return np.mean(diversity), np.std(diversity)
564
+
565
+
566
+ def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1, intdiv=False):
567
+ """
568
+ Computes the average aggregated Tanimoto similarity between two sets of molecular fingerprints.
569
+
570
+ For each fingerprint in gen_vecs, finds the closest (max or mean) similarity with fingerprints in stock_vecs.
571
+
572
+ Args:
573
+ stock_vecs (numpy.ndarray): Array of fingerprint vectors from the reference set.
574
+ gen_vecs (numpy.ndarray): Array of fingerprint vectors from the generated set.
575
+ batch_size (int): Batch size for processing fingerprints.
576
+ agg (str): Aggregation method, either 'max' or 'mean'.
577
+ device (str): Device to perform computations on.
578
+ p (int): Power for averaging.
579
+ intdiv (bool): Whether to return individual similarities or the average.
580
+
581
+ Returns:
582
+ float or numpy.ndarray: Average aggregated Tanimoto similarity or array of individual scores.
583
+ """
584
+ assert agg in ['max', 'mean'], "Can aggregate only max or mean"
585
+ agg_tanimoto = np.zeros(len(gen_vecs))
586
+ total = np.zeros(len(gen_vecs))
587
+ for j in range(0, stock_vecs.shape[0], batch_size):
588
+ x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
589
+ for i in range(0, gen_vecs.shape[0], batch_size):
590
+ y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
591
+ y_gen = y_gen.transpose(0, 1)
592
+ tp = torch.mm(x_stock, y_gen)
593
+ # Compute Jaccard/Tanimoto similarity
594
+ jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
595
+ jac[np.isnan(jac)] = 1
596
+ if p != 1:
597
+ jac = jac ** p
598
+ if agg == 'max':
599
+ agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
600
+ agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
601
+ elif agg == 'mean':
602
+ agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
603
+ total[i:i + y_gen.shape[1]] += jac.shape[0]
604
+ if agg == 'mean':
605
+ agg_tanimoto /= total
606
+ if p != 1:
607
+ agg_tanimoto = (agg_tanimoto) ** (1 / p)
608
+ if intdiv:
609
+ return agg_tanimoto
610
+ else:
611
+ return np.mean(agg_tanimoto)
612
+
613
+
614
+ def str2bool(v):
615
+ """
616
+ Converts a string to a boolean.
617
+
618
+ Args:
619
+ v (str): Input string.
620
+
621
+ Returns:
622
+ bool: True if the string is 'true' (case insensitive), else False.
623
+ """
624
+ return v.lower() in ('true')
625
+
626
+
627
+ def obey_lipinski(mol):
628
+ """
629
+ Checks if a molecule obeys Lipinski's Rule of Five.
630
+
631
+ The function evaluates weight, hydrogen bond donors and acceptors, logP, and rotatable bonds.
632
+
633
+ Args:
634
+ mol (RDKit Mol): Molecule object.
635
+
636
+ Returns:
637
+ int: Number of Lipinski rules satisfied.
638
+ """
639
+ mol = deepcopy(mol)
640
+ Chem.SanitizeMol(mol)
641
+ rule_1 = Descriptors.ExactMolWt(mol) < 500
642
+ rule_2 = Lipinski.NumHDonors(mol) <= 5
643
+ rule_3 = Lipinski.NumHAcceptors(mol) <= 10
644
+ rule_4 = (logp := Crippen.MolLogP(mol) >= -2) & (logp <= 5)
645
+ rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
646
+ return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
647
+
648
+
649
+ def obey_veber(mol):
650
+ """
651
+ Checks if a molecule obeys Veber's rules.
652
+
653
+ Veber's rules focus on the number of rotatable bonds and topological polar surface area.
654
+
655
+ Args:
656
+ mol (RDKit Mol): Molecule object.
657
+
658
+ Returns:
659
+ int: Number of Veber's rules satisfied.
660
+ """
661
+ mol = deepcopy(mol)
662
+ Chem.SanitizeMol(mol)
663
+ rule_1 = rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
664
+ rule_2 = rdMolDescriptors.CalcTPSA(mol) <= 140
665
+ return np.sum([int(a) for a in [rule_1, rule_2]])
666
+
667
+
668
+ def load_pains_filters():
669
+ """
670
+ Loads the PAINS (Pan-Assay INterference compoundS) filters A, B, and C.
671
+
672
+ Returns:
673
+ FilterCatalog: An RDKit FilterCatalog object containing PAINS filters.
674
+ """
675
+ params = FilterCatalog.FilterCatalogParams()
676
+ params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_A)
677
+ params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_B)
678
+ params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_C)
679
+ catalog = FilterCatalog.FilterCatalog(params)
680
+ return catalog
681
+
682
+
683
+ def is_pains(mol, catalog):
684
+ """
685
+ Checks if the given molecule is a PAINS compound.
686
+
687
+ Args:
688
+ mol (RDKit Mol): Molecule object.
689
+ catalog (FilterCatalog): A catalog of PAINS filters.
690
+
691
+ Returns:
692
+ bool: True if the molecule matches a PAINS filter, else False.
693
+ """
694
+ entry = catalog.GetFirstMatch(mol)
695
+ return entry is not None
696
+
697
+
698
+ def mapper(n_jobs):
699
+ """
700
+ Returns a mapping function for parallel or serial processing.
701
+
702
+ If n_jobs == 1, returns the built-in map function.
703
+ If n_jobs > 1, returns a function that uses a multiprocessing pool.
704
+
705
+ Args:
706
+ n_jobs (int or pool object): Number of jobs or a Pool instance.
707
+
708
+ Returns:
709
+ callable: A function that acts like map.
710
+ """
711
+ if n_jobs == 1:
712
+ def _mapper(*args, **kwargs):
713
+ return list(map(*args, **kwargs))
714
+ return _mapper
715
+ if isinstance(n_jobs, int):
716
+ pool = Pool(n_jobs)
717
+ def _mapper(*args, **kwargs):
718
+ try:
719
+ result = pool.map(*args, **kwargs)
720
+ finally:
721
+ pool.terminate()
722
+ return result
723
+ return _mapper
724
+ return n_jobs.map
725
+
726
+
727
+ def fragmenter(mol):
728
+ """
729
+ Fragments a molecule using BRICS and returns a list of fragment SMILES.
730
+
731
+ Args:
732
+ mol (str or RDKit Mol): Input molecule.
733
+
734
+ Returns:
735
+ list: List of fragment SMILES strings.
736
+ """
737
+ fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
738
+ fgs_smi = Chem.MolToSmiles(fgs).split(".")
739
+ return fgs_smi
740
+
741
+
742
+ def get_mol(smiles_or_mol):
743
+ """
744
+ Loads a SMILES string or molecule into an RDKit molecule object.
745
+
746
+ Args:
747
+ smiles_or_mol (str or RDKit Mol): SMILES string or molecule.
748
+
749
+ Returns:
750
+ RDKit Mol or None: Sanitized molecule object or None if invalid.
751
+ """
752
+ if isinstance(smiles_or_mol, str):
753
+ if len(smiles_or_mol) == 0:
754
+ return None
755
+ mol = Chem.MolFromSmiles(smiles_or_mol)
756
+ if mol is None:
757
+ return None
758
+ try:
759
+ Chem.SanitizeMol(mol)
760
+ except ValueError:
761
+ return None
762
+ return mol
763
+ return smiles_or_mol
764
+
765
+
766
+ def compute_fragments(mol_list, n_jobs=1):
767
+ """
768
+ Fragments a list of molecules using BRICS and returns a counter of fragment occurrences.
769
+
770
+ Args:
771
+ mol_list (list): List of molecules (SMILES or RDKit Mol).
772
+ n_jobs (int): Number of parallel jobs.
773
+
774
+ Returns:
775
+ Counter: A Counter dictionary mapping fragment SMILES to counts.
776
+ """
777
+ fragments = Counter()
778
+ for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
779
+ fragments.update(mol_frag)
780
+ return fragments
781
+
782
+
783
+ def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
784
+ """
785
+ Extracts scaffolds from a list of molecules as canonical SMILES.
786
+
787
+ Only scaffolds with at least min_rings rings are considered.
788
+
789
+ Args:
790
+ mol_list (list): List of molecules.
791
+ n_jobs (int): Number of parallel jobs.
792
+ min_rings (int): Minimum number of rings required in a scaffold.
793
+
794
+ Returns:
795
+ Counter: A Counter mapping scaffold SMILES to counts.
796
+ """
797
+ scaffolds = Counter()
798
+ map_ = mapper(n_jobs)
799
+ scaffolds = Counter(map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
800
+ if None in scaffolds:
801
+ scaffolds.pop(None)
802
+ return scaffolds
803
+
804
+
805
+ def get_n_rings(mol):
806
+ """
807
+ Computes the number of rings in a molecule.
808
+
809
+ Args:
810
+ mol (RDKit Mol): Molecule object.
811
+
812
+ Returns:
813
+ int: Number of rings.
814
+ """
815
+ return mol.GetRingInfo().NumRings()
816
+
817
+
818
+ def compute_scaffold(mol, min_rings=2):
819
+ """
820
+ Computes the Murcko scaffold of a molecule and returns its canonical SMILES if it has enough rings.
821
+
822
+ Args:
823
+ mol (str or RDKit Mol): Input molecule.
824
+ min_rings (int): Minimum number of rings required.
825
+
826
+ Returns:
827
+ str or None: Canonical SMILES of the scaffold if valid, else None.
828
+ """
829
+ mol = get_mol(mol)
830
+ try:
831
+ scaffold = MurckoScaffold.GetScaffoldForMol(mol)
832
+ except (ValueError, RuntimeError):
833
+ return None
834
+ n_rings = get_n_rings(scaffold)
835
+ scaffold_smiles = Chem.MolToSmiles(scaffold)
836
+ if scaffold_smiles == '' or n_rings < min_rings:
837
+ return None
838
+ return scaffold_smiles
839
+
840
+
841
+ class Metric:
842
+ """
843
+ Abstract base class for chemical metrics.
844
+
845
+ Derived classes should implement the precalc and metric methods.
846
+ """
847
+ def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
848
+ self.n_jobs = n_jobs
849
+ self.device = device
850
+ self.batch_size = batch_size
851
+ for k, v in kwargs.items():
852
+ setattr(self, k, v)
853
+
854
+ def __call__(self, ref=None, gen=None, pref=None, pgen=None):
855
+ """
856
+ Computes the metric between reference and generated molecules.
857
+
858
+ Exactly one of ref or pref, and gen or pgen should be provided.
859
+
860
+ Args:
861
+ ref: Reference molecule list.
862
+ gen: Generated molecule list.
863
+ pref: Precalculated reference metric.
864
+ pgen: Precalculated generated metric.
865
+
866
+ Returns:
867
+ Metric value computed by the metric method.
868
+ """
869
+ assert (ref is None) != (pref is None), "specify ref xor pref"
870
+ assert (gen is None) != (pgen is None), "specify gen xor pgen"
871
+ if pref is None:
872
+ pref = self.precalc(ref)
873
+ if pgen is None:
874
+ pgen = self.precalc(gen)
875
+ return self.metric(pref, pgen)
876
+
877
+ def precalc(self, molecules):
878
+ """
879
+ Pre-calculates necessary representations from a list of molecules.
880
+ Should be implemented by derived classes.
881
+ """
882
+ raise NotImplementedError
883
+
884
+ def metric(self, pref, pgen):
885
+ """
886
+ Computes the metric given precalculated representations.
887
+ Should be implemented by derived classes.
888
+ """
889
+ raise NotImplementedError
890
+
891
+
892
+ class FragMetric(Metric):
893
+ """
894
+ Metrics based on molecular fragments.
895
+ """
896
+ def precalc(self, mols):
897
+ return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
898
+
899
+ def metric(self, pref, pgen):
900
+ return cos_similarity(pref['frag'], pgen['frag'])
901
+
902
+
903
+ class ScafMetric(Metric):
904
+ """
905
+ Metrics based on molecular scaffolds.
906
+ """
907
+ def precalc(self, mols):
908
+ return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
909
+
910
+ def metric(self, pref, pgen):
911
+ return cos_similarity(pref['scaf'], pgen['scaf'])
912
+
913
+
914
+ def cos_similarity(ref_counts, gen_counts):
915
+ """
916
+ Computes cosine similarity between two molecular vectors.
917
+
918
+ Args:
919
+ ref_counts (dict): Reference molecular vectors.
920
+ gen_counts (dict): Generated molecular vectors.
921
+
922
+ Returns:
923
+ float: Cosine similarity between the two molecular vectors.
924
+ """
925
+ if len(ref_counts) == 0 or len(gen_counts) == 0:
926
+ return np.nan
927
+ keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
928
+ ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
929
+ gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
930
+ return 1 - cos_distance(ref_vec, gen_vec)
train.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import pickle
5
+ import argparse
6
+ import os.path as osp
7
+
8
+ import torch
9
+ import torch.utils.data
10
+ from torch import nn
11
+ from torch_geometric.loader import DataLoader
12
+
13
+ import wandb
14
+ from rdkit import RDLogger
15
+
16
+ torch.set_num_threads(5)
17
+ RDLogger.DisableLog('rdApp.*')
18
+
19
+ from src.util.utils import *
20
+ from src.model.models import Generator, Discriminator, simple_disc
21
+ from src.data.dataset import DruggenDataset
22
+ from src.data.utils import get_encoders_decoders, load_molecules
23
+ from src.model.loss import discriminator_loss, generator_loss
24
+
25
+ class Train(object):
26
+ """Trainer for DrugGEN."""
27
+
28
+ def __init__(self, config):
29
+ if config.set_seed:
30
+ np.random.seed(config.seed)
31
+ random.seed(config.seed)
32
+ torch.manual_seed(config.seed)
33
+ torch.cuda.manual_seed_all(config.seed)
34
+
35
+ torch.backends.cudnn.deterministic = True
36
+ torch.backends.cudnn.benchmark = False
37
+
38
+ os.environ["PYTHONHASHSEED"] = str(config.seed)
39
+
40
+ print(f'Using seed {config.seed}')
41
+
42
+ self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
43
+
44
+ # Initialize configurations
45
+ self.submodel = config.submodel
46
+
47
+ # Data loader.
48
+ self.raw_file = config.raw_file # SMILES containing text file for dataset.
49
+ # Write the full path to file.
50
+ self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
51
+ # Write the full path to file.
52
+
53
+ # Automatically infer dataset file names from raw file names
54
+ raw_file_basename = osp.basename(self.raw_file)
55
+ drug_raw_file_basename = osp.basename(self.drug_raw_file)
56
+
57
+ # Get the base name without extension and add max_atom to it
58
+ self.max_atom = config.max_atom # Model is based on one-shot generation.
59
+ raw_file_base = os.path.splitext(raw_file_basename)[0]
60
+ drug_raw_file_base = os.path.splitext(drug_raw_file_basename)[0]
61
+
62
+ # Change extension from .smi to .pt and add max_atom to the filename
63
+ self.dataset_file = f"{raw_file_base}{self.max_atom}.pt"
64
+ self.drugs_dataset_file = f"{drug_raw_file_base}{self.max_atom}.pt"
65
+
66
+ self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
67
+ self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
68
+ self.dataset_name = self.dataset_file.split(".")[0]
69
+ self.drugs_dataset_name = self.drugs_dataset_file.split(".")[0]
70
+ self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
71
+ # Additional node features can be added. Please check new_dataloarder.py Line 102.
72
+ self.batch_size = config.batch_size # Batch size for training.
73
+
74
+ self.parallel = config.parallel
75
+
76
+ # Get atom and bond encoders/decoders
77
+ atom_encoder, atom_decoder, bond_encoder, bond_decoder = get_encoders_decoders(
78
+ self.raw_file,
79
+ self.drug_raw_file,
80
+ self.max_atom
81
+ )
82
+ self.atom_encoder = atom_encoder
83
+ self.atom_decoder = atom_decoder
84
+ self.bond_encoder = bond_encoder
85
+ self.bond_decoder = bond_decoder
86
+
87
+ self.dataset = DruggenDataset(self.mol_data_dir,
88
+ self.dataset_file,
89
+ self.raw_file,
90
+ self.max_atom,
91
+ self.features,
92
+ atom_encoder=atom_encoder,
93
+ atom_decoder=atom_decoder,
94
+ bond_encoder=bond_encoder,
95
+ bond_decoder=bond_decoder)
96
+
97
+ self.loader = DataLoader(self.dataset,
98
+ shuffle=True,
99
+ batch_size=self.batch_size,
100
+ drop_last=True) # PyG dataloader for the GAN.
101
+
102
+ self.drugs = DruggenDataset(self.drug_data_dir,
103
+ self.drugs_dataset_file,
104
+ self.drug_raw_file,
105
+ self.max_atom,
106
+ self.features,
107
+ atom_encoder=atom_encoder,
108
+ atom_decoder=atom_decoder,
109
+ bond_encoder=bond_encoder,
110
+ bond_decoder=bond_decoder)
111
+
112
+ self.drugs_loader = DataLoader(self.drugs,
113
+ shuffle=True,
114
+ batch_size=self.batch_size,
115
+ drop_last=True) # PyG dataloader for the second GAN.
116
+
117
+ self.m_dim = len(self.atom_decoder) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
118
+ self.b_dim = len(self.bond_decoder) # Bond type dimension.
119
+ self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
120
+
121
+ # Model configurations.
122
+ self.act = config.act
123
+ self.lambda_gp = config.lambda_gp
124
+ self.dim = config.dim
125
+ self.depth = config.depth
126
+ self.heads = config.heads
127
+ self.mlp_ratio = config.mlp_ratio
128
+ self.ddepth = config.ddepth
129
+ self.ddropout = config.ddropout
130
+
131
+ # Training configurations.
132
+ self.epoch = config.epoch
133
+ self.g_lr = config.g_lr
134
+ self.d_lr = config.d_lr
135
+ self.dropout = config.dropout
136
+ self.beta1 = config.beta1
137
+ self.beta2 = config.beta2
138
+
139
+ # Directories.
140
+ self.log_dir = config.log_dir
141
+ self.sample_dir = config.sample_dir
142
+ self.model_save_dir = config.model_save_dir
143
+
144
+ # Step size.
145
+ self.log_step = config.log_sample_step
146
+
147
+ # resume training
148
+ self.resume = config.resume
149
+ self.resume_epoch = config.resume_epoch
150
+ self.resume_iter = config.resume_iter
151
+ self.resume_directory = config.resume_directory
152
+
153
+ # wandb configuration
154
+ self.use_wandb = config.use_wandb
155
+ self.online = config.online
156
+ self.exp_name = config.exp_name
157
+
158
+ # Arguments for the model.
159
+ self.arguments = "{}_{}_glr{}_dlr{}_dim{}_depth{}_heads{}_batch{}_epoch{}_dataset{}_dropout{}".format(self.exp_name, self.submodel, self.g_lr, self.d_lr, self.dim, self.depth, self.heads, self.batch_size, self.epoch, self.dataset_name, self.dropout)
160
+
161
+ self.build_model(self.model_save_dir, self.arguments)
162
+
163
+
164
+ def build_model(self, model_save_dir, arguments):
165
+ """Create generators and discriminators."""
166
+
167
+ ''' Generator is based on Transformer Encoder:
168
+
169
+ @ g_conv_dim: Dimensions for MLP layers before Transformer Encoder
170
+ @ vertexes: maximum length of generated molecules (atom length)
171
+ @ b_dim: number of bond types
172
+ @ m_dim: number of atom types (or number of features used)
173
+ @ dropout: dropout possibility
174
+ @ dim: Hidden dimension of Transformer Encoder
175
+ @ depth: Transformer layer number
176
+ @ heads: Number of multihead-attention heads
177
+ @ mlp_ratio: Read-out layer dimension of Transformer
178
+ @ drop_rate: depricated
179
+ @ tra_conv: Whether module creates output for TransformerConv discriminator
180
+ '''
181
+ self.G = Generator(self.act,
182
+ self.vertexes,
183
+ self.b_dim,
184
+ self.m_dim,
185
+ self.dropout,
186
+ dim=self.dim,
187
+ depth=self.depth,
188
+ heads=self.heads,
189
+ mlp_ratio=self.mlp_ratio)
190
+
191
+ ''' Discriminator implementation with Transformer Encoder:
192
+
193
+ @ act: Activation function for MLP
194
+ @ vertexes: maximum length of generated molecules (molecule length)
195
+ @ b_dim: number of bond types
196
+ @ m_dim: number of atom types (or number of features used)
197
+ @ dropout: dropout possibility
198
+ @ dim: Hidden dimension of Transformer Encoder
199
+ @ depth: Transformer layer number
200
+ @ heads: Number of multihead-attention heads
201
+ @ mlp_ratio: Read-out layer dimension of Transformer'''
202
+
203
+ self.D = Discriminator(self.act,
204
+ self.vertexes,
205
+ self.b_dim,
206
+ self.m_dim,
207
+ self.ddropout,
208
+ dim=self.dim,
209
+ depth=self.ddepth,
210
+ heads=self.heads,
211
+ mlp_ratio=self.mlp_ratio)
212
+
213
+ self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
214
+ self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
215
+
216
+ network_path = os.path.join(model_save_dir, arguments)
217
+ self.print_network(self.G, 'G', network_path)
218
+ self.print_network(self.D, 'D', network_path)
219
+
220
+ if self.parallel and torch.cuda.device_count() > 1:
221
+ print(f"Using {torch.cuda.device_count()} GPUs!")
222
+ self.G = nn.DataParallel(self.G)
223
+ self.D = nn.DataParallel(self.D)
224
+
225
+ self.G.to(self.device)
226
+ self.D.to(self.device)
227
+
228
+ def print_network(self, model, name, save_dir):
229
+ """Print out the network information."""
230
+ num_params = 0
231
+ for p in model.parameters():
232
+ num_params += p.numel()
233
+
234
+ if not os.path.exists(save_dir):
235
+ os.makedirs(save_dir)
236
+
237
+ network_path = os.path.join(save_dir, "{}_modules.txt".format(name))
238
+ with open(network_path, "w+") as file:
239
+ for module in model.modules():
240
+ file.write(f"{module.__class__.__name__}:\n")
241
+ print(module.__class__.__name__)
242
+ for n, param in module.named_parameters():
243
+ if param is not None:
244
+ file.write(f" - {n}: {param.size()}\n")
245
+ print(f" - {n}: {param.size()}")
246
+ break
247
+ file.write(f"Total number of parameters: {num_params}\n")
248
+ print(f"Total number of parameters: {num_params}\n\n")
249
+
250
+ def restore_model(self, epoch, iteration, model_directory):
251
+ """Restore the trained generator and discriminator."""
252
+ print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
253
+
254
+ G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
255
+ D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
256
+ self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
257
+ self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
258
+
259
+ def save_model(self, model_directory, idx,i):
260
+ G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
261
+ D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
262
+ torch.save(self.G.state_dict(), G_path)
263
+ torch.save(self.D.state_dict(), D_path)
264
+
265
+ def reset_grad(self):
266
+ """Reset the gradient buffers."""
267
+ self.g_optimizer.zero_grad()
268
+ self.d_optimizer.zero_grad()
269
+
270
+ def train(self, config):
271
+ ''' Training Script starts from here'''
272
+ if self.use_wandb:
273
+ mode = 'online' if self.online else 'offline'
274
+ else:
275
+ mode = 'disabled'
276
+ kwargs = {'name': self.exp_name, 'project': 'druggen', 'config': config,
277
+ 'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode, 'save_code': True}
278
+ wandb.init(**kwargs)
279
+
280
+ wandb.save(os.path.join(self.model_save_dir, self.arguments, "G_modules.txt"))
281
+ wandb.save(os.path.join(self.model_save_dir, self.arguments, "D_modules.txt"))
282
+
283
+ self.model_directory = os.path.join(self.model_save_dir, self.arguments)
284
+ self.sample_directory = os.path.join(self.sample_dir, self.arguments)
285
+ self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
286
+ if not os.path.exists(self.model_directory):
287
+ os.makedirs(self.model_directory)
288
+ if not os.path.exists(self.sample_directory):
289
+ os.makedirs(self.sample_directory)
290
+
291
+ # smiles data for metrics calculation.
292
+ drug_smiles = [line for line in open(self.drug_raw_file, 'r').read().splitlines()]
293
+ drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
294
+ drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
295
+
296
+ if self.resume:
297
+ self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
298
+
299
+ # Start training.
300
+ print('Start training...')
301
+ self.start_time = time.time()
302
+ for idx in range(self.epoch):
303
+ # =================================================================================== #
304
+ # 1. Preprocess input data #
305
+ # =================================================================================== #
306
+ # Load the data
307
+ dataloader_iterator = iter(self.drugs_loader)
308
+
309
+ wandb.log({"epoch": idx})
310
+
311
+ for i, data in enumerate(self.loader):
312
+ try:
313
+ drugs = next(dataloader_iterator)
314
+ except StopIteration:
315
+ dataloader_iterator = iter(self.drugs_loader)
316
+ drugs = next(dataloader_iterator)
317
+
318
+ wandb.log({"iter": i})
319
+
320
+ # Preprocess both dataset
321
+ real_graphs, a_tensor, x_tensor = load_molecules(
322
+ data=data,
323
+ batch_size=self.batch_size,
324
+ device=self.device,
325
+ b_dim=self.b_dim,
326
+ m_dim=self.m_dim,
327
+ )
328
+
329
+ drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
330
+ data=drugs,
331
+ batch_size=self.batch_size,
332
+ device=self.device,
333
+ b_dim=self.b_dim,
334
+ m_dim=self.m_dim,
335
+ )
336
+
337
+ # Training configuration.
338
+ GEN_node = x_tensor # Generator input node features (annotation matrix of real molecules)
339
+ GEN_edge = a_tensor # Generator input edge features (adjacency matrix of real molecules)
340
+ if self.submodel == "DrugGEN":
341
+ DISC_node = drugs_x_tensor # Discriminator input node features (annotation matrix of drug molecules)
342
+ DISC_edge = drugs_a_tensor # Discriminator input edge features (adjacency matrix of drug molecules)
343
+ elif self.submodel == "NoTarget":
344
+ DISC_node = x_tensor # Discriminator input node features (annotation matrix of real molecules)
345
+ DISC_edge = a_tensor # Discriminator input edge features (adjacency matrix of real molecules)
346
+
347
+ # =================================================================================== #
348
+ # 2. Train the GAN #
349
+ # =================================================================================== #
350
+
351
+ loss = {}
352
+ self.reset_grad()
353
+ # Compute discriminator loss.
354
+ node, edge, d_loss = discriminator_loss(self.G,
355
+ self.D,
356
+ DISC_edge,
357
+ DISC_node,
358
+ GEN_edge,
359
+ GEN_node,
360
+ self.batch_size,
361
+ self.device,
362
+ self.lambda_gp)
363
+ d_total = d_loss
364
+ wandb.log({"d_loss": d_total.item()})
365
+
366
+ loss["d_total"] = d_total.item()
367
+ d_total.backward()
368
+ self.d_optimizer.step()
369
+
370
+ self.reset_grad()
371
+
372
+ # Compute generator loss.
373
+ generator_output = generator_loss(self.G,
374
+ self.D,
375
+ GEN_edge,
376
+ GEN_node,
377
+ self.batch_size)
378
+ g_loss, node, edge, node_sample, edge_sample = generator_output
379
+ g_total = g_loss
380
+ wandb.log({"g_loss": g_total.item()})
381
+
382
+ loss["g_total"] = g_total.item()
383
+ g_total.backward()
384
+ self.g_optimizer.step()
385
+
386
+ # Logging.
387
+ if (i+1) % self.log_step == 0:
388
+ logging(self.log_path, self.start_time, i, idx, loss, self.sample_directory,
389
+ drug_smiles,edge_sample, node_sample, self.dataset.matrices2mol,
390
+ self.dataset_name, a_tensor, x_tensor, drug_vecs)
391
+
392
+ mol_sample(self.sample_directory, edge_sample.detach(), node_sample.detach(),
393
+ idx, i, self.dataset.matrices2mol, self.dataset_name)
394
+ print("samples saved at epoch {} and iteration {}".format(idx,i))
395
+
396
+ self.save_model(self.model_directory, idx, i)
397
+ print("model saved at epoch {} and iteration {}".format(idx,i))
398
+
399
+
400
+ if __name__ == '__main__':
401
+ parser = argparse.ArgumentParser()
402
+
403
+ # Data configuration.
404
+ parser.add_argument('--raw_file', type=str, required=True)
405
+ parser.add_argument('--drug_raw_file', type=str, required=False, help='Required for DrugGEN model, optional for NoTarget')
406
+ parser.add_argument('--drug_data_dir', type=str, default='data')
407
+ parser.add_argument('--mol_data_dir', type=str, default='data')
408
+ parser.add_argument('--features', action='store_true', help='features dimension for nodes')
409
+
410
+ # Model configuration.
411
+ parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
412
+ parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
413
+ parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
414
+ parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
415
+ parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
416
+ parser.add_argument('--ddepth', type=int, default=1, help='Depth of the Transformer model from the discriminator.')
417
+ parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
418
+ parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
419
+ parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
420
+ parser.add_argument('--ddropout', type=float, default=0., help='dropout rate for the discriminator')
421
+ parser.add_argument('--lambda_gp', type=float, default=10, help='Gradient penalty lambda multiplier for the GAN.')
422
+
423
+ # Training configuration.
424
+ parser.add_argument('--batch_size', type=int, default=128, help='Batch size for the training.')
425
+ parser.add_argument('--epoch', type=int, default=10, help='Epoch number for Training.')
426
+ parser.add_argument('--g_lr', type=float, default=0.00001, help='learning rate for G')
427
+ parser.add_argument('--d_lr', type=float, default=0.00001, help='learning rate for D')
428
+ parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer')
429
+ parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
430
+ parser.add_argument('--log_dir', type=str, default='experiments/logs')
431
+ parser.add_argument('--sample_dir', type=str, default='experiments/samples')
432
+ parser.add_argument('--model_save_dir', type=str, default='experiments/models')
433
+ parser.add_argument('--log_sample_step', type=int, default=1000, help='step size for sampling during training')
434
+
435
+ # Resume training.
436
+ parser.add_argument('--resume', type=bool, default=False, help='resume training')
437
+ parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch')
438
+ parser.add_argument('--resume_iter', type=int, default=None, help='resume training from this step')
439
+ parser.add_argument('--resume_directory', type=str, default=None, help='load pretrained weights from this directory')
440
+
441
+ # Seed configuration.
442
+ parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
443
+ parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
444
+
445
+ # wandb configuration.
446
+ parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging')
447
+ parser.add_argument('--online', action='store_true', help='use wandb online')
448
+ parser.add_argument('--exp_name', type=str, default='druggen', help='experiment name')
449
+ parser.add_argument('--parallel', action='store_true', help='Parallelize training')
450
+
451
+ config = parser.parse_args()
452
+
453
+ # Check if drug_raw_file is provided when using DrugGEN model
454
+ if config.submodel == "DrugGEN" and not config.drug_raw_file:
455
+ parser.error("--drug_raw_file is required when using DrugGEN model")
456
+
457
+ # If using NoTarget model and drug_raw_file is not provided, use a dummy file
458
+ if config.submodel == "NoTarget" and not config.drug_raw_file:
459
+ config.drug_raw_file = "data/akt_train.smi" # Use a reference file for NoTarget model (AKT) (not used for training for ease of use and encoder/decoder's)
460
+
461
+ trainer = Train(config)
462
+ trainer.train(config)