Spaces:
Running
Running
refactor
Browse files- app.py +288 -0
- inference.py +303 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/dataset.cpython-310.pyc +0 -0
- src/data/__pycache__/utils.cpython-310.pyc +0 -0
- src/data/dataset.py +317 -0
- src/data/utils.py +143 -0
- src/model/__init__.py +0 -0
- src/model/__pycache__/__init__.cpython-310.pyc +0 -0
- src/model/__pycache__/layers.cpython-310.pyc +0 -0
- src/model/__pycache__/loss.cpython-310.pyc +0 -0
- src/model/__pycache__/models.cpython-310.pyc +0 -0
- src/model/layers.py +234 -0
- src/model/loss.py +85 -0
- src/model/models.py +269 -0
- src/util/__init__.py +0 -0
- src/util/__pycache__/__init__.cpython-310.pyc +0 -0
- src/util/__pycache__/smiles_cor.cpython-310.pyc +0 -0
- src/util/__pycache__/utils.cpython-310.pyc +0 -0
- src/util/smiles_cor.py +1284 -0
- src/util/utils.py +930 -0
- train.py +462 -0
app.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from inference import Inference
|
3 |
+
import PIL
|
4 |
+
from PIL import Image
|
5 |
+
import pandas as pd
|
6 |
+
import random
|
7 |
+
from rdkit import Chem
|
8 |
+
from rdkit.Chem import Draw
|
9 |
+
from rdkit.Chem.Draw import IPythonConsole
|
10 |
+
import shutil
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
|
14 |
+
class DrugGENConfig:
|
15 |
+
# Inference configuration
|
16 |
+
submodel='DrugGEN'
|
17 |
+
inference_model="experiments/models/DrugGEN/"
|
18 |
+
sample_num=100
|
19 |
+
disable_correction=False # corresponds to correct=True in old config
|
20 |
+
|
21 |
+
# Data configuration
|
22 |
+
inf_smiles='data/chembl_test.smi' # corresponds to inf_raw_file in old config
|
23 |
+
train_smiles='data/chembl_train.smi'
|
24 |
+
train_drug_smiles='data/akt1_train.smi'
|
25 |
+
inf_batch_size=1
|
26 |
+
mol_data_dir='data'
|
27 |
+
features=False
|
28 |
+
|
29 |
+
# Model configuration
|
30 |
+
act='relu'
|
31 |
+
max_atom=45
|
32 |
+
dim=128
|
33 |
+
depth=1
|
34 |
+
heads=8
|
35 |
+
mlp_ratio=3
|
36 |
+
dropout=0.
|
37 |
+
|
38 |
+
# Seed configuration
|
39 |
+
set_seed=True
|
40 |
+
seed=10
|
41 |
+
|
42 |
+
|
43 |
+
class DrugGENAKT1Config(DrugGENConfig):
|
44 |
+
submodel='DrugGEN'
|
45 |
+
inference_model="experiments/models/DrugGEN-AKT1/"
|
46 |
+
train_drug_smiles='data/akt1_train.smi'
|
47 |
+
max_atom=45
|
48 |
+
|
49 |
+
|
50 |
+
class DrugGENCDK2Config(DrugGENConfig):
|
51 |
+
submodel='DrugGEN'
|
52 |
+
inference_model="experiments/models/DrugGEN-CDK2/"
|
53 |
+
train_drug_smiles='data/cdk2_train.smi'
|
54 |
+
max_atom=38
|
55 |
+
|
56 |
+
|
57 |
+
class NoTargetConfig(DrugGENConfig):
|
58 |
+
submodel="NoTarget"
|
59 |
+
inference_model="experiments/models/NoTarget/"
|
60 |
+
train_drug_smiles='data/chembl_train.smi' # No specific target, use general ChEMBL data
|
61 |
+
|
62 |
+
|
63 |
+
model_configs = {
|
64 |
+
"DrugGEN-AKT1": DrugGENAKT1Config(),
|
65 |
+
"DrugGEN-CDK2": DrugGENCDK2Config(),
|
66 |
+
"DrugGEN-NoTarget": NoTargetConfig(),
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def function(model_name: str, num_molecules: int, seed_num: int) -> tuple[PIL.Image, pd.DataFrame, str]:
|
72 |
+
'''
|
73 |
+
Returns:
|
74 |
+
image, score_df, file path
|
75 |
+
'''
|
76 |
+
if model_name == "DrugGEN-NoTarget":
|
77 |
+
model_name = "NoTarget"
|
78 |
+
|
79 |
+
config = model_configs[model_name]
|
80 |
+
config.sample_num = num_molecules
|
81 |
+
|
82 |
+
if config.sample_num > 250:
|
83 |
+
raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
|
84 |
+
|
85 |
+
if seed_num is None or seed_num.strip() == "":
|
86 |
+
config.seed = random.randint(0, 10000)
|
87 |
+
else:
|
88 |
+
try:
|
89 |
+
config.seed = int(seed_num)
|
90 |
+
except ValueError:
|
91 |
+
raise gr.Error("The seed must be an integer value!")
|
92 |
+
|
93 |
+
|
94 |
+
inferer = Inference(config)
|
95 |
+
start_time = time.time()
|
96 |
+
scores = inferer.inference() # create scores_df out of this
|
97 |
+
et = time.time() - start_time
|
98 |
+
|
99 |
+
score_df = pd.DataFrame({
|
100 |
+
"Runtime (seconds)": [et],
|
101 |
+
"Validity": [scores["validity"].iloc[0]],
|
102 |
+
"Uniqueness": [scores["uniqueness"].iloc[0]],
|
103 |
+
"Novelty (Train)": [scores["novelty"].iloc[0]],
|
104 |
+
"Novelty (Test)": [scores["novelty_test"].iloc[0]],
|
105 |
+
"Drug Novelty": [scores["drug_novelty"].iloc[0]],
|
106 |
+
"Max Length": [scores["max_len"].iloc[0]],
|
107 |
+
"Mean Atom Type": [scores["mean_atom_type"].iloc[0]],
|
108 |
+
"SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
|
109 |
+
"SNN Drug": [scores["snn_drug"].iloc[0]],
|
110 |
+
"Internal Diversity": [scores["IntDiv"].iloc[0]],
|
111 |
+
"QED": [scores["qed"].iloc[0]],
|
112 |
+
"SA Score": [scores["sa"].iloc[0]]
|
113 |
+
})
|
114 |
+
|
115 |
+
output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt'
|
116 |
+
|
117 |
+
new_path = f'{model_name}_denovo_mols.smi'
|
118 |
+
os.rename(output_file_path, new_path)
|
119 |
+
|
120 |
+
with open(new_path) as f:
|
121 |
+
inference_drugs = f.read()
|
122 |
+
|
123 |
+
generated_molecule_list = inference_drugs.split("\n")[:-1]
|
124 |
+
|
125 |
+
rng = random.Random(config.seed)
|
126 |
+
if num_molecules > 12:
|
127 |
+
selected_molecules = rng.choices(generated_molecule_list, k=12)
|
128 |
+
else:
|
129 |
+
selected_molecules = generated_molecule_list
|
130 |
+
|
131 |
+
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None]
|
132 |
+
|
133 |
+
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
|
134 |
+
drawOptions.prepareMolsBeforeDrawing = False
|
135 |
+
drawOptions.bondLineWidth = 0.5
|
136 |
+
|
137 |
+
molecule_image = Draw.MolsToGridImage(
|
138 |
+
selected_molecules,
|
139 |
+
molsPerRow=3,
|
140 |
+
subImgSize=(400, 400),
|
141 |
+
maxMols=len(selected_molecules),
|
142 |
+
# legends=None,
|
143 |
+
returnPNG=False,
|
144 |
+
drawOptions=drawOptions,
|
145 |
+
highlightAtomLists=None,
|
146 |
+
highlightBondLists=None,
|
147 |
+
)
|
148 |
+
|
149 |
+
return molecule_image, score_df, new_path
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
154 |
+
with gr.Row():
|
155 |
+
with gr.Column(scale=1):
|
156 |
+
gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
|
157 |
+
with gr.Row():
|
158 |
+
gr.Markdown("[](https://arxiv.org/abs/2302.07868)")
|
159 |
+
gr.Markdown("[](https://github.com/HUBioDataLab/DrugGEN)")
|
160 |
+
|
161 |
+
with gr.Accordion("About DrugGEN Models", open=False):
|
162 |
+
gr.Markdown("""
|
163 |
+
## Model Variations
|
164 |
+
|
165 |
+
### DrugGEN-AKT1
|
166 |
+
This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749), a serine/threonine-protein kinase that plays a key role in regulating cell survival, metabolism, and growth. AKT1 is a significant target in cancer therapy, particularly for breast, colorectal, and ovarian cancers.
|
167 |
+
|
168 |
+
The model learns from:
|
169 |
+
- General drug-like molecules from ChEMBL database
|
170 |
+
- Known AKT1 inhibitors
|
171 |
+
- Maximum atom count: 45
|
172 |
+
|
173 |
+
### DrugGEN-CDK2
|
174 |
+
This model targets the human CDK2 protein (UniProt ID: P24941), a cyclin-dependent kinase involved in cell cycle regulation. CDK2 inhibitors are being investigated for treating various cancers, particularly those with dysregulated cell cycle control.
|
175 |
+
|
176 |
+
The model learns from:
|
177 |
+
- General drug-like molecules from ChEMBL database
|
178 |
+
- Known CDK2 inhibitors
|
179 |
+
- Maximum atom count: 38
|
180 |
+
|
181 |
+
### DrugGEN-NoTarget
|
182 |
+
This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for:
|
183 |
+
- Exploring chemical space
|
184 |
+
- Generating diverse scaffolds
|
185 |
+
- Creating molecules with drug-like properties
|
186 |
+
|
187 |
+
## How It Works
|
188 |
+
DrugGEN uses a graph-based generative adversarial network (GAN) architecture where:
|
189 |
+
1. The generator creates molecular graphs
|
190 |
+
2. The discriminator evaluates them against real molecules
|
191 |
+
3. The model learns to generate increasingly realistic and target-specific molecules
|
192 |
+
|
193 |
+
For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
|
194 |
+
""")
|
195 |
+
|
196 |
+
with gr.Accordion("Understanding the Metrics", open=False):
|
197 |
+
gr.Markdown("""
|
198 |
+
## Evaluation Metrics
|
199 |
+
|
200 |
+
### Basic Metrics
|
201 |
+
- **Validity**: Percentage of generated molecules that are chemically valid
|
202 |
+
- **Uniqueness**: Percentage of unique molecules among valid ones
|
203 |
+
- **Runtime**: Time taken to generate the requested molecules
|
204 |
+
|
205 |
+
### Novelty Metrics
|
206 |
+
- **Novelty (Train)**: Percentage of molecules not found in the training set
|
207 |
+
- **Novelty (Test)**: Percentage of molecules not found in the test set
|
208 |
+
- **Drug Novelty**: Percentage of molecules not found in known drugs
|
209 |
+
|
210 |
+
### Structural Metrics
|
211 |
+
- **Max Length**: Maximum component length in the generated molecules
|
212 |
+
- **Mean Atom Type**: Average distribution of atom types
|
213 |
+
- **Internal Diversity**: Diversity within the generated set (higher is more diverse)
|
214 |
+
|
215 |
+
### Drug-likeness Metrics
|
216 |
+
- **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
|
217 |
+
- **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier)
|
218 |
+
|
219 |
+
### Similarity Metrics
|
220 |
+
- **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
|
221 |
+
- **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs)
|
222 |
+
""")
|
223 |
+
|
224 |
+
model_name = gr.Radio(
|
225 |
+
choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
|
226 |
+
value="DrugGEN-AKT1",
|
227 |
+
label="Select Target Model",
|
228 |
+
info="Choose which protein target or general model to use for molecule generation"
|
229 |
+
)
|
230 |
+
|
231 |
+
num_molecules = gr.Slider(
|
232 |
+
minimum=10,
|
233 |
+
maximum=250,
|
234 |
+
value=100,
|
235 |
+
step=10,
|
236 |
+
label="Number of Molecules to Generate",
|
237 |
+
info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU.""
|
238 |
+
)
|
239 |
+
|
240 |
+
seed_num = gr.Textbox(
|
241 |
+
label="Random Seed (Optional)",
|
242 |
+
value="",
|
243 |
+
info="Set a specific seed for reproducible results, or leave empty for random generation"
|
244 |
+
)
|
245 |
+
|
246 |
+
submit_button = gr.Button(
|
247 |
+
value="Generate Molecules",
|
248 |
+
variant="primary",
|
249 |
+
size="lg"
|
250 |
+
)
|
251 |
+
|
252 |
+
with gr.Column(scale=2):
|
253 |
+
with gr.Tabs():
|
254 |
+
with gr.TabItem("Generated Molecules"):
|
255 |
+
image_output = gr.Image(
|
256 |
+
label="Sample of Generated Molecules",
|
257 |
+
elem_id="molecule_display"
|
258 |
+
)
|
259 |
+
file_download = gr.File(
|
260 |
+
label="Download All Generated Molecules (SMILES format)",
|
261 |
+
)
|
262 |
+
|
263 |
+
with gr.TabItem("Performance Metrics"):
|
264 |
+
scores_df = gr.Dataframe(
|
265 |
+
label="Model Performance Metrics",
|
266 |
+
headers=["Runtime (seconds)", "Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)",
|
267 |
+
"Drug Novelty", "Max Length", "Mean Atom Type", "SNN ChEMBL", "SNN Drug",
|
268 |
+
"Internal Diversity", "QED", "SA Score"]
|
269 |
+
)
|
270 |
+
|
271 |
+
with gr.Accordion("Generation Settings", open=False):
|
272 |
+
gr.Markdown("""
|
273 |
+
## Technical Details
|
274 |
+
|
275 |
+
- This demo runs on CPU which limits generation speed
|
276 |
+
- Generating 200 molecules takes approximately 6 minutes
|
277 |
+
- For faster generation or larger batches, run the model on GPU using our GitHub repository
|
278 |
+
- The model uses a graph-based representation of molecules
|
279 |
+
- Maximum atom count varies by model (AKT1: 45, CDK2: 38)
|
280 |
+
""")
|
281 |
+
|
282 |
+
gr.Markdown("### Created by the HU BioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
|
283 |
+
|
284 |
+
submit_button.click(function, inputs=[model_name, num_molecules, seed_num], outputs=[image_output, scores_df, file_download], api_name="inference")
|
285 |
+
#demo.queue(concurrency_count=1)
|
286 |
+
demo.queue()
|
287 |
+
demo.launch()
|
288 |
+
|
inference.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import time
|
4 |
+
import random
|
5 |
+
import pickle
|
6 |
+
import argparse
|
7 |
+
import os.path as osp
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.utils.data
|
11 |
+
from torch_geometric.loader import DataLoader
|
12 |
+
|
13 |
+
import pandas as pd
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from rdkit import RDLogger, Chem
|
17 |
+
from rdkit.Chem import QED, RDConfig
|
18 |
+
|
19 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
20 |
+
import sascorer
|
21 |
+
|
22 |
+
from src.util.utils import *
|
23 |
+
from src.model.models import Generator
|
24 |
+
from src.data.dataset import DruggenDataset
|
25 |
+
from src.data.utils import get_encoders_decoders, load_molecules
|
26 |
+
from src.model.loss import generator_loss
|
27 |
+
from src.util.smiles_cor import smi_correct
|
28 |
+
|
29 |
+
|
30 |
+
class Inference(object):
|
31 |
+
"""Inference class for DrugGEN."""
|
32 |
+
|
33 |
+
def __init__(self, config):
|
34 |
+
if config.set_seed:
|
35 |
+
np.random.seed(config.seed)
|
36 |
+
random.seed(config.seed)
|
37 |
+
torch.manual_seed(config.seed)
|
38 |
+
torch.cuda.manual_seed_all(config.seed)
|
39 |
+
|
40 |
+
torch.backends.cudnn.deterministic = True
|
41 |
+
torch.backends.cudnn.benchmark = False
|
42 |
+
|
43 |
+
os.environ["PYTHONHASHSEED"] = str(config.seed)
|
44 |
+
|
45 |
+
print(f'Using seed {config.seed}')
|
46 |
+
|
47 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
48 |
+
|
49 |
+
# Initialize configurations
|
50 |
+
self.submodel = config.submodel
|
51 |
+
self.inference_model = config.inference_model
|
52 |
+
self.sample_num = config.sample_num
|
53 |
+
self.disable_correction = config.disable_correction
|
54 |
+
|
55 |
+
# Data loader.
|
56 |
+
self.inf_smiles = config.inf_smiles # SMILES containing text file for first dataset.
|
57 |
+
# Write the full path to file.
|
58 |
+
|
59 |
+
inf_smiles_basename = osp.basename(self.inf_smiles)
|
60 |
+
|
61 |
+
# Get the base name without extension and add max_atom to it
|
62 |
+
self.max_atom = config.max_atom # Model is based on one-shot generation.
|
63 |
+
inf_smiles_base = os.path.splitext(inf_smiles_basename)[0]
|
64 |
+
|
65 |
+
# Change extension from .smi to .pt and add max_atom to the filename
|
66 |
+
self.inf_dataset_file = f"{inf_smiles_base}{self.max_atom}.pt"
|
67 |
+
|
68 |
+
self.inf_batch_size = config.inf_batch_size
|
69 |
+
self.train_smiles = config.train_smiles
|
70 |
+
self.train_drug_smiles = config.train_drug_smiles
|
71 |
+
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
72 |
+
self.dataset_name = self.inf_dataset_file.split(".")[0]
|
73 |
+
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
|
74 |
+
# Additional node features can be added. Please check new_dataloarder.py Line 102.
|
75 |
+
|
76 |
+
# Get atom and bond encoders/decoders
|
77 |
+
self.atom_encoder, self.atom_decoder, self.bond_encoder, self.bond_decoder = get_encoders_decoders(
|
78 |
+
self.train_smiles,
|
79 |
+
self.train_drug_smiles,
|
80 |
+
self.max_atom
|
81 |
+
)
|
82 |
+
|
83 |
+
self.inf_dataset = DruggenDataset(self.mol_data_dir,
|
84 |
+
self.inf_dataset_file,
|
85 |
+
self.inf_smiles,
|
86 |
+
self.max_atom,
|
87 |
+
self.features,
|
88 |
+
atom_encoder=self.atom_encoder,
|
89 |
+
atom_decoder=self.atom_decoder,
|
90 |
+
bond_encoder=self.bond_encoder,
|
91 |
+
bond_decoder=self.bond_decoder)
|
92 |
+
|
93 |
+
self.inf_loader = DataLoader(self.inf_dataset,
|
94 |
+
shuffle=True,
|
95 |
+
batch_size=self.inf_batch_size,
|
96 |
+
drop_last=True) # PyG dataloader for the first GAN.
|
97 |
+
|
98 |
+
self.m_dim = len(self.atom_decoder) if not self.features else int(self.inf_loader.dataset[0].x.shape[1]) # Atom type dimension.
|
99 |
+
self.b_dim = len(self.bond_decoder) # Bond type dimension.
|
100 |
+
self.vertexes = int(self.inf_loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
101 |
+
|
102 |
+
# Model configurations.
|
103 |
+
self.act = config.act
|
104 |
+
self.dim = config.dim
|
105 |
+
self.depth = config.depth
|
106 |
+
self.heads = config.heads
|
107 |
+
self.mlp_ratio = config.mlp_ratio
|
108 |
+
self.dropout = config.dropout
|
109 |
+
|
110 |
+
self.build_model()
|
111 |
+
|
112 |
+
def build_model(self):
|
113 |
+
"""Create generators and discriminators."""
|
114 |
+
self.G = Generator(self.act,
|
115 |
+
self.vertexes,
|
116 |
+
self.b_dim,
|
117 |
+
self.m_dim,
|
118 |
+
self.dropout,
|
119 |
+
dim=self.dim,
|
120 |
+
depth=self.depth,
|
121 |
+
heads=self.heads,
|
122 |
+
mlp_ratio=self.mlp_ratio)
|
123 |
+
self.G.to(self.device)
|
124 |
+
self.print_network(self.G, 'G')
|
125 |
+
|
126 |
+
def print_network(self, model, name):
|
127 |
+
"""Print out the network information."""
|
128 |
+
num_params = 0
|
129 |
+
for p in model.parameters():
|
130 |
+
num_params += p.numel()
|
131 |
+
print(model)
|
132 |
+
print(name)
|
133 |
+
print("The number of parameters: {}".format(num_params))
|
134 |
+
|
135 |
+
def restore_model(self, submodel, model_directory):
|
136 |
+
"""Restore the trained generator and discriminator."""
|
137 |
+
print('Loading the model...')
|
138 |
+
G_path = os.path.join(model_directory, '{}-G.ckpt'.format(submodel))
|
139 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
140 |
+
|
141 |
+
def inference(self):
|
142 |
+
# Load the trained generator.
|
143 |
+
self.restore_model(self.submodel, self.inference_model)
|
144 |
+
|
145 |
+
# smiles data for metrics calculation.
|
146 |
+
chembl_smiles = [line for line in open(self.train_smiles, 'r').read().splitlines()]
|
147 |
+
chembl_test = [line for line in open(self.inf_smiles, 'r').read().splitlines()]
|
148 |
+
drug_smiles = [line for line in open(self.train_drug_smiles, 'r').read().splitlines()]
|
149 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
150 |
+
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
151 |
+
|
152 |
+
|
153 |
+
# Make directories if not exist.
|
154 |
+
if not os.path.exists("experiments/inference/{}".format(self.submodel)):
|
155 |
+
os.makedirs("experiments/inference/{}".format(self.submodel))
|
156 |
+
|
157 |
+
if not self.disable_correction:
|
158 |
+
correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
|
159 |
+
|
160 |
+
search_res = pd.DataFrame(columns=["submodel", "validity",
|
161 |
+
"uniqueness", "novelty",
|
162 |
+
"novelty_test", "drug_novelty",
|
163 |
+
"max_len", "mean_atom_type",
|
164 |
+
"snn_chembl", "snn_drug", "IntDiv", "qed", "sa"])
|
165 |
+
|
166 |
+
self.G.eval()
|
167 |
+
|
168 |
+
start_time = time.time()
|
169 |
+
metric_calc_dr = []
|
170 |
+
uniqueness_calc = []
|
171 |
+
real_smiles_snn = []
|
172 |
+
nodes_sample = torch.Tensor(size=[1, self.vertexes, 1]).to(self.device)
|
173 |
+
generated_smiles = []
|
174 |
+
val_counter = 0
|
175 |
+
none_counter = 0
|
176 |
+
|
177 |
+
# Inference mode
|
178 |
+
with torch.inference_mode():
|
179 |
+
pbar = tqdm(range(self.sample_num))
|
180 |
+
pbar.set_description('Inference mode for {} model started'.format(self.submodel))
|
181 |
+
for i, data in enumerate(self.inf_loader):
|
182 |
+
val_counter += 1
|
183 |
+
# Preprocess dataset
|
184 |
+
_, a_tensor, x_tensor = load_molecules(
|
185 |
+
data=data,
|
186 |
+
batch_size=self.inf_batch_size,
|
187 |
+
device=self.device,
|
188 |
+
b_dim=self.b_dim,
|
189 |
+
m_dim=self.m_dim,
|
190 |
+
)
|
191 |
+
|
192 |
+
_, _, node_sample, edge_sample = self.G(a_tensor, x_tensor)
|
193 |
+
|
194 |
+
g_edges_hat_sample = torch.max(edge_sample, -1)[1]
|
195 |
+
g_nodes_hat_sample = torch.max(node_sample, -1)[1]
|
196 |
+
|
197 |
+
fake_mol_g = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=False, file_name=self.dataset_name)
|
198 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
199 |
+
|
200 |
+
a_tensor_sample = torch.max(a_tensor, -1)[1]
|
201 |
+
x_tensor_sample = torch.max(x_tensor, -1)[1]
|
202 |
+
real_mols = [self.inf_dataset.matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(), strict=True, file_name=self.dataset_name)
|
203 |
+
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
|
204 |
+
|
205 |
+
inference_drugs = [None if line is None else Chem.MolToSmiles(line) for line in fake_mol_g]
|
206 |
+
inference_drugs = [None if x is None else max(x.split('.'), key=len) for x in inference_drugs]
|
207 |
+
|
208 |
+
for molecules in inference_drugs:
|
209 |
+
if molecules is None:
|
210 |
+
none_counter += 1
|
211 |
+
|
212 |
+
for molecules in inference_drugs:
|
213 |
+
if molecules is not None:
|
214 |
+
molecules = molecules.replace("*", "C")
|
215 |
+
generated_smiles.append(molecules)
|
216 |
+
uniqueness_calc.append(molecules)
|
217 |
+
nodes_sample = torch.cat((nodes_sample, g_nodes_hat_sample.view(1, self.vertexes, 1)), 0)
|
218 |
+
pbar.update(1)
|
219 |
+
metric_calc_dr.append(molecules)
|
220 |
+
|
221 |
+
real_smiles_snn.append(real_mols[0])
|
222 |
+
generation_number = len([x for x in metric_calc_dr if x is not None])
|
223 |
+
if generation_number == self.sample_num or none_counter == self.sample_num:
|
224 |
+
break
|
225 |
+
|
226 |
+
if not self.disable_correction:
|
227 |
+
correct = smi_correct(self.submodel, "experiments/inference/{}".format(self.submodel))
|
228 |
+
gen_smi = correct.correct_smiles_list(generated_smiles)
|
229 |
+
else:
|
230 |
+
gen_smi = generated_smiles
|
231 |
+
|
232 |
+
et = time.time() - start_time
|
233 |
+
|
234 |
+
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(x), 2, nBits=1024) for x in uniqueness_calc if Chem.MolFromSmiles(x) is not None]
|
235 |
+
real_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_smiles_snn if x is not None]
|
236 |
+
|
237 |
+
if not self.disable_correction:
|
238 |
+
val = round(len(gen_smi)/self.sample_num, 3)
|
239 |
+
else:
|
240 |
+
val = round(fraction_valid(gen_smi), 3)
|
241 |
+
|
242 |
+
uniq = round(fraction_unique(gen_smi), 3)
|
243 |
+
nov = round(novelty(gen_smi, chembl_smiles), 3)
|
244 |
+
nov_test = round(novelty(gen_smi, chembl_test), 3)
|
245 |
+
drug_nov = round(novelty(gen_smi, drug_smiles), 3)
|
246 |
+
max_len = round(Metrics.max_component(gen_smi, self.vertexes), 3)
|
247 |
+
mean_atom = round(Metrics.mean_atom_type(nodes_sample), 3)
|
248 |
+
snn_chembl = round(average_agg_tanimoto(np.array(real_vecs), np.array(gen_vecs)), 3)
|
249 |
+
snn_drug = round(average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs)), 3)
|
250 |
+
int_div = round((internal_diversity(np.array(gen_vecs)))[0], 3)
|
251 |
+
qed = round(np.mean([QED.qed(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
252 |
+
sa = round(np.mean([sascorer.calculateScore(Chem.MolFromSmiles(x)) for x in gen_smi if Chem.MolFromSmiles(x) is not None]), 3)
|
253 |
+
|
254 |
+
model_res = pd.DataFrame({"submodel": [self.submodel], "validity": [val],
|
255 |
+
"uniqueness": [uniq], "novelty": [nov],
|
256 |
+
"novelty_test": [nov_test], "drug_novelty": [drug_nov],
|
257 |
+
"max_len": [max_len], "mean_atom_type": [mean_atom],
|
258 |
+
"snn_chembl": [snn_chembl], "snn_drug": [snn_drug],
|
259 |
+
"IntDiv": [int_div], "qed": [qed], "sa": [sa]})
|
260 |
+
|
261 |
+
# Write generated SMILES to a temporary file for app.py to use
|
262 |
+
temp_file = f'{self.submodel}_denovo_mols.smi'
|
263 |
+
with open(temp_file, 'w') as f:
|
264 |
+
f.write("SMILES\n")
|
265 |
+
for smiles in gen_smi:
|
266 |
+
f.write(f"{smiles}\n")
|
267 |
+
|
268 |
+
return model_res
|
269 |
+
|
270 |
+
|
271 |
+
if __name__=="__main__":
|
272 |
+
parser = argparse.ArgumentParser()
|
273 |
+
|
274 |
+
# Inference configuration.
|
275 |
+
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
|
276 |
+
parser.add_argument('--inference_model', type=str, help="Path to the model for inference")
|
277 |
+
parser.add_argument('--sample_num', type=int, default=100, help='inference samples')
|
278 |
+
parser.add_argument('--disable_correction', action='store_true', help='Disable SMILES correction')
|
279 |
+
|
280 |
+
# Data configuration.
|
281 |
+
parser.add_argument('--inf_smiles', type=str, required=True)
|
282 |
+
parser.add_argument('--train_smiles', type=str, required=True)
|
283 |
+
parser.add_argument('--train_drug_smiles', type=str, required=True)
|
284 |
+
parser.add_argument('--inf_batch_size', type=int, default=1, help='Batch size for inference')
|
285 |
+
parser.add_argument('--mol_data_dir', type=str, default='data')
|
286 |
+
parser.add_argument('--features', action='store_true', help='features dimension for nodes')
|
287 |
+
|
288 |
+
# Model configuration.
|
289 |
+
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
|
290 |
+
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
|
291 |
+
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
|
292 |
+
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
|
293 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
|
294 |
+
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
|
295 |
+
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
|
296 |
+
|
297 |
+
# Seed configuration.
|
298 |
+
parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
|
299 |
+
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
|
300 |
+
|
301 |
+
config = parser.parse_args()
|
302 |
+
inference = Inference(config)
|
303 |
+
inference.inference()
|
src/__init__.py
ADDED
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (150 Bytes). View file
|
|
src/data/__init__.py
ADDED
File without changes
|
src/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (155 Bytes). View file
|
|
src/data/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (12.9 kB). View file
|
|
src/data/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.75 kB). View file
|
|
src/data/dataset.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import re
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch_geometric.data import Data, InMemoryDataset
|
12 |
+
|
13 |
+
from rdkit import Chem, RDLogger
|
14 |
+
|
15 |
+
from src.data.utils import label2onehot
|
16 |
+
|
17 |
+
RDLogger.DisableLog('rdApp.*')
|
18 |
+
|
19 |
+
|
20 |
+
class DruggenDataset(InMemoryDataset):
|
21 |
+
def __init__(self, root, dataset_file, raw_files, max_atom, features,
|
22 |
+
atom_encoder, atom_decoder, bond_encoder, bond_decoder,
|
23 |
+
transform=None, pre_transform=None, pre_filter=None):
|
24 |
+
"""
|
25 |
+
Initialize the DruggenDataset with pre-loaded encoder/decoder dictionaries.
|
26 |
+
|
27 |
+
Parameters:
|
28 |
+
root (str): Root directory.
|
29 |
+
dataset_file (str): Name of the processed dataset file.
|
30 |
+
raw_files (str): Path to the raw SMILES file.
|
31 |
+
max_atom (int): Maximum number of atoms allowed in a molecule.
|
32 |
+
features (bool): Whether to include additional node features.
|
33 |
+
atom_encoder (dict): Pre-loaded atom encoder dictionary.
|
34 |
+
atom_decoder (dict): Pre-loaded atom decoder dictionary.
|
35 |
+
bond_encoder (dict): Pre-loaded bond encoder dictionary.
|
36 |
+
bond_decoder (dict): Pre-loaded bond decoder dictionary.
|
37 |
+
transform, pre_transform, pre_filter: See PyG InMemoryDataset.
|
38 |
+
"""
|
39 |
+
self.dataset_name = dataset_file.split(".")[0]
|
40 |
+
self.dataset_file = dataset_file
|
41 |
+
self.raw_files = raw_files
|
42 |
+
self.max_atom = max_atom
|
43 |
+
self.features = features
|
44 |
+
|
45 |
+
# Use the provided encoder/decoder mappings.
|
46 |
+
self.atom_encoder_m = atom_encoder
|
47 |
+
self.atom_decoder_m = atom_decoder
|
48 |
+
self.bond_encoder_m = bond_encoder
|
49 |
+
self.bond_decoder_m = bond_decoder
|
50 |
+
|
51 |
+
self.atom_num_types = len(atom_encoder)
|
52 |
+
self.bond_num_types = len(bond_encoder)
|
53 |
+
|
54 |
+
super().__init__(root, transform, pre_transform, pre_filter)
|
55 |
+
path = osp.join(self.processed_dir, dataset_file)
|
56 |
+
self.data, self.slices = torch.load(path)
|
57 |
+
self.root = root
|
58 |
+
|
59 |
+
@property
|
60 |
+
def processed_dir(self):
|
61 |
+
"""
|
62 |
+
Returns the directory where processed dataset files are stored.
|
63 |
+
"""
|
64 |
+
return self.root
|
65 |
+
|
66 |
+
@property
|
67 |
+
def raw_file_names(self):
|
68 |
+
"""
|
69 |
+
Returns the raw SMILES file name.
|
70 |
+
"""
|
71 |
+
return self.raw_files
|
72 |
+
|
73 |
+
@property
|
74 |
+
def processed_file_names(self):
|
75 |
+
"""
|
76 |
+
Returns the name of the processed dataset file.
|
77 |
+
"""
|
78 |
+
return self.dataset_file
|
79 |
+
|
80 |
+
def _filter_smiles(self, smiles_list):
|
81 |
+
"""
|
82 |
+
Filters the input list of SMILES strings to keep only valid molecules that:
|
83 |
+
- Can be successfully parsed,
|
84 |
+
- Have a number of atoms less than or equal to the maximum allowed (max_atom),
|
85 |
+
- Contain only atoms present in the atom_encoder,
|
86 |
+
- Contain only bonds present in the bond_encoder.
|
87 |
+
|
88 |
+
Parameters:
|
89 |
+
smiles_list (list): List of SMILES strings.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
max_length (int): Maximum number of atoms found in the filtered molecules.
|
93 |
+
filtered_smiles (list): List of valid SMILES strings.
|
94 |
+
"""
|
95 |
+
max_length = 0
|
96 |
+
filtered_smiles = []
|
97 |
+
for smiles in tqdm(smiles_list, desc="Filtering SMILES"):
|
98 |
+
mol = Chem.MolFromSmiles(smiles)
|
99 |
+
if mol is None:
|
100 |
+
continue
|
101 |
+
|
102 |
+
# Check molecule size
|
103 |
+
molecule_size = mol.GetNumAtoms()
|
104 |
+
if molecule_size > self.max_atom:
|
105 |
+
continue
|
106 |
+
|
107 |
+
# Filter out molecules with atoms not in the atom_encoder
|
108 |
+
if not all(atom.GetAtomicNum() in self.atom_encoder_m for atom in mol.GetAtoms()):
|
109 |
+
continue
|
110 |
+
|
111 |
+
# Filter out molecules with bonds not in the bond_encoder
|
112 |
+
if not all(bond.GetBondType() in self.bond_encoder_m for bond in mol.GetBonds()):
|
113 |
+
continue
|
114 |
+
|
115 |
+
filtered_smiles.append(smiles)
|
116 |
+
max_length = max(max_length, molecule_size)
|
117 |
+
return max_length, filtered_smiles
|
118 |
+
|
119 |
+
def _genA(self, mol, connected=True, max_length=None):
|
120 |
+
"""
|
121 |
+
Generates the adjacency matrix for a molecule based on its bond structure.
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
mol (rdkit.Chem.Mol): The molecule.
|
125 |
+
connected (bool): If True, ensures all atoms are connected.
|
126 |
+
max_length (int, optional): The size of the matrix; if None, uses number of atoms in mol.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
np.array: Adjacency matrix with bond types as entries, or None if disconnected.
|
130 |
+
"""
|
131 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
132 |
+
A = np.zeros((max_length, max_length))
|
133 |
+
begin = [b.GetBeginAtomIdx() for b in mol.GetBonds()]
|
134 |
+
end = [b.GetEndAtomIdx() for b in mol.GetBonds()]
|
135 |
+
bond_type = [self.bond_encoder_m[b.GetBondType()] for b in mol.GetBonds()]
|
136 |
+
A[begin, end] = bond_type
|
137 |
+
A[end, begin] = bond_type
|
138 |
+
degree = np.sum(A[:mol.GetNumAtoms(), :mol.GetNumAtoms()], axis=-1)
|
139 |
+
return A if connected and (degree > 0).all() else None
|
140 |
+
|
141 |
+
def _genX(self, mol, max_length=None):
|
142 |
+
"""
|
143 |
+
Generates the feature vector for each atom in a molecule by encoding their atomic numbers.
|
144 |
+
|
145 |
+
Parameters:
|
146 |
+
mol (rdkit.Chem.Mol): The molecule.
|
147 |
+
max_length (int, optional): Length of the feature vector; if None, uses number of atoms in mol.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
np.array: Array of atom feature indices, padded with zeros if necessary, or None on error.
|
151 |
+
"""
|
152 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
153 |
+
try:
|
154 |
+
return np.array([self.atom_encoder_m[atom.GetAtomicNum()] for atom in mol.GetAtoms()] +
|
155 |
+
[0] * (max_length - mol.GetNumAtoms()))
|
156 |
+
except KeyError as e:
|
157 |
+
print(f"Skipping molecule with unsupported atom: {e}")
|
158 |
+
print(f"Skipped SMILES: {Chem.MolToSmiles(mol)}")
|
159 |
+
return None
|
160 |
+
|
161 |
+
def _genF(self, mol, max_length=None):
|
162 |
+
"""
|
163 |
+
Generates additional node features for a molecule using various atomic properties.
|
164 |
+
|
165 |
+
Parameters:
|
166 |
+
mol (rdkit.Chem.Mol): The molecule.
|
167 |
+
max_length (int, optional): Number of rows in the features matrix; if None, uses number of atoms.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
np.array: Array of additional features for each atom, padded with zeros if necessary.
|
171 |
+
"""
|
172 |
+
max_length = max_length if max_length is not None else mol.GetNumAtoms()
|
173 |
+
features = np.array([[*[a.GetDegree() == i for i in range(5)],
|
174 |
+
*[a.GetExplicitValence() == i for i in range(9)],
|
175 |
+
*[int(a.GetHybridization()) == i for i in range(1, 7)],
|
176 |
+
*[a.GetImplicitValence() == i for i in range(9)],
|
177 |
+
a.GetIsAromatic(),
|
178 |
+
a.GetNoImplicit(),
|
179 |
+
*[a.GetNumExplicitHs() == i for i in range(5)],
|
180 |
+
*[a.GetNumImplicitHs() == i for i in range(5)],
|
181 |
+
*[a.GetNumRadicalElectrons() == i for i in range(5)],
|
182 |
+
a.IsInRing(),
|
183 |
+
*[a.IsInRingSize(i) for i in range(2, 9)]]
|
184 |
+
for a in mol.GetAtoms()], dtype=np.int32)
|
185 |
+
return np.vstack((features, np.zeros((max_length - features.shape[0], features.shape[1]))))
|
186 |
+
|
187 |
+
def decoder_load(self, dictionary_name, file):
|
188 |
+
"""
|
189 |
+
Returns the pre-loaded decoder dictionary based on the dictionary name.
|
190 |
+
|
191 |
+
Parameters:
|
192 |
+
dictionary_name (str): Name of the dictionary ("atom" or "bond").
|
193 |
+
file: Placeholder parameter for compatibility.
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
dict: The corresponding decoder dictionary.
|
197 |
+
"""
|
198 |
+
if dictionary_name == "atom":
|
199 |
+
return self.atom_decoder_m
|
200 |
+
elif dictionary_name == "bond":
|
201 |
+
return self.bond_decoder_m
|
202 |
+
else:
|
203 |
+
raise ValueError("Unknown dictionary name.")
|
204 |
+
|
205 |
+
def matrices2mol(self, node_labels, edge_labels, strict=True, file_name=None):
|
206 |
+
"""
|
207 |
+
Converts graph representations (node labels and edge labels) back to an RDKit molecule.
|
208 |
+
|
209 |
+
Parameters:
|
210 |
+
node_labels (iterable): Encoded atom labels.
|
211 |
+
edge_labels (np.array): Adjacency matrix with encoded bond types.
|
212 |
+
strict (bool): If True, sanitizes the molecule and returns None on failure.
|
213 |
+
file_name: Placeholder parameter for compatibility.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
rdkit.Chem.Mol: The resulting molecule, or None if sanitization fails.
|
217 |
+
"""
|
218 |
+
mol = Chem.RWMol()
|
219 |
+
for node_label in node_labels:
|
220 |
+
mol.AddAtom(Chem.Atom(self.atom_decoder_m[node_label]))
|
221 |
+
for start, end in zip(*np.nonzero(edge_labels)):
|
222 |
+
if start > end:
|
223 |
+
mol.AddBond(int(start), int(end), self.bond_decoder_m[edge_labels[start, end]])
|
224 |
+
if strict:
|
225 |
+
try:
|
226 |
+
Chem.SanitizeMol(mol)
|
227 |
+
except Exception:
|
228 |
+
mol = None
|
229 |
+
return mol
|
230 |
+
|
231 |
+
def check_valency(self, mol):
|
232 |
+
"""
|
233 |
+
Checks that no atom in the molecule has exceeded its allowed valency.
|
234 |
+
|
235 |
+
Parameters:
|
236 |
+
mol (rdkit.Chem.Mol): The molecule.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
tuple: (True, None) if valid; (False, atomid_valence) if there is a valency issue.
|
240 |
+
"""
|
241 |
+
try:
|
242 |
+
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
|
243 |
+
return True, None
|
244 |
+
except ValueError as e:
|
245 |
+
e = str(e)
|
246 |
+
p = e.find('#')
|
247 |
+
e_sub = e[p:]
|
248 |
+
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
|
249 |
+
return False, atomid_valence
|
250 |
+
|
251 |
+
def correct_mol(self, mol):
|
252 |
+
"""
|
253 |
+
Corrects a molecule by removing bonds until all atoms satisfy their valency limits.
|
254 |
+
|
255 |
+
Parameters:
|
256 |
+
mol (rdkit.Chem.Mol): The molecule.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
rdkit.Chem.Mol: The corrected molecule.
|
260 |
+
"""
|
261 |
+
while True:
|
262 |
+
flag, atomid_valence = self.check_valency(mol)
|
263 |
+
if flag:
|
264 |
+
break
|
265 |
+
else:
|
266 |
+
# Expecting two numbers: atom index and its valence.
|
267 |
+
assert len(atomid_valence) == 2
|
268 |
+
idx = atomid_valence[0]
|
269 |
+
queue = []
|
270 |
+
for b in mol.GetAtomWithIdx(idx).GetBonds():
|
271 |
+
queue.append((b.GetIdx(), int(b.GetBondType()), b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
|
272 |
+
queue.sort(key=lambda tup: tup[1], reverse=True)
|
273 |
+
if queue:
|
274 |
+
start = queue[0][2]
|
275 |
+
end = queue[0][3]
|
276 |
+
mol.RemoveBond(start, end)
|
277 |
+
return mol
|
278 |
+
|
279 |
+
|
280 |
+
def process(self, size=None):
|
281 |
+
"""
|
282 |
+
Processes the raw SMILES file by filtering and converting each valid SMILES into a PyTorch Geometric Data object.
|
283 |
+
The resulting dataset is saved to disk.
|
284 |
+
|
285 |
+
Parameters:
|
286 |
+
size (optional): Placeholder parameter for compatibility.
|
287 |
+
|
288 |
+
Side Effects:
|
289 |
+
Saves the processed dataset as a file in the processed directory.
|
290 |
+
"""
|
291 |
+
# Read raw SMILES from file (assuming CSV with no header)
|
292 |
+
smiles_list = pd.read_csv(self.raw_files, header=None)[0].tolist()
|
293 |
+
max_length, filtered_smiles = self._filter_smiles(smiles_list)
|
294 |
+
data_list = []
|
295 |
+
self.m_dim = len(self.atom_decoder_m)
|
296 |
+
for smiles in tqdm(filtered_smiles, desc='Processing dataset', total=len(filtered_smiles)):
|
297 |
+
mol = Chem.MolFromSmiles(smiles)
|
298 |
+
A = self._genA(mol, connected=True, max_length=max_length)
|
299 |
+
if A is not None:
|
300 |
+
x_array = self._genX(mol, max_length=max_length)
|
301 |
+
if x_array is None:
|
302 |
+
continue
|
303 |
+
x = torch.from_numpy(x_array).to(torch.long).view(1, -1)
|
304 |
+
x = label2onehot(x, self.m_dim).squeeze()
|
305 |
+
if self.features:
|
306 |
+
f = torch.from_numpy(self._genF(mol, max_length=max_length)).to(torch.long).view(x.shape[0], -1)
|
307 |
+
x = torch.concat((x, f), dim=-1)
|
308 |
+
adjacency = torch.from_numpy(A)
|
309 |
+
edge_index = adjacency.nonzero(as_tuple=False).t().contiguous()
|
310 |
+
edge_attr = adjacency[edge_index[0], edge_index[1]].to(torch.long)
|
311 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles)
|
312 |
+
if self.pre_filter is not None and not self.pre_filter(data):
|
313 |
+
continue
|
314 |
+
if self.pre_transform is not None:
|
315 |
+
data = self.pre_transform(data)
|
316 |
+
data_list.append(data)
|
317 |
+
torch.save(self.collate(data_list), osp.join(self.processed_dir, self.dataset_file))
|
src/data/utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pickle
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch_geometric.data import Data, InMemoryDataset
|
9 |
+
import torch_geometric.utils as geoutils
|
10 |
+
|
11 |
+
from rdkit import Chem, RDLogger
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def label2onehot(labels, dim, device=None):
|
16 |
+
"""Convert label indices to one-hot vectors."""
|
17 |
+
out = torch.zeros(list(labels.size())+[dim])
|
18 |
+
if device:
|
19 |
+
out = out.to(device)
|
20 |
+
|
21 |
+
out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)
|
22 |
+
|
23 |
+
return out.float()
|
24 |
+
|
25 |
+
|
26 |
+
def get_encoders_decoders(raw_file1, raw_file2, max_atom):
|
27 |
+
"""
|
28 |
+
Given two raw SMILES files, either load the atom and bond encoders/decoders
|
29 |
+
if they exist (naming them based on the file names) or create and save them.
|
30 |
+
|
31 |
+
Parameters:
|
32 |
+
raw_file1 (str): Path to the first SMILES file.
|
33 |
+
raw_file2 (str): Path to the second SMILES file.
|
34 |
+
max_atom (int): Maximum allowed number of atoms in a molecule.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
atom_encoder (dict): Mapping from atomic numbers to indices.
|
38 |
+
atom_decoder (dict): Mapping from indices to atomic numbers.
|
39 |
+
bond_encoder (dict): Mapping from bond types to indices.
|
40 |
+
bond_decoder (dict): Mapping from indices to bond types.
|
41 |
+
"""
|
42 |
+
# Determine unique suffix based on the two file names (alphabetically sorted for consistency)
|
43 |
+
name1 = os.path.splitext(os.path.basename(raw_file1))[0]
|
44 |
+
name2 = os.path.splitext(os.path.basename(raw_file2))[0]
|
45 |
+
sorted_names = sorted([name1, name2])
|
46 |
+
suffix = f"{sorted_names[0]}_{sorted_names[1]}"
|
47 |
+
|
48 |
+
# Define encoder/decoder directories and file paths
|
49 |
+
enc_dir = os.path.join("data", "encoders")
|
50 |
+
dec_dir = os.path.join("data", "decoders")
|
51 |
+
atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl")
|
52 |
+
atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl")
|
53 |
+
bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl")
|
54 |
+
bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl")
|
55 |
+
|
56 |
+
# If all files exist, load and return them
|
57 |
+
if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and
|
58 |
+
os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)):
|
59 |
+
with open(atom_encoder_path, "rb") as f:
|
60 |
+
atom_encoder = pickle.load(f)
|
61 |
+
with open(atom_decoder_path, "rb") as f:
|
62 |
+
atom_decoder = pickle.load(f)
|
63 |
+
with open(bond_encoder_path, "rb") as f:
|
64 |
+
bond_encoder = pickle.load(f)
|
65 |
+
with open(bond_decoder_path, "rb") as f:
|
66 |
+
bond_decoder = pickle.load(f)
|
67 |
+
print("Loaded existing encoders/decoders!")
|
68 |
+
return atom_encoder, atom_decoder, bond_encoder, bond_decoder
|
69 |
+
|
70 |
+
# Otherwise, create the encoders/decoders
|
71 |
+
print("Creating new encoders/decoders...")
|
72 |
+
# Read SMILES from both files (assuming one SMILES per row, no header)
|
73 |
+
smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist()
|
74 |
+
smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist()
|
75 |
+
smiles_combined = smiles1 + smiles2
|
76 |
+
|
77 |
+
atom_labels = set()
|
78 |
+
bond_labels = set()
|
79 |
+
max_length = 0
|
80 |
+
filtered_smiles = []
|
81 |
+
|
82 |
+
# Process each SMILES: keep only valid molecules with <= max_atom atoms
|
83 |
+
for smiles in tqdm(smiles_combined, desc="Processing SMILES"):
|
84 |
+
mol = Chem.MolFromSmiles(smiles)
|
85 |
+
if mol is None:
|
86 |
+
continue
|
87 |
+
molecule_size = mol.GetNumAtoms()
|
88 |
+
if molecule_size > max_atom:
|
89 |
+
continue
|
90 |
+
filtered_smiles.append(smiles)
|
91 |
+
# Collect atomic numbers
|
92 |
+
atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
|
93 |
+
max_length = max(max_length, molecule_size)
|
94 |
+
# Collect bond types
|
95 |
+
bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
|
96 |
+
|
97 |
+
# Add a PAD symbol (here using 0 for atoms)
|
98 |
+
atom_labels.add(0)
|
99 |
+
atom_labels = sorted(atom_labels)
|
100 |
+
|
101 |
+
# For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO)
|
102 |
+
bond_labels = sorted(bond_labels)
|
103 |
+
bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels
|
104 |
+
|
105 |
+
# Create encoder and decoder dictionaries
|
106 |
+
atom_encoder = {l: i for i, l in enumerate(atom_labels)}
|
107 |
+
atom_decoder = {i: l for i, l in enumerate(atom_labels)}
|
108 |
+
bond_encoder = {l: i for i, l in enumerate(bond_labels)}
|
109 |
+
bond_decoder = {i: l for i, l in enumerate(bond_labels)}
|
110 |
+
|
111 |
+
# Ensure directories exist
|
112 |
+
os.makedirs(enc_dir, exist_ok=True)
|
113 |
+
os.makedirs(dec_dir, exist_ok=True)
|
114 |
+
|
115 |
+
# Save the encoders/decoders to disk
|
116 |
+
with open(atom_encoder_path, "wb") as f:
|
117 |
+
pickle.dump(atom_encoder, f)
|
118 |
+
with open(atom_decoder_path, "wb") as f:
|
119 |
+
pickle.dump(atom_decoder, f)
|
120 |
+
with open(bond_encoder_path, "wb") as f:
|
121 |
+
pickle.dump(bond_encoder, f)
|
122 |
+
with open(bond_decoder_path, "wb") as f:
|
123 |
+
pickle.dump(bond_decoder, f)
|
124 |
+
|
125 |
+
print("Encoders/decoders created and saved.")
|
126 |
+
return atom_encoder, atom_decoder, bond_encoder, bond_decoder
|
127 |
+
|
128 |
+
def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
|
129 |
+
data = data.to(device)
|
130 |
+
a = geoutils.to_dense_adj(
|
131 |
+
edge_index = data.edge_index,
|
132 |
+
batch=data.batch,
|
133 |
+
edge_attr=data.edge_attr,
|
134 |
+
max_num_nodes=int(data.batch.shape[0]/batch_size)
|
135 |
+
)
|
136 |
+
x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
|
137 |
+
a_tensor = label2onehot(a, b_dim, device)
|
138 |
+
|
139 |
+
a_tensor_vec = a_tensor.reshape(batch_size,-1)
|
140 |
+
x_tensor_vec = x_tensor.reshape(batch_size,-1)
|
141 |
+
real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)
|
142 |
+
|
143 |
+
return real_graphs, a_tensor, x_tensor
|
src/model/__init__.py
ADDED
File without changes
|
src/model/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (156 Bytes). View file
|
|
src/model/__pycache__/layers.cpython-310.pyc
ADDED
Binary file (8.31 kB). View file
|
|
src/model/__pycache__/loss.cpython-310.pyc
ADDED
Binary file (2.04 kB). View file
|
|
src/model/__pycache__/models.cpython-310.pyc
ADDED
Binary file (7.35 kB). View file
|
|
src/model/layers.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
class MLP(nn.Module):
|
8 |
+
"""
|
9 |
+
A simple Multi-Layer Perceptron (MLP) module consisting of two linear layers with a ReLU activation in between,
|
10 |
+
followed by a dropout on the output.
|
11 |
+
|
12 |
+
Attributes:
|
13 |
+
fc1 (nn.Linear): The first fully-connected layer.
|
14 |
+
act (nn.ReLU): ReLU activation function.
|
15 |
+
fc2 (nn.Linear): The second fully-connected layer.
|
16 |
+
droprateout (nn.Dropout): Dropout layer applied to the output.
|
17 |
+
"""
|
18 |
+
def __init__(self, in_feat, hid_feat=None, out_feat=None, dropout=0.):
|
19 |
+
"""
|
20 |
+
Initializes the MLP module.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
in_feat (int): Number of input features.
|
24 |
+
hid_feat (int, optional): Number of hidden features. Defaults to in_feat if not provided.
|
25 |
+
out_feat (int, optional): Number of output features. Defaults to in_feat if not provided.
|
26 |
+
dropout (float, optional): Dropout rate. Defaults to 0.
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
# Set hidden and output dimensions to input dimension if not specified
|
31 |
+
if not hid_feat:
|
32 |
+
hid_feat = in_feat
|
33 |
+
if not out_feat:
|
34 |
+
out_feat = in_feat
|
35 |
+
|
36 |
+
self.fc1 = nn.Linear(in_feat, hid_feat)
|
37 |
+
self.act = nn.ReLU()
|
38 |
+
self.fc2 = nn.Linear(hid_feat, out_feat)
|
39 |
+
self.droprateout = nn.Dropout(dropout)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
"""
|
43 |
+
Forward pass for the MLP.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
x (torch.Tensor): Input tensor.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
torch.Tensor: Output tensor after applying the linear layers, activation, and dropout.
|
50 |
+
"""
|
51 |
+
x = self.fc1(x)
|
52 |
+
x = self.act(x)
|
53 |
+
x = self.fc2(x)
|
54 |
+
return self.droprateout(x)
|
55 |
+
|
56 |
+
class MHA(nn.Module):
|
57 |
+
"""
|
58 |
+
Multi-Head Attention (MHA) module of the graph transformer with edge features incorporated into the attention computation.
|
59 |
+
|
60 |
+
Attributes:
|
61 |
+
heads (int): Number of attention heads.
|
62 |
+
scale (float): Scaling factor for the attention scores.
|
63 |
+
q, k, v (nn.Linear): Linear layers to project the node features into query, key, and value embeddings.
|
64 |
+
e (nn.Linear): Linear layer to project the edge features.
|
65 |
+
d_k (int): Dimension of each attention head.
|
66 |
+
out_e (nn.Linear): Linear layer applied to the computed edge features.
|
67 |
+
out_n (nn.Linear): Linear layer applied to the aggregated node features.
|
68 |
+
"""
|
69 |
+
def __init__(self, dim, heads, attention_dropout=0.):
|
70 |
+
"""
|
71 |
+
Initializes the Multi-Head Attention module.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
dim (int): Dimensionality of the input features.
|
75 |
+
heads (int): Number of attention heads.
|
76 |
+
attention_dropout (float, optional): Dropout rate for attention (not used explicitly in this implementation).
|
77 |
+
"""
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
# Ensure that dimension is divisible by the number of heads
|
81 |
+
assert dim % heads == 0
|
82 |
+
|
83 |
+
self.heads = heads
|
84 |
+
self.scale = 1. / math.sqrt(dim) # Scaling factor for attention
|
85 |
+
# Linear layers for projecting node features
|
86 |
+
self.q = nn.Linear(dim, dim)
|
87 |
+
self.k = nn.Linear(dim, dim)
|
88 |
+
self.v = nn.Linear(dim, dim)
|
89 |
+
# Linear layer for projecting edge features
|
90 |
+
self.e = nn.Linear(dim, dim)
|
91 |
+
self.d_k = dim // heads # Dimension per head
|
92 |
+
|
93 |
+
# Linear layers for output transformations
|
94 |
+
self.out_e = nn.Linear(dim, dim)
|
95 |
+
self.out_n = nn.Linear(dim, dim)
|
96 |
+
|
97 |
+
def forward(self, node, edge):
|
98 |
+
"""
|
99 |
+
Forward pass for the Multi-Head Attention.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
node (torch.Tensor): Node feature tensor of shape (batch, num_nodes, dim).
|
103 |
+
edge (torch.Tensor): Edge feature tensor of shape (batch, num_nodes, num_nodes, dim).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
tuple: (updated node features, updated edge features)
|
107 |
+
"""
|
108 |
+
b, n, c = node.shape
|
109 |
+
|
110 |
+
# Compute query, key, and value embeddings and reshape for multi-head attention
|
111 |
+
q_embed = self.q(node).view(b, n, self.heads, c // self.heads)
|
112 |
+
k_embed = self.k(node).view(b, n, self.heads, c // self.heads)
|
113 |
+
v_embed = self.v(node).view(b, n, self.heads, c // self.heads)
|
114 |
+
|
115 |
+
# Compute edge embeddings
|
116 |
+
e_embed = self.e(edge).view(b, n, n, self.heads, c // self.heads)
|
117 |
+
|
118 |
+
# Adjust dimensions for broadcasting: add singleton dimensions to queries and keys
|
119 |
+
q_embed = q_embed.unsqueeze(2) # Shape: (b, n, 1, heads, c//heads)
|
120 |
+
k_embed = k_embed.unsqueeze(1) # Shape: (b, 1, n, heads, c//heads)
|
121 |
+
|
122 |
+
# Compute attention scores
|
123 |
+
attn = q_embed * k_embed
|
124 |
+
attn = attn / math.sqrt(self.d_k)
|
125 |
+
attn = attn * (e_embed + 1) * e_embed # Modulated attention incorporating edge features
|
126 |
+
|
127 |
+
edge_out = self.out_e(attn.flatten(3)) # Flatten last dimension for linear layer
|
128 |
+
|
129 |
+
# Apply softmax over the node dimension to obtain normalized attention weights
|
130 |
+
attn = F.softmax(attn, dim=2)
|
131 |
+
|
132 |
+
v_embed = v_embed.unsqueeze(1) # Adjust dimensions to broadcast: (b, 1, n, heads, c//heads)
|
133 |
+
v_embed = attn * v_embed
|
134 |
+
v_embed = v_embed.sum(dim=2).flatten(2)
|
135 |
+
node_out = self.out_n(v_embed)
|
136 |
+
|
137 |
+
return node_out, edge_out
|
138 |
+
|
139 |
+
class Encoder_Block(nn.Module):
|
140 |
+
"""
|
141 |
+
Transformer encoder block that integrates node and edge features.
|
142 |
+
|
143 |
+
Consists of:
|
144 |
+
- A multi-head attention layer with edge modulation.
|
145 |
+
- Two MLP layers, each with residual connections and layer normalization.
|
146 |
+
|
147 |
+
Attributes:
|
148 |
+
ln1, ln3, ln4, ln5, ln6 (nn.LayerNorm): Layer normalization modules.
|
149 |
+
attn (MHA): Multi-head attention module.
|
150 |
+
mlp, mlp2 (MLP): MLP modules for further transformation of node and edge features.
|
151 |
+
"""
|
152 |
+
def __init__(self, dim, heads, act, mlp_ratio=4, drop_rate=0.):
|
153 |
+
"""
|
154 |
+
Initializes the encoder block.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
dim (int): Dimensionality of the input features.
|
158 |
+
heads (int): Number of attention heads.
|
159 |
+
act (callable): Activation function (not explicitly used in this block, but provided for potential extensions).
|
160 |
+
mlp_ratio (int, optional): Ratio to determine the hidden layer size in the MLP. Defaults to 4.
|
161 |
+
drop_rate (float, optional): Dropout rate applied in the MLPs. Defaults to 0.
|
162 |
+
"""
|
163 |
+
super().__init__()
|
164 |
+
|
165 |
+
self.ln1 = nn.LayerNorm(dim)
|
166 |
+
self.attn = MHA(dim, heads, drop_rate)
|
167 |
+
self.ln3 = nn.LayerNorm(dim)
|
168 |
+
self.ln4 = nn.LayerNorm(dim)
|
169 |
+
self.mlp = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
|
170 |
+
self.mlp2 = MLP(dim, dim * mlp_ratio, dim, dropout=drop_rate)
|
171 |
+
self.ln5 = nn.LayerNorm(dim)
|
172 |
+
self.ln6 = nn.LayerNorm(dim)
|
173 |
+
|
174 |
+
def forward(self, x, y):
|
175 |
+
"""
|
176 |
+
Forward pass of the encoder block.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
x (torch.Tensor): Node feature tensor.
|
180 |
+
y (torch.Tensor): Edge feature tensor.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
tuple: (updated node features, updated edge features)
|
184 |
+
"""
|
185 |
+
x1 = self.ln1(x)
|
186 |
+
x2, y1 = self.attn(x1, y)
|
187 |
+
x2 = x1 + x2
|
188 |
+
y2 = y + y1
|
189 |
+
x2 = self.ln3(x2)
|
190 |
+
y2 = self.ln4(y2)
|
191 |
+
x = self.ln5(x2 + self.mlp(x2))
|
192 |
+
y = self.ln6(y2 + self.mlp2(y2))
|
193 |
+
return x, y
|
194 |
+
|
195 |
+
class TransformerEncoder(nn.Module):
|
196 |
+
"""
|
197 |
+
Transformer Encoder composed of a sequence of encoder blocks.
|
198 |
+
|
199 |
+
Attributes:
|
200 |
+
Encoder_Blocks (nn.ModuleList): A list of Encoder_Block modules stacked sequentially.
|
201 |
+
"""
|
202 |
+
def __init__(self, dim, depth, heads, act, mlp_ratio=4, drop_rate=0.1):
|
203 |
+
"""
|
204 |
+
Initializes the Transformer Encoder.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
dim (int): Dimensionality of the input features.
|
208 |
+
depth (int): Number of encoder blocks to stack.
|
209 |
+
heads (int): Number of attention heads in each block.
|
210 |
+
act (callable): Activation function (passed to encoder blocks for potential use).
|
211 |
+
mlp_ratio (int, optional): Ratio for determining the hidden layer size in MLP modules. Defaults to 4.
|
212 |
+
drop_rate (float, optional): Dropout rate for the MLPs within each block. Defaults to 0.1.
|
213 |
+
"""
|
214 |
+
super().__init__()
|
215 |
+
|
216 |
+
self.Encoder_Blocks = nn.ModuleList([
|
217 |
+
Encoder_Block(dim, heads, act, mlp_ratio, drop_rate)
|
218 |
+
for _ in range(depth)
|
219 |
+
])
|
220 |
+
|
221 |
+
def forward(self, x, y):
|
222 |
+
"""
|
223 |
+
Forward pass of the Transformer Encoder.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
x (torch.Tensor): Node feature tensor.
|
227 |
+
y (torch.Tensor): Edge feature tensor.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
tuple: (final node features, final edge features) after processing through all encoder blocks.
|
231 |
+
"""
|
232 |
+
for block in self.Encoder_Blocks:
|
233 |
+
x, y = block(x, y)
|
234 |
+
return x, y
|
src/model/loss.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def gradient_penalty(discriminator, real_node, real_edge, fake_node, fake_edge, batch_size, device):
|
5 |
+
"""
|
6 |
+
Calculate gradient penalty for WGAN-GP.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
discriminator: The discriminator model
|
10 |
+
real_node: Real node features
|
11 |
+
real_edge: Real edge features
|
12 |
+
fake_node: Generated node features
|
13 |
+
fake_edge: Generated edge features
|
14 |
+
batch_size: Batch size
|
15 |
+
device: Device to compute on
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Gradient penalty term
|
19 |
+
"""
|
20 |
+
# Generate random interpolation factors
|
21 |
+
eps_edge = torch.rand(batch_size, 1, 1, 1, device=device)
|
22 |
+
eps_node = torch.rand(batch_size, 1, 1, device=device)
|
23 |
+
|
24 |
+
# Create interpolated samples
|
25 |
+
int_node = (eps_node * real_node + (1 - eps_node) * fake_node).requires_grad_(True)
|
26 |
+
int_edge = (eps_edge * real_edge + (1 - eps_edge) * fake_edge).requires_grad_(True)
|
27 |
+
|
28 |
+
logits_interpolated = discriminator(int_edge, int_node)
|
29 |
+
|
30 |
+
# Calculate gradients for both node and edge inputs
|
31 |
+
weight = torch.ones(logits_interpolated.size(), requires_grad=False).to(device)
|
32 |
+
gradients = torch.autograd.grad(
|
33 |
+
outputs=logits_interpolated,
|
34 |
+
inputs=[int_node, int_edge],
|
35 |
+
grad_outputs=weight,
|
36 |
+
create_graph=True,
|
37 |
+
retain_graph=True,
|
38 |
+
only_inputs=True
|
39 |
+
)
|
40 |
+
|
41 |
+
# Combine gradients from both inputs
|
42 |
+
gradients_node = gradients[0].view(batch_size, -1)
|
43 |
+
gradients_edge = gradients[1].view(batch_size, -1)
|
44 |
+
gradients = torch.cat([gradients_node, gradients_edge], dim=1)
|
45 |
+
|
46 |
+
# Calculate gradient penalty
|
47 |
+
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
|
48 |
+
|
49 |
+
return gradient_penalty
|
50 |
+
|
51 |
+
|
52 |
+
def discriminator_loss(generator, discriminator, drug_adj, drug_annot, mol_adj, mol_annot, batch_size, device, lambda_gp):
|
53 |
+
# Compute loss for drugs
|
54 |
+
logits_real_disc = discriminator(drug_adj, drug_annot)
|
55 |
+
|
56 |
+
# Use mean reduction for more stable training
|
57 |
+
prediction_real = -torch.mean(logits_real_disc)
|
58 |
+
|
59 |
+
# Compute loss for generated molecules
|
60 |
+
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
|
61 |
+
|
62 |
+
logits_fake_disc = discriminator(edge_sample.detach(), node_sample.detach())
|
63 |
+
|
64 |
+
prediction_fake = torch.mean(logits_fake_disc)
|
65 |
+
|
66 |
+
# Compute gradient penalty using the new function
|
67 |
+
gp = gradient_penalty(discriminator, drug_annot, drug_adj, node_sample.detach(), edge_sample.detach(), batch_size, device)
|
68 |
+
|
69 |
+
# Calculate total discriminator loss
|
70 |
+
d_loss = prediction_fake + prediction_real + lambda_gp * gp
|
71 |
+
|
72 |
+
return node, edge, d_loss
|
73 |
+
|
74 |
+
|
75 |
+
def generator_loss(generator, discriminator, mol_adj, mol_annot, batch_size):
|
76 |
+
# Generate fake molecules
|
77 |
+
node, edge, node_sample, edge_sample = generator(mol_adj, mol_annot)
|
78 |
+
|
79 |
+
# Compute logits for fake molecules
|
80 |
+
logits_fake_disc = discriminator(edge_sample, node_sample)
|
81 |
+
|
82 |
+
prediction_fake = -torch.mean(logits_fake_disc)
|
83 |
+
g_loss = prediction_fake
|
84 |
+
|
85 |
+
return g_loss, node, edge, node_sample, edge_sample
|
src/model/models.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from src.model.layers import TransformerEncoder
|
4 |
+
|
5 |
+
class Generator(nn.Module):
|
6 |
+
"""
|
7 |
+
Generator network that uses a Transformer Encoder to process node and edge features.
|
8 |
+
|
9 |
+
The network first processes input node and edge features with separate linear layers,
|
10 |
+
then applies a Transformer Encoder to model interactions, and finally outputs both transformed
|
11 |
+
features and readout samples.
|
12 |
+
"""
|
13 |
+
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
|
14 |
+
"""
|
15 |
+
Initializes the Generator.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
act (str): Type of activation function to use ("relu", "leaky", "sigmoid", or "tanh").
|
19 |
+
vertexes (int): Number of vertexes in the graph.
|
20 |
+
edges (int): Number of edge features.
|
21 |
+
nodes (int): Number of node features.
|
22 |
+
dropout (float): Dropout rate.
|
23 |
+
dim (int): Dimensionality used for intermediate features.
|
24 |
+
depth (int): Number of Transformer encoder blocks.
|
25 |
+
heads (int): Number of attention heads in the Transformer.
|
26 |
+
mlp_ratio (int): Ratio for determining hidden layer size in MLP modules.
|
27 |
+
"""
|
28 |
+
super(Generator, self).__init__()
|
29 |
+
self.vertexes = vertexes
|
30 |
+
self.edges = edges
|
31 |
+
self.nodes = nodes
|
32 |
+
self.depth = depth
|
33 |
+
self.dim = dim
|
34 |
+
self.heads = heads
|
35 |
+
self.mlp_ratio = mlp_ratio
|
36 |
+
self.dropout = dropout
|
37 |
+
|
38 |
+
# Set the activation function based on the provided string
|
39 |
+
if act == "relu":
|
40 |
+
act = nn.ReLU()
|
41 |
+
elif act == "leaky":
|
42 |
+
act = nn.LeakyReLU()
|
43 |
+
elif act == "sigmoid":
|
44 |
+
act = nn.Sigmoid()
|
45 |
+
elif act == "tanh":
|
46 |
+
act = nn.Tanh()
|
47 |
+
|
48 |
+
# Calculate the total number of features and dimensions for transformer
|
49 |
+
self.features = vertexes * vertexes * edges + vertexes * nodes
|
50 |
+
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
|
51 |
+
|
52 |
+
self.node_layers = nn.Sequential(
|
53 |
+
nn.Linear(nodes, 64), act,
|
54 |
+
nn.Linear(64, dim), act,
|
55 |
+
nn.Dropout(self.dropout)
|
56 |
+
)
|
57 |
+
self.edge_layers = nn.Sequential(
|
58 |
+
nn.Linear(edges, 64), act,
|
59 |
+
nn.Linear(64, dim), act,
|
60 |
+
nn.Dropout(self.dropout)
|
61 |
+
)
|
62 |
+
self.TransformerEncoder = TransformerEncoder(
|
63 |
+
dim=self.dim, depth=self.depth, heads=self.heads, act=act,
|
64 |
+
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
|
65 |
+
)
|
66 |
+
|
67 |
+
self.readout_e = nn.Linear(self.dim, edges)
|
68 |
+
self.readout_n = nn.Linear(self.dim, nodes)
|
69 |
+
self.softmax = nn.Softmax(dim=-1)
|
70 |
+
|
71 |
+
def forward(self, z_e, z_n):
|
72 |
+
"""
|
73 |
+
Forward pass of the Generator.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
|
77 |
+
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tuple: A tuple containing:
|
81 |
+
- node: Updated node features after the transformer.
|
82 |
+
- edge: Updated edge features after the transformer.
|
83 |
+
- node_sample: Readout sample from node features.
|
84 |
+
- edge_sample: Readout sample from edge features.
|
85 |
+
"""
|
86 |
+
b, n, c = z_n.shape
|
87 |
+
# The fourth dimension of edge features
|
88 |
+
_, _, _, d = z_e.shape
|
89 |
+
|
90 |
+
# Process node and edge features through their respective layers
|
91 |
+
node = self.node_layers(z_n)
|
92 |
+
edge = self.edge_layers(z_e)
|
93 |
+
# Symmetrize the edge features by averaging with its transpose along vertex dimensions
|
94 |
+
edge = (edge + edge.permute(0, 2, 1, 3)) / 2
|
95 |
+
|
96 |
+
# Pass the features through the Transformer Encoder
|
97 |
+
node, edge = self.TransformerEncoder(node, edge)
|
98 |
+
|
99 |
+
# Readout layers to generate final outputs
|
100 |
+
node_sample = self.readout_n(node)
|
101 |
+
edge_sample = self.readout_e(edge)
|
102 |
+
|
103 |
+
return node, edge, node_sample, edge_sample
|
104 |
+
|
105 |
+
|
106 |
+
class Discriminator(nn.Module):
|
107 |
+
"""
|
108 |
+
Discriminator network that evaluates node and edge features.
|
109 |
+
|
110 |
+
It processes features with linear layers, applies a Transformer Encoder to capture dependencies,
|
111 |
+
and finally predicts a scalar value using an MLP on aggregated node features.
|
112 |
+
|
113 |
+
This class is used in DrugGEN model.
|
114 |
+
"""
|
115 |
+
def __init__(self, act, vertexes, edges, nodes, dropout, dim, depth, heads, mlp_ratio):
|
116 |
+
"""
|
117 |
+
Initializes the Discriminator.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
|
121 |
+
vertexes (int): Number of vertexes.
|
122 |
+
edges (int): Number of edge features.
|
123 |
+
nodes (int): Number of node features.
|
124 |
+
dropout (float): Dropout rate.
|
125 |
+
dim (int): Dimensionality for intermediate representations.
|
126 |
+
depth (int): Number of Transformer encoder blocks.
|
127 |
+
heads (int): Number of attention heads.
|
128 |
+
mlp_ratio (int): MLP ratio for hidden layer dimensions.
|
129 |
+
"""
|
130 |
+
super(Discriminator, self).__init__()
|
131 |
+
self.vertexes = vertexes
|
132 |
+
self.edges = edges
|
133 |
+
self.nodes = nodes
|
134 |
+
self.depth = depth
|
135 |
+
self.dim = dim
|
136 |
+
self.heads = heads
|
137 |
+
self.mlp_ratio = mlp_ratio
|
138 |
+
self.dropout = dropout
|
139 |
+
|
140 |
+
# Set the activation function
|
141 |
+
if act == "relu":
|
142 |
+
act = nn.ReLU()
|
143 |
+
elif act == "leaky":
|
144 |
+
act = nn.LeakyReLU()
|
145 |
+
elif act == "sigmoid":
|
146 |
+
act = nn.Sigmoid()
|
147 |
+
elif act == "tanh":
|
148 |
+
act = nn.Tanh()
|
149 |
+
|
150 |
+
self.features = vertexes * vertexes * edges + vertexes * nodes
|
151 |
+
self.transformer_dim = vertexes * vertexes * dim + vertexes * dim
|
152 |
+
|
153 |
+
# Define layers for processing node and edge features
|
154 |
+
self.node_layers = nn.Sequential(
|
155 |
+
nn.Linear(nodes, 64), act,
|
156 |
+
nn.Linear(64, dim), act,
|
157 |
+
nn.Dropout(self.dropout)
|
158 |
+
)
|
159 |
+
self.edge_layers = nn.Sequential(
|
160 |
+
nn.Linear(edges, 64), act,
|
161 |
+
nn.Linear(64, dim), act,
|
162 |
+
nn.Dropout(self.dropout)
|
163 |
+
)
|
164 |
+
# Transformer Encoder for modeling node and edge interactions
|
165 |
+
self.TransformerEncoder = TransformerEncoder(
|
166 |
+
dim=self.dim, depth=self.depth, heads=self.heads, act=act,
|
167 |
+
mlp_ratio=self.mlp_ratio, drop_rate=self.dropout
|
168 |
+
)
|
169 |
+
# Calculate dimensions for node features aggregation
|
170 |
+
self.node_features = vertexes * dim
|
171 |
+
self.edge_features = vertexes * vertexes * dim
|
172 |
+
# MLP to predict a scalar value from aggregated node features
|
173 |
+
self.node_mlp = nn.Sequential(
|
174 |
+
nn.Linear(self.node_features, 64), act,
|
175 |
+
nn.Linear(64, 32), act,
|
176 |
+
nn.Linear(32, 16), act,
|
177 |
+
nn.Linear(16, 1)
|
178 |
+
)
|
179 |
+
|
180 |
+
def forward(self, z_e, z_n):
|
181 |
+
"""
|
182 |
+
Forward pass of the Discriminator.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
z_e (torch.Tensor): Edge features tensor of shape (batch, vertexes, vertexes, edges).
|
186 |
+
z_n (torch.Tensor): Node features tensor of shape (batch, vertexes, nodes).
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
torch.Tensor: Prediction scores (typically a scalar per sample).
|
190 |
+
"""
|
191 |
+
b, n, c = z_n.shape
|
192 |
+
# Unpack the shape of edge features (not used further directly)
|
193 |
+
_, _, _, d = z_e.shape
|
194 |
+
|
195 |
+
# Process node and edge features separately
|
196 |
+
node = self.node_layers(z_n)
|
197 |
+
edge = self.edge_layers(z_e)
|
198 |
+
# Symmetrize edge features by averaging with its transpose
|
199 |
+
edge = (edge + edge.permute(0, 2, 1, 3)) / 2
|
200 |
+
|
201 |
+
# Process features through the Transformer Encoder
|
202 |
+
node, edge = self.TransformerEncoder(node, edge)
|
203 |
+
|
204 |
+
# Flatten node features for MLP
|
205 |
+
node = node.view(b, -1)
|
206 |
+
# Predict a scalar score using the node MLP
|
207 |
+
prediction = self.node_mlp(node)
|
208 |
+
|
209 |
+
return prediction
|
210 |
+
|
211 |
+
|
212 |
+
class simple_disc(nn.Module):
|
213 |
+
"""
|
214 |
+
A simplified discriminator that processes flattened features through an MLP
|
215 |
+
to predict a scalar score.
|
216 |
+
|
217 |
+
This class is used in NoTarget model.
|
218 |
+
"""
|
219 |
+
def __init__(self, act, m_dim, vertexes, b_dim):
|
220 |
+
"""
|
221 |
+
Initializes the simple discriminator.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
act (str): Activation function type ("relu", "leaky", "sigmoid", or "tanh").
|
225 |
+
m_dim (int): Dimensionality for atom type features.
|
226 |
+
vertexes (int): Number of vertexes.
|
227 |
+
b_dim (int): Dimensionality for bond type features.
|
228 |
+
"""
|
229 |
+
super().__init__()
|
230 |
+
|
231 |
+
# Set the activation function and check if it's supported
|
232 |
+
if act == "relu":
|
233 |
+
act = nn.ReLU()
|
234 |
+
elif act == "leaky":
|
235 |
+
act = nn.LeakyReLU()
|
236 |
+
elif act == "sigmoid":
|
237 |
+
act = nn.Sigmoid()
|
238 |
+
elif act == "tanh":
|
239 |
+
act = nn.Tanh()
|
240 |
+
else:
|
241 |
+
raise ValueError("Unsupported activation function: {}".format(act))
|
242 |
+
|
243 |
+
# Compute total number of features combining both dimensions
|
244 |
+
features = vertexes * m_dim + vertexes * vertexes * b_dim
|
245 |
+
print(vertexes)
|
246 |
+
print(m_dim)
|
247 |
+
print(b_dim)
|
248 |
+
print(features)
|
249 |
+
self.predictor = nn.Sequential(
|
250 |
+
nn.Linear(features, 256), act,
|
251 |
+
nn.Linear(256, 128), act,
|
252 |
+
nn.Linear(128, 64), act,
|
253 |
+
nn.Linear(64, 32), act,
|
254 |
+
nn.Linear(32, 16), act,
|
255 |
+
nn.Linear(16, 1)
|
256 |
+
)
|
257 |
+
|
258 |
+
def forward(self, x):
|
259 |
+
"""
|
260 |
+
Forward pass of the simple discriminator.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
x (torch.Tensor): Input features tensor.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
torch.Tensor: Prediction scores.
|
267 |
+
"""
|
268 |
+
prediction = self.predictor(x)
|
269 |
+
return prediction
|
src/util/__init__.py
ADDED
File without changes
|
src/util/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (155 Bytes). View file
|
|
src/util/__pycache__/smiles_cor.cpython-310.pyc
ADDED
Binary file (30.2 kB). View file
|
|
src/util/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (30 kB). View file
|
|
src/util/smiles_cor.py
ADDED
@@ -0,0 +1,1284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
import itertools
|
6 |
+
import statistics
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.optim as optim
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from torchtext.data import TabularDataset, Field, BucketIterator, Iterator
|
15 |
+
|
16 |
+
from rdkit import Chem, rdBase, RDLogger
|
17 |
+
from rdkit.Chem import (
|
18 |
+
MolStandardize,
|
19 |
+
GraphDescriptors,
|
20 |
+
Lipinski,
|
21 |
+
AllChem,
|
22 |
+
)
|
23 |
+
from rdkit.Chem.rdSLNParse import MolFromSLN
|
24 |
+
from rdkit.Chem.rdmolfiles import MolFromSmiles
|
25 |
+
from chembl_structure_pipeline import standardizer
|
26 |
+
|
27 |
+
RDLogger.DisableLog('rdApp.*')
|
28 |
+
|
29 |
+
SEED = 42
|
30 |
+
random.seed(SEED)
|
31 |
+
torch.manual_seed(SEED)
|
32 |
+
torch.backends.cudnn.deterministic = True
|
33 |
+
|
34 |
+
##################################################################################################
|
35 |
+
##################################################################################################
|
36 |
+
# #
|
37 |
+
# THIS SCRIPT IS DIRECTLY ADAPTED FROM https://github.com/LindeSchoenmaker/SMILES-corrector #
|
38 |
+
# #
|
39 |
+
##################################################################################################
|
40 |
+
##################################################################################################
|
41 |
+
def is_smiles(array,
|
42 |
+
TRG,
|
43 |
+
reverse: bool,
|
44 |
+
return_output=False,
|
45 |
+
src=None,
|
46 |
+
src_field=None):
|
47 |
+
"""Turns predicted tokens within batch into smiles and evaluates their validity
|
48 |
+
Arguments:
|
49 |
+
array: Tensor with most probable token for each location for each sequence in batch
|
50 |
+
[trg len, batch size]
|
51 |
+
TRG: target field for getting tokens from vocab
|
52 |
+
reverse (bool): True if the target sequence is reversed
|
53 |
+
return_output (bool): True if output sequences and their validity should be saved
|
54 |
+
Returns:
|
55 |
+
df: dataframe with correct and incorrect sequences
|
56 |
+
valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
|
57 |
+
smiless: list of the predicted smiles
|
58 |
+
"""
|
59 |
+
trg_field = TRG
|
60 |
+
valids = []
|
61 |
+
smiless = []
|
62 |
+
if return_output:
|
63 |
+
df = pd.DataFrame()
|
64 |
+
else:
|
65 |
+
df = None
|
66 |
+
batch_size = array.size(1)
|
67 |
+
# check if the first token should be removed, first token is zero because
|
68 |
+
# outputs initaliazed to all be zeros
|
69 |
+
if int((array[0, 0]).tolist()) == 0:
|
70 |
+
start = 1
|
71 |
+
else:
|
72 |
+
start = 0
|
73 |
+
# for each sequence in the batch
|
74 |
+
for i in range(0, batch_size):
|
75 |
+
# turns sequence from tensor to list skipps first row as this is not
|
76 |
+
# filled in in forward
|
77 |
+
sequence = (array[start:, i]).tolist()
|
78 |
+
# goes from embedded to tokens
|
79 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
80 |
+
# print(trg_tokens)
|
81 |
+
# takes all tokens untill eos token, model would be faster if did this
|
82 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
83 |
+
rev_tokens = list(
|
84 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
85 |
+
if reverse:
|
86 |
+
rev_tokens = rev_tokens[::-1]
|
87 |
+
smiles = "".join(rev_tokens)
|
88 |
+
# determine how many valid smiles are made
|
89 |
+
valid = True if MolFromSmiles(smiles) else False
|
90 |
+
valids.append(valid)
|
91 |
+
smiless.append(smiles)
|
92 |
+
if return_output:
|
93 |
+
if valid:
|
94 |
+
df.loc[i, "CORRECT"] = smiles
|
95 |
+
else:
|
96 |
+
df.loc[i, "INCORRECT"] = smiles
|
97 |
+
|
98 |
+
# add the original drugex outputs to the _de dataframe
|
99 |
+
if return_output and src is not None:
|
100 |
+
for i in range(0, batch_size):
|
101 |
+
# turns sequence from tensor to list skipps first row as this is
|
102 |
+
# <sos> for src
|
103 |
+
sequence = (src[1:, i]).tolist()
|
104 |
+
# goes from embedded to tokens
|
105 |
+
src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
|
106 |
+
# takes all tokens untill eos token, model would be faster if did
|
107 |
+
# this one step earlier, but then changes in vocab order would
|
108 |
+
# disrupt.
|
109 |
+
rev_tokens = list(
|
110 |
+
itertools.takewhile(lambda x: x != "<eos>", src_tokens))
|
111 |
+
smiles = "".join(rev_tokens)
|
112 |
+
df.loc[i, "ORIGINAL"] = smiles
|
113 |
+
|
114 |
+
return df, valids, smiless
|
115 |
+
|
116 |
+
|
117 |
+
def is_unchanged(array,
|
118 |
+
TRG,
|
119 |
+
reverse: bool,
|
120 |
+
return_output=False,
|
121 |
+
src=None,
|
122 |
+
src_field=None):
|
123 |
+
"""Checks is output is different from input
|
124 |
+
Arguments:
|
125 |
+
array: Tensor with most probable token for each location for each sequence in batch
|
126 |
+
[trg len, batch size]
|
127 |
+
TRG: target field for getting tokens from vocab
|
128 |
+
reverse (bool): True if the target sequence is reversed
|
129 |
+
return_output (bool): True if output sequences and their validity should be saved
|
130 |
+
Returns:
|
131 |
+
df: dataframe with correct and incorrect sequences
|
132 |
+
valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
|
133 |
+
smiless: list of the predicted smiles
|
134 |
+
"""
|
135 |
+
trg_field = TRG
|
136 |
+
sources = []
|
137 |
+
batch_size = array.size(1)
|
138 |
+
unchanged = 0
|
139 |
+
|
140 |
+
# check if the first token should be removed, first token is zero because
|
141 |
+
# outputs initaliazed to all be zeros
|
142 |
+
if int((array[0, 0]).tolist()) == 0:
|
143 |
+
start = 1
|
144 |
+
else:
|
145 |
+
start = 0
|
146 |
+
|
147 |
+
for i in range(0, batch_size):
|
148 |
+
# turns sequence from tensor to list skipps first row as this is <sos>
|
149 |
+
# for src
|
150 |
+
sequence = (src[1:, i]).tolist()
|
151 |
+
# goes from embedded to tokens
|
152 |
+
src_tokens = [src_field.vocab.itos[int(t)] for t in sequence]
|
153 |
+
# takes all tokens untill eos token, model would be faster if did this
|
154 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
155 |
+
rev_tokens = list(
|
156 |
+
itertools.takewhile(lambda x: x != "<eos>", src_tokens))
|
157 |
+
smiles = "".join(rev_tokens)
|
158 |
+
sources.append(smiles)
|
159 |
+
|
160 |
+
# for each sequence in the batch
|
161 |
+
for i in range(0, batch_size):
|
162 |
+
# turns sequence from tensor to list skipps first row as this is not
|
163 |
+
# filled in in forward
|
164 |
+
sequence = (array[start:, i]).tolist()
|
165 |
+
# goes from embedded to tokens
|
166 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
167 |
+
# print(trg_tokens)
|
168 |
+
# takes all tokens untill eos token, model would be faster if did this
|
169 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
170 |
+
rev_tokens = list(
|
171 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
172 |
+
if reverse:
|
173 |
+
rev_tokens = rev_tokens[::-1]
|
174 |
+
smiles = "".join(rev_tokens)
|
175 |
+
# determine how many valid smiles are made
|
176 |
+
valid = True if MolFromSmiles(smiles) else False
|
177 |
+
if not valid:
|
178 |
+
if smiles == sources[i]:
|
179 |
+
unchanged += 1
|
180 |
+
|
181 |
+
return unchanged
|
182 |
+
|
183 |
+
|
184 |
+
def molecule_reconstruction(array, TRG, reverse: bool, outputs):
|
185 |
+
"""Turns target tokens within batch into smiles and compares them to predicted output smiles
|
186 |
+
Arguments:
|
187 |
+
array: Tensor with target's token for each location for each sequence in batch
|
188 |
+
[trg len, batch size]
|
189 |
+
TRG: target field for getting tokens from vocab
|
190 |
+
reverse (bool): True if the target sequence is reversed
|
191 |
+
outputs: list of predicted SMILES sequences
|
192 |
+
Returns:
|
193 |
+
matches(int): number of total right molecules
|
194 |
+
"""
|
195 |
+
trg_field = TRG
|
196 |
+
matches = 0
|
197 |
+
targets = []
|
198 |
+
batch_size = array.size(1)
|
199 |
+
# for each sequence in the batch
|
200 |
+
for i in range(0, batch_size):
|
201 |
+
# turns sequence from tensor to list skipps first row as this is not
|
202 |
+
# filled in in forward
|
203 |
+
sequence = (array[1:, i]).tolist()
|
204 |
+
# goes from embedded to tokens
|
205 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
206 |
+
# takes all tokens untill eos token, model would be faster if did this
|
207 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
208 |
+
rev_tokens = list(
|
209 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
210 |
+
if reverse:
|
211 |
+
rev_tokens = rev_tokens[::-1]
|
212 |
+
smiles = "".join(rev_tokens)
|
213 |
+
targets.append(smiles)
|
214 |
+
for i in range(0, batch_size):
|
215 |
+
m = MolFromSmiles(targets[i])
|
216 |
+
p = MolFromSmiles(outputs[i])
|
217 |
+
if p is not None:
|
218 |
+
if m.HasSubstructMatch(p) and p.HasSubstructMatch(m):
|
219 |
+
matches += 1
|
220 |
+
return matches
|
221 |
+
|
222 |
+
|
223 |
+
def complexity_whitlock(mol: Chem.Mol, includeAllDescs=False):
|
224 |
+
"""
|
225 |
+
Complexity as defined in DOI:10.1021/jo9814546
|
226 |
+
S: complexity = 4*#rings + 2*#unsat + #hetatm + 2*#chiral
|
227 |
+
Other descriptors:
|
228 |
+
H: size = #bonds (Hydrogen atoms included)
|
229 |
+
G: S + H
|
230 |
+
Ratio: S / H
|
231 |
+
"""
|
232 |
+
mol_ = Chem.Mol(mol)
|
233 |
+
nrings = Lipinski.RingCount(mol_) - Lipinski.NumAromaticRings(mol_)
|
234 |
+
Chem.rdmolops.SetAromaticity(mol_)
|
235 |
+
unsat = sum(1 for bond in mol_.GetBonds()
|
236 |
+
if bond.GetBondTypeAsDouble() == 2)
|
237 |
+
hetatm = len(mol_.GetSubstructMatches(Chem.MolFromSmarts("[!#6]")))
|
238 |
+
AllChem.EmbedMolecule(mol_)
|
239 |
+
Chem.AssignAtomChiralTagsFromStructure(mol_)
|
240 |
+
chiral = len(Chem.FindMolChiralCenters(mol_))
|
241 |
+
S = 4 * nrings + 2 * unsat + hetatm + 2 * chiral
|
242 |
+
if not includeAllDescs:
|
243 |
+
return S
|
244 |
+
Chem.rdmolops.Kekulize(mol_)
|
245 |
+
mol_ = Chem.AddHs(mol_)
|
246 |
+
H = sum(bond.GetBondTypeAsDouble() for bond in mol_.GetBonds())
|
247 |
+
G = S + H
|
248 |
+
R = S / H
|
249 |
+
return {"WhitlockS": S, "WhitlockH": H, "WhitlockG": G, "WhitlockRatio": R}
|
250 |
+
|
251 |
+
|
252 |
+
def complexity_baronechanon(mol: Chem.Mol):
|
253 |
+
"""
|
254 |
+
Complexity as defined in DOI:10.1021/ci000145p
|
255 |
+
"""
|
256 |
+
mol_ = Chem.Mol(mol)
|
257 |
+
Chem.Kekulize(mol_)
|
258 |
+
Chem.RemoveStereochemistry(mol_)
|
259 |
+
mol_ = Chem.RemoveHs(mol_, updateExplicitCount=True)
|
260 |
+
degree, counts = 0, 0
|
261 |
+
for atom in mol_.GetAtoms():
|
262 |
+
degree += 3 * 2**(atom.GetExplicitValence() - atom.GetNumExplicitHs() -
|
263 |
+
1)
|
264 |
+
counts += 3 if atom.GetSymbol() == "C" else 6
|
265 |
+
ringterm = sum(map(lambda x: 6 * len(x), mol_.GetRingInfo().AtomRings()))
|
266 |
+
return degree + counts + ringterm
|
267 |
+
|
268 |
+
|
269 |
+
def calc_complexity(array,
|
270 |
+
TRG,
|
271 |
+
reverse,
|
272 |
+
valids,
|
273 |
+
complexity_function=GraphDescriptors.BertzCT):
|
274 |
+
"""Calculates the complexity of inputs that are not correct.
|
275 |
+
Arguments:
|
276 |
+
array: Tensor with target's token for each location for each sequence in batch
|
277 |
+
[trg len, batch size]
|
278 |
+
TRG: target field for getting tokens from vocab
|
279 |
+
reverse (bool): True if the target sequence is reversed
|
280 |
+
valids: list with booleans that show if prediction was a valid SMILES (True) or invalid one (False)
|
281 |
+
complexity_function: the type of complexity measure that will be used
|
282 |
+
GraphDescriptors.BertzCT
|
283 |
+
complexity_whitlock
|
284 |
+
complexity_baronechanon
|
285 |
+
Returns:
|
286 |
+
matches(int): mean of complexity values
|
287 |
+
"""
|
288 |
+
trg_field = TRG
|
289 |
+
sources = []
|
290 |
+
complexities = []
|
291 |
+
loc = torch.BoolTensor(valids)
|
292 |
+
# only keeps rows in batch size dimension where valid is false
|
293 |
+
array = array[:, loc == False]
|
294 |
+
# should check if this still works
|
295 |
+
# array = torch.transpose(array, 0, 1)
|
296 |
+
array_size = array.size(1)
|
297 |
+
for i in range(0, array_size):
|
298 |
+
# turns sequence from tensor to list skipps first row as this is not
|
299 |
+
# filled in in forward
|
300 |
+
sequence = (array[1:, i]).tolist()
|
301 |
+
# goes from embedded to tokens
|
302 |
+
trg_tokens = [trg_field.vocab.itos[int(t)] for t in sequence]
|
303 |
+
# takes all tokens untill eos token, model would be faster if did this
|
304 |
+
# one step earlier, but then changes in vocab order would disrupt.
|
305 |
+
rev_tokens = list(
|
306 |
+
itertools.takewhile(lambda x: x != "<eos>", trg_tokens))
|
307 |
+
if reverse:
|
308 |
+
rev_tokens = rev_tokens[::-1]
|
309 |
+
smiles = "".join(rev_tokens)
|
310 |
+
sources.append(smiles)
|
311 |
+
for source in sources:
|
312 |
+
try:
|
313 |
+
m = MolFromSmiles(source)
|
314 |
+
except BaseException:
|
315 |
+
m = MolFromSLN(source)
|
316 |
+
complexities.append(complexity_function(m))
|
317 |
+
if len(complexities) > 0:
|
318 |
+
mean = statistics.mean(complexities)
|
319 |
+
else:
|
320 |
+
mean = 0
|
321 |
+
return mean
|
322 |
+
|
323 |
+
|
324 |
+
def epoch_time(start_time, end_time):
|
325 |
+
elapsed_time = end_time - start_time
|
326 |
+
elapsed_mins = int(elapsed_time / 60)
|
327 |
+
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
|
328 |
+
return elapsed_mins, elapsed_secs
|
329 |
+
|
330 |
+
|
331 |
+
class Convo:
|
332 |
+
"""Class for training and evaluating transformer and convolutional neural network
|
333 |
+
|
334 |
+
Methods
|
335 |
+
-------
|
336 |
+
train_model()
|
337 |
+
train model for initialized number of epochs
|
338 |
+
evaluate(return_output)
|
339 |
+
use model with validation loader (& optionally drugex loader) to get test loss & other metrics
|
340 |
+
translate(loader)
|
341 |
+
translate inputs from loader (different from evaluate in that no target sequence is used)
|
342 |
+
"""
|
343 |
+
|
344 |
+
def train_model(self):
|
345 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr)
|
346 |
+
log = open(f"{self.out}.log", "a")
|
347 |
+
best_error = np.inf
|
348 |
+
for epoch in range(self.epochs):
|
349 |
+
self.train()
|
350 |
+
start_time = time.time()
|
351 |
+
loss_train = 0
|
352 |
+
for i, batch in enumerate(self.loader_train):
|
353 |
+
optimizer.zero_grad()
|
354 |
+
# changed src,trg call to match with bentrevett
|
355 |
+
# src, trg = batch['src'], batch['trg']
|
356 |
+
trg = batch.trg
|
357 |
+
src = batch.src
|
358 |
+
output, attention = self(src, trg[:, :-1])
|
359 |
+
# feed the source and target into def forward to get the output
|
360 |
+
# Xuhan uses forward for this, with istrain = true
|
361 |
+
output_dim = output.shape[-1]
|
362 |
+
# changed
|
363 |
+
output = output.contiguous().view(-1, output_dim)
|
364 |
+
trg = trg[:, 1:].contiguous().view(-1)
|
365 |
+
# output = output[:,:,0]#.view(-1)
|
366 |
+
# output = output[1:].view(-1, output.shape[-1])
|
367 |
+
# trg = trg[1:].view(-1)
|
368 |
+
loss = nn.CrossEntropyLoss(
|
369 |
+
ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
|
370 |
+
a, b = output.view(-1), trg.to(self.device).view(-1)
|
371 |
+
# changed
|
372 |
+
# loss = loss(output.view(0), trg.view(0).to(device))
|
373 |
+
loss = loss(output, trg)
|
374 |
+
loss.backward()
|
375 |
+
torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip)
|
376 |
+
optimizer.step()
|
377 |
+
loss_train += loss.item()
|
378 |
+
# turned off for now, as not using voc so won't work, output is a tensor
|
379 |
+
# output = [(trg len - 1) * batch size, output dim]
|
380 |
+
# smiles, valid = is_valid_smiles(output, reversed)
|
381 |
+
# if valid:
|
382 |
+
# valids += 1
|
383 |
+
# smiless.append(smiles)
|
384 |
+
# added .dataset becaue len(iterator) gives len(self.dataset) /
|
385 |
+
# self.batch_size)
|
386 |
+
loss_train /= len(self.loader_train)
|
387 |
+
info = f"Epoch: {epoch+1:02} step: {i} loss_train: {loss_train:.4g}"
|
388 |
+
# model is used to generate trg based on src from the validation set to assess performance
|
389 |
+
# similar to Xuhan, although he doesn't use the if loop
|
390 |
+
if self.loader_valid is not None:
|
391 |
+
return_output = False
|
392 |
+
if epoch + 1 == self.epochs:
|
393 |
+
return_output = True
|
394 |
+
(
|
395 |
+
valids,
|
396 |
+
loss_valid,
|
397 |
+
valids_de,
|
398 |
+
df_output,
|
399 |
+
df_output_de,
|
400 |
+
right_molecules,
|
401 |
+
complexity,
|
402 |
+
unchanged,
|
403 |
+
unchanged_de,
|
404 |
+
) = self.evaluate(return_output)
|
405 |
+
reconstruction_error = 1 - right_molecules / len(
|
406 |
+
self.loader_valid.dataset)
|
407 |
+
error = 1 - valids / len(self.loader_valid.dataset)
|
408 |
+
complexity = complexity / len(self.loader_valid)
|
409 |
+
unchan = unchanged / (len(self.loader_valid.dataset) - valids)
|
410 |
+
info += f" loss_valid: {loss_valid:.4g} error_rate: {error:.4g} molecule_reconstruction_error_rate: {reconstruction_error:.4g} unchanged: {unchan:.4g} invalid_target_complexity: {complexity:.4g}"
|
411 |
+
if self.loader_drugex is not None:
|
412 |
+
error_de = 1 - valids_de / len(self.loader_drugex.dataset)
|
413 |
+
unchan_de = unchanged_de / (
|
414 |
+
len(self.loader_drugex.dataset) - valids_de)
|
415 |
+
info += f" error_rate_drugex: {error_de:.4g} unchanged_drugex: {unchan_de:.4g}"
|
416 |
+
|
417 |
+
if reconstruction_error < best_error:
|
418 |
+
torch.save(self.state_dict(), f"{self.out}.pkg")
|
419 |
+
best_error = reconstruction_error
|
420 |
+
last_save = epoch
|
421 |
+
else:
|
422 |
+
if epoch - last_save >= 10 and best_error != 1:
|
423 |
+
torch.save(self.state_dict(), f"{self.out}_last.pkg")
|
424 |
+
(
|
425 |
+
valids,
|
426 |
+
loss_valid,
|
427 |
+
valids_de,
|
428 |
+
df_output,
|
429 |
+
df_output_de,
|
430 |
+
right_molecules,
|
431 |
+
complexity,
|
432 |
+
unchanged,
|
433 |
+
unchanged_de,
|
434 |
+
) = self.evaluate(True)
|
435 |
+
end_time = time.time()
|
436 |
+
epoch_mins, epoch_secs = epoch_time(
|
437 |
+
start_time, end_time)
|
438 |
+
info += f" Time: {epoch_mins}m {epoch_secs}s"
|
439 |
+
|
440 |
+
break
|
441 |
+
elif error < best_error:
|
442 |
+
torch.save(self.state_dict(), f"{self.out}.pkg")
|
443 |
+
best_error = error
|
444 |
+
end_time = time.time()
|
445 |
+
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
|
446 |
+
info += f" Time: {epoch_mins}m {epoch_secs}s"
|
447 |
+
|
448 |
+
|
449 |
+
torch.save(self.state_dict(), f"{self.out}_last.pkg")
|
450 |
+
log.close()
|
451 |
+
self.load_state_dict(torch.load(f"{self.out}.pkg"))
|
452 |
+
df_output.to_csv(f"{self.out}.csv", index=False)
|
453 |
+
df_output_de.to_csv(f"{self.out}_de.csv", index=False)
|
454 |
+
|
455 |
+
def evaluate(self, return_output):
|
456 |
+
self.eval()
|
457 |
+
test_loss = 0
|
458 |
+
df_output = pd.DataFrame()
|
459 |
+
df_output_de = pd.DataFrame()
|
460 |
+
valids = 0
|
461 |
+
valids_de = 0
|
462 |
+
unchanged = 0
|
463 |
+
unchanged_de = 0
|
464 |
+
right_molecules = 0
|
465 |
+
complexity = 0
|
466 |
+
with torch.no_grad():
|
467 |
+
for _, batch in enumerate(self.loader_valid):
|
468 |
+
trg = batch.trg
|
469 |
+
src = batch.src
|
470 |
+
output, attention = self.forward(src, trg[:, :-1])
|
471 |
+
pred_token = output.argmax(2)
|
472 |
+
array = torch.transpose(pred_token, 0, 1)
|
473 |
+
trg_trans = torch.transpose(trg, 0, 1)
|
474 |
+
output_dim = output.shape[-1]
|
475 |
+
output = output.contiguous().view(-1, output_dim)
|
476 |
+
trg = trg[:, 1:].contiguous().view(-1)
|
477 |
+
src_trans = torch.transpose(src, 0, 1)
|
478 |
+
df_batch, valid, smiless = is_smiles(
|
479 |
+
array, self.TRG, reverse=True, return_output=return_output)
|
480 |
+
unchanged += is_unchanged(
|
481 |
+
array,
|
482 |
+
self.TRG,
|
483 |
+
reverse=True,
|
484 |
+
return_output=return_output,
|
485 |
+
src=src_trans,
|
486 |
+
src_field=self.SRC,
|
487 |
+
)
|
488 |
+
matches = molecule_reconstruction(trg_trans,
|
489 |
+
self.TRG,
|
490 |
+
reverse=True,
|
491 |
+
outputs=smiless)
|
492 |
+
complexity += calc_complexity(trg_trans,
|
493 |
+
self.TRG,
|
494 |
+
reverse=True,
|
495 |
+
valids=valid)
|
496 |
+
if df_batch is not None:
|
497 |
+
df_output = pd.concat([df_output, df_batch],
|
498 |
+
ignore_index=True)
|
499 |
+
right_molecules += matches
|
500 |
+
valids += sum(valid)
|
501 |
+
# trg = trg[1:].view(-1)
|
502 |
+
# output, trg = output[1:].view(-1, output.shape[-1]), trg[1:].view(-1)
|
503 |
+
loss = nn.CrossEntropyLoss(
|
504 |
+
ignore_index=self.TRG.vocab.stoi[self.TRG.pad_token])
|
505 |
+
loss = loss(output, trg)
|
506 |
+
test_loss += loss.item()
|
507 |
+
if self.loader_drugex is not None:
|
508 |
+
for _, batch in enumerate(self.loader_drugex):
|
509 |
+
src = batch.src
|
510 |
+
output = self.translate_sentence(src, self.TRG,
|
511 |
+
self.device)
|
512 |
+
# checks the number of valid smiles
|
513 |
+
pred_token = output.argmax(2)
|
514 |
+
array = torch.transpose(pred_token, 0, 1)
|
515 |
+
src_trans = torch.transpose(src, 0, 1)
|
516 |
+
df_batch, valid, smiless = is_smiles(
|
517 |
+
array,
|
518 |
+
self.TRG,
|
519 |
+
reverse=True,
|
520 |
+
return_output=return_output,
|
521 |
+
src=src_trans,
|
522 |
+
src_field=self.SRC,
|
523 |
+
)
|
524 |
+
unchanged_de += is_unchanged(
|
525 |
+
array,
|
526 |
+
self.TRG,
|
527 |
+
reverse=True,
|
528 |
+
return_output=return_output,
|
529 |
+
src=src_trans,
|
530 |
+
src_field=self.SRC,
|
531 |
+
)
|
532 |
+
if df_batch is not None:
|
533 |
+
df_output_de = pd.concat([df_output_de, df_batch],
|
534 |
+
ignore_index=True)
|
535 |
+
valids_de += sum(valid)
|
536 |
+
return (
|
537 |
+
valids,
|
538 |
+
test_loss / len(self.loader_valid),
|
539 |
+
valids_de,
|
540 |
+
df_output,
|
541 |
+
df_output_de,
|
542 |
+
right_molecules,
|
543 |
+
complexity,
|
544 |
+
unchanged,
|
545 |
+
unchanged_de,
|
546 |
+
)
|
547 |
+
|
548 |
+
def translate(self, loader):
|
549 |
+
self.eval()
|
550 |
+
df_output_de = pd.DataFrame()
|
551 |
+
valids_de = 0
|
552 |
+
with torch.no_grad():
|
553 |
+
for _, batch in enumerate(loader):
|
554 |
+
src = batch.src
|
555 |
+
output = self.translate_sentence(src, self.TRG, self.device)
|
556 |
+
# checks the number of valid smiles
|
557 |
+
pred_token = output.argmax(2)
|
558 |
+
array = torch.transpose(pred_token, 0, 1)
|
559 |
+
src_trans = torch.transpose(src, 0, 1)
|
560 |
+
df_batch, valid, smiless = is_smiles(
|
561 |
+
array,
|
562 |
+
self.TRG,
|
563 |
+
reverse=True,
|
564 |
+
return_output=True,
|
565 |
+
src=src_trans,
|
566 |
+
src_field=self.SRC,
|
567 |
+
)
|
568 |
+
if df_batch is not None:
|
569 |
+
df_output_de = pd.concat([df_output_de, df_batch],
|
570 |
+
ignore_index=True)
|
571 |
+
valids_de += sum(valid)
|
572 |
+
return valids_de, df_output_de
|
573 |
+
|
574 |
+
|
575 |
+
class Encoder(nn.Module):
|
576 |
+
|
577 |
+
def __init__(self, input_dim, hid_dim, n_layers, n_heads, pf_dim, dropout,
|
578 |
+
max_length, device):
|
579 |
+
super().__init__()
|
580 |
+
self.device = device
|
581 |
+
self.tok_embedding = nn.Embedding(input_dim, hid_dim)
|
582 |
+
self.pos_embedding = nn.Embedding(max_length, hid_dim)
|
583 |
+
self.layers = nn.ModuleList([
|
584 |
+
EncoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
|
585 |
+
for _ in range(n_layers)
|
586 |
+
])
|
587 |
+
|
588 |
+
self.dropout = nn.Dropout(dropout)
|
589 |
+
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
|
590 |
+
|
591 |
+
def forward(self, src, src_mask):
|
592 |
+
# src = [batch size, src len]
|
593 |
+
# src_mask = [batch size, src len]
|
594 |
+
batch_size = src.shape[0]
|
595 |
+
src_len = src.shape[1]
|
596 |
+
pos = (torch.arange(0, src_len).unsqueeze(0).repeat(batch_size,
|
597 |
+
1).to(self.device))
|
598 |
+
# pos = [batch size, src len]
|
599 |
+
src = self.dropout((self.tok_embedding(src) * self.scale) +
|
600 |
+
self.pos_embedding(pos))
|
601 |
+
# src = [batch size, src len, hid dim]
|
602 |
+
for layer in self.layers:
|
603 |
+
src = layer(src, src_mask)
|
604 |
+
# src = [batch size, src len, hid dim]
|
605 |
+
return src
|
606 |
+
|
607 |
+
|
608 |
+
class EncoderLayer(nn.Module):
|
609 |
+
|
610 |
+
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
|
611 |
+
super().__init__()
|
612 |
+
|
613 |
+
self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
|
614 |
+
self.ff_layer_norm = nn.LayerNorm(hid_dim)
|
615 |
+
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
|
616 |
+
dropout, device)
|
617 |
+
self.positionwise_feedforward = PositionwiseFeedforwardLayer(
|
618 |
+
hid_dim, pf_dim, dropout)
|
619 |
+
self.dropout = nn.Dropout(dropout)
|
620 |
+
|
621 |
+
def forward(self, src, src_mask):
|
622 |
+
# src = [batch size, src len, hid dim]
|
623 |
+
# src_mask = [batch size, src len]
|
624 |
+
# self attention
|
625 |
+
_src, _ = self.self_attention(src, src, src, src_mask)
|
626 |
+
# dropout, residual connection and layer norm
|
627 |
+
src = self.self_attn_layer_norm(src + self.dropout(_src))
|
628 |
+
# src = [batch size, src len, hid dim]
|
629 |
+
# positionwise feedforward
|
630 |
+
_src = self.positionwise_feedforward(src)
|
631 |
+
# dropout, residual and layer norm
|
632 |
+
src = self.ff_layer_norm(src + self.dropout(_src))
|
633 |
+
# src = [batch size, src len, hid dim]
|
634 |
+
|
635 |
+
return src
|
636 |
+
|
637 |
+
|
638 |
+
class MultiHeadAttentionLayer(nn.Module):
|
639 |
+
|
640 |
+
def __init__(self, hid_dim, n_heads, dropout, device):
|
641 |
+
super().__init__()
|
642 |
+
assert hid_dim % n_heads == 0
|
643 |
+
self.hid_dim = hid_dim
|
644 |
+
self.n_heads = n_heads
|
645 |
+
self.head_dim = hid_dim // n_heads
|
646 |
+
self.fc_q = nn.Linear(hid_dim, hid_dim)
|
647 |
+
self.fc_k = nn.Linear(hid_dim, hid_dim)
|
648 |
+
self.fc_v = nn.Linear(hid_dim, hid_dim)
|
649 |
+
self.fc_o = nn.Linear(hid_dim, hid_dim)
|
650 |
+
self.dropout = nn.Dropout(dropout)
|
651 |
+
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
|
652 |
+
|
653 |
+
def forward(self, query, key, value, mask=None):
|
654 |
+
batch_size = query.shape[0]
|
655 |
+
# query = [batch size, query len, hid dim]
|
656 |
+
# key = [batch size, key len, hid dim]
|
657 |
+
# value = [batch size, value len, hid dim]
|
658 |
+
Q = self.fc_q(query)
|
659 |
+
K = self.fc_k(key)
|
660 |
+
V = self.fc_v(value)
|
661 |
+
# Q = [batch size, query len, hid dim]
|
662 |
+
# K = [batch size, key len, hid dim]
|
663 |
+
# V = [batch size, value len, hid dim]
|
664 |
+
Q = Q.view(batch_size, -1, self.n_heads,
|
665 |
+
self.head_dim).permute(0, 2, 1, 3)
|
666 |
+
K = K.view(batch_size, -1, self.n_heads,
|
667 |
+
self.head_dim).permute(0, 2, 1, 3)
|
668 |
+
V = V.view(batch_size, -1, self.n_heads,
|
669 |
+
self.head_dim).permute(0, 2, 1, 3)
|
670 |
+
# Q = [batch size, n heads, query len, head dim]
|
671 |
+
# K = [batch size, n heads, key len, head dim]
|
672 |
+
# V = [batch size, n heads, value len, head dim]
|
673 |
+
energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
|
674 |
+
# energy = [batch size, n heads, query len, key len]
|
675 |
+
if mask is not None:
|
676 |
+
energy = energy.masked_fill(mask == 0, -1e10)
|
677 |
+
attention = torch.softmax(energy, dim=-1)
|
678 |
+
# attention = [batch size, n heads, query len, key len]
|
679 |
+
x = torch.matmul(self.dropout(attention), V)
|
680 |
+
# x = [batch size, n heads, query len, head dim]
|
681 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
682 |
+
# x = [batch size, query len, n heads, head dim]
|
683 |
+
x = x.view(batch_size, -1, self.hid_dim)
|
684 |
+
# x = [batch size, query len, hid dim]
|
685 |
+
x = self.fc_o(x)
|
686 |
+
# x = [batch size, query len, hid dim]
|
687 |
+
return x, attention
|
688 |
+
|
689 |
+
|
690 |
+
class PositionwiseFeedforwardLayer(nn.Module):
|
691 |
+
|
692 |
+
def __init__(self, hid_dim, pf_dim, dropout):
|
693 |
+
super().__init__()
|
694 |
+
self.fc_1 = nn.Linear(hid_dim, pf_dim)
|
695 |
+
self.fc_2 = nn.Linear(pf_dim, hid_dim)
|
696 |
+
self.dropout = nn.Dropout(dropout)
|
697 |
+
|
698 |
+
def forward(self, x):
|
699 |
+
# x = [batch size, seq len, hid dim]
|
700 |
+
x = self.dropout(torch.relu(self.fc_1(x)))
|
701 |
+
# x = [batch size, seq len, pf dim]
|
702 |
+
x = self.fc_2(x)
|
703 |
+
# x = [batch size, seq len, hid dim]
|
704 |
+
|
705 |
+
return x
|
706 |
+
|
707 |
+
|
708 |
+
class Decoder(nn.Module):
|
709 |
+
|
710 |
+
def __init__(
|
711 |
+
self,
|
712 |
+
output_dim,
|
713 |
+
hid_dim,
|
714 |
+
n_layers,
|
715 |
+
n_heads,
|
716 |
+
pf_dim,
|
717 |
+
dropout,
|
718 |
+
max_length,
|
719 |
+
device,
|
720 |
+
):
|
721 |
+
super().__init__()
|
722 |
+
self.device = device
|
723 |
+
self.tok_embedding = nn.Embedding(output_dim, hid_dim)
|
724 |
+
self.pos_embedding = nn.Embedding(max_length, hid_dim)
|
725 |
+
self.layers = nn.ModuleList([
|
726 |
+
DecoderLayer(hid_dim, n_heads, pf_dim, dropout, device)
|
727 |
+
for _ in range(n_layers)
|
728 |
+
])
|
729 |
+
self.fc_out = nn.Linear(hid_dim, output_dim)
|
730 |
+
self.dropout = nn.Dropout(dropout)
|
731 |
+
self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)
|
732 |
+
|
733 |
+
def forward(self, trg, enc_src, trg_mask, src_mask):
|
734 |
+
# trg = [batch size, trg len]
|
735 |
+
# enc_src = [batch size, src len, hid dim]
|
736 |
+
# trg_mask = [batch size, trg len]
|
737 |
+
# src_mask = [batch size, src len]
|
738 |
+
batch_size = trg.shape[0]
|
739 |
+
trg_len = trg.shape[1]
|
740 |
+
pos = (torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size,
|
741 |
+
1).to(self.device))
|
742 |
+
# pos = [batch size, trg len]
|
743 |
+
trg = self.dropout((self.tok_embedding(trg) * self.scale) +
|
744 |
+
self.pos_embedding(pos))
|
745 |
+
# trg = [batch size, trg len, hid dim]
|
746 |
+
for layer in self.layers:
|
747 |
+
trg, attention = layer(trg, enc_src, trg_mask, src_mask)
|
748 |
+
# trg = [batch size, trg len, hid dim]
|
749 |
+
# attention = [batch size, n heads, trg len, src len]
|
750 |
+
output = self.fc_out(trg)
|
751 |
+
# output = [batch size, trg len, output dim]
|
752 |
+
return output, attention
|
753 |
+
|
754 |
+
|
755 |
+
class DecoderLayer(nn.Module):
|
756 |
+
|
757 |
+
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
|
758 |
+
super().__init__()
|
759 |
+
self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
|
760 |
+
self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
|
761 |
+
self.ff_layer_norm = nn.LayerNorm(hid_dim)
|
762 |
+
self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads,
|
763 |
+
dropout, device)
|
764 |
+
self.encoder_attention = MultiHeadAttentionLayer(
|
765 |
+
hid_dim, n_heads, dropout, device)
|
766 |
+
self.positionwise_feedforward = PositionwiseFeedforwardLayer(
|
767 |
+
hid_dim, pf_dim, dropout)
|
768 |
+
self.dropout = nn.Dropout(dropout)
|
769 |
+
|
770 |
+
def forward(self, trg, enc_src, trg_mask, src_mask):
|
771 |
+
# trg = [batch size, trg len, hid dim]
|
772 |
+
# enc_src = [batch size, src len, hid dim]
|
773 |
+
# trg_mask = [batch size, trg len]
|
774 |
+
# src_mask = [batch size, src len]
|
775 |
+
# self attention
|
776 |
+
_trg, _ = self.self_attention(trg, trg, trg, trg_mask)
|
777 |
+
# dropout, residual connection and layer norm
|
778 |
+
trg = self.self_attn_layer_norm(trg + self.dropout(_trg))
|
779 |
+
# trg = [batch size, trg len, hid dim]
|
780 |
+
# encoder attention
|
781 |
+
_trg, attention = self.encoder_attention(trg, enc_src, enc_src,
|
782 |
+
src_mask)
|
783 |
+
# dropout, residual connection and layer norm
|
784 |
+
trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))
|
785 |
+
# trg = [batch size, trg len, hid dim]
|
786 |
+
# positionwise feedforward
|
787 |
+
_trg = self.positionwise_feedforward(trg)
|
788 |
+
# dropout, residual and layer norm
|
789 |
+
trg = self.ff_layer_norm(trg + self.dropout(_trg))
|
790 |
+
# trg = [batch size, trg len, hid dim]
|
791 |
+
# attention = [batch size, n heads, trg len, src len]
|
792 |
+
return trg, attention
|
793 |
+
|
794 |
+
|
795 |
+
class Seq2Seq(nn.Module, Convo):
|
796 |
+
|
797 |
+
def __init__(
|
798 |
+
self,
|
799 |
+
encoder,
|
800 |
+
decoder,
|
801 |
+
src_pad_idx,
|
802 |
+
trg_pad_idx,
|
803 |
+
device,
|
804 |
+
loader_train: DataLoader,
|
805 |
+
out: str,
|
806 |
+
loader_valid=None,
|
807 |
+
loader_drugex=None,
|
808 |
+
epochs=100,
|
809 |
+
lr=0.0005,
|
810 |
+
clip=0.1,
|
811 |
+
reverse=True,
|
812 |
+
TRG=None,
|
813 |
+
SRC=None,
|
814 |
+
):
|
815 |
+
super().__init__()
|
816 |
+
self.encoder = encoder
|
817 |
+
self.decoder = decoder
|
818 |
+
self.src_pad_idx = src_pad_idx
|
819 |
+
self.trg_pad_idx = trg_pad_idx
|
820 |
+
self.device = device
|
821 |
+
self.loader_train = loader_train
|
822 |
+
self.out = out
|
823 |
+
self.loader_valid = loader_valid
|
824 |
+
self.loader_drugex = loader_drugex
|
825 |
+
self.epochs = epochs
|
826 |
+
self.lr = lr
|
827 |
+
self.clip = clip
|
828 |
+
self.reverse = reverse
|
829 |
+
self.TRG = TRG
|
830 |
+
self.SRC = SRC
|
831 |
+
|
832 |
+
def make_src_mask(self, src):
|
833 |
+
# src = [batch size, src len]
|
834 |
+
src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
|
835 |
+
# src_mask = [batch size, 1, 1, src len]
|
836 |
+
return src_mask
|
837 |
+
|
838 |
+
def make_trg_mask(self, trg):
|
839 |
+
# trg = [batch size, trg len]
|
840 |
+
trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)
|
841 |
+
# trg_pad_mask = [batch size, 1, 1, trg len]
|
842 |
+
trg_len = trg.shape[1]
|
843 |
+
trg_sub_mask = torch.tril(
|
844 |
+
torch.ones((trg_len, trg_len), device=self.device)).bool()
|
845 |
+
# trg_sub_mask = [trg len, trg len]
|
846 |
+
trg_mask = trg_pad_mask & trg_sub_mask
|
847 |
+
# trg_mask = [batch size, 1, trg len, trg len]
|
848 |
+
return trg_mask
|
849 |
+
|
850 |
+
def forward(self, src, trg):
|
851 |
+
# src = [batch size, src len]
|
852 |
+
# trg = [batch size, trg len]
|
853 |
+
src_mask = self.make_src_mask(src)
|
854 |
+
trg_mask = self.make_trg_mask(trg)
|
855 |
+
# src_mask = [batch size, 1, 1, src len]
|
856 |
+
# trg_mask = [batch size, 1, trg len, trg len]
|
857 |
+
enc_src = self.encoder(src, src_mask)
|
858 |
+
# enc_src = [batch size, src len, hid dim]
|
859 |
+
output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
|
860 |
+
# output = [batch size, trg len, output dim]
|
861 |
+
# attention = [batch size, n heads, trg len, src len]
|
862 |
+
return output, attention
|
863 |
+
|
864 |
+
def translate_sentence(self, src, trg_field, device, max_len=202):
|
865 |
+
self.eval()
|
866 |
+
src_mask = self.make_src_mask(src)
|
867 |
+
with torch.no_grad():
|
868 |
+
enc_src = self.encoder(src, src_mask)
|
869 |
+
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
|
870 |
+
batch_size = src.shape[0]
|
871 |
+
trg = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
|
872 |
+
trg = trg.repeat(batch_size, 1)
|
873 |
+
for i in range(max_len):
|
874 |
+
# turned model into self.
|
875 |
+
trg_mask = self.make_trg_mask(trg)
|
876 |
+
with torch.no_grad():
|
877 |
+
output, attention = self.decoder(trg, enc_src, trg_mask,
|
878 |
+
src_mask)
|
879 |
+
pred_tokens = output.argmax(2)[:, -1].unsqueeze(1)
|
880 |
+
trg = torch.cat((trg, pred_tokens), 1)
|
881 |
+
|
882 |
+
return output
|
883 |
+
|
884 |
+
|
885 |
+
def remove_floats(df: pd.DataFrame, subset: str):
|
886 |
+
"""Preprocessing step to remove any entries that are not strings"""
|
887 |
+
df_subset = df[subset]
|
888 |
+
df[subset] = df[subset].astype(str)
|
889 |
+
# only keep entries that stayed the same after applying astype str
|
890 |
+
df = df[df[subset] == df_subset].copy()
|
891 |
+
|
892 |
+
return df
|
893 |
+
|
894 |
+
|
895 |
+
def smi_tokenizer(smi: str, reverse=False) -> list:
|
896 |
+
"""
|
897 |
+
Tokenize a SMILES molecule
|
898 |
+
"""
|
899 |
+
pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
|
900 |
+
regex = re.compile(pattern)
|
901 |
+
# tokens = ['<sos>'] + [token for token in regex.findall(smi)] + ['<eos>']
|
902 |
+
tokens = [token for token in regex.findall(smi)]
|
903 |
+
# assert smi == ''.join(tokens[1:-1])
|
904 |
+
assert smi == "".join(tokens[:])
|
905 |
+
# try:
|
906 |
+
# assert smi == "".join(tokens[:])
|
907 |
+
# except:
|
908 |
+
# print(smi)
|
909 |
+
# print("".join(tokens[:]))
|
910 |
+
if reverse:
|
911 |
+
return tokens[::-1]
|
912 |
+
return tokens
|
913 |
+
|
914 |
+
|
915 |
+
def init_weights(m: nn.Module):
|
916 |
+
if hasattr(m, "weight") and m.weight.dim() > 1:
|
917 |
+
nn.init.xavier_uniform_(m.weight.data)
|
918 |
+
|
919 |
+
|
920 |
+
def count_parameters(model: nn.Module):
|
921 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
922 |
+
|
923 |
+
|
924 |
+
def epoch_time(start_time, end_time):
|
925 |
+
elapsed_time = end_time - start_time
|
926 |
+
elapsed_mins = int(elapsed_time / 60)
|
927 |
+
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
|
928 |
+
return elapsed_mins, elapsed_secs
|
929 |
+
|
930 |
+
|
931 |
+
def initialize_model(folder_out: str,
|
932 |
+
data_source: str,
|
933 |
+
error_source: str,
|
934 |
+
device: torch.device,
|
935 |
+
threshold: int,
|
936 |
+
epochs: int,
|
937 |
+
layers: int = 3,
|
938 |
+
batch_size: int = 16,
|
939 |
+
invalid_type: str = "all",
|
940 |
+
num_errors: int = 1,
|
941 |
+
validation_step=False):
|
942 |
+
"""Create encoder decoder models for specified model (currently only translator) & type of invalid SMILES
|
943 |
+
|
944 |
+
param data: collection of invalid, valid SMILES pairs
|
945 |
+
param invalid_smiles_path: path to previously generated invalid SMILES
|
946 |
+
param invalid_type: type of errors introduced into invalid SMILES
|
947 |
+
|
948 |
+
return:
|
949 |
+
|
950 |
+
"""
|
951 |
+
|
952 |
+
# set fields
|
953 |
+
SRC = Field(
|
954 |
+
tokenize=lambda x: smi_tokenizer(x),
|
955 |
+
init_token="<sos>",
|
956 |
+
eos_token="<eos>",
|
957 |
+
batch_first=True,
|
958 |
+
)
|
959 |
+
TRG = Field(
|
960 |
+
tokenize=lambda x: smi_tokenizer(x, reverse=True),
|
961 |
+
init_token="<sos>",
|
962 |
+
eos_token="<eos>",
|
963 |
+
batch_first=True,
|
964 |
+
)
|
965 |
+
|
966 |
+
if validation_step:
|
967 |
+
train, val = TabularDataset.splits(
|
968 |
+
path=f'{folder_out}errors/split/',
|
969 |
+
train=f"{data_source}_{invalid_type}_{num_errors}_errors_train.csv",
|
970 |
+
validation=
|
971 |
+
f"{data_source}_{invalid_type}_{num_errors}_errors_dev.csv",
|
972 |
+
format="CSV",
|
973 |
+
skip_header=False,
|
974 |
+
fields={
|
975 |
+
"ERROR": ("src", SRC),
|
976 |
+
"STD_SMILES": ("trg", TRG)
|
977 |
+
},
|
978 |
+
)
|
979 |
+
SRC.build_vocab(train, val, max_size=1000)
|
980 |
+
TRG.build_vocab(train, val, max_size=1000)
|
981 |
+
else:
|
982 |
+
train = TabularDataset(
|
983 |
+
path=
|
984 |
+
f'{folder_out}{data_source}_{invalid_type}_{num_errors}_errors.csv',
|
985 |
+
format="CSV",
|
986 |
+
skip_header=False,
|
987 |
+
fields={
|
988 |
+
"ERROR": ("src", SRC),
|
989 |
+
"STD_SMILES": ("trg", TRG)
|
990 |
+
},
|
991 |
+
)
|
992 |
+
SRC.build_vocab(train, max_size=1000)
|
993 |
+
TRG.build_vocab(train, max_size=1000)
|
994 |
+
|
995 |
+
drugex = TabularDataset(
|
996 |
+
path=error_source,
|
997 |
+
format="csv",
|
998 |
+
skip_header=False,
|
999 |
+
fields={
|
1000 |
+
"SMILES": ("src", SRC),
|
1001 |
+
"SMILES_TARGET": ("trg", TRG)
|
1002 |
+
},
|
1003 |
+
)
|
1004 |
+
|
1005 |
+
|
1006 |
+
#SRC.vocab = torch.load('vocab_src.pth')
|
1007 |
+
#TRG.vocab = torch.load('vocab_trg.pth')
|
1008 |
+
|
1009 |
+
# model parameters
|
1010 |
+
EPOCHS = epochs
|
1011 |
+
BATCH_SIZE = batch_size
|
1012 |
+
INPUT_DIM = len(SRC.vocab)
|
1013 |
+
OUTPUT_DIM = len(TRG.vocab)
|
1014 |
+
HID_DIM = 256
|
1015 |
+
ENC_LAYERS = layers
|
1016 |
+
DEC_LAYERS = layers
|
1017 |
+
ENC_HEADS = 8
|
1018 |
+
DEC_HEADS = 8
|
1019 |
+
ENC_PF_DIM = 512
|
1020 |
+
DEC_PF_DIM = 512
|
1021 |
+
ENC_DROPOUT = 0.1
|
1022 |
+
DEC_DROPOUT = 0.1
|
1023 |
+
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
|
1024 |
+
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
|
1025 |
+
# add 2 to length for start and stop tokens
|
1026 |
+
MAX_LENGTH = threshold + 2
|
1027 |
+
|
1028 |
+
# model name
|
1029 |
+
MODEL_OUT_FOLDER = f"{folder_out}"
|
1030 |
+
|
1031 |
+
MODEL_NAME = "transformer_%s_%s_%s_%s_%s" % (
|
1032 |
+
invalid_type, num_errors, data_source, BATCH_SIZE, layers)
|
1033 |
+
if not os.path.exists(MODEL_OUT_FOLDER):
|
1034 |
+
os.mkdir(MODEL_OUT_FOLDER)
|
1035 |
+
|
1036 |
+
out = os.path.join(MODEL_OUT_FOLDER, MODEL_NAME)
|
1037 |
+
|
1038 |
+
torch.save(SRC.vocab, f'{out}_vocab_src.pth')
|
1039 |
+
torch.save(TRG.vocab, f'{out}_vocab_trg.pth')
|
1040 |
+
|
1041 |
+
# iterator is a dataloader
|
1042 |
+
# iterator to pass to the same length and create batches in which the
|
1043 |
+
# amount of padding is minimized
|
1044 |
+
if validation_step:
|
1045 |
+
train_iter, val_iter = BucketIterator.splits(
|
1046 |
+
(train, val),
|
1047 |
+
batch_sizes=(BATCH_SIZE, 256),
|
1048 |
+
sort_within_batch=True,
|
1049 |
+
shuffle=True,
|
1050 |
+
# the BucketIterator needs to be told what function it should use to
|
1051 |
+
# group the data.
|
1052 |
+
sort_key=lambda x: len(x.src),
|
1053 |
+
device=device,
|
1054 |
+
)
|
1055 |
+
else:
|
1056 |
+
train_iter = BucketIterator(
|
1057 |
+
train,
|
1058 |
+
batch_size=BATCH_SIZE,
|
1059 |
+
sort_within_batch=True,
|
1060 |
+
shuffle=True,
|
1061 |
+
# the BucketIterator needs to be told what function it should use to
|
1062 |
+
# group the data.
|
1063 |
+
sort_key=lambda x: len(x.src),
|
1064 |
+
device=device,
|
1065 |
+
)
|
1066 |
+
val_iter = None
|
1067 |
+
|
1068 |
+
drugex_iter = Iterator(
|
1069 |
+
drugex,
|
1070 |
+
batch_size=64,
|
1071 |
+
device=device,
|
1072 |
+
sort=False,
|
1073 |
+
sort_within_batch=True,
|
1074 |
+
sort_key=lambda x: len(x.src),
|
1075 |
+
repeat=False,
|
1076 |
+
)
|
1077 |
+
|
1078 |
+
|
1079 |
+
# model initialization
|
1080 |
+
|
1081 |
+
enc = Encoder(
|
1082 |
+
INPUT_DIM,
|
1083 |
+
HID_DIM,
|
1084 |
+
ENC_LAYERS,
|
1085 |
+
ENC_HEADS,
|
1086 |
+
ENC_PF_DIM,
|
1087 |
+
ENC_DROPOUT,
|
1088 |
+
MAX_LENGTH,
|
1089 |
+
device,
|
1090 |
+
)
|
1091 |
+
dec = Decoder(
|
1092 |
+
OUTPUT_DIM,
|
1093 |
+
HID_DIM,
|
1094 |
+
DEC_LAYERS,
|
1095 |
+
DEC_HEADS,
|
1096 |
+
DEC_PF_DIM,
|
1097 |
+
DEC_DROPOUT,
|
1098 |
+
MAX_LENGTH,
|
1099 |
+
device,
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
model = Seq2Seq(
|
1103 |
+
enc,
|
1104 |
+
dec,
|
1105 |
+
SRC_PAD_IDX,
|
1106 |
+
TRG_PAD_IDX,
|
1107 |
+
device,
|
1108 |
+
train_iter,
|
1109 |
+
out=out,
|
1110 |
+
loader_valid=val_iter,
|
1111 |
+
loader_drugex=drugex_iter,
|
1112 |
+
epochs=EPOCHS,
|
1113 |
+
TRG=TRG,
|
1114 |
+
SRC=SRC,
|
1115 |
+
).to(device)
|
1116 |
+
|
1117 |
+
|
1118 |
+
|
1119 |
+
|
1120 |
+
return model, out, SRC
|
1121 |
+
|
1122 |
+
|
1123 |
+
def train_model(model, out, assess):
|
1124 |
+
"""Apply given weights (& assess performance or train further) or start training new model
|
1125 |
+
|
1126 |
+
Args:
|
1127 |
+
model: initialized model
|
1128 |
+
out: .pkg file with model parameters
|
1129 |
+
asses: bool
|
1130 |
+
|
1131 |
+
Returns:
|
1132 |
+
model with (new) weights
|
1133 |
+
"""
|
1134 |
+
|
1135 |
+
if os.path.exists(f"{out}.pkg") and assess:
|
1136 |
+
|
1137 |
+
|
1138 |
+
model.load_state_dict(torch.load(f=out + ".pkg"))
|
1139 |
+
(
|
1140 |
+
valids,
|
1141 |
+
loss_valid,
|
1142 |
+
valids_de,
|
1143 |
+
df_output,
|
1144 |
+
df_output_de,
|
1145 |
+
right_molecules,
|
1146 |
+
complexity,
|
1147 |
+
unchanged,
|
1148 |
+
unchanged_de,
|
1149 |
+
) = model.evaluate(True)
|
1150 |
+
|
1151 |
+
|
1152 |
+
# log = open('unchanged.log', 'a')
|
1153 |
+
# info = f'type: comb unchanged: {unchan:.4g} unchanged_drugex: {unchan_de:.4g}'
|
1154 |
+
# print(info, file=log, flush = True)
|
1155 |
+
# print(valids_de)
|
1156 |
+
# print(unchanged_de)
|
1157 |
+
|
1158 |
+
# print(unchan)
|
1159 |
+
# print(unchan_de)
|
1160 |
+
# df_output_de.to_csv(f'{out}_de_new.csv', index = False)
|
1161 |
+
|
1162 |
+
# error_de = 1 - valids_de / len(drugex_iter.dataset)
|
1163 |
+
# print(error_de)
|
1164 |
+
# df_output.to_csv(f'{out}_par.csv', index = False)
|
1165 |
+
|
1166 |
+
elif os.path.exists(f"{out}.pkg"):
|
1167 |
+
|
1168 |
+
# starts from the model after the last epoch, not the best epoch
|
1169 |
+
model.load_state_dict(torch.load(f=out + "_last.pkg"))
|
1170 |
+
# need to change how log file names epochs
|
1171 |
+
model.train_model()
|
1172 |
+
else:
|
1173 |
+
|
1174 |
+
model = model.apply(init_weights)
|
1175 |
+
model.train_model()
|
1176 |
+
|
1177 |
+
return model
|
1178 |
+
|
1179 |
+
|
1180 |
+
def correct_SMILES(model, out, error_source, device, SRC):
|
1181 |
+
"""Model that is given corrects SMILES and return number of correct ouputs and dataframe containing all outputs
|
1182 |
+
Args:
|
1183 |
+
model: initialized model
|
1184 |
+
out: .pkg file with model parameters
|
1185 |
+
asses: bool
|
1186 |
+
|
1187 |
+
Returns:
|
1188 |
+
valids: number of fixed outputs
|
1189 |
+
df_output: dataframe containing output (either correct or incorrect) & original input
|
1190 |
+
"""
|
1191 |
+
## account for tokens that are not yet in SRC without changing existing SRC token embeddings
|
1192 |
+
errors = TabularDataset(
|
1193 |
+
path=error_source,
|
1194 |
+
format="csv",
|
1195 |
+
skip_header=False,
|
1196 |
+
fields={"SMILES": ("src", SRC)},
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
errors_loader = Iterator(
|
1200 |
+
errors,
|
1201 |
+
batch_size=64,
|
1202 |
+
device=device,
|
1203 |
+
sort=False,
|
1204 |
+
sort_within_batch=True,
|
1205 |
+
sort_key=lambda x: len(x.src),
|
1206 |
+
repeat=False,
|
1207 |
+
)
|
1208 |
+
model.load_state_dict(torch.load(f=out + ".pkg",map_location=torch.device('cpu')))
|
1209 |
+
# add option to use different iterator maybe?
|
1210 |
+
|
1211 |
+
valids, df_output = model.translate(errors_loader)
|
1212 |
+
#df_output.to_csv(f"{error_source}_fixed.csv", index=False)
|
1213 |
+
|
1214 |
+
|
1215 |
+
return valids, df_output
|
1216 |
+
|
1217 |
+
|
1218 |
+
|
1219 |
+
class smi_correct(object):
|
1220 |
+
def __init__(self, model_name, trans_file_path):
|
1221 |
+
# set random seed, used for error generation & initiation transformer
|
1222 |
+
|
1223 |
+
self.SEED = 42
|
1224 |
+
random.seed(self.SEED)
|
1225 |
+
self.model_name = model_name
|
1226 |
+
self.folder_out = "data/"
|
1227 |
+
|
1228 |
+
self.trans_file_path = trans_file_path
|
1229 |
+
|
1230 |
+
if not os.path.exists(self.folder_out):
|
1231 |
+
os.makedirs(self.folder_out)
|
1232 |
+
|
1233 |
+
self.invalid_type = 'multiple'
|
1234 |
+
self.num_errors = 12
|
1235 |
+
self.threshold = 200
|
1236 |
+
self.data_source = f"PAPYRUS_{self.threshold}"
|
1237 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
1238 |
+
self.initialize_source = 'data/papyrus_rnn_S.csv' # change this path
|
1239 |
+
|
1240 |
+
def standardization_pipeline(self, smile):
|
1241 |
+
desalter = MolStandardize.rdMolStandardize.LargestFragmentChooser()
|
1242 |
+
std_smile = None
|
1243 |
+
if not isinstance(smile, str): return None
|
1244 |
+
m = Chem.MolFromSmiles(smile)
|
1245 |
+
# skips smiles for which no mol file could be generated
|
1246 |
+
if m is not None:
|
1247 |
+
# standardizes
|
1248 |
+
std_m = standardizer.standardize_mol(m)
|
1249 |
+
# strips salts
|
1250 |
+
std_m_p, exclude = standardizer.get_parent_mol(std_m)
|
1251 |
+
if not exclude:
|
1252 |
+
# choose largest fragment for rare cases where chembl structure
|
1253 |
+
# pipeline leaves 2 fragments
|
1254 |
+
std_m_p_d = desalter.choose(std_m_p)
|
1255 |
+
std_smile = Chem.MolToSmiles(std_m_p_d)
|
1256 |
+
return std_smile
|
1257 |
+
|
1258 |
+
def remove_smiles_duplicates(self, dataframe: pd.DataFrame,
|
1259 |
+
subset: str) -> pd.DataFrame:
|
1260 |
+
return dataframe.drop_duplicates(subset=subset)
|
1261 |
+
|
1262 |
+
def correct(self, smi):
|
1263 |
+
|
1264 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1265 |
+
|
1266 |
+
model, out, SRC = initialize_model(self.folder_out,
|
1267 |
+
self.data_source,
|
1268 |
+
error_source=self.initialize_source,
|
1269 |
+
device=device,
|
1270 |
+
threshold=self.threshold,
|
1271 |
+
epochs=30,
|
1272 |
+
layers=3,
|
1273 |
+
batch_size=16,
|
1274 |
+
invalid_type=self.invalid_type,
|
1275 |
+
num_errors=self.num_errors)
|
1276 |
+
|
1277 |
+
valids, df_output = correct_SMILES(model, out, smi, device,
|
1278 |
+
SRC)
|
1279 |
+
|
1280 |
+
df_output["SMILES"] = df_output.apply(lambda row: self.standardization_pipeline(row["CORRECT"]), axis=1)
|
1281 |
+
|
1282 |
+
df_output = self.remove_smiles_duplicates(df_output, subset="SMILES").drop(columns=["CORRECT", "INCORRECT", "ORIGINAL"]).dropna()
|
1283 |
+
|
1284 |
+
return df_output
|
src/util/utils.py
ADDED
@@ -0,0 +1,930 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import math
|
4 |
+
import datetime
|
5 |
+
import warnings
|
6 |
+
import itertools
|
7 |
+
from copy import deepcopy
|
8 |
+
from functools import partial
|
9 |
+
from collections import Counter
|
10 |
+
from multiprocessing import Pool
|
11 |
+
from statistics import mean
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
from matplotlib.lines import Line2D
|
16 |
+
from scipy.spatial.distance import cosine as cos_distance
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import wandb
|
20 |
+
|
21 |
+
from rdkit import Chem, DataStructs, RDLogger
|
22 |
+
from rdkit.Chem import (
|
23 |
+
AllChem,
|
24 |
+
Draw,
|
25 |
+
Descriptors,
|
26 |
+
Lipinski,
|
27 |
+
Crippen,
|
28 |
+
rdMolDescriptors,
|
29 |
+
FilterCatalog,
|
30 |
+
)
|
31 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
32 |
+
|
33 |
+
# Disable RDKit warnings
|
34 |
+
RDLogger.DisableLog("rdApp.*")
|
35 |
+
|
36 |
+
|
37 |
+
class Metrics(object):
|
38 |
+
"""
|
39 |
+
Collection of static methods to compute various metrics for molecules.
|
40 |
+
"""
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def valid(x):
|
44 |
+
"""
|
45 |
+
Checks whether the molecule is valid.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
x: RDKit molecule object.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
bool: True if molecule is valid and has a non-empty SMILES representation.
|
52 |
+
"""
|
53 |
+
return x is not None and Chem.MolToSmiles(x) != ''
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def tanimoto_sim_1v2(data1, data2):
|
57 |
+
"""
|
58 |
+
Computes the average Tanimoto similarity for paired fingerprints.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
data1: Fingerprint data for first set.
|
62 |
+
data2: Fingerprint data for second set.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
float: The average Tanimoto similarity between corresponding fingerprints.
|
66 |
+
"""
|
67 |
+
# Determine the minimum size between two arrays for pairing
|
68 |
+
min_len = data1.size if data1.size > data2.size else data2
|
69 |
+
sims = []
|
70 |
+
for i in range(min_len):
|
71 |
+
sim = DataStructs.FingerprintSimilarity(data1[i], data2[i])
|
72 |
+
sims.append(sim)
|
73 |
+
# Use 'mean' from statistics; note that variable 'sim' was used, corrected to use sims list.
|
74 |
+
mean_sim = mean(sims)
|
75 |
+
return mean_sim
|
76 |
+
|
77 |
+
@staticmethod
|
78 |
+
def mol_length(x):
|
79 |
+
"""
|
80 |
+
Computes the length of the largest fragment (by character count) in a SMILES string.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
x (str): SMILES string.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
int: Number of alphabetic characters in the longest fragment of the SMILES.
|
87 |
+
"""
|
88 |
+
if x is not None:
|
89 |
+
# Split at dots (.) and take the fragment with maximum length, then count alphabetic characters.
|
90 |
+
return len([char for char in max(x.split(sep="."), key=len).upper() if char.isalpha()])
|
91 |
+
else:
|
92 |
+
return 0
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def max_component(data, max_len):
|
96 |
+
"""
|
97 |
+
Returns the average normalized length of molecules in the dataset.
|
98 |
+
|
99 |
+
Each molecule's length is computed and divided by max_len, then averaged.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
data (iterable): Collection of SMILES strings.
|
103 |
+
max_len (int): Maximum possible length for normalization.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
float: Normalized average length.
|
107 |
+
"""
|
108 |
+
lengths = np.array(list(map(Metrics.mol_length, data)), dtype=np.float32)
|
109 |
+
return (lengths / max_len).mean()
|
110 |
+
|
111 |
+
@staticmethod
|
112 |
+
def mean_atom_type(data):
|
113 |
+
"""
|
114 |
+
Computes the average number of unique atom types in the provided node data.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
data (iterable): Iterable containing node data with unique atom types.
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
float: The average count of unique atom types, subtracting one.
|
121 |
+
"""
|
122 |
+
atom_types_used = []
|
123 |
+
for i in data:
|
124 |
+
# Assuming each element i has a .unique() method that returns unique atom types.
|
125 |
+
atom_types_used.append(len(i.unique().tolist()))
|
126 |
+
av_type = np.mean(atom_types_used) - 1
|
127 |
+
return av_type
|
128 |
+
|
129 |
+
|
130 |
+
def mols2grid_image(mols, path):
|
131 |
+
"""
|
132 |
+
Saves grid images for a list of molecules.
|
133 |
+
|
134 |
+
For each molecule in the list, computes 2D coordinates and saves an image file.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
mols (list): List of RDKit molecule objects.
|
138 |
+
path (str): Directory where images will be saved.
|
139 |
+
"""
|
140 |
+
# Replace None molecules with an empty molecule
|
141 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
142 |
+
|
143 |
+
for i in range(len(mols)):
|
144 |
+
if Metrics.valid(mols[i]):
|
145 |
+
AllChem.Compute2DCoords(mols[i])
|
146 |
+
file_path = os.path.join(path, "{}.png".format(i + 1))
|
147 |
+
Draw.MolToFile(mols[i], file_path, size=(1200, 1200))
|
148 |
+
# wandb.save(file_path) # Optionally save to Weights & Biases
|
149 |
+
else:
|
150 |
+
continue
|
151 |
+
|
152 |
+
|
153 |
+
def save_smiles_matrices(mols, edges_hard, nodes_hard, path, data_source=None):
|
154 |
+
"""
|
155 |
+
Saves the edge and node matrices along with SMILES strings to text files.
|
156 |
+
|
157 |
+
Each file contains the edge matrix, node matrix, and SMILES representation for a molecule.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
mols (list): List of RDKit molecule objects.
|
161 |
+
edges_hard (torch.Tensor): Tensor of edge features.
|
162 |
+
nodes_hard (torch.Tensor): Tensor of node features.
|
163 |
+
path (str): Directory where files will be saved.
|
164 |
+
data_source: Optional data source information (not used in function).
|
165 |
+
"""
|
166 |
+
mols = [e if e is not None else Chem.RWMol() for e in mols]
|
167 |
+
|
168 |
+
for i in range(len(mols)):
|
169 |
+
if Metrics.valid(mols[i]):
|
170 |
+
save_path = os.path.join(path, "{}.txt".format(i + 1))
|
171 |
+
with open(save_path, "a") as f:
|
172 |
+
np.savetxt(f, edges_hard[i].cpu().numpy(), header="edge matrix:\n", fmt='%1.2f')
|
173 |
+
f.write("\n")
|
174 |
+
np.savetxt(f, nodes_hard[i].cpu().numpy(), header="node matrix:\n", footer="\nsmiles:", fmt='%1.2f')
|
175 |
+
f.write("\n")
|
176 |
+
# Append the SMILES representation to the file
|
177 |
+
with open(save_path, "a") as f:
|
178 |
+
print(Chem.MolToSmiles(mols[i]), file=f)
|
179 |
+
# wandb.save(save_path) # Optionally save to Weights & Biases
|
180 |
+
else:
|
181 |
+
continue
|
182 |
+
|
183 |
+
def dense_to_sparse_with_attr(adj):
|
184 |
+
"""
|
185 |
+
Converts a dense adjacency matrix to a sparse representation.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
adj (torch.Tensor): Adjacency matrix tensor (2D or 3D) with square last two dimensions.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
tuple: A tuple containing indices and corresponding edge attributes.
|
192 |
+
"""
|
193 |
+
assert adj.dim() >= 2 and adj.dim() <= 3
|
194 |
+
assert adj.size(-1) == adj.size(-2)
|
195 |
+
|
196 |
+
index = adj.nonzero(as_tuple=True)
|
197 |
+
edge_attr = adj[index]
|
198 |
+
|
199 |
+
if len(index) == 3:
|
200 |
+
batch = index[0] * adj.size(-1)
|
201 |
+
index = (batch + index[1], batch + index[2])
|
202 |
+
return index, edge_attr
|
203 |
+
|
204 |
+
|
205 |
+
def mol_sample(sample_directory, edges, nodes, idx, i, matrices2mol, dataset_name):
|
206 |
+
"""
|
207 |
+
Samples molecules from edge and node predictions, then saves grid images and text files.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
sample_directory (str): Directory to save the samples.
|
211 |
+
edges (torch.Tensor): Edge predictions tensor.
|
212 |
+
nodes (torch.Tensor): Node predictions tensor.
|
213 |
+
idx (int): Current index for naming the sample.
|
214 |
+
i (int): Epoch/iteration index.
|
215 |
+
matrices2mol (callable): Function to convert matrices to RDKit molecule.
|
216 |
+
dataset_name (str): Name of the dataset for file naming.
|
217 |
+
"""
|
218 |
+
sample_path = os.path.join(sample_directory, "{}_{}-epoch_iteration".format(idx + 1, i + 1))
|
219 |
+
# Get the index of the maximum predicted feature along the last dimension
|
220 |
+
g_edges_hat_sample = torch.max(edges, -1)[1]
|
221 |
+
g_nodes_hat_sample = torch.max(nodes, -1)[1]
|
222 |
+
# Convert matrices to molecule objects
|
223 |
+
mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
|
224 |
+
strict=True, file_name=dataset_name)
|
225 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
226 |
+
|
227 |
+
if not os.path.exists(sample_path):
|
228 |
+
os.makedirs(sample_path)
|
229 |
+
|
230 |
+
mols2grid_image(mol, sample_path)
|
231 |
+
save_smiles_matrices(mol, g_edges_hat_sample.detach(), g_nodes_hat_sample.detach(), sample_path)
|
232 |
+
|
233 |
+
# Remove the directory if no files were saved
|
234 |
+
if len(os.listdir(sample_path)) == 0:
|
235 |
+
os.rmdir(sample_path)
|
236 |
+
|
237 |
+
print("Valid molecules are saved.")
|
238 |
+
print("Valid matrices and smiles are saved")
|
239 |
+
|
240 |
+
|
241 |
+
def logging(log_path, start_time, i, idx, loss, save_path, drug_smiles, edge, node,
|
242 |
+
matrices2mol, dataset_name, real_adj, real_annot, drug_vecs):
|
243 |
+
"""
|
244 |
+
Logs training statistics and evaluation metrics.
|
245 |
+
|
246 |
+
The function generates molecules from predictions, computes various metrics such as
|
247 |
+
validity, uniqueness, novelty, and similarity scores, and logs them using wandb and a file.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
log_path (str): Path to save the log file.
|
251 |
+
start_time (float): Start time to compute elapsed time.
|
252 |
+
i (int): Current iteration index.
|
253 |
+
idx (int): Current epoch index.
|
254 |
+
loss (dict): Dictionary to update with loss and metric values.
|
255 |
+
save_path (str): Directory path to save sample outputs.
|
256 |
+
drug_smiles (list): List of reference drug SMILES.
|
257 |
+
edge (torch.Tensor): Edge prediction tensor.
|
258 |
+
node (torch.Tensor): Node prediction tensor.
|
259 |
+
matrices2mol (callable): Function to convert matrices to molecules.
|
260 |
+
dataset_name (str): Dataset name.
|
261 |
+
real_adj (torch.Tensor): Ground truth adjacency matrix tensor.
|
262 |
+
real_annot (torch.Tensor): Ground truth annotation tensor.
|
263 |
+
drug_vecs (list): List of drug vectors for similarity calculation.
|
264 |
+
"""
|
265 |
+
g_edges_hat_sample = torch.max(edge, -1)[1]
|
266 |
+
g_nodes_hat_sample = torch.max(node, -1)[1]
|
267 |
+
|
268 |
+
a_tensor_sample = torch.max(real_adj, -1)[1].float()
|
269 |
+
x_tensor_sample = torch.max(real_annot, -1)[1].float()
|
270 |
+
|
271 |
+
# Generate molecules from predictions and real data
|
272 |
+
mols = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
|
273 |
+
strict=True, file_name=dataset_name)
|
274 |
+
for e_, n_ in zip(g_edges_hat_sample, g_nodes_hat_sample)]
|
275 |
+
real_mol = [matrices2mol(n_.data.cpu().numpy(), e_.data.cpu().numpy(),
|
276 |
+
strict=True, file_name=dataset_name)
|
277 |
+
for e_, n_ in zip(a_tensor_sample, x_tensor_sample)]
|
278 |
+
|
279 |
+
# Compute average number of atom types
|
280 |
+
atom_types_average = Metrics.mean_atom_type(g_nodes_hat_sample)
|
281 |
+
real_smiles = [Chem.MolToSmiles(x) for x in real_mol if x is not None]
|
282 |
+
gen_smiles = []
|
283 |
+
uniq_smiles = []
|
284 |
+
for line in mols:
|
285 |
+
if line is not None:
|
286 |
+
gen_smiles.append(Chem.MolToSmiles(line))
|
287 |
+
uniq_smiles.append(Chem.MolToSmiles(line))
|
288 |
+
elif line is None:
|
289 |
+
gen_smiles.append(None)
|
290 |
+
|
291 |
+
# Process SMILES to take the longest fragment if multiple are present
|
292 |
+
gen_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in gen_smiles]
|
293 |
+
uniq_smiles_saves = [None if x is None else max(x.split('.'), key=len) for x in uniq_smiles]
|
294 |
+
|
295 |
+
# Save the generated SMILES to a text file
|
296 |
+
sample_save_dir = os.path.join(save_path, "samples.txt")
|
297 |
+
with open(sample_save_dir, "a") as f:
|
298 |
+
for s in gen_smiles_saves:
|
299 |
+
if s is not None:
|
300 |
+
f.write(s + "\n")
|
301 |
+
|
302 |
+
k = len(set(uniq_smiles_saves) - {None})
|
303 |
+
et = time.time() - start_time
|
304 |
+
et = str(datetime.timedelta(seconds=et))[:-7]
|
305 |
+
log_str = "Elapsed [{}], Epoch/Iteration [{}/{}]".format(et, idx, i + 1)
|
306 |
+
|
307 |
+
# Generate molecular fingerprints for similarity computations
|
308 |
+
gen_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in mols if x is not None]
|
309 |
+
chembl_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in real_mol if x is not None]
|
310 |
+
|
311 |
+
# Compute evaluation metrics: validity, uniqueness, novelty, similarity scores, and average maximum molecule length.
|
312 |
+
valid = fraction_valid(gen_smiles_saves)
|
313 |
+
unique = fraction_unique(uniq_smiles_saves, k)
|
314 |
+
novel_starting_mol = novelty(gen_smiles_saves, real_smiles)
|
315 |
+
novel_akt = novelty(gen_smiles_saves, drug_smiles)
|
316 |
+
if len(uniq_smiles_saves) == 0:
|
317 |
+
snn_chembl = 0
|
318 |
+
snn_akt = 0
|
319 |
+
maxlen = 0
|
320 |
+
else:
|
321 |
+
snn_chembl = average_agg_tanimoto(np.array(chembl_vecs), np.array(gen_vecs))
|
322 |
+
snn_akt = average_agg_tanimoto(np.array(drug_vecs), np.array(gen_vecs))
|
323 |
+
maxlen = Metrics.max_component(uniq_smiles_saves, 45)
|
324 |
+
|
325 |
+
# Update loss dictionary with computed metrics
|
326 |
+
loss.update({
|
327 |
+
'Validity': valid,
|
328 |
+
'Uniqueness': unique,
|
329 |
+
'Novelty': novel_starting_mol,
|
330 |
+
'Novelty_akt': novel_akt,
|
331 |
+
'SNN_chembl': snn_chembl,
|
332 |
+
'SNN_akt': snn_akt,
|
333 |
+
'MaxLen': maxlen,
|
334 |
+
'Atom_types': atom_types_average
|
335 |
+
})
|
336 |
+
|
337 |
+
# Log metrics using wandb
|
338 |
+
wandb.log({
|
339 |
+
"Validity": valid,
|
340 |
+
"Uniqueness": unique,
|
341 |
+
"Novelty": novel_starting_mol,
|
342 |
+
"Novelty_akt": novel_akt,
|
343 |
+
"SNN_chembl": snn_chembl,
|
344 |
+
"SNN_akt": snn_akt,
|
345 |
+
"MaxLen": maxlen,
|
346 |
+
"Atom_types": atom_types_average
|
347 |
+
})
|
348 |
+
|
349 |
+
# Append each metric to the log string and write to the log file
|
350 |
+
for tag, value in loss.items():
|
351 |
+
log_str += ", {}: {:.4f}".format(tag, value)
|
352 |
+
with open(log_path, "a") as f:
|
353 |
+
f.write(log_str + "\n")
|
354 |
+
print(log_str)
|
355 |
+
print("\n")
|
356 |
+
|
357 |
+
|
358 |
+
def plot_grad_flow(named_parameters, model, itera, epoch, grad_flow_directory):
|
359 |
+
"""
|
360 |
+
Plots the gradients flowing through different layers during training.
|
361 |
+
|
362 |
+
This is useful to check for possible gradient vanishing or exploding problems.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
named_parameters (iterable): Iterable of (name, parameter) tuples from the model.
|
366 |
+
model (str): Name of the model (used for saving the plot).
|
367 |
+
itera (int): Iteration index.
|
368 |
+
epoch (int): Current epoch.
|
369 |
+
grad_flow_directory (str): Directory to save the gradient flow plot.
|
370 |
+
"""
|
371 |
+
ave_grads = []
|
372 |
+
max_grads = []
|
373 |
+
layers = []
|
374 |
+
for n, p in named_parameters:
|
375 |
+
if p.requires_grad and ("bias" not in n):
|
376 |
+
layers.append(n)
|
377 |
+
ave_grads.append(p.grad.abs().mean().cpu())
|
378 |
+
max_grads.append(p.grad.abs().max().cpu())
|
379 |
+
# Plot maximum gradients and average gradients for each layer
|
380 |
+
plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
|
381 |
+
plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
|
382 |
+
plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
|
383 |
+
plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
|
384 |
+
plt.xlim(left=0, right=len(ave_grads))
|
385 |
+
plt.ylim(bottom=-0.001, top=1) # Zoom in on lower gradient regions
|
386 |
+
plt.xlabel("Layers")
|
387 |
+
plt.ylabel("Average Gradient")
|
388 |
+
plt.title("Gradient Flow")
|
389 |
+
plt.grid(True)
|
390 |
+
plt.legend([
|
391 |
+
Line2D([0], [0], color="c", lw=4),
|
392 |
+
Line2D([0], [0], color="b", lw=4),
|
393 |
+
Line2D([0], [0], color="k", lw=4)
|
394 |
+
], ['max-gradient', 'mean-gradient', 'zero-gradient'])
|
395 |
+
# Save the plot to the specified directory
|
396 |
+
plt.savefig(os.path.join(grad_flow_directory, "weights_" + model + "_" + str(itera) + "_" + str(epoch) + ".png"), dpi=500, bbox_inches='tight')
|
397 |
+
|
398 |
+
|
399 |
+
def get_mol(smiles_or_mol):
|
400 |
+
"""
|
401 |
+
Loads a SMILES string or molecule into an RDKit molecule object.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
smiles_or_mol (str or RDKit Mol): SMILES string or RDKit molecule.
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
RDKit Mol or None: Sanitized molecule object, or None if invalid.
|
408 |
+
"""
|
409 |
+
if isinstance(smiles_or_mol, str):
|
410 |
+
if len(smiles_or_mol) == 0:
|
411 |
+
return None
|
412 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
413 |
+
if mol is None:
|
414 |
+
return None
|
415 |
+
try:
|
416 |
+
Chem.SanitizeMol(mol)
|
417 |
+
except ValueError:
|
418 |
+
return None
|
419 |
+
return mol
|
420 |
+
return smiles_or_mol
|
421 |
+
|
422 |
+
|
423 |
+
def mapper(n_jobs):
|
424 |
+
"""
|
425 |
+
Returns a mapping function for parallel or serial processing.
|
426 |
+
|
427 |
+
If n_jobs == 1, returns the built-in map function.
|
428 |
+
If n_jobs > 1, returns a function that uses a multiprocessing pool.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
n_jobs (int or pool object): Number of jobs or a Pool instance.
|
432 |
+
|
433 |
+
Returns:
|
434 |
+
callable: A function that acts like map.
|
435 |
+
"""
|
436 |
+
if n_jobs == 1:
|
437 |
+
def _mapper(*args, **kwargs):
|
438 |
+
return list(map(*args, **kwargs))
|
439 |
+
return _mapper
|
440 |
+
if isinstance(n_jobs, int):
|
441 |
+
pool = Pool(n_jobs)
|
442 |
+
def _mapper(*args, **kwargs):
|
443 |
+
try:
|
444 |
+
result = pool.map(*args, **kwargs)
|
445 |
+
finally:
|
446 |
+
pool.terminate()
|
447 |
+
return result
|
448 |
+
return _mapper
|
449 |
+
return n_jobs.map
|
450 |
+
|
451 |
+
|
452 |
+
def remove_invalid(gen, canonize=True, n_jobs=1):
|
453 |
+
"""
|
454 |
+
Removes invalid molecules from the provided dataset.
|
455 |
+
|
456 |
+
Optionally canonizes the SMILES strings.
|
457 |
+
|
458 |
+
Args:
|
459 |
+
gen (list): List of SMILES strings.
|
460 |
+
canonize (bool): Whether to convert to canonical SMILES.
|
461 |
+
n_jobs (int): Number of parallel jobs.
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
list: Filtered list of valid molecules.
|
465 |
+
"""
|
466 |
+
if not canonize:
|
467 |
+
mols = mapper(n_jobs)(get_mol, gen)
|
468 |
+
return [gen_ for gen_, mol in zip(gen, mols) if mol is not None]
|
469 |
+
return [x for x in mapper(n_jobs)(canonic_smiles, gen) if x is not None]
|
470 |
+
|
471 |
+
|
472 |
+
def fraction_valid(gen, n_jobs=1):
|
473 |
+
"""
|
474 |
+
Computes the fraction of valid molecules in the dataset.
|
475 |
+
|
476 |
+
Args:
|
477 |
+
gen (list): List of SMILES strings.
|
478 |
+
n_jobs (int): Number of parallel jobs.
|
479 |
+
|
480 |
+
Returns:
|
481 |
+
float: Fraction of molecules that are valid.
|
482 |
+
"""
|
483 |
+
gen = mapper(n_jobs)(get_mol, gen)
|
484 |
+
return 1 - gen.count(None) / len(gen)
|
485 |
+
|
486 |
+
|
487 |
+
def canonic_smiles(smiles_or_mol):
|
488 |
+
"""
|
489 |
+
Converts a SMILES string or molecule to its canonical SMILES.
|
490 |
+
|
491 |
+
Args:
|
492 |
+
smiles_or_mol (str or RDKit Mol): Input molecule.
|
493 |
+
|
494 |
+
Returns:
|
495 |
+
str or None: Canonical SMILES string or None if invalid.
|
496 |
+
"""
|
497 |
+
mol = get_mol(smiles_or_mol)
|
498 |
+
if mol is None:
|
499 |
+
return None
|
500 |
+
return Chem.MolToSmiles(mol)
|
501 |
+
|
502 |
+
|
503 |
+
def fraction_unique(gen, k=None, n_jobs=1, check_validity=True):
|
504 |
+
"""
|
505 |
+
Computes the fraction of unique molecules.
|
506 |
+
|
507 |
+
Optionally computes unique@k, where only the first k molecules are considered.
|
508 |
+
|
509 |
+
Args:
|
510 |
+
gen (list): List of SMILES strings.
|
511 |
+
k (int): Optional cutoff for unique@k computation.
|
512 |
+
n_jobs (int): Number of parallel jobs.
|
513 |
+
check_validity (bool): Whether to check for validity of molecules.
|
514 |
+
|
515 |
+
Returns:
|
516 |
+
float: Fraction of unique molecules.
|
517 |
+
"""
|
518 |
+
if k is not None:
|
519 |
+
if len(gen) < k:
|
520 |
+
warnings.warn("Can't compute unique@{}.".format(k) +
|
521 |
+
" gen contains only {} molecules".format(len(gen)))
|
522 |
+
gen = gen[:k]
|
523 |
+
if check_validity:
|
524 |
+
canonic = list(mapper(n_jobs)(canonic_smiles, gen))
|
525 |
+
canonic = [i for i in canonic if i is not None]
|
526 |
+
set_cannonic = set(canonic)
|
527 |
+
return 0 if len(canonic) == 0 else len(set_cannonic) / len(canonic)
|
528 |
+
|
529 |
+
|
530 |
+
def novelty(gen, train, n_jobs=1):
|
531 |
+
"""
|
532 |
+
Computes the novelty score of generated molecules.
|
533 |
+
|
534 |
+
Novelty is defined as the fraction of generated molecules that do not appear in the training set.
|
535 |
+
|
536 |
+
Args:
|
537 |
+
gen (list): List of generated SMILES strings.
|
538 |
+
train (list): List of training SMILES strings.
|
539 |
+
n_jobs (int): Number of parallel jobs.
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
float: Novelty score.
|
543 |
+
"""
|
544 |
+
gen_smiles = mapper(n_jobs)(canonic_smiles, gen)
|
545 |
+
gen_smiles_set = set(gen_smiles) - {None}
|
546 |
+
train_set = set(train)
|
547 |
+
return 0 if len(gen_smiles_set) == 0 else len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
548 |
+
|
549 |
+
|
550 |
+
def internal_diversity(gen):
|
551 |
+
"""
|
552 |
+
Computes the internal diversity of a set of molecules.
|
553 |
+
|
554 |
+
Internal diversity is defined as one minus the average Tanimoto similarity between all pairs.
|
555 |
+
|
556 |
+
Args:
|
557 |
+
gen: Array-like representation of molecules.
|
558 |
+
|
559 |
+
Returns:
|
560 |
+
tuple: Mean and standard deviation of internal diversity.
|
561 |
+
"""
|
562 |
+
diversity = [1 - x for x in average_agg_tanimoto(gen, gen, agg="mean", intdiv=True)]
|
563 |
+
return np.mean(diversity), np.std(diversity)
|
564 |
+
|
565 |
+
|
566 |
+
def average_agg_tanimoto(stock_vecs, gen_vecs, batch_size=5000, agg='max', device='cpu', p=1, intdiv=False):
|
567 |
+
"""
|
568 |
+
Computes the average aggregated Tanimoto similarity between two sets of molecular fingerprints.
|
569 |
+
|
570 |
+
For each fingerprint in gen_vecs, finds the closest (max or mean) similarity with fingerprints in stock_vecs.
|
571 |
+
|
572 |
+
Args:
|
573 |
+
stock_vecs (numpy.ndarray): Array of fingerprint vectors from the reference set.
|
574 |
+
gen_vecs (numpy.ndarray): Array of fingerprint vectors from the generated set.
|
575 |
+
batch_size (int): Batch size for processing fingerprints.
|
576 |
+
agg (str): Aggregation method, either 'max' or 'mean'.
|
577 |
+
device (str): Device to perform computations on.
|
578 |
+
p (int): Power for averaging.
|
579 |
+
intdiv (bool): Whether to return individual similarities or the average.
|
580 |
+
|
581 |
+
Returns:
|
582 |
+
float or numpy.ndarray: Average aggregated Tanimoto similarity or array of individual scores.
|
583 |
+
"""
|
584 |
+
assert agg in ['max', 'mean'], "Can aggregate only max or mean"
|
585 |
+
agg_tanimoto = np.zeros(len(gen_vecs))
|
586 |
+
total = np.zeros(len(gen_vecs))
|
587 |
+
for j in range(0, stock_vecs.shape[0], batch_size):
|
588 |
+
x_stock = torch.tensor(stock_vecs[j:j + batch_size]).to(device).float()
|
589 |
+
for i in range(0, gen_vecs.shape[0], batch_size):
|
590 |
+
y_gen = torch.tensor(gen_vecs[i:i + batch_size]).to(device).float()
|
591 |
+
y_gen = y_gen.transpose(0, 1)
|
592 |
+
tp = torch.mm(x_stock, y_gen)
|
593 |
+
# Compute Jaccard/Tanimoto similarity
|
594 |
+
jac = (tp / (x_stock.sum(1, keepdim=True) + y_gen.sum(0, keepdim=True) - tp)).cpu().numpy()
|
595 |
+
jac[np.isnan(jac)] = 1
|
596 |
+
if p != 1:
|
597 |
+
jac = jac ** p
|
598 |
+
if agg == 'max':
|
599 |
+
agg_tanimoto[i:i + y_gen.shape[1]] = np.maximum(
|
600 |
+
agg_tanimoto[i:i + y_gen.shape[1]], jac.max(0))
|
601 |
+
elif agg == 'mean':
|
602 |
+
agg_tanimoto[i:i + y_gen.shape[1]] += jac.sum(0)
|
603 |
+
total[i:i + y_gen.shape[1]] += jac.shape[0]
|
604 |
+
if agg == 'mean':
|
605 |
+
agg_tanimoto /= total
|
606 |
+
if p != 1:
|
607 |
+
agg_tanimoto = (agg_tanimoto) ** (1 / p)
|
608 |
+
if intdiv:
|
609 |
+
return agg_tanimoto
|
610 |
+
else:
|
611 |
+
return np.mean(agg_tanimoto)
|
612 |
+
|
613 |
+
|
614 |
+
def str2bool(v):
|
615 |
+
"""
|
616 |
+
Converts a string to a boolean.
|
617 |
+
|
618 |
+
Args:
|
619 |
+
v (str): Input string.
|
620 |
+
|
621 |
+
Returns:
|
622 |
+
bool: True if the string is 'true' (case insensitive), else False.
|
623 |
+
"""
|
624 |
+
return v.lower() in ('true')
|
625 |
+
|
626 |
+
|
627 |
+
def obey_lipinski(mol):
|
628 |
+
"""
|
629 |
+
Checks if a molecule obeys Lipinski's Rule of Five.
|
630 |
+
|
631 |
+
The function evaluates weight, hydrogen bond donors and acceptors, logP, and rotatable bonds.
|
632 |
+
|
633 |
+
Args:
|
634 |
+
mol (RDKit Mol): Molecule object.
|
635 |
+
|
636 |
+
Returns:
|
637 |
+
int: Number of Lipinski rules satisfied.
|
638 |
+
"""
|
639 |
+
mol = deepcopy(mol)
|
640 |
+
Chem.SanitizeMol(mol)
|
641 |
+
rule_1 = Descriptors.ExactMolWt(mol) < 500
|
642 |
+
rule_2 = Lipinski.NumHDonors(mol) <= 5
|
643 |
+
rule_3 = Lipinski.NumHAcceptors(mol) <= 10
|
644 |
+
rule_4 = (logp := Crippen.MolLogP(mol) >= -2) & (logp <= 5)
|
645 |
+
rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
|
646 |
+
return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]])
|
647 |
+
|
648 |
+
|
649 |
+
def obey_veber(mol):
|
650 |
+
"""
|
651 |
+
Checks if a molecule obeys Veber's rules.
|
652 |
+
|
653 |
+
Veber's rules focus on the number of rotatable bonds and topological polar surface area.
|
654 |
+
|
655 |
+
Args:
|
656 |
+
mol (RDKit Mol): Molecule object.
|
657 |
+
|
658 |
+
Returns:
|
659 |
+
int: Number of Veber's rules satisfied.
|
660 |
+
"""
|
661 |
+
mol = deepcopy(mol)
|
662 |
+
Chem.SanitizeMol(mol)
|
663 |
+
rule_1 = rdMolDescriptors.CalcNumRotatableBonds(mol) <= 10
|
664 |
+
rule_2 = rdMolDescriptors.CalcTPSA(mol) <= 140
|
665 |
+
return np.sum([int(a) for a in [rule_1, rule_2]])
|
666 |
+
|
667 |
+
|
668 |
+
def load_pains_filters():
|
669 |
+
"""
|
670 |
+
Loads the PAINS (Pan-Assay INterference compoundS) filters A, B, and C.
|
671 |
+
|
672 |
+
Returns:
|
673 |
+
FilterCatalog: An RDKit FilterCatalog object containing PAINS filters.
|
674 |
+
"""
|
675 |
+
params = FilterCatalog.FilterCatalogParams()
|
676 |
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_A)
|
677 |
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_B)
|
678 |
+
params.AddCatalog(FilterCatalog.FilterCatalogParams.FilterCatalogs.PAINS_C)
|
679 |
+
catalog = FilterCatalog.FilterCatalog(params)
|
680 |
+
return catalog
|
681 |
+
|
682 |
+
|
683 |
+
def is_pains(mol, catalog):
|
684 |
+
"""
|
685 |
+
Checks if the given molecule is a PAINS compound.
|
686 |
+
|
687 |
+
Args:
|
688 |
+
mol (RDKit Mol): Molecule object.
|
689 |
+
catalog (FilterCatalog): A catalog of PAINS filters.
|
690 |
+
|
691 |
+
Returns:
|
692 |
+
bool: True if the molecule matches a PAINS filter, else False.
|
693 |
+
"""
|
694 |
+
entry = catalog.GetFirstMatch(mol)
|
695 |
+
return entry is not None
|
696 |
+
|
697 |
+
|
698 |
+
def mapper(n_jobs):
|
699 |
+
"""
|
700 |
+
Returns a mapping function for parallel or serial processing.
|
701 |
+
|
702 |
+
If n_jobs == 1, returns the built-in map function.
|
703 |
+
If n_jobs > 1, returns a function that uses a multiprocessing pool.
|
704 |
+
|
705 |
+
Args:
|
706 |
+
n_jobs (int or pool object): Number of jobs or a Pool instance.
|
707 |
+
|
708 |
+
Returns:
|
709 |
+
callable: A function that acts like map.
|
710 |
+
"""
|
711 |
+
if n_jobs == 1:
|
712 |
+
def _mapper(*args, **kwargs):
|
713 |
+
return list(map(*args, **kwargs))
|
714 |
+
return _mapper
|
715 |
+
if isinstance(n_jobs, int):
|
716 |
+
pool = Pool(n_jobs)
|
717 |
+
def _mapper(*args, **kwargs):
|
718 |
+
try:
|
719 |
+
result = pool.map(*args, **kwargs)
|
720 |
+
finally:
|
721 |
+
pool.terminate()
|
722 |
+
return result
|
723 |
+
return _mapper
|
724 |
+
return n_jobs.map
|
725 |
+
|
726 |
+
|
727 |
+
def fragmenter(mol):
|
728 |
+
"""
|
729 |
+
Fragments a molecule using BRICS and returns a list of fragment SMILES.
|
730 |
+
|
731 |
+
Args:
|
732 |
+
mol (str or RDKit Mol): Input molecule.
|
733 |
+
|
734 |
+
Returns:
|
735 |
+
list: List of fragment SMILES strings.
|
736 |
+
"""
|
737 |
+
fgs = AllChem.FragmentOnBRICSBonds(get_mol(mol))
|
738 |
+
fgs_smi = Chem.MolToSmiles(fgs).split(".")
|
739 |
+
return fgs_smi
|
740 |
+
|
741 |
+
|
742 |
+
def get_mol(smiles_or_mol):
|
743 |
+
"""
|
744 |
+
Loads a SMILES string or molecule into an RDKit molecule object.
|
745 |
+
|
746 |
+
Args:
|
747 |
+
smiles_or_mol (str or RDKit Mol): SMILES string or molecule.
|
748 |
+
|
749 |
+
Returns:
|
750 |
+
RDKit Mol or None: Sanitized molecule object or None if invalid.
|
751 |
+
"""
|
752 |
+
if isinstance(smiles_or_mol, str):
|
753 |
+
if len(smiles_or_mol) == 0:
|
754 |
+
return None
|
755 |
+
mol = Chem.MolFromSmiles(smiles_or_mol)
|
756 |
+
if mol is None:
|
757 |
+
return None
|
758 |
+
try:
|
759 |
+
Chem.SanitizeMol(mol)
|
760 |
+
except ValueError:
|
761 |
+
return None
|
762 |
+
return mol
|
763 |
+
return smiles_or_mol
|
764 |
+
|
765 |
+
|
766 |
+
def compute_fragments(mol_list, n_jobs=1):
|
767 |
+
"""
|
768 |
+
Fragments a list of molecules using BRICS and returns a counter of fragment occurrences.
|
769 |
+
|
770 |
+
Args:
|
771 |
+
mol_list (list): List of molecules (SMILES or RDKit Mol).
|
772 |
+
n_jobs (int): Number of parallel jobs.
|
773 |
+
|
774 |
+
Returns:
|
775 |
+
Counter: A Counter dictionary mapping fragment SMILES to counts.
|
776 |
+
"""
|
777 |
+
fragments = Counter()
|
778 |
+
for mol_frag in mapper(n_jobs)(fragmenter, mol_list):
|
779 |
+
fragments.update(mol_frag)
|
780 |
+
return fragments
|
781 |
+
|
782 |
+
|
783 |
+
def compute_scaffolds(mol_list, n_jobs=1, min_rings=2):
|
784 |
+
"""
|
785 |
+
Extracts scaffolds from a list of molecules as canonical SMILES.
|
786 |
+
|
787 |
+
Only scaffolds with at least min_rings rings are considered.
|
788 |
+
|
789 |
+
Args:
|
790 |
+
mol_list (list): List of molecules.
|
791 |
+
n_jobs (int): Number of parallel jobs.
|
792 |
+
min_rings (int): Minimum number of rings required in a scaffold.
|
793 |
+
|
794 |
+
Returns:
|
795 |
+
Counter: A Counter mapping scaffold SMILES to counts.
|
796 |
+
"""
|
797 |
+
scaffolds = Counter()
|
798 |
+
map_ = mapper(n_jobs)
|
799 |
+
scaffolds = Counter(map_(partial(compute_scaffold, min_rings=min_rings), mol_list))
|
800 |
+
if None in scaffolds:
|
801 |
+
scaffolds.pop(None)
|
802 |
+
return scaffolds
|
803 |
+
|
804 |
+
|
805 |
+
def get_n_rings(mol):
|
806 |
+
"""
|
807 |
+
Computes the number of rings in a molecule.
|
808 |
+
|
809 |
+
Args:
|
810 |
+
mol (RDKit Mol): Molecule object.
|
811 |
+
|
812 |
+
Returns:
|
813 |
+
int: Number of rings.
|
814 |
+
"""
|
815 |
+
return mol.GetRingInfo().NumRings()
|
816 |
+
|
817 |
+
|
818 |
+
def compute_scaffold(mol, min_rings=2):
|
819 |
+
"""
|
820 |
+
Computes the Murcko scaffold of a molecule and returns its canonical SMILES if it has enough rings.
|
821 |
+
|
822 |
+
Args:
|
823 |
+
mol (str or RDKit Mol): Input molecule.
|
824 |
+
min_rings (int): Minimum number of rings required.
|
825 |
+
|
826 |
+
Returns:
|
827 |
+
str or None: Canonical SMILES of the scaffold if valid, else None.
|
828 |
+
"""
|
829 |
+
mol = get_mol(mol)
|
830 |
+
try:
|
831 |
+
scaffold = MurckoScaffold.GetScaffoldForMol(mol)
|
832 |
+
except (ValueError, RuntimeError):
|
833 |
+
return None
|
834 |
+
n_rings = get_n_rings(scaffold)
|
835 |
+
scaffold_smiles = Chem.MolToSmiles(scaffold)
|
836 |
+
if scaffold_smiles == '' or n_rings < min_rings:
|
837 |
+
return None
|
838 |
+
return scaffold_smiles
|
839 |
+
|
840 |
+
|
841 |
+
class Metric:
|
842 |
+
"""
|
843 |
+
Abstract base class for chemical metrics.
|
844 |
+
|
845 |
+
Derived classes should implement the precalc and metric methods.
|
846 |
+
"""
|
847 |
+
def __init__(self, n_jobs=1, device='cpu', batch_size=512, **kwargs):
|
848 |
+
self.n_jobs = n_jobs
|
849 |
+
self.device = device
|
850 |
+
self.batch_size = batch_size
|
851 |
+
for k, v in kwargs.items():
|
852 |
+
setattr(self, k, v)
|
853 |
+
|
854 |
+
def __call__(self, ref=None, gen=None, pref=None, pgen=None):
|
855 |
+
"""
|
856 |
+
Computes the metric between reference and generated molecules.
|
857 |
+
|
858 |
+
Exactly one of ref or pref, and gen or pgen should be provided.
|
859 |
+
|
860 |
+
Args:
|
861 |
+
ref: Reference molecule list.
|
862 |
+
gen: Generated molecule list.
|
863 |
+
pref: Precalculated reference metric.
|
864 |
+
pgen: Precalculated generated metric.
|
865 |
+
|
866 |
+
Returns:
|
867 |
+
Metric value computed by the metric method.
|
868 |
+
"""
|
869 |
+
assert (ref is None) != (pref is None), "specify ref xor pref"
|
870 |
+
assert (gen is None) != (pgen is None), "specify gen xor pgen"
|
871 |
+
if pref is None:
|
872 |
+
pref = self.precalc(ref)
|
873 |
+
if pgen is None:
|
874 |
+
pgen = self.precalc(gen)
|
875 |
+
return self.metric(pref, pgen)
|
876 |
+
|
877 |
+
def precalc(self, molecules):
|
878 |
+
"""
|
879 |
+
Pre-calculates necessary representations from a list of molecules.
|
880 |
+
Should be implemented by derived classes.
|
881 |
+
"""
|
882 |
+
raise NotImplementedError
|
883 |
+
|
884 |
+
def metric(self, pref, pgen):
|
885 |
+
"""
|
886 |
+
Computes the metric given precalculated representations.
|
887 |
+
Should be implemented by derived classes.
|
888 |
+
"""
|
889 |
+
raise NotImplementedError
|
890 |
+
|
891 |
+
|
892 |
+
class FragMetric(Metric):
|
893 |
+
"""
|
894 |
+
Metrics based on molecular fragments.
|
895 |
+
"""
|
896 |
+
def precalc(self, mols):
|
897 |
+
return {'frag': compute_fragments(mols, n_jobs=self.n_jobs)}
|
898 |
+
|
899 |
+
def metric(self, pref, pgen):
|
900 |
+
return cos_similarity(pref['frag'], pgen['frag'])
|
901 |
+
|
902 |
+
|
903 |
+
class ScafMetric(Metric):
|
904 |
+
"""
|
905 |
+
Metrics based on molecular scaffolds.
|
906 |
+
"""
|
907 |
+
def precalc(self, mols):
|
908 |
+
return {'scaf': compute_scaffolds(mols, n_jobs=self.n_jobs)}
|
909 |
+
|
910 |
+
def metric(self, pref, pgen):
|
911 |
+
return cos_similarity(pref['scaf'], pgen['scaf'])
|
912 |
+
|
913 |
+
|
914 |
+
def cos_similarity(ref_counts, gen_counts):
|
915 |
+
"""
|
916 |
+
Computes cosine similarity between two molecular vectors.
|
917 |
+
|
918 |
+
Args:
|
919 |
+
ref_counts (dict): Reference molecular vectors.
|
920 |
+
gen_counts (dict): Generated molecular vectors.
|
921 |
+
|
922 |
+
Returns:
|
923 |
+
float: Cosine similarity between the two molecular vectors.
|
924 |
+
"""
|
925 |
+
if len(ref_counts) == 0 or len(gen_counts) == 0:
|
926 |
+
return np.nan
|
927 |
+
keys = np.unique(list(ref_counts.keys()) + list(gen_counts.keys()))
|
928 |
+
ref_vec = np.array([ref_counts.get(k, 0) for k in keys])
|
929 |
+
gen_vec = np.array([gen_counts.get(k, 0) for k in keys])
|
930 |
+
return 1 - cos_distance(ref_vec, gen_vec)
|
train.py
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import random
|
4 |
+
import pickle
|
5 |
+
import argparse
|
6 |
+
import os.path as osp
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
from torch import nn
|
11 |
+
from torch_geometric.loader import DataLoader
|
12 |
+
|
13 |
+
import wandb
|
14 |
+
from rdkit import RDLogger
|
15 |
+
|
16 |
+
torch.set_num_threads(5)
|
17 |
+
RDLogger.DisableLog('rdApp.*')
|
18 |
+
|
19 |
+
from src.util.utils import *
|
20 |
+
from src.model.models import Generator, Discriminator, simple_disc
|
21 |
+
from src.data.dataset import DruggenDataset
|
22 |
+
from src.data.utils import get_encoders_decoders, load_molecules
|
23 |
+
from src.model.loss import discriminator_loss, generator_loss
|
24 |
+
|
25 |
+
class Train(object):
|
26 |
+
"""Trainer for DrugGEN."""
|
27 |
+
|
28 |
+
def __init__(self, config):
|
29 |
+
if config.set_seed:
|
30 |
+
np.random.seed(config.seed)
|
31 |
+
random.seed(config.seed)
|
32 |
+
torch.manual_seed(config.seed)
|
33 |
+
torch.cuda.manual_seed_all(config.seed)
|
34 |
+
|
35 |
+
torch.backends.cudnn.deterministic = True
|
36 |
+
torch.backends.cudnn.benchmark = False
|
37 |
+
|
38 |
+
os.environ["PYTHONHASHSEED"] = str(config.seed)
|
39 |
+
|
40 |
+
print(f'Using seed {config.seed}')
|
41 |
+
|
42 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
|
43 |
+
|
44 |
+
# Initialize configurations
|
45 |
+
self.submodel = config.submodel
|
46 |
+
|
47 |
+
# Data loader.
|
48 |
+
self.raw_file = config.raw_file # SMILES containing text file for dataset.
|
49 |
+
# Write the full path to file.
|
50 |
+
self.drug_raw_file = config.drug_raw_file # SMILES containing text file for second dataset.
|
51 |
+
# Write the full path to file.
|
52 |
+
|
53 |
+
# Automatically infer dataset file names from raw file names
|
54 |
+
raw_file_basename = osp.basename(self.raw_file)
|
55 |
+
drug_raw_file_basename = osp.basename(self.drug_raw_file)
|
56 |
+
|
57 |
+
# Get the base name without extension and add max_atom to it
|
58 |
+
self.max_atom = config.max_atom # Model is based on one-shot generation.
|
59 |
+
raw_file_base = os.path.splitext(raw_file_basename)[0]
|
60 |
+
drug_raw_file_base = os.path.splitext(drug_raw_file_basename)[0]
|
61 |
+
|
62 |
+
# Change extension from .smi to .pt and add max_atom to the filename
|
63 |
+
self.dataset_file = f"{raw_file_base}{self.max_atom}.pt"
|
64 |
+
self.drugs_dataset_file = f"{drug_raw_file_base}{self.max_atom}.pt"
|
65 |
+
|
66 |
+
self.mol_data_dir = config.mol_data_dir # Directory where the dataset files are stored.
|
67 |
+
self.drug_data_dir = config.drug_data_dir # Directory where the drug dataset files are stored.
|
68 |
+
self.dataset_name = self.dataset_file.split(".")[0]
|
69 |
+
self.drugs_dataset_name = self.drugs_dataset_file.split(".")[0]
|
70 |
+
self.features = config.features # Small model uses atom types as node features. (Boolean, False uses atom types only.)
|
71 |
+
# Additional node features can be added. Please check new_dataloarder.py Line 102.
|
72 |
+
self.batch_size = config.batch_size # Batch size for training.
|
73 |
+
|
74 |
+
self.parallel = config.parallel
|
75 |
+
|
76 |
+
# Get atom and bond encoders/decoders
|
77 |
+
atom_encoder, atom_decoder, bond_encoder, bond_decoder = get_encoders_decoders(
|
78 |
+
self.raw_file,
|
79 |
+
self.drug_raw_file,
|
80 |
+
self.max_atom
|
81 |
+
)
|
82 |
+
self.atom_encoder = atom_encoder
|
83 |
+
self.atom_decoder = atom_decoder
|
84 |
+
self.bond_encoder = bond_encoder
|
85 |
+
self.bond_decoder = bond_decoder
|
86 |
+
|
87 |
+
self.dataset = DruggenDataset(self.mol_data_dir,
|
88 |
+
self.dataset_file,
|
89 |
+
self.raw_file,
|
90 |
+
self.max_atom,
|
91 |
+
self.features,
|
92 |
+
atom_encoder=atom_encoder,
|
93 |
+
atom_decoder=atom_decoder,
|
94 |
+
bond_encoder=bond_encoder,
|
95 |
+
bond_decoder=bond_decoder)
|
96 |
+
|
97 |
+
self.loader = DataLoader(self.dataset,
|
98 |
+
shuffle=True,
|
99 |
+
batch_size=self.batch_size,
|
100 |
+
drop_last=True) # PyG dataloader for the GAN.
|
101 |
+
|
102 |
+
self.drugs = DruggenDataset(self.drug_data_dir,
|
103 |
+
self.drugs_dataset_file,
|
104 |
+
self.drug_raw_file,
|
105 |
+
self.max_atom,
|
106 |
+
self.features,
|
107 |
+
atom_encoder=atom_encoder,
|
108 |
+
atom_decoder=atom_decoder,
|
109 |
+
bond_encoder=bond_encoder,
|
110 |
+
bond_decoder=bond_decoder)
|
111 |
+
|
112 |
+
self.drugs_loader = DataLoader(self.drugs,
|
113 |
+
shuffle=True,
|
114 |
+
batch_size=self.batch_size,
|
115 |
+
drop_last=True) # PyG dataloader for the second GAN.
|
116 |
+
|
117 |
+
self.m_dim = len(self.atom_decoder) if not self.features else int(self.loader.dataset[0].x.shape[1]) # Atom type dimension.
|
118 |
+
self.b_dim = len(self.bond_decoder) # Bond type dimension.
|
119 |
+
self.vertexes = int(self.loader.dataset[0].x.shape[0]) # Number of nodes in the graph.
|
120 |
+
|
121 |
+
# Model configurations.
|
122 |
+
self.act = config.act
|
123 |
+
self.lambda_gp = config.lambda_gp
|
124 |
+
self.dim = config.dim
|
125 |
+
self.depth = config.depth
|
126 |
+
self.heads = config.heads
|
127 |
+
self.mlp_ratio = config.mlp_ratio
|
128 |
+
self.ddepth = config.ddepth
|
129 |
+
self.ddropout = config.ddropout
|
130 |
+
|
131 |
+
# Training configurations.
|
132 |
+
self.epoch = config.epoch
|
133 |
+
self.g_lr = config.g_lr
|
134 |
+
self.d_lr = config.d_lr
|
135 |
+
self.dropout = config.dropout
|
136 |
+
self.beta1 = config.beta1
|
137 |
+
self.beta2 = config.beta2
|
138 |
+
|
139 |
+
# Directories.
|
140 |
+
self.log_dir = config.log_dir
|
141 |
+
self.sample_dir = config.sample_dir
|
142 |
+
self.model_save_dir = config.model_save_dir
|
143 |
+
|
144 |
+
# Step size.
|
145 |
+
self.log_step = config.log_sample_step
|
146 |
+
|
147 |
+
# resume training
|
148 |
+
self.resume = config.resume
|
149 |
+
self.resume_epoch = config.resume_epoch
|
150 |
+
self.resume_iter = config.resume_iter
|
151 |
+
self.resume_directory = config.resume_directory
|
152 |
+
|
153 |
+
# wandb configuration
|
154 |
+
self.use_wandb = config.use_wandb
|
155 |
+
self.online = config.online
|
156 |
+
self.exp_name = config.exp_name
|
157 |
+
|
158 |
+
# Arguments for the model.
|
159 |
+
self.arguments = "{}_{}_glr{}_dlr{}_dim{}_depth{}_heads{}_batch{}_epoch{}_dataset{}_dropout{}".format(self.exp_name, self.submodel, self.g_lr, self.d_lr, self.dim, self.depth, self.heads, self.batch_size, self.epoch, self.dataset_name, self.dropout)
|
160 |
+
|
161 |
+
self.build_model(self.model_save_dir, self.arguments)
|
162 |
+
|
163 |
+
|
164 |
+
def build_model(self, model_save_dir, arguments):
|
165 |
+
"""Create generators and discriminators."""
|
166 |
+
|
167 |
+
''' Generator is based on Transformer Encoder:
|
168 |
+
|
169 |
+
@ g_conv_dim: Dimensions for MLP layers before Transformer Encoder
|
170 |
+
@ vertexes: maximum length of generated molecules (atom length)
|
171 |
+
@ b_dim: number of bond types
|
172 |
+
@ m_dim: number of atom types (or number of features used)
|
173 |
+
@ dropout: dropout possibility
|
174 |
+
@ dim: Hidden dimension of Transformer Encoder
|
175 |
+
@ depth: Transformer layer number
|
176 |
+
@ heads: Number of multihead-attention heads
|
177 |
+
@ mlp_ratio: Read-out layer dimension of Transformer
|
178 |
+
@ drop_rate: depricated
|
179 |
+
@ tra_conv: Whether module creates output for TransformerConv discriminator
|
180 |
+
'''
|
181 |
+
self.G = Generator(self.act,
|
182 |
+
self.vertexes,
|
183 |
+
self.b_dim,
|
184 |
+
self.m_dim,
|
185 |
+
self.dropout,
|
186 |
+
dim=self.dim,
|
187 |
+
depth=self.depth,
|
188 |
+
heads=self.heads,
|
189 |
+
mlp_ratio=self.mlp_ratio)
|
190 |
+
|
191 |
+
''' Discriminator implementation with Transformer Encoder:
|
192 |
+
|
193 |
+
@ act: Activation function for MLP
|
194 |
+
@ vertexes: maximum length of generated molecules (molecule length)
|
195 |
+
@ b_dim: number of bond types
|
196 |
+
@ m_dim: number of atom types (or number of features used)
|
197 |
+
@ dropout: dropout possibility
|
198 |
+
@ dim: Hidden dimension of Transformer Encoder
|
199 |
+
@ depth: Transformer layer number
|
200 |
+
@ heads: Number of multihead-attention heads
|
201 |
+
@ mlp_ratio: Read-out layer dimension of Transformer'''
|
202 |
+
|
203 |
+
self.D = Discriminator(self.act,
|
204 |
+
self.vertexes,
|
205 |
+
self.b_dim,
|
206 |
+
self.m_dim,
|
207 |
+
self.ddropout,
|
208 |
+
dim=self.dim,
|
209 |
+
depth=self.ddepth,
|
210 |
+
heads=self.heads,
|
211 |
+
mlp_ratio=self.mlp_ratio)
|
212 |
+
|
213 |
+
self.g_optimizer = torch.optim.AdamW(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
|
214 |
+
self.d_optimizer = torch.optim.AdamW(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
|
215 |
+
|
216 |
+
network_path = os.path.join(model_save_dir, arguments)
|
217 |
+
self.print_network(self.G, 'G', network_path)
|
218 |
+
self.print_network(self.D, 'D', network_path)
|
219 |
+
|
220 |
+
if self.parallel and torch.cuda.device_count() > 1:
|
221 |
+
print(f"Using {torch.cuda.device_count()} GPUs!")
|
222 |
+
self.G = nn.DataParallel(self.G)
|
223 |
+
self.D = nn.DataParallel(self.D)
|
224 |
+
|
225 |
+
self.G.to(self.device)
|
226 |
+
self.D.to(self.device)
|
227 |
+
|
228 |
+
def print_network(self, model, name, save_dir):
|
229 |
+
"""Print out the network information."""
|
230 |
+
num_params = 0
|
231 |
+
for p in model.parameters():
|
232 |
+
num_params += p.numel()
|
233 |
+
|
234 |
+
if not os.path.exists(save_dir):
|
235 |
+
os.makedirs(save_dir)
|
236 |
+
|
237 |
+
network_path = os.path.join(save_dir, "{}_modules.txt".format(name))
|
238 |
+
with open(network_path, "w+") as file:
|
239 |
+
for module in model.modules():
|
240 |
+
file.write(f"{module.__class__.__name__}:\n")
|
241 |
+
print(module.__class__.__name__)
|
242 |
+
for n, param in module.named_parameters():
|
243 |
+
if param is not None:
|
244 |
+
file.write(f" - {n}: {param.size()}\n")
|
245 |
+
print(f" - {n}: {param.size()}")
|
246 |
+
break
|
247 |
+
file.write(f"Total number of parameters: {num_params}\n")
|
248 |
+
print(f"Total number of parameters: {num_params}\n\n")
|
249 |
+
|
250 |
+
def restore_model(self, epoch, iteration, model_directory):
|
251 |
+
"""Restore the trained generator and discriminator."""
|
252 |
+
print('Loading the trained models from epoch / iteration {}-{}...'.format(epoch, iteration))
|
253 |
+
|
254 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(epoch, iteration))
|
255 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(epoch, iteration))
|
256 |
+
self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
|
257 |
+
self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))
|
258 |
+
|
259 |
+
def save_model(self, model_directory, idx,i):
|
260 |
+
G_path = os.path.join(model_directory, '{}-{}-G.ckpt'.format(idx+1,i+1))
|
261 |
+
D_path = os.path.join(model_directory, '{}-{}-D.ckpt'.format(idx+1,i+1))
|
262 |
+
torch.save(self.G.state_dict(), G_path)
|
263 |
+
torch.save(self.D.state_dict(), D_path)
|
264 |
+
|
265 |
+
def reset_grad(self):
|
266 |
+
"""Reset the gradient buffers."""
|
267 |
+
self.g_optimizer.zero_grad()
|
268 |
+
self.d_optimizer.zero_grad()
|
269 |
+
|
270 |
+
def train(self, config):
|
271 |
+
''' Training Script starts from here'''
|
272 |
+
if self.use_wandb:
|
273 |
+
mode = 'online' if self.online else 'offline'
|
274 |
+
else:
|
275 |
+
mode = 'disabled'
|
276 |
+
kwargs = {'name': self.exp_name, 'project': 'druggen', 'config': config,
|
277 |
+
'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode, 'save_code': True}
|
278 |
+
wandb.init(**kwargs)
|
279 |
+
|
280 |
+
wandb.save(os.path.join(self.model_save_dir, self.arguments, "G_modules.txt"))
|
281 |
+
wandb.save(os.path.join(self.model_save_dir, self.arguments, "D_modules.txt"))
|
282 |
+
|
283 |
+
self.model_directory = os.path.join(self.model_save_dir, self.arguments)
|
284 |
+
self.sample_directory = os.path.join(self.sample_dir, self.arguments)
|
285 |
+
self.log_path = os.path.join(self.log_dir, "{}.txt".format(self.arguments))
|
286 |
+
if not os.path.exists(self.model_directory):
|
287 |
+
os.makedirs(self.model_directory)
|
288 |
+
if not os.path.exists(self.sample_directory):
|
289 |
+
os.makedirs(self.sample_directory)
|
290 |
+
|
291 |
+
# smiles data for metrics calculation.
|
292 |
+
drug_smiles = [line for line in open(self.drug_raw_file, 'r').read().splitlines()]
|
293 |
+
drug_mols = [Chem.MolFromSmiles(smi) for smi in drug_smiles]
|
294 |
+
drug_vecs = [AllChem.GetMorganFingerprintAsBitVect(x, 2, nBits=1024) for x in drug_mols if x is not None]
|
295 |
+
|
296 |
+
if self.resume:
|
297 |
+
self.restore_model(self.resume_epoch, self.resume_iter, self.resume_directory)
|
298 |
+
|
299 |
+
# Start training.
|
300 |
+
print('Start training...')
|
301 |
+
self.start_time = time.time()
|
302 |
+
for idx in range(self.epoch):
|
303 |
+
# =================================================================================== #
|
304 |
+
# 1. Preprocess input data #
|
305 |
+
# =================================================================================== #
|
306 |
+
# Load the data
|
307 |
+
dataloader_iterator = iter(self.drugs_loader)
|
308 |
+
|
309 |
+
wandb.log({"epoch": idx})
|
310 |
+
|
311 |
+
for i, data in enumerate(self.loader):
|
312 |
+
try:
|
313 |
+
drugs = next(dataloader_iterator)
|
314 |
+
except StopIteration:
|
315 |
+
dataloader_iterator = iter(self.drugs_loader)
|
316 |
+
drugs = next(dataloader_iterator)
|
317 |
+
|
318 |
+
wandb.log({"iter": i})
|
319 |
+
|
320 |
+
# Preprocess both dataset
|
321 |
+
real_graphs, a_tensor, x_tensor = load_molecules(
|
322 |
+
data=data,
|
323 |
+
batch_size=self.batch_size,
|
324 |
+
device=self.device,
|
325 |
+
b_dim=self.b_dim,
|
326 |
+
m_dim=self.m_dim,
|
327 |
+
)
|
328 |
+
|
329 |
+
drug_graphs, drugs_a_tensor, drugs_x_tensor = load_molecules(
|
330 |
+
data=drugs,
|
331 |
+
batch_size=self.batch_size,
|
332 |
+
device=self.device,
|
333 |
+
b_dim=self.b_dim,
|
334 |
+
m_dim=self.m_dim,
|
335 |
+
)
|
336 |
+
|
337 |
+
# Training configuration.
|
338 |
+
GEN_node = x_tensor # Generator input node features (annotation matrix of real molecules)
|
339 |
+
GEN_edge = a_tensor # Generator input edge features (adjacency matrix of real molecules)
|
340 |
+
if self.submodel == "DrugGEN":
|
341 |
+
DISC_node = drugs_x_tensor # Discriminator input node features (annotation matrix of drug molecules)
|
342 |
+
DISC_edge = drugs_a_tensor # Discriminator input edge features (adjacency matrix of drug molecules)
|
343 |
+
elif self.submodel == "NoTarget":
|
344 |
+
DISC_node = x_tensor # Discriminator input node features (annotation matrix of real molecules)
|
345 |
+
DISC_edge = a_tensor # Discriminator input edge features (adjacency matrix of real molecules)
|
346 |
+
|
347 |
+
# =================================================================================== #
|
348 |
+
# 2. Train the GAN #
|
349 |
+
# =================================================================================== #
|
350 |
+
|
351 |
+
loss = {}
|
352 |
+
self.reset_grad()
|
353 |
+
# Compute discriminator loss.
|
354 |
+
node, edge, d_loss = discriminator_loss(self.G,
|
355 |
+
self.D,
|
356 |
+
DISC_edge,
|
357 |
+
DISC_node,
|
358 |
+
GEN_edge,
|
359 |
+
GEN_node,
|
360 |
+
self.batch_size,
|
361 |
+
self.device,
|
362 |
+
self.lambda_gp)
|
363 |
+
d_total = d_loss
|
364 |
+
wandb.log({"d_loss": d_total.item()})
|
365 |
+
|
366 |
+
loss["d_total"] = d_total.item()
|
367 |
+
d_total.backward()
|
368 |
+
self.d_optimizer.step()
|
369 |
+
|
370 |
+
self.reset_grad()
|
371 |
+
|
372 |
+
# Compute generator loss.
|
373 |
+
generator_output = generator_loss(self.G,
|
374 |
+
self.D,
|
375 |
+
GEN_edge,
|
376 |
+
GEN_node,
|
377 |
+
self.batch_size)
|
378 |
+
g_loss, node, edge, node_sample, edge_sample = generator_output
|
379 |
+
g_total = g_loss
|
380 |
+
wandb.log({"g_loss": g_total.item()})
|
381 |
+
|
382 |
+
loss["g_total"] = g_total.item()
|
383 |
+
g_total.backward()
|
384 |
+
self.g_optimizer.step()
|
385 |
+
|
386 |
+
# Logging.
|
387 |
+
if (i+1) % self.log_step == 0:
|
388 |
+
logging(self.log_path, self.start_time, i, idx, loss, self.sample_directory,
|
389 |
+
drug_smiles,edge_sample, node_sample, self.dataset.matrices2mol,
|
390 |
+
self.dataset_name, a_tensor, x_tensor, drug_vecs)
|
391 |
+
|
392 |
+
mol_sample(self.sample_directory, edge_sample.detach(), node_sample.detach(),
|
393 |
+
idx, i, self.dataset.matrices2mol, self.dataset_name)
|
394 |
+
print("samples saved at epoch {} and iteration {}".format(idx,i))
|
395 |
+
|
396 |
+
self.save_model(self.model_directory, idx, i)
|
397 |
+
print("model saved at epoch {} and iteration {}".format(idx,i))
|
398 |
+
|
399 |
+
|
400 |
+
if __name__ == '__main__':
|
401 |
+
parser = argparse.ArgumentParser()
|
402 |
+
|
403 |
+
# Data configuration.
|
404 |
+
parser.add_argument('--raw_file', type=str, required=True)
|
405 |
+
parser.add_argument('--drug_raw_file', type=str, required=False, help='Required for DrugGEN model, optional for NoTarget')
|
406 |
+
parser.add_argument('--drug_data_dir', type=str, default='data')
|
407 |
+
parser.add_argument('--mol_data_dir', type=str, default='data')
|
408 |
+
parser.add_argument('--features', action='store_true', help='features dimension for nodes')
|
409 |
+
|
410 |
+
# Model configuration.
|
411 |
+
parser.add_argument('--submodel', type=str, default="DrugGEN", help="Chose model subtype: DrugGEN, NoTarget", choices=['DrugGEN', 'NoTarget'])
|
412 |
+
parser.add_argument('--act', type=str, default="relu", help="Activation function for the model.", choices=['relu', 'tanh', 'leaky', 'sigmoid'])
|
413 |
+
parser.add_argument('--max_atom', type=int, default=45, help='Max atom number for molecules must be specified.')
|
414 |
+
parser.add_argument('--dim', type=int, default=128, help='Dimension of the Transformer Encoder model for the GAN.')
|
415 |
+
parser.add_argument('--depth', type=int, default=1, help='Depth of the Transformer model from the GAN.')
|
416 |
+
parser.add_argument('--ddepth', type=int, default=1, help='Depth of the Transformer model from the discriminator.')
|
417 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of heads for the MultiHeadAttention module from the GAN.')
|
418 |
+
parser.add_argument('--mlp_ratio', type=int, default=3, help='MLP ratio for the Transformer.')
|
419 |
+
parser.add_argument('--dropout', type=float, default=0., help='dropout rate')
|
420 |
+
parser.add_argument('--ddropout', type=float, default=0., help='dropout rate for the discriminator')
|
421 |
+
parser.add_argument('--lambda_gp', type=float, default=10, help='Gradient penalty lambda multiplier for the GAN.')
|
422 |
+
|
423 |
+
# Training configuration.
|
424 |
+
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for the training.')
|
425 |
+
parser.add_argument('--epoch', type=int, default=10, help='Epoch number for Training.')
|
426 |
+
parser.add_argument('--g_lr', type=float, default=0.00001, help='learning rate for G')
|
427 |
+
parser.add_argument('--d_lr', type=float, default=0.00001, help='learning rate for D')
|
428 |
+
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for Adam optimizer')
|
429 |
+
parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer')
|
430 |
+
parser.add_argument('--log_dir', type=str, default='experiments/logs')
|
431 |
+
parser.add_argument('--sample_dir', type=str, default='experiments/samples')
|
432 |
+
parser.add_argument('--model_save_dir', type=str, default='experiments/models')
|
433 |
+
parser.add_argument('--log_sample_step', type=int, default=1000, help='step size for sampling during training')
|
434 |
+
|
435 |
+
# Resume training.
|
436 |
+
parser.add_argument('--resume', type=bool, default=False, help='resume training')
|
437 |
+
parser.add_argument('--resume_epoch', type=int, default=None, help='resume training from this epoch')
|
438 |
+
parser.add_argument('--resume_iter', type=int, default=None, help='resume training from this step')
|
439 |
+
parser.add_argument('--resume_directory', type=str, default=None, help='load pretrained weights from this directory')
|
440 |
+
|
441 |
+
# Seed configuration.
|
442 |
+
parser.add_argument('--set_seed', action='store_true', help='set seed for reproducibility')
|
443 |
+
parser.add_argument('--seed', type=int, default=1, help='seed for reproducibility')
|
444 |
+
|
445 |
+
# wandb configuration.
|
446 |
+
parser.add_argument('--use_wandb', action='store_true', help='use wandb for logging')
|
447 |
+
parser.add_argument('--online', action='store_true', help='use wandb online')
|
448 |
+
parser.add_argument('--exp_name', type=str, default='druggen', help='experiment name')
|
449 |
+
parser.add_argument('--parallel', action='store_true', help='Parallelize training')
|
450 |
+
|
451 |
+
config = parser.parse_args()
|
452 |
+
|
453 |
+
# Check if drug_raw_file is provided when using DrugGEN model
|
454 |
+
if config.submodel == "DrugGEN" and not config.drug_raw_file:
|
455 |
+
parser.error("--drug_raw_file is required when using DrugGEN model")
|
456 |
+
|
457 |
+
# If using NoTarget model and drug_raw_file is not provided, use a dummy file
|
458 |
+
if config.submodel == "NoTarget" and not config.drug_raw_file:
|
459 |
+
config.drug_raw_file = "data/akt_train.smi" # Use a reference file for NoTarget model (AKT) (not used for training for ease of use and encoder/decoder's)
|
460 |
+
|
461 |
+
trainer = Train(config)
|
462 |
+
trainer.train(config)
|