ipd commited on
Commit
5306c2a
·
1 Parent(s): 468214c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. INTRODUCTION.md +9 -0
  2. app.py +523 -0
  3. data/.DS_Store +0 -0
  4. data/bace/test.csv +0 -0
  5. data/bace/train.csv +0 -0
  6. data/bace/valid.csv +0 -0
  7. data/esol/test.csv +109 -0
  8. data/esol/train.csv +0 -0
  9. data/lce/test.csv +31 -0
  10. data/lce/test_data.csv +14 -0
  11. data/lce/train.csv +121 -0
  12. data/lce/train_data.csv +148 -0
  13. models/.DS_Store +0 -0
  14. models/.gitattributes +3 -0
  15. models/fm4m.py +964 -0
  16. models/mhg_model/.DS_Store +0 -0
  17. models/mhg_model/README.md +75 -0
  18. models/mhg_model/__init__.py +5 -0
  19. models/mhg_model/graph_grammar/__init__.py +19 -0
  20. models/mhg_model/graph_grammar/algo/__init__.py +20 -0
  21. models/mhg_model/graph_grammar/algo/tree_decomposition.py +821 -0
  22. models/mhg_model/graph_grammar/graph_grammar/__init__.py +20 -0
  23. models/mhg_model/graph_grammar/graph_grammar/base.py +30 -0
  24. models/mhg_model/graph_grammar/graph_grammar/corpus.py +152 -0
  25. models/mhg_model/graph_grammar/graph_grammar/hrg.py +1065 -0
  26. models/mhg_model/graph_grammar/graph_grammar/symbols.py +180 -0
  27. models/mhg_model/graph_grammar/graph_grammar/utils.py +130 -0
  28. models/mhg_model/graph_grammar/hypergraph.py +544 -0
  29. models/mhg_model/graph_grammar/io/__init__.py +20 -0
  30. models/mhg_model/graph_grammar/io/smi.py +559 -0
  31. models/mhg_model/graph_grammar/nn/__init__.py +11 -0
  32. models/mhg_model/graph_grammar/nn/dataset.py +121 -0
  33. models/mhg_model/graph_grammar/nn/decoder.py +158 -0
  34. models/mhg_model/graph_grammar/nn/encoder.py +199 -0
  35. models/mhg_model/graph_grammar/nn/graph.py +313 -0
  36. models/mhg_model/load.py +103 -0
  37. models/mhg_model/mhg_gnn.egg-info/PKG-INFO +102 -0
  38. models/mhg_model/mhg_gnn.egg-info/SOURCES.txt +46 -0
  39. models/mhg_model/mhg_gnn.egg-info/dependency_links.txt +1 -0
  40. models/mhg_model/mhg_gnn.egg-info/requires.txt +7 -0
  41. models/mhg_model/mhg_gnn.egg-info/top_level.txt +2 -0
  42. models/mhg_model/models/__init__.py +5 -0
  43. models/mhg_model/models/mhgvae.py +956 -0
  44. models/mhg_model/notebooks/mhg-gnn_encoder_decoder_example.ipynb +114 -0
  45. models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf +3 -0
  46. models/mhg_model/setup.cfg +37 -0
  47. models/mhg_model/setup.py +6 -0
  48. models/selfies_model/README.md +87 -0
  49. models/selfies_model/load.py +102 -0
  50. models/selfies_model/requirements.txt +12 -0
INTRODUCTION.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Foundation Models for Materials - FM4M
2
+
3
+ FM4M adopts a modular architecture designed for flexible extensibility. As illustrated in the figure below, it comprises both uni-modal and fused models. Each uni-modal model is pre-trained independently for its respective modality (e.g., SMILES), and users can access individual functionalities directly from the corresponding model directory (e.g., smi-ted/). Some of these uni-modal models can be "late-fused" using fusion algorithms, creating a more powerful multi-modal feature representation for downstream predictions. To simplify usage, we provide fm4m-kit, a wrapper that enables users to easily access the capabilities of all models through straightforward methods. These models are also available on Hugging Face, where they can be accessed via an intuitive and user-friendly GUI.
4
+
5
+ <p align="center">
6
+ <img src="gradio_api/file=img/introduction.png" alt="FM4M Overview" width="50%"/>
7
+ </p>
8
+
9
+
app.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import pandas as pd
4
+ from tempfile import NamedTemporaryFile
5
+ from PIL import Image
6
+ from rdkit import RDLogger
7
+ from sklearn.model_selection import train_test_split
8
+
9
+ from molecule_generation_helpers import *
10
+ from property_prediction_helpers import *
11
+
12
+ DEBUG_VISIBLE = False
13
+ RDLogger.logger().setLevel(RDLogger.ERROR)
14
+
15
+ # Predefined dataset paths (these should be adjusted to your file paths)
16
+ predefined_datasets = {
17
+ " ": " ",
18
+ "BACE": "./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
19
+ "ESOL": "./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
20
+ }
21
+
22
+ # Models
23
+ models_enabled = [
24
+ "MorganFingerprint",
25
+ "SMI-TED",
26
+ "SELFIES-TED",
27
+ "MHG-GED",
28
+ ]
29
+
30
+ blank_df = pd.DataFrame({"id": [], "Model": [], "Score": []})
31
+
32
+
33
+ # Function to load a predefined dataset from the local path
34
+ def load_predefined_dataset(dataset_name):
35
+ val = predefined_datasets.get(dataset_name)
36
+ if val:
37
+ try:
38
+ df = pd.read_csv(val.split(",")[0])
39
+ return (
40
+ df.head(),
41
+ gr.update(choices=list(df.columns), value=None),
42
+ gr.update(choices=list(df.columns), value=None),
43
+ dataset_name.lower(),
44
+ )
45
+ except:
46
+ pass
47
+ else:
48
+ dataset_name = "Custom"
49
+ return (
50
+ pd.DataFrame(),
51
+ gr.update(choices=[], value=None),
52
+ gr.update(choices=[], value=None),
53
+ dataset_name.lower(),
54
+ )
55
+
56
+
57
+ # Function to handle dataset selection (predefined or custom)
58
+ def handle_dataset_selection(selected_dataset, state):
59
+ state["dataset_name"] = (
60
+ selected_dataset if selected_dataset in predefined_datasets else "CUSTOM"
61
+ )
62
+ # Show file upload fields for train and test datasets if "Custom Dataset" is selected
63
+ task_type = (
64
+ "Classification"
65
+ if selected_dataset == "BACE"
66
+ else "Regression" if selected_dataset == "ESOL" else None
67
+ )
68
+ return (
69
+ gr.update(visible=selected_dataset not in predefined_datasets or DEBUG_VISIBLE),
70
+ task_type,
71
+ )
72
+
73
+
74
+ # Function to select input and output columns and display a message
75
+ def select_columns(input_column, output_column, train_data, test_data, state):
76
+ if train_data and test_data and input_column and output_column:
77
+ return f"{train_data.name},{test_data.name},{input_column},{output_column},{state['dataset_name']}"
78
+ return gr.update()
79
+
80
+
81
+ # Function to display the head of the uploaded CSV file
82
+ def display_csv_head(file):
83
+ if file is not None:
84
+ # Load the CSV file into a DataFrame
85
+ df = pd.read_csv(file.name)
86
+ return (
87
+ df.head(),
88
+ gr.update(choices=list(df.columns)),
89
+ gr.update(choices=list(df.columns)),
90
+ )
91
+ return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
92
+
93
+
94
+ def process_custom_file(file, selected_dataset):
95
+ if file and os.path.getsize(file.name) < 50 * 1024:
96
+ df = pd.read_csv(file.name)
97
+ if "input" in df.columns and "output" in df.columns:
98
+ train, test = train_test_split(df, test_size=0.2)
99
+ with NamedTemporaryFile(
100
+ prefix="fm4m-train-", suffix=".csv", delete=False
101
+ ) as train_file:
102
+ train.to_csv(train_file.name, index=False)
103
+ with NamedTemporaryFile(
104
+ prefix="fm4m-test-", suffix=".csv", delete=False
105
+ ) as test_file:
106
+ test.to_csv(test_file.name, index=False)
107
+ task_type = (
108
+ "Classification" if df["output"].dtype == np.int64 else "Regression"
109
+ )
110
+ return train_file.name, test_file.name, "input", "output", task_type
111
+
112
+ return (
113
+ None,
114
+ None,
115
+ None,
116
+ None,
117
+ gr.update() if selected_dataset in predefined_datasets else None,
118
+ )
119
+
120
+
121
+ def update_plot_choices(current, state):
122
+ choices = []
123
+ if state.get("roc_auc") is not None:
124
+ choices.append("ROC-AUC")
125
+ if state.get("RMSE") is not None:
126
+ choices.append("Parity Plot")
127
+ if state.get("x_batch") is not None:
128
+ choices.append("Latent Space")
129
+ if current in choices:
130
+ return gr.update(choices=choices)
131
+ return gr.update(choices=choices, value=None if len(choices) == 0 else choices[0])
132
+
133
+
134
+ def log_selected(df: pd.DataFrame, evt: gr.SelectData, state):
135
+ state.update(state["results"].get(df.at[evt.index[0], 'id'], {}))
136
+
137
+
138
+ # Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
139
+ smiles_image_mapping = {
140
+ # Example SMILES for ethanol
141
+ "Mol 1": {
142
+ "smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
143
+ "image": "img/img1.png",
144
+ },
145
+ # Example SMILES for butane
146
+ "Mol 2": {
147
+ "smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
148
+ "image": "img/img2.png",
149
+ },
150
+ # Example SMILES for ethylamine
151
+ "Mol 3": {
152
+ "smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
153
+ "image": "img/img3.png",
154
+ },
155
+ # Example SMILES for diethyl ether
156
+ "Mol 4": {
157
+ "smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
158
+ "image": "img/img4.png",
159
+ },
160
+ # Example SMILES for chloroethane
161
+ "Mol 5": {
162
+ "smiles": "C=CCS[C@@H](C)CC(=O)OCC",
163
+ "image": "img/img5.png",
164
+ },
165
+ }
166
+
167
+
168
+ # Load images for selection
169
+ def load_image(path):
170
+ try:
171
+ return Image.open(smiles_image_mapping[path]["image"])
172
+ except:
173
+ pass
174
+
175
+
176
+ # Function to handle image selection
177
+ def handle_image_selection(image_key):
178
+ if not image_key:
179
+ return None, None
180
+ smiles = smiles_image_mapping[image_key]["smiles"]
181
+ mol_image = smiles_to_image(smiles)
182
+ return smiles, mol_image
183
+
184
+
185
+ # Introduction
186
+ with gr.Blocks() as introduction:
187
+ with open("INTRODUCTION.md") as f:
188
+ gr.Markdown(f.read(), sanitize_html=False)
189
+
190
+ # Property Prediction
191
+ with gr.Blocks() as property_prediction:
192
+ state = gr.State({"model_name": "Default - Auto", "results": {}})
193
+ gr.HTML(
194
+ '''
195
+ <p style="text-align: center">
196
+ Task : Property Prediction
197
+ <br>
198
+ Models are finetuned with different combination of modalities on the uploaded or selected built data set.
199
+ </p>
200
+ '''
201
+ )
202
+ with gr.Row():
203
+ with gr.Column():
204
+ # Dropdown menu for predefined datasets including "Custom Dataset" option
205
+ dataset_selector = gr.Dropdown(
206
+ label="Select Dataset",
207
+ choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
208
+ )
209
+ # Display the message for selected columns
210
+ selected_columns_message = gr.Textbox(
211
+ label="Selected Columns Info", visible=DEBUG_VISIBLE
212
+ )
213
+
214
+ with gr.Accordion(
215
+ "Custom Dataset Settings", open=True, visible=DEBUG_VISIBLE
216
+ ) as settings:
217
+ # File upload options for custom dataset (train and test)
218
+ custom_file = gr.File(
219
+ label="Upload Custom Dataset",
220
+ file_types=[".csv"],
221
+ )
222
+ train_file = gr.File(
223
+ label="Upload Custom Train Dataset",
224
+ file_types=[".csv"],
225
+ visible=False,
226
+ )
227
+ train_display = gr.Dataframe(
228
+ label="Train Dataset Preview (First 5 Rows)",
229
+ interactive=False,
230
+ visible=DEBUG_VISIBLE,
231
+ )
232
+
233
+ test_file = gr.File(
234
+ label="Upload Custom Test Dataset",
235
+ file_types=[".csv"],
236
+ visible=False,
237
+ )
238
+ test_display = gr.Dataframe(
239
+ label="Test Dataset Preview (First 5 Rows)",
240
+ interactive=False,
241
+ visible=DEBUG_VISIBLE,
242
+ )
243
+
244
+ # Predefined dataset displays
245
+ predefined_display = gr.Dataframe(
246
+ label="Predefined Dataset Preview (First 5 Rows)",
247
+ interactive=False,
248
+ visible=DEBUG_VISIBLE,
249
+ )
250
+
251
+ # Dropdowns for selecting input and output columns for the custom dataset
252
+ input_column_selector = gr.Dropdown(
253
+ label="Select Input Column",
254
+ choices=[],
255
+ allow_custom_value=True,
256
+ visible=DEBUG_VISIBLE,
257
+ )
258
+ output_column_selector = gr.Dropdown(
259
+ label="Select Output Column",
260
+ choices=[],
261
+ allow_custom_value=True,
262
+ visible=DEBUG_VISIBLE,
263
+ )
264
+
265
+ # When a custom train file is uploaded, display its head and update column selectors
266
+ train_file.change(
267
+ display_csv_head,
268
+ inputs=train_file,
269
+ outputs=[
270
+ train_display,
271
+ input_column_selector,
272
+ output_column_selector,
273
+ ],
274
+ )
275
+
276
+ # When a custom test file is uploaded, display its head
277
+ test_file.change(
278
+ display_csv_head,
279
+ inputs=test_file,
280
+ outputs=[
281
+ test_display,
282
+ input_column_selector,
283
+ output_column_selector,
284
+ ],
285
+ )
286
+
287
+ model_checkbox = gr.CheckboxGroup(
288
+ choices=models_enabled, label="Select Model", visible=DEBUG_VISIBLE
289
+ )
290
+
291
+ task_radiobutton = gr.Radio(
292
+ choices=["Classification", "Regression"],
293
+ label="Task Type",
294
+ visible=DEBUG_VISIBLE,
295
+ )
296
+
297
+ # When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
298
+ # When a predefined dataset is selected, load its head and update column selectors
299
+ dataset_selector.change(lambda: None, outputs=custom_file).then(
300
+ handle_dataset_selection,
301
+ inputs=[dataset_selector, state],
302
+ outputs=[settings, task_radiobutton],
303
+ ).then(
304
+ load_predefined_dataset,
305
+ inputs=dataset_selector,
306
+ outputs=[
307
+ predefined_display,
308
+ input_column_selector,
309
+ output_column_selector,
310
+ selected_columns_message,
311
+ ],
312
+ )
313
+
314
+ custom_file.change(
315
+ process_custom_file,
316
+ inputs=[custom_file, dataset_selector],
317
+ outputs=[
318
+ train_file,
319
+ test_file,
320
+ input_column_selector,
321
+ output_column_selector,
322
+ task_radiobutton,
323
+ ],
324
+ )
325
+ eval_clear_button = gr.Button("Clear")
326
+ eval_button = gr.Button("Submit", variant="primary")
327
+ step_slider = gr.Slider(
328
+ minimum=0,
329
+ maximum=8,
330
+ value=0,
331
+ label="Progress",
332
+ show_label=True,
333
+ interactive=False,
334
+ visible=False,
335
+ )
336
+
337
+ # Right Column
338
+ with gr.Column():
339
+ log_table = gr.Dataframe(value=blank_df, interactive=False)
340
+
341
+ plot_radio = gr.Radio(choices=[], label="Select Plot Type")
342
+
343
+ plot_output = gr.Plot(label="Visualization")
344
+
345
+ log_table.select(log_selected, [log_table, state]).success(
346
+ update_plot_choices, inputs=[plot_radio, state], outputs=plot_radio
347
+ ).then(display_plot, inputs=[plot_radio, state], outputs=plot_output)
348
+
349
+ def clear_eval(state):
350
+ state["results"] = {}
351
+ return None, gr.update(choices=[], value=None), blank_df
352
+
353
+ def eval_part(part, step, selector, show_progress=False):
354
+ return (
355
+ part.then(
356
+ lambda: [models_enabled[x] for x in selector],
357
+ outputs=model_checkbox,
358
+ )
359
+ .then(
360
+ evaluate_and_log,
361
+ inputs=[
362
+ model_checkbox,
363
+ selected_columns_message,
364
+ task_radiobutton,
365
+ log_table,
366
+ state,
367
+ ],
368
+ outputs=log_table,
369
+ show_progress=show_progress,
370
+ )
371
+ .then(lambda: step, outputs=step_slider, show_progress=False)
372
+ )
373
+
374
+ part = (
375
+ eval_button.click(
376
+ lambda: (
377
+ gr.update(interactive=False),
378
+ gr.update(interactive=False),
379
+ ),
380
+ outputs=[eval_clear_button, eval_button],
381
+ )
382
+ .then(
383
+ select_columns,
384
+ inputs=[
385
+ input_column_selector,
386
+ output_column_selector,
387
+ train_file,
388
+ test_file,
389
+ state,
390
+ ],
391
+ outputs=selected_columns_message,
392
+ )
393
+ .then(
394
+ clear_eval,
395
+ inputs=state,
396
+ outputs=[
397
+ plot_output,
398
+ plot_radio,
399
+ log_table,
400
+ ],
401
+ )
402
+ )
403
+ part = part.then(
404
+ lambda: gr.update(value=0, visible=True),
405
+ outputs=step_slider,
406
+ show_progress=False,
407
+ )
408
+ part = eval_part(part, 1, [0], True)
409
+ part = eval_part(part, 2, [1])
410
+ part = eval_part(part, 3, [2])
411
+ part = eval_part(part, 4, [3])
412
+ part = eval_part(part, 5, [1, 2])
413
+ part = eval_part(part, 6, [2, 3])
414
+ part = eval_part(part, 7, [1, 3])
415
+ part = eval_part(part, 8, [1, 2, 3])
416
+ part = part.then(
417
+ lambda: gr.update(visible=False),
418
+ outputs=step_slider,
419
+ show_progress=False,
420
+ )
421
+ part.then(
422
+ lambda: (
423
+ gr.update(interactive=True),
424
+ gr.update(interactive=True),
425
+ ),
426
+ outputs=[eval_clear_button, eval_button],
427
+ )
428
+
429
+ plot_radio.change(
430
+ display_plot, inputs=[plot_radio, state], outputs=plot_output
431
+ )
432
+
433
+ eval_clear_button.click(
434
+ clear_eval,
435
+ inputs=state,
436
+ outputs=[
437
+ plot_output,
438
+ plot_radio,
439
+ log_table,
440
+ ],
441
+ ).then(lambda: " ", outputs=dataset_selector)
442
+
443
+
444
+ # Molecule Generation
445
+ with gr.Blocks() as molecule_generation:
446
+ gr.HTML(
447
+ '''
448
+ <p style="text-align: center">
449
+ Task : Molecule Generation
450
+ <br>
451
+ Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.
452
+ </p>
453
+ '''
454
+ )
455
+ with gr.Row():
456
+ with gr.Column():
457
+ smiles_input = gr.Textbox(label="Input SMILES String")
458
+ image_display = gr.Image(label="Molecule Image", height=250, width=250)
459
+ # Show images for selection
460
+ with gr.Accordion("Select from sample molecules", open=False):
461
+ image_selector = gr.Radio(
462
+ choices=list(smiles_image_mapping.keys()),
463
+ label="Select from sample molecules",
464
+ value=None,
465
+ )
466
+ image_selector.change(load_image, image_selector, image_display)
467
+ clear_button = gr.Button("Clear")
468
+ generate_button = gr.Button("Submit", variant="primary")
469
+
470
+ # Right Column
471
+ with gr.Column():
472
+ gen_image_display = gr.Image(
473
+ label="Generated Molecule Image", height=250, width=250
474
+ )
475
+ generated_output = gr.Textbox(label="Generated Output")
476
+ property_table = gr.Dataframe(label="Molecular Properties Comparison")
477
+
478
+ # Handle image selection
479
+ image_selector.change(
480
+ handle_image_selection,
481
+ inputs=image_selector,
482
+ outputs=[smiles_input, image_display],
483
+ )
484
+ smiles_input.change(
485
+ smiles_to_image, inputs=smiles_input, outputs=image_display
486
+ )
487
+
488
+ # Generate button to display canonical SMILES and molecule image
489
+ generate_button.click(
490
+ lambda: (
491
+ gr.update(interactive=False),
492
+ gr.update(interactive=False),
493
+ ),
494
+ outputs=[clear_button, generate_button],
495
+ ).then(
496
+ generate_canonical,
497
+ inputs=smiles_input,
498
+ outputs=[property_table, generated_output, gen_image_display],
499
+ ).then(
500
+ lambda: (
501
+ gr.update(interactive=True),
502
+ gr.update(interactive=True),
503
+ ),
504
+ outputs=[clear_button, generate_button],
505
+ )
506
+ clear_button.click(
507
+ lambda: (None, None, None, None, None, None),
508
+ outputs=[
509
+ smiles_input,
510
+ image_display,
511
+ image_selector,
512
+ gen_image_display,
513
+ generated_output,
514
+ property_table,
515
+ ],
516
+ )
517
+
518
+
519
+ # Render with tabs
520
+ gr.TabbedInterface(
521
+ [introduction, property_prediction, molecule_generation],
522
+ ["Introduction", "Property Prediction", "Molecule Generation"],
523
+ ).launch(server_name="0.0.0.0", allowed_paths=["./"])
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/bace/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/bace/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/bace/valid.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/esol/test.csv ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,selfies,prop,smiles
2
+ 0,[Cl] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [C] [C] [C] [C] [Branch1] [Branch2] [C] [O] [C] [Ring1] [=Branch1] [Ring1] [Ring1] [C] [Ring1] [Branch2] [C] [Ring1] [=C] [Branch1] [C] [Cl] [C] [Ring1] [=N] [Branch1] [C] [Cl] [Cl],-4.533,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl
3
+ 1,[C] [C] [C] [C] [C] [=O],-1.103,CCCCC=O
4
+ 2,[O] [C] [C] [C] [C] [=C],-0.7909999999999999,OCCCC=C
5
+ 3,[C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [N] [=C] [C] [Branch1] [C] [N] [=C] [Branch1] [C] [Br] [C] [Ring1] [Branch2] [=O],-3.005,c1ccccc1n2ncc(N)c(Br)c2(=O)
6
+ 4,[N] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-1.231,Nc1ccc(O)cc1
7
+ 5,[C] [C] [Branch1] [C] [C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.817,CC(C)CCOC(=O)C
8
+ 6,[C] [O] [P] [=Branch1] [C] [=S] [Branch1] [Ring1] [O] [C] [S] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [C] [C] [C] [=O],-2.087,COP(=S)(OC)SCC(=O)N(C)C=O
9
+ 7,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=Branch1] [Ring2] [=C] [Ring1] [#Branch1] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-6.312,Clc1ccc(Cl)c(c1)c2ccc(Cl)c(Cl)c2
10
+ 8,[C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [N] [=C] [C] [=N] [C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [Cl],-4.438,c2(Cl)c(Cl)c(Cl)c1nccnc1c2(Cl)
11
+ 9,[C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [N] [=C] [Branch1] [=Branch1] [N] [=C] [Ring1] [#Branch1] [O] [N] [Branch1] [C] [C] [C],-3.57,CCCCc1c(C)nc(nc1O)N(C)C
12
+ 10,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [C] [=Branch1] [C] [=O] [O] [C] [C],-1.413,CCOC(=O)CC(=O)OCC
13
+ 11,[C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-3.192,CC(C)(C)c1ccc(O)cc1
14
+ 12,[C] [C] [=C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch1],-3.035,Cc1cccc(C)c1
15
+ 13,[C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.125,CCCOC(=O)C
16
+ 14,[C] [S] [C] [=N] [N] [=C] [Branch1] [=Branch2] [C] [=Branch1] [C] [=O] [N] [Ring1] [#Branch1] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.324,CSc1nnc(c(=O)n1N)C(C)(C)C
17
+ 15,[Cl] [C] [=C] [C] [=C] [Branch1] [Branch1] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [Cl],-5.142,Clc1ccc(cc1)c2ccccc2Cl
18
+ 16,[C] [C] [C] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [C] [Branch1] [Ring2] [C] [Ring1] [Branch2] [C] [Branch1] [C] [O] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2],-1.5319999999999998,CC1CC(C)C(=O)C(C1)C(O)CC2CC(=O)NC(=O)C2
19
+ 17,[C] [N] [C] [=Branch1] [C] [=O] [O] [C] [=C] [C] [=C] [C] [Branch1] [Branch2] [N] [=C] [N] [Branch1] [C] [C] [C] [=C] [Ring1] [O],-1.846,CNC(=O)Oc1cccc(N=CN(C)C)c1
20
+ 18,[C] [C] [=C] [C] [=N] [C] [N] [Branch1] [=Branch1] [C] [C] [C] [Ring1] [Ring1] [C] [=N] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [N] [C] [Ring2] [Ring1] [Ring1] [=Ring1] [#C],-3.397,Cc3ccnc4N(C1CC1)c2ncccc2C(=O)Nc34
21
+ 19,[C] [C] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.389,CCNc1ccccc1
22
+ 20,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Ring2] [Ring1] [Ring1] [Ring1] [#Branch2],-6.297000000000001,Cc1c2ccccc2c(C)c3ccc4ccccc4c13
23
+ 21,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [F] [C] [Branch1] [C] [Cl] [=C] [Ring1] [=Branch2] [F],-5.462000000000001,Fc1cccc(F)c1C(=O)NC(=O)Nc2cc(Cl)c(F)c(Cl)c2F
24
+ 22,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.057,COc1ccc(Cl)cc1
25
+ 23,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=N] [Ring1] [=Branch1],-4.2010000000000005,o1c2ccccc2c3ccccc13
26
+ 24,[C] [=C] [C] [=C] [N] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [Ring1] [#Branch2] [=C] [Ring1] [=C],-3.846,c3ccc2nc1ccccc1cc2c3
27
+ 25,[C] [C] [C] [C] [=Branch1] [C] [=O] [C] [C] [Branch1] [P] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [O] [Ring1] [#Branch1] [C] [C] [Ring1] [P] [C] [C] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [O] [C] [=Branch1] [C] [=O] [C] [O],-2.893,CC12CC(=O)C3C(CCC4=CC(=O)CCC34C)C2CCC1(O)C(=O)CO
28
+ 26,[C] [C] [C] [=C] [C] [=C] [C] [Branch1] [Ring1] [C] [C] [=C] [Ring1] [Branch2] [N] [Branch1] [Ring2] [C] [O] [C] [C] [=Branch1] [C] [=O] [C] [Cl],-3.319,CCc1cccc(CC)c1N(COC)C(=O)CCl
29
+ 27,[C] [C] [C] [C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-4.157,CCCCN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
30
+ 28,[C] [S] [C] [=Branch1] [C] [=S] [N] [C] [Ring1] [=Branch1] [=O],-0.396,C1SC(=S)NC1(=O)
31
+ 29,[O] [C] [=C] [C] [=C] [Branch1] [Branch2] [C] [Branch1] [C] [O] [=C] [Ring1] [#Branch1] [C] [O] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [Branch1] [C] [O] [=C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [C] [=Ring1] [=N] [O],-2.7310000000000003,Oc1ccc(c(O)c1)c3oc2cc(O)cc(O)c2c(=O)c3O
32
+ 30,[C] [N] [Branch1] [C] [C] [C] [=N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C],-3.164,CN(C)C=Nc1ccc(Cl)cc1C
33
+ 31,[N] [C] [=Branch1] [C] [=O] [N] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [=Branch1] [=O],0.652,NC(=O)NC1NC(=O)NC1=O
34
+ 32,[Cl] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-4.063,Clc1cccc2ccccc12
35
+ 33,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.352,Oc1ccc(Cl)c(Cl)c1
36
+ 34,[C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [=C] [Branch1] [C] [Cl] [Cl] [C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-6.775,CC1(C)C(C=C(Cl)Cl)C1C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
37
+ 35,[C] [=C] [C] [=C] [NH1] [N] [=N] [C] [Ring1] [Branch1] [=C] [Ring1] [=Branch2],-2.21,c2ccc1[nH]nnc1c2
38
+ 36,[C] [C] [Branch1] [C] [C] [C] [Branch2] [Ring1] [Branch1] [N] [C] [=C] [C] [=C] [Branch1] [=Branch1] [C] [=C] [Ring1] [=Branch1] [Cl] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-8.057,CC(C)C(Nc1ccc(cc1Cl)C(F)(F)F)C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
39
+ 37,[C] [C] [C],-1.5530000000000002,CCC
40
+ 38,[C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [O] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-3.792,C1Cc2cccc3cccc1c23
41
+ 39,[C] [C] [C] [#C],-1.092,CCC#C
42
+ 40,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.5580000000000003,Clc1ccc(Cl)cc1
43
+ 41,[C] [C] [=C] [NH1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch2] [Ring1] [=Branch1],-2.9810000000000003,Cc1c[nH]c2ccccc12
44
+ 42,[C] [C] [#N],0.152,CC#N
45
+ 43,[C] [C] [C] [C] [O],-0.688,CCCCO
46
+ 44,[C] [C] [=Branch1] [C] [=C] [C] [=Branch1] [C] [=C] [C],-2.052,CC(=C)C(=C)C
47
+ 45,[C] [C] [C] [Branch1] [C] [C] [C] [C] [O],-1.308,CCC(C)CCO
48
+ 46,[Cl] [C] [=C] [C] [=C] [Branch1] [=Branch2] [C] [Branch1] [C] [Cl] [=C] [Ring1] [#Branch1] [Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2] [Cl],-7.192,Clc1ccc(c(Cl)c1Cl)c2ccc(Cl)c(Cl)c2Cl
49
+ 47,[C] [C] [=C] [C] [=Branch2] [Ring1] [=Branch1] [=C] [C] [=C] [Ring1] [=Branch1] [N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-4.945,Cc1cc(ccc1NS(=O)(=O)C(F)(F)F)S(=O)(=O)c2ccccc2
50
+ 48,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [Cl],-3.22,Oc1ccc(Cl)cc1Cl
51
+ 49,[C] [N] [C] [=Branch2] [Ring1] [Ring2] [=C] [Branch1] [C] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [S] [Ring1] [O] [=Branch1] [C] [=O] [=O] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=N] [Ring1] [=Branch1],-3.4730000000000003,CN2C(=C(O)c1ccccc1S2(=O)=O)C(=O)Nc3ccccn3
52
+ 50,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [C] [Ring2] [Ring1] [C] [=O],-3.872,CC12CCC3C(CCc4cc(O)ccc34)C2CCC1=O
53
+ 51,[C] [C] [=C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1],-4.147,Cc1cccc2c(C)cccc12
54
+ 52,[N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [O] [N] [C] [N] [S] [Ring1] [=Branch1] [=Branch1] [C] [=O] [=O] [C] [=C] [Ring1] [N] [Cl],-1.72,NS(=O)(=O)c2cc1c(NCNS1(=O)=O)cc2Cl
55
+ 53,[O] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-2.725,Oc1cccc2cccnc12
56
+ 54,[C] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [Ring1] [#Branch2],-3.447,C1CCc2ccccc2C1
57
+ 55,[C] [C] [O] [C] [Branch1] [C] [C] [O] [C] [C],-0.899,CCOC(C)OCC
58
+ 56,[C] [C] [C] [C] [Ring1] [Ring1] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [Branch1] [C] [Ring1] [Branch2] [=O] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.464,CC12CC2(C)C(=O)N(C1=O)c3cc(Cl)cc(Cl)c3
59
+ 57,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-4.87,Cc1c2ccccc2cc3ccccc13
60
+ 58,[C] [C] [C] [C] [O] [C],-1.072,CCCCOC
61
+ 59,[C] [C] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [Ring1] [#Branch1] [C] [C] [C] [C] [C] [C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [C] [O] [C] [Ring1] [=Branch2] [Branch1] [N] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [#Branch1] [Ring1] [=C] [C] [=O],-3.0660000000000003,CC13CCC(=O)C=C1CCC4C2CCC(C(=O)CO)C2(CC(O)C34)C=O
62
+ 60,[C] [C] [C] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [O] [=O],-1.6030000000000002,CCC1(C(C)C)C(=O)NC(=O)NC1=O
63
+ 61,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-2.761,CCOC(=O)c1ccc(O)cc1
64
+ 62,[C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [C] [=C] [Ring2] [Ring1] [C] [C] [Ring1] [S] [=C] [Ring1] [=C] [C] [Ring1] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-6.885,c1cc2ccc3ccc4ccc5ccc6ccc1c7c2c3c4c5c67
65
+ 63,[C] [C] [N] [C] [=C] [C] [Branch1] [=Branch1] [N] [Branch1] [C] [C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch2] [N] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [N] [=C] [Ring2] [Ring1] [Ring2] [Ring1] [=Branch1],-4.408,CCN2c1cc(N(C)C)cc(C)c1NC(=O)c3cccnc23
66
+ 64,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.301,CN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
67
+ 65,[C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-3.3080000000000003,CCCCCC(C)C
68
+ 66,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [N] [N] [=C] [Branch1] [#C] [N] [=C] [Ring1] [#Branch1] [C] [Branch1] [Ring1] [O] [C] [=C] [Ring1] [=N] [O] [C] [N] [C] [C] [N] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [O] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-3.958,COc2cc1c(N)nc(nc1c(OC)c2OC)N3CCN(CC3)C(=O)OCC(C)(C)O
69
+ 67,[C] [=C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [#Branch1] [C] [=C] [Ring1] [O],-0.636,c1cC2C(=O)NC(=O)C2cc1
70
+ 68,[C] [C] [C] [=O],-0.3939999999999999,CCC=O
71
+ 69,[Cl] [C] [=C] [C] [=C] [Branch2] [Ring1] [=Branch2] [C] [N] [Branch1] [Branch2] [C] [C] [C] [C] [C] [Ring1] [Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [Ring2] [Ring1] [=Branch1],-5.126,Clc1ccc(CN(C2CCCC2)C(=O)Nc3ccccc3)cc1
72
+ 70,[C] [C] [C] [C] [C] [Branch1] [Ring1] [C] [C] [C] [=O],-2.232,CCCCC(CC)C=O
73
+ 71,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-2.312,O=C1NC(=O)NC(=O)C1(CC)CCC(C)C
74
+ 72,[C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-1.857,CC(=O)Nc1ccccc1
75
+ 73,[C] [=N] [C] [=C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [N] [=C] [Ring1] [#Branch2],-0.7170000000000001,c1nccc(C(=O)NN)c1
76
+ 74,[C] [C] [Branch1] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [Ring2] [C] [Ring1] [=Branch1] [C] [Ring1] [=Branch2] [=O],-2.158,CC2(C)C1CCC(C)(C1)C2=O
77
+ 75,[C] [O] [C] [=C] [N] [=C] [C] [=N] [C] [=N] [C] [Ring1] [=Branch1] [=N] [Ring1] [#Branch2],-1.589,COc2cnc1cncnc1n2
78
+ 76,[C] [N] [C] [=Branch1] [C] [=O] [C] [=C] [Branch1] [C] [C] [O] [P] [=Branch1] [C] [=O] [Branch1] [Ring1] [O] [C] [O] [C],-0.949,CNC(=O)C=C(C)OP(=O)(OC)OC
79
+ 77,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [Branch1] [Ring1] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [Ring1] [=Branch1],-3.784,O2c1ccccc1N(CC)C(=O)c3ccccc23
80
+ 78,[C] [=C] [C] [=C] [C] [=C] [Branch1] [Ring1] [O] [C] [C] [Branch1] [Branch2] [C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [=N] [O] [C] [Ring1] [P] [=O],-4.0760000000000005,c1cc2ccc(OC)c(CC=C(C)(C))c2oc1=O
81
+ 79,[C] [C] [C] [S] [C] [C] [C],-2.307,CCCSCCC
82
+ 80,[C] [O] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-2.948,CON(C)C(=O)Nc1ccc(Cl)cc1
83
+ 81,[C] [C] [O] [C] [C],-0.718,CCOCC
84
+ 82,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [Ring1] [O],-3.858,CC34CCC1C(CCc2cc(O)ccc12)C3CC(O)C4O
85
+ 83,[C] [C] [N] [C] [=N] [C] [Branch1] [C] [Cl] [=N] [C] [Branch1] [O] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [#N] [=N] [Ring1] [=N],-2.49,CCNc1nc(Cl)nc(NC(C)(C)C#N)n1
86
+ 84,[C] [C] [Branch1] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-1.6469999999999998,CC(C)CC(C)(C)O
87
+ 85,[Cl] [C] [=C] [C] [=C] [C] [Branch1] [C] [Br] [=C] [Ring1] [#Branch1],-3.928,Clc1cccc(Br)c1
88
+ 86,[C] [C] [C] [C] [C] [C] [Branch1] [C] [O] [C] [C],-2.033,CCCCCC(O)CC
89
+ 87,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.126,O=C1NC(=O)NC(=O)C1(CC)CC=C(C)C
90
+ 88,[C] [C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [C] [Branch1] [C] [Br] [=C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [N] [=O],-2.766,CCC(C)C1(CC(Br)=C)C(=O)NC(=O)NC1=O
91
+ 89,[C] [O] [C] [=Branch1] [C] [=O] [C],-0.416,COC(=O)C
92
+ 90,[C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [#Branch1] [O],-3.129,CC(C)c1ccc(C)cc1O
93
+ 91,[C],-0.636,C
94
+ 92,[N] [C] [=N] [C] [Branch1] [C] [O] [=N] [C] [N] [=C] [NH1] [C] [Ring1] [#Branch2] [=Ring1] [Branch1],-1.74,Nc1nc(O)nc2nc[nH]c12
95
+ 93,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-4.692,Fc1cccc(F)c1C(=O)NC(=O)Nc2ccc(Cl)cc2
96
+ 94,[C] [C] [C] [C] [C] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O] [Ring1] [#Branch2],-2.579,CC12CCC(CC1)C(C)(C)O2
97
+ 95,[C] [C] [O],0.02,CCO
98
+ 96,[C] [=C] [Branch2] [Ring1] [C] [N] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [C] [C] [=C] [C] [=C] [Ring1] [P],-2.29,c1c(NC(=O)OC(C)C(=O)NCC)cccc1
99
+ 97,[C] [C] [Branch1] [C] [C] [=C] [C] [C] [Branch2] [Ring1] [#Branch2] [C] [=Branch1] [C] [=O] [O] [C] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [C] [C],-6.763,CC(C)=CC3C(C(=O)OCc2cccc(Oc1ccccc1)c2)C3(C)C
100
+ 98,[C] [C] [C] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Branch1] [Branch2] [N] [C] [=Branch1] [C] [=O] [O] [C] [=N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-2.902,CCCCNC(=O)n1c(NC(=O)OC)nc2ccccc12
101
+ 99,[C] [N] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.542,CN(C)c1ccccc1
102
+ 100,[C] [O] [C] [=Branch1] [C] [=O] [C] [=C],-0.878,COC(=O)C=C
103
+ 101,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [=N] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C] [=C] [Ring1] [=C],-4.477,CN(C)C(=O)Nc2ccc(Oc1ccc(Cl)cc1)cc2
104
+ 102,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.465,O=C1NC(=O)NC(=O)C1(C(C)C)CC=C(C)C
105
+ 103,[C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1] [C],-2.6210000000000004,Cc1ccc(O)cc1C
106
+ 104,[Cl] [C] [=C] [C] [=C] [C] [=Branch1] [Ring2] [=N] [Ring1] [=Branch1] [C] [Branch1] [C] [Cl] [Branch1] [C] [Cl] [Cl],-3.833,Clc1cccc(n1)C(Cl)(Cl)Cl
107
+ 105,[C] [C] [=Branch1] [C] [=O] [O] [C] [Branch2] [Ring1] [=C] [C] [C] [C] [C] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [#Branch1] [C] [Ring1] [O] [C] [C] [C] [Ring2] [Ring1] [C] [Ring1] [#C] [C] [C] [#C],-4.2410000000000005,CC(=O)OC3(CCC4C2CCC1=CC(=O)CCC1C2CCC34C)C#C
108
+ 106,[C] [N] [C] [=Branch1] [C] [=O] [O] [N] [=C] [Branch1] [Ring2] [C] [S] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.7,CNC(=O)ON=C(CSC)C(C)(C)C
109
+ 107,[C] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [O],-2.033,CCCCCCC(C)O
data/esol/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/lce/test.csv ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE
2
+ C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0.0,O,0.0,O,0.0,O,0.0,1.629
3
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,1.085
4
+ COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0.0,O,0.0,O,0.0,2.056
5
+ COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],,O,0.0,O,0.0,1.658
6
+ C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0.0,1.638
7
+ C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.276
8
+ O1CCOC1,0.368,COCCOC,0.547,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.076,CSi(C)(C)([N+]).C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.008,O,0.0,O,0.0,1.569
9
+ COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0.0,O,0.0,O,0.0,2.268
10
+ C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,1.602
11
+ C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0.0,1.678
12
+ O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0.0,O,0.0,2.0
13
+ C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0.0,O,0.0,O,0.0,0.921
14
+ C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0.0,O,0.0,O,0.0,1.301
15
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0.0,O,0.0,0.854
16
+ C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0.0,O,0.0,1.108
17
+ O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0.0,O,0.0,O,0.0,1.523
18
+ CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0.0,O,0.0,O,0.0,O,0.0,1.921
19
+ CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0.0,O,0.0,O,0.0,O,0.0,1.602
20
+ O1CCOC1,0.375,COCCOC,0.557,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.061,O,0.0,1.523
21
+ COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0.0,O,0.0,O,0.0,2.155
22
+ C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.26
23
+ CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0.0,O,0.0,O,0.0,2.155
24
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.225
25
+ COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155
26
+ COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0.0,O,0.0,O,0.0,2.155
27
+ O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0.0,O,0.0,O,0.0,2.301
28
+ COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0.0,O,0.0,O,0.0,O,0.0,1.991
29
+ COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,2.301
30
+ C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0.0,O,0.0,O,0.0,1.398
31
+ COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0.0,O,0.0,O,0.0,1.268
data/lce/test_data.csv ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ smiles1,conc1,mol1,smiles2,conc2,mol2,smiles3,conc3,mol3,smiles4,conc4,mol4,smiles5,conc5,mol5,smiles6,conc6,LCE_Predicted,LCE
2
+ C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.187,1.094
3
+ COCCOC,0.596,59.5609428,COCCOCCOCCOCCOC,0.281,28.07124115,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.124,12.36781605,O,0,0,O,0,0,O,0,1.691,1.384
4
+ C1COC(=O)O1,0.285,28.50894036,C1C(OC(=O)O1)F,0.261,26.07552384,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.228,22.82322096,COC(=O)OC,0.226,22.59231484,O,0,0,O,0,1.508,1.468
5
+ COCCOC,0.434,43.4423376,COCCOCCOCCOCCOC,0.205,20.47449683,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.361,36.08316557,O,0,0,O,0,0,O,0,1.882,1.71
6
+ C1C(OC(=O)O1)F,0.187,18.72872664,COC(=O)OC,0.162,16.22691423,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.109,10.92850826,FC(F)C(F)(F)COC(F)(F)C(F)F,0.541,54.11585087,O,0,0,O,0,2.103,1.832
7
+ C1COC(=O)O1,0.134,13.35070843,C1C(OC(=O)O1)F,0.122,12.2111419,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.107,10.72028474,COC(=O)OC,0.106,10.57995858,FC(F)C(F)(F)COC(F)(F)C(F)F,0.531,53.13790635,O,0,2.077,2.104
8
+ COCCOC,0.096,9.614613177,COCCOCCOCCOCCOC,0.045,4.53139444,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.12,12.01491409,C1COCO1,0.143,14.28400162,FC(F)C(F)(F)COC(F)(F)C(F)F,0.596,59.55507668,O,0,2.211,2.274
9
+ C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].F[P-](F)(F)(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.17,1.071
10
+ C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.077,1.166
11
+ C1COC(=O)O1,0.519,51.85215842,COC(=O)OC,0.411,41.09097965,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.918492083,[Li+].[N+](=O)([O-])[O-],0.001,0.138369842,O,0,0,O,0,1.19,1.335
12
+ C1COC(=O)O1,0.513,51.33049845,COC(=O)OC,0.407,40.6775828,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.9173773,C1=COC(=O)O1,0.011,1.07454145,O,0,0,O,0,1.114,1.129
13
+ COCCOC,0.53,53.00533987,COCCOCCOCCOCCOC,0.25,24.98156691,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.22,22.01309322,O,0,0,O,0,0,O,0,1.758,1.501
14
+ COCCOC,0.477,47.74974224,COCCOCCOCCOCCOC,0.225,22.50458884,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.297,29.74566892,O,0,0,O,0,0,O,0,1.821,1.663
data/lce/train.csv ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE
2
+ C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,1.155
3
+ C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.046
4
+ O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0.0,O,0.0,O,0.0,O,0.0,1.569
5
+ C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0.0,O,0.0,O,0.0,0.886
6
+ COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0.0,O,0.0,O,0.0,O,0.0,1.367
7
+ COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,2.301
8
+ C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0.0,O,0.0,O,0.0,O,0.0,1.489
9
+ COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0.0,O,0.0,O,0.0,1.244
10
+ C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0.0,O,0.0,0.745
11
+ COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0.0,O,0.0,O,0.0,1.292
12
+ CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0.0,O,0.0,O,0.0,2.301
13
+ O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,O,0.0,1.745
14
+ COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,1.745
15
+ C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0.0,O,0.0,1.076
16
+ C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.854
17
+ C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0.0,1.678
18
+ FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155
19
+ CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222
20
+ COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,O,0.0,1.638
21
+ CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0.0,O,0.0,O,0.0,O,0.0,2.0
22
+ COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,1.854
23
+ O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0.0,O,0.0,O,0.0,1.959
24
+ C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0.0,O,0.0,O,0.0,O,0.0,1.587
25
+ CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0.0,O,0.0,O,0.0,0.699
26
+ C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0.0,O,0.0,2.097
27
+ C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0.0,O,0.0,O,0.0,1.59
28
+ C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252
29
+ COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,1.337
30
+ C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0.0,O,0.0,O,0.0,O,0.0,1.377
31
+ C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0.0,O,0.0,O,0.0,O,0.0,1.544
32
+ CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0.0,O,0.0,O,0.0,O,0.0,2.097
33
+ COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0.0,O,0.0,O,0.0,1.215
34
+ COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0.0,O,0.0,O,0.0,O,0.0,1.222
35
+ C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.194
36
+ O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0.0,O,0.0,1.824
37
+ C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.333
38
+ O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0.0,O,0.0,1.824
39
+ COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0.0,O,0.0,O,0.0,2.051
40
+ COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0.0,O,0.0,O,0.0,1.444
41
+ O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0.0,O,0.0,1.854
42
+ CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0.0,O,0.0,O,0.0,O,0.0,2.046
43
+ C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0.0,O,0.0,1.301
44
+ C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222
45
+ O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,1.903
46
+ COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.561
47
+ C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0.0,O,0.0,1.735
48
+ FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0.0,O,0.0,O,0.0,O,0.0,2.301
49
+ C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.498
50
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0.0,O,0.0,0.745
51
+ O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0.0,O,0.0,O,0.0,1.824
52
+ CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,O,0.0,2.0
53
+ O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0.0,O,0.0,O,0.0,1.456
54
+ COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0.0,O,0.0,O,0.0,O,0.0,1.301
55
+ COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0.0,O,0.0,O,0.0,1.678
56
+ C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0.0,O,0.0,O,0.0,O,0.0,1.646
57
+ O1CCOC1,0.397,COCCOC,0.589,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.012,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.002,O,0.0,1.301
58
+ C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0.0,O,0.0,O,0.0,2.046
59
+ C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0.0,O,0.0,O,0.0,0.788
60
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.373
61
+ O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0.0,O,0.0,O,0.0,O,0.0,1.602
62
+ CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,1.854
63
+ COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0.0,O,0.0,O,0.0,O,0.0,2.097
64
+ O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0.0,O,0.0,O,0.0,O,0.0,1.699
65
+ FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155
66
+ CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0.0,O,0.0,O,0.0,O,0.0,2.208
67
+ COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,1.77
68
+ CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0.0,O,0.0,O,0.0,O,0.0,0.824
69
+ C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0.0,O,0.0,0.924
70
+ CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,2.097
71
+ COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0.0,O,0.0,O,0.0,O,0.0,2.108
72
+ CC1COC(=O)O1,0.922,[LI+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0.0,O,0.0,O,0.0,O,0.0,0.712
73
+ C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0.0,O,0.0,1.081
74
+ C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0.0,O,0.0,1.319
75
+ COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0.0,O,0.0,O,0.0,1.62
76
+ C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0.0,O,0.0,2.222
77
+ C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.777
78
+ CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0.0,O,0.0,O,0.0,O,0.0,2.018
79
+ COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0.0,O,0.0,O,0.0,O,0.0,1.886
80
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0.0,O,0.0,0.699
81
+ CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0.0,O,0.0,O,0.0,O,0.0,1.569
82
+ C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0.0,O,0.0,O,0.0,1.523
83
+ COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155
84
+ C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.488
85
+ O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0.0,2.046
86
+ C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0.0,O,0.0,O,0.0,O,0.0,1.41
87
+ COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0.0,O,0.0,O,0.0,2.222
88
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.559
89
+ COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.301
90
+ CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,O,0.0,1.672
91
+ C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538
92
+ CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0.0,O,0.0,O,0.0,O,0.0,1.796
93
+ C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.355
94
+ C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0.0,O,0.0,O,0.0,1.523
95
+ COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0.0,O,0.0,1.78
96
+ O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0.0,O,0.0,O,0.0,O,0.0,1.456
97
+ O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0.0,O,0.0,O,0.0,O,0.0,1.745
98
+ O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0.0,O,0.0,O,0.0,1.967
99
+ COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0.0,O,0.0,O,0.0,O,0.0,2.097
100
+ COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0.0,O,0.0,O,0.0,O,0.0,1.143
101
+ O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0.0,O,0.0,1.523
102
+ COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0.0,O,0.0,O,0.0,2.301
103
+ CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,2.155
104
+ C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,1.301
105
+ COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.222
106
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0.0,O,0.0,0.699
107
+ COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,1.495
108
+ C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0.0,O,0.0,O,0.0,2.155
109
+ C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0.0,O,0.0,1.921
110
+ COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0.0,O,0.0,1.886
111
+ CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,2.046
112
+ COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,O,0.0,1.745
113
+ C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0.0,O,0.0,O,0.0,O,0.0,1.633
114
+ C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0.0,O,0.0,O,0.0,2.097
115
+ FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.301
116
+ C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0.0,O,0.0,1.108
117
+ C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0.0,O,0.0,O,0.0,1.62
118
+ CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.959
119
+ C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,O,0.0,1.013
120
+ C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0.0,1.824
121
+ O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0.0,O,0.0,O,0.0,1.921
data/lce/train_data.csv ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ smiles1,conc1,smiles2,conc2,smiles3,conc3,smiles4,conc4,smiles5,conc5,smiles6,conc6,LCE
2
+ CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0,O,0,O,0,0.699
3
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0,O,0,0.699
4
+ FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0,O,0,O,0,O,0,2.301
5
+ FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.301
6
+ CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0,O,0,O,0,2.155
7
+ COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155
8
+ CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,2.155
9
+ O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0,O,0,O,0,1.967
10
+ COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0,O,0,O,0,O,0,1.991
11
+ C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0,O,0,O,0,O,0,1.646
12
+ COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],0.035,O,0,O,0,1.658
13
+ CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,O,0,1.672
14
+ C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252
15
+ C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0,O,0,O,0,2.155
16
+ COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0,O,0,O,0,2.155
17
+ COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0,O,0,O,0,2.155
18
+ FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155
19
+ FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155
20
+ CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0,O,0,O,0,O,0,2.208
21
+ C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0,O,0,2.222
22
+ CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.054,O,0,O,0,O,0,2.222
23
+ C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.029,O,0,O,0,O,0,2.222
24
+ COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0,O,0,O,0,2.222
25
+ COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.222
26
+ COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0,O,0,O,0,2.268
27
+ CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0,O,0,O,0,2.301
28
+ COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,2.301
29
+ COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0,O,0,O,0,2.301
30
+ O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0,O,0,O,0,2.301
31
+ COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,2.301
32
+ COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.301
33
+ O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0,O,0,2
34
+ CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,O,0,2
35
+ CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0,O,0,O,0,O,0,2.018
36
+ CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0,O,0,O,0,O,0,2.046
37
+ C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0,O,0,O,0,2.046
38
+ O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0,2.046
39
+ CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,2.046
40
+ COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0,O,0,O,0,2.051
41
+ COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0,O,0,O,0,2.056
42
+ CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,2.097
43
+ COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0,O,0,O,0,O,0,2.097
44
+ C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0,O,0,O,0,2.097
45
+ C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0,O,0,2.097
46
+ CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0,O,0,O,0,O,0,2.097
47
+ COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0,O,0,O,0,O,0,2.097
48
+ COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0,O,0,O,0,O,0,2.108
49
+ COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155
50
+ CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0,O,0,O,0,O,0,2
51
+ COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0,O,0,O,0,1.678
52
+ C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0,1.678
53
+ C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0,1.678
54
+ O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0,O,0,O,0,O,0,1.699
55
+ C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0,O,0,1.735
56
+ O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,O,0,1.745
57
+ COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,1.745
58
+ COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,O,0,1.745
59
+ O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0,O,0,O,0,O,0,1.745
60
+ COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,1.77
61
+ COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0,O,0,1.78
62
+ CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0,O,0,O,0,O,0,1.796
63
+ C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0,1.824
64
+ O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0,O,0,O,0,1.824
65
+ O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0,O,0,1.824
66
+ O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0,O,0,1.824
67
+ COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,1.854
68
+ O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0,O,0,1.854
69
+ CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,1.854
70
+ COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0,O,0,1.886
71
+ COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0,O,0,O,0,O,0,1.886
72
+ O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,1.903
73
+ O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0,O,0,O,0,1.921
74
+ C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0,O,0,1.921
75
+ CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0,O,0,O,0,O,0,1.921
76
+ O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0,O,0,O,0,1.959
77
+ CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.959
78
+ C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0,O,0,O,0,O,0,1.377
79
+ C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0,O,0,O,0,1.398
80
+ C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0,O,0,O,0,O,0,1.41
81
+ COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0,O,0,O,0,1.444
82
+ O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0,O,0,O,0,O,0,1.456
83
+ O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0,O,0,O,0,1.456
84
+ C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.488
85
+ C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0,O,0,O,0,O,0,1.489
86
+ COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,1.495
87
+ C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.498
88
+ C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0,O,0,O,0,1.523
89
+ O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0,O,0,O,0,1.523
90
+ O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0,O,0,1.523
91
+ C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0,O,0,O,0,1.523
92
+ C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538
93
+ C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0,O,0,O,0,O,0,1.544
94
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.559
95
+ COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.561
96
+ CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0,O,0,O,0,O,0,1.569
97
+ O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0,O,0,O,0,O,0,1.569
98
+ C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0,O,0,O,0,O,0,1.587
99
+ C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0,O,0,O,0,1.59
100
+ C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,1.602
101
+ CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0,O,0,O,0,O,0,1.602
102
+ O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0,O,0,O,0,O,0,1.602
103
+ C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0,O,0,O,0,1.62
104
+ COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0,O,0,O,0,1.62
105
+ C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0,O,0,O,0,O,0,1.629
106
+ C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0,O,0,O,0,O,0,1.633
107
+ COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,O,0,1.638
108
+ C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0,1.638
109
+ C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.26
110
+ COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0,O,0,O,0,1.268
111
+ C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.276
112
+ COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0,O,0,O,0,1.292
113
+ C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0,O,0,1.301
114
+ COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0,O,0,O,0,O,0,1.301
115
+ C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0,O,0,O,0,1.301
116
+ C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,1.301
117
+ C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0,O,0,1.319
118
+ C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.333
119
+ COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,1.337
120
+ C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.355
121
+ COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0,O,0,O,0,O,0,1.367
122
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.373
123
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0,O,0,0.699
124
+ CC1COC(=O)O1,0.922,[Li+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0,O,0,O,0,O,0,0.712
125
+ C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0,O,0,0.745
126
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0,O,0,0.745
127
+ C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.777
128
+ C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0,O,0,O,0,0.788
129
+ CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0,O,0,O,0,O,0,0.824
130
+ C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0,O,0,0.854
131
+ C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.854
132
+ C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0,O,0,O,0,0.886
133
+ C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0,O,0,O,0,0.921
134
+ C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0,O,0,0.924
135
+ C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,O,0,1.013
136
+ C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.046
137
+ C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0,O,0,1.076
138
+ C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0,O,0,1.081
139
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,1.085
140
+ C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0,O,0,1.108
141
+ C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0,O,0,1.108
142
+ COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0,O,0,O,0,O,0,1.143
143
+ C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,1.155
144
+ C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.194
145
+ COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0,O,0,O,0,1.215
146
+ COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0,O,0,O,0,O,0,1.222
147
+ C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.225
148
+ COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0,O,0,O,0,1.244
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.csv filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
3
+ *.pdf filter=lfs diff=lfs merge=lfs -text
models/fm4m.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import roc_auc_score, roc_curve
2
+
3
+ import datetime
4
+ import os
5
+ import umap
6
+ import numpy as np
7
+
8
+ import matplotlib.pyplot as plt
9
+ import pandas as pd
10
+ import pickle
11
+ import json
12
+
13
+ from xgboost import XGBClassifier, XGBRegressor
14
+ import xgboost as xgb
15
+ from sklearn.metrics import roc_auc_score, mean_squared_error
16
+ import xgboost as xgb
17
+ from sklearn.svm import SVR
18
+ from sklearn.linear_model import LinearRegression
19
+ from sklearn.kernel_ridge import KernelRidge
20
+ import json
21
+ from sklearn.compose import TransformedTargetRegressor
22
+ from sklearn.preprocessing import MinMaxScaler
23
+
24
+
25
+ import torch
26
+ from transformers import AutoTokenizer, AutoModel
27
+
28
+ import sys
29
+ sys.path.append("models/")
30
+
31
+ from models.selfies_ted.load import SELFIES as bart
32
+ from models.mhg_model import load as mhg
33
+ from models.smi_ted.smi_ted_light.load import load_smi_ted
34
+
35
+ import mordred
36
+ from mordred import Calculator, descriptors
37
+ from rdkit import Chem
38
+ from rdkit.Chem import AllChem
39
+
40
+ datasets = {}
41
+ models = {}
42
+ downstream_models ={}
43
+
44
+
45
+ def avail_models_data():
46
+ global datasets
47
+ global models
48
+
49
+ datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
50
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
51
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
52
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
53
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
54
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
55
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"}]
56
+
57
+
58
+ models = [{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality", "Timestamp": "2024-06-21 12:32:20"},
59
+ {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
60
+ {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
61
+ {"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
62
+
63
+
64
+ def avail_models(raw=False):
65
+ global models
66
+
67
+ models = [{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model"},
68
+ {"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
69
+ {"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality"},
70
+ {"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
71
+ {"Name": "Mordred", "Model Name": "Mordred","Description": "Baseline: A descriptor-calculation software application that can calculate more than 1800 two- and three-dimensional descriptors"},
72
+ {"Name": "MorganFingerprint", "Model Name": "MorganFingerprint","Description": "Baseline: Circular atom environments based descriptor"}
73
+ ]
74
+
75
+
76
+
77
+ if raw: return models
78
+ else:
79
+ return pd.DataFrame(models).drop('Name', axis=1)
80
+
81
+ return models
82
+
83
+ def avail_downstream_models(raw=False):
84
+ global downstream_models
85
+
86
+ downstream_models = [{"Name": "XGBClassifier", "Task Type": "Classfication"},
87
+ {"Name": "DefaultClassifier", "Task Type": "Classfication"},
88
+ {"Name": "SVR", "Task Type": "Regression"},
89
+ {"Name": "Kernel Ridge", "Task Type": "Regression"},
90
+ {"Name": "Linear Regression", "Task Type": "Regression"},
91
+ {"Name": "DefaultRegressor", "Task Type": "Regression"},
92
+ ]
93
+
94
+ if raw: return downstream_models
95
+ else:
96
+ return pd.DataFrame(downstream_models)
97
+
98
+
99
+
100
+ def avail_datasets():
101
+ global datasets
102
+
103
+ datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv",
104
+ "Timestamp": "2024-06-26 11:27:37"},
105
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre",
106
+ "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
107
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv",
108
+ "Timestamp": "2024-06-26 11:33:47"},
109
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo",
110
+ "Timestamp": "2024-06-26 11:34:37"},
111
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace",
112
+ "Timestamp": "2024-06-26 11:36:40"},
113
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp",
114
+ "Timestamp": "2024-06-26 11:39:23"},
115
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox",
116
+ "Timestamp": "2024-06-26 11:42:43"}]
117
+
118
+ return datasets
119
+
120
+ def reset():
121
+
122
+ """datasets = {"esol": ["smiles", "ESOL predicted log solubility in mols per litre", "data/esol", "2024-06-26 11:36:46.509324"],
123
+ "freesolv": ["smiles", "expt", "data/freesolv", "2024-06-26 11:37:37.393273"],
124
+ "lipo": ["smiles", "y", "data/lipo", "2024-06-26 11:37:37.393273"],
125
+ "hiv": ["smiles", "HIV_active", "data/hiv", "2024-06-26 11:37:37.393273"],
126
+ "bace": ["smiles", "Class", "data/bace", "2024-06-26 11:38:40.058354"],
127
+ "bbbp": ["smiles", "p_np", "data/bbbp","2024-06-26 11:38:40.058354"],
128
+ "clintox": ["smiles", "CT_TOX", "data/clintox","2024-06-26 11:38:40.058354"],
129
+ "sider": ["smiles","1:", "data/sider","2024-06-26 11:38:40.058354"],
130
+ "tox21": ["smiles",":-2", "data/tox21","2024-06-26 11:38:40.058354"]
131
+ }"""
132
+
133
+ datasets = [
134
+ {"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
135
+ {"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
136
+ {"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
137
+ {"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
138
+ {"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
139
+ {"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
140
+ {"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"},
141
+ #{"Dataset": "sider", "Input": "smiles", "Output": "1:", "path": "data/sider", "Timestamp": "2024-06-26 11:38:40.058354"},
142
+ #{"Dataset": "tox21", "Input": "smiles", "Output": ":-2", "path": "data/tox21", "Timestamp": "2024-06-26 11:38:40.058354"}
143
+ ]
144
+
145
+ models = [{"Name": "bart", "Description": "BART model for string based SELFIES modality",
146
+ "Timestamp": "2024-06-21 12:32:20"},
147
+ {"Name": "mol-xl", "Description": "MolFormer model for string based SMILES modality",
148
+ "Timestamp": "2024-06-21 12:35:56"},
149
+ {"Name": "mhg", "Description": "MHG", "Timestamp": "2024-07-10 00:09:42"},
150
+ {"Name": "spec-gru", "Description": "Spectrum modality with GRU", "Timestamp": "2024-07-10 00:09:42"},
151
+ {"Name": "spec-lstm", "Description": "Spectrum modality with LSTM", "Timestamp": "2024-07-10 00:09:54"},
152
+ {"Name": "3d-vae", "Description": "VAE model for 3D atom positions", "Timestamp": "2024-07-10 00:10:08"}]
153
+
154
+
155
+ downstream_models = [
156
+ {"Name": "XGBClassifier", "Description": "XG Boost Classifier",
157
+ "Timestamp": "2024-06-21 12:31:20"},
158
+ {"Name": "XGBRegressor", "Description": "XG Boost Regressor",
159
+ "Timestamp": "2024-06-21 12:32:56"},
160
+ {"Name": "2-FNN", "Description": "A two layer feedforward network",
161
+ "Timestamp": "2024-06-24 14:34:16"},
162
+ {"Name": "3-FNN", "Description": "A three layer feedforward network",
163
+ "Timestamp": "2024-06-24 14:38:37"},
164
+ ]
165
+
166
+ with open("datasets.json", "w") as outfile:
167
+ json.dump(datasets, outfile)
168
+
169
+ with open("models.json", "w") as outfile:
170
+ json.dump(models, outfile)
171
+
172
+ with open("downstream_models.json", "w") as outfile:
173
+ json.dump(downstream_models, outfile)
174
+
175
+ def update_data_list(list_data):
176
+ #datasets[list_data[0]] = list_data[1:]
177
+
178
+ with open("datasets.json", "w") as outfile:
179
+ json.dump(datasets, outfile)
180
+
181
+ avail_models_data()
182
+
183
+ def update_model_list(list_model):
184
+ #models[list_model[0]] = list_model[1]
185
+
186
+ with open("models.json", "w") as outfile:
187
+ json.dump(list_model, outfile)
188
+
189
+ avail_models_data()
190
+
191
+ def update_downstream_model_list(list_model):
192
+ #models[list_model[0]] = list_model[1]
193
+
194
+ with open("downstream_models.json", "w") as outfile:
195
+ json.dump(list_model, outfile)
196
+
197
+ avail_models_data()
198
+
199
+ avail_models_data()
200
+
201
+
202
+
203
+ def get_representation(train_data,test_data,model_type, return_tensor=True):
204
+ alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
205
+ if model_type in alias.keys():
206
+ model_type = alias[model_type]
207
+
208
+ if model_type == "mhg":
209
+ model = mhg.load("../models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle")
210
+ with torch.no_grad():
211
+ train_emb = model.encode(train_data)
212
+ x_batch = torch.stack(train_emb)
213
+
214
+ test_emb = model.encode(test_data)
215
+ x_batch_test = torch.stack(test_emb)
216
+ if not return_tensor:
217
+ x_batch = pd.DataFrame(x_batch)
218
+ x_batch_test = pd.DataFrame(x_batch_test)
219
+
220
+
221
+ elif model_type == "bart":
222
+ model = bart()
223
+ model.load()
224
+ x_batch = model.encode(train_data, return_tensor=return_tensor)
225
+ x_batch_test = model.encode(test_data, return_tensor=return_tensor)
226
+
227
+ elif model_type == "smi-ted":
228
+ # model = load_smi_ted(folder='../models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
229
+ folder = os.path.join(os.path.dirname(__file__), '../models/smi_ted/smi_ted_light')
230
+ model = load_smi_ted(folder=folder, ckpt_filename='smi-ted-Light_40.pt')
231
+ with torch.no_grad():
232
+ x_batch = model.encode(train_data, return_torch=return_tensor)
233
+ x_batch_test = model.encode(test_data, return_torch=return_tensor)
234
+
235
+ elif model_type == "mol-xl":
236
+ model = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True,
237
+ trust_remote_code=True)
238
+ tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)
239
+
240
+ if type(train_data) == list:
241
+ inputs = tokenizer(train_data, padding=True, return_tensors="pt")
242
+ else:
243
+ inputs = tokenizer(list(train_data.values), padding=True, return_tensors="pt")
244
+
245
+ with torch.no_grad():
246
+ outputs = model(**inputs)
247
+
248
+ x_batch = outputs.pooler_output
249
+
250
+ if type(test_data) == list:
251
+ inputs = tokenizer(test_data, padding=True, return_tensors="pt")
252
+ else:
253
+ inputs = tokenizer(list(test_data.values), padding=True, return_tensors="pt")
254
+
255
+ with torch.no_grad():
256
+ outputs = model(**inputs)
257
+
258
+ x_batch_test = outputs.pooler_output
259
+
260
+ if not return_tensor:
261
+ x_batch = pd.DataFrame(x_batch)
262
+ x_batch_test = pd.DataFrame(x_batch_test)
263
+
264
+ elif model_type == 'Mordred':
265
+ all_data = train_data + test_data
266
+ calc = Calculator(descriptors, ignore_3D=True)
267
+ mol_list = [Chem.MolFromSmiles(sm) for sm in all_data]
268
+ x_all = calc.pandas(mol_list)
269
+ print (f'original mordred fv dim: {x_all.shape}')
270
+
271
+ for j in x_all.columns:
272
+ for k in range(len(x_all[j])):
273
+ i = x_all.loc[k, j]
274
+ if type(i) is mordred.error.Missing or type(i) is mordred.error.Error:
275
+ x_all.loc[k, j] = np.nan
276
+
277
+ x_all.dropna(how="any", axis = 1, inplace=True)
278
+ print (f'Nan excluded mordred fv dim: {x_all.shape}')
279
+
280
+ x_batch = x_all.iloc[:len(train_data)]
281
+ x_batch_test = x_all.iloc[len(train_data):]
282
+ # print(f'x_batch: {len(x_batch)}, x_batch_test: {len(x_batch_test)}')
283
+
284
+ elif model_type == 'MorganFingerprint':
285
+ params = {'radius':2, 'nBits':1024}
286
+
287
+ mol_train = [Chem.MolFromSmiles(sm) for sm in train_data]
288
+ mol_test = [Chem.MolFromSmiles(sm) for sm in test_data]
289
+
290
+ x_batch = []
291
+ for mol in mol_train:
292
+ info = {}
293
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info)
294
+ vector = list(fp)
295
+ x_batch.append(vector)
296
+ x_batch = pd.DataFrame(x_batch)
297
+
298
+ x_batch_test = []
299
+ for mol in mol_test:
300
+ info = {}
301
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info)
302
+ vector = list(fp)
303
+ x_batch_test.append(vector)
304
+ x_batch_test = pd.DataFrame(x_batch_test)
305
+
306
+ return x_batch, x_batch_test
307
+
308
+ def single_modal(model,dataset=None, downstream_model=None, params=None, x_train=None, x_test=None, y_train=None, y_test=None):
309
+ print(model)
310
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
311
+ data = avail_models(raw=True)
312
+ df = pd.DataFrame(data)
313
+ #print(list(df["Name"].values))
314
+
315
+ if model in list(df["Name"].values):
316
+ model_type = model
317
+ elif alias[model] in list(df["Name"].values):
318
+ model_type = alias[model]
319
+ else:
320
+ print("Model not available")
321
+ return
322
+
323
+
324
+ data = avail_datasets()
325
+ df = pd.DataFrame(data)
326
+ #print(list(df["Dataset"].values))
327
+
328
+ if dataset in list(df["Dataset"].values):
329
+ task = dataset
330
+ with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
331
+ x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
332
+ print(f" Representation loaded successfully")
333
+
334
+ elif x_train==None:
335
+
336
+ print("Custom Dataset")
337
+ #return
338
+ components = dataset.split(",")
339
+ train_data = pd.read_csv(components[0])[components[2]]
340
+ test_data = pd.read_csv(components[1])[components[2]]
341
+
342
+ y_batch = pd.read_csv(components[0])[components[3]]
343
+ y_batch_test = pd.read_csv(components[1])[components[3]]
344
+
345
+
346
+ x_batch, x_batch_test = get_representation(train_data,test_data,model_type)
347
+
348
+
349
+
350
+ print(f" Representation loaded successfully")
351
+
352
+ else:
353
+
354
+ y_batch = y_train
355
+ y_batch_test = y_test
356
+ x_batch, x_batch_test = get_representation(x_train, x_test, model_type)
357
+
358
+ # exclude row containing Nan value
359
+ if isinstance(x_batch, torch.Tensor):
360
+ x_batch = pd.DataFrame(x_batch)
361
+ nan_indices = x_batch.index[x_batch.isna().any(axis=1)]
362
+ if len(nan_indices) > 0:
363
+ x_batch.dropna(inplace = True)
364
+ for index in sorted(nan_indices, reverse=True):
365
+ del y_batch[index]
366
+ print(f'x_batch Nan index: {nan_indices}')
367
+ print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}')
368
+
369
+ if isinstance(x_batch_test, torch.Tensor):
370
+ x_batch_test = pd.DataFrame(x_batch_test)
371
+ nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)]
372
+ if len(nan_indices) > 0:
373
+ x_batch_test.dropna(inplace = True)
374
+ for index in sorted(nan_indices, reverse=True):
375
+ del y_batch_test[index]
376
+ print(f'x_batch_test Nan index: {nan_indices}')
377
+ print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
378
+
379
+ print(f" Calculating ROC AUC Score ...")
380
+
381
+ if downstream_model == "XGBClassifier":
382
+ if params == None:
383
+ xgb_predict_concat = XGBClassifier()
384
+ else:
385
+ xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
386
+ xgb_predict_concat.fit(x_batch, y_batch)
387
+
388
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
389
+
390
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
391
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
392
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
393
+
394
+ try:
395
+ with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
396
+ class_0,class_1 = pickle.load(f1)
397
+ except:
398
+ print("Generating latent plots")
399
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
400
+ verbose=False)
401
+ n_samples = np.minimum(1000, len(x_batch))
402
+
403
+ try:x = y_batch.values[:n_samples]
404
+ except: x = y_batch[:n_samples]
405
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
406
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
407
+
408
+ try:
409
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
410
+ class_0 = features_umap[index_0]
411
+ class_1 = features_umap[index_1]
412
+ except:
413
+ class_0 = []
414
+ class_1 = []
415
+ print("Generating latent plots : Done")
416
+
417
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
418
+
419
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
420
+
421
+ return result, roc_auc,fpr, tpr, class_0, class_1
422
+
423
+ elif downstream_model == "DefaultClassifier":
424
+ xgb_predict_concat = XGBClassifier() # n_estimators=5000, learning_rate=0.01, max_depth=10
425
+ xgb_predict_concat.fit(x_batch, y_batch)
426
+
427
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
428
+
429
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
430
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
431
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
432
+
433
+ try:
434
+ with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
435
+ class_0,class_1 = pickle.load(f1)
436
+ except:
437
+ print("Generating latent plots")
438
+ reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
439
+ n_samples = np.minimum(1000,len(x_batch))
440
+
441
+ try:
442
+ x = y_batch.values[:n_samples]
443
+ except:
444
+ x = y_batch[:n_samples]
445
+
446
+ try:
447
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
448
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
449
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
450
+
451
+ class_0 = features_umap[index_0]
452
+ class_1 = features_umap[index_1]
453
+ except:
454
+ class_0 = []
455
+ class_1 = []
456
+
457
+ print("Generating latent plots : Done")
458
+
459
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
460
+
461
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
462
+
463
+ return result, roc_auc,fpr, tpr, class_0, class_1
464
+
465
+ elif downstream_model == "SVR":
466
+ if params == None:
467
+ regressor = SVR()
468
+ else:
469
+ regressor = SVR(**params)
470
+ model = TransformedTargetRegressor(regressor= regressor,
471
+ transformer = MinMaxScaler(feature_range=(-1, 1))
472
+ ).fit(x_batch,y_batch)
473
+
474
+ y_prob = model.predict(x_batch_test)
475
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
476
+
477
+ print(f"RMSE Score: {RMSE_score:.4f}")
478
+ result = f"RMSE Score: {RMSE_score:.4f}"
479
+
480
+ print("Generating latent plots")
481
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
482
+ verbose=False)
483
+ n_samples = np.minimum(1000, len(x_batch))
484
+
485
+ try: x = y_batch.values[:n_samples]
486
+ except: x = y_batch[:n_samples]
487
+ #index_0 = [index for index in range(len(x)) if x[index] == 0]
488
+ #index_1 = [index for index in range(len(x)) if x[index] == 1]
489
+
490
+ try:
491
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
492
+ class_0 = features_umap#[index_0]
493
+ class_1 = features_umap#[index_1]
494
+ except:
495
+ class_0 = []
496
+ class_1 = []
497
+ print("Generating latent plots : Done")
498
+
499
+ return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
500
+
501
+ elif downstream_model == "Kernel Ridge":
502
+ if params == None:
503
+ regressor = KernelRidge()
504
+ else:
505
+ regressor = KernelRidge(**params)
506
+ model = TransformedTargetRegressor(regressor=regressor,
507
+ transformer=MinMaxScaler(feature_range=(-1, 1))
508
+ ).fit(x_batch, y_batch)
509
+
510
+ y_prob = model.predict(x_batch_test)
511
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
512
+
513
+ print(f"RMSE Score: {RMSE_score:.4f}")
514
+ result = f"RMSE Score: {RMSE_score:.4f}"
515
+
516
+ print("Generating latent plots")
517
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
518
+ verbose=False)
519
+ n_samples = np.minimum(1000, len(x_batch))
520
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
521
+ try: x = y_batch.values[:n_samples]
522
+ except: x = y_batch[:n_samples]
523
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
524
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
525
+
526
+ class_0 = features_umap#[index_0]
527
+ class_1 = features_umap#[index_1]
528
+ print("Generating latent plots : Done")
529
+
530
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
531
+
532
+
533
+ elif downstream_model == "Linear Regression":
534
+ if params == None:
535
+ regressor = LinearRegression()
536
+ else:
537
+ regressor = LinearRegression(**params)
538
+ model = TransformedTargetRegressor(regressor=regressor,
539
+ transformer=MinMaxScaler(feature_range=(-1, 1))
540
+ ).fit(x_batch, y_batch)
541
+
542
+ y_prob = model.predict(x_batch_test)
543
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
544
+
545
+ print(f"RMSE Score: {RMSE_score:.4f}")
546
+ result = f"RMSE Score: {RMSE_score:.4f}"
547
+
548
+ print("Generating latent plots")
549
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
550
+ verbose=False)
551
+ n_samples = np.minimum(1000, len(x_batch))
552
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
553
+ try:x = y_batch.values[:n_samples]
554
+ except: x = y_batch[:n_samples]
555
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
556
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
557
+
558
+ class_0 = features_umap#[index_0]
559
+ class_1 = features_umap#[index_1]
560
+ print("Generating latent plots : Done")
561
+
562
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
563
+
564
+
565
+ elif downstream_model == "DefaultRegressor":
566
+ regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
567
+ model = TransformedTargetRegressor(regressor=regressor,
568
+ transformer=MinMaxScaler(feature_range=(-1, 1))
569
+ ).fit(x_batch, y_batch)
570
+
571
+ y_prob = model.predict(x_batch_test)
572
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
573
+
574
+ print(f"RMSE Score: {RMSE_score:.4f}")
575
+ result = f"RMSE Score: {RMSE_score:.4f}"
576
+
577
+ print("Generating latent plots")
578
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
579
+ verbose=False)
580
+ n_samples = np.minimum(1000, len(x_batch))
581
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
582
+ try:x = y_batch.values[:n_samples]
583
+ except: x = y_batch[:n_samples]
584
+ # index_0 = [index for index in range(len(x)) if x[index] == 0]
585
+ # index_1 = [index for index in range(len(x)) if x[index] == 1]
586
+
587
+ class_0 = features_umap#[index_0]
588
+ class_1 = features_umap#[index_1]
589
+ print("Generating latent plots : Done")
590
+
591
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
592
+
593
+
594
+ def multi_modal(model_list,dataset=None, downstream_model=None,params=None, x_train=None, x_test=None, y_train=None, y_test=None):
595
+ #print(model_list)
596
+ data = avail_datasets()
597
+ df = pd.DataFrame(data)
598
+ list(df["Dataset"].values)
599
+
600
+ if dataset in list(df["Dataset"].values):
601
+ task = dataset
602
+ predefined = True
603
+ elif x_train==None:
604
+ predefined = False
605
+ components = dataset.split(",")
606
+ train_data = pd.read_csv(components[0])[components[2]]
607
+ test_data = pd.read_csv(components[1])[components[2]]
608
+
609
+ y_batch = pd.read_csv(components[0])[components[3]]
610
+ y_batch_test = pd.read_csv(components[1])[components[3]]
611
+
612
+ print("Custom Dataset loaded")
613
+ else:
614
+ predefined = False
615
+ y_batch = y_train
616
+ y_batch_test = y_test
617
+ train_data = x_train
618
+ test_data = x_test
619
+
620
+ data = avail_models(raw=True)
621
+ df = pd.DataFrame(data)
622
+ list(df["Name"].values)
623
+
624
+ alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl","SMI-TED":"smi-ted", "Mordred": "Mordred", "MorganFingerprint": "MorganFingerprint"}
625
+ #if set(model_list).issubset(list(df["Name"].values)):
626
+ if set(model_list).issubset(list(alias.keys())):
627
+ for i, model in enumerate(model_list):
628
+ if model in alias.keys():
629
+ model_type = alias[model]
630
+ else:
631
+ model_type = model
632
+
633
+ if i == 0:
634
+ if predefined:
635
+ with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
636
+ x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
637
+ print(f" Loaded representation/{task}_{model_type}.pkl")
638
+ else:
639
+ x_batch, x_batch_test = get_representation(train_data, test_data, model_type)
640
+ x_batch = pd.DataFrame(x_batch)
641
+ x_batch_test = pd.DataFrame(x_batch_test)
642
+
643
+ else:
644
+ if predefined:
645
+ with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
646
+ x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
647
+ print(f" Loaded representation/{task}_{model_type}.pkl")
648
+ else:
649
+ x_batch_1, x_batch_test_1 = get_representation(train_data, test_data, model_type)
650
+ x_batch_1 = pd.DataFrame(x_batch_1)
651
+ x_batch_test_1 = pd.DataFrame(x_batch_test_1)
652
+
653
+ x_batch = pd.concat([x_batch, x_batch_1], axis=1)
654
+ x_batch_test = pd.concat([x_batch_test, x_batch_test_1], axis=1)
655
+
656
+ else:
657
+ print("Model not available")
658
+ return
659
+
660
+ num_columns = x_batch_test.shape[1]
661
+ x_batch_test.columns = [f'{i + 1}' for i in range(num_columns)]
662
+
663
+ num_columns = x_batch.shape[1]
664
+ x_batch.columns = [f'{i + 1}' for i in range(num_columns)]
665
+
666
+ # exclude row containing Nan value
667
+ if isinstance(x_batch, torch.Tensor):
668
+ x_batch = pd.DataFrame(x_batch)
669
+ nan_indices = x_batch.index[x_batch.isna().any(axis=1)]
670
+ if len(nan_indices) > 0:
671
+ x_batch.dropna(inplace = True)
672
+ for index in sorted(nan_indices, reverse=True):
673
+ del y_batch[index]
674
+ print(f'x_batch Nan index: {nan_indices}')
675
+ print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}')
676
+
677
+ if isinstance(x_batch_test, torch.Tensor):
678
+ x_batch_test = pd.DataFrame(x_batch_test)
679
+ nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)]
680
+ if len(nan_indices) > 0:
681
+ x_batch_test.dropna(inplace = True)
682
+ for index in sorted(nan_indices, reverse=True):
683
+ del y_batch_test[index]
684
+ print(f'x_batch_test Nan index: {nan_indices}')
685
+ print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
686
+
687
+ print(f"Representations loaded successfully")
688
+ try:
689
+ with open(f"plot_emb/{task}_multi.pkl", "rb") as f1:
690
+ class_0, class_1 = pickle.load(f1)
691
+ except:
692
+ print("Generating latent plots")
693
+ reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
694
+ verbose=False)
695
+ n_samples = np.minimum(1000, len(x_batch))
696
+ features_umap = reducer.fit_transform(x_batch[:n_samples])
697
+
698
+ if "Classifier" in downstream_model:
699
+ try: x = y_batch.values[:n_samples]
700
+ except: x = y_batch[:n_samples]
701
+ index_0 = [index for index in range(len(x)) if x[index] == 0]
702
+ index_1 = [index for index in range(len(x)) if x[index] == 1]
703
+
704
+ class_0 = features_umap[index_0]
705
+ class_1 = features_umap[index_1]
706
+
707
+ else:
708
+ class_0 = features_umap
709
+ class_1 = features_umap
710
+
711
+ print("Generating latent plots : Done")
712
+
713
+ print(f" Calculating ROC AUC Score ...")
714
+
715
+
716
+ if downstream_model == "XGBClassifier":
717
+ if params == None:
718
+ xgb_predict_concat = XGBClassifier()
719
+ else:
720
+ xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10)
721
+ xgb_predict_concat.fit(x_batch, y_batch)
722
+
723
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
724
+
725
+
726
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
727
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
728
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
729
+
730
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
731
+
732
+ #vizualize(x_batch_test, y_batch_test)
733
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
734
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
735
+
736
+ return result, roc_auc,fpr, tpr, class_0, class_1
737
+
738
+ elif downstream_model == "DefaultClassifier":
739
+ xgb_predict_concat = XGBClassifier()#n_estimators=5000, learning_rate=0.01, max_depth=10)
740
+ xgb_predict_concat.fit(x_batch, y_batch)
741
+
742
+ y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
743
+
744
+
745
+ roc_auc = roc_auc_score(y_batch_test, y_prob)
746
+ fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
747
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
748
+
749
+ #vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
750
+
751
+ #vizualize(x_batch_test, y_batch_test)
752
+ print(f"ROC-AUC Score: {roc_auc:.4f}")
753
+ result = f"ROC-AUC Score: {roc_auc:.4f}"
754
+
755
+ return result, roc_auc,fpr, tpr, class_0, class_1
756
+
757
+ elif downstream_model == "SVR":
758
+ if params == None:
759
+ regressor = SVR()
760
+ else:
761
+ regressor = SVR(**params)
762
+ model = TransformedTargetRegressor(regressor= regressor,
763
+ transformer = MinMaxScaler(feature_range=(-1, 1))
764
+ ).fit(x_batch,y_batch)
765
+
766
+ y_prob = model.predict(x_batch_test)
767
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
768
+
769
+ print(f"RMSE Score: {RMSE_score:.4f}")
770
+ result = f"RMSE Score: {RMSE_score:.4f}"
771
+
772
+ return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
773
+
774
+ elif downstream_model == "Linear Regression":
775
+ if params == None:
776
+ regressor = LinearRegression()
777
+ else:
778
+ regressor = LinearRegression(**params)
779
+ model = TransformedTargetRegressor(regressor=regressor,
780
+ transformer=MinMaxScaler(feature_range=(-1, 1))
781
+ ).fit(x_batch, y_batch)
782
+
783
+ y_prob = model.predict(x_batch_test)
784
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
785
+
786
+ print(f"RMSE Score: {RMSE_score:.4f}")
787
+ result = f"RMSE Score: {RMSE_score:.4f}"
788
+
789
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
790
+
791
+ elif downstream_model == "Kernel Ridge":
792
+ if params == None:
793
+ regressor = KernelRidge()
794
+ else:
795
+ regressor = KernelRidge(**params)
796
+ model = TransformedTargetRegressor(regressor=regressor,
797
+ transformer=MinMaxScaler(feature_range=(-1, 1))
798
+ ).fit(x_batch, y_batch)
799
+
800
+ y_prob = model.predict(x_batch_test)
801
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
802
+
803
+ print(f"RMSE Score: {RMSE_score:.4f}")
804
+ result = f"RMSE Score: {RMSE_score:.4f}"
805
+
806
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
807
+
808
+ elif downstream_model == "DefaultRegressor":
809
+ regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
810
+ model = TransformedTargetRegressor(regressor=regressor,
811
+ transformer=MinMaxScaler(feature_range=(-1, 1))
812
+ ).fit(x_batch, y_batch)
813
+
814
+ y_prob = model.predict(x_batch_test)
815
+ RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
816
+
817
+ print(f"RMSE Score: {RMSE_score:.4f}")
818
+ result = f"RMSE Score: {RMSE_score:.4f}"
819
+
820
+ return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
821
+
822
+
823
+
824
+ def finetune_optuna(x_batch,y_batch, x_batch_test, y_test ):
825
+ print(f" Finetuning with Optuna and calculating ROC AUC Score ...")
826
+ X_train = x_batch.values
827
+ y_train = y_batch.values
828
+ X_test = x_batch_test.values
829
+ y_test = y_test.values
830
+ def objective(trial):
831
+ # Define parameters to be optimized
832
+ params = {
833
+ # 'objective': 'binary:logistic',
834
+ 'eval_metric': 'auc',
835
+ 'verbosity': 0,
836
+ 'n_estimators': trial.suggest_int('n_estimators', 1000, 10000),
837
+ # 'booster': trial.suggest_categorical('booster', ['gbtree', 'gblinear', 'dart']),
838
+ # 'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0),
839
+ 'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0),
840
+ 'max_depth': trial.suggest_int('max_depth', 1, 12),
841
+ # 'eta': trial.suggest_loguniform('eta', 1e-8, 1.0),
842
+ # 'gamma': trial.suggest_loguniform('gamma', 1e-8, 1.0),
843
+ # 'grow_policy': trial.suggest_categorical('grow_policy', ['depthwise', 'lossguide']),
844
+ # "subsample": trial.suggest_float("subsample", 0.05, 1.0),
845
+ # "colsample_bytree": trial.suggest_float("colsample_bytree", 0.05, 1.0),
846
+ }
847
+
848
+ # Train XGBoost model
849
+ dtrain = xgb.DMatrix(X_train, label=y_train)
850
+ dtest = xgb.DMatrix(X_test, label=y_test)
851
+
852
+ model = xgb.train(params, dtrain)
853
+
854
+ # Predict probabilities
855
+ y_pred = model.predict(dtest)
856
+
857
+ # Calculate ROC AUC score
858
+ roc_auc = roc_auc_score(y_test, y_pred)
859
+ print("ROC_AUC : ", roc_auc)
860
+
861
+ return roc_auc
862
+
863
+ def add_new_model():
864
+ models = avail_models(raw=True)
865
+
866
+ # Function to display models
867
+ def display_models():
868
+ for model in models:
869
+ model_display = f"Name: {model['Name']}, Description: {model['Description']}, Timestamp: {model['Timestamp']}"
870
+ print(model_display)
871
+
872
+ # Function to update models
873
+ def update_models(new_name, new_description, new_path):
874
+ new_model = {
875
+ "Name": new_name,
876
+ "Description": new_description,
877
+ "Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
878
+ #"path": new_path
879
+ }
880
+ models.append(new_model)
881
+ with open("models.json", "w") as outfile:
882
+ json.dump(models, outfile)
883
+
884
+ print("Model uploaded and updated successfully!")
885
+ list_models()
886
+ #display_models()
887
+
888
+ # Widgets
889
+ name_text = widgets.Text(description="Name:", layout=Layout(width='50%'))
890
+ description_text = widgets.Text(description="Description:", layout=Layout(width='50%'))
891
+ path_text = widgets.Text(description="Path:", layout=Layout(width='50%'))
892
+
893
+ def browse_callback(b):
894
+ root = tk.Tk()
895
+ root.withdraw() # Hide the main window
896
+ file_path = filedialog.askopenfilename(title="Select a Model File")
897
+ if file_path:
898
+ path_text.value = file_path
899
+
900
+ browse_button = widgets.Button(description="Browse")
901
+ browse_button.on_click(browse_callback)
902
+
903
+ def submit_callback(b):
904
+ update_models(name_text.value, description_text.value, path_text.value)
905
+
906
+ submit_button = widgets.Button(description="Submit")
907
+ submit_button.on_click(submit_callback)
908
+
909
+ # Display widgets
910
+ display(VBox([name_text, description_text, path_text, browse_button, submit_button]))
911
+
912
+
913
+ def add_new_dataset():
914
+ # Sample data
915
+ datasets = avail_datasets()
916
+
917
+ # Function to display models
918
+ def display_datasets():
919
+ for dataset in datasets:
920
+ dataset_display = f"Name: {dataset['Dataset']}, Input: {dataset['Input']},Output: {dataset['Output']},Path: {dataset['Path']}, Timestamp: {dataset['Timestamp']}"
921
+
922
+ # Function to update models
923
+ def update_datasets(new_dataset, new_input, new_output, new_path):
924
+ new_model = {
925
+ "Dataset": new_dataset,
926
+ "Input": new_input,
927
+ "Output": new_output,
928
+ "Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
929
+ "Path": os.path.basename(new_path)
930
+ }
931
+ datasets.append(new_model)
932
+ with open("datasets.json", "w") as outfile:
933
+ json.dump(datasets, outfile)
934
+
935
+ print("Dataset uploaded and updated successfully!")
936
+ list_data()
937
+
938
+
939
+ # Widgets
940
+ dataset_text = widgets.Text(description="Dataset:", layout=Layout(width='50%'))
941
+ input_text = widgets.Text(description="Input:", layout=Layout(width='50%'))
942
+ output_text = widgets.Text(description="Output:", layout=Layout(width='50%'))
943
+ path_text = widgets.Text(description="Path:", layout=Layout(width='50%'))
944
+
945
+ def browse_callback(b):
946
+ root = tk.Tk()
947
+ root.withdraw() # Hide the main window
948
+ file_path = filedialog.askopenfilename(title="Select a Dataset File")
949
+ if file_path:
950
+ path_text.value = file_path
951
+
952
+ browse_button = widgets.Button(description="Browse")
953
+ browse_button.on_click(browse_callback)
954
+
955
+ def submit_callback(b):
956
+ update_datasets(dataset_text.value, input_text.value, output_text.value, path_text.value)
957
+
958
+ submit_button = widgets.Button(description="Submit")
959
+ submit_button.on_click(submit_callback)
960
+
961
+ display(VBox([dataset_text, input_text, output_text, path_text, browse_button, submit_button]))
962
+
963
+
964
+
models/mhg_model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/mhg_model/README.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mhg-gnn
2
+
3
+ This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
4
+
5
+ **Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
6
+
7
+ ![mhg-gnn](images/mhg_example1.png)
8
+
9
+ ## Introduction
10
+
11
+ We present MHG-GNN, an autoencoder architecture
12
+ that has an encoder based on GNN and a decoder based on a sequential model with MHG.
13
+ Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
14
+ demonstrate high predictive performance on molecular graph data.
15
+ In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
16
+
17
+ ## Table of Contents
18
+
19
+ 1. [Getting Started](#getting-started)
20
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
21
+ 2. [Installation](#installation)
22
+ 2. [Feature Extraction](#feature-extraction)
23
+
24
+ ## Getting Started
25
+
26
+ **This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
27
+
28
+ ### Pretrained Models and Training Logs
29
+
30
+ We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.mhg-ged/blob/main/mhggnn_pretrained_model_0724_2023.pickle)
31
+
32
+ Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
33
+
34
+ ### Installation
35
+
36
+ We recommend to create a virtual environment. For example:
37
+
38
+ ```
39
+ python3 -m venv .venv
40
+ . .venv/bin/activate
41
+ ```
42
+
43
+ Type the following command once the virtual environment is activated:
44
+
45
+ ```
46
+ git clone [email protected]:CMD-TRL/mhg-gnn.git
47
+ cd ./mhg-gnn
48
+ pip install .
49
+ ```
50
+
51
+ ## Feature Extraction
52
+
53
+ The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
54
+
55
+ To load mhg-gnn, you can simply use:
56
+
57
+ ```python
58
+ import torch
59
+ import load
60
+
61
+ model = load.load()
62
+ ```
63
+
64
+ To encode SMILES into embeddings, you can use:
65
+
66
+ ```python
67
+ with torch.no_grad():
68
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
69
+ ```
70
+
71
+ For decoder, you can use the function, so you can return from embeddings to SMILES strings:
72
+
73
+ ```python
74
+ orig = model.decode(repr)
75
+ ```
models/mhg_model/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
models/mhg_model/graph_grammar/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """
12
+
13
+ """ Title """
14
+
15
+ __author__ = "Hiroshi Kajino <[email protected]>"
16
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
17
+ __version__ = "0.1"
18
+ __date__ = "Jan 1 2018"
19
+
models/mhg_model/graph_grammar/algo/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding:utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/algo/tree_decomposition.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from copy import deepcopy
22
+ from itertools import combinations
23
+ from ..hypergraph import Hypergraph
24
+ import networkx as nx
25
+ import numpy as np
26
+
27
+
28
+ class CliqueTree(nx.Graph):
29
+ ''' clique tree object
30
+
31
+ Attributes
32
+ ----------
33
+ hg : Hypergraph
34
+ This hypergraph will be decomposed.
35
+ root_hg : Hypergraph
36
+ Hypergraph on the root node.
37
+ ident_node_dict : dict
38
+ ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
39
+ '''
40
+ def __init__(self, hg=None, **kwargs):
41
+ self.hg = deepcopy(hg)
42
+ if self.hg is not None:
43
+ self.ident_node_dict = self.hg.get_identical_node_dict()
44
+ else:
45
+ self.ident_node_dict = {}
46
+ super().__init__(**kwargs)
47
+
48
+ @property
49
+ def root_hg(self):
50
+ ''' return the hypergraph on the root node
51
+ '''
52
+ return self.nodes[0]['subhg']
53
+
54
+ @root_hg.setter
55
+ def root_hg(self, hypergraph):
56
+ ''' set the hypergraph on the root node
57
+ '''
58
+ self.nodes[0]['subhg'] = hypergraph
59
+
60
+ def insert_subhg(self, subhypergraph: Hypergraph) -> None:
61
+ ''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
62
+
63
+ Parameters
64
+ ----------
65
+ subhg : Hypergraph
66
+ '''
67
+ num_nodes = self.number_of_nodes()
68
+ self.add_node(num_nodes, subhg=subhypergraph)
69
+ self.add_edge(num_nodes, 0)
70
+ adj_nodes = deepcopy(list(self.adj[0].keys()))
71
+ for each_node in adj_nodes:
72
+ if len(self.nodes[each_node]["subhg"].nodes.intersection(
73
+ self.nodes[num_nodes]["subhg"].nodes)\
74
+ - self.root_hg.nodes) != 0 and each_node != num_nodes:
75
+ self.remove_edge(0, each_node)
76
+ self.add_edge(each_node, num_nodes)
77
+
78
+ def to_irredundant(self) -> None:
79
+ ''' convert the clique tree to be irredundant
80
+ '''
81
+ for each_node in self.hg.nodes:
82
+ subtree = self.subgraph([
83
+ each_tree_node for each_tree_node in self.nodes()\
84
+ if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
85
+ leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
86
+ redundant_leaf_node_list = []
87
+ for each_leaf_node in leaf_node_list:
88
+ if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
89
+ redundant_leaf_node_list.append(each_leaf_node)
90
+ for each_red_leaf_node in redundant_leaf_node_list:
91
+ current_node = each_red_leaf_node
92
+ while subtree.degree(current_node) == 1 \
93
+ and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
94
+ self.nodes[current_node]["subhg"].remove_node(each_node)
95
+ remove_node = current_node
96
+ current_node = list(dict(subtree[remove_node]).keys())[0]
97
+ subtree.remove_node(remove_node)
98
+
99
+ fixed_node_set = deepcopy(self.nodes)
100
+ for each_node in fixed_node_set:
101
+ if self.nodes[each_node]["subhg"].num_edges == 0:
102
+ if len(self[each_node]) == 1:
103
+ self.remove_node(each_node)
104
+ elif len(self[each_node]) == 2:
105
+ self.add_edge(*self[each_node])
106
+ self.remove_node(each_node)
107
+ else:
108
+ pass
109
+ else:
110
+ pass
111
+
112
+ redundant = True
113
+ while redundant:
114
+ redundant = False
115
+ fixed_edge_set = deepcopy(self.edges)
116
+ remove_node_set = set()
117
+ for node_1, node_2 in fixed_edge_set:
118
+ if node_1 in remove_node_set or node_2 in remove_node_set:
119
+ pass
120
+ else:
121
+ if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
122
+ redundant = True
123
+ adj_node_list = set(self.adj[node_1]) - {node_2}
124
+ self.remove_node(node_1)
125
+ remove_node_set.add(node_1)
126
+ for each_node in adj_node_list:
127
+ self.add_edge(node_2, each_node)
128
+
129
+ elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
130
+ redundant = True
131
+ adj_node_list = set(self.adj[node_2]) - {node_1}
132
+ self.remove_node(node_2)
133
+ remove_node_set.add(node_2)
134
+ for each_node in adj_node_list:
135
+ self.add_edge(node_1, each_node)
136
+
137
+ def node_update(self, key_node: str, subhg) -> None:
138
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
139
+
140
+ Parameters
141
+ ----------
142
+ key_node : str
143
+ key node that must be removed.
144
+ subhg : Hypegraph
145
+ """
146
+ for each_edge in subhg.edges:
147
+ self.root_hg.remove_edge(each_edge)
148
+ self.root_hg.remove_nodes(self.ident_node_dict[key_node])
149
+
150
+ adj_node_list = list(subhg.nodes)
151
+ for each_node in subhg.nodes:
152
+ if each_node not in self.ident_node_dict[key_node]:
153
+ if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
154
+ self.root_hg.remove_node(each_node)
155
+ adj_node_list.remove(each_node)
156
+ else:
157
+ adj_node_list.remove(each_node)
158
+
159
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
160
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
161
+ self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
162
+
163
+ subhg.remove_edges_with_attr({'tmp' : True})
164
+ self.insert_subhg(subhg)
165
+
166
+ def update(self, subhg, remove_nodes=False):
167
+ """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
168
+
169
+ Parameters
170
+ ----------
171
+ subhg : Hypegraph
172
+ """
173
+ for each_edge in subhg.edges:
174
+ self.root_hg.remove_edge(each_edge)
175
+ if remove_nodes:
176
+ remove_edge_list = []
177
+ for each_edge in self.root_hg.edges:
178
+ if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
179
+ and self.root_hg.edge_attr(each_edge).get('tmp', False):
180
+ remove_edge_list.append(each_edge)
181
+ self.root_hg.remove_edges(remove_edge_list)
182
+
183
+ adj_node_list = list(subhg.nodes)
184
+ for each_node in subhg.nodes:
185
+ if self.root_hg.degree(each_node) == 0:
186
+ self.root_hg.remove_node(each_node)
187
+ adj_node_list.remove(each_node)
188
+
189
+ if len(adj_node_list) != 1 and not remove_nodes:
190
+ self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
191
+ '''
192
+ else:
193
+ for each_node_1, each_node_2 in combinations(adj_node_list, 2):
194
+ if not self.root_hg.is_adj(each_node_1, each_node_2):
195
+ self.root_hg.add_edge(
196
+ [each_node_1, each_node_2], attr_dict=dict(tmp=True))
197
+ '''
198
+ subhg.remove_edges_with_attr({'tmp':True})
199
+ self.insert_subhg(subhg)
200
+
201
+
202
+ def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
203
+ if mode == 'standard':
204
+ degree_dict = hg.degrees()
205
+ min_deg_node = min(degree_dict, key=degree_dict.get)
206
+ min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
207
+ return min_deg_node, min_deg_subhg
208
+ elif mode == 'mol':
209
+ degree_dict = hg.degrees()
210
+ min_deg = min(degree_dict.values())
211
+ min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
212
+ min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
213
+ for each_min_deg_node in min_deg_node_list]
214
+ best_score = np.inf
215
+ best_idx = -1
216
+ for each_idx in range(len(min_deg_subhg_list)):
217
+ if min_deg_subhg_list[each_idx].num_nodes < best_score:
218
+ best_idx = each_idx
219
+ return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
220
+ else:
221
+ raise ValueError
222
+
223
+
224
+ def tree_decomposition(hg, irredundant=True):
225
+ """ compute a tree decomposition of the input hypergraph
226
+
227
+ Parameters
228
+ ----------
229
+ hg : Hypergraph
230
+ hypergraph to be decomposed
231
+ irredundant : bool
232
+ if True, irredundant tree decomposition will be computed.
233
+
234
+ Returns
235
+ -------
236
+ clique_tree : nx.Graph
237
+ each node contains a subhypergraph of `hg`
238
+ """
239
+ org_hg = hg.copy()
240
+ ident_node_dict = hg.get_identical_node_dict()
241
+ clique_tree = CliqueTree(org_hg)
242
+ clique_tree.add_node(0, subhg=org_hg)
243
+ while True:
244
+ degree_dict = org_hg.degrees()
245
+ min_deg_node = min(degree_dict, key=degree_dict.get)
246
+ min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
247
+ if org_hg.nodes == min_deg_subhg.nodes:
248
+ break
249
+
250
+ # org_hg and min_deg_subhg are divided
251
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
252
+
253
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
254
+
255
+ if irredundant:
256
+ clique_tree.to_irredundant()
257
+ return clique_tree
258
+
259
+
260
+ def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
261
+ ''' compute a tree decomposition given a hyperedge replacement grammar.
262
+ the resultant clique tree should induce a less compact HRG.
263
+
264
+ Parameters
265
+ ----------
266
+ hg : Hypergraph
267
+ hypergraph to be decomposed
268
+ hrg : HyperedgeReplacementGrammar
269
+ current HRG
270
+ irredundant : bool
271
+ if True, irredundant tree decomposition will be computed.
272
+
273
+ Returns
274
+ -------
275
+ clique_tree : nx.Graph
276
+ each node contains a subhypergraph of `hg`
277
+ '''
278
+ org_hg = hg.copy()
279
+ ident_node_dict = hg.get_identical_node_dict()
280
+ clique_tree = CliqueTree(org_hg)
281
+ clique_tree.add_node(0, subhg=org_hg)
282
+ root_node = 0
283
+
284
+ # construct a clique tree using HRG
285
+ success_any = True
286
+ while success_any:
287
+ success_any = False
288
+ for each_prod_rule in hrg.prod_rule_list:
289
+ org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
290
+ if success:
291
+ if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
292
+ success_any = True
293
+ subhg.remove_edges_with_attr({'terminal' : False})
294
+ clique_tree.root_hg = org_hg
295
+ clique_tree.insert_subhg(subhg)
296
+
297
+ clique_tree.root_hg = org_hg
298
+
299
+ for each_edge in deepcopy(org_hg.edges):
300
+ if not org_hg.edge_attr(each_edge)['terminal']:
301
+ node_list = org_hg.nodes_in_edge(each_edge)
302
+ org_hg.remove_edge(each_edge)
303
+
304
+ for each_node_1, each_node_2 in combinations(node_list, 2):
305
+ if not org_hg.is_adj(each_node_1, each_node_2):
306
+ org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
307
+
308
+ # construct a clique tree using the existing algorithm
309
+ degree_dict = org_hg.degrees()
310
+ if degree_dict:
311
+ while True:
312
+ min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
313
+ if org_hg.nodes == min_deg_subhg.nodes: break
314
+
315
+ # org_hg and min_deg_subhg are divided
316
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
317
+
318
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
319
+ if irredundant:
320
+ clique_tree.to_irredundant()
321
+
322
+ if return_root:
323
+ if root_node == 0 and 0 not in clique_tree.nodes:
324
+ root_node = clique_tree.number_of_nodes()
325
+ while root_node not in clique_tree.nodes:
326
+ root_node -= 1
327
+ elif root_node not in clique_tree.nodes:
328
+ while root_node not in clique_tree.nodes:
329
+ root_node -= 1
330
+ else:
331
+ pass
332
+ return clique_tree, root_node
333
+ else:
334
+ return clique_tree
335
+
336
+
337
+ def tree_decomposition_from_leaf(hg, irredundant=True):
338
+ """ compute a tree decomposition of the input hypergraph
339
+
340
+ Parameters
341
+ ----------
342
+ hg : Hypergraph
343
+ hypergraph to be decomposed
344
+ irredundant : bool
345
+ if True, irredundant tree decomposition will be computed.
346
+
347
+ Returns
348
+ -------
349
+ clique_tree : nx.Graph
350
+ each node contains a subhypergraph of `hg`
351
+ """
352
+ def apply_normal_decomposition(clique_tree):
353
+ degree_dict = clique_tree.root_hg.degrees()
354
+ min_deg_node = min(degree_dict, key=degree_dict.get)
355
+ min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
356
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
357
+ return clique_tree, False
358
+ clique_tree.node_update(min_deg_node, min_deg_subhg)
359
+ return clique_tree, True
360
+
361
+ def apply_min_edge_deg_decomposition(clique_tree):
362
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
363
+ non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
364
+ if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
365
+ if not non_tmp_edge_list:
366
+ return clique_tree, False
367
+ min_deg_edge = None
368
+ min_deg = np.inf
369
+ for each_edge in non_tmp_edge_list:
370
+ if min_deg > edge_degree_dict[each_edge]:
371
+ min_deg_edge = each_edge
372
+ min_deg = edge_degree_dict[each_edge]
373
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
374
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
375
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
376
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
377
+ return clique_tree, False
378
+ clique_tree.update(min_deg_subhg)
379
+ return clique_tree, True
380
+
381
+ org_hg = hg.copy()
382
+ clique_tree = CliqueTree(org_hg)
383
+ clique_tree.add_node(0, subhg=org_hg)
384
+
385
+ success = True
386
+ while success:
387
+ clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
388
+ if not success:
389
+ clique_tree, success = apply_normal_decomposition(clique_tree)
390
+
391
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
392
+ if irredundant:
393
+ clique_tree.to_irredundant()
394
+ return clique_tree
395
+
396
+ def topological_tree_decomposition(
397
+ hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
398
+ ''' compute a tree decomposition of the input hypergraph
399
+
400
+ Parameters
401
+ ----------
402
+ hg : Hypergraph
403
+ hypergraph to be decomposed
404
+ irredundant : bool
405
+ if True, irredundant tree decomposition will be computed.
406
+
407
+ Returns
408
+ -------
409
+ clique_tree : CliqueTree
410
+ each node contains a subhypergraph of `hg`
411
+ '''
412
+ def _contract_tree(clique_tree):
413
+ ''' contract a single leaf
414
+
415
+ Parameters
416
+ ----------
417
+ clique_tree : CliqueTree
418
+
419
+ Returns
420
+ -------
421
+ CliqueTree, bool
422
+ bool represents whether this operation succeeds or not.
423
+ '''
424
+ edge_degree_dict = clique_tree.root_hg.edge_degrees()
425
+ leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
426
+ if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
427
+ and edge_degree_dict[each_edge] == 1]
428
+ if not leaf_edge_list:
429
+ return clique_tree, False
430
+ min_deg_edge = leaf_edge_list[0]
431
+ node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
432
+ min_deg_subhg = clique_tree.root_hg.get_subhg(
433
+ node_list, [min_deg_edge], clique_tree.ident_node_dict)
434
+ if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
435
+ return clique_tree, False
436
+ clique_tree.update(min_deg_subhg)
437
+ return clique_tree, True
438
+
439
+ def _rip_labels_from_cycles(clique_tree, org_hg):
440
+ ''' rip hyperedge-labels off
441
+
442
+ Parameters
443
+ ----------
444
+ clique_tree : CliqueTree
445
+ org_hg : Hypergraph
446
+
447
+ Returns
448
+ -------
449
+ CliqueTree, bool
450
+ bool represents whether this operation succeeds or not.
451
+ '''
452
+ ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
453
+ for each_edge in clique_tree.root_hg.edges:
454
+ if each_edge in org_hg.edges:
455
+ if org_hg.in_cycle(each_edge):
456
+ node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
457
+ subhg = clique_tree.root_hg.get_subhg(
458
+ node_list, [each_edge], ident_node_dict)
459
+ if clique_tree.root_hg.nodes == subhg.nodes:
460
+ return clique_tree, False
461
+ clique_tree.update(subhg)
462
+ '''
463
+ in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
464
+ if not all(in_cycle_dict.values()):
465
+ node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
466
+ node_list = [node_not_in_cycle]
467
+ node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
468
+ edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
469
+ import pdb; pdb.set_trace()
470
+ subhg = clique_tree.root_hg.get_subhg(
471
+ node_list, edge_list, ident_node_dict)
472
+
473
+ clique_tree.update(subhg)
474
+ '''
475
+ return clique_tree, True
476
+ return clique_tree, False
477
+
478
+ def _shrink_cycle(clique_tree):
479
+ ''' shrink a cycle
480
+
481
+ Parameters
482
+ ----------
483
+ clique_tree : CliqueTree
484
+
485
+ Returns
486
+ -------
487
+ CliqueTree, bool
488
+ bool represents whether this operation succeeds or not.
489
+ '''
490
+ def filter_subhg(subhg, hg, key_node):
491
+ num_nodes_cycle = 0
492
+ nodes_in_cycle_list = []
493
+ for each_node in subhg.nodes:
494
+ if hg.in_cycle(each_node):
495
+ num_nodes_cycle += 1
496
+ if each_node != key_node:
497
+ nodes_in_cycle_list.append(each_node)
498
+ if num_nodes_cycle > 3:
499
+ break
500
+ if num_nodes_cycle != 3:
501
+ return False
502
+ else:
503
+ for each_edge in hg.edges:
504
+ if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
505
+ return False
506
+ return True
507
+
508
+ #ident_node_dict = hg.get_identical_node_dict()
509
+ ident_node_dict = clique_tree.ident_node_dict
510
+ for each_node in clique_tree.root_hg.nodes:
511
+ if clique_tree.root_hg.in_cycle(each_node)\
512
+ and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
513
+ clique_tree.root_hg,
514
+ each_node):
515
+ target_node = each_node
516
+ target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
517
+ if clique_tree.root_hg.nodes == target_subhg.nodes:
518
+ return clique_tree, False
519
+ clique_tree.update(target_subhg)
520
+ return clique_tree, True
521
+ return clique_tree, False
522
+
523
+ def _contract_cycles(clique_tree):
524
+ '''
525
+ remove a subhypergraph that looks like a cycle on a leaf.
526
+
527
+ Parameters
528
+ ----------
529
+ clique_tree : CliqueTree
530
+
531
+ Returns
532
+ -------
533
+ CliqueTree, bool
534
+ bool represents whether this operation succeeds or not.
535
+ '''
536
+ def _divide_hg(hg):
537
+ ''' divide a hypergraph into subhypergraphs such that
538
+ each subhypergraph is connected to each other in a tree-like way.
539
+
540
+ Parameters
541
+ ----------
542
+ hg : Hypergraph
543
+
544
+ Returns
545
+ -------
546
+ list of Hypergraphs
547
+ each element corresponds to a subhypergraph of `hg`
548
+ '''
549
+ for each_node in hg.nodes:
550
+ if hg.is_dividable(each_node):
551
+ adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
552
+ '''
553
+ if any(adj_edges_dict.values()):
554
+ import pdb; pdb.set_trace()
555
+ edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
556
+ subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
557
+ return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
558
+ else:
559
+ '''
560
+ subhg1, subhg2 = hg.divide(each_node)
561
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
562
+ return [hg]
563
+
564
+ def _is_leaf(hg, divided_subhg) -> bool:
565
+ ''' judge whether subhg is a leaf-like in the original hypergraph
566
+
567
+ Parameters
568
+ ----------
569
+ hg : Hypergraph
570
+ divided_subhg : Hypergraph
571
+ `divided_subhg` is a subhypergraph of `hg`
572
+
573
+ Returns
574
+ -------
575
+ bool
576
+ '''
577
+ '''
578
+ adj_edges_set = set([])
579
+ for each_node in divided_subhg.nodes:
580
+ adj_edges_set.update(set(hg.adj_edges(each_node)))
581
+
582
+
583
+ _hg = deepcopy(hg)
584
+ _hg.remove_subhg(divided_subhg)
585
+ if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
586
+ import pdb; pdb.set_trace()
587
+ return len(adj_edges_set - divided_subhg.edges) == 1
588
+ '''
589
+ _hg = deepcopy(hg)
590
+ _hg.remove_subhg(divided_subhg)
591
+ return nx.is_connected(_hg.hg)
592
+
593
+ subhg_list = _divide_hg(clique_tree.root_hg)
594
+ if len(subhg_list) == 1:
595
+ return clique_tree, False
596
+ else:
597
+ while len(subhg_list) > 1:
598
+ max_leaf_subhg = None
599
+ for each_subhg in subhg_list:
600
+ if _is_leaf(clique_tree.root_hg, each_subhg):
601
+ if max_leaf_subhg is None:
602
+ max_leaf_subhg = each_subhg
603
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
604
+ max_leaf_subhg = each_subhg
605
+ clique_tree.update(max_leaf_subhg)
606
+ subhg_list.remove(max_leaf_subhg)
607
+ return clique_tree, True
608
+
609
+ org_hg = hg.copy()
610
+ clique_tree = CliqueTree(org_hg)
611
+ clique_tree.add_node(0, subhg=org_hg)
612
+
613
+ success = True
614
+ while success:
615
+ '''
616
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
617
+ if not success:
618
+ clique_tree, success = _contract_cycles(clique_tree)
619
+ '''
620
+ clique_tree, success = _contract_tree(clique_tree)
621
+ if not success:
622
+ if rip_labels:
623
+ clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
624
+ if not success:
625
+ if shrink_cycle:
626
+ clique_tree, success = _shrink_cycle(clique_tree)
627
+ if not success:
628
+ if contract_cycles:
629
+ clique_tree, success = _contract_cycles(clique_tree)
630
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
631
+ if irredundant:
632
+ clique_tree.to_irredundant()
633
+ return clique_tree
634
+
635
+ def molecular_tree_decomposition(hg, irredundant=True):
636
+ """ compute a tree decomposition of the input molecular hypergraph
637
+
638
+ Parameters
639
+ ----------
640
+ hg : Hypergraph
641
+ molecular hypergraph to be decomposed
642
+ irredundant : bool
643
+ if True, irredundant tree decomposition will be computed.
644
+
645
+ Returns
646
+ -------
647
+ clique_tree : CliqueTree
648
+ each node contains a subhypergraph of `hg`
649
+ """
650
+ def _divide_hg(hg):
651
+ ''' divide a hypergraph into subhypergraphs such that
652
+ each subhypergraph is connected to each other in a tree-like way.
653
+
654
+ Parameters
655
+ ----------
656
+ hg : Hypergraph
657
+
658
+ Returns
659
+ -------
660
+ list of Hypergraphs
661
+ each element corresponds to a subhypergraph of `hg`
662
+ '''
663
+ is_ring = False
664
+ for each_node in hg.nodes:
665
+ if hg.node_attr(each_node)['is_in_ring']:
666
+ is_ring = True
667
+ if not hg.node_attr(each_node)['is_in_ring'] \
668
+ and hg.degree(each_node) == 2:
669
+ subhg1, subhg2 = hg.divide(each_node)
670
+ return _divide_hg(subhg1) + _divide_hg(subhg2)
671
+
672
+ if is_ring:
673
+ subhg_list = []
674
+ remove_edge_list = []
675
+ remove_node_list = []
676
+ for each_edge in hg.edges:
677
+ node_list = hg.nodes_in_edge(each_edge)
678
+ subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
679
+ subhg_list.append(subhg)
680
+ remove_edge_list.append(each_edge)
681
+ for each_node in node_list:
682
+ if not hg.node_attr(each_node)['is_in_ring']:
683
+ remove_node_list.append(each_node)
684
+ hg.remove_edges(remove_edge_list)
685
+ hg.remove_nodes(remove_node_list, False)
686
+ return subhg_list + [hg]
687
+ else:
688
+ return [hg]
689
+
690
+ org_hg = hg.copy()
691
+ clique_tree = CliqueTree(org_hg)
692
+ clique_tree.add_node(0, subhg=org_hg)
693
+
694
+ subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
695
+ #_subhg_list = deepcopy(subhg_list)
696
+ if len(subhg_list) == 1:
697
+ pass
698
+ else:
699
+ while len(subhg_list) > 1:
700
+ max_leaf_subhg = None
701
+ for each_subhg in subhg_list:
702
+ if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
703
+ if max_leaf_subhg is None:
704
+ max_leaf_subhg = each_subhg
705
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
706
+ max_leaf_subhg = each_subhg
707
+
708
+ if max_leaf_subhg is None:
709
+ for each_subhg in subhg_list:
710
+ if _is_ring_label(clique_tree.root_hg, each_subhg):
711
+ if max_leaf_subhg is None:
712
+ max_leaf_subhg = each_subhg
713
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
714
+ max_leaf_subhg = each_subhg
715
+ if max_leaf_subhg is not None:
716
+ clique_tree.update(max_leaf_subhg)
717
+ subhg_list.remove(max_leaf_subhg)
718
+ else:
719
+ for each_subhg in subhg_list:
720
+ if _is_leaf(clique_tree.root_hg, each_subhg):
721
+ if max_leaf_subhg is None:
722
+ max_leaf_subhg = each_subhg
723
+ elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
724
+ max_leaf_subhg = each_subhg
725
+ if max_leaf_subhg is not None:
726
+ clique_tree.update(max_leaf_subhg, True)
727
+ subhg_list.remove(max_leaf_subhg)
728
+ else:
729
+ break
730
+ if len(subhg_list) > 1:
731
+ '''
732
+ for each_idx, each_subhg in enumerate(subhg_list):
733
+ each_subhg.draw(f'{each_idx}', True)
734
+ clique_tree.root_hg.draw('root', True)
735
+ import pickle
736
+ with open('buggy_hg.pkl', 'wb') as f:
737
+ pickle.dump(hg, f)
738
+ return clique_tree, subhg_list, _subhg_list
739
+ '''
740
+ raise RuntimeError('bug in tree decomposition algorithm')
741
+ clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
742
+
743
+ '''
744
+ for each_tree_node in clique_tree.adj[0]:
745
+ subhg = clique_tree.nodes[each_tree_node]['subhg']
746
+ for each_edge in subhg.edges:
747
+ if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
748
+ clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
749
+ '''
750
+ if irredundant:
751
+ clique_tree.to_irredundant()
752
+ return clique_tree #, _subhg_list
753
+
754
+ def _is_leaf(hg, subhg) -> bool:
755
+ ''' judge whether subhg is a leaf-like in the original hypergraph
756
+
757
+ Parameters
758
+ ----------
759
+ hg : Hypergraph
760
+ subhg : Hypergraph
761
+ `subhg` is a subhypergraph of `hg`
762
+
763
+ Returns
764
+ -------
765
+ bool
766
+ '''
767
+ if len(subhg.edges) == 0:
768
+ adj_edge_set = set([])
769
+ subhg_edge_set = set([])
770
+ for each_edge in hg.edges:
771
+ if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
772
+ subhg_edge_set.add(each_edge)
773
+ for each_node in subhg.nodes:
774
+ adj_edge_set.update(set(hg.adj_edges(each_node)))
775
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
776
+ return True
777
+ else:
778
+ return False
779
+ elif len(subhg.edges) == 1:
780
+ adj_edge_set = set([])
781
+ subhg_edge_set = subhg.edges
782
+ for each_node in subhg.nodes:
783
+ for each_adj_edge in hg.adj_edges(each_node):
784
+ adj_edge_set.add(each_adj_edge)
785
+ if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
786
+ return True
787
+ else:
788
+ return False
789
+ else:
790
+ raise ValueError('subhg should be nodes only or one-edge hypergraph.')
791
+
792
+ def _is_ring_label(hg, subhg):
793
+ if len(subhg.edges) != 1:
794
+ return False
795
+ edge_name = list(subhg.edges)[0]
796
+ #assert edge_name in hg.edges, f'{edge_name}'
797
+ is_in_ring = False
798
+ for each_node in subhg.nodes:
799
+ if subhg.node_attr(each_node)['is_in_ring']:
800
+ is_in_ring = True
801
+ else:
802
+ adj_edge_list = list(hg.adj_edges(each_node))
803
+ adj_edge_list.remove(edge_name)
804
+ if len(adj_edge_list) == 1:
805
+ if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
806
+ return False
807
+ elif len(adj_edge_list) == 0:
808
+ pass
809
+ else:
810
+ raise ValueError
811
+ if is_in_ring:
812
+ return True
813
+ else:
814
+ return False
815
+
816
+ def _is_ring(hg):
817
+ for each_node in hg.nodes:
818
+ if not hg.node_attr(each_node)['is_in_ring']:
819
+ return False
820
+ return True
821
+
models/mhg_model/graph_grammar/graph_grammar/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/graph_grammar/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from abc import ABCMeta, abstractmethod
22
+
23
+ class GraphGrammarBase(metaclass=ABCMeta):
24
+ @abstractmethod
25
+ def learn(self):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def sample(self):
30
+ pass
models/mhg_model/graph_grammar/graph_grammar/corpus.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from collections import Counter
22
+ from functools import partial
23
+ from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
24
+ from networkx.algorithms.isomorphism import GraphMatcher
25
+ import os
26
+
27
+
28
+ class CliqueTreeCorpus(object):
29
+
30
+ ''' clique tree corpus
31
+
32
+ Attributes
33
+ ----------
34
+ clique_tree_list : list of CliqueTree
35
+ subhg_list : list of Hypergraph
36
+ '''
37
+
38
+ def __init__(self):
39
+ self.clique_tree_list = []
40
+ self.subhg_list = []
41
+
42
+ @property
43
+ def size(self):
44
+ return len(self.subhg_list)
45
+
46
+ def add_clique_tree(self, clique_tree):
47
+ for each_node in clique_tree.nodes:
48
+ subhg = clique_tree.nodes[each_node]['subhg']
49
+ subhg_idx = self.add_subhg(subhg)
50
+ clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
51
+ self.clique_tree_list.append(clique_tree)
52
+
53
+ def add_to_subhg_list(self, clique_tree, root_node):
54
+ parent_node_dict = {}
55
+ current_node = None
56
+ parent_node_dict[root_node] = None
57
+ stack = [root_node]
58
+ while stack:
59
+ current_node = stack.pop()
60
+ current_subhg = clique_tree.nodes[current_node]['subhg']
61
+ for each_child in clique_tree.adj[current_node]:
62
+ if each_child != parent_node_dict[current_node]:
63
+ stack.append(each_child)
64
+ parent_node_dict[each_child] = current_node
65
+ if parent_node_dict[current_node] is not None:
66
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
67
+ common, _ = common_node_list(parent_subhg, current_subhg)
68
+ parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
69
+
70
+ parent_node_dict = {}
71
+ current_node = None
72
+ parent_node_dict[root_node] = None
73
+ stack = [root_node]
74
+ while stack:
75
+ current_node = stack.pop()
76
+ current_subhg = clique_tree.nodes[current_node]['subhg']
77
+ for each_child in clique_tree.adj[current_node]:
78
+ if each_child != parent_node_dict[current_node]:
79
+ stack.append(each_child)
80
+ parent_node_dict[each_child] = current_node
81
+ if parent_node_dict[current_node] is not None:
82
+ parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
83
+ common, _ = common_node_list(parent_subhg, current_subhg)
84
+ for each_idx, each_node in enumerate(common):
85
+ current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
86
+
87
+ subhg_idx, is_new = self.add_subhg(current_subhg)
88
+ clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
89
+ return clique_tree
90
+
91
+ def add_subhg(self, subhg):
92
+ if len(self.subhg_list) == 0:
93
+ node_dict = {}
94
+ for each_node in subhg.nodes:
95
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
96
+ node_list = []
97
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
98
+ node_list.append(each_key)
99
+ for each_idx, each_node in enumerate(node_list):
100
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
101
+ self.subhg_list.append(subhg)
102
+ return 0, True
103
+ else:
104
+ match = False
105
+ subhg_bond_symbol_counter \
106
+ = Counter([subhg.node_attr(each_node)['symbol'] \
107
+ for each_node in subhg.nodes])
108
+ subhg_atom_symbol_counter \
109
+ = Counter([subhg.edge_attr(each_edge).get('symbol', None) \
110
+ for each_edge in subhg.edges])
111
+ for each_idx, each_subhg in enumerate(self.subhg_list):
112
+ each_bond_symbol_counter \
113
+ = Counter([each_subhg.node_attr(each_node)['symbol'] \
114
+ for each_node in each_subhg.nodes])
115
+ each_atom_symbol_counter \
116
+ = Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
117
+ for each_edge in each_subhg.edges])
118
+ if not match \
119
+ and (subhg.num_nodes == each_subhg.num_nodes
120
+ and subhg.num_edges == each_subhg.num_edges
121
+ and subhg_bond_symbol_counter == each_bond_symbol_counter
122
+ and subhg_atom_symbol_counter == each_atom_symbol_counter):
123
+ gm = GraphMatcher(each_subhg.hg,
124
+ subhg.hg,
125
+ node_match=_easy_node_match,
126
+ edge_match=_edge_match)
127
+ try:
128
+ isomap = next(gm.isomorphisms_iter())
129
+ match = True
130
+ for each_node in each_subhg.nodes:
131
+ subhg.node_attr(isomap[each_node])['order4hrg'] \
132
+ = each_subhg.node_attr(each_node)['order4hrg']
133
+ if 'ext_id' in each_subhg.node_attr(each_node):
134
+ subhg.node_attr(isomap[each_node])['ext_id'] \
135
+ = each_subhg.node_attr(each_node)['ext_id']
136
+ return each_idx, False
137
+ except StopIteration:
138
+ match = False
139
+ if not match:
140
+ node_dict = {}
141
+ for each_node in subhg.nodes:
142
+ node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
143
+ node_list = []
144
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
145
+ node_list.append(each_key)
146
+ for each_idx, each_node in enumerate(node_list):
147
+ subhg.node_attr(each_node)['order4hrg'] = each_idx
148
+
149
+ #for each_idx, each_node in enumerate(subhg.nodes):
150
+ # subhg.node_attr(each_node)['order4hrg'] = each_idx
151
+ self.subhg_list.append(subhg)
152
+ return len(self.subhg_list) - 1, True
models/mhg_model/graph_grammar/graph_grammar/hrg.py ADDED
@@ -0,0 +1,1065 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2017"
18
+ __version__ = "0.1"
19
+ __date__ = "Dec 11 2017"
20
+
21
+ from .corpus import CliqueTreeCorpus
22
+ from .base import GraphGrammarBase
23
+ from .symbols import TSymbol, NTSymbol, BondSymbol
24
+ from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
25
+ from ..hypergraph import Hypergraph
26
+ from collections import Counter
27
+ from copy import deepcopy
28
+ from ..algo.tree_decomposition import (
29
+ tree_decomposition,
30
+ tree_decomposition_with_hrg,
31
+ tree_decomposition_from_leaf,
32
+ topological_tree_decomposition,
33
+ molecular_tree_decomposition)
34
+ from functools import partial
35
+ from networkx.algorithms.isomorphism import GraphMatcher
36
+ from typing import List, Dict, Tuple
37
+ import networkx as nx
38
+ import numpy as np
39
+ import torch
40
+ import os
41
+ import random
42
+
43
+ DEBUG = False
44
+
45
+
46
+ class ProductionRule(object):
47
+ """ A class of a production rule
48
+
49
+ Attributes
50
+ ----------
51
+ lhs : Hypergraph or None
52
+ the left hand side of the production rule.
53
+ if None, the rule is a starting rule.
54
+ rhs : Hypergraph
55
+ the right hand side of the production rule.
56
+ """
57
+ def __init__(self, lhs, rhs):
58
+ self.lhs = lhs
59
+ self.rhs = rhs
60
+
61
+ @property
62
+ def is_start_rule(self) -> bool:
63
+ return self.lhs.num_nodes == 0
64
+
65
+ @property
66
+ def ext_node(self) -> Dict[int, str]:
67
+ """ return a dict of external nodes
68
+ """
69
+ if self.is_start_rule:
70
+ return {}
71
+ else:
72
+ ext_node_dict = {}
73
+ for each_node in self.lhs.nodes:
74
+ ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
75
+ return ext_node_dict
76
+
77
+ @property
78
+ def lhs_nt_symbol(self) -> NTSymbol:
79
+ if self.is_start_rule:
80
+ return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
81
+ else:
82
+ return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
83
+
84
+ def rhs_adj_mat(self, node_edge_list):
85
+ ''' return the adjacency matrix of rhs of the production rule
86
+ '''
87
+ return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
88
+
89
+ def draw(self, file_path=None):
90
+ return self.rhs.draw(file_path)
91
+
92
+ def is_same(self, prod_rule, ignore_order=False):
93
+ """ judge whether this production rule is
94
+ the same as the input one, `prod_rule`
95
+
96
+ Parameters
97
+ ----------
98
+ prod_rule : ProductionRule
99
+ production rule to be compared
100
+
101
+ Returns
102
+ -------
103
+ is_same : bool
104
+ isomap : dict
105
+ isomorphism of nodes and hyperedges.
106
+ ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
107
+ 'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
108
+ 'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
109
+ key comes from `prod_rule`, value comes from `self`.
110
+ """
111
+ if self.is_start_rule:
112
+ if not prod_rule.is_start_rule:
113
+ return False, {}
114
+ else:
115
+ if prod_rule.is_start_rule:
116
+ return False, {}
117
+ else:
118
+ if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
119
+ return False, {}
120
+
121
+ if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
122
+ return False, {}
123
+ if prod_rule.rhs.num_edges != self.rhs.num_edges:
124
+ return False, {}
125
+
126
+ subhg_bond_symbol_counter \
127
+ = Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
128
+ for each_node in prod_rule.rhs.nodes])
129
+ each_bond_symbol_counter \
130
+ = Counter([self.rhs.node_attr(each_node)['symbol'] \
131
+ for each_node in self.rhs.nodes])
132
+ if subhg_bond_symbol_counter != each_bond_symbol_counter:
133
+ return False, {}
134
+
135
+ subhg_atom_symbol_counter \
136
+ = Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
137
+ for each_edge in prod_rule.rhs.edges])
138
+ each_atom_symbol_counter \
139
+ = Counter([self.rhs.edge_attr(each_edge)['symbol'] \
140
+ for each_edge in self.rhs.edges])
141
+ if subhg_atom_symbol_counter != each_atom_symbol_counter:
142
+ return False, {}
143
+
144
+ gm = GraphMatcher(prod_rule.rhs.hg,
145
+ self.rhs.hg,
146
+ partial(_node_match_prod_rule,
147
+ ignore_order=ignore_order),
148
+ partial(_edge_match,
149
+ ignore_order=ignore_order))
150
+ try:
151
+ return True, next(gm.isomorphisms_iter())
152
+ except StopIteration:
153
+ return False, {}
154
+
155
+ def applied_to(self,
156
+ hg: Hypergraph,
157
+ edge: str) -> Tuple[Hypergraph, List[str]]:
158
+ """ augment `hg` by replacing `edge` with `self.rhs`.
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+ edge : str
164
+ `edge` must belong to `hg`
165
+
166
+ Returns
167
+ -------
168
+ hg : Hypergraph
169
+ resultant hypergraph
170
+ nt_edge_list : list
171
+ list of non-terminal edges
172
+ """
173
+ nt_edge_dict = {}
174
+ if self.is_start_rule:
175
+ if (edge is not None) or (hg is not None):
176
+ ValueError("edge and hg must be None for this prod rule.")
177
+ hg = Hypergraph()
178
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
179
+ for num_idx, each_node in enumerate(self.rhs.nodes):
180
+ hg.add_node(f"bond_{num_idx}",
181
+ #attr_dict=deepcopy(self.rhs.node_attr(each_node)))
182
+ attr_dict=self.rhs.node_attr(each_node))
183
+ node_map_rhs[each_node] = f"bond_{num_idx}"
184
+ for each_edge in self.rhs.edges:
185
+ node_list = []
186
+ for each_node in self.rhs.nodes_in_edge(each_edge):
187
+ node_list.append(node_map_rhs[each_node])
188
+ if isinstance(self.rhs.nodes_in_edge(each_edge), set):
189
+ node_list = set(node_list)
190
+ edge_id = hg.add_edge(
191
+ node_list,
192
+ #attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
193
+ attr_dict=self.rhs.edge_attr(each_edge))
194
+ if "nt_idx" in hg.edge_attr(edge_id):
195
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
196
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
197
+ return hg, nt_edge_list
198
+ else:
199
+ if edge not in hg.edges:
200
+ raise ValueError("the input hyperedge does not exist.")
201
+ if hg.edge_attr(edge)["terminal"]:
202
+ raise ValueError("the input hyperedge is terminal.")
203
+ if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
204
+ print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
205
+ raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
206
+ if DEBUG:
207
+ for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
208
+ other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
209
+ attr = deepcopy(self.lhs.node_attr(other_node))
210
+ attr.pop('ext_id')
211
+ if hg.node_attr(each_node) != attr:
212
+ raise ValueError('node attributes are inconsistent.')
213
+
214
+ # order of nodes that belong to the non-terminal edge in hg
215
+ nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
216
+ nt_order_dict_inv = {} # order -> hg_node
217
+ for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
218
+ nt_order_dict[each_node] = each_idx
219
+ nt_order_dict_inv[each_idx] = each_node
220
+
221
+ # construct a node_map_rhs: rhs -> new hg
222
+ node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
223
+ node_idx = hg.num_nodes
224
+ for each_node in self.rhs.nodes:
225
+ if "ext_id" in self.rhs.node_attr(each_node):
226
+ node_map_rhs[each_node] \
227
+ = nt_order_dict_inv[
228
+ self.rhs.node_attr(each_node)["ext_id"]]
229
+ else:
230
+ node_map_rhs[each_node] = f"bond_{node_idx}"
231
+ node_idx += 1
232
+
233
+ # delete non-terminal
234
+ hg.remove_edge(edge)
235
+
236
+ # add nodes to hg
237
+ for each_node in self.rhs.nodes:
238
+ hg.add_node(node_map_rhs[each_node],
239
+ attr_dict=self.rhs.node_attr(each_node))
240
+
241
+ # add hyperedges to hg
242
+ for each_edge in self.rhs.edges:
243
+ node_list_hg = []
244
+ for each_node in self.rhs.nodes_in_edge(each_edge):
245
+ node_list_hg.append(node_map_rhs[each_node])
246
+ edge_id = hg.add_edge(
247
+ node_list_hg,
248
+ attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
249
+ if "nt_idx" in hg.edge_attr(edge_id):
250
+ nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
251
+ nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
252
+ return hg, nt_edge_list
253
+
254
+ def revert(self, hg: Hypergraph, return_subhg=False):
255
+ ''' revert applying this production rule.
256
+ i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
257
+ this method replaces the subhypergraph with a non-terminal hyperedge.
258
+
259
+ Parameters
260
+ ----------
261
+ hg : Hypergraph
262
+ hypergraph to be reverted
263
+ return_subhg : bool
264
+ if True, the removed subhypergraph will be returned.
265
+
266
+ Returns
267
+ -------
268
+ hg : Hypergraph
269
+ the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
270
+ success : bool
271
+ this indicates whether reverting is successed or not.
272
+ '''
273
+ gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
274
+ edge_match=_edge_match)
275
+ try:
276
+ # in case when the matched subhg is connected to the other part via external nodes and more.
277
+ not_iso = True
278
+ while not_iso:
279
+ isomap = next(gm.subgraph_isomorphisms_iter())
280
+ adj_node_set = set([]) # reachable nodes from the internal nodes
281
+ subhg_node_set = set(isomap.keys()) # nodes in subhg
282
+ for each_node in subhg_node_set:
283
+ adj_node_set.add(each_node)
284
+ if isomap[each_node] not in self.ext_node.values():
285
+ adj_node_set.update(hg.hg.adj[each_node])
286
+ if adj_node_set == subhg_node_set:
287
+ not_iso = False
288
+ else:
289
+ if return_subhg:
290
+ return hg, False, Hypergraph()
291
+ else:
292
+ return hg, False
293
+ inv_isomap = {v: k for k, v in isomap.items()}
294
+ '''
295
+ isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
296
+ 'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
297
+ where keys come from `hg` and values come from `self.rhs`
298
+ '''
299
+ except StopIteration:
300
+ if return_subhg:
301
+ return hg, False, Hypergraph()
302
+ else:
303
+ return hg, False
304
+
305
+ if return_subhg:
306
+ subhg = Hypergraph()
307
+ for each_node in hg.nodes:
308
+ if each_node in isomap:
309
+ subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
310
+ for each_edge in hg.edges:
311
+ if each_edge in isomap:
312
+ subhg.add_edge(hg.nodes_in_edge(each_edge),
313
+ attr_dict=hg.edge_attr(each_edge),
314
+ edge_name=each_edge)
315
+ subhg.edge_idx = hg.edge_idx
316
+
317
+ # remove subhg except for the externael nodes
318
+ for each_key, each_val in isomap.items():
319
+ if each_key.startswith('e'):
320
+ hg.remove_edge(each_key)
321
+ for each_key, each_val in isomap.items():
322
+ if each_key.startswith('bond_'):
323
+ if each_val not in self.ext_node.values():
324
+ hg.remove_node(each_key)
325
+
326
+ # add non-terminal hyperedge
327
+ nt_node_list = []
328
+ for each_ext_id in self.ext_node.keys():
329
+ nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
330
+
331
+ hg.add_edge(nt_node_list,
332
+ attr_dict=dict(
333
+ terminal=False,
334
+ symbol=self.lhs_nt_symbol))
335
+ if return_subhg:
336
+ return hg, True, subhg
337
+ else:
338
+ return hg, True
339
+
340
+
341
+ class ProductionRuleCorpus(object):
342
+
343
+ '''
344
+ A corpus of production rules.
345
+ This class maintains
346
+ (i) list of unique production rules,
347
+ (ii) list of unique edge symbols (both terminal and non-terminal), and
348
+ (iii) list of unique node symbols.
349
+
350
+ Attributes
351
+ ----------
352
+ prod_rule_list : list
353
+ list of unique production rules
354
+ edge_symbol_list : list
355
+ list of unique symbols (including both terminal and non-terminal)
356
+ node_symbol_list : list
357
+ list of node symbols
358
+ nt_symbol_list : list
359
+ list of unique lhs symbols
360
+ ext_id_list : list
361
+ list of ext_ids
362
+ lhs_in_prod_rule : array
363
+ a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
364
+ '''
365
+
366
+ def __init__(self):
367
+ self.prod_rule_list = []
368
+ self.edge_symbol_list = []
369
+ self.edge_symbol_dict = {}
370
+ self.node_symbol_list = []
371
+ self.node_symbol_dict = {}
372
+ self.nt_symbol_list = []
373
+ self.ext_id_list = []
374
+ self._lhs_in_prod_rule = None
375
+ self.lhs_in_prod_rule_row_list = []
376
+ self.lhs_in_prod_rule_col_list = []
377
+
378
+ @property
379
+ def lhs_in_prod_rule(self):
380
+ if self._lhs_in_prod_rule is None:
381
+ self._lhs_in_prod_rule = torch.sparse.FloatTensor(
382
+ torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
383
+ torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
384
+ torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
385
+ ).to_dense()
386
+ return self._lhs_in_prod_rule
387
+
388
+ @property
389
+ def num_prod_rule(self):
390
+ ''' return the number of production rules
391
+
392
+ Returns
393
+ -------
394
+ int : the number of unique production rules
395
+ '''
396
+ return len(self.prod_rule_list)
397
+
398
+ @property
399
+ def start_rule_list(self):
400
+ ''' return a list of start rules
401
+
402
+ Returns
403
+ -------
404
+ list : list of start rules
405
+ '''
406
+ start_rule_list = []
407
+ for each_prod_rule in self.prod_rule_list:
408
+ if each_prod_rule.is_start_rule:
409
+ start_rule_list.append(each_prod_rule)
410
+ return start_rule_list
411
+
412
+ @property
413
+ def num_edge_symbol(self):
414
+ return len(self.edge_symbol_list)
415
+
416
+ @property
417
+ def num_node_symbol(self):
418
+ return len(self.node_symbol_list)
419
+
420
+ @property
421
+ def num_ext_id(self):
422
+ return len(self.ext_id_list)
423
+
424
+ def construct_feature_vectors(self):
425
+ ''' this method constructs feature vectors for the production rules collected so far.
426
+ currently, NTSymbol and TSymbol are treated in the same manner.
427
+ '''
428
+ feature_id_dict = {}
429
+ feature_id_dict['TSymbol'] = 0
430
+ feature_id_dict['NTSymbol'] = 1
431
+ feature_id_dict['BondSymbol'] = 2
432
+ for each_edge_symbol in self.edge_symbol_list:
433
+ for each_attr in each_edge_symbol.__dict__.keys():
434
+ each_val = each_edge_symbol.__dict__[each_attr]
435
+ if isinstance(each_val, list):
436
+ each_val = tuple(each_val)
437
+ if (each_attr, each_val) not in feature_id_dict:
438
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
439
+
440
+ for each_node_symbol in self.node_symbol_list:
441
+ for each_attr in each_node_symbol.__dict__.keys():
442
+ each_val = each_node_symbol.__dict__[each_attr]
443
+ if isinstance(each_val, list):
444
+ each_val = tuple(each_val)
445
+ if (each_attr, each_val) not in feature_id_dict:
446
+ feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
447
+ for each_ext_id in self.ext_id_list:
448
+ feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
449
+ dim = len(feature_id_dict)
450
+
451
+ feature_dict = {}
452
+ for each_edge_symbol in self.edge_symbol_list:
453
+ idx_list = []
454
+ idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
455
+ for each_attr in each_edge_symbol.__dict__.keys():
456
+ each_val = each_edge_symbol.__dict__[each_attr]
457
+ if isinstance(each_val, list):
458
+ each_val = tuple(each_val)
459
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
460
+ feature = torch.sparse.LongTensor(
461
+ torch.LongTensor([idx_list]),
462
+ torch.ones(len(idx_list)),
463
+ torch.Size([len(feature_id_dict)])
464
+ )
465
+ feature_dict[each_edge_symbol] = feature
466
+
467
+ for each_node_symbol in self.node_symbol_list:
468
+ idx_list = []
469
+ idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
470
+ for each_attr in each_node_symbol.__dict__.keys():
471
+ each_val = each_node_symbol.__dict__[each_attr]
472
+ if isinstance(each_val, list):
473
+ each_val = tuple(each_val)
474
+ idx_list.append(feature_id_dict[(each_attr, each_val)])
475
+ feature = torch.sparse.LongTensor(
476
+ torch.LongTensor([idx_list]),
477
+ torch.ones(len(idx_list)),
478
+ torch.Size([len(feature_id_dict)])
479
+ )
480
+ feature_dict[each_node_symbol] = feature
481
+ for each_ext_id in self.ext_id_list:
482
+ idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
483
+ feature_dict[('ext_id', each_ext_id)] \
484
+ = torch.sparse.LongTensor(
485
+ torch.LongTensor([idx_list]),
486
+ torch.ones(len(idx_list)),
487
+ torch.Size([len(feature_id_dict)])
488
+ )
489
+ return feature_dict, dim
490
+
491
+ def edge_symbol_idx(self, symbol):
492
+ return self.edge_symbol_dict[symbol]
493
+
494
+ def node_symbol_idx(self, symbol):
495
+ return self.node_symbol_dict[symbol]
496
+
497
+ def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
498
+ """ return whether the input production rule is new or not, and its production rule id.
499
+ Production rules are regarded as the same if
500
+ i) there exists a one-to-one mapping of nodes and edges, and
501
+ ii) all the attributes associated with nodes and hyperedges are the same.
502
+
503
+ Parameters
504
+ ----------
505
+ prod_rule : ProductionRule
506
+
507
+ Returns
508
+ -------
509
+ prod_rule_id : int
510
+ production rule index. if new, a new index will be assigned.
511
+ prod_rule : ProductionRule
512
+ """
513
+ num_lhs = len(self.nt_symbol_list)
514
+ for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
515
+ is_same, isomap = prod_rule.is_same(each_prod_rule)
516
+ if is_same:
517
+ # we do not care about edge and node names, but care about the order of non-terminal edges.
518
+ for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
519
+ if key.startswith("bond_"):
520
+ continue
521
+
522
+ # rewrite `nt_idx` in `prod_rule` for further processing
523
+ if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
524
+ if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
525
+ raise ValueError
526
+ prod_rule.rhs.set_edge_attr(
527
+ val,
528
+ {'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
529
+ return each_idx, prod_rule
530
+ self.prod_rule_list.append(prod_rule)
531
+ self._update_edge_symbol_list(prod_rule)
532
+ self._update_node_symbol_list(prod_rule)
533
+ self._update_ext_id_list(prod_rule)
534
+
535
+ lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
536
+ self.lhs_in_prod_rule_row_list.append(lhs_idx)
537
+ self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
538
+ self._lhs_in_prod_rule = None
539
+ return len(self.prod_rule_list)-1, prod_rule
540
+
541
+ def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
542
+ return self.prod_rule_list[prod_rule_idx]
543
+
544
+ def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
545
+ ''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
546
+
547
+ Parameters
548
+ ----------
549
+ unmasked_logit_array : array-like, length `num_prod_rule`
550
+ nt_symbol : NTSymbol
551
+ '''
552
+ if not isinstance(unmasked_logit_array, np.ndarray):
553
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
554
+ if deterministic:
555
+ prob = masked_softmax(unmasked_logit_array,
556
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
557
+ return self.prod_rule_list[np.argmax(prob)]
558
+ else:
559
+ return np.random.choice(
560
+ self.prod_rule_list, 1,
561
+ p=masked_softmax(unmasked_logit_array,
562
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
563
+
564
+ def masked_logprob(self, unmasked_logit_array, nt_symbol):
565
+ if not isinstance(unmasked_logit_array, np.ndarray):
566
+ unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
567
+ prob = masked_softmax(unmasked_logit_array,
568
+ self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
569
+ return np.log(prob)
570
+
571
+ def _update_edge_symbol_list(self, prod_rule: ProductionRule):
572
+ ''' update edge symbol list
573
+
574
+ Parameters
575
+ ----------
576
+ prod_rule : ProductionRule
577
+ '''
578
+ if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
579
+ self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
580
+
581
+ for each_edge in prod_rule.rhs.edges:
582
+ if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
583
+ edge_symbol_idx = len(self.edge_symbol_list)
584
+ self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
585
+ self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
586
+ else:
587
+ edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
588
+ prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
589
+ pass
590
+
591
+ def _update_node_symbol_list(self, prod_rule: ProductionRule):
592
+ ''' update node symbol list
593
+
594
+ Parameters
595
+ ----------
596
+ prod_rule : ProductionRule
597
+ '''
598
+ for each_node in prod_rule.rhs.nodes:
599
+ if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
600
+ node_symbol_idx = len(self.node_symbol_list)
601
+ self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
602
+ self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
603
+ else:
604
+ node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
605
+ prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
606
+
607
+ def _update_ext_id_list(self, prod_rule: ProductionRule):
608
+ for each_node in prod_rule.rhs.nodes:
609
+ if 'ext_id' in prod_rule.rhs.node_attr(each_node):
610
+ if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
611
+ self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
612
+
613
+
614
+ class HyperedgeReplacementGrammar(GraphGrammarBase):
615
+ """
616
+ Learn a hyperedge replacement grammar from a set of hypergraphs.
617
+
618
+ Attributes
619
+ ----------
620
+ prod_rule_list : list of ProductionRule
621
+ production rules learned from the input hypergraphs
622
+ """
623
+ def __init__(self,
624
+ tree_decomposition=molecular_tree_decomposition,
625
+ ignore_order=False, **kwargs):
626
+ from functools import partial
627
+ self.prod_rule_corpus = ProductionRuleCorpus()
628
+ self.clique_tree_corpus = CliqueTreeCorpus()
629
+ self.ignore_order = ignore_order
630
+ self.tree_decomposition = partial(tree_decomposition, **kwargs)
631
+
632
+ @property
633
+ def num_prod_rule(self):
634
+ ''' return the number of production rules
635
+
636
+ Returns
637
+ -------
638
+ int : the number of unique production rules
639
+ '''
640
+ return self.prod_rule_corpus.num_prod_rule
641
+
642
+ @property
643
+ def start_rule_list(self):
644
+ ''' return a list of start rules
645
+
646
+ Returns
647
+ -------
648
+ list : list of start rules
649
+ '''
650
+ return self.prod_rule_corpus.start_rule_list
651
+
652
+ @property
653
+ def prod_rule_list(self):
654
+ return self.prod_rule_corpus.prod_rule_list
655
+
656
+ def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
657
+ """ learn from a list of hypergraphs
658
+
659
+ Parameters
660
+ ----------
661
+ hg_list : list of Hypergraph
662
+
663
+ Returns
664
+ -------
665
+ prod_rule_seq_list : list of integers
666
+ each element corresponds to a sequence of production rules to generate each hypergraph.
667
+ """
668
+ prod_rule_seq_list = []
669
+ idx = 0
670
+ for each_idx, each_hg in enumerate(hg_list):
671
+ clique_tree = self.tree_decomposition(each_hg)
672
+
673
+ # get a pair of myself and children
674
+ root_node = _find_root(clique_tree)
675
+ clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
676
+ prod_rule_seq = []
677
+ stack = []
678
+
679
+ children = sorted(list(clique_tree[root_node].keys()))
680
+
681
+ # extract a temporary production rule
682
+ prod_rule = extract_prod_rule(
683
+ None,
684
+ clique_tree.nodes[root_node]["subhg"],
685
+ [clique_tree.nodes[each_child]["subhg"]
686
+ for each_child in children],
687
+ clique_tree.nodes[root_node].get('subhg_idx', None))
688
+
689
+ # update the production rule list
690
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
691
+ children = reorder_children(root_node,
692
+ children,
693
+ prod_rule,
694
+ clique_tree)
695
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
696
+ prod_rule_seq.append(prod_rule_id)
697
+
698
+ while len(stack) != 0:
699
+ # get a triple of parent, myself, and children
700
+ parent, myself = stack.pop()
701
+ children = sorted(list(dict(clique_tree[myself]).keys()))
702
+ children.remove(parent)
703
+
704
+ # extract a temp prod rule
705
+ prod_rule = extract_prod_rule(
706
+ clique_tree.nodes[parent]["subhg"],
707
+ clique_tree.nodes[myself]["subhg"],
708
+ [clique_tree.nodes[each_child]["subhg"]
709
+ for each_child in children],
710
+ clique_tree.nodes[myself].get('subhg_idx', None))
711
+
712
+ # update the prod rule list
713
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
714
+ children = reorder_children(myself,
715
+ children,
716
+ prod_rule,
717
+ clique_tree)
718
+ stack.extend([(myself, each_child)
719
+ for each_child in children[::-1]])
720
+ prod_rule_seq.append(prod_rule_id)
721
+ prod_rule_seq_list.append(prod_rule_seq)
722
+ if (each_idx+1) % print_freq == 0:
723
+ msg = f'#(molecules processed)={each_idx+1}\t'\
724
+ f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
725
+ logger(msg)
726
+ if each_idx > max_mol:
727
+ break
728
+
729
+ print(f'corpus_size = {self.clique_tree_corpus.size}')
730
+ return prod_rule_seq_list
731
+
732
+ def sample(self, z, deterministic=False):
733
+ """ sample a new hypergraph from HRG.
734
+
735
+ Parameters
736
+ ----------
737
+ z : array-like, shape (len, num_prod_rule)
738
+ logit
739
+ deterministic : bool
740
+ if True, deterministic sampling
741
+
742
+ Returns
743
+ -------
744
+ Hypergraph
745
+ """
746
+ seq_idx = 0
747
+ stack = []
748
+ z = z[:, :-1]
749
+ init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
750
+ is_aromatic=False,
751
+ bond_symbol_list=[]),
752
+ deterministic=deterministic)
753
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
754
+ stack = deepcopy(nt_edge_list[::-1])
755
+ while len(stack) != 0 and seq_idx < z.shape[0]-1:
756
+ seq_idx += 1
757
+ nt_edge = stack.pop()
758
+ nt_symbol = hg.edge_attr(nt_edge)['symbol']
759
+ prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
760
+ hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
761
+ stack.extend(nt_edge_list[::-1])
762
+ if len(stack) != 0:
763
+ raise RuntimeError(f'{len(stack)} non-terminals are left.')
764
+ return hg
765
+
766
+ def construct(self, prod_rule_seq):
767
+ """ construct a hypergraph following `prod_rule_seq`
768
+
769
+ Parameters
770
+ ----------
771
+ prod_rule_seq : list of integers
772
+ a sequence of production rules.
773
+
774
+ Returns
775
+ -------
776
+ UndirectedHypergraph
777
+ """
778
+ seq_idx = 0
779
+ init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
780
+ hg, nt_edge_list = init_prod_rule.applied_to(None, None)
781
+ stack = deepcopy(nt_edge_list[::-1])
782
+ while len(stack) != 0:
783
+ seq_idx += 1
784
+ nt_edge = stack.pop()
785
+ hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
786
+ stack.extend(nt_edge_list[::-1])
787
+ return hg
788
+
789
+ def update_prod_rule_list(self, prod_rule):
790
+ """ return whether the input production rule is new or not, and its production rule id.
791
+ Production rules are regarded as the same if
792
+ i) there exists a one-to-one mapping of nodes and edges, and
793
+ ii) all the attributes associated with nodes and hyperedges are the same.
794
+
795
+ Parameters
796
+ ----------
797
+ prod_rule : ProductionRule
798
+
799
+ Returns
800
+ -------
801
+ is_new : bool
802
+ if True, this production rule is new
803
+ prod_rule_id : int
804
+ production rule index. if new, a new index will be assigned.
805
+ """
806
+ return self.prod_rule_corpus.append(prod_rule)
807
+
808
+
809
+ class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
810
+ '''
811
+ This class learns HRG incrementally leveraging the previously obtained production rules.
812
+ '''
813
+ def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
814
+ self.prod_rule_list = []
815
+ self.tree_decomposition = tree_decomposition
816
+ self.ignore_order = ignore_order
817
+
818
+ def learn(self, hg_list):
819
+ """ learn from a list of hypergraphs
820
+
821
+ Parameters
822
+ ----------
823
+ hg_list : list of UndirectedHypergraph
824
+
825
+ Returns
826
+ -------
827
+ prod_rule_seq_list : list of integers
828
+ each element corresponds to a sequence of production rules to generate each hypergraph.
829
+ """
830
+ prod_rule_seq_list = []
831
+ for each_hg in hg_list:
832
+ clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
833
+
834
+ prod_rule_seq = []
835
+ stack = []
836
+
837
+ # get a pair of myself and children
838
+ children = sorted(list(clique_tree[root_node].keys()))
839
+
840
+ # extract a temporary production rule
841
+ prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
842
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
843
+
844
+ # update the production rule list
845
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
846
+ children = reorder_children(root_node, children, prod_rule, clique_tree)
847
+ stack.extend([(root_node, each_child) for each_child in children[::-1]])
848
+ prod_rule_seq.append(prod_rule_id)
849
+
850
+ while len(stack) != 0:
851
+ # get a triple of parent, myself, and children
852
+ parent, myself = stack.pop()
853
+ children = sorted(list(dict(clique_tree[myself]).keys()))
854
+ children.remove(parent)
855
+
856
+ # extract a temp prod rule
857
+ prod_rule = extract_prod_rule(
858
+ clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
859
+ [clique_tree.nodes[each_child]["subhg"] for each_child in children])
860
+
861
+ # update the prod rule list
862
+ prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
863
+ children = reorder_children(myself, children, prod_rule, clique_tree)
864
+ stack.extend([(myself, each_child) for each_child in children[::-1]])
865
+ prod_rule_seq.append(prod_rule_id)
866
+ prod_rule_seq_list.append(prod_rule_seq)
867
+ self._compute_stats()
868
+ return prod_rule_seq_list
869
+
870
+
871
+ def reorder_children(myself, children, prod_rule, clique_tree):
872
+ """ reorder children so that they match the order in `prod_rule`.
873
+
874
+ Parameters
875
+ ----------
876
+ myself : int
877
+ children : list of int
878
+ prod_rule : ProductionRule
879
+ clique_tree : nx.Graph
880
+
881
+ Returns
882
+ -------
883
+ new_children : list of str
884
+ reordered children
885
+ """
886
+ perm = {} # key : `nt_idx`, val : child
887
+ for each_edge in prod_rule.rhs.edges:
888
+ if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
889
+ for each_child in children:
890
+ common_node_set = set(
891
+ common_node_list(clique_tree.nodes[myself]["subhg"],
892
+ clique_tree.nodes[each_child]["subhg"])[0])
893
+ if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
894
+ assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
895
+ perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
896
+ new_children = []
897
+ assert len(perm) == len(children)
898
+ for i in range(len(perm)):
899
+ new_children.append(perm[i])
900
+ return new_children
901
+
902
+
903
+ def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
904
+ """ extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
905
+
906
+ Parameters
907
+ ----------
908
+ parent_hg : Hypergraph
909
+ myself_hg : Hypergraph
910
+ children_hg_list : list of Hypergraph
911
+
912
+ Returns
913
+ -------
914
+ ProductionRule, consisting of
915
+ lhs : Hypergraph or None
916
+ rhs : Hypergraph
917
+ """
918
+ def _add_ext_node(hg, ext_nodes):
919
+ """ mark nodes to be external (ordered ids are assigned)
920
+
921
+ Parameters
922
+ ----------
923
+ hg : UndirectedHypergraph
924
+ ext_nodes : list of str
925
+ list of external nodes
926
+
927
+ Returns
928
+ -------
929
+ hg : Hypergraph
930
+ nodes in `ext_nodes` are marked to be external
931
+ """
932
+ ext_id = 0
933
+ ext_id_exists = []
934
+ for each_node in ext_nodes:
935
+ ext_id_exists.append('ext_id' in hg.node_attr(each_node))
936
+ if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
937
+ raise ValueError
938
+ if not all(ext_id_exists):
939
+ for each_node in ext_nodes:
940
+ hg.node_attr(each_node)['ext_id'] = ext_id
941
+ ext_id += 1
942
+ return hg
943
+
944
+ def _check_aromatic(hg, node_list):
945
+ is_aromatic = False
946
+ node_aromatic_list = []
947
+ for each_node in node_list:
948
+ if hg.node_attr(each_node)['symbol'].is_aromatic:
949
+ is_aromatic = True
950
+ node_aromatic_list.append(True)
951
+ else:
952
+ node_aromatic_list.append(False)
953
+ return is_aromatic, node_aromatic_list
954
+
955
+ def _check_ring(hg):
956
+ for each_edge in hg.edges:
957
+ if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
958
+ return False
959
+ return True
960
+
961
+ if parent_hg is None:
962
+ lhs = Hypergraph()
963
+ node_list = []
964
+ else:
965
+ lhs = Hypergraph()
966
+ node_list, edge_exists = common_node_list(parent_hg, myself_hg)
967
+ for each_node in node_list:
968
+ lhs.add_node(each_node,
969
+ deepcopy(myself_hg.node_attr(each_node)))
970
+ is_aromatic, _ = _check_aromatic(parent_hg, node_list)
971
+ for_ring = _check_ring(myself_hg)
972
+ bond_symbol_list = []
973
+ for each_node in node_list:
974
+ bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
975
+ lhs.add_edge(
976
+ node_list,
977
+ attr_dict=dict(
978
+ terminal=False,
979
+ edge_exists=edge_exists,
980
+ symbol=NTSymbol(
981
+ degree=len(node_list),
982
+ is_aromatic=is_aromatic,
983
+ bond_symbol_list=bond_symbol_list,
984
+ for_ring=for_ring)))
985
+ try:
986
+ lhs = _add_ext_node(lhs, node_list)
987
+ except ValueError:
988
+ import pdb; pdb.set_trace()
989
+
990
+ rhs = remove_tmp_edge(deepcopy(myself_hg))
991
+ #rhs = remove_ext_node(rhs)
992
+ #rhs = remove_nt_edge(rhs)
993
+ try:
994
+ rhs = _add_ext_node(rhs, node_list)
995
+ except ValueError:
996
+ import pdb; pdb.set_trace()
997
+
998
+ nt_idx = 0
999
+ if children_hg_list is not None:
1000
+ for each_child_hg in children_hg_list:
1001
+ node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
1002
+ is_aromatic, _ = _check_aromatic(myself_hg, node_list)
1003
+ for_ring = _check_ring(each_child_hg)
1004
+ bond_symbol_list = []
1005
+ for each_node in node_list:
1006
+ bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
1007
+ rhs.add_edge(
1008
+ node_list,
1009
+ attr_dict=dict(
1010
+ terminal=False,
1011
+ nt_idx=nt_idx,
1012
+ edge_exists=edge_exists,
1013
+ symbol=NTSymbol(degree=len(node_list),
1014
+ is_aromatic=is_aromatic,
1015
+ bond_symbol_list=bond_symbol_list,
1016
+ for_ring=for_ring)))
1017
+ nt_idx += 1
1018
+ prod_rule = ProductionRule(lhs, rhs)
1019
+ prod_rule.subhg_idx = subhg_idx
1020
+ if DEBUG:
1021
+ if sorted(list(prod_rule.ext_node.keys())) \
1022
+ != list(np.arange(len(prod_rule.ext_node))):
1023
+ raise RuntimeError('ext_id is not continuous')
1024
+ return prod_rule
1025
+
1026
+
1027
+ def _find_root(clique_tree):
1028
+ max_node = None
1029
+ num_nodes_max = -np.inf
1030
+ for each_node in clique_tree.nodes:
1031
+ if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
1032
+ max_node = each_node
1033
+ num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
1034
+ '''
1035
+ children = sorted(list(clique_tree[each_node].keys()))
1036
+ prod_rule = extract_prod_rule(None,
1037
+ clique_tree.nodes[each_node]["subhg"],
1038
+ [clique_tree.nodes[each_child]["subhg"]
1039
+ for each_child in children])
1040
+ for each_start_rule in start_rule_list:
1041
+ if prod_rule.is_same(each_start_rule):
1042
+ return each_node
1043
+ '''
1044
+ return max_node
1045
+
1046
+ def remove_ext_node(hg):
1047
+ for each_node in hg.nodes:
1048
+ hg.node_attr(each_node).pop('ext_id', None)
1049
+ return hg
1050
+
1051
+ def remove_nt_edge(hg):
1052
+ remove_edge_list = []
1053
+ for each_edge in hg.edges:
1054
+ if not hg.edge_attr(each_edge)['terminal']:
1055
+ remove_edge_list.append(each_edge)
1056
+ hg.remove_edges(remove_edge_list)
1057
+ return hg
1058
+
1059
+ def remove_tmp_edge(hg):
1060
+ remove_edge_list = []
1061
+ for each_edge in hg.edges:
1062
+ if hg.edge_attr(each_edge).get('tmp', False):
1063
+ remove_edge_list.append(each_edge)
1064
+ hg.remove_edges(remove_edge_list)
1065
+ return hg
models/mhg_model/graph_grammar/graph_grammar/symbols.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+
15
+ """ Title """
16
+
17
+ __author__ = "Hiroshi Kajino <[email protected]>"
18
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
19
+ __version__ = "0.1"
20
+ __date__ = "Jan 1 2018"
21
+
22
+ from typing import List
23
+
24
+ class TSymbol(object):
25
+
26
+ ''' terminal symbol
27
+
28
+ Attributes
29
+ ----------
30
+ degree : int
31
+ the number of nodes in a hyperedge
32
+ is_aromatic : bool
33
+ whether or not the hyperedge is in an aromatic ring
34
+ symbol : str
35
+ atomic symbol
36
+ num_explicit_Hs : int
37
+ the number of hydrogens associated to this hyperedge
38
+ formal_charge : int
39
+ charge
40
+ chirality : int
41
+ chirality
42
+ '''
43
+
44
+ def __init__(self, degree, is_aromatic,
45
+ symbol, num_explicit_Hs, formal_charge, chirality):
46
+ self.degree = degree
47
+ self.is_aromatic = is_aromatic
48
+ self.symbol = symbol
49
+ self.num_explicit_Hs = num_explicit_Hs
50
+ self.formal_charge = formal_charge
51
+ self.chirality = chirality
52
+
53
+ @property
54
+ def terminal(self):
55
+ return True
56
+
57
+ def __eq__(self, other):
58
+ if not isinstance(other, TSymbol):
59
+ return False
60
+ if self.degree != other.degree:
61
+ return False
62
+ if self.is_aromatic != other.is_aromatic:
63
+ return False
64
+ if self.symbol != other.symbol:
65
+ return False
66
+ if self.num_explicit_Hs != other.num_explicit_Hs:
67
+ return False
68
+ if self.formal_charge != other.formal_charge:
69
+ return False
70
+ if self.chirality != other.chirality:
71
+ return False
72
+ return True
73
+
74
+ def __hash__(self):
75
+ return self.__str__().__hash__()
76
+
77
+ def __str__(self):
78
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
79
+ f'symbol={self.symbol}, '\
80
+ f'num_explicit_Hs={self.num_explicit_Hs}, '\
81
+ f'formal_charge={self.formal_charge}, chirality={self.chirality}'
82
+
83
+
84
+ class NTSymbol(object):
85
+
86
+ ''' non-terminal symbol
87
+
88
+ Attributes
89
+ ----------
90
+ degree : int
91
+ degree of the hyperedge
92
+ is_aromatic : bool
93
+ if True, at least one of the associated bonds must be aromatic.
94
+ node_aromatic_list : list of bool
95
+ indicate whether each of the nodes is aromatic or not.
96
+ bond_type_list : list of int
97
+ bond type of each node"
98
+ '''
99
+
100
+ def __init__(self, degree: int, is_aromatic: bool,
101
+ bond_symbol_list: list,
102
+ for_ring=False):
103
+ self.degree = degree
104
+ self.is_aromatic = is_aromatic
105
+ self.for_ring = for_ring
106
+ self.bond_symbol_list = bond_symbol_list
107
+
108
+ @property
109
+ def terminal(self) -> bool:
110
+ return False
111
+
112
+ @property
113
+ def symbol(self):
114
+ return f'NT{self.degree}'
115
+
116
+ def __eq__(self, other) -> bool:
117
+ if not isinstance(other, NTSymbol):
118
+ return False
119
+
120
+ if self.degree != other.degree:
121
+ return False
122
+ if self.is_aromatic != other.is_aromatic:
123
+ return False
124
+ if self.for_ring != other.for_ring:
125
+ return False
126
+ if len(self.bond_symbol_list) != len(other.bond_symbol_list):
127
+ return False
128
+ for each_idx in range(len(self.bond_symbol_list)):
129
+ if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
130
+ return False
131
+ return True
132
+
133
+ def __hash__(self):
134
+ return self.__str__().__hash__()
135
+
136
+ def __str__(self) -> str:
137
+ return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
138
+ f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
139
+ f'for_ring={self.for_ring}'
140
+
141
+
142
+ class BondSymbol(object):
143
+
144
+
145
+ ''' Bond symbol
146
+
147
+ Attributes
148
+ ----------
149
+ is_aromatic : bool
150
+ if True, at least one of the associated bonds must be aromatic.
151
+ bond_type : int
152
+ bond type of each node"
153
+ '''
154
+
155
+ def __init__(self, is_aromatic: bool,
156
+ bond_type: int,
157
+ stereo: int):
158
+ self.is_aromatic = is_aromatic
159
+ self.bond_type = bond_type
160
+ self.stereo = stereo
161
+
162
+ def __eq__(self, other) -> bool:
163
+ if not isinstance(other, BondSymbol):
164
+ return False
165
+
166
+ if self.is_aromatic != other.is_aromatic:
167
+ return False
168
+ if self.bond_type != other.bond_type:
169
+ return False
170
+ if self.stereo != other.stereo:
171
+ return False
172
+ return True
173
+
174
+ def __hash__(self):
175
+ return self.__str__().__hash__()
176
+
177
+ def __str__(self) -> str:
178
+ return f'is_aromatic={self.is_aromatic}, '\
179
+ f'bond_type={self.bond_type}, '\
180
+ f'stereo={self.stereo}, '
models/mhg_model/graph_grammar/graph_grammar/utils.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jun 4 2018"
20
+
21
+ from ..hypergraph import Hypergraph
22
+ from copy import deepcopy
23
+ from typing import List
24
+ import numpy as np
25
+
26
+
27
+ def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
28
+ """ return a list of common nodes
29
+
30
+ Parameters
31
+ ----------
32
+ hg1, hg2 : Hypergraph
33
+
34
+ Returns
35
+ -------
36
+ list of str
37
+ list of common nodes
38
+ """
39
+ if hg1 is None or hg2 is None:
40
+ return [], False
41
+ else:
42
+ node_set = hg1.nodes.intersection(hg2.nodes)
43
+ node_dict = {}
44
+ if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
45
+ for each_node in node_set:
46
+ node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
47
+ else:
48
+ for each_node in node_set:
49
+ node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
50
+ node_list = []
51
+ for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
52
+ node_list.append(each_key)
53
+ edge_name = hg1.has_edge(node_list, ignore_order=True)
54
+ if edge_name:
55
+ if not hg1.edge_attr(edge_name).get('terminal', True):
56
+ node_list = hg1.nodes_in_edge(edge_name)
57
+ return node_list, True
58
+ else:
59
+ return node_list, False
60
+
61
+
62
+ def _node_match(node1, node2):
63
+ # if the nodes are hyperedges, `atom_attr` determines the match
64
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
65
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
66
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
67
+ # bond_symbol
68
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
69
+ else:
70
+ return False
71
+
72
+ def _easy_node_match(node1, node2):
73
+ # if the nodes are hyperedges, `atom_attr` determines the match
74
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
75
+ return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
76
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
77
+ # bond_symbol
78
+ return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
79
+ and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
80
+ else:
81
+ return False
82
+
83
+
84
+ def _node_match_prod_rule(node1, node2, ignore_order=False):
85
+ # if the nodes are hyperedges, `atom_attr` determines the match
86
+ if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
87
+ return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
88
+ elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
89
+ # ext_id, order4hrg, bond_symbol
90
+ if ignore_order:
91
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
92
+ else:
93
+ return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
94
+ and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
95
+ else:
96
+ return False
97
+
98
+
99
+ def _edge_match(edge1, edge2, ignore_order=False):
100
+ #return True
101
+ if ignore_order:
102
+ return True
103
+ else:
104
+ return edge1["order"] == edge2["order"]
105
+
106
+ def masked_softmax(logit, mask):
107
+ ''' compute a probability distribution from logit
108
+
109
+ Parameters
110
+ ----------
111
+ logit : array-like, length D
112
+ each element indicates how each dimension is likely to be chosen
113
+ (the larger, the more likely)
114
+ mask : array-like, length D
115
+ each element is either 0 or 1.
116
+ if 0, the dimension is ignored
117
+ when computing the probability distribution.
118
+
119
+ Returns
120
+ -------
121
+ prob_dist : array, length D
122
+ probability distribution computed from logit.
123
+ if `mask[d] = 0`, `prob_dist[d] = 0`.
124
+ '''
125
+ if logit.shape != mask.shape:
126
+ raise ValueError('logit and mask must have the same shape')
127
+ c = np.max(logit)
128
+ exp_logit = np.exp(logit - c) * mask
129
+ sum_exp_logit = exp_logit @ mask
130
+ return exp_logit / sum_exp_logit
models/mhg_model/graph_grammar/hypergraph.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 31 2018"
20
+
21
+ from copy import deepcopy
22
+ from typing import List, Dict, Tuple
23
+ import networkx as nx
24
+ import numpy as np
25
+ import os
26
+
27
+
28
+ class Hypergraph(object):
29
+ '''
30
+ A class of a hypergraph.
31
+ Each hyperedge can be ordered. For the ordered case,
32
+ edges adjacent to the hyperedge node are labeled by their orders.
33
+
34
+ Attributes
35
+ ----------
36
+ hg : nx.Graph
37
+ a bipartite graph representation of a hypergraph
38
+ edge_idx : int
39
+ total number of hyperedges that exist so far
40
+ '''
41
+ def __init__(self):
42
+ self.hg = nx.Graph()
43
+ self.edge_idx = 0
44
+ self.nodes = set([])
45
+ self.num_nodes = 0
46
+ self.edges = set([])
47
+ self.num_edges = 0
48
+ self.nodes_in_edge_dict = {}
49
+
50
+ def add_node(self, node: str, attr_dict=None):
51
+ ''' add a node to hypergraph
52
+
53
+ Parameters
54
+ ----------
55
+ node : str
56
+ node name
57
+ attr_dict : dict
58
+ dictionary of node attributes
59
+ '''
60
+ self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
61
+ if node not in self.nodes:
62
+ self.num_nodes += 1
63
+ self.nodes.add(node)
64
+
65
+ def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
66
+ ''' add an edge consisting of nodes `node_list`
67
+
68
+ Parameters
69
+ ----------
70
+ node_list : list
71
+ ordered list of nodes that consist the edge
72
+ attr_dict : dict
73
+ dictionary of edge attributes
74
+ '''
75
+ if edge_name is None:
76
+ edge = 'e{}'.format(self.edge_idx)
77
+ else:
78
+ assert edge_name not in self.edges
79
+ edge = edge_name
80
+ self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
81
+ if edge not in self.edges:
82
+ self.num_edges += 1
83
+ self.edges.add(edge)
84
+ self.nodes_in_edge_dict[edge] = node_list
85
+ if type(node_list) == list:
86
+ for node_idx, each_node in enumerate(node_list):
87
+ self.hg.add_edge(edge, each_node, order=node_idx)
88
+ if each_node not in self.nodes:
89
+ self.num_nodes += 1
90
+ self.nodes.add(each_node)
91
+
92
+ elif type(node_list) == set:
93
+ for each_node in node_list:
94
+ self.hg.add_edge(edge, each_node, order=-1)
95
+ if each_node not in self.nodes:
96
+ self.num_nodes += 1
97
+ self.nodes.add(each_node)
98
+ else:
99
+ raise ValueError
100
+ self.edge_idx += 1
101
+ return edge
102
+
103
+ def remove_node(self, node: str, remove_connected_edges=True):
104
+ ''' remove a node
105
+
106
+ Parameters
107
+ ----------
108
+ node : str
109
+ node name
110
+ remove_connected_edges : bool
111
+ if True, remove edges that are adjacent to the node
112
+ '''
113
+ if remove_connected_edges:
114
+ connected_edges = deepcopy(self.adj_edges(node))
115
+ for each_edge in connected_edges:
116
+ self.remove_edge(each_edge)
117
+ self.hg.remove_node(node)
118
+ self.num_nodes -= 1
119
+ self.nodes.remove(node)
120
+
121
+ def remove_nodes(self, node_iter, remove_connected_edges=True):
122
+ ''' remove a set of nodes
123
+
124
+ Parameters
125
+ ----------
126
+ node_iter : iterator of strings
127
+ nodes to be removed
128
+ remove_connected_edges : bool
129
+ if True, remove edges that are adjacent to the node
130
+ '''
131
+ for each_node in node_iter:
132
+ self.remove_node(each_node, remove_connected_edges)
133
+
134
+ def remove_edge(self, edge: str):
135
+ ''' remove an edge
136
+
137
+ Parameters
138
+ ----------
139
+ edge : str
140
+ edge to be removed
141
+ '''
142
+ self.hg.remove_node(edge)
143
+ self.edges.remove(edge)
144
+ self.num_edges -= 1
145
+ self.nodes_in_edge_dict.pop(edge)
146
+
147
+ def remove_edges(self, edge_iter):
148
+ ''' remove a set of edges
149
+
150
+ Parameters
151
+ ----------
152
+ edge_iter : iterator of strings
153
+ edges to be removed
154
+ '''
155
+ for each_edge in edge_iter:
156
+ self.remove_edge(each_edge)
157
+
158
+ def remove_edges_with_attr(self, edge_attr_dict):
159
+ remove_edge_list = []
160
+ for each_edge in self.edges:
161
+ satisfy = True
162
+ for each_key, each_val in edge_attr_dict.items():
163
+ if not satisfy:
164
+ break
165
+ try:
166
+ if self.edge_attr(each_edge)[each_key] != each_val:
167
+ satisfy = False
168
+ except KeyError:
169
+ satisfy = False
170
+ if satisfy:
171
+ remove_edge_list.append(each_edge)
172
+ self.remove_edges(remove_edge_list)
173
+
174
+ def remove_subhg(self, subhg):
175
+ ''' remove subhypergraph.
176
+ all of the hyperedges are removed.
177
+ each node of subhg is removed if its degree becomes 0 after removing hyperedges.
178
+
179
+ Parameters
180
+ ----------
181
+ subhg : Hypergraph
182
+ '''
183
+ for each_edge in subhg.edges:
184
+ self.remove_edge(each_edge)
185
+ for each_node in subhg.nodes:
186
+ if self.degree(each_node) == 0:
187
+ self.remove_node(each_node)
188
+
189
+ def nodes_in_edge(self, edge):
190
+ ''' return an ordered list of nodes in a given edge.
191
+
192
+ Parameters
193
+ ----------
194
+ edge : str
195
+ edge whose nodes are returned
196
+
197
+ Returns
198
+ -------
199
+ list or set
200
+ ordered list or set of nodes that belong to the edge
201
+ '''
202
+ if edge.startswith('e'):
203
+ return self.nodes_in_edge_dict[edge]
204
+ else:
205
+ adj_node_list = self.hg.adj[edge]
206
+ adj_node_order_list = []
207
+ adj_node_name_list = []
208
+ for each_node in adj_node_list:
209
+ adj_node_order_list.append(adj_node_list[each_node]['order'])
210
+ adj_node_name_list.append(each_node)
211
+ if adj_node_order_list == [-1] * len(adj_node_order_list):
212
+ return set(adj_node_name_list)
213
+ else:
214
+ return [adj_node_name_list[each_idx] for each_idx
215
+ in np.argsort(adj_node_order_list)]
216
+
217
+ def adj_edges(self, node):
218
+ ''' return a dict of adjacent hyperedges
219
+
220
+ Parameters
221
+ ----------
222
+ node : str
223
+
224
+ Returns
225
+ -------
226
+ set
227
+ set of edges that are adjacent to `node`
228
+ '''
229
+ return self.hg.adj[node]
230
+
231
+ def adj_nodes(self, node):
232
+ ''' return a set of adjacent nodes
233
+
234
+ Parameters
235
+ ----------
236
+ node : str
237
+
238
+ Returns
239
+ -------
240
+ set
241
+ set of nodes that are adjacent to `node`
242
+ '''
243
+ node_set = set([])
244
+ for each_adj_edge in self.adj_edges(node):
245
+ node_set.update(set(self.nodes_in_edge(each_adj_edge)))
246
+ node_set.discard(node)
247
+ return node_set
248
+
249
+ def has_edge(self, node_list, ignore_order=False):
250
+ for each_edge in self.edges:
251
+ if ignore_order:
252
+ if set(self.nodes_in_edge(each_edge)) == set(node_list):
253
+ return each_edge
254
+ else:
255
+ if self.nodes_in_edge(each_edge) == node_list:
256
+ return each_edge
257
+ return False
258
+
259
+ def degree(self, node):
260
+ return len(self.hg.adj[node])
261
+
262
+ def degrees(self):
263
+ return {each_node: self.degree(each_node) for each_node in self.nodes}
264
+
265
+ def edge_degree(self, edge):
266
+ return len(self.nodes_in_edge(edge))
267
+
268
+ def edge_degrees(self):
269
+ return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
270
+
271
+ def is_adj(self, node1, node2):
272
+ return node1 in self.adj_nodes(node2)
273
+
274
+ def adj_subhg(self, node, ident_node_dict=None):
275
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
276
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
277
+
278
+ Parameters
279
+ ----------
280
+ node : str
281
+ ident_node_dict : dict
282
+ dict containing identical nodes. see `get_identical_node_dict` for more details
283
+
284
+ Returns
285
+ -------
286
+ subhg : Hypergraph
287
+ """
288
+ if ident_node_dict is None:
289
+ ident_node_dict = self.get_identical_node_dict()
290
+ adj_node_set = set(ident_node_dict[node])
291
+ adj_edge_set = set([])
292
+ for each_node in ident_node_dict[node]:
293
+ adj_edge_set.update(set(self.adj_edges(each_node)))
294
+ fixed_adj_edge_set = deepcopy(adj_edge_set)
295
+ for each_edge in fixed_adj_edge_set:
296
+ other_nodes = self.nodes_in_edge(each_edge)
297
+ adj_node_set.update(other_nodes)
298
+
299
+ # if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
300
+ for each_node in other_nodes:
301
+ for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
302
+ if len(set(self.nodes_in_edge(other_edge)) \
303
+ - set(self.nodes_in_edge(each_edge))) == 0:
304
+ adj_edge_set.update(set([other_edge]))
305
+ subhg = Hypergraph()
306
+ for each_node in adj_node_set:
307
+ subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
308
+ for each_edge in adj_edge_set:
309
+ subhg.add_edge(self.nodes_in_edge(each_edge),
310
+ attr_dict=self.edge_attr(each_edge),
311
+ edge_name=each_edge)
312
+ subhg.edge_idx = self.edge_idx
313
+ return subhg
314
+
315
+ def get_subhg(self, node_list, edge_list, ident_node_dict=None):
316
+ """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
317
+ if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
318
+
319
+ Parameters
320
+ ----------
321
+ node : str
322
+ ident_node_dict : dict
323
+ dict containing identical nodes. see `get_identical_node_dict` for more details
324
+
325
+ Returns
326
+ -------
327
+ subhg : Hypergraph
328
+ """
329
+ if ident_node_dict is None:
330
+ ident_node_dict = self.get_identical_node_dict()
331
+ adj_node_set = set([])
332
+ for each_node in node_list:
333
+ adj_node_set.update(set(ident_node_dict[each_node]))
334
+ adj_edge_set = set(edge_list)
335
+
336
+ subhg = Hypergraph()
337
+ for each_node in adj_node_set:
338
+ subhg.add_node(each_node,
339
+ attr_dict=deepcopy(self.node_attr(each_node)))
340
+ for each_edge in adj_edge_set:
341
+ subhg.add_edge(self.nodes_in_edge(each_edge),
342
+ attr_dict=deepcopy(self.edge_attr(each_edge)),
343
+ edge_name=each_edge)
344
+ subhg.edge_idx = self.edge_idx
345
+ return subhg
346
+
347
+ def copy(self):
348
+ ''' return a copy of the object
349
+
350
+ Returns
351
+ -------
352
+ Hypergraph
353
+ '''
354
+ return deepcopy(self)
355
+
356
+ def node_attr(self, node):
357
+ return self.hg.nodes[node]['attr_dict']
358
+
359
+ def edge_attr(self, edge):
360
+ return self.hg.nodes[edge]['attr_dict']
361
+
362
+ def set_node_attr(self, node, attr_dict):
363
+ for each_key, each_val in attr_dict.items():
364
+ self.hg.nodes[node]['attr_dict'][each_key] = each_val
365
+
366
+ def set_edge_attr(self, edge, attr_dict):
367
+ for each_key, each_val in attr_dict.items():
368
+ self.hg.nodes[edge]['attr_dict'][each_key] = each_val
369
+
370
+ def get_identical_node_dict(self):
371
+ ''' get identical nodes
372
+ nodes are identical if they share the same set of adjacent edges.
373
+
374
+ Returns
375
+ -------
376
+ ident_node_dict : dict
377
+ ident_node_dict[node] returns a list of nodes that are identical to `node`.
378
+ '''
379
+ ident_node_dict = {}
380
+ for each_node in self.nodes:
381
+ ident_node_list = []
382
+ for each_other_node in self.nodes:
383
+ if each_other_node == each_node:
384
+ ident_node_list.append(each_other_node)
385
+ elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
386
+ and len(self.adj_edges(each_node)) != 0:
387
+ ident_node_list.append(each_other_node)
388
+ ident_node_dict[each_node] = ident_node_list
389
+ return ident_node_dict
390
+ '''
391
+ ident_node_dict = {}
392
+ for each_node in self.nodes:
393
+ ident_node_dict[each_node] = [each_node]
394
+ return ident_node_dict
395
+ '''
396
+
397
+ def get_leaf_edge(self):
398
+ ''' get an edge that is incident only to one edge
399
+
400
+ Returns
401
+ -------
402
+ if exists, return a leaf edge. otherwise, return None.
403
+ '''
404
+ for each_edge in self.edges:
405
+ if len(self.adj_nodes(each_edge)) == 1:
406
+ if 'tmp' not in self.edge_attr(each_edge):
407
+ return each_edge
408
+ return None
409
+
410
+ def get_nontmp_edge(self):
411
+ for each_edge in self.edges:
412
+ if 'tmp' not in self.edge_attr(each_edge):
413
+ return each_edge
414
+ return None
415
+
416
+ def is_subhg(self, hg):
417
+ ''' return whether this hypergraph is a subhypergraph of `hg`
418
+
419
+ Returns
420
+ -------
421
+ True if self \in hg,
422
+ False otherwise.
423
+ '''
424
+ for each_node in self.nodes:
425
+ if each_node not in hg.nodes:
426
+ return False
427
+ for each_edge in self.edges:
428
+ if each_edge not in hg.edges:
429
+ return False
430
+ return True
431
+
432
+ def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
433
+ ''' if `node` is in a cycle, then return True. otherwise, False.
434
+
435
+ Parameters
436
+ ----------
437
+ node : str
438
+ node in a hypergraph
439
+ visited : list
440
+ list of visited nodes, used for recursion
441
+ parent : str
442
+ parent node, used to eliminate a cycle consisting of two nodes and one edge.
443
+
444
+ Returns
445
+ -------
446
+ bool
447
+ '''
448
+ if visited is None:
449
+ visited = []
450
+ if parent == '':
451
+ visited = []
452
+ if root_node == '':
453
+ root_node = node
454
+ visited.append(node)
455
+ for each_adj_node in self.adj_nodes(node):
456
+ if each_adj_node not in visited:
457
+ if self.in_cycle(each_adj_node, visited, node, root_node):
458
+ return True
459
+ elif each_adj_node != parent and each_adj_node == root_node:
460
+ return True
461
+ return False
462
+
463
+
464
+ def draw(self, file_path=None, with_node=False, with_edge_name=False):
465
+ ''' draw hypergraph
466
+ '''
467
+ import graphviz
468
+ G = graphviz.Graph(format='png')
469
+ for each_node in self.nodes:
470
+ if 'ext_id' in self.node_attr(each_node):
471
+ G.node(each_node, label='',
472
+ shape='circle', width='0.1', height='0.1', style='filled',
473
+ fillcolor='black')
474
+ else:
475
+ if with_node:
476
+ G.node(each_node, label='',
477
+ shape='circle', width='0.1', height='0.1', style='filled',
478
+ fillcolor='gray')
479
+ edge_list = []
480
+ for each_edge in self.edges:
481
+ if self.edge_attr(each_edge).get('terminal', False):
482
+ G.node(each_edge,
483
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
484
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
485
+ fontcolor='black', shape='square')
486
+ elif self.edge_attr(each_edge).get('tmp', False):
487
+ G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
488
+ fontcolor='black', shape='square')
489
+ else:
490
+ G.node(each_edge,
491
+ label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
492
+ else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
493
+ fontcolor='black', shape='square', style='filled')
494
+ if with_node:
495
+ for each_node in self.nodes_in_edge(each_edge):
496
+ G.edge(each_edge, each_node)
497
+ else:
498
+ for each_node in self.nodes_in_edge(each_edge):
499
+ if 'ext_id' in self.node_attr(each_node)\
500
+ and set([each_node, each_edge]) not in edge_list:
501
+ G.edge(each_edge, each_node)
502
+ edge_list.append(set([each_node, each_edge]))
503
+ for each_other_edge in self.adj_nodes(each_edge):
504
+ if set([each_edge, each_other_edge]) not in edge_list:
505
+ num_bond = 0
506
+ common_node_set = set(self.nodes_in_edge(each_edge))\
507
+ .intersection(set(self.nodes_in_edge(each_other_edge)))
508
+ for each_node in common_node_set:
509
+ if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
510
+ num_bond += self.node_attr(each_node)['symbol'].bond_type
511
+ elif self.node_attr(each_node)['symbol'].bond_type in [12]:
512
+ num_bond += 1
513
+ else:
514
+ raise NotImplementedError('unsupported bond type')
515
+ for _ in range(num_bond):
516
+ G.edge(each_edge, each_other_edge)
517
+ edge_list.append(set([each_edge, each_other_edge]))
518
+ if file_path is not None:
519
+ G.render(file_path, cleanup=True)
520
+ #os.remove(file_path)
521
+ return G
522
+
523
+ def is_dividable(self, node):
524
+ _hg = deepcopy(self.hg)
525
+ _hg.remove_node(node)
526
+ return (not nx.is_connected(_hg))
527
+
528
+ def divide(self, node):
529
+ subhg_list = []
530
+
531
+ hg_wo_node = deepcopy(self)
532
+ hg_wo_node.remove_node(node, remove_connected_edges=False)
533
+ connected_components = nx.connected_components(hg_wo_node.hg)
534
+ for each_component in connected_components:
535
+ node_list = [node]
536
+ edge_list = []
537
+ node_list.extend([each_node for each_node in each_component
538
+ if each_node.startswith('bond_')])
539
+ edge_list.extend([each_edge for each_edge in each_component
540
+ if each_edge.startswith('e')])
541
+ subhg_list.append(self.get_subhg(node_list, edge_list))
542
+ #subhg_list[-1].set_node_attr(node, {'divided': True})
543
+ return subhg_list
544
+
models/mhg_model/graph_grammar/io/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
models/mhg_model/graph_grammar/io/smi.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 12 2018"
20
+
21
+ from copy import deepcopy
22
+ from rdkit import Chem
23
+ from rdkit import RDLogger
24
+ import networkx as nx
25
+ import numpy as np
26
+ from ..hypergraph import Hypergraph
27
+ from ..graph_grammar.symbols import TSymbol, BondSymbol
28
+
29
+ # supress warnings
30
+ lg = RDLogger.logger()
31
+ lg.setLevel(RDLogger.CRITICAL)
32
+
33
+
34
+ class HGGen(object):
35
+ """
36
+ load .smi file and yield a hypergraph.
37
+
38
+ Attributes
39
+ ----------
40
+ path_to_file : str
41
+ path to .smi file
42
+ kekulize : bool
43
+ kekulize or not
44
+ add_Hs : bool
45
+ add implicit hydrogens to the molecule or not.
46
+ all_single : bool
47
+ if True, all multiple bonds are summarized into a single bond with some attributes
48
+
49
+ Yields
50
+ ------
51
+ Hypergraph
52
+ """
53
+ def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
54
+ self.num_line = 1
55
+ self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
56
+ self.kekulize = kekulize
57
+ self.add_Hs = add_Hs
58
+ self.all_single = all_single
59
+
60
+ def __iter__(self):
61
+ return self
62
+
63
+ def __next__(self):
64
+ '''
65
+ each_mol = None
66
+ while each_mol is None:
67
+ each_mol = next(self.mol_gen)
68
+ '''
69
+ # not ignoring parse errors
70
+ each_mol = next(self.mol_gen)
71
+ if each_mol is None:
72
+ raise ValueError(f'incorrect smiles in line {self.num_line}')
73
+ else:
74
+ self.num_line += 1
75
+ return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
76
+
77
+
78
+ def mol_to_bipartite(mol, kekulize):
79
+ """
80
+ get a bipartite representation of a molecule.
81
+
82
+ Parameters
83
+ ----------
84
+ mol : rdkit.Chem.rdchem.Mol
85
+ molecule object
86
+
87
+ Returns
88
+ -------
89
+ nx.Graph
90
+ a bipartite graph representing which bond is connected to which atoms.
91
+ """
92
+ try:
93
+ mol = standardize_stereo(mol)
94
+ except KeyError:
95
+ print(Chem.MolToSmiles(mol))
96
+ raise KeyError
97
+
98
+ if kekulize:
99
+ Chem.Kekulize(mol)
100
+
101
+ bipartite_g = nx.Graph()
102
+ for each_atom in mol.GetAtoms():
103
+ bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
104
+ atom_attr=atom_attr(each_atom, kekulize))
105
+
106
+ for each_bond in mol.GetBonds():
107
+ bond_idx = each_bond.GetIdx()
108
+ bipartite_g.add_node(
109
+ f"bond_{bond_idx}",
110
+ bond_attr=bond_attr(each_bond, kekulize))
111
+ bipartite_g.add_edge(
112
+ f"atom_{each_bond.GetBeginAtomIdx()}",
113
+ f"bond_{bond_idx}")
114
+ bipartite_g.add_edge(
115
+ f"atom_{each_bond.GetEndAtomIdx()}",
116
+ f"bond_{bond_idx}")
117
+ return bipartite_g
118
+
119
+
120
+ def mol_to_hg(mol, kekulize, add_Hs):
121
+ """
122
+ get a bipartite representation of a molecule.
123
+
124
+ Parameters
125
+ ----------
126
+ mol : rdkit.Chem.rdchem.Mol
127
+ molecule object
128
+ kekulize : bool
129
+ kekulize or not
130
+ add_Hs : bool
131
+ add implicit hydrogens to the molecule or not.
132
+
133
+ Returns
134
+ -------
135
+ Hypergraph
136
+ """
137
+ if add_Hs:
138
+ mol = Chem.AddHs(mol)
139
+
140
+ if kekulize:
141
+ Chem.Kekulize(mol)
142
+
143
+ bipartite_g = mol_to_bipartite(mol, kekulize)
144
+ hg = Hypergraph()
145
+ for each_atom in [each_node for each_node in bipartite_g.nodes()
146
+ if each_node.startswith('atom_')]:
147
+ node_set = set([])
148
+ for each_bond in bipartite_g.adj[each_atom]:
149
+ hg.add_node(each_bond,
150
+ attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
151
+ node_set.add(each_bond)
152
+ hg.add_edge(node_set,
153
+ attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
154
+ return hg
155
+
156
+
157
+ def hg_to_mol(hg, verbose=False):
158
+ """ convert a hypergraph into Mol object
159
+
160
+ Parameters
161
+ ----------
162
+ hg : Hypergraph
163
+
164
+ Returns
165
+ -------
166
+ mol : Chem.RWMol
167
+ """
168
+ mol = Chem.RWMol()
169
+ atom_dict = {}
170
+ bond_set = set([])
171
+ for each_edge in hg.edges:
172
+ atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
173
+ atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
174
+ atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
175
+ atom.SetChiralTag(
176
+ Chem.rdchem.ChiralType.values[
177
+ hg.edge_attr(each_edge)['symbol'].chirality])
178
+ atom_idx = mol.AddAtom(atom)
179
+ atom_dict[each_edge] = atom_idx
180
+
181
+ for each_node in hg.nodes:
182
+ edge_1, edge_2 = hg.adj_edges(each_node)
183
+ if edge_1+edge_2 not in bond_set:
184
+ if hg.node_attr(each_node)['symbol'].bond_type <= 3:
185
+ num_bond = hg.node_attr(each_node)['symbol'].bond_type
186
+ elif hg.node_attr(each_node)['symbol'].bond_type == 12:
187
+ num_bond = 1
188
+ else:
189
+ raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
190
+ _ = mol.AddBond(atom_dict[edge_1],
191
+ atom_dict[edge_2],
192
+ order=Chem.rdchem.BondType.values[num_bond])
193
+ bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
194
+
195
+ # stereo
196
+ mol.GetBondWithIdx(bond_idx).SetStereo(
197
+ Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
198
+ bond_set.update([edge_1+edge_2])
199
+ bond_set.update([edge_2+edge_1])
200
+ mol.UpdatePropertyCache()
201
+ mol = mol.GetMol()
202
+ not_stereo_mol = deepcopy(mol)
203
+ if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
204
+ raise RuntimeError('no valid molecule was obtained.')
205
+ try:
206
+ mol = set_stereo(mol)
207
+ is_stereo = True
208
+ except:
209
+ import traceback
210
+ traceback.print_exc()
211
+ is_stereo = False
212
+ mol_tmp = deepcopy(mol)
213
+ Chem.SetAromaticity(mol_tmp)
214
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
215
+ mol = mol_tmp
216
+ else:
217
+ if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
218
+ mol = not_stereo_mol
219
+ mol.UpdatePropertyCache()
220
+ Chem.GetSymmSSSR(mol)
221
+ mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
222
+ if verbose:
223
+ return mol, is_stereo
224
+ else:
225
+ return mol
226
+
227
+ def hgs_to_mols(hg_list, ignore_error=False):
228
+ if ignore_error:
229
+ mol_list = []
230
+ for each_hg in hg_list:
231
+ try:
232
+ mol = hg_to_mol(each_hg)
233
+ except:
234
+ mol = None
235
+ mol_list.append(mol)
236
+ else:
237
+ mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
238
+ return mol_list
239
+
240
+ def hgs_to_smiles(hg_list, ignore_error=False):
241
+ mol_list = hgs_to_mols(hg_list, ignore_error)
242
+ smiles_list = []
243
+ for each_mol in mol_list:
244
+ try:
245
+ smiles_list.append(
246
+ Chem.MolToSmiles(
247
+ Chem.MolFromSmiles(
248
+ Chem.MolToSmiles(
249
+ each_mol))))
250
+ except:
251
+ smiles_list.append(None)
252
+ return smiles_list
253
+
254
+ def atom_attr(atom, kekulize):
255
+ """
256
+ get atom's attributes
257
+
258
+ Parameters
259
+ ----------
260
+ atom : rdkit.Chem.rdchem.Atom
261
+ kekulize : bool
262
+ kekulize or not
263
+
264
+ Returns
265
+ -------
266
+ atom_attr : dict
267
+ "is_aromatic" : bool
268
+ the atom is aromatic or not.
269
+ "smarts" : str
270
+ SMARTS representation of the atom.
271
+ """
272
+ if kekulize:
273
+ return {'terminal': True,
274
+ 'is_in_ring': atom.IsInRing(),
275
+ 'symbol': TSymbol(degree=0,
276
+ #degree=atom.GetTotalDegree(),
277
+ is_aromatic=False,
278
+ symbol=atom.GetSymbol(),
279
+ num_explicit_Hs=atom.GetNumExplicitHs(),
280
+ formal_charge=atom.GetFormalCharge(),
281
+ chirality=atom.GetChiralTag().real
282
+ )}
283
+ else:
284
+ return {'terminal': True,
285
+ 'is_in_ring': atom.IsInRing(),
286
+ 'symbol': TSymbol(degree=0,
287
+ #degree=atom.GetTotalDegree(),
288
+ is_aromatic=atom.GetIsAromatic(),
289
+ symbol=atom.GetSymbol(),
290
+ num_explicit_Hs=atom.GetNumExplicitHs(),
291
+ formal_charge=atom.GetFormalCharge(),
292
+ chirality=atom.GetChiralTag().real
293
+ )}
294
+
295
+ def bond_attr(bond, kekulize):
296
+ """
297
+ get atom's attributes
298
+
299
+ Parameters
300
+ ----------
301
+ bond : rdkit.Chem.rdchem.Bond
302
+ kekulize : bool
303
+ kekulize or not
304
+
305
+ Returns
306
+ -------
307
+ bond_attr : dict
308
+ "bond_type" : int
309
+ {0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
310
+ 1: rdkit.Chem.rdchem.BondType.SINGLE,
311
+ 2: rdkit.Chem.rdchem.BondType.DOUBLE,
312
+ 3: rdkit.Chem.rdchem.BondType.TRIPLE,
313
+ 4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
314
+ 5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
315
+ 6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
316
+ 7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
317
+ 8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
318
+ 9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
319
+ 10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
320
+ 11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
321
+ 12: rdkit.Chem.rdchem.BondType.AROMATIC,
322
+ 13: rdkit.Chem.rdchem.BondType.IONIC,
323
+ 14: rdkit.Chem.rdchem.BondType.HYDROGEN,
324
+ 15: rdkit.Chem.rdchem.BondType.THREECENTER,
325
+ 16: rdkit.Chem.rdchem.BondType.DATIVEONE,
326
+ 17: rdkit.Chem.rdchem.BondType.DATIVE,
327
+ 18: rdkit.Chem.rdchem.BondType.DATIVEL,
328
+ 19: rdkit.Chem.rdchem.BondType.DATIVER,
329
+ 20: rdkit.Chem.rdchem.BondType.OTHER,
330
+ 21: rdkit.Chem.rdchem.BondType.ZERO}
331
+ """
332
+ if kekulize:
333
+ is_aromatic = False
334
+ if bond.GetBondType().real == 12:
335
+ bond_type = 1
336
+ else:
337
+ bond_type = bond.GetBondType().real
338
+ else:
339
+ is_aromatic = bond.GetIsAromatic()
340
+ bond_type = bond.GetBondType().real
341
+ return {'symbol': BondSymbol(is_aromatic=is_aromatic,
342
+ bond_type=bond_type,
343
+ stereo=int(bond.GetStereo())),
344
+ 'is_in_ring': bond.IsInRing()}
345
+
346
+
347
+ def standardize_stereo(mol):
348
+ '''
349
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
350
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
351
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
352
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
353
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
354
+
355
+ '''
356
+ # mol = Chem.AddHs(mol) # this removes CIPRank !!!
357
+ for each_bond in mol.GetBonds():
358
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
359
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
360
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
361
+ atom_idx_1 = each_bond.GetStereoAtoms()[0]
362
+ atom_idx_2 = each_bond.GetStereoAtoms()[1]
363
+ if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
364
+ begin_atom_idx = atom_idx_1
365
+ end_atom_idx = atom_idx_2
366
+ else:
367
+ begin_atom_idx = atom_idx_2
368
+ end_atom_idx = atom_idx_1
369
+
370
+ begin_another_atom_idx = None
371
+ assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
372
+ for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
373
+ each_neighbor_idx = each_neighbor.GetIdx()
374
+ if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
375
+ begin_another_atom_idx = each_neighbor_idx
376
+
377
+ end_another_atom_idx = None
378
+ assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
379
+ for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
380
+ each_neighbor_idx = each_neighbor.GetIdx()
381
+ if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
382
+ end_another_atom_idx = each_neighbor_idx
383
+
384
+ '''
385
+ relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
386
+ '''
387
+ begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
388
+ end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
389
+ try:
390
+ begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
391
+ except:
392
+ begin_another_atom_rank = np.inf
393
+ try:
394
+ end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
395
+ except:
396
+ end_another_atom_rank = np.inf
397
+ if begin_atom_rank < begin_another_atom_rank\
398
+ and end_atom_rank < end_another_atom_rank:
399
+ pass
400
+ elif begin_atom_rank < begin_another_atom_rank\
401
+ and end_atom_rank > end_another_atom_rank:
402
+ # (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
403
+ if each_bond.GetStereo() == 2:
404
+ # set stereo
405
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
406
+ # set bond dir
407
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
408
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
409
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
410
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
411
+ elif each_bond.GetStereo() == 3:
412
+ # set stereo
413
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
414
+ # set bond dir
415
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
416
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
417
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
418
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
419
+ else:
420
+ raise ValueError
421
+ each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
422
+ elif begin_atom_rank > begin_another_atom_rank\
423
+ and end_atom_rank < end_another_atom_rank:
424
+ # (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
425
+ if each_bond.GetStereo() == 2:
426
+ # set stereo
427
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
428
+ # set bond dir
429
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
430
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
431
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
432
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
433
+ elif each_bond.GetStereo() == 3:
434
+ # set stereo
435
+ each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
436
+ # set bond dir
437
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
438
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
439
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
440
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
441
+ else:
442
+ raise ValueError
443
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
444
+ elif begin_atom_rank > begin_another_atom_rank\
445
+ and end_atom_rank > end_another_atom_rank:
446
+ # begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
447
+ if each_bond.GetStereo() == 2:
448
+ # set bond dir
449
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
450
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
451
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
452
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
453
+ elif each_bond.GetStereo() == 3:
454
+ # set bond dir
455
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
456
+ mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
457
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
458
+ mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
459
+ else:
460
+ raise ValueError
461
+ each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
462
+ else:
463
+ raise RuntimeError
464
+ return mol
465
+
466
+
467
+ def set_stereo(mol):
468
+ '''
469
+ 0: rdkit.Chem.rdchem.BondDir.NONE,
470
+ 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
471
+ 2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
472
+ 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
473
+ 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
474
+ '''
475
+ _mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
476
+ Chem.Kekulize(_mol, True)
477
+ substruct_match = mol.GetSubstructMatch(_mol)
478
+ if not substruct_match:
479
+ ''' mol and _mol are kekulized.
480
+ sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
481
+ '''
482
+ Chem.SetAromaticity(mol)
483
+ Chem.SetAromaticity(_mol)
484
+ substruct_match = mol.GetSubstructMatch(_mol)
485
+ try:
486
+ atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
487
+ except:
488
+ raise ValueError('two molecules obtained from the same data do not match.')
489
+
490
+ for each_bond in mol.GetBonds():
491
+ begin_atom_idx = each_bond.GetBeginAtomIdx()
492
+ end_atom_idx = each_bond.GetEndAtomIdx()
493
+ _bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
494
+ _bond.SetStereo(each_bond.GetStereo())
495
+
496
+ mol = _mol
497
+ for each_bond in mol.GetBonds():
498
+ if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
499
+ begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
500
+ end_stereo_atom_idx = each_bond.GetEndAtomIdx()
501
+ begin_atom_idx_set = set([each_neighbor.GetIdx()
502
+ for each_neighbor
503
+ in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
504
+ if each_neighbor.GetIdx() != end_stereo_atom_idx])
505
+ end_atom_idx_set = set([each_neighbor.GetIdx()
506
+ for each_neighbor
507
+ in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
508
+ if each_neighbor.GetIdx() != begin_stereo_atom_idx])
509
+ if not begin_atom_idx_set:
510
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
511
+ continue
512
+ if not end_atom_idx_set:
513
+ each_bond.SetStereo(Chem.rdchem.BondStereo(0))
514
+ continue
515
+ if len(begin_atom_idx_set) == 1:
516
+ begin_atom_idx = begin_atom_idx_set.pop()
517
+ begin_another_atom_idx = None
518
+ if len(end_atom_idx_set) == 1:
519
+ end_atom_idx = end_atom_idx_set.pop()
520
+ end_another_atom_idx = None
521
+ if len(begin_atom_idx_set) == 2:
522
+ atom_idx_1 = begin_atom_idx_set.pop()
523
+ atom_idx_2 = begin_atom_idx_set.pop()
524
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
525
+ begin_atom_idx = atom_idx_1
526
+ begin_another_atom_idx = atom_idx_2
527
+ else:
528
+ begin_atom_idx = atom_idx_2
529
+ begin_another_atom_idx = atom_idx_1
530
+ if len(end_atom_idx_set) == 2:
531
+ atom_idx_1 = end_atom_idx_set.pop()
532
+ atom_idx_2 = end_atom_idx_set.pop()
533
+ if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
534
+ end_atom_idx = atom_idx_1
535
+ end_another_atom_idx = atom_idx_2
536
+ else:
537
+ end_atom_idx = atom_idx_2
538
+ end_another_atom_idx = atom_idx_1
539
+
540
+ if each_bond.GetStereo() == 2: # same side
541
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
542
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
543
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
544
+ elif each_bond.GetStereo() == 3: # opposite side
545
+ mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
546
+ mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
547
+ each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
548
+ else:
549
+ raise ValueError
550
+ return mol
551
+
552
+
553
+ def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
554
+ if atom_idx_1 is None or atom_idx_2 is None:
555
+ return mol
556
+ else:
557
+ mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
558
+ return mol
559
+
models/mhg_model/graph_grammar/nn/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
10
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
11
+ """
models/mhg_model/graph_grammar/nn/dataset.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Apr 18 2018"
20
+
21
+ from torch.utils.data import Dataset, DataLoader
22
+ import torch
23
+ import numpy as np
24
+
25
+
26
+ def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False):
27
+ ''' pad left
28
+
29
+ Parameters
30
+ ----------
31
+ sentence_list : list of sequences of integers
32
+ max_len : int
33
+ maximum length of sentences.
34
+ if a sentence is shorter than `max_len`, its left part is padded.
35
+ pad_idx : int
36
+ integer for padding
37
+ inverse : bool
38
+ if True, the sequence is inversed.
39
+
40
+ Returns
41
+ -------
42
+ List of torch.LongTensor
43
+ each sentence is left-padded.
44
+ '''
45
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
46
+
47
+ if max_in_list > max_len:
48
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
49
+
50
+ if inverse:
51
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list]
52
+ else:
53
+ return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list]
54
+
55
+
56
+ def right_padding(sentence_list, max_len, pad_idx=-1):
57
+ ''' pad right
58
+
59
+ Parameters
60
+ ----------
61
+ sentence_list : list of sequences of integers
62
+ max_len : int
63
+ maximum length of sentences.
64
+ if a sentence is shorter than `max_len`, its right part is padded.
65
+ pad_idx : int
66
+ integer for padding
67
+
68
+ Returns
69
+ -------
70
+ List of torch.LongTensor
71
+ each sentence is right-padded.
72
+ '''
73
+ max_in_list = max([len(each_sen) for each_sen in sentence_list])
74
+ if max_in_list > max_len:
75
+ raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
76
+
77
+ return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list]
78
+
79
+
80
+ class HRGDataset(Dataset):
81
+
82
+ '''
83
+ A class of HRG data
84
+ '''
85
+
86
+ def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False):
87
+ self.hrg = hrg
88
+ self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list,
89
+ max_len,
90
+ inverse=inversed_input)
91
+
92
+ self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len)
93
+ self.inserved_input = inversed_input
94
+ self.target_val_list = target_val_list
95
+ if target_val_list is not None:
96
+ if len(prod_rule_seq_list) != len(target_val_list):
97
+ raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}')
98
+
99
+ def __len__(self):
100
+ return len(self.left_prod_rule_seq_list)
101
+
102
+ def __getitem__(self, idx):
103
+ if self.target_val_list is not None:
104
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx])
105
+ else:
106
+ return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx]
107
+
108
+ @property
109
+ def vocab_size(self):
110
+ return self.hrg.num_prod_rule
111
+
112
+ def batch_padding(each_batch, batch_size, padding_idx):
113
+ num_pad = batch_size - len(each_batch[0])
114
+ if num_pad:
115
+ each_batch[0] = torch.cat([each_batch[0],
116
+ padding_idx * torch.ones((batch_size - len(each_batch[0]),
117
+ len(each_batch[0][0])), dtype=torch.int64)], dim=0)
118
+ each_batch[1] = torch.cat([each_batch[1],
119
+ padding_idx * torch.ones((batch_size - len(each_batch[1]),
120
+ len(each_batch[1][0])), dtype=torch.int64)], dim=0)
121
+ return each_batch, num_pad
models/mhg_model/graph_grammar/nn/decoder.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Aug 9 2018"
20
+
21
+
22
+ import abc
23
+ import numpy as np
24
+ import torch
25
+ from torch import nn
26
+
27
+
28
+ class DecoderBase(nn.Module):
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.hidden_dict = {}
33
+
34
+ @abc.abstractmethod
35
+ def forward_one_step(self, tgt_emb_in):
36
+ ''' one-step forward model
37
+
38
+ Parameters
39
+ ----------
40
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
41
+
42
+ Returns
43
+ -------
44
+ Tensor, shape (batch_size, hidden_dim)
45
+ '''
46
+ tgt_emb_out = None
47
+ return tgt_emb_out
48
+
49
+ @abc.abstractmethod
50
+ def init_hidden(self):
51
+ ''' initialize the hidden states
52
+ '''
53
+ pass
54
+
55
+ @abc.abstractmethod
56
+ def feed_hidden(self, hidden_dict_0):
57
+ for each_hidden in self.hidden_dict.keys():
58
+ self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden]
59
+
60
+
61
+ class GRUDecoder(DecoderBase):
62
+
63
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
64
+ dropout: float, batch_size: int, use_gpu: bool,
65
+ no_dropout=False):
66
+ super().__init__()
67
+ self.input_dim = input_dim
68
+ self.hidden_dim = hidden_dim
69
+ self.num_layers = num_layers
70
+ self.dropout = dropout
71
+ self.batch_size = batch_size
72
+ self.use_gpu = use_gpu
73
+ self.model = nn.GRU(input_size=self.input_dim,
74
+ hidden_size=self.hidden_dim,
75
+ num_layers=self.num_layers,
76
+ batch_first=True,
77
+ bidirectional=False,
78
+ dropout=self.dropout if not no_dropout else 0
79
+ )
80
+ if self.use_gpu:
81
+ self.model.cuda()
82
+ self.init_hidden()
83
+
84
+ def init_hidden(self):
85
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
86
+ self.batch_size,
87
+ self.hidden_dim),
88
+ requires_grad=False)
89
+ if self.use_gpu:
90
+ self.hidden_dict['h'] = self.hidden_dict['h'].cuda()
91
+
92
+ def forward_one_step(self, tgt_emb_in):
93
+ ''' one-step forward model
94
+
95
+ Parameters
96
+ ----------
97
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
98
+
99
+ Returns
100
+ -------
101
+ Tensor, shape (batch_size, hidden_dim)
102
+ '''
103
+ tgt_emb_out, self.hidden_dict['h'] \
104
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
105
+ self.hidden_dict['h'])
106
+ return tgt_emb_out
107
+
108
+
109
+ class LSTMDecoder(DecoderBase):
110
+
111
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
112
+ dropout: float, batch_size: int, use_gpu: bool,
113
+ no_dropout=False):
114
+ super().__init__()
115
+ self.input_dim = input_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.dropout = dropout
119
+ self.batch_size = batch_size
120
+ self.use_gpu = use_gpu
121
+ self.model = nn.LSTM(input_size=self.input_dim,
122
+ hidden_size=self.hidden_dim,
123
+ num_layers=self.num_layers,
124
+ batch_first=True,
125
+ bidirectional=False,
126
+ dropout=self.dropout if not no_dropout else 0)
127
+ if self.use_gpu:
128
+ self.model.cuda()
129
+ self.init_hidden()
130
+
131
+ def init_hidden(self):
132
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
133
+ self.batch_size,
134
+ self.hidden_dim),
135
+ requires_grad=False)
136
+ self.hidden_dict['c'] = torch.zeros((self.num_layers,
137
+ self.batch_size,
138
+ self.hidden_dim),
139
+ requires_grad=False)
140
+ if self.use_gpu:
141
+ for each_hidden in self.hidden_dict.keys():
142
+ self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda()
143
+
144
+ def forward_one_step(self, tgt_emb_in):
145
+ ''' one-step forward model
146
+
147
+ Parameters
148
+ ----------
149
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
150
+
151
+ Returns
152
+ -------
153
+ Tensor, shape (batch_size, hidden_dim)
154
+ '''
155
+ tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \
156
+ = self.model(tgt_emb_in.view(self.batch_size, 1, -1),
157
+ self.hidden_dict['h'], self.hidden_dict['c'])
158
+ return tgt_hidden_out
models/mhg_model/graph_grammar/nn/encoder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Aug 9 2018"
20
+
21
+
22
+ import abc
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+ from typing import List
28
+
29
+
30
+ class EncoderBase(nn.Module):
31
+
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ @abc.abstractmethod
36
+ def forward(self, in_seq):
37
+ ''' forward model
38
+
39
+ Parameters
40
+ ----------
41
+ in_seq_emb : Variable, shape (batch_size, max_len, input_dim)
42
+
43
+ Returns
44
+ -------
45
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
46
+ '''
47
+ pass
48
+
49
+ @abc.abstractmethod
50
+ def init_hidden(self):
51
+ ''' initialize the hidden states
52
+ '''
53
+ pass
54
+
55
+
56
+ class GRUEncoder(EncoderBase):
57
+
58
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
59
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
60
+ no_dropout=False):
61
+ super().__init__()
62
+ self.input_dim = input_dim
63
+ self.hidden_dim = hidden_dim
64
+ self.num_layers = num_layers
65
+ self.bidirectional = bidirectional
66
+ self.dropout = dropout
67
+ self.batch_size = batch_size
68
+ self.use_gpu = use_gpu
69
+ self.model = nn.GRU(input_size=self.input_dim,
70
+ hidden_size=self.hidden_dim,
71
+ num_layers=self.num_layers,
72
+ batch_first=True,
73
+ bidirectional=self.bidirectional,
74
+ dropout=self.dropout if not no_dropout else 0)
75
+ if self.use_gpu:
76
+ self.model.cuda()
77
+ self.init_hidden()
78
+
79
+
80
+ def init_hidden(self):
81
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
82
+ self.batch_size,
83
+ self.hidden_dim),
84
+ requires_grad=False)
85
+ if self.use_gpu:
86
+ self.h0 = self.h0.cuda()
87
+
88
+ def forward(self, in_seq_emb):
89
+ ''' forward model
90
+
91
+ Parameters
92
+ ----------
93
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
94
+
95
+ Returns
96
+ -------
97
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
98
+ '''
99
+ max_len = in_seq_emb.size(1)
100
+ hidden_seq_emb, self.h0 = self.model(
101
+ in_seq_emb, self.h0)
102
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
103
+ max_len,
104
+ 1 + self.bidirectional,
105
+ self.hidden_dim)
106
+ return hidden_seq_emb
107
+
108
+
109
+ class LSTMEncoder(EncoderBase):
110
+
111
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
112
+ bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
113
+ no_dropout=False):
114
+ super().__init__()
115
+ self.input_dim = input_dim
116
+ self.hidden_dim = hidden_dim
117
+ self.num_layers = num_layers
118
+ self.bidirectional = bidirectional
119
+ self.dropout = dropout
120
+ self.batch_size = batch_size
121
+ self.use_gpu = use_gpu
122
+ self.model = nn.LSTM(input_size=self.input_dim,
123
+ hidden_size=self.hidden_dim,
124
+ num_layers=self.num_layers,
125
+ batch_first=True,
126
+ bidirectional=self.bidirectional,
127
+ dropout=self.dropout if not no_dropout else 0)
128
+ if self.use_gpu:
129
+ self.model.cuda()
130
+ self.init_hidden()
131
+
132
+ def init_hidden(self):
133
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
134
+ self.batch_size,
135
+ self.hidden_dim),
136
+ requires_grad=False)
137
+ self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
138
+ self.batch_size,
139
+ self.hidden_dim),
140
+ requires_grad=False)
141
+ if self.use_gpu:
142
+ self.h0 = self.h0.cuda()
143
+ self.c0 = self.c0.cuda()
144
+
145
+ def forward(self, in_seq_emb):
146
+ ''' forward model
147
+
148
+ Parameters
149
+ ----------
150
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
151
+
152
+ Returns
153
+ -------
154
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
155
+ '''
156
+ max_len = in_seq_emb.size(1)
157
+ hidden_seq_emb, (self.h0, self.c0) = self.model(
158
+ in_seq_emb, (self.h0, self.c0))
159
+ hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
160
+ max_len,
161
+ 1 + self.bidirectional,
162
+ self.hidden_dim)
163
+ return hidden_seq_emb
164
+
165
+
166
+ class FullConnectedEncoder(EncoderBase):
167
+
168
+ def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int],
169
+ batch_size: int, use_gpu: bool):
170
+ super().__init__()
171
+ self.input_dim = input_dim
172
+ self.hidden_dim = hidden_dim
173
+ self.max_len = max_len
174
+ self.hidden_dim_list = hidden_dim_list
175
+ self.use_gpu = use_gpu
176
+ in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim]
177
+ self.linear_list = nn.ModuleList(
178
+ [nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\
179
+ for each_idx in range(len(in_out_dim_list) - 1)])
180
+
181
+ def forward(self, in_seq_emb):
182
+ ''' forward model
183
+
184
+ Parameters
185
+ ----------
186
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
187
+
188
+ Returns
189
+ -------
190
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
191
+ '''
192
+ batch_size = in_seq_emb.size(0)
193
+ x = in_seq_emb.view(batch_size, -1)
194
+ for each_linear in self.linear_list:
195
+ x = F.relu(each_linear(x))
196
+ return x.view(batch_size, 1, -1)
197
+
198
+ def init_hidden(self):
199
+ pass
models/mhg_model/graph_grammar/nn/graph.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ # Rhizome
4
+ # Version beta 0.0, August 2023
5
+ # Property of IBM Research, Accelerated Discovery
6
+ #
7
+
8
+ """
9
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
10
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ """ Title """
15
+
16
+ __author__ = "Hiroshi Kajino <[email protected]>"
17
+ __copyright__ = "(c) Copyright IBM Corp. 2018"
18
+ __version__ = "0.1"
19
+ __date__ = "Jan 1 2018"
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus
25
+ from torch import nn
26
+ from torch.autograd import Variable
27
+
28
+ class MolecularProdRuleEmbedding(nn.Module):
29
+
30
+ ''' molecular fingerprint layer
31
+ '''
32
+
33
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
34
+ out_dim=32, element_embed_dim=32,
35
+ num_layers=3, padding_idx=None, use_gpu=False):
36
+ super().__init__()
37
+ if padding_idx is not None:
38
+ assert padding_idx == -1, 'padding_idx must be -1.'
39
+ self.prod_rule_corpus = prod_rule_corpus
40
+ self.layer2layer_activation = layer2layer_activation
41
+ self.layer2out_activation = layer2out_activation
42
+ self.out_dim = out_dim
43
+ self.element_embed_dim = element_embed_dim
44
+ self.num_layers = num_layers
45
+ self.padding_idx = padding_idx
46
+ self.use_gpu = use_gpu
47
+
48
+ self.layer2layer_list = []
49
+ self.layer2out_list = []
50
+
51
+ if self.use_gpu:
52
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
53
+ self.element_embed_dim, requires_grad=True).cuda()
54
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
55
+ self.element_embed_dim, requires_grad=True).cuda()
56
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
57
+ self.element_embed_dim, requires_grad=True).cuda()
58
+ for _ in range(num_layers):
59
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
60
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
61
+ else:
62
+ self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
63
+ self.element_embed_dim, requires_grad=True)
64
+ self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
65
+ self.element_embed_dim, requires_grad=True)
66
+ self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
67
+ self.element_embed_dim, requires_grad=True)
68
+ for _ in range(num_layers):
69
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
70
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
71
+
72
+
73
+ def forward(self, prod_rule_idx_seq):
74
+ ''' forward model for mini-batch
75
+
76
+ Parameters
77
+ ----------
78
+ prod_rule_idx_seq : (batch_size, length)
79
+
80
+ Returns
81
+ -------
82
+ Variable, shape (batch_size, length, out_dim)
83
+ '''
84
+ batch_size, length = prod_rule_idx_seq.shape
85
+ if self.use_gpu:
86
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
87
+ else:
88
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
89
+ for each_batch_idx in range(batch_size):
90
+ for each_idx in range(length):
91
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
92
+ continue
93
+ else:
94
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
95
+ layer_wise_embed_dict = {each_edge: self.atom_embed[
96
+ each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
97
+ for each_edge in each_prod_rule.rhs.edges}
98
+ layer_wise_embed_dict.update({each_node: self.bond_embed[
99
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]
100
+ for each_node in each_prod_rule.rhs.nodes})
101
+ for each_node in each_prod_rule.rhs.nodes:
102
+ if 'ext_id' in each_prod_rule.rhs.node_attr(each_node):
103
+ layer_wise_embed_dict[each_node] \
104
+ = layer_wise_embed_dict[each_node] \
105
+ + self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']]
106
+
107
+ for each_layer in range(self.num_layers):
108
+ next_layer_embed_dict = {}
109
+ for each_edge in each_prod_rule.rhs.edges:
110
+ v = layer_wise_embed_dict[each_edge]
111
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
112
+ v = v + layer_wise_embed_dict[each_node]
113
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
114
+ out[each_batch_idx, each_idx, :] \
115
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
116
+ for each_node in each_prod_rule.rhs.nodes:
117
+ v = layer_wise_embed_dict[each_node]
118
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
119
+ v = v + layer_wise_embed_dict[each_edge]
120
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
121
+ out[each_batch_idx, each_idx, :]\
122
+ = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
123
+ layer_wise_embed_dict = next_layer_embed_dict
124
+
125
+ return out
126
+
127
+
128
+ class MolecularProdRuleEmbeddingLastLayer(nn.Module):
129
+
130
+ ''' molecular fingerprint layer
131
+ '''
132
+
133
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
134
+ out_dim=32, element_embed_dim=32,
135
+ num_layers=3, padding_idx=None, use_gpu=False):
136
+ super().__init__()
137
+ if padding_idx is not None:
138
+ assert padding_idx == -1, 'padding_idx must be -1.'
139
+ self.prod_rule_corpus = prod_rule_corpus
140
+ self.layer2layer_activation = layer2layer_activation
141
+ self.layer2out_activation = layer2out_activation
142
+ self.out_dim = out_dim
143
+ self.element_embed_dim = element_embed_dim
144
+ self.num_layers = num_layers
145
+ self.padding_idx = padding_idx
146
+ self.use_gpu = use_gpu
147
+
148
+ self.layer2layer_list = []
149
+ self.layer2out_list = []
150
+
151
+ if self.use_gpu:
152
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda()
153
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda()
154
+ for _ in range(num_layers+1):
155
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
156
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
157
+ else:
158
+ self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim)
159
+ self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim)
160
+ for _ in range(num_layers+1):
161
+ self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
162
+ self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
163
+
164
+
165
+ def forward(self, prod_rule_idx_seq):
166
+ ''' forward model for mini-batch
167
+
168
+ Parameters
169
+ ----------
170
+ prod_rule_idx_seq : (batch_size, length)
171
+
172
+ Returns
173
+ -------
174
+ Variable, shape (batch_size, length, out_dim)
175
+ '''
176
+ batch_size, length = prod_rule_idx_seq.shape
177
+ if self.use_gpu:
178
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
179
+ else:
180
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
181
+ for each_batch_idx in range(batch_size):
182
+ for each_idx in range(length):
183
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
184
+ continue
185
+ else:
186
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
187
+
188
+ if self.use_gpu:
189
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
190
+ Variable(torch.LongTensor(
191
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
192
+ ), requires_grad=False).cuda())
193
+ for each_edge in each_prod_rule.rhs.edges}
194
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
195
+ Variable(
196
+ torch.LongTensor([
197
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
198
+ requires_grad=False).cuda()
199
+ ) for each_node in each_prod_rule.rhs.nodes})
200
+ else:
201
+ layer_wise_embed_dict = {each_edge: self.atom_embed(
202
+ Variable(torch.LongTensor(
203
+ [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
204
+ ), requires_grad=False))
205
+ for each_edge in each_prod_rule.rhs.edges}
206
+ layer_wise_embed_dict.update({each_node: self.bond_embed(
207
+ Variable(
208
+ torch.LongTensor([
209
+ each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
210
+ requires_grad=False)
211
+ ) for each_node in each_prod_rule.rhs.nodes})
212
+
213
+ for each_layer in range(self.num_layers):
214
+ next_layer_embed_dict = {}
215
+ for each_edge in each_prod_rule.rhs.edges:
216
+ v = layer_wise_embed_dict[each_edge]
217
+ for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
218
+ v += layer_wise_embed_dict[each_node]
219
+ next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
220
+ for each_node in each_prod_rule.rhs.nodes:
221
+ v = layer_wise_embed_dict[each_node]
222
+ for each_edge in each_prod_rule.rhs.adj_edges(each_node):
223
+ v += layer_wise_embed_dict[each_edge]
224
+ next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
225
+ layer_wise_embed_dict = next_layer_embed_dict
226
+ for each_edge in each_prod_rule.rhs.edges:
227
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
228
+ for each_edge in each_prod_rule.rhs.edges:
229
+ out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
230
+
231
+ return out
232
+
233
+
234
+ class MolecularProdRuleEmbeddingUsingFeatures(nn.Module):
235
+
236
+ ''' molecular fingerprint layer
237
+ '''
238
+
239
+ def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
240
+ out_dim=32, num_layers=3, padding_idx=None, use_gpu=False):
241
+ super().__init__()
242
+ if padding_idx is not None:
243
+ assert padding_idx == -1, 'padding_idx must be -1.'
244
+ self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors()
245
+ self.prod_rule_corpus = prod_rule_corpus
246
+ self.layer2layer_activation = layer2layer_activation
247
+ self.layer2out_activation = layer2out_activation
248
+ self.out_dim = out_dim
249
+ self.num_layers = num_layers
250
+ self.padding_idx = padding_idx
251
+ self.use_gpu = use_gpu
252
+
253
+ self.layer2layer_list = []
254
+ self.layer2out_list = []
255
+
256
+ if self.use_gpu:
257
+ for each_key in self.feature_dict:
258
+ self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda()
259
+ for _ in range(num_layers):
260
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda())
261
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda())
262
+ else:
263
+ for _ in range(num_layers):
264
+ self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim))
265
+ self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim))
266
+
267
+
268
+ def forward(self, prod_rule_idx_seq):
269
+ ''' forward model for mini-batch
270
+
271
+ Parameters
272
+ ----------
273
+ prod_rule_idx_seq : (batch_size, length)
274
+
275
+ Returns
276
+ -------
277
+ Variable, shape (batch_size, length, out_dim)
278
+ '''
279
+ batch_size, length = prod_rule_idx_seq.shape
280
+ if self.use_gpu:
281
+ out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
282
+ else:
283
+ out = Variable(torch.zeros((batch_size, length, self.out_dim)))
284
+ for each_batch_idx in range(batch_size):
285
+ for each_idx in range(length):
286
+ if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
287
+ continue
288
+ else:
289
+ each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
290
+ edge_list = sorted(list(each_prod_rule.rhs.edges))
291
+ node_list = sorted(list(each_prod_rule.rhs.nodes))
292
+ adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list)))
293
+ if self.use_gpu:
294
+ adj_mat = adj_mat.cuda()
295
+ layer_wise_embed = [
296
+ self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']]
297
+ for each_edge in edge_list]\
298
+ + [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']]
299
+ for each_node in node_list]
300
+ for each_node in each_prod_rule.ext_node.values():
301
+ layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
302
+ = layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
303
+ + self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])]
304
+ layer_wise_embed = torch.stack(layer_wise_embed)
305
+
306
+ for each_layer in range(self.num_layers):
307
+ message = adj_mat @ layer_wise_embed
308
+ next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message))
309
+ out[each_batch_idx, each_idx, :] \
310
+ = out[each_batch_idx, each_idx, :] \
311
+ + self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0)
312
+ layer_wise_embed = next_layer_embed
313
+ return out
models/mhg_model/load.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ import os
8
+ import pickle
9
+ import sys
10
+
11
+ from rdkit import Chem
12
+ import torch
13
+ from torch_geometric.utils.smiles import from_smiles
14
+
15
+ from typing import Any, Dict, List, Optional, Union
16
+ from typing_extensions import Self
17
+
18
+ from .graph_grammar.io.smi import hg_to_mol
19
+ from .models.mhgvae import GrammarGINVAE
20
+
21
+ from huggingface_hub import hf_hub_download
22
+
23
+
24
+ class PretrainedModelWrapper:
25
+ model: GrammarGINVAE
26
+
27
+ def __init__(self, model_dict: Dict[str, Any]) -> None:
28
+ json_params = model_dict['gnn_params']
29
+ encoder_params = json_params['encoder_params']
30
+ encoder_params['node_feature_size'] = model_dict['num_features']
31
+ encoder_params['edge_feature_size'] = model_dict['num_edge_features']
32
+ self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
33
+ decoder_params=json_params['decoder_params'],
34
+ prod_rule_embed_params=json_params["prod_rule_embed_params"],
35
+ batch_size=512, max_len=model_dict['max_length'])
36
+ self.model.load_state_dict(model_dict['model_state_dict'])
37
+
38
+ self.model.eval()
39
+
40
+ def to(self, device: Union[str, int, torch.device]) -> Self:
41
+ dev_type = type(device)
42
+ if dev_type != torch.device:
43
+ if dev_type == str or torch.cuda.is_available():
44
+ device = torch.device(device)
45
+ else:
46
+ device = torch.device("mps", device)
47
+
48
+ self.model = self.model.to(device)
49
+ return self
50
+
51
+ def encode(self, data: List[str]) -> List[torch.tensor]:
52
+ # Need to encode them into a graph nn
53
+ output = []
54
+ for d in data:
55
+ params = next(self.model.parameters())
56
+ g = from_smiles(d)
57
+ if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
58
+ g.to(params.device)
59
+ ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
60
+ output.append(ltvec[0])
61
+ return output
62
+
63
+ def decode(self, data: List[torch.tensor]) -> List[str]:
64
+ output = []
65
+ for d in data:
66
+ mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
67
+ z = self.model.reparameterize(mu, logvar)
68
+ flags, _, hgs = self.model.decode(z)
69
+ if flags[0]:
70
+ reconstructed_mol, _ = hg_to_mol(hgs[0], True)
71
+ output.append(Chem.MolToSmiles(reconstructed_mol))
72
+ else:
73
+ output.append(None)
74
+ return output
75
+
76
+
77
+ def load(model_name: str = "mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
78
+ PretrainedModelWrapper]:
79
+
80
+ repo_id = "ibm/materials.mhg-ged"
81
+ filename = "pytorch_model.bin" #"mhggnn_pretrained_model_0724_2023.pickle"
82
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename)
83
+ with open(file_path, "rb") as f:
84
+ model_dict = torch.load(f, weights_only=False)
85
+ return PretrainedModelWrapper(model_dict)
86
+
87
+
88
+ """try:
89
+ if os.path.isfile(model_name):
90
+ with open(model_name, "rb") as f:
91
+ model_dict = pickle.load(f)
92
+ print("MHG Model Loaded")
93
+ return PretrainedModelWrapper(model_dict)
94
+
95
+ except:
96
+
97
+ for p in sys.path:
98
+ file = p + "/" + model_name
99
+ if os.path.isfile(file):
100
+ with open(file, "rb") as f:
101
+ model_dict = pickle.load(f)
102
+ return PretrainedModelWrapper(model_dict)"""
103
+ return None
models/mhg_model/mhg_gnn.egg-info/PKG-INFO ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: mhg-gnn
3
+ Version: 0.0
4
+ Summary: Package for mhg-gnn
5
+ Author: team
6
+ License: TBD
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: Programming Language :: Python :: 3.9
9
+ Description-Content-Type: text/markdown
10
+ Requires-Dist: networkx>=2.8
11
+ Requires-Dist: numpy<2.0.0,>=1.23.5
12
+ Requires-Dist: pandas>=1.5.3
13
+ Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4
14
+ Requires-Dist: torch>=2.0.0
15
+ Requires-Dist: torchinfo>=1.8.0
16
+ Requires-Dist: torch-geometric>=2.3.1
17
+
18
+ # mhg-gnn
19
+
20
+ This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
21
+
22
+ **Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
23
+
24
+ For more information contact: [email protected]
25
+
26
+ ![mhg-gnn](images/mhg_example1.png)
27
+
28
+ ## Introduction
29
+
30
+ We present MHG-GNN, an autoencoder architecture
31
+ that has an encoder based on GNN and a decoder based on a sequential model with MHG.
32
+ Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
33
+ demonstrate high predictive performance on molecular graph data.
34
+ In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
35
+
36
+ ## Table of Contents
37
+
38
+ 1. [Getting Started](#getting-started)
39
+ 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
40
+ 2. [Replicating Conda Environment](#replicating-conda-environment)
41
+ 2. [Feature Extraction](#feature-extraction)
42
+
43
+ ## Getting Started
44
+
45
+ **This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
46
+
47
+ ### Pretrained Models and Training Logs
48
+
49
+ We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
50
+
51
+ Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
52
+
53
+ ### Replacicating Conda Environment
54
+
55
+ Follow these steps to replicate our Conda environment and install the necessary libraries:
56
+
57
+ ```
58
+ conda create --name mhg-gnn-env python=3.8.18
59
+ conda activate mhg-gnn-env
60
+ ```
61
+
62
+ #### Install Packages with Conda
63
+
64
+ ```
65
+ conda install -c conda-forge networkx=2.8
66
+ conda install numpy=1.23.5
67
+ # conda install -c conda-forge rdkit=2022.9.4
68
+ conda install pytorch=2.0.0 torchvision torchaudio -c pytorch
69
+ conda install -c conda-forge torchinfo=1.8.0
70
+ conda install pyg -c pyg
71
+ ```
72
+
73
+ #### Install Packages with pip
74
+ ```
75
+ pip install rdkit torch-nl==0.3 torch-scatter torch-sparse
76
+ ```
77
+
78
+ ## Feature Extraction
79
+
80
+ The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
81
+
82
+ To load mhg-gnn, you can simply use:
83
+
84
+ ```python
85
+ import torch
86
+ import load
87
+
88
+ model = load.load()
89
+ ```
90
+
91
+ To encode SMILES into embeddings, you can use:
92
+
93
+ ```python
94
+ with torch.no_grad():
95
+ repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
96
+ ```
97
+
98
+ For decoder, you can use the function, so you can return from embeddings to SMILES strings:
99
+
100
+ ```python
101
+ orig = model.decode(repr)
102
+ ```
models/mhg_model/mhg_gnn.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ setup.cfg
3
+ setup.py
4
+ ./graph_grammar/__init__.py
5
+ ./graph_grammar/hypergraph.py
6
+ ./graph_grammar/algo/__init__.py
7
+ ./graph_grammar/algo/tree_decomposition.py
8
+ ./graph_grammar/graph_grammar/__init__.py
9
+ ./graph_grammar/graph_grammar/base.py
10
+ ./graph_grammar/graph_grammar/corpus.py
11
+ ./graph_grammar/graph_grammar/hrg.py
12
+ ./graph_grammar/graph_grammar/symbols.py
13
+ ./graph_grammar/graph_grammar/utils.py
14
+ ./graph_grammar/io/__init__.py
15
+ ./graph_grammar/io/smi.py
16
+ ./graph_grammar/nn/__init__.py
17
+ ./graph_grammar/nn/dataset.py
18
+ ./graph_grammar/nn/decoder.py
19
+ ./graph_grammar/nn/encoder.py
20
+ ./graph_grammar/nn/graph.py
21
+ ./models/__init__.py
22
+ ./models/mhgvae.py
23
+ graph_grammar/__init__.py
24
+ graph_grammar/hypergraph.py
25
+ graph_grammar/algo/__init__.py
26
+ graph_grammar/algo/tree_decomposition.py
27
+ graph_grammar/graph_grammar/__init__.py
28
+ graph_grammar/graph_grammar/base.py
29
+ graph_grammar/graph_grammar/corpus.py
30
+ graph_grammar/graph_grammar/hrg.py
31
+ graph_grammar/graph_grammar/symbols.py
32
+ graph_grammar/graph_grammar/utils.py
33
+ graph_grammar/io/__init__.py
34
+ graph_grammar/io/smi.py
35
+ graph_grammar/nn/__init__.py
36
+ graph_grammar/nn/dataset.py
37
+ graph_grammar/nn/decoder.py
38
+ graph_grammar/nn/encoder.py
39
+ graph_grammar/nn/graph.py
40
+ mhg_gnn.egg-info/PKG-INFO
41
+ mhg_gnn.egg-info/SOURCES.txt
42
+ mhg_gnn.egg-info/dependency_links.txt
43
+ mhg_gnn.egg-info/requires.txt
44
+ mhg_gnn.egg-info/top_level.txt
45
+ models/__init__.py
46
+ models/mhgvae.py
models/mhg_model/mhg_gnn.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
models/mhg_model/mhg_gnn.egg-info/requires.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ networkx>=2.8
2
+ numpy<2.0.0,>=1.23.5
3
+ pandas>=1.5.3
4
+ rdkit-pypi<2023.9.6,>=2022.9.4
5
+ torch>=2.0.0
6
+ torchinfo>=1.8.0
7
+ torch-geometric>=2.3.1
models/mhg_model/mhg_gnn.egg-info/top_level.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ graph_grammar
2
+ models
models/mhg_model/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
models/mhg_model/models/mhgvae.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # Rhizome
3
+ # Version beta 0.0, August 2023
4
+ # Property of IBM Research, Accelerated Discovery
5
+ #
6
+
7
+ """
8
+ PLEASE NOTE THIS IMPLEMENTATION INCLUDES ADAPTED SOURCE CODE
9
+ OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE,
10
+ E.G., GRUEncoder/GRUDecoder, GrammarSeq2SeqVAE AND EVEN SOME METHODS OF GrammarGINVAE.
11
+ THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
12
+ """
13
+
14
+ import numpy as np
15
+ import logging
16
+
17
+ import torch
18
+ from torch.autograd import Variable
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.nn.modules.loss import _Loss
22
+
23
+ from torch_geometric.nn import MessagePassing
24
+ from torch_geometric.nn import global_add_pool
25
+
26
+
27
+ from ..graph_grammar.graph_grammar.symbols import NTSymbol
28
+ from ..graph_grammar.nn.encoder import EncoderBase
29
+ from ..graph_grammar.nn.decoder import DecoderBase
30
+
31
+ def get_atom_edge_feature_dims():
32
+ from torch_geometric.utils.smiles import x_map, e_map
33
+ func = lambda x: len(x[1])
34
+ return list(map(func, x_map.items())), list(map(func, e_map.items()))
35
+
36
+
37
+ class FeatureEmbedding(nn.Module):
38
+ def __init__(self, input_dims, embedded_dim):
39
+ super().__init__()
40
+ self.embedding_list = nn.ModuleList()
41
+ for dim in input_dims:
42
+ embedding = nn.Embedding(dim, embedded_dim)
43
+ self.embedding_list.append(embedding)
44
+
45
+ def forward(self, x):
46
+ output = 0
47
+ for i in range(x.shape[1]):
48
+ input = x[:, i].to(torch.int)
49
+ device = next(self.parameters()).device
50
+ if device != input.device:
51
+ input = input.to(device)
52
+ emb = self.embedding_list[i](input)
53
+ output += emb
54
+ return output
55
+
56
+
57
+ class GRUEncoder(EncoderBase):
58
+
59
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
60
+ bidirectional: bool, dropout: float, batch_size: int, rank: int=-1,
61
+ no_dropout: bool=False):
62
+ super().__init__()
63
+ self.input_dim = input_dim
64
+ self.hidden_dim = hidden_dim
65
+ self.num_layers = num_layers
66
+ self.bidirectional = bidirectional
67
+ self.dropout = dropout
68
+ self.batch_size = batch_size
69
+ self.rank = rank
70
+ self.model = nn.GRU(input_size=self.input_dim,
71
+ hidden_size=self.hidden_dim,
72
+ num_layers=self.num_layers,
73
+ batch_first=True,
74
+ bidirectional=self.bidirectional,
75
+ dropout=self.dropout if not no_dropout else 0)
76
+ if self.rank >= 0:
77
+ if torch.cuda.is_available():
78
+ self.model = self.model.to(rank)
79
+ else:
80
+ # support mac mps
81
+ self.model = self.model.to(torch.device("mps", rank))
82
+ self.init_hidden(self.batch_size)
83
+
84
+ def init_hidden(self, bsize):
85
+ self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
86
+ min(self.batch_size, bsize),
87
+ self.hidden_dim),
88
+ requires_grad=False)
89
+ if self.rank >= 0:
90
+ if torch.cuda.is_available():
91
+ self.h0 = self.h0.to(self.rank)
92
+ else:
93
+ # support mac mps
94
+ self.h0 = self.h0.to(torch.device("mps", self.rank))
95
+
96
+ def to(self, device):
97
+ newself = super().to(device)
98
+ newself.model = newself.model.to(device)
99
+ newself.h0 = newself.h0.to(device)
100
+ newself.rank = next(newself.parameters()).get_device()
101
+ return newself
102
+
103
+ def forward(self, in_seq_emb):
104
+ ''' forward model
105
+
106
+ Parameters
107
+ ----------
108
+ in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
109
+
110
+ Returns
111
+ -------
112
+ hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
113
+ '''
114
+ # Kishi: I think original MHG had this init_hidden()
115
+ self.init_hidden(in_seq_emb.size(0))
116
+ max_len = in_seq_emb.size(1)
117
+ hidden_seq_emb, self.h0 = self.model(
118
+ in_seq_emb, self.h0)
119
+ # As shown as returns, convert hidden_seq_emb: (batch_size, seq_len, (1 or 2) * hidden_size) -->
120
+ # (batch_size, seq_len, 1 or 2, hidden_size)
121
+ # In the original input the original GRU/LSTM with bidirectional encoding
122
+ # has contactinated tensors
123
+ # (first half for forward RNN, latter half for backward RNN)
124
+ # so convert them in a more friendly format packed for each RNN
125
+ hidden_seq_emb = hidden_seq_emb.view(-1,
126
+ max_len,
127
+ 1 + self.bidirectional,
128
+ self.hidden_dim)
129
+ return hidden_seq_emb
130
+
131
+
132
+ class GRUDecoder(DecoderBase):
133
+
134
+ def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
135
+ dropout: float, batch_size: int, rank: int=-1,
136
+ no_dropout: bool=False):
137
+ super().__init__()
138
+ self.input_dim = input_dim
139
+ self.hidden_dim = hidden_dim
140
+ self.num_layers = num_layers
141
+ self.dropout = dropout
142
+ self.batch_size = batch_size
143
+ self.rank = rank
144
+ self.model = nn.GRU(input_size=self.input_dim,
145
+ hidden_size=self.hidden_dim,
146
+ num_layers=self.num_layers,
147
+ batch_first=True,
148
+ bidirectional=False,
149
+ dropout=self.dropout if not no_dropout else 0
150
+ )
151
+ if self.rank >= 0:
152
+ if torch.cuda.is_available():
153
+ self.model = self.model.to(self.rank)
154
+ else:
155
+ # support mac mps
156
+ self.model = self.model.to(torch.device("mps", self.rank))
157
+ self.init_hidden(self.batch_size)
158
+
159
+ def init_hidden(self, bsize):
160
+ self.hidden_dict['h'] = torch.zeros((self.num_layers,
161
+ min(self.batch_size, bsize),
162
+ self.hidden_dim),
163
+ requires_grad=False)
164
+ if self.rank >= 0:
165
+ if torch.cuda.is_available():
166
+ self.hidden_dict['h'] = self.hidden_dict['h'].to(self.rank)
167
+ else:
168
+ self.hidden_dict['h'] = self.hidden_dict['h'].to(torch.device("mps", self.rank))
169
+
170
+ def to(self, device):
171
+ newself = super().to(device)
172
+ newself.model = newself.model.to(device)
173
+ for k in self.hidden_dict.keys():
174
+ newself.hidden_dict[k] = newself.hidden_dict[k].to(device)
175
+ newself.rank = next(newself.parameters()).get_device()
176
+ return newself
177
+
178
+ def forward_one_step(self, tgt_emb_in):
179
+ ''' one-step forward model
180
+
181
+ Parameters
182
+ ----------
183
+ tgt_emb_in : Tensor, shape (batch_size, input_dim)
184
+
185
+ Returns
186
+ -------
187
+ Tensor, shape (batch_size, hidden_dim)
188
+ '''
189
+ bsize = tgt_emb_in.size(0)
190
+ tgt_emb_out, self.hidden_dict['h'] \
191
+ = self.model(tgt_emb_in.view(bsize, 1, -1),
192
+ self.hidden_dict['h'])
193
+ return tgt_emb_out
194
+
195
+
196
+ class NodeMLP(nn.Module):
197
+ def __init__(self, input_size, output_size, hidden_size):
198
+ super().__init__()
199
+ self.lin1 = nn.Linear(input_size, hidden_size)
200
+ self.nbat = nn.BatchNorm1d(hidden_size)
201
+ self.lin2 = nn.Linear(hidden_size, output_size)
202
+
203
+ def forward(self, x):
204
+ x = self.lin1(x)
205
+ x = self.nbat(x)
206
+ x = x.relu()
207
+ x = self.lin2(x)
208
+ return x
209
+
210
+
211
+ class GINLayer(MessagePassing):
212
+ def __init__(self, node_input_size, node_output_size, node_hidden_size, edge_input_size):
213
+ super().__init__()
214
+ self.node_mlp = NodeMLP(node_input_size, node_output_size, node_hidden_size)
215
+ self.edge_mlp = FeatureEmbedding(edge_input_size, node_output_size)
216
+ self.eps = nn.Parameter(torch.tensor([0.0]))
217
+
218
+ def forward(self, x, edge_index, edge_attr):
219
+ msg = self.propagate(edge_index, x=x ,edge_attr=edge_attr)
220
+ x = (1.0 + self.eps) * x + msg
221
+ x = x.relu()
222
+ x = self.node_mlp(x)
223
+ return x
224
+
225
+ def message(self, x_j, edge_attr):
226
+ edge_attr = self.edge_mlp(edge_attr)
227
+ x_j = x_j + edge_attr
228
+ x_j = x_j.relu()
229
+ return x_j
230
+
231
+ def update(self, aggr_out):
232
+ return aggr_out
233
+
234
+ #TODO implement the case where features of atoms and edges are considered
235
+ # Check GraphMVP and ogb (open graph benchmark) to realize this
236
+ class GIN(torch.nn.Module):
237
+ def __init__(self, node_feature_size, edge_feature_size, hidden_channels=64,
238
+ proximity_size=3, dropout=0.1):
239
+ super().__init__()
240
+ #print("(num node features, num edge features)=", (node_feature_size, edge_feature_size))
241
+ hsize = hidden_channels * 2
242
+ atom_dim, edge_dim = get_atom_edge_feature_dims()
243
+ self.trans = FeatureEmbedding(atom_dim, hidden_channels)
244
+ ml = []
245
+ for _ in range(proximity_size):
246
+ ml.append(GINLayer(hidden_channels, hidden_channels, hsize, edge_dim))
247
+ self.mlist = nn.ModuleList(ml)
248
+ #It is possible to calculate relu with x.relu() where x is an output
249
+ #self.activations = nn.ModuleList(actl)
250
+ self.dropout = dropout
251
+ self.proximity_size = proximity_size
252
+
253
+ def forward(self, x, edge_index, edge_attr, batch_size):
254
+ x = x.to(torch.float)
255
+ #print("before: edge_weight.shape=", edge_attr.shape)
256
+ edge_attr = edge_attr.to(torch.float)
257
+ #print("after: edge_weight.shape=", edge_attr.shape)
258
+ x = self.trans(x)
259
+ # TODO Check if this x is consistent with global_add_pool
260
+ hlist = [global_add_pool(x, batch_size)]
261
+ for id, m in enumerate(self.mlist):
262
+ x = m(x, edge_index=edge_index, edge_attr=edge_attr)
263
+ #print("Done with one layer")
264
+ ###if id != self.proximity_size - 1:
265
+ x = x.relu()
266
+ x = F.dropout(x, p=self.dropout, training=self.training)
267
+ #h = global_mean_pool(x, batch_size)
268
+ h = global_add_pool(x, batch_size)
269
+ hlist.append(h)
270
+ #print("Done with one relu call: x.shape=", x.shape)
271
+ #print("calling golbal mean pool")
272
+ #print("calling dropout x.shape=", x.shape)
273
+ #print("x=", x)
274
+ #print("hlist[0].shape=", hlist[0].shape)
275
+ x = torch.cat(hlist, dim=1)
276
+ #print("x.shape=", x.shape)
277
+ x = F.dropout(x, p=self.dropout, training=self.training)
278
+
279
+ return x
280
+
281
+
282
+ # TODO copied from MHG implementation and adapted here.
283
+ class GrammarSeq2SeqVAE(nn.Module):
284
+
285
+ '''
286
+ Variational seq2seq with grammar.
287
+ TODO: rewrite this class using mixin
288
+ '''
289
+
290
+ def __init__(self, hrg, rank=-1, latent_dim=64, max_len=80,
291
+ batch_size=64, padding_idx=-1,
292
+ encoder_params={'hidden_dim': 384, 'num_layers': 3, 'bidirectional': True,
293
+ 'dropout': 0.1},
294
+ decoder_params={'hidden_dim': 384, #'num_layers': 2,
295
+ 'num_layers': 3,
296
+ 'dropout': 0.1},
297
+ prod_rule_embed_params={'out_dim': 128},
298
+ no_dropout=False):
299
+
300
+ super().__init__()
301
+ # TODO USE GRU FOR ENCODING AND DECODING
302
+ self.hrg = hrg
303
+ self.rank = rank
304
+ self.prod_rule_corpus = hrg.prod_rule_corpus
305
+ self.prod_rule_embed_params = prod_rule_embed_params
306
+
307
+ self.vocab_size = hrg.num_prod_rule + 1
308
+ self.batch_size = batch_size
309
+ self.padding_idx = np.mod(padding_idx, self.vocab_size)
310
+ self.no_dropout = no_dropout
311
+
312
+ self.latent_dim = latent_dim
313
+ self.max_len = max_len
314
+ self.encoder_params = encoder_params
315
+ self.decoder_params = decoder_params
316
+
317
+ # TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
318
+ embed_out_dim = self.prod_rule_embed_params['out_dim']
319
+ #use MolecularProdRuleEmbedding later on
320
+ self.src_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
321
+ padding_idx=self.padding_idx)
322
+ self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
323
+ padding_idx=self.padding_idx)
324
+
325
+ # USE a GRU-based encoder in MHG
326
+ self.encoder = GRUEncoder(input_dim=embed_out_dim, batch_size=self.batch_size,
327
+ rank=self.rank, no_dropout=self.no_dropout,
328
+ **self.encoder_params)
329
+
330
+ lin_dim = (self.encoder_params.get('bidirectional', False) + 1) * self.encoder_params['hidden_dim']
331
+ lin_out_dim = self.latent_dim
332
+ self.hidden2mean = nn.Linear(lin_dim, lin_out_dim, bias=False)
333
+ self.hidden2logvar = nn.Linear(lin_dim, lin_out_dim)
334
+
335
+ # USE a GRU-based decoder in MHG
336
+ self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
337
+ rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
338
+ self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
339
+ self.latent2hidden_dict = nn.ModuleDict()
340
+ dec_lin_out_dim = self.decoder_params['hidden_dim']
341
+ for each_hidden in self.decoder.hidden_dict.keys():
342
+ self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, dec_lin_out_dim)
343
+ if self.rank >= 0:
344
+ if torch.cuda.is_available():
345
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
346
+ else:
347
+ # support mac mps
348
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
349
+
350
+ self.dec2vocab = nn.Linear(dec_lin_out_dim, self.vocab_size)
351
+ self.encoder.init_hidden(self.batch_size)
352
+ self.decoder.init_hidden(self.batch_size)
353
+
354
+ # TODO Do we need this?
355
+ if hasattr(self.src_embedding, 'weight'):
356
+ self.src_embedding.weight.data.uniform_(-0.1, 0.1)
357
+ if hasattr(self.tgt_embedding, 'weight'):
358
+ self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
359
+
360
+ self.encoder.init_hidden(self.batch_size)
361
+ self.decoder.init_hidden(self.batch_size)
362
+
363
+ def to(self, device):
364
+ newself = super().to(device)
365
+ newself.src_embedding = newself.src_embedding.to(device)
366
+ newself.tgt_embedding = newself.tgt_embedding.to(device)
367
+ newself.encoder = newself.encoder.to(device)
368
+ newself.decoder = newself.decoder.to(device)
369
+ newself.dec2vocab = newself.dec2vocab.to(device)
370
+ newself.hidden2mean = newself.hidden2mean.to(device)
371
+ newself.hidden2logvar = newself.hidden2logvar.to(device)
372
+ newself.latent2tgt_emb = newself.latent2tgt_emb.to(device)
373
+ newself.latent2hidden_dict = newself.latent2hidden_dict.to(device)
374
+ return newself
375
+
376
+ def forward(self, in_seq, out_seq):
377
+ ''' forward model
378
+
379
+ Parameters
380
+ ----------
381
+ in_seq : Variable, shape (batch_size, length)
382
+ each element corresponds to word index.
383
+ where the index should be less than `vocab_size`
384
+
385
+ Returns
386
+ -------
387
+ Variable, shape (batch_size, length, vocab_size)
388
+ logit of each word (applying softmax yields the probability)
389
+ '''
390
+ mu, logvar = self.encode(in_seq)
391
+ z = self.reparameterize(mu, logvar)
392
+ return self.decode(z, out_seq), mu, logvar
393
+
394
+ def encode(self, in_seq):
395
+ src_emb = self.src_embedding(in_seq)
396
+ src_h = self.encoder.forward(src_emb)
397
+ if self.encoder_params.get('bidirectional', False):
398
+ concat_src_h = torch.cat((src_h[:, -1, 0, :], src_h[:, 0, 1, :]), dim=1)
399
+ return self.hidden2mean(concat_src_h), self.hidden2logvar(concat_src_h)
400
+ else:
401
+ return self.hidden2mean(src_h[:, -1, :]), self.hidden2logvar(src_h[:, -1, :])
402
+
403
+ def reparameterize(self, mu, logvar, training=True):
404
+ if training:
405
+ std = logvar.mul(0.5).exp_()
406
+ device = next(self.parameters()).device
407
+ eps = Variable(std.data.new(std.size()).normal_())
408
+ if device != eps.get_device():
409
+ eps.to(device)
410
+ return eps.mul(std).add_(mu)
411
+ else:
412
+ return mu
413
+
414
+ #TODO Not tested. Need to implement this in case of molecular structure generation
415
+ def sample(self, sample_size=-1, deterministic=True, return_z=False):
416
+ self.eval()
417
+ self.init_hidden()
418
+ if sample_size == -1:
419
+ sample_size = self.batch_size
420
+
421
+ num_iter = int(np.ceil(sample_size / self.batch_size))
422
+ hg_list = []
423
+ z_list = []
424
+ for _ in range(num_iter):
425
+ z = Variable(torch.normal(
426
+ torch.zeros(self.batch_size, self.latent_dim),
427
+ torch.ones(self.batch_size * self.latent_dim))).cuda()
428
+ _, each_hg_list = self.decode(z, deterministic=deterministic)
429
+ z_list.append(z)
430
+ hg_list += each_hg_list
431
+ z = torch.cat(z_list)[:sample_size]
432
+ hg_list = hg_list[:sample_size]
433
+ if return_z:
434
+ return hg_list, z.cpu().detach().numpy()
435
+ else:
436
+ return hg_list
437
+
438
+ def decode(self, z=None, out_seq=None, deterministic=True):
439
+ if z is None:
440
+ z = Variable(torch.normal(
441
+ torch.zeros(self.batch_size, self.latent_dim),
442
+ torch.ones(self.batch_size * self.latent_dim)))
443
+ if self.rank >= 0:
444
+ z = z.to(next(self.parameters()).device)
445
+
446
+ hidden_dict_0 = {}
447
+ for each_hidden in self.latent2hidden_dict.keys():
448
+ hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
449
+ bsize = z.size(0)
450
+ self.decoder.init_hidden(bsize)
451
+ self.decoder.feed_hidden(hidden_dict_0)
452
+
453
+ if out_seq is not None:
454
+ tgt_emb0 = self.latent2tgt_emb(z)
455
+ tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
456
+ out_seq_emb = self.tgt_embedding(out_seq)
457
+ tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
458
+ tgt_emb_pred_list = []
459
+ for each_idx in range(self.max_len):
460
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb[:, each_idx, :].view(bsize, 1, -1))
461
+ tgt_emb_pred_list.append(tgt_emb_pred)
462
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
463
+ return vocab_logit
464
+ else:
465
+ with torch.no_grad():
466
+ tgt_emb = self.latent2tgt_emb(z)
467
+ tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
468
+ tgt_emb_pred_list = []
469
+ stack_list = []
470
+ hg_list = []
471
+ nt_symbol_list = []
472
+ nt_edge_list = []
473
+ gen_finish_list = []
474
+ for _ in range(bsize):
475
+ stack_list.append([])
476
+ hg_list.append(None)
477
+ nt_symbol_list.append(NTSymbol(degree=0,
478
+ is_aromatic=False,
479
+ bond_symbol_list=[]))
480
+ nt_edge_list.append(None)
481
+ gen_finish_list.append(False)
482
+
483
+ for idx in range(self.max_len):
484
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
485
+ tgt_emb_pred_list.append(tgt_emb_pred)
486
+ vocab_logit = self.dec2vocab(tgt_emb_pred)
487
+ for each_batch_idx in range(bsize):
488
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
489
+ # get production rule greedily
490
+ prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
491
+ nt_symbol_list[each_batch_idx],
492
+ deterministic=deterministic)
493
+ # convert production rule into an index
494
+ tgt_id = self.hrg.prod_rule_list.index(prod_rule)
495
+ # apply the production rule
496
+ hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
497
+ # add non-terminals to the stack
498
+ stack_list[each_batch_idx].extend(nt_edges[::-1])
499
+ # if the stack size is 0, generation has finished!
500
+ if len(stack_list[each_batch_idx]) == 0:
501
+ gen_finish_list[each_batch_idx] = True
502
+ else:
503
+ nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
504
+ nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
505
+ else:
506
+ tgt_id = np.mod(self.padding_idx, self.vocab_size)
507
+ indice_tensor = torch.LongTensor([tgt_id])
508
+ device = next(self.parameters()).device
509
+ if indice_tensor.device != device:
510
+ indice_tensor = indice_tensor.to(device)
511
+ tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
512
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
513
+ #for id, v in enumerate(gen_finish_list):
514
+ #if not v:
515
+ # print("bacth id={} not finished generating a sequence: ".format(id))
516
+ return gen_finish_list, vocab_logit, hg_list
517
+
518
+
519
+ # TODO A lot of duplicates with GrammarVAE. Clean up it if necessary
520
+ class GrammarGINVAE(nn.Module):
521
+
522
+ '''
523
+ Variational autoencoder based on GIN and grammar
524
+ '''
525
+
526
+ def __init__(self, hrg, rank=-1, max_len=80,
527
+ batch_size=64, padding_idx=-1,
528
+ encoder_params={'node_feature_size': 4, 'edge_feature_size': 3,
529
+ 'hidden_channels': 64, 'proximity_size': 3,
530
+ 'dropout': 0.1},
531
+ decoder_params={'hidden_dim': 384, 'num_layers': 3,
532
+ 'dropout': 0.1},
533
+ prod_rule_embed_params={'out_dim': 128},
534
+ no_dropout=False):
535
+
536
+ super().__init__()
537
+ # TODO USE GRU FOR ENCODING AND DECODING
538
+ self.hrg = hrg
539
+ self.rank = rank
540
+ self.prod_rule_corpus = hrg.prod_rule_corpus
541
+ self.prod_rule_embed_params = prod_rule_embed_params
542
+
543
+ self.vocab_size = hrg.num_prod_rule + 1
544
+ self.batch_size = batch_size
545
+ self.padding_idx = np.mod(padding_idx, self.vocab_size)
546
+ self.no_dropout = no_dropout
547
+ self.max_len = max_len
548
+ self.encoder_params = encoder_params
549
+ self.decoder_params = decoder_params
550
+
551
+ # TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
552
+ embed_out_dim = self.prod_rule_embed_params['out_dim']
553
+ #use MolecularProdRuleEmbedding later on
554
+ self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
555
+ padding_idx=self.padding_idx)
556
+
557
+ self.encoder = GIN(**self.encoder_params)
558
+ self.latent_dim = self.encoder_params['hidden_channels']
559
+ self.proximity_size = self.encoder_params['proximity_size']
560
+ hidden_dim = self.decoder_params['hidden_dim']
561
+ self.hidden2mean = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim, bias=False)
562
+ self.hidden2logvar = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim)
563
+
564
+ self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
565
+ rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
566
+ self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
567
+ self.latent2hidden_dict = nn.ModuleDict()
568
+ for each_hidden in self.decoder.hidden_dict.keys():
569
+ self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, hidden_dim)
570
+ if self.rank >= 0:
571
+ if torch.cuda.is_available():
572
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
573
+ else:
574
+ # support mac mps
575
+ self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
576
+
577
+ self.dec2vocab = nn.Linear(hidden_dim, self.vocab_size)
578
+ self.decoder.init_hidden(self.batch_size)
579
+
580
+ # TODO Do we need this?
581
+ if hasattr(self.tgt_embedding, 'weight'):
582
+ self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
583
+ self.decoder.init_hidden(self.batch_size)
584
+
585
+ def to(self, device):
586
+ newself = super().to(device)
587
+ newself.encoder = newself.encoder.to(device)
588
+ newself.decoder = newself.decoder.to(device)
589
+ newself.rank = next(newself.encoder.parameters()).get_device()
590
+ return newself
591
+
592
+ def forward(self, x, edge_index, edge_attr, batch_size, out_seq=None, sched_prob = None):
593
+ mu, logvar = self.encode(x, edge_index, edge_attr, batch_size)
594
+ z = self.reparameterize(mu, logvar)
595
+ return self.decode(z, out_seq, sched_prob=sched_prob), mu, logvar
596
+
597
+ #TODO Not tested. Need to implement this in case of molecular structure generation
598
+ def sample(self, sample_size=-1, deterministic=True, return_z=False):
599
+ self.eval()
600
+ self.init_hidden()
601
+ if sample_size == -1:
602
+ sample_size = self.batch_size
603
+
604
+ num_iter = int(np.ceil(sample_size / self.batch_size))
605
+ hg_list = []
606
+ z_list = []
607
+ for _ in range(num_iter):
608
+ z = Variable(torch.normal(
609
+ torch.zeros(self.batch_size, self.latent_dim),
610
+ torch.ones(self.batch_size * self.latent_dim))).cuda()
611
+ _, each_hg_list = self.decode(z, deterministic=deterministic)
612
+ z_list.append(z)
613
+ hg_list += each_hg_list
614
+ z = torch.cat(z_list)[:sample_size]
615
+ hg_list = hg_list[:sample_size]
616
+ if return_z:
617
+ return hg_list, z.cpu().detach().numpy()
618
+ else:
619
+ return hg_list
620
+
621
+ def decode(self, z=None, out_seq=None, deterministic=True, sched_prob=None):
622
+ if z is None:
623
+ z = Variable(torch.normal(
624
+ torch.zeros(self.batch_size, self.latent_dim),
625
+ torch.ones(self.batch_size * self.latent_dim)))
626
+ if self.rank >= 0:
627
+ z = z.to(next(self.parameters()).device)
628
+
629
+ hidden_dict_0 = {}
630
+ for each_hidden in self.latent2hidden_dict.keys():
631
+ hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
632
+ bsize = z.size(0)
633
+ self.decoder.init_hidden(bsize)
634
+ self.decoder.feed_hidden(hidden_dict_0)
635
+
636
+ if out_seq is not None:
637
+ tgt_emb0 = self.latent2tgt_emb(z)
638
+ tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
639
+ out_seq_emb = self.tgt_embedding(out_seq)
640
+ tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
641
+ tgt_emb_pred_list = []
642
+ tgt_emb_pred = None
643
+ for each_idx in range(self.max_len):
644
+ if tgt_emb_pred is None or sched_prob is None or torch.rand(1)[0] <= sched_prob:
645
+ inp = tgt_emb[:, each_idx, :].view(bsize, 1, -1)
646
+ else:
647
+ cur_logit = self.dec2vocab(tgt_emb_pred)
648
+ yi = torch.argmax(cur_logit, dim=2)
649
+ inp = self.tgt_embedding(yi)
650
+ tgt_emb_pred = self.decoder.forward_one_step(inp)
651
+ tgt_emb_pred_list.append(tgt_emb_pred)
652
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
653
+ return vocab_logit
654
+ else:
655
+ with torch.no_grad():
656
+ tgt_emb = self.latent2tgt_emb(z)
657
+ tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
658
+ tgt_emb_pred_list = []
659
+ stack_list = []
660
+ hg_list = []
661
+ nt_symbol_list = []
662
+ nt_edge_list = []
663
+ gen_finish_list = []
664
+ for _ in range(bsize):
665
+ stack_list.append([])
666
+ hg_list.append(None)
667
+ nt_symbol_list.append(NTSymbol(degree=0,
668
+ is_aromatic=False,
669
+ bond_symbol_list=[]))
670
+ nt_edge_list.append(None)
671
+ gen_finish_list.append(False)
672
+
673
+ for _ in range(self.max_len):
674
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
675
+ tgt_emb_pred_list.append(tgt_emb_pred)
676
+ vocab_logit = self.dec2vocab(tgt_emb_pred)
677
+ for each_batch_idx in range(bsize):
678
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
679
+ # get production rule greedily
680
+ prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
681
+ nt_symbol_list[each_batch_idx],
682
+ deterministic=deterministic)
683
+ # convert production rule into an index
684
+ tgt_id = self.hrg.prod_rule_list.index(prod_rule)
685
+ # apply the production rule
686
+ hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
687
+ # add non-terminals to the stack
688
+ stack_list[each_batch_idx].extend(nt_edges[::-1])
689
+ # if the stack size is 0, generation has finished!
690
+ if len(stack_list[each_batch_idx]) == 0:
691
+ gen_finish_list[each_batch_idx] = True
692
+ else:
693
+ nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
694
+ nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
695
+ else:
696
+ tgt_id = np.mod(self.padding_idx, self.vocab_size)
697
+ indice_tensor = torch.LongTensor([tgt_id])
698
+ if self.rank >= 0:
699
+ indice_tensor = indice_tensor.to(next(self.parameters()).device)
700
+ tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
701
+ vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
702
+ return gen_finish_list, vocab_logit, hg_list
703
+
704
+ #TODO Not tested. Need to implement this in case of molecular structure generation
705
+ def conditional_distribution(self, z, tgt_id_list):
706
+ self.eval()
707
+ self.init_hidden()
708
+ z = z.cuda()
709
+
710
+ hidden_dict_0 = {}
711
+ for each_hidden in self.latent2hidden_dict.keys():
712
+ hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
713
+ self.decoder.feed_hidden(hidden_dict_0)
714
+
715
+ with torch.no_grad():
716
+ tgt_emb = self.latent2tgt_emb(z)
717
+ tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
718
+ nt_symbol_list = []
719
+ stack_list = []
720
+ hg_list = []
721
+ nt_edge_list = []
722
+ gen_finish_list = []
723
+ for _ in range(self.batch_size):
724
+ nt_symbol_list.append(NTSymbol(degree=0,
725
+ is_aromatic=False,
726
+ bond_symbol_list=[]))
727
+ stack_list.append([])
728
+ hg_list.append(None)
729
+ nt_edge_list.append(None)
730
+ gen_finish_list.append(False)
731
+
732
+ for each_position in range(len(tgt_id_list[0])):
733
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
734
+ for each_batch_idx in range(self.batch_size):
735
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
736
+ # use the prespecified target ids
737
+ tgt_id = tgt_id_list[each_batch_idx][each_position]
738
+ prod_rule = self.hrg.prod_rule_list[tgt_id]
739
+ # apply the production rule
740
+ hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
741
+ # add non-terminals to the stack
742
+ stack_list[each_batch_idx].extend(nt_edges[::-1])
743
+ # if the stack size is 0, generation has finished!
744
+ if len(stack_list[each_batch_idx]) == 0:
745
+ gen_finish_list[each_batch_idx] = True
746
+ else:
747
+ nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
748
+ nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
749
+ else:
750
+ tgt_id = np.mod(self.padding_idx, self.vocab_size)
751
+ indice_tensor = torch.LongTensor([tgt_id])
752
+ indice_tensor = indice_tensor.cuda()
753
+ tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
754
+
755
+ # last one step
756
+ conditional_logprob_list = []
757
+ tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
758
+ vocab_logit = self.dec2vocab(tgt_emb_pred)
759
+ for each_batch_idx in range(self.batch_size):
760
+ if not gen_finish_list[each_batch_idx]: # if generation has not finished
761
+ # get production rule greedily
762
+ masked_logprob = self.hrg.prod_rule_corpus.masked_logprob(
763
+ vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
764
+ nt_symbol_list[each_batch_idx])
765
+ conditional_logprob_list.append(masked_logprob)
766
+ else:
767
+ conditional_logprob_list.append(None)
768
+ return conditional_logprob_list
769
+
770
+ #TODO Not tested. Need to implement this in case of molecular structure generation
771
+ def decode_with_beam_search(self, z, beam_width=1):
772
+ ''' Decode a latent vector using beam search.
773
+
774
+ Parameters
775
+ ----------
776
+ z
777
+ latent vector
778
+ beam_width : int
779
+ parameter for beam search
780
+
781
+ Returns
782
+ -------
783
+ List of Hypergraphs
784
+ '''
785
+ if self.batch_size != 1:
786
+ raise ValueError('this method works only under batch_size=1')
787
+ if self.padding_idx != -1:
788
+ raise ValueError('this method works only under padding_idx=-1')
789
+ top_k_tgt_id_list = [[]] * beam_width
790
+ logprob_list = [0.] * beam_width
791
+
792
+ for each_len in range(self.max_len):
793
+ expanded_logprob_list = np.repeat(logprob_list, self.vocab_size) # including padding_idx
794
+ expanded_length_list = np.array([0] * (beam_width * self.vocab_size))
795
+ for each_beam_idx, each_candidate in enumerate(top_k_tgt_id_list):
796
+ conditional_logprob = self.conditional_distribution(z, [each_candidate])[0]
797
+ if conditional_logprob is None:
798
+ expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
799
+ = logprob_list[each_beam_idx]
800
+ expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
801
+ = -np.inf
802
+ expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
803
+ = len(each_candidate)
804
+ else:
805
+ expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
806
+ = logprob_list[each_beam_idx] + conditional_logprob
807
+ expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
808
+ = -np.inf
809
+ expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
810
+ = len(each_candidate) + 1
811
+ score_list = np.array(expanded_logprob_list) / np.array(expanded_length_list)
812
+ if each_len == 0:
813
+ top_k_list = np.argsort(score_list[:self.vocab_size])[::-1][:beam_width]
814
+ else:
815
+ top_k_list = np.argsort(score_list)[::-1][:beam_width]
816
+ next_top_k_tgt_id_list = []
817
+ next_logprob_list = []
818
+ for each_top_k in top_k_list:
819
+ beam_idx = each_top_k // self.vocab_size
820
+ vocab_idx = each_top_k % self.vocab_size
821
+ if vocab_idx == self.vocab_size - 1:
822
+ next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx])
823
+ next_logprob_list.append(expanded_logprob_list[each_top_k])
824
+ else:
825
+ next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx] + [vocab_idx])
826
+ next_logprob_list.append(expanded_logprob_list[each_top_k])
827
+ top_k_tgt_id_list = next_top_k_tgt_id_list
828
+ logprob_list = next_logprob_list
829
+
830
+ # construct hypergraphs
831
+ hg_list = []
832
+ for each_tgt_id_list in top_k_tgt_id_list:
833
+ hg = None
834
+ stack = []
835
+ nt_edge = None
836
+ for each_idx, each_prod_rule_id in enumerate(each_tgt_id_list):
837
+ prod_rule = self.hrg.prod_rule_list[each_prod_rule_id]
838
+ hg, nt_edges = prod_rule.applied_to(hg, nt_edge)
839
+ stack.extend(nt_edges[::-1])
840
+ try:
841
+ nt_edge = stack.pop()
842
+ except IndexError:
843
+ if each_idx == len(each_tgt_id_list) - 1:
844
+ break
845
+ else:
846
+ raise ValueError('some bugs')
847
+ hg_list.append(hg)
848
+ return hg_list
849
+
850
+ def graph_embed(self, x, edge_index, edge_attr, batch_size):
851
+ src_h = self.encoder.forward(x, edge_index, edge_attr, batch_size)
852
+ return src_h
853
+
854
+ def encode(self, x, edge_index, edge_attr, batch_size):
855
+ #print("device for src_emb=", src_emb.get_device())
856
+ #print("device for self.encoder=", next(self.encoder.parameters()).get_device())
857
+ src_h = self.graph_embed(x, edge_index, edge_attr, batch_size)
858
+ mu, lv = self.get_mean_var(src_h)
859
+ return mu, lv
860
+
861
+ def get_mean_var(self, src_h):
862
+ #src_h = torch.tanh(src_h)
863
+ mu = self.hidden2mean(src_h)
864
+ lv = self.hidden2logvar(src_h)
865
+ mu = torch.tanh(mu)
866
+ lv = torch.tanh(lv)
867
+ return mu, lv
868
+
869
+ def reparameterize(self, mu, logvar, training=True):
870
+ if training:
871
+ std = logvar.mul(0.5).exp_()
872
+ eps = Variable(std.data.new(std.size()).normal_())
873
+ if self.rank >= 0:
874
+ eps = eps.to(next(self.parameters()).device)
875
+ return eps.mul(std).add_(mu)
876
+ else:
877
+ return mu
878
+
879
+ # Copied from the MHG implementation and adapted
880
+ class GrammarVAELoss(_Loss):
881
+
882
+ '''
883
+ a loss function for Grammar VAE
884
+
885
+ Attributes
886
+ ----------
887
+ hrg : HyperedgeReplacementGrammar
888
+ beta : float
889
+ coefficient of KL divergence
890
+ '''
891
+
892
+ def __init__(self, rank, hrg, beta=1.0, **kwargs):
893
+ super().__init__(**kwargs)
894
+ self.hrg = hrg
895
+ self.beta = beta
896
+ self.rank = rank
897
+
898
+ def forward(self, mu, logvar, in_seq_pred, in_seq):
899
+ ''' compute VAE loss
900
+
901
+ Parameters
902
+ ----------
903
+ in_seq_pred : torch.Tensor, shape (batch_size, max_len, vocab_size)
904
+ logit
905
+ in_seq : torch.Tensor, shape (batch_size, max_len)
906
+ each element corresponds to a word id in vocabulary.
907
+ mu : torch.Tensor, shape (batch_size, hidden_dim)
908
+ logvar : torch.Tensor, shape (batch_size, hidden_dim)
909
+ mean and log variance of the normal distribution
910
+ '''
911
+ batch_size = in_seq_pred.shape[0]
912
+ max_len = in_seq_pred.shape[1]
913
+ vocab_size = in_seq_pred.shape[2]
914
+ mask = torch.zeros(in_seq_pred.shape)
915
+
916
+ for each_batch in range(batch_size):
917
+ flag = True
918
+ for each_idx in range(max_len):
919
+ prod_rule_idx = in_seq[each_batch, each_idx]
920
+ if prod_rule_idx == vocab_size - 1:
921
+ #### DETERMINE WHETHER THIS SHOULD BE SKIPPED OR NOT
922
+ mask[each_batch, each_idx, prod_rule_idx] = 1
923
+ #break
924
+ continue
925
+ lhs = self.hrg.prod_rule_corpus.prod_rule_list[prod_rule_idx].lhs_nt_symbol
926
+ lhs_idx = self.hrg.prod_rule_corpus.nt_symbol_list.index(lhs)
927
+ mask[each_batch, each_idx, :-1] = torch.FloatTensor(self.hrg.prod_rule_corpus.lhs_in_prod_rule[lhs_idx])
928
+ if self.rank >= 0:
929
+ mask = mask.to(next(self.parameters()).device)
930
+ in_seq_pred = mask * in_seq_pred
931
+
932
+ cross_entropy = F.cross_entropy(
933
+ in_seq_pred.view(-1, vocab_size),
934
+ in_seq.view(-1),
935
+ reduction='sum',
936
+ #ignore_index=self.ignore_index if self.ignore_index is not None else -100
937
+ )
938
+ kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
939
+ return cross_entropy + self.beta * kl_div
940
+
941
+
942
+ class VAELoss(_Loss):
943
+ def __init__(self, beta=0.01):
944
+ super().__init__()
945
+ self.beta = beta
946
+
947
+ def forward(self, mean, log_var, dec_outputs, targets):
948
+
949
+ device = mean.get_device()
950
+ if device >= 0:
951
+ targets = targets.to(mean.get_device())
952
+ reconstruction = F.cross_entropy(dec_outputs.view(-1, dec_outputs.size(2)), targets.view(-1), reduction='sum')
953
+
954
+ KL = 0.5 * torch.sum(1 + log_var - mean ** 2 - torch.exp(log_var))
955
+ loss = - self.beta * KL + reconstruction
956
+ return loss
models/mhg_model/notebooks/mhg-gnn_encoder_decoder_example.ipynb ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "829ddc03",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import sys\n",
11
+ "sys.path.append('..')"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "ea820e23",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "import torch\n",
22
+ "import load"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "b9a51fa8",
28
+ "metadata": {},
29
+ "source": [
30
+ "# Load MHG-GNN"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "c6ea1fc8",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "model_ckp = \"models/model_checkpoints/mhg_model/pickles/mhggnn_pretrained_model_radius7_1116_2023.pickle\"\n",
41
+ "\n",
42
+ "model = load.load(model_name = model_ckp)\n",
43
+ "if model is None:\n",
44
+ " print(\"Model not loaded, please check you have MHG pickle file\")\n",
45
+ "else:\n",
46
+ " print(\"MHG model loaded\")"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "b4a0b557",
52
+ "metadata": {},
53
+ "source": [
54
+ "# Embeddings\n",
55
+ "\n",
56
+ "※ replace the smiles exaple list with your dataset"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "c63a6be6",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "with torch.no_grad():\n",
67
+ " repr = model.encode([\"CCO\", \"O=C=O\", \"OC(=O)c1ccccc1C(=O)O\"])\n",
68
+ " \n",
69
+ "# Print the latent vectors\n",
70
+ "print(repr)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "id": "a59f9442",
76
+ "metadata": {},
77
+ "source": [
78
+ "# Decoding"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "6a0d8a41",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "orig = model.decode(repr)\n",
89
+ "print(orig)"
90
+ ]
91
+ }
92
+ ],
93
+ "metadata": {
94
+ "kernelspec": {
95
+ "display_name": "Python 3 (ipykernel)",
96
+ "language": "python",
97
+ "name": "python3"
98
+ },
99
+ "language_info": {
100
+ "codemirror_mode": {
101
+ "name": "ipython",
102
+ "version": 3
103
+ },
104
+ "file_extension": ".py",
105
+ "mimetype": "text/x-python",
106
+ "name": "python",
107
+ "nbconvert_exporter": "python",
108
+ "pygments_lexer": "ipython3",
109
+ "version": "3.7.10"
110
+ }
111
+ },
112
+ "nbformat": 4,
113
+ "nbformat_minor": 5
114
+ }
models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa61ca0975afd93bf1f3c43f3b720594018f57ac9df3cb740def2eb28af6529d
3
+ size 342570
models/mhg_model/setup.cfg ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = mhg-gnn
3
+ version = attr: .__version__
4
+ description = Package for mhg-gnn
5
+ author= team
6
+ long_description_content_type=text/markdown
7
+ long_description = file: README.md
8
+ python_requires = >= 3.9.7
9
+ license = TBD
10
+
11
+ classifiers =
12
+ Programming Language :: Python :: 3
13
+ Programming Language :: Python :: 3.9
14
+
15
+ [options]
16
+ install_requires =
17
+ networkx>=2.8
18
+ numpy>=1.23.5, <2.0.0
19
+ pandas>=1.5.3
20
+ rdkit-pypi>=2022.9.4, <2023.9.6
21
+ torch>=2.0.0
22
+ torchinfo>=1.8.0
23
+ torch-geometric>=2.3.1
24
+ requests>=2.32.2
25
+ scikit-learn>=1.5.0
26
+ urllib3>=2.2.2
27
+
28
+
29
+ setup_requires =
30
+ setuptools
31
+ package_dir =
32
+ = .
33
+ packages=find:
34
+ include_package_data = True
35
+
36
+ [options.packages.find]
37
+ where = .
models/mhg_model/setup.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import setuptools
4
+
5
+ if __name__ == "__main__":
6
+ setuptools.setup()
models/selfies_model/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: transformers
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - chemistry
7
+ ---
8
+
9
+ # selfies-ted
10
+
11
+ selfies-ted is a project for encoding SMILES (Simplified Molecular Input Line Entry System) into SELFIES (SELF-referencing Embedded Strings) and generating embeddings for molecular representations.
12
+
13
+ ![selfies-ted](selfies-ted.png)
14
+ ## Model Architecture
15
+
16
+ Configuration details
17
+
18
+ Encoder and Decoder FFN dimensions: 256
19
+ Number of attention heads: 4
20
+ Number of encoder and decoder layers: 2
21
+ Total number of hidden layers: 6
22
+ Maximum position embeddings: 128
23
+ Model dimension (d_model): 256
24
+
25
+ ## Pretrained Models and Training Logs
26
+ We provide checkpoints of the selfies-ted model pre-trained on a dataset of molecules curated from PubChem. The pre-trained model shows competitive performance on molecular representation tasks. For model weights: "HuggingFace link".
27
+
28
+ To install and use the pre-trained model:
29
+
30
+ Download the selfies_ted_model.pkl file from the "HuggingFace link".
31
+ Add the selfies-ted selfies_ted_model.pkl to the models/ directory. The directory structure should look like the following:
32
+
33
+ ```
34
+ models/
35
+ └── selfies_ted_model.pkl
36
+ ```
37
+
38
+ ## Installation
39
+
40
+ To use this project, you'll need to install the required dependencies. We recommend using a virtual environment:
41
+
42
+ ```bash
43
+ python -m venv venv
44
+ source venv/bin/activate # On Windows use `venv\Scripts\activate`
45
+ ```
46
+
47
+ Install the required dependencies
48
+
49
+ ```
50
+ pip install -r requirements.txt
51
+ ```
52
+
53
+
54
+ ## Usage
55
+
56
+ ### Import
57
+
58
+ ```
59
+ import load
60
+ ```
61
+ ### Training the Model
62
+
63
+ To train the model, use the train.py script:
64
+
65
+ ```
66
+ python train.py -f <path_to_your_data_file>
67
+ ```
68
+
69
+
70
+ Note: The actual usage may depend on the specific implementation in load.py. Please refer to the source code for detailed functionality.
71
+
72
+ ### Load the model and tokenizer
73
+ ```
74
+ load.load("path/to/checkpoint.pkl")
75
+ ```
76
+ ### Encode SMILES strings
77
+ ```
78
+ smiles_list = ["COC", "CCO"]
79
+ ```
80
+ ```
81
+ embeddings = load.encode(smiles_list)
82
+ ```
83
+
84
+
85
+ ## Example Notebook
86
+
87
+ Example notebook of this project is `selfies-ted-example.ipynb`.
models/selfies_model/load.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import selfies as sf
3
+ import numpy as np
4
+ import pandas as pd
5
+ from rdkit import Chem
6
+ from transformers import AutoTokenizer, AutoModel
7
+ import gc
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from multiprocessing import Pool, cpu_count
10
+ from tqdm import tqdm
11
+
12
+ import os
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+
16
+ class SELFIESDataset(Dataset):
17
+ def __init__(self, selfies_list):
18
+ self.selfies = selfies_list
19
+
20
+ def __len__(self):
21
+ return len(self.selfies)
22
+
23
+ def __getitem__(self, idx):
24
+ return self.selfies[idx]
25
+
26
+ class SELFIES(torch.nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.model = None
30
+ self.tokenizer = None
31
+ self.invalid = []
32
+
33
+ def smiles_to_selfies(self, smiles):
34
+ try:
35
+ return sf.encoder(smiles.strip()).replace('][', '] [')
36
+ except:
37
+ try:
38
+ smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.strip()))
39
+ return sf.encoder(smiles).replace('][', '] [')
40
+ except:
41
+ return None
42
+
43
+ def get_selfies(self, smiles_list):
44
+ with Pool(cpu_count()) as pool:
45
+ selfies = list(pool.map(self.smiles_to_selfies, smiles_list))
46
+
47
+ self.invalid = [i for i, s in enumerate(selfies) if s is None]
48
+ selfies = [s if s is not None else '[nop]' for s in selfies]
49
+ return selfies
50
+
51
+ @torch.no_grad()
52
+ def get_embedding_batch(self, selfies_batch):
53
+ encodings = self.tokenizer(
54
+ selfies_batch,
55
+ return_tensors='pt',
56
+ max_length=128,
57
+ truncation=True,
58
+ padding='max_length'
59
+ )
60
+ encodings = {k: v.to(self.model.device) for k, v in encodings.items()}
61
+
62
+ outputs = self.model.encoder(
63
+ input_ids=encodings['input_ids'],
64
+ attention_mask=encodings['attention_mask']
65
+ )
66
+
67
+ model_output = outputs.last_hidden_state
68
+ input_mask_expanded = encodings['attention_mask'].unsqueeze(-1).expand(model_output.size()).float()
69
+ sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
70
+ sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
71
+ pooled_output = sum_embeddings / sum_mask
72
+
73
+ return pooled_output.cpu().numpy()
74
+
75
+ def load(self, checkpoint=None):
76
+ self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
77
+ self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted")
78
+ self.model.eval()
79
+
80
+ def encode(self, smiles_list=[], use_gpu=False, return_tensor=False, batch_size=128, num_workers=4):
81
+ selfies = self.get_selfies(smiles_list)
82
+ dataset = SELFIESDataset(selfies)
83
+
84
+ device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
85
+ self.model.to(device)
86
+
87
+ loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
88
+
89
+ embeddings = []
90
+ for batch in tqdm(loader, desc="Encoding"):
91
+ emb = self.get_embedding_batch(batch)
92
+ embeddings.append(emb)
93
+ del emb
94
+ gc.collect()
95
+
96
+ emb = np.vstack(embeddings)
97
+
98
+ for idx in self.invalid:
99
+ emb[idx] = np.nan
100
+ print(f"Cannot encode {smiles_list[idx]} to selfies. Embedding replaced by NaN.")
101
+
102
+ return torch.tensor(emb) if return_tensor else pd.DataFrame(emb)
models/selfies_model/requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.38
3
+ numpy>=1.26.1
4
+ datasets>=2.13.1
5
+ evaluate>=0.4.0
6
+ selfies>=2.1.0
7
+ scikit-learn>=1.2.1
8
+ pyarrow>=14.0.1
9
+ requests>=2.31.0
10
+ urllib3>=2.0.7
11
+ aiohttp>=3.9.0
12
+ zipp>=3.17.0