Spaces:
Running
Running
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- INTRODUCTION.md +9 -0
- app.py +523 -0
- data/.DS_Store +0 -0
- data/bace/test.csv +0 -0
- data/bace/train.csv +0 -0
- data/bace/valid.csv +0 -0
- data/esol/test.csv +109 -0
- data/esol/train.csv +0 -0
- data/lce/test.csv +31 -0
- data/lce/test_data.csv +14 -0
- data/lce/train.csv +121 -0
- data/lce/train_data.csv +148 -0
- models/.DS_Store +0 -0
- models/.gitattributes +3 -0
- models/fm4m.py +964 -0
- models/mhg_model/.DS_Store +0 -0
- models/mhg_model/README.md +75 -0
- models/mhg_model/__init__.py +5 -0
- models/mhg_model/graph_grammar/__init__.py +19 -0
- models/mhg_model/graph_grammar/algo/__init__.py +20 -0
- models/mhg_model/graph_grammar/algo/tree_decomposition.py +821 -0
- models/mhg_model/graph_grammar/graph_grammar/__init__.py +20 -0
- models/mhg_model/graph_grammar/graph_grammar/base.py +30 -0
- models/mhg_model/graph_grammar/graph_grammar/corpus.py +152 -0
- models/mhg_model/graph_grammar/graph_grammar/hrg.py +1065 -0
- models/mhg_model/graph_grammar/graph_grammar/symbols.py +180 -0
- models/mhg_model/graph_grammar/graph_grammar/utils.py +130 -0
- models/mhg_model/graph_grammar/hypergraph.py +544 -0
- models/mhg_model/graph_grammar/io/__init__.py +20 -0
- models/mhg_model/graph_grammar/io/smi.py +559 -0
- models/mhg_model/graph_grammar/nn/__init__.py +11 -0
- models/mhg_model/graph_grammar/nn/dataset.py +121 -0
- models/mhg_model/graph_grammar/nn/decoder.py +158 -0
- models/mhg_model/graph_grammar/nn/encoder.py +199 -0
- models/mhg_model/graph_grammar/nn/graph.py +313 -0
- models/mhg_model/load.py +103 -0
- models/mhg_model/mhg_gnn.egg-info/PKG-INFO +102 -0
- models/mhg_model/mhg_gnn.egg-info/SOURCES.txt +46 -0
- models/mhg_model/mhg_gnn.egg-info/dependency_links.txt +1 -0
- models/mhg_model/mhg_gnn.egg-info/requires.txt +7 -0
- models/mhg_model/mhg_gnn.egg-info/top_level.txt +2 -0
- models/mhg_model/models/__init__.py +5 -0
- models/mhg_model/models/mhgvae.py +956 -0
- models/mhg_model/notebooks/mhg-gnn_encoder_decoder_example.ipynb +114 -0
- models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf +3 -0
- models/mhg_model/setup.cfg +37 -0
- models/mhg_model/setup.py +6 -0
- models/selfies_model/README.md +87 -0
- models/selfies_model/load.py +102 -0
- models/selfies_model/requirements.txt +12 -0
INTRODUCTION.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Foundation Models for Materials - FM4M
|
2 |
+
|
3 |
+
FM4M adopts a modular architecture designed for flexible extensibility. As illustrated in the figure below, it comprises both uni-modal and fused models. Each uni-modal model is pre-trained independently for its respective modality (e.g., SMILES), and users can access individual functionalities directly from the corresponding model directory (e.g., smi-ted/). Some of these uni-modal models can be "late-fused" using fusion algorithms, creating a more powerful multi-modal feature representation for downstream predictions. To simplify usage, we provide fm4m-kit, a wrapper that enables users to easily access the capabilities of all models through straightforward methods. These models are also available on Hugging Face, where they can be accessed via an intuitive and user-friendly GUI.
|
4 |
+
|
5 |
+
<p align="center">
|
6 |
+
<img src="gradio_api/file=img/introduction.png" alt="FM4M Overview" width="50%"/>
|
7 |
+
</p>
|
8 |
+
|
9 |
+
|
app.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from tempfile import NamedTemporaryFile
|
5 |
+
from PIL import Image
|
6 |
+
from rdkit import RDLogger
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
|
9 |
+
from molecule_generation_helpers import *
|
10 |
+
from property_prediction_helpers import *
|
11 |
+
|
12 |
+
DEBUG_VISIBLE = False
|
13 |
+
RDLogger.logger().setLevel(RDLogger.ERROR)
|
14 |
+
|
15 |
+
# Predefined dataset paths (these should be adjusted to your file paths)
|
16 |
+
predefined_datasets = {
|
17 |
+
" ": " ",
|
18 |
+
"BACE": "./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
|
19 |
+
"ESOL": "./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
|
20 |
+
}
|
21 |
+
|
22 |
+
# Models
|
23 |
+
models_enabled = [
|
24 |
+
"MorganFingerprint",
|
25 |
+
"SMI-TED",
|
26 |
+
"SELFIES-TED",
|
27 |
+
"MHG-GED",
|
28 |
+
]
|
29 |
+
|
30 |
+
blank_df = pd.DataFrame({"id": [], "Model": [], "Score": []})
|
31 |
+
|
32 |
+
|
33 |
+
# Function to load a predefined dataset from the local path
|
34 |
+
def load_predefined_dataset(dataset_name):
|
35 |
+
val = predefined_datasets.get(dataset_name)
|
36 |
+
if val:
|
37 |
+
try:
|
38 |
+
df = pd.read_csv(val.split(",")[0])
|
39 |
+
return (
|
40 |
+
df.head(),
|
41 |
+
gr.update(choices=list(df.columns), value=None),
|
42 |
+
gr.update(choices=list(df.columns), value=None),
|
43 |
+
dataset_name.lower(),
|
44 |
+
)
|
45 |
+
except:
|
46 |
+
pass
|
47 |
+
else:
|
48 |
+
dataset_name = "Custom"
|
49 |
+
return (
|
50 |
+
pd.DataFrame(),
|
51 |
+
gr.update(choices=[], value=None),
|
52 |
+
gr.update(choices=[], value=None),
|
53 |
+
dataset_name.lower(),
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
# Function to handle dataset selection (predefined or custom)
|
58 |
+
def handle_dataset_selection(selected_dataset, state):
|
59 |
+
state["dataset_name"] = (
|
60 |
+
selected_dataset if selected_dataset in predefined_datasets else "CUSTOM"
|
61 |
+
)
|
62 |
+
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
|
63 |
+
task_type = (
|
64 |
+
"Classification"
|
65 |
+
if selected_dataset == "BACE"
|
66 |
+
else "Regression" if selected_dataset == "ESOL" else None
|
67 |
+
)
|
68 |
+
return (
|
69 |
+
gr.update(visible=selected_dataset not in predefined_datasets or DEBUG_VISIBLE),
|
70 |
+
task_type,
|
71 |
+
)
|
72 |
+
|
73 |
+
|
74 |
+
# Function to select input and output columns and display a message
|
75 |
+
def select_columns(input_column, output_column, train_data, test_data, state):
|
76 |
+
if train_data and test_data and input_column and output_column:
|
77 |
+
return f"{train_data.name},{test_data.name},{input_column},{output_column},{state['dataset_name']}"
|
78 |
+
return gr.update()
|
79 |
+
|
80 |
+
|
81 |
+
# Function to display the head of the uploaded CSV file
|
82 |
+
def display_csv_head(file):
|
83 |
+
if file is not None:
|
84 |
+
# Load the CSV file into a DataFrame
|
85 |
+
df = pd.read_csv(file.name)
|
86 |
+
return (
|
87 |
+
df.head(),
|
88 |
+
gr.update(choices=list(df.columns)),
|
89 |
+
gr.update(choices=list(df.columns)),
|
90 |
+
)
|
91 |
+
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
|
92 |
+
|
93 |
+
|
94 |
+
def process_custom_file(file, selected_dataset):
|
95 |
+
if file and os.path.getsize(file.name) < 50 * 1024:
|
96 |
+
df = pd.read_csv(file.name)
|
97 |
+
if "input" in df.columns and "output" in df.columns:
|
98 |
+
train, test = train_test_split(df, test_size=0.2)
|
99 |
+
with NamedTemporaryFile(
|
100 |
+
prefix="fm4m-train-", suffix=".csv", delete=False
|
101 |
+
) as train_file:
|
102 |
+
train.to_csv(train_file.name, index=False)
|
103 |
+
with NamedTemporaryFile(
|
104 |
+
prefix="fm4m-test-", suffix=".csv", delete=False
|
105 |
+
) as test_file:
|
106 |
+
test.to_csv(test_file.name, index=False)
|
107 |
+
task_type = (
|
108 |
+
"Classification" if df["output"].dtype == np.int64 else "Regression"
|
109 |
+
)
|
110 |
+
return train_file.name, test_file.name, "input", "output", task_type
|
111 |
+
|
112 |
+
return (
|
113 |
+
None,
|
114 |
+
None,
|
115 |
+
None,
|
116 |
+
None,
|
117 |
+
gr.update() if selected_dataset in predefined_datasets else None,
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
def update_plot_choices(current, state):
|
122 |
+
choices = []
|
123 |
+
if state.get("roc_auc") is not None:
|
124 |
+
choices.append("ROC-AUC")
|
125 |
+
if state.get("RMSE") is not None:
|
126 |
+
choices.append("Parity Plot")
|
127 |
+
if state.get("x_batch") is not None:
|
128 |
+
choices.append("Latent Space")
|
129 |
+
if current in choices:
|
130 |
+
return gr.update(choices=choices)
|
131 |
+
return gr.update(choices=choices, value=None if len(choices) == 0 else choices[0])
|
132 |
+
|
133 |
+
|
134 |
+
def log_selected(df: pd.DataFrame, evt: gr.SelectData, state):
|
135 |
+
state.update(state["results"].get(df.at[evt.index[0], 'id'], {}))
|
136 |
+
|
137 |
+
|
138 |
+
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
|
139 |
+
smiles_image_mapping = {
|
140 |
+
# Example SMILES for ethanol
|
141 |
+
"Mol 1": {
|
142 |
+
"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
|
143 |
+
"image": "img/img1.png",
|
144 |
+
},
|
145 |
+
# Example SMILES for butane
|
146 |
+
"Mol 2": {
|
147 |
+
"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
|
148 |
+
"image": "img/img2.png",
|
149 |
+
},
|
150 |
+
# Example SMILES for ethylamine
|
151 |
+
"Mol 3": {
|
152 |
+
"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
|
153 |
+
"image": "img/img3.png",
|
154 |
+
},
|
155 |
+
# Example SMILES for diethyl ether
|
156 |
+
"Mol 4": {
|
157 |
+
"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
|
158 |
+
"image": "img/img4.png",
|
159 |
+
},
|
160 |
+
# Example SMILES for chloroethane
|
161 |
+
"Mol 5": {
|
162 |
+
"smiles": "C=CCS[C@@H](C)CC(=O)OCC",
|
163 |
+
"image": "img/img5.png",
|
164 |
+
},
|
165 |
+
}
|
166 |
+
|
167 |
+
|
168 |
+
# Load images for selection
|
169 |
+
def load_image(path):
|
170 |
+
try:
|
171 |
+
return Image.open(smiles_image_mapping[path]["image"])
|
172 |
+
except:
|
173 |
+
pass
|
174 |
+
|
175 |
+
|
176 |
+
# Function to handle image selection
|
177 |
+
def handle_image_selection(image_key):
|
178 |
+
if not image_key:
|
179 |
+
return None, None
|
180 |
+
smiles = smiles_image_mapping[image_key]["smiles"]
|
181 |
+
mol_image = smiles_to_image(smiles)
|
182 |
+
return smiles, mol_image
|
183 |
+
|
184 |
+
|
185 |
+
# Introduction
|
186 |
+
with gr.Blocks() as introduction:
|
187 |
+
with open("INTRODUCTION.md") as f:
|
188 |
+
gr.Markdown(f.read(), sanitize_html=False)
|
189 |
+
|
190 |
+
# Property Prediction
|
191 |
+
with gr.Blocks() as property_prediction:
|
192 |
+
state = gr.State({"model_name": "Default - Auto", "results": {}})
|
193 |
+
gr.HTML(
|
194 |
+
'''
|
195 |
+
<p style="text-align: center">
|
196 |
+
Task : Property Prediction
|
197 |
+
<br>
|
198 |
+
Models are finetuned with different combination of modalities on the uploaded or selected built data set.
|
199 |
+
</p>
|
200 |
+
'''
|
201 |
+
)
|
202 |
+
with gr.Row():
|
203 |
+
with gr.Column():
|
204 |
+
# Dropdown menu for predefined datasets including "Custom Dataset" option
|
205 |
+
dataset_selector = gr.Dropdown(
|
206 |
+
label="Select Dataset",
|
207 |
+
choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
|
208 |
+
)
|
209 |
+
# Display the message for selected columns
|
210 |
+
selected_columns_message = gr.Textbox(
|
211 |
+
label="Selected Columns Info", visible=DEBUG_VISIBLE
|
212 |
+
)
|
213 |
+
|
214 |
+
with gr.Accordion(
|
215 |
+
"Custom Dataset Settings", open=True, visible=DEBUG_VISIBLE
|
216 |
+
) as settings:
|
217 |
+
# File upload options for custom dataset (train and test)
|
218 |
+
custom_file = gr.File(
|
219 |
+
label="Upload Custom Dataset",
|
220 |
+
file_types=[".csv"],
|
221 |
+
)
|
222 |
+
train_file = gr.File(
|
223 |
+
label="Upload Custom Train Dataset",
|
224 |
+
file_types=[".csv"],
|
225 |
+
visible=False,
|
226 |
+
)
|
227 |
+
train_display = gr.Dataframe(
|
228 |
+
label="Train Dataset Preview (First 5 Rows)",
|
229 |
+
interactive=False,
|
230 |
+
visible=DEBUG_VISIBLE,
|
231 |
+
)
|
232 |
+
|
233 |
+
test_file = gr.File(
|
234 |
+
label="Upload Custom Test Dataset",
|
235 |
+
file_types=[".csv"],
|
236 |
+
visible=False,
|
237 |
+
)
|
238 |
+
test_display = gr.Dataframe(
|
239 |
+
label="Test Dataset Preview (First 5 Rows)",
|
240 |
+
interactive=False,
|
241 |
+
visible=DEBUG_VISIBLE,
|
242 |
+
)
|
243 |
+
|
244 |
+
# Predefined dataset displays
|
245 |
+
predefined_display = gr.Dataframe(
|
246 |
+
label="Predefined Dataset Preview (First 5 Rows)",
|
247 |
+
interactive=False,
|
248 |
+
visible=DEBUG_VISIBLE,
|
249 |
+
)
|
250 |
+
|
251 |
+
# Dropdowns for selecting input and output columns for the custom dataset
|
252 |
+
input_column_selector = gr.Dropdown(
|
253 |
+
label="Select Input Column",
|
254 |
+
choices=[],
|
255 |
+
allow_custom_value=True,
|
256 |
+
visible=DEBUG_VISIBLE,
|
257 |
+
)
|
258 |
+
output_column_selector = gr.Dropdown(
|
259 |
+
label="Select Output Column",
|
260 |
+
choices=[],
|
261 |
+
allow_custom_value=True,
|
262 |
+
visible=DEBUG_VISIBLE,
|
263 |
+
)
|
264 |
+
|
265 |
+
# When a custom train file is uploaded, display its head and update column selectors
|
266 |
+
train_file.change(
|
267 |
+
display_csv_head,
|
268 |
+
inputs=train_file,
|
269 |
+
outputs=[
|
270 |
+
train_display,
|
271 |
+
input_column_selector,
|
272 |
+
output_column_selector,
|
273 |
+
],
|
274 |
+
)
|
275 |
+
|
276 |
+
# When a custom test file is uploaded, display its head
|
277 |
+
test_file.change(
|
278 |
+
display_csv_head,
|
279 |
+
inputs=test_file,
|
280 |
+
outputs=[
|
281 |
+
test_display,
|
282 |
+
input_column_selector,
|
283 |
+
output_column_selector,
|
284 |
+
],
|
285 |
+
)
|
286 |
+
|
287 |
+
model_checkbox = gr.CheckboxGroup(
|
288 |
+
choices=models_enabled, label="Select Model", visible=DEBUG_VISIBLE
|
289 |
+
)
|
290 |
+
|
291 |
+
task_radiobutton = gr.Radio(
|
292 |
+
choices=["Classification", "Regression"],
|
293 |
+
label="Task Type",
|
294 |
+
visible=DEBUG_VISIBLE,
|
295 |
+
)
|
296 |
+
|
297 |
+
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
|
298 |
+
# When a predefined dataset is selected, load its head and update column selectors
|
299 |
+
dataset_selector.change(lambda: None, outputs=custom_file).then(
|
300 |
+
handle_dataset_selection,
|
301 |
+
inputs=[dataset_selector, state],
|
302 |
+
outputs=[settings, task_radiobutton],
|
303 |
+
).then(
|
304 |
+
load_predefined_dataset,
|
305 |
+
inputs=dataset_selector,
|
306 |
+
outputs=[
|
307 |
+
predefined_display,
|
308 |
+
input_column_selector,
|
309 |
+
output_column_selector,
|
310 |
+
selected_columns_message,
|
311 |
+
],
|
312 |
+
)
|
313 |
+
|
314 |
+
custom_file.change(
|
315 |
+
process_custom_file,
|
316 |
+
inputs=[custom_file, dataset_selector],
|
317 |
+
outputs=[
|
318 |
+
train_file,
|
319 |
+
test_file,
|
320 |
+
input_column_selector,
|
321 |
+
output_column_selector,
|
322 |
+
task_radiobutton,
|
323 |
+
],
|
324 |
+
)
|
325 |
+
eval_clear_button = gr.Button("Clear")
|
326 |
+
eval_button = gr.Button("Submit", variant="primary")
|
327 |
+
step_slider = gr.Slider(
|
328 |
+
minimum=0,
|
329 |
+
maximum=8,
|
330 |
+
value=0,
|
331 |
+
label="Progress",
|
332 |
+
show_label=True,
|
333 |
+
interactive=False,
|
334 |
+
visible=False,
|
335 |
+
)
|
336 |
+
|
337 |
+
# Right Column
|
338 |
+
with gr.Column():
|
339 |
+
log_table = gr.Dataframe(value=blank_df, interactive=False)
|
340 |
+
|
341 |
+
plot_radio = gr.Radio(choices=[], label="Select Plot Type")
|
342 |
+
|
343 |
+
plot_output = gr.Plot(label="Visualization")
|
344 |
+
|
345 |
+
log_table.select(log_selected, [log_table, state]).success(
|
346 |
+
update_plot_choices, inputs=[plot_radio, state], outputs=plot_radio
|
347 |
+
).then(display_plot, inputs=[plot_radio, state], outputs=plot_output)
|
348 |
+
|
349 |
+
def clear_eval(state):
|
350 |
+
state["results"] = {}
|
351 |
+
return None, gr.update(choices=[], value=None), blank_df
|
352 |
+
|
353 |
+
def eval_part(part, step, selector, show_progress=False):
|
354 |
+
return (
|
355 |
+
part.then(
|
356 |
+
lambda: [models_enabled[x] for x in selector],
|
357 |
+
outputs=model_checkbox,
|
358 |
+
)
|
359 |
+
.then(
|
360 |
+
evaluate_and_log,
|
361 |
+
inputs=[
|
362 |
+
model_checkbox,
|
363 |
+
selected_columns_message,
|
364 |
+
task_radiobutton,
|
365 |
+
log_table,
|
366 |
+
state,
|
367 |
+
],
|
368 |
+
outputs=log_table,
|
369 |
+
show_progress=show_progress,
|
370 |
+
)
|
371 |
+
.then(lambda: step, outputs=step_slider, show_progress=False)
|
372 |
+
)
|
373 |
+
|
374 |
+
part = (
|
375 |
+
eval_button.click(
|
376 |
+
lambda: (
|
377 |
+
gr.update(interactive=False),
|
378 |
+
gr.update(interactive=False),
|
379 |
+
),
|
380 |
+
outputs=[eval_clear_button, eval_button],
|
381 |
+
)
|
382 |
+
.then(
|
383 |
+
select_columns,
|
384 |
+
inputs=[
|
385 |
+
input_column_selector,
|
386 |
+
output_column_selector,
|
387 |
+
train_file,
|
388 |
+
test_file,
|
389 |
+
state,
|
390 |
+
],
|
391 |
+
outputs=selected_columns_message,
|
392 |
+
)
|
393 |
+
.then(
|
394 |
+
clear_eval,
|
395 |
+
inputs=state,
|
396 |
+
outputs=[
|
397 |
+
plot_output,
|
398 |
+
plot_radio,
|
399 |
+
log_table,
|
400 |
+
],
|
401 |
+
)
|
402 |
+
)
|
403 |
+
part = part.then(
|
404 |
+
lambda: gr.update(value=0, visible=True),
|
405 |
+
outputs=step_slider,
|
406 |
+
show_progress=False,
|
407 |
+
)
|
408 |
+
part = eval_part(part, 1, [0], True)
|
409 |
+
part = eval_part(part, 2, [1])
|
410 |
+
part = eval_part(part, 3, [2])
|
411 |
+
part = eval_part(part, 4, [3])
|
412 |
+
part = eval_part(part, 5, [1, 2])
|
413 |
+
part = eval_part(part, 6, [2, 3])
|
414 |
+
part = eval_part(part, 7, [1, 3])
|
415 |
+
part = eval_part(part, 8, [1, 2, 3])
|
416 |
+
part = part.then(
|
417 |
+
lambda: gr.update(visible=False),
|
418 |
+
outputs=step_slider,
|
419 |
+
show_progress=False,
|
420 |
+
)
|
421 |
+
part.then(
|
422 |
+
lambda: (
|
423 |
+
gr.update(interactive=True),
|
424 |
+
gr.update(interactive=True),
|
425 |
+
),
|
426 |
+
outputs=[eval_clear_button, eval_button],
|
427 |
+
)
|
428 |
+
|
429 |
+
plot_radio.change(
|
430 |
+
display_plot, inputs=[plot_radio, state], outputs=plot_output
|
431 |
+
)
|
432 |
+
|
433 |
+
eval_clear_button.click(
|
434 |
+
clear_eval,
|
435 |
+
inputs=state,
|
436 |
+
outputs=[
|
437 |
+
plot_output,
|
438 |
+
plot_radio,
|
439 |
+
log_table,
|
440 |
+
],
|
441 |
+
).then(lambda: " ", outputs=dataset_selector)
|
442 |
+
|
443 |
+
|
444 |
+
# Molecule Generation
|
445 |
+
with gr.Blocks() as molecule_generation:
|
446 |
+
gr.HTML(
|
447 |
+
'''
|
448 |
+
<p style="text-align: center">
|
449 |
+
Task : Molecule Generation
|
450 |
+
<br>
|
451 |
+
Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.
|
452 |
+
</p>
|
453 |
+
'''
|
454 |
+
)
|
455 |
+
with gr.Row():
|
456 |
+
with gr.Column():
|
457 |
+
smiles_input = gr.Textbox(label="Input SMILES String")
|
458 |
+
image_display = gr.Image(label="Molecule Image", height=250, width=250)
|
459 |
+
# Show images for selection
|
460 |
+
with gr.Accordion("Select from sample molecules", open=False):
|
461 |
+
image_selector = gr.Radio(
|
462 |
+
choices=list(smiles_image_mapping.keys()),
|
463 |
+
label="Select from sample molecules",
|
464 |
+
value=None,
|
465 |
+
)
|
466 |
+
image_selector.change(load_image, image_selector, image_display)
|
467 |
+
clear_button = gr.Button("Clear")
|
468 |
+
generate_button = gr.Button("Submit", variant="primary")
|
469 |
+
|
470 |
+
# Right Column
|
471 |
+
with gr.Column():
|
472 |
+
gen_image_display = gr.Image(
|
473 |
+
label="Generated Molecule Image", height=250, width=250
|
474 |
+
)
|
475 |
+
generated_output = gr.Textbox(label="Generated Output")
|
476 |
+
property_table = gr.Dataframe(label="Molecular Properties Comparison")
|
477 |
+
|
478 |
+
# Handle image selection
|
479 |
+
image_selector.change(
|
480 |
+
handle_image_selection,
|
481 |
+
inputs=image_selector,
|
482 |
+
outputs=[smiles_input, image_display],
|
483 |
+
)
|
484 |
+
smiles_input.change(
|
485 |
+
smiles_to_image, inputs=smiles_input, outputs=image_display
|
486 |
+
)
|
487 |
+
|
488 |
+
# Generate button to display canonical SMILES and molecule image
|
489 |
+
generate_button.click(
|
490 |
+
lambda: (
|
491 |
+
gr.update(interactive=False),
|
492 |
+
gr.update(interactive=False),
|
493 |
+
),
|
494 |
+
outputs=[clear_button, generate_button],
|
495 |
+
).then(
|
496 |
+
generate_canonical,
|
497 |
+
inputs=smiles_input,
|
498 |
+
outputs=[property_table, generated_output, gen_image_display],
|
499 |
+
).then(
|
500 |
+
lambda: (
|
501 |
+
gr.update(interactive=True),
|
502 |
+
gr.update(interactive=True),
|
503 |
+
),
|
504 |
+
outputs=[clear_button, generate_button],
|
505 |
+
)
|
506 |
+
clear_button.click(
|
507 |
+
lambda: (None, None, None, None, None, None),
|
508 |
+
outputs=[
|
509 |
+
smiles_input,
|
510 |
+
image_display,
|
511 |
+
image_selector,
|
512 |
+
gen_image_display,
|
513 |
+
generated_output,
|
514 |
+
property_table,
|
515 |
+
],
|
516 |
+
)
|
517 |
+
|
518 |
+
|
519 |
+
# Render with tabs
|
520 |
+
gr.TabbedInterface(
|
521 |
+
[introduction, property_prediction, molecule_generation],
|
522 |
+
["Introduction", "Property Prediction", "Molecule Generation"],
|
523 |
+
).launch(server_name="0.0.0.0", allowed_paths=["./"])
|
data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/bace/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/bace/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/bace/valid.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/esol/test.csv
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,selfies,prop,smiles
|
2 |
+
0,[Cl] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [C] [C] [C] [C] [Branch1] [Branch2] [C] [O] [C] [Ring1] [=Branch1] [Ring1] [Ring1] [C] [Ring1] [Branch2] [C] [Ring1] [=C] [Branch1] [C] [Cl] [C] [Ring1] [=N] [Branch1] [C] [Cl] [Cl],-4.533,ClC4=C(Cl)C5(Cl)C3C1CC(C2OC12)C3C4(Cl)C5(Cl)Cl
|
3 |
+
1,[C] [C] [C] [C] [C] [=O],-1.103,CCCCC=O
|
4 |
+
2,[O] [C] [C] [C] [C] [=C],-0.7909999999999999,OCCCC=C
|
5 |
+
3,[C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [N] [=C] [C] [Branch1] [C] [N] [=C] [Branch1] [C] [Br] [C] [Ring1] [Branch2] [=O],-3.005,c1ccccc1n2ncc(N)c(Br)c2(=O)
|
6 |
+
4,[N] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-1.231,Nc1ccc(O)cc1
|
7 |
+
5,[C] [C] [Branch1] [C] [C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.817,CC(C)CCOC(=O)C
|
8 |
+
6,[C] [O] [P] [=Branch1] [C] [=S] [Branch1] [Ring1] [O] [C] [S] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [C] [C] [C] [=O],-2.087,COP(=S)(OC)SCC(=O)N(C)C=O
|
9 |
+
7,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=Branch1] [Ring2] [=C] [Ring1] [#Branch1] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-6.312,Clc1ccc(Cl)c(c1)c2ccc(Cl)c(Cl)c2
|
10 |
+
8,[C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [N] [=C] [C] [=N] [C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [Cl],-4.438,c2(Cl)c(Cl)c(Cl)c1nccnc1c2(Cl)
|
11 |
+
9,[C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [N] [=C] [Branch1] [=Branch1] [N] [=C] [Ring1] [#Branch1] [O] [N] [Branch1] [C] [C] [C],-3.57,CCCCc1c(C)nc(nc1O)N(C)C
|
12 |
+
10,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [C] [=Branch1] [C] [=O] [O] [C] [C],-1.413,CCOC(=O)CC(=O)OCC
|
13 |
+
11,[C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-3.192,CC(C)(C)c1ccc(O)cc1
|
14 |
+
12,[C] [C] [=C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch1],-3.035,Cc1cccc(C)c1
|
15 |
+
13,[C] [C] [C] [O] [C] [=Branch1] [C] [=O] [C],-1.125,CCCOC(=O)C
|
16 |
+
14,[C] [S] [C] [=N] [N] [=C] [Branch1] [=Branch2] [C] [=Branch1] [C] [=O] [N] [Ring1] [#Branch1] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.324,CSc1nnc(c(=O)n1N)C(C)(C)C
|
17 |
+
15,[Cl] [C] [=C] [C] [=C] [Branch1] [Branch1] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [Cl],-5.142,Clc1ccc(cc1)c2ccccc2Cl
|
18 |
+
16,[C] [C] [C] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [C] [Branch1] [Ring2] [C] [Ring1] [Branch2] [C] [Branch1] [C] [O] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2],-1.5319999999999998,CC1CC(C)C(=O)C(C1)C(O)CC2CC(=O)NC(=O)C2
|
19 |
+
17,[C] [N] [C] [=Branch1] [C] [=O] [O] [C] [=C] [C] [=C] [C] [Branch1] [Branch2] [N] [=C] [N] [Branch1] [C] [C] [C] [=C] [Ring1] [O],-1.846,CNC(=O)Oc1cccc(N=CN(C)C)c1
|
20 |
+
18,[C] [C] [=C] [C] [=N] [C] [N] [Branch1] [=Branch1] [C] [C] [C] [Ring1] [Ring1] [C] [=N] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [N] [C] [Ring2] [Ring1] [Ring1] [=Ring1] [#C],-3.397,Cc3ccnc4N(C1CC1)c2ncccc2C(=O)Nc34
|
21 |
+
19,[C] [C] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.389,CCNc1ccccc1
|
22 |
+
20,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [Ring2] [Ring1] [Ring1] [Ring1] [#Branch2],-6.297000000000001,Cc1c2ccccc2c(C)c3ccc4ccccc4c13
|
23 |
+
21,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [Branch1] [C] [F] [C] [Branch1] [C] [Cl] [=C] [Ring1] [=Branch2] [F],-5.462000000000001,Fc1cccc(F)c1C(=O)NC(=O)Nc2cc(Cl)c(F)c(Cl)c2F
|
24 |
+
22,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.057,COc1ccc(Cl)cc1
|
25 |
+
23,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=N] [Ring1] [=Branch1],-4.2010000000000005,o1c2ccccc2c3ccccc13
|
26 |
+
24,[C] [=C] [C] [=C] [N] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [Ring1] [#Branch2] [=C] [Ring1] [=C],-3.846,c3ccc2nc1ccccc1cc2c3
|
27 |
+
25,[C] [C] [C] [C] [=Branch1] [C] [=O] [C] [C] [Branch1] [P] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [O] [Ring1] [#Branch1] [C] [C] [Ring1] [P] [C] [C] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [O] [C] [=Branch1] [C] [=O] [C] [O],-2.893,CC12CC(=O)C3C(CCC4=CC(=O)CCC34C)C2CCC1(O)C(=O)CO
|
28 |
+
26,[C] [C] [C] [=C] [C] [=C] [C] [Branch1] [Ring1] [C] [C] [=C] [Ring1] [Branch2] [N] [Branch1] [Ring2] [C] [O] [C] [C] [=Branch1] [C] [=O] [C] [Cl],-3.319,CCc1cccc(CC)c1N(COC)C(=O)CCl
|
29 |
+
27,[C] [C] [C] [C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-4.157,CCCCN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
|
30 |
+
28,[C] [S] [C] [=Branch1] [C] [=S] [N] [C] [Ring1] [=Branch1] [=O],-0.396,C1SC(=S)NC1(=O)
|
31 |
+
29,[O] [C] [=C] [C] [=C] [Branch1] [Branch2] [C] [Branch1] [C] [O] [=C] [Ring1] [#Branch1] [C] [O] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [Branch1] [C] [O] [=C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [C] [=Ring1] [=N] [O],-2.7310000000000003,Oc1ccc(c(O)c1)c3oc2cc(O)cc(O)c2c(=O)c3O
|
32 |
+
30,[C] [N] [Branch1] [C] [C] [C] [=N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C],-3.164,CN(C)C=Nc1ccc(Cl)cc1C
|
33 |
+
31,[N] [C] [=Branch1] [C] [=O] [N] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [=Branch1] [=O],0.652,NC(=O)NC1NC(=O)NC1=O
|
34 |
+
32,[Cl] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-4.063,Clc1cccc2ccccc12
|
35 |
+
33,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.352,Oc1ccc(Cl)c(Cl)c1
|
36 |
+
34,[C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [=C] [Branch1] [C] [Cl] [Cl] [C] [Ring1] [Branch2] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-6.775,CC1(C)C(C=C(Cl)Cl)C1C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
|
37 |
+
35,[C] [=C] [C] [=C] [NH1] [N] [=N] [C] [Ring1] [Branch1] [=C] [Ring1] [=Branch2],-2.21,c2ccc1[nH]nnc1c2
|
38 |
+
36,[C] [C] [Branch1] [C] [C] [C] [Branch2] [Ring1] [Branch1] [N] [C] [=C] [C] [=C] [Branch1] [=Branch1] [C] [=C] [Ring1] [=Branch1] [Cl] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [Ring1] [C] [#N] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N],-8.057,CC(C)C(Nc1ccc(cc1Cl)C(F)(F)F)C(=O)OC(C#N)c2cccc(Oc3ccccc3)c2
|
39 |
+
37,[C] [C] [C],-1.5530000000000002,CCC
|
40 |
+
38,[C] [C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [O] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-3.792,C1Cc2cccc3cccc1c23
|
41 |
+
39,[C] [C] [C] [#C],-1.092,CCC#C
|
42 |
+
40,[Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-3.5580000000000003,Clc1ccc(Cl)cc1
|
43 |
+
41,[C] [C] [=C] [NH1] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch2] [Ring1] [=Branch1],-2.9810000000000003,Cc1c[nH]c2ccccc12
|
44 |
+
42,[C] [C] [#N],0.152,CC#N
|
45 |
+
43,[C] [C] [C] [C] [O],-0.688,CCCCO
|
46 |
+
44,[C] [C] [=Branch1] [C] [=C] [C] [=Branch1] [C] [=C] [C],-2.052,CC(=C)C(=C)C
|
47 |
+
45,[C] [C] [C] [Branch1] [C] [C] [C] [C] [O],-1.308,CCC(C)CCO
|
48 |
+
46,[Cl] [C] [=C] [C] [=C] [Branch1] [=Branch2] [C] [Branch1] [C] [Cl] [=C] [Ring1] [#Branch1] [Cl] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2] [Cl],-7.192,Clc1ccc(c(Cl)c1Cl)c2ccc(Cl)c(Cl)c2Cl
|
49 |
+
47,[C] [C] [=C] [C] [=Branch2] [Ring1] [=Branch1] [=C] [C] [=C] [Ring1] [=Branch1] [N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [Branch1] [C] [F] [Branch1] [C] [F] [F] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-4.945,Cc1cc(ccc1NS(=O)(=O)C(F)(F)F)S(=O)(=O)c2ccccc2
|
50 |
+
48,[O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [Cl],-3.22,Oc1ccc(Cl)cc1Cl
|
51 |
+
49,[C] [N] [C] [=Branch2] [Ring1] [Ring2] [=C] [Branch1] [C] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [S] [Ring1] [O] [=Branch1] [C] [=O] [=O] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=N] [Ring1] [=Branch1],-3.4730000000000003,CN2C(=C(O)c1ccccc1S2(=O)=O)C(=O)Nc3ccccn3
|
52 |
+
50,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [C] [Ring2] [Ring1] [C] [=O],-3.872,CC12CCC3C(CCc4cc(O)ccc34)C2CCC1=O
|
53 |
+
51,[C] [C] [=C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1],-4.147,Cc1cccc2c(C)cccc12
|
54 |
+
52,[N] [S] [=Branch1] [C] [=O] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [O] [N] [C] [N] [S] [Ring1] [=Branch1] [=Branch1] [C] [=O] [=O] [C] [=C] [Ring1] [N] [Cl],-1.72,NS(=O)(=O)c2cc1c(NCNS1(=O)=O)cc2Cl
|
55 |
+
53,[O] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-2.725,Oc1cccc2cccnc12
|
56 |
+
54,[C] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [Ring1] [#Branch2],-3.447,C1CCc2ccccc2C1
|
57 |
+
55,[C] [C] [O] [C] [Branch1] [C] [C] [O] [C] [C],-0.899,CCOC(C)OCC
|
58 |
+
56,[C] [C] [C] [C] [Ring1] [Ring1] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [Branch1] [C] [Ring1] [Branch2] [=O] [C] [=C] [C] [Branch1] [C] [Cl] [=C] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.464,CC12CC2(C)C(=O)N(C1=O)c3cc(Cl)cc(Cl)c3
|
59 |
+
57,[C] [C] [=C] [C] [=C] [C] [=C] [C] [Ring1] [=Branch1] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-4.87,Cc1c2ccccc2cc3ccccc13
|
60 |
+
58,[C] [C] [C] [C] [O] [C],-1.072,CCCCOC
|
61 |
+
59,[C] [C] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [Ring1] [#Branch1] [C] [C] [C] [C] [C] [C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [C] [O] [C] [Ring1] [=Branch2] [Branch1] [N] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [#Branch1] [Ring1] [=C] [C] [=O],-3.0660000000000003,CC13CCC(=O)C=C1CCC4C2CCC(C(=O)CO)C2(CC(O)C34)C=O
|
62 |
+
60,[C] [C] [C] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [O] [=O],-1.6030000000000002,CCC1(C(C)C)C(=O)NC(=O)NC1=O
|
63 |
+
61,[C] [C] [O] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1],-2.761,CCOC(=O)c1ccc(O)cc1
|
64 |
+
62,[C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [C] [=C] [Ring2] [Ring1] [C] [C] [Ring1] [S] [=C] [Ring1] [=C] [C] [Ring1] [N] [=C] [Ring1] [#Branch2] [Ring1] [=Branch1],-6.885,c1cc2ccc3ccc4ccc5ccc6ccc1c7c2c3c4c5c67
|
65 |
+
63,[C] [C] [N] [C] [=C] [C] [Branch1] [=Branch1] [N] [Branch1] [C] [C] [C] [=C] [C] [Branch1] [C] [C] [=C] [Ring1] [#Branch2] [N] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [N] [=C] [Ring2] [Ring1] [Ring2] [Ring1] [=Branch1],-4.408,CCN2c1cc(N(C)C)cc(C)c1NC(=O)c3cccnc23
|
66 |
+
64,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [Branch1] [C] [Cl] [=C] [Ring1] [Branch2],-3.301,CN(C)C(=O)Nc1ccc(Cl)c(Cl)c1
|
67 |
+
65,[C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-3.3080000000000003,CCCCCC(C)C
|
68 |
+
66,[C] [O] [C] [=C] [C] [=C] [Branch1] [C] [N] [N] [=C] [Branch1] [#C] [N] [=C] [Ring1] [#Branch1] [C] [Branch1] [Ring1] [O] [C] [=C] [Ring1] [=N] [O] [C] [N] [C] [C] [N] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [=Branch1] [C] [=O] [O] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-3.958,COc2cc1c(N)nc(nc1c(OC)c2OC)N3CCN(CC3)C(=O)OCC(C)(C)O
|
69 |
+
67,[C] [=C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [#Branch1] [C] [=C] [Ring1] [O],-0.636,c1cC2C(=O)NC(=O)C2cc1
|
70 |
+
68,[C] [C] [C] [=O],-0.3939999999999999,CCC=O
|
71 |
+
69,[Cl] [C] [=C] [C] [=C] [Branch2] [Ring1] [=Branch2] [C] [N] [Branch1] [Branch2] [C] [C] [C] [C] [C] [Ring1] [Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [Ring2] [Ring1] [=Branch1],-5.126,Clc1ccc(CN(C2CCCC2)C(=O)Nc3ccccc3)cc1
|
72 |
+
70,[C] [C] [C] [C] [C] [Branch1] [Ring1] [C] [C] [C] [=O],-2.232,CCCCC(CC)C=O
|
73 |
+
71,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [C] [Branch1] [C] [C] [C],-2.312,O=C1NC(=O)NC(=O)C1(CC)CCC(C)C
|
74 |
+
72,[C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-1.857,CC(=O)Nc1ccccc1
|
75 |
+
73,[C] [=N] [C] [=C] [C] [Branch1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [N] [=C] [Ring1] [#Branch2],-0.7170000000000001,c1nccc(C(=O)NN)c1
|
76 |
+
74,[C] [C] [Branch1] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [Ring2] [C] [Ring1] [=Branch1] [C] [Ring1] [=Branch2] [=O],-2.158,CC2(C)C1CCC(C)(C1)C2=O
|
77 |
+
75,[C] [O] [C] [=C] [N] [=C] [C] [=N] [C] [=N] [C] [Ring1] [=Branch1] [=N] [Ring1] [#Branch2],-1.589,COc2cnc1cncnc1n2
|
78 |
+
76,[C] [N] [C] [=Branch1] [C] [=O] [C] [=C] [Branch1] [C] [C] [O] [P] [=Branch1] [C] [=O] [Branch1] [Ring1] [O] [C] [O] [C],-0.949,CNC(=O)C=C(C)OP(=O)(OC)OC
|
79 |
+
77,[O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [N] [Branch1] [Ring1] [C] [C] [C] [=Branch1] [C] [=O] [C] [=C] [C] [=C] [C] [=C] [Ring2] [Ring1] [C] [Ring1] [=Branch1],-3.784,O2c1ccccc1N(CC)C(=O)c3ccccc23
|
80 |
+
78,[C] [=C] [C] [=C] [C] [=C] [Branch1] [Ring1] [O] [C] [C] [Branch1] [Branch2] [C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [=N] [O] [C] [Ring1] [P] [=O],-4.0760000000000005,c1cc2ccc(OC)c(CC=C(C)(C))c2oc1=O
|
81 |
+
79,[C] [C] [C] [S] [C] [C] [C],-2.307,CCCSCCC
|
82 |
+
80,[C] [O] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-2.948,CON(C)C(=O)Nc1ccc(Cl)cc1
|
83 |
+
81,[C] [C] [O] [C] [C],-0.718,CCOCC
|
84 |
+
82,[C] [C] [C] [C] [C] [C] [Branch1] [S] [C] [C] [C] [=C] [C] [Branch1] [C] [O] [=C] [C] [=C] [Ring1] [O] [Ring1] [#Branch1] [C] [Ring1] [#C] [C] [C] [Branch1] [C] [O] [C] [Ring2] [Ring1] [Ring1] [O],-3.858,CC34CCC1C(CCc2cc(O)ccc12)C3CC(O)C4O
|
85 |
+
83,[C] [C] [N] [C] [=N] [C] [Branch1] [C] [Cl] [=N] [C] [Branch1] [O] [N] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C] [#N] [=N] [Ring1] [=N],-2.49,CCNc1nc(Cl)nc(NC(C)(C)C#N)n1
|
86 |
+
84,[C] [C] [Branch1] [C] [C] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O],-1.6469999999999998,CC(C)CC(C)(C)O
|
87 |
+
85,[Cl] [C] [=C] [C] [=C] [C] [Branch1] [C] [Br] [=C] [Ring1] [#Branch1],-3.928,Clc1cccc(Br)c1
|
88 |
+
86,[C] [C] [C] [C] [C] [C] [Branch1] [C] [O] [C] [C],-2.033,CCCCCC(O)CC
|
89 |
+
87,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [Ring1] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.126,O=C1NC(=O)NC(=O)C1(CC)CC=C(C)C
|
90 |
+
88,[C] [C] [C] [Branch1] [C] [C] [C] [Branch1] [#Branch1] [C] [C] [Branch1] [C] [Br] [=C] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [Ring1] [N] [=O],-2.766,CCC(C)C1(CC(Br)=C)C(=O)NC(=O)NC1=O
|
91 |
+
89,[C] [O] [C] [=Branch1] [C] [=O] [C],-0.416,COC(=O)C
|
92 |
+
90,[C] [C] [Branch1] [C] [C] [C] [=C] [C] [=C] [Branch1] [C] [C] [C] [=C] [Ring1] [#Branch1] [O],-3.129,CC(C)c1ccc(C)cc1O
|
93 |
+
91,[C],-0.636,C
|
94 |
+
92,[N] [C] [=N] [C] [Branch1] [C] [O] [=N] [C] [N] [=C] [NH1] [C] [Ring1] [#Branch2] [=Ring1] [Branch1],-1.74,Nc1nc(O)nc2nc[nH]c12
|
95 |
+
93,[F] [C] [=C] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [#Branch1] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1],-4.692,Fc1cccc(F)c1C(=O)NC(=O)Nc2ccc(Cl)cc2
|
96 |
+
94,[C] [C] [C] [C] [C] [Branch1] [Branch1] [C] [C] [Ring1] [=Branch1] [C] [Branch1] [C] [C] [Branch1] [C] [C] [O] [Ring1] [#Branch2],-2.579,CC12CCC(CC1)C(C)(C)O2
|
97 |
+
95,[C] [C] [O],0.02,CCO
|
98 |
+
96,[C] [=C] [Branch2] [Ring1] [C] [N] [C] [=Branch1] [C] [=O] [O] [C] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [C] [C] [=C] [C] [=C] [Ring1] [P],-2.29,c1c(NC(=O)OC(C)C(=O)NCC)cccc1
|
99 |
+
97,[C] [C] [Branch1] [C] [C] [=C] [C] [C] [Branch2] [Ring1] [#Branch2] [C] [=Branch1] [C] [=O] [O] [C] [C] [=C] [C] [=C] [C] [Branch1] [#Branch2] [O] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [=C] [Ring1] [=N] [C] [Ring2] [Ring1] [Ring2] [Branch1] [C] [C] [C],-6.763,CC(C)=CC3C(C(=O)OCc2cccc(Oc1ccccc1)c2)C3(C)C
|
100 |
+
98,[C] [C] [C] [C] [N] [C] [=Branch1] [C] [=O] [N] [C] [Branch1] [Branch2] [N] [C] [=Branch1] [C] [=O] [O] [C] [=N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=C] [Ring1] [=Branch1],-2.902,CCCCNC(=O)n1c(NC(=O)OC)nc2ccccc12
|
101 |
+
99,[C] [N] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1],-2.542,CN(C)c1ccccc1
|
102 |
+
100,[C] [O] [C] [=Branch1] [C] [=O] [C] [=C],-0.878,COC(=O)C=C
|
103 |
+
101,[C] [N] [Branch1] [C] [C] [C] [=Branch1] [C] [=O] [N] [C] [=C] [C] [=C] [Branch1] [=N] [O] [C] [=C] [C] [=C] [Branch1] [C] [Cl] [C] [=C] [Ring1] [#Branch1] [C] [=C] [Ring1] [=C],-4.477,CN(C)C(=O)Nc2ccc(Oc1ccc(Cl)cc1)cc2
|
104 |
+
102,[O] [=C] [N] [C] [=Branch1] [C] [=O] [N] [C] [=Branch1] [C] [=O] [C] [Ring1] [Branch2] [Branch1] [=Branch1] [C] [Branch1] [C] [C] [C] [C] [C] [=C] [Branch1] [C] [C] [C],-2.465,O=C1NC(=O)NC(=O)C1(C(C)C)CC=C(C)C
|
105 |
+
103,[C] [C] [=C] [C] [=C] [Branch1] [C] [O] [C] [=C] [Ring1] [#Branch1] [C],-2.6210000000000004,Cc1ccc(O)cc1C
|
106 |
+
104,[Cl] [C] [=C] [C] [=C] [C] [=Branch1] [Ring2] [=N] [Ring1] [=Branch1] [C] [Branch1] [C] [Cl] [Branch1] [C] [Cl] [Cl],-3.833,Clc1cccc(n1)C(Cl)(Cl)Cl
|
107 |
+
105,[C] [C] [=Branch1] [C] [=O] [O] [C] [Branch2] [Ring1] [=C] [C] [C] [C] [C] [C] [C] [C] [=C] [C] [=Branch1] [C] [=O] [C] [C] [C] [Ring1] [#Branch1] [C] [Ring1] [O] [C] [C] [C] [Ring2] [Ring1] [C] [Ring1] [#C] [C] [C] [#C],-4.2410000000000005,CC(=O)OC3(CCC4C2CCC1=CC(=O)CCC1C2CCC34C)C#C
|
108 |
+
106,[C] [N] [C] [=Branch1] [C] [=O] [O] [N] [=C] [Branch1] [Ring2] [C] [S] [C] [C] [Branch1] [C] [C] [Branch1] [C] [C] [C],-2.7,CNC(=O)ON=C(CSC)C(C)(C)C
|
109 |
+
107,[C] [C] [C] [C] [C] [C] [C] [Branch1] [C] [C] [O],-2.033,CCCCCCC(C)O
|
data/esol/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/lce/test.csv
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE
|
2 |
+
C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0.0,O,0.0,O,0.0,O,0.0,1.629
|
3 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,1.085
|
4 |
+
COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0.0,O,0.0,O,0.0,2.056
|
5 |
+
COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],,O,0.0,O,0.0,1.658
|
6 |
+
C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0.0,1.638
|
7 |
+
C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.276
|
8 |
+
O1CCOC1,0.368,COCCOC,0.547,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.076,CSi(C)(C)([N+]).C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.008,O,0.0,O,0.0,1.569
|
9 |
+
COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0.0,O,0.0,O,0.0,2.268
|
10 |
+
C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,1.602
|
11 |
+
C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0.0,1.678
|
12 |
+
O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0.0,O,0.0,2.0
|
13 |
+
C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0.0,O,0.0,O,0.0,0.921
|
14 |
+
C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0.0,O,0.0,O,0.0,1.301
|
15 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0.0,O,0.0,0.854
|
16 |
+
C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0.0,O,0.0,1.108
|
17 |
+
O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0.0,O,0.0,O,0.0,1.523
|
18 |
+
CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0.0,O,0.0,O,0.0,O,0.0,1.921
|
19 |
+
CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0.0,O,0.0,O,0.0,O,0.0,1.602
|
20 |
+
O1CCOC1,0.375,COCCOC,0.557,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.061,O,0.0,1.523
|
21 |
+
COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0.0,O,0.0,O,0.0,2.155
|
22 |
+
C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.26
|
23 |
+
CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0.0,O,0.0,O,0.0,2.155
|
24 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.225
|
25 |
+
COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155
|
26 |
+
COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0.0,O,0.0,O,0.0,2.155
|
27 |
+
O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0.0,O,0.0,O,0.0,2.301
|
28 |
+
COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0.0,O,0.0,O,0.0,O,0.0,1.991
|
29 |
+
COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,2.301
|
30 |
+
C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0.0,O,0.0,O,0.0,1.398
|
31 |
+
COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0.0,O,0.0,O,0.0,1.268
|
data/lce/test_data.csv
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
smiles1,conc1,mol1,smiles2,conc2,mol2,smiles3,conc3,mol3,smiles4,conc4,mol4,smiles5,conc5,mol5,smiles6,conc6,LCE_Predicted,LCE
|
2 |
+
C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.187,1.094
|
3 |
+
COCCOC,0.596,59.5609428,COCCOCCOCCOCCOC,0.281,28.07124115,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.124,12.36781605,O,0,0,O,0,0,O,0,1.691,1.384
|
4 |
+
C1COC(=O)O1,0.285,28.50894036,C1C(OC(=O)O1)F,0.261,26.07552384,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.228,22.82322096,COC(=O)OC,0.226,22.59231484,O,0,0,O,0,1.508,1.468
|
5 |
+
COCCOC,0.434,43.4423376,COCCOCCOCCOCCOC,0.205,20.47449683,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.361,36.08316557,O,0,0,O,0,0,O,0,1.882,1.71
|
6 |
+
C1C(OC(=O)O1)F,0.187,18.72872664,COC(=O)OC,0.162,16.22691423,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.109,10.92850826,FC(F)C(F)(F)COC(F)(F)C(F)F,0.541,54.11585087,O,0,0,O,0,2.103,1.832
|
7 |
+
C1COC(=O)O1,0.134,13.35070843,C1C(OC(=O)O1)F,0.122,12.2111419,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.107,10.72028474,COC(=O)OC,0.106,10.57995858,FC(F)C(F)(F)COC(F)(F)C(F)F,0.531,53.13790635,O,0,2.077,2.104
|
8 |
+
COCCOC,0.096,9.614613177,COCCOCCOCCOCCOC,0.045,4.53139444,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.12,12.01491409,C1COCO1,0.143,14.28400162,FC(F)C(F)(F)COC(F)(F)C(F)F,0.596,59.55507668,O,0,2.211,2.274
|
9 |
+
C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].F[P-](F)(F)(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.17,1.071
|
10 |
+
C1COC(=O)O1,0.519,51.92400559,COC(=O)OC,0.411,41.14791596,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,6.928078454,O,0,0,O,0,0,O,0,1.077,1.166
|
11 |
+
C1COC(=O)O1,0.519,51.85215842,COC(=O)OC,0.411,41.09097965,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.918492083,[Li+].[N+](=O)([O-])[O-],0.001,0.138369842,O,0,0,O,0,1.19,1.335
|
12 |
+
C1COC(=O)O1,0.513,51.33049845,COC(=O)OC,0.407,40.6775828,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.069,6.9173773,C1=COC(=O)O1,0.011,1.07454145,O,0,0,O,0,1.114,1.129
|
13 |
+
COCCOC,0.53,53.00533987,COCCOCCOCCOCCOC,0.25,24.98156691,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.22,22.01309322,O,0,0,O,0,0,O,0,1.758,1.501
|
14 |
+
COCCOC,0.477,47.74974224,COCCOCCOCCOCCOC,0.225,22.50458884,[Li+].[N-](S(=O)(=O)F)S(=O)(=O)F,0.297,29.74566892,O,0,0,O,0,0,O,0,1.821,1.663
|
data/lce/train.csv
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
smi1,conc1,smi2,conc2,smi3,conc3,smi4,conc4,smi5,conc5,smi6,conc6,LCE
|
2 |
+
C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,1.155
|
3 |
+
C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.046
|
4 |
+
O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0.0,O,0.0,O,0.0,O,0.0,1.569
|
5 |
+
C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0.0,O,0.0,O,0.0,0.886
|
6 |
+
COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0.0,O,0.0,O,0.0,O,0.0,1.367
|
7 |
+
COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,2.301
|
8 |
+
C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0.0,O,0.0,O,0.0,O,0.0,1.489
|
9 |
+
COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0.0,O,0.0,O,0.0,1.244
|
10 |
+
C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0.0,O,0.0,0.745
|
11 |
+
COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0.0,O,0.0,O,0.0,1.292
|
12 |
+
CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0.0,O,0.0,O,0.0,2.301
|
13 |
+
O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,O,0.0,1.745
|
14 |
+
COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,1.745
|
15 |
+
C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0.0,O,0.0,1.076
|
16 |
+
C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.854
|
17 |
+
C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0.0,1.678
|
18 |
+
FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155
|
19 |
+
CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222
|
20 |
+
COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0.0,O,0.0,O,0.0,O,0.0,1.638
|
21 |
+
CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0.0,O,0.0,O,0.0,O,0.0,2.0
|
22 |
+
COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0.0,O,0.0,O,0.0,1.854
|
23 |
+
O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0.0,O,0.0,O,0.0,1.959
|
24 |
+
C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0.0,O,0.0,O,0.0,O,0.0,1.587
|
25 |
+
CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0.0,O,0.0,O,0.0,0.699
|
26 |
+
C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0.0,O,0.0,2.097
|
27 |
+
C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0.0,O,0.0,O,0.0,1.59
|
28 |
+
C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252
|
29 |
+
COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,1.337
|
30 |
+
C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0.0,O,0.0,O,0.0,O,0.0,1.377
|
31 |
+
C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0.0,O,0.0,O,0.0,O,0.0,1.544
|
32 |
+
CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0.0,O,0.0,O,0.0,O,0.0,2.097
|
33 |
+
COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0.0,O,0.0,O,0.0,1.215
|
34 |
+
COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0.0,O,0.0,O,0.0,O,0.0,1.222
|
35 |
+
C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0.0,O,0.0,1.194
|
36 |
+
O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0.0,O,0.0,1.824
|
37 |
+
C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.333
|
38 |
+
O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0.0,O,0.0,1.824
|
39 |
+
COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0.0,O,0.0,O,0.0,2.051
|
40 |
+
COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0.0,O,0.0,O,0.0,1.444
|
41 |
+
O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0.0,O,0.0,1.854
|
42 |
+
CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0.0,O,0.0,O,0.0,O,0.0,2.046
|
43 |
+
C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0.0,O,0.0,1.301
|
44 |
+
C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,,O,0.0,O,0.0,O,0.0,2.222
|
45 |
+
O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0.0,O,0.0,O,0.0,1.903
|
46 |
+
COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.561
|
47 |
+
C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0.0,O,0.0,1.735
|
48 |
+
FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0.0,O,0.0,O,0.0,O,0.0,2.301
|
49 |
+
C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.498
|
50 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0.0,O,0.0,0.745
|
51 |
+
O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0.0,O,0.0,O,0.0,1.824
|
52 |
+
CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,O,0.0,2.0
|
53 |
+
O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0.0,O,0.0,O,0.0,1.456
|
54 |
+
COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0.0,O,0.0,O,0.0,O,0.0,1.301
|
55 |
+
COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0.0,O,0.0,O,0.0,1.678
|
56 |
+
C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0.0,O,0.0,O,0.0,O,0.0,1.646
|
57 |
+
O1CCOC1,0.397,COCCOC,0.589,[Li+][S-]SSS[S-][Li+],,[Li+].[N+](=O)([O-])[O-],0.012,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.002,O,0.0,1.301
|
58 |
+
C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0.0,O,0.0,O,0.0,2.046
|
59 |
+
C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0.0,O,0.0,O,0.0,0.788
|
60 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.373
|
61 |
+
O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0.0,O,0.0,O,0.0,O,0.0,1.602
|
62 |
+
CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,1.854
|
63 |
+
COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0.0,O,0.0,O,0.0,O,0.0,2.097
|
64 |
+
O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0.0,O,0.0,O,0.0,O,0.0,1.699
|
65 |
+
FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.155
|
66 |
+
CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0.0,O,0.0,O,0.0,O,0.0,2.208
|
67 |
+
COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,1.77
|
68 |
+
CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0.0,O,0.0,O,0.0,O,0.0,0.824
|
69 |
+
C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0.0,O,0.0,0.924
|
70 |
+
CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0.0,O,0.0,O,0.0,2.097
|
71 |
+
COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0.0,O,0.0,O,0.0,O,0.0,2.108
|
72 |
+
CC1COC(=O)O1,0.922,[LI+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0.0,O,0.0,O,0.0,O,0.0,0.712
|
73 |
+
C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0.0,O,0.0,1.081
|
74 |
+
C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0.0,O,0.0,1.319
|
75 |
+
COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0.0,O,0.0,O,0.0,1.62
|
76 |
+
C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0.0,O,0.0,2.222
|
77 |
+
C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,O,0.0,0.777
|
78 |
+
CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0.0,O,0.0,O,0.0,O,0.0,2.018
|
79 |
+
COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0.0,O,0.0,O,0.0,O,0.0,1.886
|
80 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0.0,O,0.0,0.699
|
81 |
+
CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0.0,O,0.0,O,0.0,O,0.0,1.569
|
82 |
+
C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0.0,O,0.0,O,0.0,1.523
|
83 |
+
COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.155
|
84 |
+
C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.488
|
85 |
+
O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0.0,2.046
|
86 |
+
C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0.0,O,0.0,O,0.0,O,0.0,1.41
|
87 |
+
COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0.0,O,0.0,O,0.0,2.222
|
88 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0.0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.559
|
89 |
+
COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.301
|
90 |
+
CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0.0,O,0.0,O,0.0,O,0.0,1.672
|
91 |
+
C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538
|
92 |
+
CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0.0,O,0.0,O,0.0,O,0.0,1.796
|
93 |
+
C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0.0,O,0.0,1.355
|
94 |
+
C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0.0,O,0.0,O,0.0,1.523
|
95 |
+
COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0.0,O,0.0,1.78
|
96 |
+
O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0.0,O,0.0,O,0.0,O,0.0,1.456
|
97 |
+
O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0.0,O,0.0,O,0.0,O,0.0,1.745
|
98 |
+
O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0.0,O,0.0,O,0.0,1.967
|
99 |
+
COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0.0,O,0.0,O,0.0,O,0.0,2.097
|
100 |
+
COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0.0,O,0.0,O,0.0,O,0.0,1.143
|
101 |
+
O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0.0,O,0.0,1.523
|
102 |
+
COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0.0,O,0.0,O,0.0,2.301
|
103 |
+
CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0.0,O,0.0,O,0.0,O,0.0,2.155
|
104 |
+
C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,1.301
|
105 |
+
COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,2.222
|
106 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0.0,O,0.0,0.699
|
107 |
+
COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0.0,O,0.0,O,0.0,1.495
|
108 |
+
C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0.0,O,0.0,O,0.0,2.155
|
109 |
+
C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0.0,O,0.0,1.921
|
110 |
+
COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0.0,O,0.0,1.886
|
111 |
+
CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0.0,O,0.0,O,0.0,2.046
|
112 |
+
COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0.0,O,0.0,O,0.0,O,0.0,1.745
|
113 |
+
C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0.0,O,0.0,O,0.0,O,0.0,1.633
|
114 |
+
C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0.0,O,0.0,O,0.0,2.097
|
115 |
+
FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0.0,O,0.0,O,0.0,O,0.0,2.301
|
116 |
+
C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0.0,O,0.0,1.108
|
117 |
+
C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0.0,O,0.0,O,0.0,1.62
|
118 |
+
CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0.0,O,0.0,O,0.0,O,0.0,1.959
|
119 |
+
C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0.0,O,0.0,O,0.0,1.013
|
120 |
+
C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0.0,1.824
|
121 |
+
O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0.0,O,0.0,O,0.0,1.921
|
data/lce/train_data.csv
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
smiles1,conc1,smiles2,conc2,smiles3,conc3,smiles4,conc4,smiles5,conc5,smiles6,conc6,LCE
|
2 |
+
CC1COC(=O)O1,0.875,C1C(OC(=O)O1)F,0.051,[Li+].[O-]Cl(=O)(=O)=O,0.074,O,0,O,0,O,0,0.699
|
3 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[P-](F)(F)(F)(F)F,0.069,O,0,O,0,0.699
|
4 |
+
FC(F)COCCOCC(F)(F),0.845,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.155,O,0,O,0,O,0,O,0,2.301
|
5 |
+
FC(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.301
|
6 |
+
CN(C)C(=O)C(F)(F)F,0.362,C1C(OC(=O)O1)F,0.556,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.081,O,0,O,0,O,0,2.155
|
7 |
+
COCCOC,0.231,FC1CCCCC1,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155
|
8 |
+
CCOP(=O)(OCC)OCC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,2.155
|
9 |
+
O1CCOC1,0.362,C(C(F)(F)F)OCC(F)(F)F,0.59,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.048,O,0,O,0,O,0,1.967
|
10 |
+
COCC(F)(F)C(F)(F)COC,0.864,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.136,O,0,O,0,O,0,O,0,1.991
|
11 |
+
C1C(OC(=O)O1)F,0.662,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.338,O,0,O,0,O,0,O,0,1.646
|
12 |
+
COCCOC,0.358,O1CCOC1,0.532,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.074,[Li+].[N+](=O)([O-])[O-],0.035,O,0,O,0,1.658
|
13 |
+
CN(C)S(=O)(=O)F,0.921,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,O,0,1.672
|
14 |
+
C1C(OC(=O)O1)F,0.106,C1COC(=O)O1,0.522,O=C(OCC)OCC,0.287,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.004,O1CCOCCOCCOCCOCCOCC1,0.004,1.252
|
15 |
+
C1COC(=O)O1,0.32,COC(=O)OC,0.253,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.427,O,0,O,0,O,0,2.155
|
16 |
+
COCCOC,0.277,FC(F)C(F)(F)COC(F)(F)C(F)F,0.555,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.168,O,0,O,0,O,0,2.155
|
17 |
+
COC(=O)OC,0.161,FC(F)C(F)(F)COC(F)(F)C(F)F,0.355,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.484,O,0,O,0,O,0,2.155
|
18 |
+
FC(F)(F)COCCOCC,0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155
|
19 |
+
FC(F)(F)COCCOCC(F)(F)(F),0.838,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.162,O,0,O,0,O,0,O,0,2.155
|
20 |
+
CCOCC,0.64,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.36,O,0,O,0,O,0,O,0,2.208
|
21 |
+
C1C(OC(=O)O1)F,0.144,COC(=O)OCCF,0.173,C(C(F)(F)F)OC(C(F)F)(F)F,0.548,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.135,O,0,O,0,2.222
|
22 |
+
CC#N,0.882,FC,0.065,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.054,O,0,O,0,O,0,2.222
|
23 |
+
C1CCOC1,0.942,FC,0.029,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.029,O,0,O,0,O,0,2.222
|
24 |
+
COCCOC,0.139,COCC(F)(F)C(F)(F)C(F)(F)C(F)(F)COC,0.692,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.169,O,0,O,0,O,0,2.222
|
25 |
+
COCCOC,0.231,C(C(F)(F)F)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.222
|
26 |
+
COCCOC,0.507,COC(C(F)(F)F)C(F)(F)F,0.399,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.095,O,0,O,0,O,0,2.268
|
27 |
+
CCOCC,0.313,C(C(F)(F)F)OCC(F)(F)F,0.51,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.177,O,0,O,0,O,0,2.301
|
28 |
+
COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,2.301
|
29 |
+
COCCOC,0.242,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.604,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.154,O,0,O,0,O,0,2.301
|
30 |
+
O1C(C)CCC1,0.331,FC(F)C(F)(F)COC(F)(F)C(F)F,0.498,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.171,O,0,O,0,O,0,2.301
|
31 |
+
COCCOC,0.2,FC(F)C(F)(F)COC(F)(F)C(F)F,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,2.301
|
32 |
+
COCCOC,0.231,FC(COC(OCC(F)(F)F)OCC(F)(F)F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.301
|
33 |
+
O=S1(=O)CCCC1,0.359,C(C(F)(F)F)OC(C(F)F)(F)F,0.504,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.133,[Li+].[N+](=O)([O-])[O-],0.004,O,0,O,0,2
|
34 |
+
CCOCC,0.856,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,O,0,2
|
35 |
+
CCOCC,0.877,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,O,0,O,0,O,0,O,0,2.018
|
36 |
+
CCOCC,0.707,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.293,O,0,O,0,O,0,O,0,2.046
|
37 |
+
C1COC(=O)O1,0.308,O=C(OCC)OCC(F)(F)F,0.349,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.343,O,0,O,0,O,0,2.046
|
38 |
+
O1CCOC1,0.453,COCCOC,0.305,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.063,[Li+].[N+](=O)([O-])[O-],0.051,O,0,2.046
|
39 |
+
CCOP(=O)(OCC)OCC,0.259,FC(F)C(F)(F)COC(F)(F)C(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,2.046
|
40 |
+
COCCOC,0.257,C(C(F)(F)F)OCC(F)(F)F,0.508,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.235,O,0,O,0,O,0,2.051
|
41 |
+
COC(=O)OC,0.299,C(C(F)(F)F)OCC(F)(F)F,0.598,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.103,O,0,O,0,O,0,2.056
|
42 |
+
CCOP(=O)(OCC)OCC,0.214,C(C(F)(F)F)OCC(F)(F)F,0.642,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.144,O,0,O,0,O,0,2.097
|
43 |
+
COC(=O)OC,0.684,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.316,O,0,O,0,O,0,O,0,2.097
|
44 |
+
C1CCOC1,0.47,O1C(C)CCC1,0.378,[Li+].F[P-](F)(F)(F)(F)F,0.152,O,0,O,0,O,0,2.097
|
45 |
+
C1C(OC(=O)O1)F,0.264,COC(=O)OCCF,0.479,C(C(F)(F)F)OC(C(F)F)(F)F,0.155,[Li+].F[P-](F)(F)(F)(F)F,0.103,O,0,O,0,2.097
|
46 |
+
CCOP(=O)(OCC)OCC,0.5,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.5,O,0,O,0,O,0,O,0,2.097
|
47 |
+
COC(=O)OC,0.543,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.457,O,0,O,0,O,0,O,0,2.097
|
48 |
+
COCCOC,0.682,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.318,O,0,O,0,O,0,O,0,2.108
|
49 |
+
COCCOC,0.231,FC(F)C(F)(F)COC(F)(F)C(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,2.155
|
50 |
+
CCOP(=O)(OCC)OCC,0.728,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.272,O,0,O,0,O,0,O,0,2
|
51 |
+
COCCOC,0.583,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.278,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.139,O,0,O,0,O,0,1.678
|
52 |
+
C1COC(=O)O1,0.305,COC(=O)OC,0.242,COCCOCCOCCOCCOC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.041,[Li+].[N+](=O)([O-])[O-],0.02,O,0,1.678
|
53 |
+
C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,B(O[Si](C)(C)C)(O[Si](C)(C)C)O[Si](C)(C),0.083,[Li+].F[P-](F)(F)(F)(F)F,0.001,O,0,1.678
|
54 |
+
O=S1(=O)CCCC1,0.714,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.286,O,0,O,0,O,0,O,0,1.699
|
55 |
+
C1C(OC(=O)O1)F,0.149,COC(=O)OCCF,0.178,C(C(F)(F)F)OC(C(F)F)(F)F,0.564,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.108,O,0,O,0,1.735
|
56 |
+
O=S1(=O)CCCC1,0.75,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,O,0,1.745
|
57 |
+
COC(=O)OC,0.29,C(C(F)(F)F)OCC(F)(F)F,0.589,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,1.745
|
58 |
+
COCCOC,0.677,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,O,0,1.745
|
59 |
+
O1C(C)CCC1,0.908,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.092,O,0,O,0,O,0,O,0,1.745
|
60 |
+
COC(=O)OC,0.6,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.4,O,0,O,0,O,0,O,0,1.77
|
61 |
+
COCCOC,0.371,O1CCOC1,0.552,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.031,[Li+].[N+](=O)([O-])[O-],0.046,O,0,O,0,1.78
|
62 |
+
CC1CCC(C)O1,0.893,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.107,O,0,O,0,O,0,O,0,1.796
|
63 |
+
C1COC(=O)O1,0.05,CCOC(=O)OC,0.237,C(C(F)(F)F)OCC(F)(F)F,0.575,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.123,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.015,O,0,1.824
|
64 |
+
O=S1(=O)CCCC1,0.758,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.235,[Li+].[N+](=O)([O-])[O-],0.007,O,0,O,0,O,0,1.824
|
65 |
+
O1CCOC1,0.463,COCCOC,0.312,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.194,[Li+].[N+](=O)([O-])[O-],0.03,O,0,O,0,1.824
|
66 |
+
O1CCOC1,0.539,COCCOC,0.363,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.075,[Li+].[N+](=O)([O-])[O-],0.023,O,0,O,0,1.824
|
67 |
+
COC(=O)OC,0.375,FC(F)C(F)(F)COC(F)(F)C(F)F,0.375,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.25,O,0,O,0,O,0,1.854
|
68 |
+
O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.134,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.067,O,0,O,0,1.854
|
69 |
+
CC#N,0.621,C1=COC(=O)O1,0.056,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.323,O,0,O,0,O,0,1.854
|
70 |
+
COC(=O)OC,0.478,FC(F)C(F)(F)COC(F)(F)C(F)F,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.067,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.134,O,0,O,0,1.886
|
71 |
+
COC(=O)OC,0.664,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.336,O,0,O,0,O,0,O,0,1.886
|
72 |
+
O1CCOC1,0.478,COCCOC,0.322,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.2,O,0,O,0,O,0,1.903
|
73 |
+
O=S1(=O)CCCC1,0.429,FC(F)C(F)(F)COC(F)(F)C(F)F,0.429,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.143,O,0,O,0,O,0,1.921
|
74 |
+
C1C(OC(=O)O1)F,0.312,O=C1OCCC1,0.599,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,[Li+].[N+](=O)([O-])[O-],0.021,O,0,O,0,1.921
|
75 |
+
CC1COC(=O)O1,0.595,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.405,O,0,O,0,O,0,O,0,1.921
|
76 |
+
O1CCOC1,0.371,COCCOC,0.552,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.077,O,0,O,0,O,0,1.959
|
77 |
+
CCOCC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.959
|
78 |
+
C1CCOC1,0.925,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.075,O,0,O,0,O,0,O,0,1.377
|
79 |
+
C1COC(=O)O1,0.425,O=C(OCC)OCC,0.234,[Li+].F[P-](F)(F)(F)(F)F,0.34,O,0,O,0,O,0,1.398
|
80 |
+
C1C(OC(=O)O1)F,0.932,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.068,O,0,O,0,O,0,O,0,1.41
|
81 |
+
COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.047,[Li+].FP(F)(=O)([O-]),0.047,O,0,O,0,O,0,1.444
|
82 |
+
O=S1(=O)CCCC1,0.764,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.236,O,0,O,0,O,0,O,0,1.456
|
83 |
+
O=C(OCC)C,0.105,ClCCl,0.64,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.255,O,0,O,0,O,0,1.456
|
84 |
+
C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.488
|
85 |
+
C1C(OC(=O)O1)F,0.873,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.127,O,0,O,0,O,0,O,0,1.489
|
86 |
+
COCCOC,0.231,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.577,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.192,O,0,O,0,O,0,1.495
|
87 |
+
C1C(OC(=O)O1)F,0.495,COC(=O)OC,0.429,O1CCOCCOCCOCC1,0.003,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.498
|
88 |
+
C1C(OC(=O)O1)F,0.481,O=C(OCC)OCC,0.432,[Li+].F[P-](F)(F)(F)(F)F,0.087,O,0,O,0,O,0,1.523
|
89 |
+
O1CCOC1,0.322,COCCOC,0.478,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)C(F)(F)C(F)(F)C(F)(F)F,0.2,O,0,O,0,O,0,1.523
|
90 |
+
O1CCOC1,0.552,COCCOC,0.371,[Li+].[N+](=O)([O-])[O-],0.039,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,O,0,O,0,1.523
|
91 |
+
C1COC(=O)O1,0.444,C1COS(=O)O1,0.497,[Li+].[O-]Cl(=O)(=O)=O,0.059,O,0,O,0,O,0,1.523
|
92 |
+
C1C(OC(=O)O1)F,0.105,C1COC(=O)O1,0.518,O=C(OCC)OCC,0.285,[Li+].F[P-](F)(F)(F)(F)F,0.077,[Rb+].[O-][N+]([O-])=O,0.008,O1CCOCCOCCOCCOCCOCC1,0.008,1.538
|
93 |
+
C1C(OC(=O)O1)F,0.82,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.18,O,0,O,0,O,0,O,0,1.544
|
94 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.559
|
95 |
+
COCCOC,0.906,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,O,0,1.561
|
96 |
+
CCOP(=O)(OCC)OCC,0.648,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.352,O,0,O,0,O,0,O,0,1.569
|
97 |
+
O=S1(=O)CCCC1,0.25,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.75,O,0,O,0,O,0,O,0,1.569
|
98 |
+
C1C(OC(=O)O1)F,0.774,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.226,O,0,O,0,O,0,O,0,1.587
|
99 |
+
C1C(OC(=O)O1)F,0.413,O=C(OCC)OCC,0.497,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.09,O,0,O,0,O,0,1.59
|
100 |
+
C1COC(=O)O1,0.425,O=C(OCC)OCC(F)(F)F,0.481,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.094,O,0,O,0,O,0,1.602
|
101 |
+
CC1COC(=O)O1,0.702,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.298,O,0,O,0,O,0,O,0,1.602
|
102 |
+
O1CCOCC1,0.912,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.088,O,0,O,0,O,0,O,0,1.602
|
103 |
+
C1C(OC(=O)O1)F,0.62,C(C(F)(F)F)OC(=O)OCC(F)(F)F,0.291,[Li+].F[P-](F)(F)(F)(F)F,0.089,O,0,O,0,O,0,1.62
|
104 |
+
COC(=O)CC(F)(F)F,0.768,C1C(OC(=O)O1)F,0.134,[Li+].F[P-](F)(F)(F)(F)F,0.098,O,0,O,0,O,0,1.62
|
105 |
+
C1C(OC(=O)O1)F,0.733,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.267,O,0,O,0,O,0,O,0,1.629
|
106 |
+
C1C(OC(=O)O1)F,0.696,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.304,O,0,O,0,O,0,O,0,1.633
|
107 |
+
COC(C)C(C)OC,0.879,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.121,O,0,O,0,O,0,O,0,1.638
|
108 |
+
C1COC(=O)O1,0.197,COC(=O)OC,0.156,COCCOCCOCCOCCOC,0.59,[Li+].F[P-](F)(F)(F)(F)F,0.026,[Li+].[N+](=O)([O-])[O-],0.031,O,0,1.638
|
109 |
+
C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.26
|
110 |
+
COCCOC,0.707,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.147,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.147,O,0,O,0,O,0,1.268
|
111 |
+
C1COC(=O)O1,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.276
|
112 |
+
COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.174,[Li+].[O-]P(=O)(F)F,0.063,O,0,O,0,O,0,1.292
|
113 |
+
C1COC(=O)O1,0.563,O=C(OCC)OCC,0.31,C1C(OC(=O)O1)F,0.052,[Li+].F[P-](F)(F)(F)(F)F,0.075,O,0,O,0,1.301
|
114 |
+
COCCOCCOCC(F)(F)OC(F)(F)OC(F)(F)COCCOCCOC,0.708,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.292,O,0,O,0,O,0,O,0,1.301
|
115 |
+
C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].[B-]1(OC(=O)C(=O)O1)(F)F,0.092,O,0,O,0,O,0,1.301
|
116 |
+
C1C(OC(=O)O1)F,0.318,CCOC(=O)OC,0.504,COC(=O)OC,0.094,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,1.301
|
117 |
+
C1COC(=O)O1,0.519,O=C(OCC)OCC,0.387,[Li+].F[P-](F)(F)(F)(F)F,0.082,[Li+].[O-]P(=O)(F)F,0.012,O,0,O,0,1.319
|
118 |
+
C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.333
|
119 |
+
COCCOC,0.259,B(OCC(F)(F)F)(OCC(F)(F)F)OCC(F)(F)F,0.556,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.185,O,0,O,0,O,0,1.337
|
120 |
+
C1C(OC(=O)O1)F,0.496,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.002,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.355
|
121 |
+
COCCOC,0.763,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.237,O,0,O,0,O,0,O,0,1.367
|
122 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0.001,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.373
|
123 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].F[B-](F)(F)F,0.069,O,0,O,0,0.699
|
124 |
+
CC1COC(=O)O1,0.922,[Li+].F[B-](F)(F)OC(C(F)(F)(F))(C(F)(F)(F))C(F)(F)(F),0.078,O,0,O,0,O,0,O,0,0.712
|
125 |
+
C1COC(=O)O1,0.3,CCOC(=O)OC,0.593,C1=COC(=O)O1,0.026,[Li+].F[P-](F)(F)(F)(F)F,0.081,O,0,O,0,0.745
|
126 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.069,O,0,O,0,0.745
|
127 |
+
C1COC(=O)O1,0.326,COC(=O)OC,0.602,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.777
|
128 |
+
C1COC(=O)O1,0.362,O=C(OCC)OCC,0.548,[Li+].F[P-](F)(F)(F)(F)F,0.09,O,0,O,0,O,0,0.788
|
129 |
+
CC1COC(=O)O1,0.887,[Li+].F[As-](F)(F)(F)(F)F,0.113,O,0,O,0,O,0,O,0,0.824
|
130 |
+
C1COC(=O)O1,0.507,COC(=O)OC,0.402,C1=COC(=O)O1,0.022,[Li+].C(C(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(C(F)(F)F)(F)F)(F)(F)F,0.069,O,0,O,0,0.854
|
131 |
+
C1COC(=O)O1,0.359,COC(=O)OC,0.569,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,0.854
|
132 |
+
C1COC(=O)O1,0.331,O=C(OCC)OCC,0.577,[Li+].F[P-](F)(F)(F)(F)F,0.092,O,0,O,0,O,0,0.886
|
133 |
+
C1COC(=O)O1,0.594,O=C(OCC)OCC,0.327,[Li+].F[P-](F)(F)(F)(F)F,0.079,O,0,O,0,O,0,0.921
|
134 |
+
C1COC(=O)O1,0.5,CCOC(=O)OC,0.423,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.046,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.031,O,0,O,0,0.924
|
135 |
+
C1COC(=O)O1,0.526,O=C(OCC)OCC,0.392,[Li+].F[P-](F)(F)(F)(F)F,0.083,O,0,O,0,O,0,1.013
|
136 |
+
C1COC(=O)O1,0.356,COC(=O)OC,0.566,FC(F)(F)COB(OCC(F)(F)F)OCC(F)(F)F,0.007,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.046
|
137 |
+
C1COC(=O)O1,0.682,CCOC(=O)OC,0.247,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.043,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.028,O,0,O,0,1.076
|
138 |
+
C1COC(=O)O1,0.854,CCOC(=O)OC,0.08,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.039,[Li+].O=C1O[B-]2(OC1=O)OC(=O)C(=O)O2,0.026,O,0,O,0,1.081
|
139 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.431,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,O,0,1.085
|
140 |
+
C1C(OC(=O)O1)F,0.107,C1COC(=O)O1,0.526,O=C(OCC)OCC,0.289,[Li+].F[P-](F)(F)(F)(F)F,0.078,O,0,O,0,1.108
|
141 |
+
C1COC(=O)O1,0.496,COC(=O)OC,0.393,C1C(OC(=O)O1)F,0.045,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.066,O,0,O,0,1.108
|
142 |
+
COCCOC,0.73,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.27,O,0,O,0,O,0,O,0,1.143
|
143 |
+
C1COC(=O)O1,0.327,O=C(OCC)OCC,0.594,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.079,O,0,O,0,O,0,1.155
|
144 |
+
C1COC(=O)O1,0.338,COC(=O)OC,0.625,[Li+].[O-]P(=O)(F)F,0.008,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.03,O,0,O,0,1.194
|
145 |
+
COCCOC,0.731,[Li+].[O-]P(=O)(F)F,0.064,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.205,O,0,O,0,O,0,1.215
|
146 |
+
COCCOCCOCCOCCOC,0.819,FS([N-]S(F)(=O)=O)(=O)=O.[Li+],0.181,O,0,O,0,O,0,O,0,1.222
|
147 |
+
C1C(OC(=O)O1)F,0.497,COC(=O)OC,0.43,O1CCOCCOCCOCC1,0,[Li+].F[P-](F)(F)(F)(F)F,0.072,O,0,O,0,1.225
|
148 |
+
COCCOC,0.706,[Li+].C(F)(F)(F)S(=O)(=O)[N-]S(=O)(=O)C(F)(F)F,0.008,[Li+].[O-]P(=O)(F)F,0.286,O,0,O,0,O,0,1.244
|
models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/.gitattributes
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
*.csv filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
models/fm4m.py
ADDED
@@ -0,0 +1,964 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import roc_auc_score, roc_curve
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import os
|
5 |
+
import umap
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import pandas as pd
|
10 |
+
import pickle
|
11 |
+
import json
|
12 |
+
|
13 |
+
from xgboost import XGBClassifier, XGBRegressor
|
14 |
+
import xgboost as xgb
|
15 |
+
from sklearn.metrics import roc_auc_score, mean_squared_error
|
16 |
+
import xgboost as xgb
|
17 |
+
from sklearn.svm import SVR
|
18 |
+
from sklearn.linear_model import LinearRegression
|
19 |
+
from sklearn.kernel_ridge import KernelRidge
|
20 |
+
import json
|
21 |
+
from sklearn.compose import TransformedTargetRegressor
|
22 |
+
from sklearn.preprocessing import MinMaxScaler
|
23 |
+
|
24 |
+
|
25 |
+
import torch
|
26 |
+
from transformers import AutoTokenizer, AutoModel
|
27 |
+
|
28 |
+
import sys
|
29 |
+
sys.path.append("models/")
|
30 |
+
|
31 |
+
from models.selfies_ted.load import SELFIES as bart
|
32 |
+
from models.mhg_model import load as mhg
|
33 |
+
from models.smi_ted.smi_ted_light.load import load_smi_ted
|
34 |
+
|
35 |
+
import mordred
|
36 |
+
from mordred import Calculator, descriptors
|
37 |
+
from rdkit import Chem
|
38 |
+
from rdkit.Chem import AllChem
|
39 |
+
|
40 |
+
datasets = {}
|
41 |
+
models = {}
|
42 |
+
downstream_models ={}
|
43 |
+
|
44 |
+
|
45 |
+
def avail_models_data():
|
46 |
+
global datasets
|
47 |
+
global models
|
48 |
+
|
49 |
+
datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
|
50 |
+
{"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
|
51 |
+
{"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
|
52 |
+
{"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
|
53 |
+
{"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
|
54 |
+
{"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
|
55 |
+
{"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"}]
|
56 |
+
|
57 |
+
|
58 |
+
models = [{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality", "Timestamp": "2024-06-21 12:32:20"},
|
59 |
+
{"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality", "Timestamp": "2024-06-21 12:35:56"},
|
60 |
+
{"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model", "Timestamp": "2024-07-10 00:09:42"},
|
61 |
+
{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model", "Timestamp": "2024-07-10 00:09:42"}]
|
62 |
+
|
63 |
+
|
64 |
+
def avail_models(raw=False):
|
65 |
+
global models
|
66 |
+
|
67 |
+
models = [{"Name": "smi-ted", "Model Name": "SMI-TED","Description": "SMILES based encoder decoder model"},
|
68 |
+
{"Name": "bart","Model Name": "SELFIES-TED","Description": "BART model for string based SELFIES modality"},
|
69 |
+
{"Name": "mol-xl","Model Name": "MolFormer", "Description": "MolFormer model for string based SMILES modality"},
|
70 |
+
{"Name": "mhg", "Model Name": "MHG-GED","Description": "Molecular hypergraph model"},
|
71 |
+
{"Name": "Mordred", "Model Name": "Mordred","Description": "Baseline: A descriptor-calculation software application that can calculate more than 1800 two- and three-dimensional descriptors"},
|
72 |
+
{"Name": "MorganFingerprint", "Model Name": "MorganFingerprint","Description": "Baseline: Circular atom environments based descriptor"}
|
73 |
+
]
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
if raw: return models
|
78 |
+
else:
|
79 |
+
return pd.DataFrame(models).drop('Name', axis=1)
|
80 |
+
|
81 |
+
return models
|
82 |
+
|
83 |
+
def avail_downstream_models(raw=False):
|
84 |
+
global downstream_models
|
85 |
+
|
86 |
+
downstream_models = [{"Name": "XGBClassifier", "Task Type": "Classfication"},
|
87 |
+
{"Name": "DefaultClassifier", "Task Type": "Classfication"},
|
88 |
+
{"Name": "SVR", "Task Type": "Regression"},
|
89 |
+
{"Name": "Kernel Ridge", "Task Type": "Regression"},
|
90 |
+
{"Name": "Linear Regression", "Task Type": "Regression"},
|
91 |
+
{"Name": "DefaultRegressor", "Task Type": "Regression"},
|
92 |
+
]
|
93 |
+
|
94 |
+
if raw: return downstream_models
|
95 |
+
else:
|
96 |
+
return pd.DataFrame(downstream_models)
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
def avail_datasets():
|
101 |
+
global datasets
|
102 |
+
|
103 |
+
datasets = [{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv",
|
104 |
+
"Timestamp": "2024-06-26 11:27:37"},
|
105 |
+
{"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre",
|
106 |
+
"Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
|
107 |
+
{"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv",
|
108 |
+
"Timestamp": "2024-06-26 11:33:47"},
|
109 |
+
{"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo",
|
110 |
+
"Timestamp": "2024-06-26 11:34:37"},
|
111 |
+
{"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace",
|
112 |
+
"Timestamp": "2024-06-26 11:36:40"},
|
113 |
+
{"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp",
|
114 |
+
"Timestamp": "2024-06-26 11:39:23"},
|
115 |
+
{"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox",
|
116 |
+
"Timestamp": "2024-06-26 11:42:43"}]
|
117 |
+
|
118 |
+
return datasets
|
119 |
+
|
120 |
+
def reset():
|
121 |
+
|
122 |
+
"""datasets = {"esol": ["smiles", "ESOL predicted log solubility in mols per litre", "data/esol", "2024-06-26 11:36:46.509324"],
|
123 |
+
"freesolv": ["smiles", "expt", "data/freesolv", "2024-06-26 11:37:37.393273"],
|
124 |
+
"lipo": ["smiles", "y", "data/lipo", "2024-06-26 11:37:37.393273"],
|
125 |
+
"hiv": ["smiles", "HIV_active", "data/hiv", "2024-06-26 11:37:37.393273"],
|
126 |
+
"bace": ["smiles", "Class", "data/bace", "2024-06-26 11:38:40.058354"],
|
127 |
+
"bbbp": ["smiles", "p_np", "data/bbbp","2024-06-26 11:38:40.058354"],
|
128 |
+
"clintox": ["smiles", "CT_TOX", "data/clintox","2024-06-26 11:38:40.058354"],
|
129 |
+
"sider": ["smiles","1:", "data/sider","2024-06-26 11:38:40.058354"],
|
130 |
+
"tox21": ["smiles",":-2", "data/tox21","2024-06-26 11:38:40.058354"]
|
131 |
+
}"""
|
132 |
+
|
133 |
+
datasets = [
|
134 |
+
{"Dataset": "hiv", "Input": "smiles", "Output": "HIV_active", "Path": "data/hiv", "Timestamp": "2024-06-26 11:27:37"},
|
135 |
+
{"Dataset": "esol", "Input": "smiles", "Output": "ESOL predicted log solubility in mols per litre", "Path": "data/esol", "Timestamp": "2024-06-26 11:31:46"},
|
136 |
+
{"Dataset": "freesolv", "Input": "smiles", "Output": "expt", "Path": "data/freesolv", "Timestamp": "2024-06-26 11:33:47"},
|
137 |
+
{"Dataset": "lipo", "Input": "smiles", "Output": "y", "Path": "data/lipo", "Timestamp": "2024-06-26 11:34:37"},
|
138 |
+
{"Dataset": "bace", "Input": "smiles", "Output": "Class", "Path": "data/bace", "Timestamp": "2024-06-26 11:36:40"},
|
139 |
+
{"Dataset": "bbbp", "Input": "smiles", "Output": "p_np", "Path": "data/bbbp", "Timestamp": "2024-06-26 11:39:23"},
|
140 |
+
{"Dataset": "clintox", "Input": "smiles", "Output": "CT_TOX", "Path": "data/clintox", "Timestamp": "2024-06-26 11:42:43"},
|
141 |
+
#{"Dataset": "sider", "Input": "smiles", "Output": "1:", "path": "data/sider", "Timestamp": "2024-06-26 11:38:40.058354"},
|
142 |
+
#{"Dataset": "tox21", "Input": "smiles", "Output": ":-2", "path": "data/tox21", "Timestamp": "2024-06-26 11:38:40.058354"}
|
143 |
+
]
|
144 |
+
|
145 |
+
models = [{"Name": "bart", "Description": "BART model for string based SELFIES modality",
|
146 |
+
"Timestamp": "2024-06-21 12:32:20"},
|
147 |
+
{"Name": "mol-xl", "Description": "MolFormer model for string based SMILES modality",
|
148 |
+
"Timestamp": "2024-06-21 12:35:56"},
|
149 |
+
{"Name": "mhg", "Description": "MHG", "Timestamp": "2024-07-10 00:09:42"},
|
150 |
+
{"Name": "spec-gru", "Description": "Spectrum modality with GRU", "Timestamp": "2024-07-10 00:09:42"},
|
151 |
+
{"Name": "spec-lstm", "Description": "Spectrum modality with LSTM", "Timestamp": "2024-07-10 00:09:54"},
|
152 |
+
{"Name": "3d-vae", "Description": "VAE model for 3D atom positions", "Timestamp": "2024-07-10 00:10:08"}]
|
153 |
+
|
154 |
+
|
155 |
+
downstream_models = [
|
156 |
+
{"Name": "XGBClassifier", "Description": "XG Boost Classifier",
|
157 |
+
"Timestamp": "2024-06-21 12:31:20"},
|
158 |
+
{"Name": "XGBRegressor", "Description": "XG Boost Regressor",
|
159 |
+
"Timestamp": "2024-06-21 12:32:56"},
|
160 |
+
{"Name": "2-FNN", "Description": "A two layer feedforward network",
|
161 |
+
"Timestamp": "2024-06-24 14:34:16"},
|
162 |
+
{"Name": "3-FNN", "Description": "A three layer feedforward network",
|
163 |
+
"Timestamp": "2024-06-24 14:38:37"},
|
164 |
+
]
|
165 |
+
|
166 |
+
with open("datasets.json", "w") as outfile:
|
167 |
+
json.dump(datasets, outfile)
|
168 |
+
|
169 |
+
with open("models.json", "w") as outfile:
|
170 |
+
json.dump(models, outfile)
|
171 |
+
|
172 |
+
with open("downstream_models.json", "w") as outfile:
|
173 |
+
json.dump(downstream_models, outfile)
|
174 |
+
|
175 |
+
def update_data_list(list_data):
|
176 |
+
#datasets[list_data[0]] = list_data[1:]
|
177 |
+
|
178 |
+
with open("datasets.json", "w") as outfile:
|
179 |
+
json.dump(datasets, outfile)
|
180 |
+
|
181 |
+
avail_models_data()
|
182 |
+
|
183 |
+
def update_model_list(list_model):
|
184 |
+
#models[list_model[0]] = list_model[1]
|
185 |
+
|
186 |
+
with open("models.json", "w") as outfile:
|
187 |
+
json.dump(list_model, outfile)
|
188 |
+
|
189 |
+
avail_models_data()
|
190 |
+
|
191 |
+
def update_downstream_model_list(list_model):
|
192 |
+
#models[list_model[0]] = list_model[1]
|
193 |
+
|
194 |
+
with open("downstream_models.json", "w") as outfile:
|
195 |
+
json.dump(list_model, outfile)
|
196 |
+
|
197 |
+
avail_models_data()
|
198 |
+
|
199 |
+
avail_models_data()
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
def get_representation(train_data,test_data,model_type, return_tensor=True):
|
204 |
+
alias = {"MHG-GED": "mhg", "SELFIES-TED": "bart", "MolFormer": "mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
|
205 |
+
if model_type in alias.keys():
|
206 |
+
model_type = alias[model_type]
|
207 |
+
|
208 |
+
if model_type == "mhg":
|
209 |
+
model = mhg.load("../models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle")
|
210 |
+
with torch.no_grad():
|
211 |
+
train_emb = model.encode(train_data)
|
212 |
+
x_batch = torch.stack(train_emb)
|
213 |
+
|
214 |
+
test_emb = model.encode(test_data)
|
215 |
+
x_batch_test = torch.stack(test_emb)
|
216 |
+
if not return_tensor:
|
217 |
+
x_batch = pd.DataFrame(x_batch)
|
218 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
219 |
+
|
220 |
+
|
221 |
+
elif model_type == "bart":
|
222 |
+
model = bart()
|
223 |
+
model.load()
|
224 |
+
x_batch = model.encode(train_data, return_tensor=return_tensor)
|
225 |
+
x_batch_test = model.encode(test_data, return_tensor=return_tensor)
|
226 |
+
|
227 |
+
elif model_type == "smi-ted":
|
228 |
+
# model = load_smi_ted(folder='../models/smi_ted/smi_ted_light', ckpt_filename='smi-ted-Light_40.pt')
|
229 |
+
folder = os.path.join(os.path.dirname(__file__), '../models/smi_ted/smi_ted_light')
|
230 |
+
model = load_smi_ted(folder=folder, ckpt_filename='smi-ted-Light_40.pt')
|
231 |
+
with torch.no_grad():
|
232 |
+
x_batch = model.encode(train_data, return_torch=return_tensor)
|
233 |
+
x_batch_test = model.encode(test_data, return_torch=return_tensor)
|
234 |
+
|
235 |
+
elif model_type == "mol-xl":
|
236 |
+
model = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True,
|
237 |
+
trust_remote_code=True)
|
238 |
+
tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True)
|
239 |
+
|
240 |
+
if type(train_data) == list:
|
241 |
+
inputs = tokenizer(train_data, padding=True, return_tensors="pt")
|
242 |
+
else:
|
243 |
+
inputs = tokenizer(list(train_data.values), padding=True, return_tensors="pt")
|
244 |
+
|
245 |
+
with torch.no_grad():
|
246 |
+
outputs = model(**inputs)
|
247 |
+
|
248 |
+
x_batch = outputs.pooler_output
|
249 |
+
|
250 |
+
if type(test_data) == list:
|
251 |
+
inputs = tokenizer(test_data, padding=True, return_tensors="pt")
|
252 |
+
else:
|
253 |
+
inputs = tokenizer(list(test_data.values), padding=True, return_tensors="pt")
|
254 |
+
|
255 |
+
with torch.no_grad():
|
256 |
+
outputs = model(**inputs)
|
257 |
+
|
258 |
+
x_batch_test = outputs.pooler_output
|
259 |
+
|
260 |
+
if not return_tensor:
|
261 |
+
x_batch = pd.DataFrame(x_batch)
|
262 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
263 |
+
|
264 |
+
elif model_type == 'Mordred':
|
265 |
+
all_data = train_data + test_data
|
266 |
+
calc = Calculator(descriptors, ignore_3D=True)
|
267 |
+
mol_list = [Chem.MolFromSmiles(sm) for sm in all_data]
|
268 |
+
x_all = calc.pandas(mol_list)
|
269 |
+
print (f'original mordred fv dim: {x_all.shape}')
|
270 |
+
|
271 |
+
for j in x_all.columns:
|
272 |
+
for k in range(len(x_all[j])):
|
273 |
+
i = x_all.loc[k, j]
|
274 |
+
if type(i) is mordred.error.Missing or type(i) is mordred.error.Error:
|
275 |
+
x_all.loc[k, j] = np.nan
|
276 |
+
|
277 |
+
x_all.dropna(how="any", axis = 1, inplace=True)
|
278 |
+
print (f'Nan excluded mordred fv dim: {x_all.shape}')
|
279 |
+
|
280 |
+
x_batch = x_all.iloc[:len(train_data)]
|
281 |
+
x_batch_test = x_all.iloc[len(train_data):]
|
282 |
+
# print(f'x_batch: {len(x_batch)}, x_batch_test: {len(x_batch_test)}')
|
283 |
+
|
284 |
+
elif model_type == 'MorganFingerprint':
|
285 |
+
params = {'radius':2, 'nBits':1024}
|
286 |
+
|
287 |
+
mol_train = [Chem.MolFromSmiles(sm) for sm in train_data]
|
288 |
+
mol_test = [Chem.MolFromSmiles(sm) for sm in test_data]
|
289 |
+
|
290 |
+
x_batch = []
|
291 |
+
for mol in mol_train:
|
292 |
+
info = {}
|
293 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info)
|
294 |
+
vector = list(fp)
|
295 |
+
x_batch.append(vector)
|
296 |
+
x_batch = pd.DataFrame(x_batch)
|
297 |
+
|
298 |
+
x_batch_test = []
|
299 |
+
for mol in mol_test:
|
300 |
+
info = {}
|
301 |
+
fp = AllChem.GetMorganFingerprintAsBitVect(mol, **params, bitInfo=info)
|
302 |
+
vector = list(fp)
|
303 |
+
x_batch_test.append(vector)
|
304 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
305 |
+
|
306 |
+
return x_batch, x_batch_test
|
307 |
+
|
308 |
+
def single_modal(model,dataset=None, downstream_model=None, params=None, x_train=None, x_test=None, y_train=None, y_test=None):
|
309 |
+
print(model)
|
310 |
+
alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl", "SMI-TED": "smi-ted"}
|
311 |
+
data = avail_models(raw=True)
|
312 |
+
df = pd.DataFrame(data)
|
313 |
+
#print(list(df["Name"].values))
|
314 |
+
|
315 |
+
if model in list(df["Name"].values):
|
316 |
+
model_type = model
|
317 |
+
elif alias[model] in list(df["Name"].values):
|
318 |
+
model_type = alias[model]
|
319 |
+
else:
|
320 |
+
print("Model not available")
|
321 |
+
return
|
322 |
+
|
323 |
+
|
324 |
+
data = avail_datasets()
|
325 |
+
df = pd.DataFrame(data)
|
326 |
+
#print(list(df["Dataset"].values))
|
327 |
+
|
328 |
+
if dataset in list(df["Dataset"].values):
|
329 |
+
task = dataset
|
330 |
+
with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
|
331 |
+
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
332 |
+
print(f" Representation loaded successfully")
|
333 |
+
|
334 |
+
elif x_train==None:
|
335 |
+
|
336 |
+
print("Custom Dataset")
|
337 |
+
#return
|
338 |
+
components = dataset.split(",")
|
339 |
+
train_data = pd.read_csv(components[0])[components[2]]
|
340 |
+
test_data = pd.read_csv(components[1])[components[2]]
|
341 |
+
|
342 |
+
y_batch = pd.read_csv(components[0])[components[3]]
|
343 |
+
y_batch_test = pd.read_csv(components[1])[components[3]]
|
344 |
+
|
345 |
+
|
346 |
+
x_batch, x_batch_test = get_representation(train_data,test_data,model_type)
|
347 |
+
|
348 |
+
|
349 |
+
|
350 |
+
print(f" Representation loaded successfully")
|
351 |
+
|
352 |
+
else:
|
353 |
+
|
354 |
+
y_batch = y_train
|
355 |
+
y_batch_test = y_test
|
356 |
+
x_batch, x_batch_test = get_representation(x_train, x_test, model_type)
|
357 |
+
|
358 |
+
# exclude row containing Nan value
|
359 |
+
if isinstance(x_batch, torch.Tensor):
|
360 |
+
x_batch = pd.DataFrame(x_batch)
|
361 |
+
nan_indices = x_batch.index[x_batch.isna().any(axis=1)]
|
362 |
+
if len(nan_indices) > 0:
|
363 |
+
x_batch.dropna(inplace = True)
|
364 |
+
for index in sorted(nan_indices, reverse=True):
|
365 |
+
del y_batch[index]
|
366 |
+
print(f'x_batch Nan index: {nan_indices}')
|
367 |
+
print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}')
|
368 |
+
|
369 |
+
if isinstance(x_batch_test, torch.Tensor):
|
370 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
371 |
+
nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)]
|
372 |
+
if len(nan_indices) > 0:
|
373 |
+
x_batch_test.dropna(inplace = True)
|
374 |
+
for index in sorted(nan_indices, reverse=True):
|
375 |
+
del y_batch_test[index]
|
376 |
+
print(f'x_batch_test Nan index: {nan_indices}')
|
377 |
+
print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
|
378 |
+
|
379 |
+
print(f" Calculating ROC AUC Score ...")
|
380 |
+
|
381 |
+
if downstream_model == "XGBClassifier":
|
382 |
+
if params == None:
|
383 |
+
xgb_predict_concat = XGBClassifier()
|
384 |
+
else:
|
385 |
+
xgb_predict_concat = XGBClassifier(**params) # n_estimators=5000, learning_rate=0.01, max_depth=10
|
386 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
387 |
+
|
388 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
389 |
+
|
390 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
391 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
392 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
393 |
+
|
394 |
+
try:
|
395 |
+
with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
396 |
+
class_0,class_1 = pickle.load(f1)
|
397 |
+
except:
|
398 |
+
print("Generating latent plots")
|
399 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
400 |
+
verbose=False)
|
401 |
+
n_samples = np.minimum(1000, len(x_batch))
|
402 |
+
|
403 |
+
try:x = y_batch.values[:n_samples]
|
404 |
+
except: x = y_batch[:n_samples]
|
405 |
+
index_0 = [index for index in range(len(x)) if x[index] == 0]
|
406 |
+
index_1 = [index for index in range(len(x)) if x[index] == 1]
|
407 |
+
|
408 |
+
try:
|
409 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
410 |
+
class_0 = features_umap[index_0]
|
411 |
+
class_1 = features_umap[index_1]
|
412 |
+
except:
|
413 |
+
class_0 = []
|
414 |
+
class_1 = []
|
415 |
+
print("Generating latent plots : Done")
|
416 |
+
|
417 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
418 |
+
|
419 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
420 |
+
|
421 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
422 |
+
|
423 |
+
elif downstream_model == "DefaultClassifier":
|
424 |
+
xgb_predict_concat = XGBClassifier() # n_estimators=5000, learning_rate=0.01, max_depth=10
|
425 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
426 |
+
|
427 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
428 |
+
|
429 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
430 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
431 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
432 |
+
|
433 |
+
try:
|
434 |
+
with open(f"plot_emb/{task}_{model_type}.pkl", "rb") as f1:
|
435 |
+
class_0,class_1 = pickle.load(f1)
|
436 |
+
except:
|
437 |
+
print("Generating latent plots")
|
438 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors= 10, n_components=2, low_memory=True, min_dist=0.1, verbose=False)
|
439 |
+
n_samples = np.minimum(1000,len(x_batch))
|
440 |
+
|
441 |
+
try:
|
442 |
+
x = y_batch.values[:n_samples]
|
443 |
+
except:
|
444 |
+
x = y_batch[:n_samples]
|
445 |
+
|
446 |
+
try:
|
447 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
448 |
+
index_0 = [index for index in range(len(x)) if x[index] == 0]
|
449 |
+
index_1 = [index for index in range(len(x)) if x[index] == 1]
|
450 |
+
|
451 |
+
class_0 = features_umap[index_0]
|
452 |
+
class_1 = features_umap[index_1]
|
453 |
+
except:
|
454 |
+
class_0 = []
|
455 |
+
class_1 = []
|
456 |
+
|
457 |
+
print("Generating latent plots : Done")
|
458 |
+
|
459 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
460 |
+
|
461 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
462 |
+
|
463 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
464 |
+
|
465 |
+
elif downstream_model == "SVR":
|
466 |
+
if params == None:
|
467 |
+
regressor = SVR()
|
468 |
+
else:
|
469 |
+
regressor = SVR(**params)
|
470 |
+
model = TransformedTargetRegressor(regressor= regressor,
|
471 |
+
transformer = MinMaxScaler(feature_range=(-1, 1))
|
472 |
+
).fit(x_batch,y_batch)
|
473 |
+
|
474 |
+
y_prob = model.predict(x_batch_test)
|
475 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
476 |
+
|
477 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
478 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
479 |
+
|
480 |
+
print("Generating latent plots")
|
481 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
482 |
+
verbose=False)
|
483 |
+
n_samples = np.minimum(1000, len(x_batch))
|
484 |
+
|
485 |
+
try: x = y_batch.values[:n_samples]
|
486 |
+
except: x = y_batch[:n_samples]
|
487 |
+
#index_0 = [index for index in range(len(x)) if x[index] == 0]
|
488 |
+
#index_1 = [index for index in range(len(x)) if x[index] == 1]
|
489 |
+
|
490 |
+
try:
|
491 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
492 |
+
class_0 = features_umap#[index_0]
|
493 |
+
class_1 = features_umap#[index_1]
|
494 |
+
except:
|
495 |
+
class_0 = []
|
496 |
+
class_1 = []
|
497 |
+
print("Generating latent plots : Done")
|
498 |
+
|
499 |
+
return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
|
500 |
+
|
501 |
+
elif downstream_model == "Kernel Ridge":
|
502 |
+
if params == None:
|
503 |
+
regressor = KernelRidge()
|
504 |
+
else:
|
505 |
+
regressor = KernelRidge(**params)
|
506 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
507 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
508 |
+
).fit(x_batch, y_batch)
|
509 |
+
|
510 |
+
y_prob = model.predict(x_batch_test)
|
511 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
512 |
+
|
513 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
514 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
515 |
+
|
516 |
+
print("Generating latent plots")
|
517 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
518 |
+
verbose=False)
|
519 |
+
n_samples = np.minimum(1000, len(x_batch))
|
520 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
521 |
+
try: x = y_batch.values[:n_samples]
|
522 |
+
except: x = y_batch[:n_samples]
|
523 |
+
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
524 |
+
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
525 |
+
|
526 |
+
class_0 = features_umap#[index_0]
|
527 |
+
class_1 = features_umap#[index_1]
|
528 |
+
print("Generating latent plots : Done")
|
529 |
+
|
530 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
531 |
+
|
532 |
+
|
533 |
+
elif downstream_model == "Linear Regression":
|
534 |
+
if params == None:
|
535 |
+
regressor = LinearRegression()
|
536 |
+
else:
|
537 |
+
regressor = LinearRegression(**params)
|
538 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
539 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
540 |
+
).fit(x_batch, y_batch)
|
541 |
+
|
542 |
+
y_prob = model.predict(x_batch_test)
|
543 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
544 |
+
|
545 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
546 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
547 |
+
|
548 |
+
print("Generating latent plots")
|
549 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
550 |
+
verbose=False)
|
551 |
+
n_samples = np.minimum(1000, len(x_batch))
|
552 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
553 |
+
try:x = y_batch.values[:n_samples]
|
554 |
+
except: x = y_batch[:n_samples]
|
555 |
+
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
556 |
+
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
557 |
+
|
558 |
+
class_0 = features_umap#[index_0]
|
559 |
+
class_1 = features_umap#[index_1]
|
560 |
+
print("Generating latent plots : Done")
|
561 |
+
|
562 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
563 |
+
|
564 |
+
|
565 |
+
elif downstream_model == "DefaultRegressor":
|
566 |
+
regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
|
567 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
568 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
569 |
+
).fit(x_batch, y_batch)
|
570 |
+
|
571 |
+
y_prob = model.predict(x_batch_test)
|
572 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
573 |
+
|
574 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
575 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
576 |
+
|
577 |
+
print("Generating latent plots")
|
578 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
579 |
+
verbose=False)
|
580 |
+
n_samples = np.minimum(1000, len(x_batch))
|
581 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
582 |
+
try:x = y_batch.values[:n_samples]
|
583 |
+
except: x = y_batch[:n_samples]
|
584 |
+
# index_0 = [index for index in range(len(x)) if x[index] == 0]
|
585 |
+
# index_1 = [index for index in range(len(x)) if x[index] == 1]
|
586 |
+
|
587 |
+
class_0 = features_umap#[index_0]
|
588 |
+
class_1 = features_umap#[index_1]
|
589 |
+
print("Generating latent plots : Done")
|
590 |
+
|
591 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
592 |
+
|
593 |
+
|
594 |
+
def multi_modal(model_list,dataset=None, downstream_model=None,params=None, x_train=None, x_test=None, y_train=None, y_test=None):
|
595 |
+
#print(model_list)
|
596 |
+
data = avail_datasets()
|
597 |
+
df = pd.DataFrame(data)
|
598 |
+
list(df["Dataset"].values)
|
599 |
+
|
600 |
+
if dataset in list(df["Dataset"].values):
|
601 |
+
task = dataset
|
602 |
+
predefined = True
|
603 |
+
elif x_train==None:
|
604 |
+
predefined = False
|
605 |
+
components = dataset.split(",")
|
606 |
+
train_data = pd.read_csv(components[0])[components[2]]
|
607 |
+
test_data = pd.read_csv(components[1])[components[2]]
|
608 |
+
|
609 |
+
y_batch = pd.read_csv(components[0])[components[3]]
|
610 |
+
y_batch_test = pd.read_csv(components[1])[components[3]]
|
611 |
+
|
612 |
+
print("Custom Dataset loaded")
|
613 |
+
else:
|
614 |
+
predefined = False
|
615 |
+
y_batch = y_train
|
616 |
+
y_batch_test = y_test
|
617 |
+
train_data = x_train
|
618 |
+
test_data = x_test
|
619 |
+
|
620 |
+
data = avail_models(raw=True)
|
621 |
+
df = pd.DataFrame(data)
|
622 |
+
list(df["Name"].values)
|
623 |
+
|
624 |
+
alias = {"MHG-GED":"mhg", "SELFIES-TED": "bart", "MolFormer":"mol-xl", "Molformer": "mol-xl","SMI-TED":"smi-ted", "Mordred": "Mordred", "MorganFingerprint": "MorganFingerprint"}
|
625 |
+
#if set(model_list).issubset(list(df["Name"].values)):
|
626 |
+
if set(model_list).issubset(list(alias.keys())):
|
627 |
+
for i, model in enumerate(model_list):
|
628 |
+
if model in alias.keys():
|
629 |
+
model_type = alias[model]
|
630 |
+
else:
|
631 |
+
model_type = model
|
632 |
+
|
633 |
+
if i == 0:
|
634 |
+
if predefined:
|
635 |
+
with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
|
636 |
+
x_batch, y_batch, x_batch_test, y_batch_test = pickle.load(f1)
|
637 |
+
print(f" Loaded representation/{task}_{model_type}.pkl")
|
638 |
+
else:
|
639 |
+
x_batch, x_batch_test = get_representation(train_data, test_data, model_type)
|
640 |
+
x_batch = pd.DataFrame(x_batch)
|
641 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
642 |
+
|
643 |
+
else:
|
644 |
+
if predefined:
|
645 |
+
with open(f"representation/{task}_{model_type}.pkl", "rb") as f1:
|
646 |
+
x_batch_1, y_batch_1, x_batch_test_1, y_batch_test_1 = pickle.load(f1)
|
647 |
+
print(f" Loaded representation/{task}_{model_type}.pkl")
|
648 |
+
else:
|
649 |
+
x_batch_1, x_batch_test_1 = get_representation(train_data, test_data, model_type)
|
650 |
+
x_batch_1 = pd.DataFrame(x_batch_1)
|
651 |
+
x_batch_test_1 = pd.DataFrame(x_batch_test_1)
|
652 |
+
|
653 |
+
x_batch = pd.concat([x_batch, x_batch_1], axis=1)
|
654 |
+
x_batch_test = pd.concat([x_batch_test, x_batch_test_1], axis=1)
|
655 |
+
|
656 |
+
else:
|
657 |
+
print("Model not available")
|
658 |
+
return
|
659 |
+
|
660 |
+
num_columns = x_batch_test.shape[1]
|
661 |
+
x_batch_test.columns = [f'{i + 1}' for i in range(num_columns)]
|
662 |
+
|
663 |
+
num_columns = x_batch.shape[1]
|
664 |
+
x_batch.columns = [f'{i + 1}' for i in range(num_columns)]
|
665 |
+
|
666 |
+
# exclude row containing Nan value
|
667 |
+
if isinstance(x_batch, torch.Tensor):
|
668 |
+
x_batch = pd.DataFrame(x_batch)
|
669 |
+
nan_indices = x_batch.index[x_batch.isna().any(axis=1)]
|
670 |
+
if len(nan_indices) > 0:
|
671 |
+
x_batch.dropna(inplace = True)
|
672 |
+
for index in sorted(nan_indices, reverse=True):
|
673 |
+
del y_batch[index]
|
674 |
+
print(f'x_batch Nan index: {nan_indices}')
|
675 |
+
print(f'x_batch shape: {x_batch.shape}, y_batch len: {len(y_batch)}')
|
676 |
+
|
677 |
+
if isinstance(x_batch_test, torch.Tensor):
|
678 |
+
x_batch_test = pd.DataFrame(x_batch_test)
|
679 |
+
nan_indices = x_batch_test.index[x_batch_test.isna().any(axis=1)]
|
680 |
+
if len(nan_indices) > 0:
|
681 |
+
x_batch_test.dropna(inplace = True)
|
682 |
+
for index in sorted(nan_indices, reverse=True):
|
683 |
+
del y_batch_test[index]
|
684 |
+
print(f'x_batch_test Nan index: {nan_indices}')
|
685 |
+
print(f'x_batch_test shape: {x_batch_test.shape}, y_batch_test len: {len(y_batch_test)}')
|
686 |
+
|
687 |
+
print(f"Representations loaded successfully")
|
688 |
+
try:
|
689 |
+
with open(f"plot_emb/{task}_multi.pkl", "rb") as f1:
|
690 |
+
class_0, class_1 = pickle.load(f1)
|
691 |
+
except:
|
692 |
+
print("Generating latent plots")
|
693 |
+
reducer = umap.UMAP(metric='euclidean', n_neighbors=10, n_components=2, low_memory=True, min_dist=0.1,
|
694 |
+
verbose=False)
|
695 |
+
n_samples = np.minimum(1000, len(x_batch))
|
696 |
+
features_umap = reducer.fit_transform(x_batch[:n_samples])
|
697 |
+
|
698 |
+
if "Classifier" in downstream_model:
|
699 |
+
try: x = y_batch.values[:n_samples]
|
700 |
+
except: x = y_batch[:n_samples]
|
701 |
+
index_0 = [index for index in range(len(x)) if x[index] == 0]
|
702 |
+
index_1 = [index for index in range(len(x)) if x[index] == 1]
|
703 |
+
|
704 |
+
class_0 = features_umap[index_0]
|
705 |
+
class_1 = features_umap[index_1]
|
706 |
+
|
707 |
+
else:
|
708 |
+
class_0 = features_umap
|
709 |
+
class_1 = features_umap
|
710 |
+
|
711 |
+
print("Generating latent plots : Done")
|
712 |
+
|
713 |
+
print(f" Calculating ROC AUC Score ...")
|
714 |
+
|
715 |
+
|
716 |
+
if downstream_model == "XGBClassifier":
|
717 |
+
if params == None:
|
718 |
+
xgb_predict_concat = XGBClassifier()
|
719 |
+
else:
|
720 |
+
xgb_predict_concat = XGBClassifier(**params)#n_estimators=5000, learning_rate=0.01, max_depth=10)
|
721 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
722 |
+
|
723 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
724 |
+
|
725 |
+
|
726 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
727 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
728 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
729 |
+
|
730 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
731 |
+
|
732 |
+
#vizualize(x_batch_test, y_batch_test)
|
733 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
734 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
735 |
+
|
736 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
737 |
+
|
738 |
+
elif downstream_model == "DefaultClassifier":
|
739 |
+
xgb_predict_concat = XGBClassifier()#n_estimators=5000, learning_rate=0.01, max_depth=10)
|
740 |
+
xgb_predict_concat.fit(x_batch, y_batch)
|
741 |
+
|
742 |
+
y_prob = xgb_predict_concat.predict_proba(x_batch_test)[:, 1]
|
743 |
+
|
744 |
+
|
745 |
+
roc_auc = roc_auc_score(y_batch_test, y_prob)
|
746 |
+
fpr, tpr, _ = roc_curve(y_batch_test, y_prob)
|
747 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
748 |
+
|
749 |
+
#vizualize(roc_auc,fpr, tpr, x_batch, y_batch )
|
750 |
+
|
751 |
+
#vizualize(x_batch_test, y_batch_test)
|
752 |
+
print(f"ROC-AUC Score: {roc_auc:.4f}")
|
753 |
+
result = f"ROC-AUC Score: {roc_auc:.4f}"
|
754 |
+
|
755 |
+
return result, roc_auc,fpr, tpr, class_0, class_1
|
756 |
+
|
757 |
+
elif downstream_model == "SVR":
|
758 |
+
if params == None:
|
759 |
+
regressor = SVR()
|
760 |
+
else:
|
761 |
+
regressor = SVR(**params)
|
762 |
+
model = TransformedTargetRegressor(regressor= regressor,
|
763 |
+
transformer = MinMaxScaler(feature_range=(-1, 1))
|
764 |
+
).fit(x_batch,y_batch)
|
765 |
+
|
766 |
+
y_prob = model.predict(x_batch_test)
|
767 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
768 |
+
|
769 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
770 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
771 |
+
|
772 |
+
return result, RMSE_score,y_batch_test, y_prob, class_0, class_1
|
773 |
+
|
774 |
+
elif downstream_model == "Linear Regression":
|
775 |
+
if params == None:
|
776 |
+
regressor = LinearRegression()
|
777 |
+
else:
|
778 |
+
regressor = LinearRegression(**params)
|
779 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
780 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
781 |
+
).fit(x_batch, y_batch)
|
782 |
+
|
783 |
+
y_prob = model.predict(x_batch_test)
|
784 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
785 |
+
|
786 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
787 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
788 |
+
|
789 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
790 |
+
|
791 |
+
elif downstream_model == "Kernel Ridge":
|
792 |
+
if params == None:
|
793 |
+
regressor = KernelRidge()
|
794 |
+
else:
|
795 |
+
regressor = KernelRidge(**params)
|
796 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
797 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
798 |
+
).fit(x_batch, y_batch)
|
799 |
+
|
800 |
+
y_prob = model.predict(x_batch_test)
|
801 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
802 |
+
|
803 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
804 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
805 |
+
|
806 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
807 |
+
|
808 |
+
elif downstream_model == "DefaultRegressor":
|
809 |
+
regressor = SVR(kernel="rbf", degree=3, C=5, gamma="scale", epsilon=0.01)
|
810 |
+
model = TransformedTargetRegressor(regressor=regressor,
|
811 |
+
transformer=MinMaxScaler(feature_range=(-1, 1))
|
812 |
+
).fit(x_batch, y_batch)
|
813 |
+
|
814 |
+
y_prob = model.predict(x_batch_test)
|
815 |
+
RMSE_score = np.sqrt(mean_squared_error(y_batch_test, y_prob))
|
816 |
+
|
817 |
+
print(f"RMSE Score: {RMSE_score:.4f}")
|
818 |
+
result = f"RMSE Score: {RMSE_score:.4f}"
|
819 |
+
|
820 |
+
return result, RMSE_score, y_batch_test, y_prob, class_0, class_1
|
821 |
+
|
822 |
+
|
823 |
+
|
824 |
+
def finetune_optuna(x_batch,y_batch, x_batch_test, y_test ):
|
825 |
+
print(f" Finetuning with Optuna and calculating ROC AUC Score ...")
|
826 |
+
X_train = x_batch.values
|
827 |
+
y_train = y_batch.values
|
828 |
+
X_test = x_batch_test.values
|
829 |
+
y_test = y_test.values
|
830 |
+
def objective(trial):
|
831 |
+
# Define parameters to be optimized
|
832 |
+
params = {
|
833 |
+
# 'objective': 'binary:logistic',
|
834 |
+
'eval_metric': 'auc',
|
835 |
+
'verbosity': 0,
|
836 |
+
'n_estimators': trial.suggest_int('n_estimators', 1000, 10000),
|
837 |
+
# 'booster': trial.suggest_categorical('booster', ['gbtree', 'gblinear', 'dart']),
|
838 |
+
# 'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0),
|
839 |
+
'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0),
|
840 |
+
'max_depth': trial.suggest_int('max_depth', 1, 12),
|
841 |
+
# 'eta': trial.suggest_loguniform('eta', 1e-8, 1.0),
|
842 |
+
# 'gamma': trial.suggest_loguniform('gamma', 1e-8, 1.0),
|
843 |
+
# 'grow_policy': trial.suggest_categorical('grow_policy', ['depthwise', 'lossguide']),
|
844 |
+
# "subsample": trial.suggest_float("subsample", 0.05, 1.0),
|
845 |
+
# "colsample_bytree": trial.suggest_float("colsample_bytree", 0.05, 1.0),
|
846 |
+
}
|
847 |
+
|
848 |
+
# Train XGBoost model
|
849 |
+
dtrain = xgb.DMatrix(X_train, label=y_train)
|
850 |
+
dtest = xgb.DMatrix(X_test, label=y_test)
|
851 |
+
|
852 |
+
model = xgb.train(params, dtrain)
|
853 |
+
|
854 |
+
# Predict probabilities
|
855 |
+
y_pred = model.predict(dtest)
|
856 |
+
|
857 |
+
# Calculate ROC AUC score
|
858 |
+
roc_auc = roc_auc_score(y_test, y_pred)
|
859 |
+
print("ROC_AUC : ", roc_auc)
|
860 |
+
|
861 |
+
return roc_auc
|
862 |
+
|
863 |
+
def add_new_model():
|
864 |
+
models = avail_models(raw=True)
|
865 |
+
|
866 |
+
# Function to display models
|
867 |
+
def display_models():
|
868 |
+
for model in models:
|
869 |
+
model_display = f"Name: {model['Name']}, Description: {model['Description']}, Timestamp: {model['Timestamp']}"
|
870 |
+
print(model_display)
|
871 |
+
|
872 |
+
# Function to update models
|
873 |
+
def update_models(new_name, new_description, new_path):
|
874 |
+
new_model = {
|
875 |
+
"Name": new_name,
|
876 |
+
"Description": new_description,
|
877 |
+
"Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
878 |
+
#"path": new_path
|
879 |
+
}
|
880 |
+
models.append(new_model)
|
881 |
+
with open("models.json", "w") as outfile:
|
882 |
+
json.dump(models, outfile)
|
883 |
+
|
884 |
+
print("Model uploaded and updated successfully!")
|
885 |
+
list_models()
|
886 |
+
#display_models()
|
887 |
+
|
888 |
+
# Widgets
|
889 |
+
name_text = widgets.Text(description="Name:", layout=Layout(width='50%'))
|
890 |
+
description_text = widgets.Text(description="Description:", layout=Layout(width='50%'))
|
891 |
+
path_text = widgets.Text(description="Path:", layout=Layout(width='50%'))
|
892 |
+
|
893 |
+
def browse_callback(b):
|
894 |
+
root = tk.Tk()
|
895 |
+
root.withdraw() # Hide the main window
|
896 |
+
file_path = filedialog.askopenfilename(title="Select a Model File")
|
897 |
+
if file_path:
|
898 |
+
path_text.value = file_path
|
899 |
+
|
900 |
+
browse_button = widgets.Button(description="Browse")
|
901 |
+
browse_button.on_click(browse_callback)
|
902 |
+
|
903 |
+
def submit_callback(b):
|
904 |
+
update_models(name_text.value, description_text.value, path_text.value)
|
905 |
+
|
906 |
+
submit_button = widgets.Button(description="Submit")
|
907 |
+
submit_button.on_click(submit_callback)
|
908 |
+
|
909 |
+
# Display widgets
|
910 |
+
display(VBox([name_text, description_text, path_text, browse_button, submit_button]))
|
911 |
+
|
912 |
+
|
913 |
+
def add_new_dataset():
|
914 |
+
# Sample data
|
915 |
+
datasets = avail_datasets()
|
916 |
+
|
917 |
+
# Function to display models
|
918 |
+
def display_datasets():
|
919 |
+
for dataset in datasets:
|
920 |
+
dataset_display = f"Name: {dataset['Dataset']}, Input: {dataset['Input']},Output: {dataset['Output']},Path: {dataset['Path']}, Timestamp: {dataset['Timestamp']}"
|
921 |
+
|
922 |
+
# Function to update models
|
923 |
+
def update_datasets(new_dataset, new_input, new_output, new_path):
|
924 |
+
new_model = {
|
925 |
+
"Dataset": new_dataset,
|
926 |
+
"Input": new_input,
|
927 |
+
"Output": new_output,
|
928 |
+
"Timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
929 |
+
"Path": os.path.basename(new_path)
|
930 |
+
}
|
931 |
+
datasets.append(new_model)
|
932 |
+
with open("datasets.json", "w") as outfile:
|
933 |
+
json.dump(datasets, outfile)
|
934 |
+
|
935 |
+
print("Dataset uploaded and updated successfully!")
|
936 |
+
list_data()
|
937 |
+
|
938 |
+
|
939 |
+
# Widgets
|
940 |
+
dataset_text = widgets.Text(description="Dataset:", layout=Layout(width='50%'))
|
941 |
+
input_text = widgets.Text(description="Input:", layout=Layout(width='50%'))
|
942 |
+
output_text = widgets.Text(description="Output:", layout=Layout(width='50%'))
|
943 |
+
path_text = widgets.Text(description="Path:", layout=Layout(width='50%'))
|
944 |
+
|
945 |
+
def browse_callback(b):
|
946 |
+
root = tk.Tk()
|
947 |
+
root.withdraw() # Hide the main window
|
948 |
+
file_path = filedialog.askopenfilename(title="Select a Dataset File")
|
949 |
+
if file_path:
|
950 |
+
path_text.value = file_path
|
951 |
+
|
952 |
+
browse_button = widgets.Button(description="Browse")
|
953 |
+
browse_button.on_click(browse_callback)
|
954 |
+
|
955 |
+
def submit_callback(b):
|
956 |
+
update_datasets(dataset_text.value, input_text.value, output_text.value, path_text.value)
|
957 |
+
|
958 |
+
submit_button = widgets.Button(description="Submit")
|
959 |
+
submit_button.on_click(submit_callback)
|
960 |
+
|
961 |
+
display(VBox([dataset_text, input_text, output_text, path_text, browse_button, submit_button]))
|
962 |
+
|
963 |
+
|
964 |
+
|
models/mhg_model/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/mhg_model/README.md
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# mhg-gnn
|
2 |
+
|
3 |
+
This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
|
4 |
+
|
5 |
+
**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
## Introduction
|
10 |
+
|
11 |
+
We present MHG-GNN, an autoencoder architecture
|
12 |
+
that has an encoder based on GNN and a decoder based on a sequential model with MHG.
|
13 |
+
Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
|
14 |
+
demonstrate high predictive performance on molecular graph data.
|
15 |
+
In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
|
16 |
+
|
17 |
+
## Table of Contents
|
18 |
+
|
19 |
+
1. [Getting Started](#getting-started)
|
20 |
+
1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
|
21 |
+
2. [Installation](#installation)
|
22 |
+
2. [Feature Extraction](#feature-extraction)
|
23 |
+
|
24 |
+
## Getting Started
|
25 |
+
|
26 |
+
**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
|
27 |
+
|
28 |
+
### Pretrained Models and Training Logs
|
29 |
+
|
30 |
+
We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link](https://huggingface.co/ibm/materials.mhg-ged/blob/main/mhggnn_pretrained_model_0724_2023.pickle)
|
31 |
+
|
32 |
+
Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
|
33 |
+
|
34 |
+
### Installation
|
35 |
+
|
36 |
+
We recommend to create a virtual environment. For example:
|
37 |
+
|
38 |
+
```
|
39 |
+
python3 -m venv .venv
|
40 |
+
. .venv/bin/activate
|
41 |
+
```
|
42 |
+
|
43 |
+
Type the following command once the virtual environment is activated:
|
44 |
+
|
45 |
+
```
|
46 |
+
git clone [email protected]:CMD-TRL/mhg-gnn.git
|
47 |
+
cd ./mhg-gnn
|
48 |
+
pip install .
|
49 |
+
```
|
50 |
+
|
51 |
+
## Feature Extraction
|
52 |
+
|
53 |
+
The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
|
54 |
+
|
55 |
+
To load mhg-gnn, you can simply use:
|
56 |
+
|
57 |
+
```python
|
58 |
+
import torch
|
59 |
+
import load
|
60 |
+
|
61 |
+
model = load.load()
|
62 |
+
```
|
63 |
+
|
64 |
+
To encode SMILES into embeddings, you can use:
|
65 |
+
|
66 |
+
```python
|
67 |
+
with torch.no_grad():
|
68 |
+
repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
|
69 |
+
```
|
70 |
+
|
71 |
+
For decoder, you can use the function, so you can return from embeddings to SMILES strings:
|
72 |
+
|
73 |
+
```python
|
74 |
+
orig = model.decode(repr)
|
75 |
+
```
|
models/mhg_model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
models/mhg_model/graph_grammar/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
"""
|
8 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
9 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
10 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
11 |
+
"""
|
12 |
+
|
13 |
+
""" Title """
|
14 |
+
|
15 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
16 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
17 |
+
__version__ = "0.1"
|
18 |
+
__date__ = "Jan 1 2018"
|
19 |
+
|
models/mhg_model/graph_grammar/algo/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
models/mhg_model/graph_grammar/algo/tree_decomposition.py
ADDED
@@ -0,0 +1,821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from itertools import combinations
|
23 |
+
from ..hypergraph import Hypergraph
|
24 |
+
import networkx as nx
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
class CliqueTree(nx.Graph):
|
29 |
+
''' clique tree object
|
30 |
+
|
31 |
+
Attributes
|
32 |
+
----------
|
33 |
+
hg : Hypergraph
|
34 |
+
This hypergraph will be decomposed.
|
35 |
+
root_hg : Hypergraph
|
36 |
+
Hypergraph on the root node.
|
37 |
+
ident_node_dict : dict
|
38 |
+
ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common)
|
39 |
+
'''
|
40 |
+
def __init__(self, hg=None, **kwargs):
|
41 |
+
self.hg = deepcopy(hg)
|
42 |
+
if self.hg is not None:
|
43 |
+
self.ident_node_dict = self.hg.get_identical_node_dict()
|
44 |
+
else:
|
45 |
+
self.ident_node_dict = {}
|
46 |
+
super().__init__(**kwargs)
|
47 |
+
|
48 |
+
@property
|
49 |
+
def root_hg(self):
|
50 |
+
''' return the hypergraph on the root node
|
51 |
+
'''
|
52 |
+
return self.nodes[0]['subhg']
|
53 |
+
|
54 |
+
@root_hg.setter
|
55 |
+
def root_hg(self, hypergraph):
|
56 |
+
''' set the hypergraph on the root node
|
57 |
+
'''
|
58 |
+
self.nodes[0]['subhg'] = hypergraph
|
59 |
+
|
60 |
+
def insert_subhg(self, subhypergraph: Hypergraph) -> None:
|
61 |
+
''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree.
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
subhg : Hypergraph
|
66 |
+
'''
|
67 |
+
num_nodes = self.number_of_nodes()
|
68 |
+
self.add_node(num_nodes, subhg=subhypergraph)
|
69 |
+
self.add_edge(num_nodes, 0)
|
70 |
+
adj_nodes = deepcopy(list(self.adj[0].keys()))
|
71 |
+
for each_node in adj_nodes:
|
72 |
+
if len(self.nodes[each_node]["subhg"].nodes.intersection(
|
73 |
+
self.nodes[num_nodes]["subhg"].nodes)\
|
74 |
+
- self.root_hg.nodes) != 0 and each_node != num_nodes:
|
75 |
+
self.remove_edge(0, each_node)
|
76 |
+
self.add_edge(each_node, num_nodes)
|
77 |
+
|
78 |
+
def to_irredundant(self) -> None:
|
79 |
+
''' convert the clique tree to be irredundant
|
80 |
+
'''
|
81 |
+
for each_node in self.hg.nodes:
|
82 |
+
subtree = self.subgraph([
|
83 |
+
each_tree_node for each_tree_node in self.nodes()\
|
84 |
+
if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy()
|
85 |
+
leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1]
|
86 |
+
redundant_leaf_node_list = []
|
87 |
+
for each_leaf_node in leaf_node_list:
|
88 |
+
if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0:
|
89 |
+
redundant_leaf_node_list.append(each_leaf_node)
|
90 |
+
for each_red_leaf_node in redundant_leaf_node_list:
|
91 |
+
current_node = each_red_leaf_node
|
92 |
+
while subtree.degree(current_node) == 1 \
|
93 |
+
and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0:
|
94 |
+
self.nodes[current_node]["subhg"].remove_node(each_node)
|
95 |
+
remove_node = current_node
|
96 |
+
current_node = list(dict(subtree[remove_node]).keys())[0]
|
97 |
+
subtree.remove_node(remove_node)
|
98 |
+
|
99 |
+
fixed_node_set = deepcopy(self.nodes)
|
100 |
+
for each_node in fixed_node_set:
|
101 |
+
if self.nodes[each_node]["subhg"].num_edges == 0:
|
102 |
+
if len(self[each_node]) == 1:
|
103 |
+
self.remove_node(each_node)
|
104 |
+
elif len(self[each_node]) == 2:
|
105 |
+
self.add_edge(*self[each_node])
|
106 |
+
self.remove_node(each_node)
|
107 |
+
else:
|
108 |
+
pass
|
109 |
+
else:
|
110 |
+
pass
|
111 |
+
|
112 |
+
redundant = True
|
113 |
+
while redundant:
|
114 |
+
redundant = False
|
115 |
+
fixed_edge_set = deepcopy(self.edges)
|
116 |
+
remove_node_set = set()
|
117 |
+
for node_1, node_2 in fixed_edge_set:
|
118 |
+
if node_1 in remove_node_set or node_2 in remove_node_set:
|
119 |
+
pass
|
120 |
+
else:
|
121 |
+
if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']):
|
122 |
+
redundant = True
|
123 |
+
adj_node_list = set(self.adj[node_1]) - {node_2}
|
124 |
+
self.remove_node(node_1)
|
125 |
+
remove_node_set.add(node_1)
|
126 |
+
for each_node in adj_node_list:
|
127 |
+
self.add_edge(node_2, each_node)
|
128 |
+
|
129 |
+
elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']):
|
130 |
+
redundant = True
|
131 |
+
adj_node_list = set(self.adj[node_2]) - {node_1}
|
132 |
+
self.remove_node(node_2)
|
133 |
+
remove_node_set.add(node_2)
|
134 |
+
for each_node in adj_node_list:
|
135 |
+
self.add_edge(node_1, each_node)
|
136 |
+
|
137 |
+
def node_update(self, key_node: str, subhg) -> None:
|
138 |
+
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
key_node : str
|
143 |
+
key node that must be removed.
|
144 |
+
subhg : Hypegraph
|
145 |
+
"""
|
146 |
+
for each_edge in subhg.edges:
|
147 |
+
self.root_hg.remove_edge(each_edge)
|
148 |
+
self.root_hg.remove_nodes(self.ident_node_dict[key_node])
|
149 |
+
|
150 |
+
adj_node_list = list(subhg.nodes)
|
151 |
+
for each_node in subhg.nodes:
|
152 |
+
if each_node not in self.ident_node_dict[key_node]:
|
153 |
+
if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges):
|
154 |
+
self.root_hg.remove_node(each_node)
|
155 |
+
adj_node_list.remove(each_node)
|
156 |
+
else:
|
157 |
+
adj_node_list.remove(each_node)
|
158 |
+
|
159 |
+
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
160 |
+
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
161 |
+
self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True))
|
162 |
+
|
163 |
+
subhg.remove_edges_with_attr({'tmp' : True})
|
164 |
+
self.insert_subhg(subhg)
|
165 |
+
|
166 |
+
def update(self, subhg, remove_nodes=False):
|
167 |
+
""" given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH.
|
168 |
+
|
169 |
+
Parameters
|
170 |
+
----------
|
171 |
+
subhg : Hypegraph
|
172 |
+
"""
|
173 |
+
for each_edge in subhg.edges:
|
174 |
+
self.root_hg.remove_edge(each_edge)
|
175 |
+
if remove_nodes:
|
176 |
+
remove_edge_list = []
|
177 |
+
for each_edge in self.root_hg.edges:
|
178 |
+
if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\
|
179 |
+
and self.root_hg.edge_attr(each_edge).get('tmp', False):
|
180 |
+
remove_edge_list.append(each_edge)
|
181 |
+
self.root_hg.remove_edges(remove_edge_list)
|
182 |
+
|
183 |
+
adj_node_list = list(subhg.nodes)
|
184 |
+
for each_node in subhg.nodes:
|
185 |
+
if self.root_hg.degree(each_node) == 0:
|
186 |
+
self.root_hg.remove_node(each_node)
|
187 |
+
adj_node_list.remove(each_node)
|
188 |
+
|
189 |
+
if len(adj_node_list) != 1 and not remove_nodes:
|
190 |
+
self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True))
|
191 |
+
'''
|
192 |
+
else:
|
193 |
+
for each_node_1, each_node_2 in combinations(adj_node_list, 2):
|
194 |
+
if not self.root_hg.is_adj(each_node_1, each_node_2):
|
195 |
+
self.root_hg.add_edge(
|
196 |
+
[each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
197 |
+
'''
|
198 |
+
subhg.remove_edges_with_attr({'tmp':True})
|
199 |
+
self.insert_subhg(subhg)
|
200 |
+
|
201 |
+
|
202 |
+
def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'):
|
203 |
+
if mode == 'standard':
|
204 |
+
degree_dict = hg.degrees()
|
205 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
206 |
+
min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict)
|
207 |
+
return min_deg_node, min_deg_subhg
|
208 |
+
elif mode == 'mol':
|
209 |
+
degree_dict = hg.degrees()
|
210 |
+
min_deg = min(degree_dict.values())
|
211 |
+
min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg]
|
212 |
+
min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict)
|
213 |
+
for each_min_deg_node in min_deg_node_list]
|
214 |
+
best_score = np.inf
|
215 |
+
best_idx = -1
|
216 |
+
for each_idx in range(len(min_deg_subhg_list)):
|
217 |
+
if min_deg_subhg_list[each_idx].num_nodes < best_score:
|
218 |
+
best_idx = each_idx
|
219 |
+
return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx]
|
220 |
+
else:
|
221 |
+
raise ValueError
|
222 |
+
|
223 |
+
|
224 |
+
def tree_decomposition(hg, irredundant=True):
|
225 |
+
""" compute a tree decomposition of the input hypergraph
|
226 |
+
|
227 |
+
Parameters
|
228 |
+
----------
|
229 |
+
hg : Hypergraph
|
230 |
+
hypergraph to be decomposed
|
231 |
+
irredundant : bool
|
232 |
+
if True, irredundant tree decomposition will be computed.
|
233 |
+
|
234 |
+
Returns
|
235 |
+
-------
|
236 |
+
clique_tree : nx.Graph
|
237 |
+
each node contains a subhypergraph of `hg`
|
238 |
+
"""
|
239 |
+
org_hg = hg.copy()
|
240 |
+
ident_node_dict = hg.get_identical_node_dict()
|
241 |
+
clique_tree = CliqueTree(org_hg)
|
242 |
+
clique_tree.add_node(0, subhg=org_hg)
|
243 |
+
while True:
|
244 |
+
degree_dict = org_hg.degrees()
|
245 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
246 |
+
min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict)
|
247 |
+
if org_hg.nodes == min_deg_subhg.nodes:
|
248 |
+
break
|
249 |
+
|
250 |
+
# org_hg and min_deg_subhg are divided
|
251 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
252 |
+
|
253 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
254 |
+
|
255 |
+
if irredundant:
|
256 |
+
clique_tree.to_irredundant()
|
257 |
+
return clique_tree
|
258 |
+
|
259 |
+
|
260 |
+
def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False):
|
261 |
+
''' compute a tree decomposition given a hyperedge replacement grammar.
|
262 |
+
the resultant clique tree should induce a less compact HRG.
|
263 |
+
|
264 |
+
Parameters
|
265 |
+
----------
|
266 |
+
hg : Hypergraph
|
267 |
+
hypergraph to be decomposed
|
268 |
+
hrg : HyperedgeReplacementGrammar
|
269 |
+
current HRG
|
270 |
+
irredundant : bool
|
271 |
+
if True, irredundant tree decomposition will be computed.
|
272 |
+
|
273 |
+
Returns
|
274 |
+
-------
|
275 |
+
clique_tree : nx.Graph
|
276 |
+
each node contains a subhypergraph of `hg`
|
277 |
+
'''
|
278 |
+
org_hg = hg.copy()
|
279 |
+
ident_node_dict = hg.get_identical_node_dict()
|
280 |
+
clique_tree = CliqueTree(org_hg)
|
281 |
+
clique_tree.add_node(0, subhg=org_hg)
|
282 |
+
root_node = 0
|
283 |
+
|
284 |
+
# construct a clique tree using HRG
|
285 |
+
success_any = True
|
286 |
+
while success_any:
|
287 |
+
success_any = False
|
288 |
+
for each_prod_rule in hrg.prod_rule_list:
|
289 |
+
org_hg, success, subhg = each_prod_rule.revert(org_hg, True)
|
290 |
+
if success:
|
291 |
+
if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes()
|
292 |
+
success_any = True
|
293 |
+
subhg.remove_edges_with_attr({'terminal' : False})
|
294 |
+
clique_tree.root_hg = org_hg
|
295 |
+
clique_tree.insert_subhg(subhg)
|
296 |
+
|
297 |
+
clique_tree.root_hg = org_hg
|
298 |
+
|
299 |
+
for each_edge in deepcopy(org_hg.edges):
|
300 |
+
if not org_hg.edge_attr(each_edge)['terminal']:
|
301 |
+
node_list = org_hg.nodes_in_edge(each_edge)
|
302 |
+
org_hg.remove_edge(each_edge)
|
303 |
+
|
304 |
+
for each_node_1, each_node_2 in combinations(node_list, 2):
|
305 |
+
if not org_hg.is_adj(each_node_1, each_node_2):
|
306 |
+
org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True))
|
307 |
+
|
308 |
+
# construct a clique tree using the existing algorithm
|
309 |
+
degree_dict = org_hg.degrees()
|
310 |
+
if degree_dict:
|
311 |
+
while True:
|
312 |
+
min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict)
|
313 |
+
if org_hg.nodes == min_deg_subhg.nodes: break
|
314 |
+
|
315 |
+
# org_hg and min_deg_subhg are divided
|
316 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
317 |
+
|
318 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
319 |
+
if irredundant:
|
320 |
+
clique_tree.to_irredundant()
|
321 |
+
|
322 |
+
if return_root:
|
323 |
+
if root_node == 0 and 0 not in clique_tree.nodes:
|
324 |
+
root_node = clique_tree.number_of_nodes()
|
325 |
+
while root_node not in clique_tree.nodes:
|
326 |
+
root_node -= 1
|
327 |
+
elif root_node not in clique_tree.nodes:
|
328 |
+
while root_node not in clique_tree.nodes:
|
329 |
+
root_node -= 1
|
330 |
+
else:
|
331 |
+
pass
|
332 |
+
return clique_tree, root_node
|
333 |
+
else:
|
334 |
+
return clique_tree
|
335 |
+
|
336 |
+
|
337 |
+
def tree_decomposition_from_leaf(hg, irredundant=True):
|
338 |
+
""" compute a tree decomposition of the input hypergraph
|
339 |
+
|
340 |
+
Parameters
|
341 |
+
----------
|
342 |
+
hg : Hypergraph
|
343 |
+
hypergraph to be decomposed
|
344 |
+
irredundant : bool
|
345 |
+
if True, irredundant tree decomposition will be computed.
|
346 |
+
|
347 |
+
Returns
|
348 |
+
-------
|
349 |
+
clique_tree : nx.Graph
|
350 |
+
each node contains a subhypergraph of `hg`
|
351 |
+
"""
|
352 |
+
def apply_normal_decomposition(clique_tree):
|
353 |
+
degree_dict = clique_tree.root_hg.degrees()
|
354 |
+
min_deg_node = min(degree_dict, key=degree_dict.get)
|
355 |
+
min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict)
|
356 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
357 |
+
return clique_tree, False
|
358 |
+
clique_tree.node_update(min_deg_node, min_deg_subhg)
|
359 |
+
return clique_tree, True
|
360 |
+
|
361 |
+
def apply_min_edge_deg_decomposition(clique_tree):
|
362 |
+
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
363 |
+
non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
364 |
+
if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')]
|
365 |
+
if not non_tmp_edge_list:
|
366 |
+
return clique_tree, False
|
367 |
+
min_deg_edge = None
|
368 |
+
min_deg = np.inf
|
369 |
+
for each_edge in non_tmp_edge_list:
|
370 |
+
if min_deg > edge_degree_dict[each_edge]:
|
371 |
+
min_deg_edge = each_edge
|
372 |
+
min_deg = edge_degree_dict[each_edge]
|
373 |
+
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
374 |
+
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
375 |
+
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
376 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
377 |
+
return clique_tree, False
|
378 |
+
clique_tree.update(min_deg_subhg)
|
379 |
+
return clique_tree, True
|
380 |
+
|
381 |
+
org_hg = hg.copy()
|
382 |
+
clique_tree = CliqueTree(org_hg)
|
383 |
+
clique_tree.add_node(0, subhg=org_hg)
|
384 |
+
|
385 |
+
success = True
|
386 |
+
while success:
|
387 |
+
clique_tree, success = apply_min_edge_deg_decomposition(clique_tree)
|
388 |
+
if not success:
|
389 |
+
clique_tree, success = apply_normal_decomposition(clique_tree)
|
390 |
+
|
391 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
392 |
+
if irredundant:
|
393 |
+
clique_tree.to_irredundant()
|
394 |
+
return clique_tree
|
395 |
+
|
396 |
+
def topological_tree_decomposition(
|
397 |
+
hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False):
|
398 |
+
''' compute a tree decomposition of the input hypergraph
|
399 |
+
|
400 |
+
Parameters
|
401 |
+
----------
|
402 |
+
hg : Hypergraph
|
403 |
+
hypergraph to be decomposed
|
404 |
+
irredundant : bool
|
405 |
+
if True, irredundant tree decomposition will be computed.
|
406 |
+
|
407 |
+
Returns
|
408 |
+
-------
|
409 |
+
clique_tree : CliqueTree
|
410 |
+
each node contains a subhypergraph of `hg`
|
411 |
+
'''
|
412 |
+
def _contract_tree(clique_tree):
|
413 |
+
''' contract a single leaf
|
414 |
+
|
415 |
+
Parameters
|
416 |
+
----------
|
417 |
+
clique_tree : CliqueTree
|
418 |
+
|
419 |
+
Returns
|
420 |
+
-------
|
421 |
+
CliqueTree, bool
|
422 |
+
bool represents whether this operation succeeds or not.
|
423 |
+
'''
|
424 |
+
edge_degree_dict = clique_tree.root_hg.edge_degrees()
|
425 |
+
leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \
|
426 |
+
if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\
|
427 |
+
and edge_degree_dict[each_edge] == 1]
|
428 |
+
if not leaf_edge_list:
|
429 |
+
return clique_tree, False
|
430 |
+
min_deg_edge = leaf_edge_list[0]
|
431 |
+
node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge)
|
432 |
+
min_deg_subhg = clique_tree.root_hg.get_subhg(
|
433 |
+
node_list, [min_deg_edge], clique_tree.ident_node_dict)
|
434 |
+
if clique_tree.root_hg.nodes == min_deg_subhg.nodes:
|
435 |
+
return clique_tree, False
|
436 |
+
clique_tree.update(min_deg_subhg)
|
437 |
+
return clique_tree, True
|
438 |
+
|
439 |
+
def _rip_labels_from_cycles(clique_tree, org_hg):
|
440 |
+
''' rip hyperedge-labels off
|
441 |
+
|
442 |
+
Parameters
|
443 |
+
----------
|
444 |
+
clique_tree : CliqueTree
|
445 |
+
org_hg : Hypergraph
|
446 |
+
|
447 |
+
Returns
|
448 |
+
-------
|
449 |
+
CliqueTree, bool
|
450 |
+
bool represents whether this operation succeeds or not.
|
451 |
+
'''
|
452 |
+
ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict()
|
453 |
+
for each_edge in clique_tree.root_hg.edges:
|
454 |
+
if each_edge in org_hg.edges:
|
455 |
+
if org_hg.in_cycle(each_edge):
|
456 |
+
node_list = clique_tree.root_hg.nodes_in_edge(each_edge)
|
457 |
+
subhg = clique_tree.root_hg.get_subhg(
|
458 |
+
node_list, [each_edge], ident_node_dict)
|
459 |
+
if clique_tree.root_hg.nodes == subhg.nodes:
|
460 |
+
return clique_tree, False
|
461 |
+
clique_tree.update(subhg)
|
462 |
+
'''
|
463 |
+
in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list}
|
464 |
+
if not all(in_cycle_dict.values()):
|
465 |
+
node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0]
|
466 |
+
node_list = [node_not_in_cycle]
|
467 |
+
node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle))
|
468 |
+
edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle)
|
469 |
+
import pdb; pdb.set_trace()
|
470 |
+
subhg = clique_tree.root_hg.get_subhg(
|
471 |
+
node_list, edge_list, ident_node_dict)
|
472 |
+
|
473 |
+
clique_tree.update(subhg)
|
474 |
+
'''
|
475 |
+
return clique_tree, True
|
476 |
+
return clique_tree, False
|
477 |
+
|
478 |
+
def _shrink_cycle(clique_tree):
|
479 |
+
''' shrink a cycle
|
480 |
+
|
481 |
+
Parameters
|
482 |
+
----------
|
483 |
+
clique_tree : CliqueTree
|
484 |
+
|
485 |
+
Returns
|
486 |
+
-------
|
487 |
+
CliqueTree, bool
|
488 |
+
bool represents whether this operation succeeds or not.
|
489 |
+
'''
|
490 |
+
def filter_subhg(subhg, hg, key_node):
|
491 |
+
num_nodes_cycle = 0
|
492 |
+
nodes_in_cycle_list = []
|
493 |
+
for each_node in subhg.nodes:
|
494 |
+
if hg.in_cycle(each_node):
|
495 |
+
num_nodes_cycle += 1
|
496 |
+
if each_node != key_node:
|
497 |
+
nodes_in_cycle_list.append(each_node)
|
498 |
+
if num_nodes_cycle > 3:
|
499 |
+
break
|
500 |
+
if num_nodes_cycle != 3:
|
501 |
+
return False
|
502 |
+
else:
|
503 |
+
for each_edge in hg.edges:
|
504 |
+
if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)):
|
505 |
+
return False
|
506 |
+
return True
|
507 |
+
|
508 |
+
#ident_node_dict = hg.get_identical_node_dict()
|
509 |
+
ident_node_dict = clique_tree.ident_node_dict
|
510 |
+
for each_node in clique_tree.root_hg.nodes:
|
511 |
+
if clique_tree.root_hg.in_cycle(each_node)\
|
512 |
+
and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict),
|
513 |
+
clique_tree.root_hg,
|
514 |
+
each_node):
|
515 |
+
target_node = each_node
|
516 |
+
target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict)
|
517 |
+
if clique_tree.root_hg.nodes == target_subhg.nodes:
|
518 |
+
return clique_tree, False
|
519 |
+
clique_tree.update(target_subhg)
|
520 |
+
return clique_tree, True
|
521 |
+
return clique_tree, False
|
522 |
+
|
523 |
+
def _contract_cycles(clique_tree):
|
524 |
+
'''
|
525 |
+
remove a subhypergraph that looks like a cycle on a leaf.
|
526 |
+
|
527 |
+
Parameters
|
528 |
+
----------
|
529 |
+
clique_tree : CliqueTree
|
530 |
+
|
531 |
+
Returns
|
532 |
+
-------
|
533 |
+
CliqueTree, bool
|
534 |
+
bool represents whether this operation succeeds or not.
|
535 |
+
'''
|
536 |
+
def _divide_hg(hg):
|
537 |
+
''' divide a hypergraph into subhypergraphs such that
|
538 |
+
each subhypergraph is connected to each other in a tree-like way.
|
539 |
+
|
540 |
+
Parameters
|
541 |
+
----------
|
542 |
+
hg : Hypergraph
|
543 |
+
|
544 |
+
Returns
|
545 |
+
-------
|
546 |
+
list of Hypergraphs
|
547 |
+
each element corresponds to a subhypergraph of `hg`
|
548 |
+
'''
|
549 |
+
for each_node in hg.nodes:
|
550 |
+
if hg.is_dividable(each_node):
|
551 |
+
adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)}
|
552 |
+
'''
|
553 |
+
if any(adj_edges_dict.values()):
|
554 |
+
import pdb; pdb.set_trace()
|
555 |
+
edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0]
|
556 |
+
subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle)
|
557 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3)
|
558 |
+
else:
|
559 |
+
'''
|
560 |
+
subhg1, subhg2 = hg.divide(each_node)
|
561 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
562 |
+
return [hg]
|
563 |
+
|
564 |
+
def _is_leaf(hg, divided_subhg) -> bool:
|
565 |
+
''' judge whether subhg is a leaf-like in the original hypergraph
|
566 |
+
|
567 |
+
Parameters
|
568 |
+
----------
|
569 |
+
hg : Hypergraph
|
570 |
+
divided_subhg : Hypergraph
|
571 |
+
`divided_subhg` is a subhypergraph of `hg`
|
572 |
+
|
573 |
+
Returns
|
574 |
+
-------
|
575 |
+
bool
|
576 |
+
'''
|
577 |
+
'''
|
578 |
+
adj_edges_set = set([])
|
579 |
+
for each_node in divided_subhg.nodes:
|
580 |
+
adj_edges_set.update(set(hg.adj_edges(each_node)))
|
581 |
+
|
582 |
+
|
583 |
+
_hg = deepcopy(hg)
|
584 |
+
_hg.remove_subhg(divided_subhg)
|
585 |
+
if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1):
|
586 |
+
import pdb; pdb.set_trace()
|
587 |
+
return len(adj_edges_set - divided_subhg.edges) == 1
|
588 |
+
'''
|
589 |
+
_hg = deepcopy(hg)
|
590 |
+
_hg.remove_subhg(divided_subhg)
|
591 |
+
return nx.is_connected(_hg.hg)
|
592 |
+
|
593 |
+
subhg_list = _divide_hg(clique_tree.root_hg)
|
594 |
+
if len(subhg_list) == 1:
|
595 |
+
return clique_tree, False
|
596 |
+
else:
|
597 |
+
while len(subhg_list) > 1:
|
598 |
+
max_leaf_subhg = None
|
599 |
+
for each_subhg in subhg_list:
|
600 |
+
if _is_leaf(clique_tree.root_hg, each_subhg):
|
601 |
+
if max_leaf_subhg is None:
|
602 |
+
max_leaf_subhg = each_subhg
|
603 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
604 |
+
max_leaf_subhg = each_subhg
|
605 |
+
clique_tree.update(max_leaf_subhg)
|
606 |
+
subhg_list.remove(max_leaf_subhg)
|
607 |
+
return clique_tree, True
|
608 |
+
|
609 |
+
org_hg = hg.copy()
|
610 |
+
clique_tree = CliqueTree(org_hg)
|
611 |
+
clique_tree.add_node(0, subhg=org_hg)
|
612 |
+
|
613 |
+
success = True
|
614 |
+
while success:
|
615 |
+
'''
|
616 |
+
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
617 |
+
if not success:
|
618 |
+
clique_tree, success = _contract_cycles(clique_tree)
|
619 |
+
'''
|
620 |
+
clique_tree, success = _contract_tree(clique_tree)
|
621 |
+
if not success:
|
622 |
+
if rip_labels:
|
623 |
+
clique_tree, success = _rip_labels_from_cycles(clique_tree, hg)
|
624 |
+
if not success:
|
625 |
+
if shrink_cycle:
|
626 |
+
clique_tree, success = _shrink_cycle(clique_tree)
|
627 |
+
if not success:
|
628 |
+
if contract_cycles:
|
629 |
+
clique_tree, success = _contract_cycles(clique_tree)
|
630 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
631 |
+
if irredundant:
|
632 |
+
clique_tree.to_irredundant()
|
633 |
+
return clique_tree
|
634 |
+
|
635 |
+
def molecular_tree_decomposition(hg, irredundant=True):
|
636 |
+
""" compute a tree decomposition of the input molecular hypergraph
|
637 |
+
|
638 |
+
Parameters
|
639 |
+
----------
|
640 |
+
hg : Hypergraph
|
641 |
+
molecular hypergraph to be decomposed
|
642 |
+
irredundant : bool
|
643 |
+
if True, irredundant tree decomposition will be computed.
|
644 |
+
|
645 |
+
Returns
|
646 |
+
-------
|
647 |
+
clique_tree : CliqueTree
|
648 |
+
each node contains a subhypergraph of `hg`
|
649 |
+
"""
|
650 |
+
def _divide_hg(hg):
|
651 |
+
''' divide a hypergraph into subhypergraphs such that
|
652 |
+
each subhypergraph is connected to each other in a tree-like way.
|
653 |
+
|
654 |
+
Parameters
|
655 |
+
----------
|
656 |
+
hg : Hypergraph
|
657 |
+
|
658 |
+
Returns
|
659 |
+
-------
|
660 |
+
list of Hypergraphs
|
661 |
+
each element corresponds to a subhypergraph of `hg`
|
662 |
+
'''
|
663 |
+
is_ring = False
|
664 |
+
for each_node in hg.nodes:
|
665 |
+
if hg.node_attr(each_node)['is_in_ring']:
|
666 |
+
is_ring = True
|
667 |
+
if not hg.node_attr(each_node)['is_in_ring'] \
|
668 |
+
and hg.degree(each_node) == 2:
|
669 |
+
subhg1, subhg2 = hg.divide(each_node)
|
670 |
+
return _divide_hg(subhg1) + _divide_hg(subhg2)
|
671 |
+
|
672 |
+
if is_ring:
|
673 |
+
subhg_list = []
|
674 |
+
remove_edge_list = []
|
675 |
+
remove_node_list = []
|
676 |
+
for each_edge in hg.edges:
|
677 |
+
node_list = hg.nodes_in_edge(each_edge)
|
678 |
+
subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict())
|
679 |
+
subhg_list.append(subhg)
|
680 |
+
remove_edge_list.append(each_edge)
|
681 |
+
for each_node in node_list:
|
682 |
+
if not hg.node_attr(each_node)['is_in_ring']:
|
683 |
+
remove_node_list.append(each_node)
|
684 |
+
hg.remove_edges(remove_edge_list)
|
685 |
+
hg.remove_nodes(remove_node_list, False)
|
686 |
+
return subhg_list + [hg]
|
687 |
+
else:
|
688 |
+
return [hg]
|
689 |
+
|
690 |
+
org_hg = hg.copy()
|
691 |
+
clique_tree = CliqueTree(org_hg)
|
692 |
+
clique_tree.add_node(0, subhg=org_hg)
|
693 |
+
|
694 |
+
subhg_list = _divide_hg(deepcopy(clique_tree.root_hg))
|
695 |
+
#_subhg_list = deepcopy(subhg_list)
|
696 |
+
if len(subhg_list) == 1:
|
697 |
+
pass
|
698 |
+
else:
|
699 |
+
while len(subhg_list) > 1:
|
700 |
+
max_leaf_subhg = None
|
701 |
+
for each_subhg in subhg_list:
|
702 |
+
if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg):
|
703 |
+
if max_leaf_subhg is None:
|
704 |
+
max_leaf_subhg = each_subhg
|
705 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
706 |
+
max_leaf_subhg = each_subhg
|
707 |
+
|
708 |
+
if max_leaf_subhg is None:
|
709 |
+
for each_subhg in subhg_list:
|
710 |
+
if _is_ring_label(clique_tree.root_hg, each_subhg):
|
711 |
+
if max_leaf_subhg is None:
|
712 |
+
max_leaf_subhg = each_subhg
|
713 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
714 |
+
max_leaf_subhg = each_subhg
|
715 |
+
if max_leaf_subhg is not None:
|
716 |
+
clique_tree.update(max_leaf_subhg)
|
717 |
+
subhg_list.remove(max_leaf_subhg)
|
718 |
+
else:
|
719 |
+
for each_subhg in subhg_list:
|
720 |
+
if _is_leaf(clique_tree.root_hg, each_subhg):
|
721 |
+
if max_leaf_subhg is None:
|
722 |
+
max_leaf_subhg = each_subhg
|
723 |
+
elif max_leaf_subhg.num_nodes < each_subhg.num_nodes:
|
724 |
+
max_leaf_subhg = each_subhg
|
725 |
+
if max_leaf_subhg is not None:
|
726 |
+
clique_tree.update(max_leaf_subhg, True)
|
727 |
+
subhg_list.remove(max_leaf_subhg)
|
728 |
+
else:
|
729 |
+
break
|
730 |
+
if len(subhg_list) > 1:
|
731 |
+
'''
|
732 |
+
for each_idx, each_subhg in enumerate(subhg_list):
|
733 |
+
each_subhg.draw(f'{each_idx}', True)
|
734 |
+
clique_tree.root_hg.draw('root', True)
|
735 |
+
import pickle
|
736 |
+
with open('buggy_hg.pkl', 'wb') as f:
|
737 |
+
pickle.dump(hg, f)
|
738 |
+
return clique_tree, subhg_list, _subhg_list
|
739 |
+
'''
|
740 |
+
raise RuntimeError('bug in tree decomposition algorithm')
|
741 |
+
clique_tree.root_hg.remove_edges_with_attr({'tmp' : True})
|
742 |
+
|
743 |
+
'''
|
744 |
+
for each_tree_node in clique_tree.adj[0]:
|
745 |
+
subhg = clique_tree.nodes[each_tree_node]['subhg']
|
746 |
+
for each_edge in subhg.edges:
|
747 |
+
if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes):
|
748 |
+
clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True))
|
749 |
+
'''
|
750 |
+
if irredundant:
|
751 |
+
clique_tree.to_irredundant()
|
752 |
+
return clique_tree #, _subhg_list
|
753 |
+
|
754 |
+
def _is_leaf(hg, subhg) -> bool:
|
755 |
+
''' judge whether subhg is a leaf-like in the original hypergraph
|
756 |
+
|
757 |
+
Parameters
|
758 |
+
----------
|
759 |
+
hg : Hypergraph
|
760 |
+
subhg : Hypergraph
|
761 |
+
`subhg` is a subhypergraph of `hg`
|
762 |
+
|
763 |
+
Returns
|
764 |
+
-------
|
765 |
+
bool
|
766 |
+
'''
|
767 |
+
if len(subhg.edges) == 0:
|
768 |
+
adj_edge_set = set([])
|
769 |
+
subhg_edge_set = set([])
|
770 |
+
for each_edge in hg.edges:
|
771 |
+
if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False):
|
772 |
+
subhg_edge_set.add(each_edge)
|
773 |
+
for each_node in subhg.nodes:
|
774 |
+
adj_edge_set.update(set(hg.adj_edges(each_node)))
|
775 |
+
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
776 |
+
return True
|
777 |
+
else:
|
778 |
+
return False
|
779 |
+
elif len(subhg.edges) == 1:
|
780 |
+
adj_edge_set = set([])
|
781 |
+
subhg_edge_set = subhg.edges
|
782 |
+
for each_node in subhg.nodes:
|
783 |
+
for each_adj_edge in hg.adj_edges(each_node):
|
784 |
+
adj_edge_set.add(each_adj_edge)
|
785 |
+
if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1:
|
786 |
+
return True
|
787 |
+
else:
|
788 |
+
return False
|
789 |
+
else:
|
790 |
+
raise ValueError('subhg should be nodes only or one-edge hypergraph.')
|
791 |
+
|
792 |
+
def _is_ring_label(hg, subhg):
|
793 |
+
if len(subhg.edges) != 1:
|
794 |
+
return False
|
795 |
+
edge_name = list(subhg.edges)[0]
|
796 |
+
#assert edge_name in hg.edges, f'{edge_name}'
|
797 |
+
is_in_ring = False
|
798 |
+
for each_node in subhg.nodes:
|
799 |
+
if subhg.node_attr(each_node)['is_in_ring']:
|
800 |
+
is_in_ring = True
|
801 |
+
else:
|
802 |
+
adj_edge_list = list(hg.adj_edges(each_node))
|
803 |
+
adj_edge_list.remove(edge_name)
|
804 |
+
if len(adj_edge_list) == 1:
|
805 |
+
if not hg.edge_attr(adj_edge_list[0]).get('tmp', False):
|
806 |
+
return False
|
807 |
+
elif len(adj_edge_list) == 0:
|
808 |
+
pass
|
809 |
+
else:
|
810 |
+
raise ValueError
|
811 |
+
if is_in_ring:
|
812 |
+
return True
|
813 |
+
else:
|
814 |
+
return False
|
815 |
+
|
816 |
+
def _is_ring(hg):
|
817 |
+
for each_node in hg.nodes:
|
818 |
+
if not hg.node_attr(each_node)['is_in_ring']:
|
819 |
+
return False
|
820 |
+
return True
|
821 |
+
|
models/mhg_model/graph_grammar/graph_grammar/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
models/mhg_model/graph_grammar/graph_grammar/base.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from abc import ABCMeta, abstractmethod
|
22 |
+
|
23 |
+
class GraphGrammarBase(metaclass=ABCMeta):
|
24 |
+
@abstractmethod
|
25 |
+
def learn(self):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def sample(self):
|
30 |
+
pass
|
models/mhg_model/graph_grammar/graph_grammar/corpus.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jun 4 2018"
|
20 |
+
|
21 |
+
from collections import Counter
|
22 |
+
from functools import partial
|
23 |
+
from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule
|
24 |
+
from networkx.algorithms.isomorphism import GraphMatcher
|
25 |
+
import os
|
26 |
+
|
27 |
+
|
28 |
+
class CliqueTreeCorpus(object):
|
29 |
+
|
30 |
+
''' clique tree corpus
|
31 |
+
|
32 |
+
Attributes
|
33 |
+
----------
|
34 |
+
clique_tree_list : list of CliqueTree
|
35 |
+
subhg_list : list of Hypergraph
|
36 |
+
'''
|
37 |
+
|
38 |
+
def __init__(self):
|
39 |
+
self.clique_tree_list = []
|
40 |
+
self.subhg_list = []
|
41 |
+
|
42 |
+
@property
|
43 |
+
def size(self):
|
44 |
+
return len(self.subhg_list)
|
45 |
+
|
46 |
+
def add_clique_tree(self, clique_tree):
|
47 |
+
for each_node in clique_tree.nodes:
|
48 |
+
subhg = clique_tree.nodes[each_node]['subhg']
|
49 |
+
subhg_idx = self.add_subhg(subhg)
|
50 |
+
clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx
|
51 |
+
self.clique_tree_list.append(clique_tree)
|
52 |
+
|
53 |
+
def add_to_subhg_list(self, clique_tree, root_node):
|
54 |
+
parent_node_dict = {}
|
55 |
+
current_node = None
|
56 |
+
parent_node_dict[root_node] = None
|
57 |
+
stack = [root_node]
|
58 |
+
while stack:
|
59 |
+
current_node = stack.pop()
|
60 |
+
current_subhg = clique_tree.nodes[current_node]['subhg']
|
61 |
+
for each_child in clique_tree.adj[current_node]:
|
62 |
+
if each_child != parent_node_dict[current_node]:
|
63 |
+
stack.append(each_child)
|
64 |
+
parent_node_dict[each_child] = current_node
|
65 |
+
if parent_node_dict[current_node] is not None:
|
66 |
+
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
67 |
+
common, _ = common_node_list(parent_subhg, current_subhg)
|
68 |
+
parent_subhg.add_edge(set(common), attr_dict={'tmp': True})
|
69 |
+
|
70 |
+
parent_node_dict = {}
|
71 |
+
current_node = None
|
72 |
+
parent_node_dict[root_node] = None
|
73 |
+
stack = [root_node]
|
74 |
+
while stack:
|
75 |
+
current_node = stack.pop()
|
76 |
+
current_subhg = clique_tree.nodes[current_node]['subhg']
|
77 |
+
for each_child in clique_tree.adj[current_node]:
|
78 |
+
if each_child != parent_node_dict[current_node]:
|
79 |
+
stack.append(each_child)
|
80 |
+
parent_node_dict[each_child] = current_node
|
81 |
+
if parent_node_dict[current_node] is not None:
|
82 |
+
parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg']
|
83 |
+
common, _ = common_node_list(parent_subhg, current_subhg)
|
84 |
+
for each_idx, each_node in enumerate(common):
|
85 |
+
current_subhg.set_node_attr(each_node, {'ext_id': each_idx})
|
86 |
+
|
87 |
+
subhg_idx, is_new = self.add_subhg(current_subhg)
|
88 |
+
clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx
|
89 |
+
return clique_tree
|
90 |
+
|
91 |
+
def add_subhg(self, subhg):
|
92 |
+
if len(self.subhg_list) == 0:
|
93 |
+
node_dict = {}
|
94 |
+
for each_node in subhg.nodes:
|
95 |
+
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
96 |
+
node_list = []
|
97 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
98 |
+
node_list.append(each_key)
|
99 |
+
for each_idx, each_node in enumerate(node_list):
|
100 |
+
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
101 |
+
self.subhg_list.append(subhg)
|
102 |
+
return 0, True
|
103 |
+
else:
|
104 |
+
match = False
|
105 |
+
subhg_bond_symbol_counter \
|
106 |
+
= Counter([subhg.node_attr(each_node)['symbol'] \
|
107 |
+
for each_node in subhg.nodes])
|
108 |
+
subhg_atom_symbol_counter \
|
109 |
+
= Counter([subhg.edge_attr(each_edge).get('symbol', None) \
|
110 |
+
for each_edge in subhg.edges])
|
111 |
+
for each_idx, each_subhg in enumerate(self.subhg_list):
|
112 |
+
each_bond_symbol_counter \
|
113 |
+
= Counter([each_subhg.node_attr(each_node)['symbol'] \
|
114 |
+
for each_node in each_subhg.nodes])
|
115 |
+
each_atom_symbol_counter \
|
116 |
+
= Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \
|
117 |
+
for each_edge in each_subhg.edges])
|
118 |
+
if not match \
|
119 |
+
and (subhg.num_nodes == each_subhg.num_nodes
|
120 |
+
and subhg.num_edges == each_subhg.num_edges
|
121 |
+
and subhg_bond_symbol_counter == each_bond_symbol_counter
|
122 |
+
and subhg_atom_symbol_counter == each_atom_symbol_counter):
|
123 |
+
gm = GraphMatcher(each_subhg.hg,
|
124 |
+
subhg.hg,
|
125 |
+
node_match=_easy_node_match,
|
126 |
+
edge_match=_edge_match)
|
127 |
+
try:
|
128 |
+
isomap = next(gm.isomorphisms_iter())
|
129 |
+
match = True
|
130 |
+
for each_node in each_subhg.nodes:
|
131 |
+
subhg.node_attr(isomap[each_node])['order4hrg'] \
|
132 |
+
= each_subhg.node_attr(each_node)['order4hrg']
|
133 |
+
if 'ext_id' in each_subhg.node_attr(each_node):
|
134 |
+
subhg.node_attr(isomap[each_node])['ext_id'] \
|
135 |
+
= each_subhg.node_attr(each_node)['ext_id']
|
136 |
+
return each_idx, False
|
137 |
+
except StopIteration:
|
138 |
+
match = False
|
139 |
+
if not match:
|
140 |
+
node_dict = {}
|
141 |
+
for each_node in subhg.nodes:
|
142 |
+
node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__()
|
143 |
+
node_list = []
|
144 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
145 |
+
node_list.append(each_key)
|
146 |
+
for each_idx, each_node in enumerate(node_list):
|
147 |
+
subhg.node_attr(each_node)['order4hrg'] = each_idx
|
148 |
+
|
149 |
+
#for each_idx, each_node in enumerate(subhg.nodes):
|
150 |
+
# subhg.node_attr(each_node)['order4hrg'] = each_idx
|
151 |
+
self.subhg_list.append(subhg)
|
152 |
+
return len(self.subhg_list) - 1, True
|
models/mhg_model/graph_grammar/graph_grammar/hrg.py
ADDED
@@ -0,0 +1,1065 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2017"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Dec 11 2017"
|
20 |
+
|
21 |
+
from .corpus import CliqueTreeCorpus
|
22 |
+
from .base import GraphGrammarBase
|
23 |
+
from .symbols import TSymbol, NTSymbol, BondSymbol
|
24 |
+
from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list
|
25 |
+
from ..hypergraph import Hypergraph
|
26 |
+
from collections import Counter
|
27 |
+
from copy import deepcopy
|
28 |
+
from ..algo.tree_decomposition import (
|
29 |
+
tree_decomposition,
|
30 |
+
tree_decomposition_with_hrg,
|
31 |
+
tree_decomposition_from_leaf,
|
32 |
+
topological_tree_decomposition,
|
33 |
+
molecular_tree_decomposition)
|
34 |
+
from functools import partial
|
35 |
+
from networkx.algorithms.isomorphism import GraphMatcher
|
36 |
+
from typing import List, Dict, Tuple
|
37 |
+
import networkx as nx
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
import os
|
41 |
+
import random
|
42 |
+
|
43 |
+
DEBUG = False
|
44 |
+
|
45 |
+
|
46 |
+
class ProductionRule(object):
|
47 |
+
""" A class of a production rule
|
48 |
+
|
49 |
+
Attributes
|
50 |
+
----------
|
51 |
+
lhs : Hypergraph or None
|
52 |
+
the left hand side of the production rule.
|
53 |
+
if None, the rule is a starting rule.
|
54 |
+
rhs : Hypergraph
|
55 |
+
the right hand side of the production rule.
|
56 |
+
"""
|
57 |
+
def __init__(self, lhs, rhs):
|
58 |
+
self.lhs = lhs
|
59 |
+
self.rhs = rhs
|
60 |
+
|
61 |
+
@property
|
62 |
+
def is_start_rule(self) -> bool:
|
63 |
+
return self.lhs.num_nodes == 0
|
64 |
+
|
65 |
+
@property
|
66 |
+
def ext_node(self) -> Dict[int, str]:
|
67 |
+
""" return a dict of external nodes
|
68 |
+
"""
|
69 |
+
if self.is_start_rule:
|
70 |
+
return {}
|
71 |
+
else:
|
72 |
+
ext_node_dict = {}
|
73 |
+
for each_node in self.lhs.nodes:
|
74 |
+
ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node
|
75 |
+
return ext_node_dict
|
76 |
+
|
77 |
+
@property
|
78 |
+
def lhs_nt_symbol(self) -> NTSymbol:
|
79 |
+
if self.is_start_rule:
|
80 |
+
return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[])
|
81 |
+
else:
|
82 |
+
return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol']
|
83 |
+
|
84 |
+
def rhs_adj_mat(self, node_edge_list):
|
85 |
+
''' return the adjacency matrix of rhs of the production rule
|
86 |
+
'''
|
87 |
+
return nx.adjacency_matrix(self.rhs.hg, node_edge_list)
|
88 |
+
|
89 |
+
def draw(self, file_path=None):
|
90 |
+
return self.rhs.draw(file_path)
|
91 |
+
|
92 |
+
def is_same(self, prod_rule, ignore_order=False):
|
93 |
+
""" judge whether this production rule is
|
94 |
+
the same as the input one, `prod_rule`
|
95 |
+
|
96 |
+
Parameters
|
97 |
+
----------
|
98 |
+
prod_rule : ProductionRule
|
99 |
+
production rule to be compared
|
100 |
+
|
101 |
+
Returns
|
102 |
+
-------
|
103 |
+
is_same : bool
|
104 |
+
isomap : dict
|
105 |
+
isomorphism of nodes and hyperedges.
|
106 |
+
ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1',
|
107 |
+
'e36': 'e11', 'e16': 'e12', 'e25': 'e18',
|
108 |
+
'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}.
|
109 |
+
key comes from `prod_rule`, value comes from `self`.
|
110 |
+
"""
|
111 |
+
if self.is_start_rule:
|
112 |
+
if not prod_rule.is_start_rule:
|
113 |
+
return False, {}
|
114 |
+
else:
|
115 |
+
if prod_rule.is_start_rule:
|
116 |
+
return False, {}
|
117 |
+
else:
|
118 |
+
if prod_rule.lhs.num_nodes != self.lhs.num_nodes:
|
119 |
+
return False, {}
|
120 |
+
|
121 |
+
if prod_rule.rhs.num_nodes != self.rhs.num_nodes:
|
122 |
+
return False, {}
|
123 |
+
if prod_rule.rhs.num_edges != self.rhs.num_edges:
|
124 |
+
return False, {}
|
125 |
+
|
126 |
+
subhg_bond_symbol_counter \
|
127 |
+
= Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \
|
128 |
+
for each_node in prod_rule.rhs.nodes])
|
129 |
+
each_bond_symbol_counter \
|
130 |
+
= Counter([self.rhs.node_attr(each_node)['symbol'] \
|
131 |
+
for each_node in self.rhs.nodes])
|
132 |
+
if subhg_bond_symbol_counter != each_bond_symbol_counter:
|
133 |
+
return False, {}
|
134 |
+
|
135 |
+
subhg_atom_symbol_counter \
|
136 |
+
= Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \
|
137 |
+
for each_edge in prod_rule.rhs.edges])
|
138 |
+
each_atom_symbol_counter \
|
139 |
+
= Counter([self.rhs.edge_attr(each_edge)['symbol'] \
|
140 |
+
for each_edge in self.rhs.edges])
|
141 |
+
if subhg_atom_symbol_counter != each_atom_symbol_counter:
|
142 |
+
return False, {}
|
143 |
+
|
144 |
+
gm = GraphMatcher(prod_rule.rhs.hg,
|
145 |
+
self.rhs.hg,
|
146 |
+
partial(_node_match_prod_rule,
|
147 |
+
ignore_order=ignore_order),
|
148 |
+
partial(_edge_match,
|
149 |
+
ignore_order=ignore_order))
|
150 |
+
try:
|
151 |
+
return True, next(gm.isomorphisms_iter())
|
152 |
+
except StopIteration:
|
153 |
+
return False, {}
|
154 |
+
|
155 |
+
def applied_to(self,
|
156 |
+
hg: Hypergraph,
|
157 |
+
edge: str) -> Tuple[Hypergraph, List[str]]:
|
158 |
+
""" augment `hg` by replacing `edge` with `self.rhs`.
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
hg : Hypergraph
|
163 |
+
edge : str
|
164 |
+
`edge` must belong to `hg`
|
165 |
+
|
166 |
+
Returns
|
167 |
+
-------
|
168 |
+
hg : Hypergraph
|
169 |
+
resultant hypergraph
|
170 |
+
nt_edge_list : list
|
171 |
+
list of non-terminal edges
|
172 |
+
"""
|
173 |
+
nt_edge_dict = {}
|
174 |
+
if self.is_start_rule:
|
175 |
+
if (edge is not None) or (hg is not None):
|
176 |
+
ValueError("edge and hg must be None for this prod rule.")
|
177 |
+
hg = Hypergraph()
|
178 |
+
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
179 |
+
for num_idx, each_node in enumerate(self.rhs.nodes):
|
180 |
+
hg.add_node(f"bond_{num_idx}",
|
181 |
+
#attr_dict=deepcopy(self.rhs.node_attr(each_node)))
|
182 |
+
attr_dict=self.rhs.node_attr(each_node))
|
183 |
+
node_map_rhs[each_node] = f"bond_{num_idx}"
|
184 |
+
for each_edge in self.rhs.edges:
|
185 |
+
node_list = []
|
186 |
+
for each_node in self.rhs.nodes_in_edge(each_edge):
|
187 |
+
node_list.append(node_map_rhs[each_node])
|
188 |
+
if isinstance(self.rhs.nodes_in_edge(each_edge), set):
|
189 |
+
node_list = set(node_list)
|
190 |
+
edge_id = hg.add_edge(
|
191 |
+
node_list,
|
192 |
+
#attr_dict=deepcopy(self.rhs.edge_attr(each_edge)))
|
193 |
+
attr_dict=self.rhs.edge_attr(each_edge))
|
194 |
+
if "nt_idx" in hg.edge_attr(edge_id):
|
195 |
+
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
196 |
+
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
197 |
+
return hg, nt_edge_list
|
198 |
+
else:
|
199 |
+
if edge not in hg.edges:
|
200 |
+
raise ValueError("the input hyperedge does not exist.")
|
201 |
+
if hg.edge_attr(edge)["terminal"]:
|
202 |
+
raise ValueError("the input hyperedge is terminal.")
|
203 |
+
if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol:
|
204 |
+
print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol)
|
205 |
+
raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.")
|
206 |
+
if DEBUG:
|
207 |
+
for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
208 |
+
other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx]
|
209 |
+
attr = deepcopy(self.lhs.node_attr(other_node))
|
210 |
+
attr.pop('ext_id')
|
211 |
+
if hg.node_attr(each_node) != attr:
|
212 |
+
raise ValueError('node attributes are inconsistent.')
|
213 |
+
|
214 |
+
# order of nodes that belong to the non-terminal edge in hg
|
215 |
+
nt_order_dict = {} # hg_node -> order ("bond_17" : 1)
|
216 |
+
nt_order_dict_inv = {} # order -> hg_node
|
217 |
+
for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)):
|
218 |
+
nt_order_dict[each_node] = each_idx
|
219 |
+
nt_order_dict_inv[each_idx] = each_node
|
220 |
+
|
221 |
+
# construct a node_map_rhs: rhs -> new hg
|
222 |
+
node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented.
|
223 |
+
node_idx = hg.num_nodes
|
224 |
+
for each_node in self.rhs.nodes:
|
225 |
+
if "ext_id" in self.rhs.node_attr(each_node):
|
226 |
+
node_map_rhs[each_node] \
|
227 |
+
= nt_order_dict_inv[
|
228 |
+
self.rhs.node_attr(each_node)["ext_id"]]
|
229 |
+
else:
|
230 |
+
node_map_rhs[each_node] = f"bond_{node_idx}"
|
231 |
+
node_idx += 1
|
232 |
+
|
233 |
+
# delete non-terminal
|
234 |
+
hg.remove_edge(edge)
|
235 |
+
|
236 |
+
# add nodes to hg
|
237 |
+
for each_node in self.rhs.nodes:
|
238 |
+
hg.add_node(node_map_rhs[each_node],
|
239 |
+
attr_dict=self.rhs.node_attr(each_node))
|
240 |
+
|
241 |
+
# add hyperedges to hg
|
242 |
+
for each_edge in self.rhs.edges:
|
243 |
+
node_list_hg = []
|
244 |
+
for each_node in self.rhs.nodes_in_edge(each_edge):
|
245 |
+
node_list_hg.append(node_map_rhs[each_node])
|
246 |
+
edge_id = hg.add_edge(
|
247 |
+
node_list_hg,
|
248 |
+
attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge)))
|
249 |
+
if "nt_idx" in hg.edge_attr(edge_id):
|
250 |
+
nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id
|
251 |
+
nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))]
|
252 |
+
return hg, nt_edge_list
|
253 |
+
|
254 |
+
def revert(self, hg: Hypergraph, return_subhg=False):
|
255 |
+
''' revert applying this production rule.
|
256 |
+
i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule,
|
257 |
+
this method replaces the subhypergraph with a non-terminal hyperedge.
|
258 |
+
|
259 |
+
Parameters
|
260 |
+
----------
|
261 |
+
hg : Hypergraph
|
262 |
+
hypergraph to be reverted
|
263 |
+
return_subhg : bool
|
264 |
+
if True, the removed subhypergraph will be returned.
|
265 |
+
|
266 |
+
Returns
|
267 |
+
-------
|
268 |
+
hg : Hypergraph
|
269 |
+
the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement.
|
270 |
+
success : bool
|
271 |
+
this indicates whether reverting is successed or not.
|
272 |
+
'''
|
273 |
+
gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule,
|
274 |
+
edge_match=_edge_match)
|
275 |
+
try:
|
276 |
+
# in case when the matched subhg is connected to the other part via external nodes and more.
|
277 |
+
not_iso = True
|
278 |
+
while not_iso:
|
279 |
+
isomap = next(gm.subgraph_isomorphisms_iter())
|
280 |
+
adj_node_set = set([]) # reachable nodes from the internal nodes
|
281 |
+
subhg_node_set = set(isomap.keys()) # nodes in subhg
|
282 |
+
for each_node in subhg_node_set:
|
283 |
+
adj_node_set.add(each_node)
|
284 |
+
if isomap[each_node] not in self.ext_node.values():
|
285 |
+
adj_node_set.update(hg.hg.adj[each_node])
|
286 |
+
if adj_node_set == subhg_node_set:
|
287 |
+
not_iso = False
|
288 |
+
else:
|
289 |
+
if return_subhg:
|
290 |
+
return hg, False, Hypergraph()
|
291 |
+
else:
|
292 |
+
return hg, False
|
293 |
+
inv_isomap = {v: k for k, v in isomap.items()}
|
294 |
+
'''
|
295 |
+
isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19',
|
296 |
+
'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'}
|
297 |
+
where keys come from `hg` and values come from `self.rhs`
|
298 |
+
'''
|
299 |
+
except StopIteration:
|
300 |
+
if return_subhg:
|
301 |
+
return hg, False, Hypergraph()
|
302 |
+
else:
|
303 |
+
return hg, False
|
304 |
+
|
305 |
+
if return_subhg:
|
306 |
+
subhg = Hypergraph()
|
307 |
+
for each_node in hg.nodes:
|
308 |
+
if each_node in isomap:
|
309 |
+
subhg.add_node(each_node, attr_dict=hg.node_attr(each_node))
|
310 |
+
for each_edge in hg.edges:
|
311 |
+
if each_edge in isomap:
|
312 |
+
subhg.add_edge(hg.nodes_in_edge(each_edge),
|
313 |
+
attr_dict=hg.edge_attr(each_edge),
|
314 |
+
edge_name=each_edge)
|
315 |
+
subhg.edge_idx = hg.edge_idx
|
316 |
+
|
317 |
+
# remove subhg except for the externael nodes
|
318 |
+
for each_key, each_val in isomap.items():
|
319 |
+
if each_key.startswith('e'):
|
320 |
+
hg.remove_edge(each_key)
|
321 |
+
for each_key, each_val in isomap.items():
|
322 |
+
if each_key.startswith('bond_'):
|
323 |
+
if each_val not in self.ext_node.values():
|
324 |
+
hg.remove_node(each_key)
|
325 |
+
|
326 |
+
# add non-terminal hyperedge
|
327 |
+
nt_node_list = []
|
328 |
+
for each_ext_id in self.ext_node.keys():
|
329 |
+
nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]])
|
330 |
+
|
331 |
+
hg.add_edge(nt_node_list,
|
332 |
+
attr_dict=dict(
|
333 |
+
terminal=False,
|
334 |
+
symbol=self.lhs_nt_symbol))
|
335 |
+
if return_subhg:
|
336 |
+
return hg, True, subhg
|
337 |
+
else:
|
338 |
+
return hg, True
|
339 |
+
|
340 |
+
|
341 |
+
class ProductionRuleCorpus(object):
|
342 |
+
|
343 |
+
'''
|
344 |
+
A corpus of production rules.
|
345 |
+
This class maintains
|
346 |
+
(i) list of unique production rules,
|
347 |
+
(ii) list of unique edge symbols (both terminal and non-terminal), and
|
348 |
+
(iii) list of unique node symbols.
|
349 |
+
|
350 |
+
Attributes
|
351 |
+
----------
|
352 |
+
prod_rule_list : list
|
353 |
+
list of unique production rules
|
354 |
+
edge_symbol_list : list
|
355 |
+
list of unique symbols (including both terminal and non-terminal)
|
356 |
+
node_symbol_list : list
|
357 |
+
list of node symbols
|
358 |
+
nt_symbol_list : list
|
359 |
+
list of unique lhs symbols
|
360 |
+
ext_id_list : list
|
361 |
+
list of ext_ids
|
362 |
+
lhs_in_prod_rule : array
|
363 |
+
a matrix of lhs vs prod_rule (= lhs_in_prod_rule)
|
364 |
+
'''
|
365 |
+
|
366 |
+
def __init__(self):
|
367 |
+
self.prod_rule_list = []
|
368 |
+
self.edge_symbol_list = []
|
369 |
+
self.edge_symbol_dict = {}
|
370 |
+
self.node_symbol_list = []
|
371 |
+
self.node_symbol_dict = {}
|
372 |
+
self.nt_symbol_list = []
|
373 |
+
self.ext_id_list = []
|
374 |
+
self._lhs_in_prod_rule = None
|
375 |
+
self.lhs_in_prod_rule_row_list = []
|
376 |
+
self.lhs_in_prod_rule_col_list = []
|
377 |
+
|
378 |
+
@property
|
379 |
+
def lhs_in_prod_rule(self):
|
380 |
+
if self._lhs_in_prod_rule is None:
|
381 |
+
self._lhs_in_prod_rule = torch.sparse.FloatTensor(
|
382 |
+
torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(),
|
383 |
+
torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)),
|
384 |
+
torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)])
|
385 |
+
).to_dense()
|
386 |
+
return self._lhs_in_prod_rule
|
387 |
+
|
388 |
+
@property
|
389 |
+
def num_prod_rule(self):
|
390 |
+
''' return the number of production rules
|
391 |
+
|
392 |
+
Returns
|
393 |
+
-------
|
394 |
+
int : the number of unique production rules
|
395 |
+
'''
|
396 |
+
return len(self.prod_rule_list)
|
397 |
+
|
398 |
+
@property
|
399 |
+
def start_rule_list(self):
|
400 |
+
''' return a list of start rules
|
401 |
+
|
402 |
+
Returns
|
403 |
+
-------
|
404 |
+
list : list of start rules
|
405 |
+
'''
|
406 |
+
start_rule_list = []
|
407 |
+
for each_prod_rule in self.prod_rule_list:
|
408 |
+
if each_prod_rule.is_start_rule:
|
409 |
+
start_rule_list.append(each_prod_rule)
|
410 |
+
return start_rule_list
|
411 |
+
|
412 |
+
@property
|
413 |
+
def num_edge_symbol(self):
|
414 |
+
return len(self.edge_symbol_list)
|
415 |
+
|
416 |
+
@property
|
417 |
+
def num_node_symbol(self):
|
418 |
+
return len(self.node_symbol_list)
|
419 |
+
|
420 |
+
@property
|
421 |
+
def num_ext_id(self):
|
422 |
+
return len(self.ext_id_list)
|
423 |
+
|
424 |
+
def construct_feature_vectors(self):
|
425 |
+
''' this method constructs feature vectors for the production rules collected so far.
|
426 |
+
currently, NTSymbol and TSymbol are treated in the same manner.
|
427 |
+
'''
|
428 |
+
feature_id_dict = {}
|
429 |
+
feature_id_dict['TSymbol'] = 0
|
430 |
+
feature_id_dict['NTSymbol'] = 1
|
431 |
+
feature_id_dict['BondSymbol'] = 2
|
432 |
+
for each_edge_symbol in self.edge_symbol_list:
|
433 |
+
for each_attr in each_edge_symbol.__dict__.keys():
|
434 |
+
each_val = each_edge_symbol.__dict__[each_attr]
|
435 |
+
if isinstance(each_val, list):
|
436 |
+
each_val = tuple(each_val)
|
437 |
+
if (each_attr, each_val) not in feature_id_dict:
|
438 |
+
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
439 |
+
|
440 |
+
for each_node_symbol in self.node_symbol_list:
|
441 |
+
for each_attr in each_node_symbol.__dict__.keys():
|
442 |
+
each_val = each_node_symbol.__dict__[each_attr]
|
443 |
+
if isinstance(each_val, list):
|
444 |
+
each_val = tuple(each_val)
|
445 |
+
if (each_attr, each_val) not in feature_id_dict:
|
446 |
+
feature_id_dict[(each_attr, each_val)] = len(feature_id_dict)
|
447 |
+
for each_ext_id in self.ext_id_list:
|
448 |
+
feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict)
|
449 |
+
dim = len(feature_id_dict)
|
450 |
+
|
451 |
+
feature_dict = {}
|
452 |
+
for each_edge_symbol in self.edge_symbol_list:
|
453 |
+
idx_list = []
|
454 |
+
idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__])
|
455 |
+
for each_attr in each_edge_symbol.__dict__.keys():
|
456 |
+
each_val = each_edge_symbol.__dict__[each_attr]
|
457 |
+
if isinstance(each_val, list):
|
458 |
+
each_val = tuple(each_val)
|
459 |
+
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
460 |
+
feature = torch.sparse.LongTensor(
|
461 |
+
torch.LongTensor([idx_list]),
|
462 |
+
torch.ones(len(idx_list)),
|
463 |
+
torch.Size([len(feature_id_dict)])
|
464 |
+
)
|
465 |
+
feature_dict[each_edge_symbol] = feature
|
466 |
+
|
467 |
+
for each_node_symbol in self.node_symbol_list:
|
468 |
+
idx_list = []
|
469 |
+
idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__])
|
470 |
+
for each_attr in each_node_symbol.__dict__.keys():
|
471 |
+
each_val = each_node_symbol.__dict__[each_attr]
|
472 |
+
if isinstance(each_val, list):
|
473 |
+
each_val = tuple(each_val)
|
474 |
+
idx_list.append(feature_id_dict[(each_attr, each_val)])
|
475 |
+
feature = torch.sparse.LongTensor(
|
476 |
+
torch.LongTensor([idx_list]),
|
477 |
+
torch.ones(len(idx_list)),
|
478 |
+
torch.Size([len(feature_id_dict)])
|
479 |
+
)
|
480 |
+
feature_dict[each_node_symbol] = feature
|
481 |
+
for each_ext_id in self.ext_id_list:
|
482 |
+
idx_list = [feature_id_dict[('ext_id', each_ext_id)]]
|
483 |
+
feature_dict[('ext_id', each_ext_id)] \
|
484 |
+
= torch.sparse.LongTensor(
|
485 |
+
torch.LongTensor([idx_list]),
|
486 |
+
torch.ones(len(idx_list)),
|
487 |
+
torch.Size([len(feature_id_dict)])
|
488 |
+
)
|
489 |
+
return feature_dict, dim
|
490 |
+
|
491 |
+
def edge_symbol_idx(self, symbol):
|
492 |
+
return self.edge_symbol_dict[symbol]
|
493 |
+
|
494 |
+
def node_symbol_idx(self, symbol):
|
495 |
+
return self.node_symbol_dict[symbol]
|
496 |
+
|
497 |
+
def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]:
|
498 |
+
""" return whether the input production rule is new or not, and its production rule id.
|
499 |
+
Production rules are regarded as the same if
|
500 |
+
i) there exists a one-to-one mapping of nodes and edges, and
|
501 |
+
ii) all the attributes associated with nodes and hyperedges are the same.
|
502 |
+
|
503 |
+
Parameters
|
504 |
+
----------
|
505 |
+
prod_rule : ProductionRule
|
506 |
+
|
507 |
+
Returns
|
508 |
+
-------
|
509 |
+
prod_rule_id : int
|
510 |
+
production rule index. if new, a new index will be assigned.
|
511 |
+
prod_rule : ProductionRule
|
512 |
+
"""
|
513 |
+
num_lhs = len(self.nt_symbol_list)
|
514 |
+
for each_idx, each_prod_rule in enumerate(self.prod_rule_list):
|
515 |
+
is_same, isomap = prod_rule.is_same(each_prod_rule)
|
516 |
+
if is_same:
|
517 |
+
# we do not care about edge and node names, but care about the order of non-terminal edges.
|
518 |
+
for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs
|
519 |
+
if key.startswith("bond_"):
|
520 |
+
continue
|
521 |
+
|
522 |
+
# rewrite `nt_idx` in `prod_rule` for further processing
|
523 |
+
if "nt_idx" in prod_rule.rhs.edge_attr(val).keys():
|
524 |
+
if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys():
|
525 |
+
raise ValueError
|
526 |
+
prod_rule.rhs.set_edge_attr(
|
527 |
+
val,
|
528 |
+
{'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]})
|
529 |
+
return each_idx, prod_rule
|
530 |
+
self.prod_rule_list.append(prod_rule)
|
531 |
+
self._update_edge_symbol_list(prod_rule)
|
532 |
+
self._update_node_symbol_list(prod_rule)
|
533 |
+
self._update_ext_id_list(prod_rule)
|
534 |
+
|
535 |
+
lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol)
|
536 |
+
self.lhs_in_prod_rule_row_list.append(lhs_idx)
|
537 |
+
self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1)
|
538 |
+
self._lhs_in_prod_rule = None
|
539 |
+
return len(self.prod_rule_list)-1, prod_rule
|
540 |
+
|
541 |
+
def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule:
|
542 |
+
return self.prod_rule_list[prod_rule_idx]
|
543 |
+
|
544 |
+
def sample(self, unmasked_logit_array, nt_symbol, deterministic=False):
|
545 |
+
''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`.
|
546 |
+
|
547 |
+
Parameters
|
548 |
+
----------
|
549 |
+
unmasked_logit_array : array-like, length `num_prod_rule`
|
550 |
+
nt_symbol : NTSymbol
|
551 |
+
'''
|
552 |
+
if not isinstance(unmasked_logit_array, np.ndarray):
|
553 |
+
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
554 |
+
if deterministic:
|
555 |
+
prob = masked_softmax(unmasked_logit_array,
|
556 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
557 |
+
return self.prod_rule_list[np.argmax(prob)]
|
558 |
+
else:
|
559 |
+
return np.random.choice(
|
560 |
+
self.prod_rule_list, 1,
|
561 |
+
p=masked_softmax(unmasked_logit_array,
|
562 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0]
|
563 |
+
|
564 |
+
def masked_logprob(self, unmasked_logit_array, nt_symbol):
|
565 |
+
if not isinstance(unmasked_logit_array, np.ndarray):
|
566 |
+
unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64)
|
567 |
+
prob = masked_softmax(unmasked_logit_array,
|
568 |
+
self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64))
|
569 |
+
return np.log(prob)
|
570 |
+
|
571 |
+
def _update_edge_symbol_list(self, prod_rule: ProductionRule):
|
572 |
+
''' update edge symbol list
|
573 |
+
|
574 |
+
Parameters
|
575 |
+
----------
|
576 |
+
prod_rule : ProductionRule
|
577 |
+
'''
|
578 |
+
if prod_rule.lhs_nt_symbol not in self.nt_symbol_list:
|
579 |
+
self.nt_symbol_list.append(prod_rule.lhs_nt_symbol)
|
580 |
+
|
581 |
+
for each_edge in prod_rule.rhs.edges:
|
582 |
+
if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict:
|
583 |
+
edge_symbol_idx = len(self.edge_symbol_list)
|
584 |
+
self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol'])
|
585 |
+
self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx
|
586 |
+
else:
|
587 |
+
edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']]
|
588 |
+
prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx
|
589 |
+
pass
|
590 |
+
|
591 |
+
def _update_node_symbol_list(self, prod_rule: ProductionRule):
|
592 |
+
''' update node symbol list
|
593 |
+
|
594 |
+
Parameters
|
595 |
+
----------
|
596 |
+
prod_rule : ProductionRule
|
597 |
+
'''
|
598 |
+
for each_node in prod_rule.rhs.nodes:
|
599 |
+
if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict:
|
600 |
+
node_symbol_idx = len(self.node_symbol_list)
|
601 |
+
self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol'])
|
602 |
+
self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx
|
603 |
+
else:
|
604 |
+
node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']]
|
605 |
+
prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx
|
606 |
+
|
607 |
+
def _update_ext_id_list(self, prod_rule: ProductionRule):
|
608 |
+
for each_node in prod_rule.rhs.nodes:
|
609 |
+
if 'ext_id' in prod_rule.rhs.node_attr(each_node):
|
610 |
+
if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list:
|
611 |
+
self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id'])
|
612 |
+
|
613 |
+
|
614 |
+
class HyperedgeReplacementGrammar(GraphGrammarBase):
|
615 |
+
"""
|
616 |
+
Learn a hyperedge replacement grammar from a set of hypergraphs.
|
617 |
+
|
618 |
+
Attributes
|
619 |
+
----------
|
620 |
+
prod_rule_list : list of ProductionRule
|
621 |
+
production rules learned from the input hypergraphs
|
622 |
+
"""
|
623 |
+
def __init__(self,
|
624 |
+
tree_decomposition=molecular_tree_decomposition,
|
625 |
+
ignore_order=False, **kwargs):
|
626 |
+
from functools import partial
|
627 |
+
self.prod_rule_corpus = ProductionRuleCorpus()
|
628 |
+
self.clique_tree_corpus = CliqueTreeCorpus()
|
629 |
+
self.ignore_order = ignore_order
|
630 |
+
self.tree_decomposition = partial(tree_decomposition, **kwargs)
|
631 |
+
|
632 |
+
@property
|
633 |
+
def num_prod_rule(self):
|
634 |
+
''' return the number of production rules
|
635 |
+
|
636 |
+
Returns
|
637 |
+
-------
|
638 |
+
int : the number of unique production rules
|
639 |
+
'''
|
640 |
+
return self.prod_rule_corpus.num_prod_rule
|
641 |
+
|
642 |
+
@property
|
643 |
+
def start_rule_list(self):
|
644 |
+
''' return a list of start rules
|
645 |
+
|
646 |
+
Returns
|
647 |
+
-------
|
648 |
+
list : list of start rules
|
649 |
+
'''
|
650 |
+
return self.prod_rule_corpus.start_rule_list
|
651 |
+
|
652 |
+
@property
|
653 |
+
def prod_rule_list(self):
|
654 |
+
return self.prod_rule_corpus.prod_rule_list
|
655 |
+
|
656 |
+
def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500):
|
657 |
+
""" learn from a list of hypergraphs
|
658 |
+
|
659 |
+
Parameters
|
660 |
+
----------
|
661 |
+
hg_list : list of Hypergraph
|
662 |
+
|
663 |
+
Returns
|
664 |
+
-------
|
665 |
+
prod_rule_seq_list : list of integers
|
666 |
+
each element corresponds to a sequence of production rules to generate each hypergraph.
|
667 |
+
"""
|
668 |
+
prod_rule_seq_list = []
|
669 |
+
idx = 0
|
670 |
+
for each_idx, each_hg in enumerate(hg_list):
|
671 |
+
clique_tree = self.tree_decomposition(each_hg)
|
672 |
+
|
673 |
+
# get a pair of myself and children
|
674 |
+
root_node = _find_root(clique_tree)
|
675 |
+
clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node)
|
676 |
+
prod_rule_seq = []
|
677 |
+
stack = []
|
678 |
+
|
679 |
+
children = sorted(list(clique_tree[root_node].keys()))
|
680 |
+
|
681 |
+
# extract a temporary production rule
|
682 |
+
prod_rule = extract_prod_rule(
|
683 |
+
None,
|
684 |
+
clique_tree.nodes[root_node]["subhg"],
|
685 |
+
[clique_tree.nodes[each_child]["subhg"]
|
686 |
+
for each_child in children],
|
687 |
+
clique_tree.nodes[root_node].get('subhg_idx', None))
|
688 |
+
|
689 |
+
# update the production rule list
|
690 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
691 |
+
children = reorder_children(root_node,
|
692 |
+
children,
|
693 |
+
prod_rule,
|
694 |
+
clique_tree)
|
695 |
+
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
696 |
+
prod_rule_seq.append(prod_rule_id)
|
697 |
+
|
698 |
+
while len(stack) != 0:
|
699 |
+
# get a triple of parent, myself, and children
|
700 |
+
parent, myself = stack.pop()
|
701 |
+
children = sorted(list(dict(clique_tree[myself]).keys()))
|
702 |
+
children.remove(parent)
|
703 |
+
|
704 |
+
# extract a temp prod rule
|
705 |
+
prod_rule = extract_prod_rule(
|
706 |
+
clique_tree.nodes[parent]["subhg"],
|
707 |
+
clique_tree.nodes[myself]["subhg"],
|
708 |
+
[clique_tree.nodes[each_child]["subhg"]
|
709 |
+
for each_child in children],
|
710 |
+
clique_tree.nodes[myself].get('subhg_idx', None))
|
711 |
+
|
712 |
+
# update the prod rule list
|
713 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
714 |
+
children = reorder_children(myself,
|
715 |
+
children,
|
716 |
+
prod_rule,
|
717 |
+
clique_tree)
|
718 |
+
stack.extend([(myself, each_child)
|
719 |
+
for each_child in children[::-1]])
|
720 |
+
prod_rule_seq.append(prod_rule_id)
|
721 |
+
prod_rule_seq_list.append(prod_rule_seq)
|
722 |
+
if (each_idx+1) % print_freq == 0:
|
723 |
+
msg = f'#(molecules processed)={each_idx+1}\t'\
|
724 |
+
f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}'
|
725 |
+
logger(msg)
|
726 |
+
if each_idx > max_mol:
|
727 |
+
break
|
728 |
+
|
729 |
+
print(f'corpus_size = {self.clique_tree_corpus.size}')
|
730 |
+
return prod_rule_seq_list
|
731 |
+
|
732 |
+
def sample(self, z, deterministic=False):
|
733 |
+
""" sample a new hypergraph from HRG.
|
734 |
+
|
735 |
+
Parameters
|
736 |
+
----------
|
737 |
+
z : array-like, shape (len, num_prod_rule)
|
738 |
+
logit
|
739 |
+
deterministic : bool
|
740 |
+
if True, deterministic sampling
|
741 |
+
|
742 |
+
Returns
|
743 |
+
-------
|
744 |
+
Hypergraph
|
745 |
+
"""
|
746 |
+
seq_idx = 0
|
747 |
+
stack = []
|
748 |
+
z = z[:, :-1]
|
749 |
+
init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0,
|
750 |
+
is_aromatic=False,
|
751 |
+
bond_symbol_list=[]),
|
752 |
+
deterministic=deterministic)
|
753 |
+
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
754 |
+
stack = deepcopy(nt_edge_list[::-1])
|
755 |
+
while len(stack) != 0 and seq_idx < z.shape[0]-1:
|
756 |
+
seq_idx += 1
|
757 |
+
nt_edge = stack.pop()
|
758 |
+
nt_symbol = hg.edge_attr(nt_edge)['symbol']
|
759 |
+
prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic)
|
760 |
+
hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge)
|
761 |
+
stack.extend(nt_edge_list[::-1])
|
762 |
+
if len(stack) != 0:
|
763 |
+
raise RuntimeError(f'{len(stack)} non-terminals are left.')
|
764 |
+
return hg
|
765 |
+
|
766 |
+
def construct(self, prod_rule_seq):
|
767 |
+
""" construct a hypergraph following `prod_rule_seq`
|
768 |
+
|
769 |
+
Parameters
|
770 |
+
----------
|
771 |
+
prod_rule_seq : list of integers
|
772 |
+
a sequence of production rules.
|
773 |
+
|
774 |
+
Returns
|
775 |
+
-------
|
776 |
+
UndirectedHypergraph
|
777 |
+
"""
|
778 |
+
seq_idx = 0
|
779 |
+
init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx])
|
780 |
+
hg, nt_edge_list = init_prod_rule.applied_to(None, None)
|
781 |
+
stack = deepcopy(nt_edge_list[::-1])
|
782 |
+
while len(stack) != 0:
|
783 |
+
seq_idx += 1
|
784 |
+
nt_edge = stack.pop()
|
785 |
+
hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge)
|
786 |
+
stack.extend(nt_edge_list[::-1])
|
787 |
+
return hg
|
788 |
+
|
789 |
+
def update_prod_rule_list(self, prod_rule):
|
790 |
+
""" return whether the input production rule is new or not, and its production rule id.
|
791 |
+
Production rules are regarded as the same if
|
792 |
+
i) there exists a one-to-one mapping of nodes and edges, and
|
793 |
+
ii) all the attributes associated with nodes and hyperedges are the same.
|
794 |
+
|
795 |
+
Parameters
|
796 |
+
----------
|
797 |
+
prod_rule : ProductionRule
|
798 |
+
|
799 |
+
Returns
|
800 |
+
-------
|
801 |
+
is_new : bool
|
802 |
+
if True, this production rule is new
|
803 |
+
prod_rule_id : int
|
804 |
+
production rule index. if new, a new index will be assigned.
|
805 |
+
"""
|
806 |
+
return self.prod_rule_corpus.append(prod_rule)
|
807 |
+
|
808 |
+
|
809 |
+
class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar):
|
810 |
+
'''
|
811 |
+
This class learns HRG incrementally leveraging the previously obtained production rules.
|
812 |
+
'''
|
813 |
+
def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False):
|
814 |
+
self.prod_rule_list = []
|
815 |
+
self.tree_decomposition = tree_decomposition
|
816 |
+
self.ignore_order = ignore_order
|
817 |
+
|
818 |
+
def learn(self, hg_list):
|
819 |
+
""" learn from a list of hypergraphs
|
820 |
+
|
821 |
+
Parameters
|
822 |
+
----------
|
823 |
+
hg_list : list of UndirectedHypergraph
|
824 |
+
|
825 |
+
Returns
|
826 |
+
-------
|
827 |
+
prod_rule_seq_list : list of integers
|
828 |
+
each element corresponds to a sequence of production rules to generate each hypergraph.
|
829 |
+
"""
|
830 |
+
prod_rule_seq_list = []
|
831 |
+
for each_hg in hg_list:
|
832 |
+
clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True)
|
833 |
+
|
834 |
+
prod_rule_seq = []
|
835 |
+
stack = []
|
836 |
+
|
837 |
+
# get a pair of myself and children
|
838 |
+
children = sorted(list(clique_tree[root_node].keys()))
|
839 |
+
|
840 |
+
# extract a temporary production rule
|
841 |
+
prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"],
|
842 |
+
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
843 |
+
|
844 |
+
# update the production rule list
|
845 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
846 |
+
children = reorder_children(root_node, children, prod_rule, clique_tree)
|
847 |
+
stack.extend([(root_node, each_child) for each_child in children[::-1]])
|
848 |
+
prod_rule_seq.append(prod_rule_id)
|
849 |
+
|
850 |
+
while len(stack) != 0:
|
851 |
+
# get a triple of parent, myself, and children
|
852 |
+
parent, myself = stack.pop()
|
853 |
+
children = sorted(list(dict(clique_tree[myself]).keys()))
|
854 |
+
children.remove(parent)
|
855 |
+
|
856 |
+
# extract a temp prod rule
|
857 |
+
prod_rule = extract_prod_rule(
|
858 |
+
clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"],
|
859 |
+
[clique_tree.nodes[each_child]["subhg"] for each_child in children])
|
860 |
+
|
861 |
+
# update the prod rule list
|
862 |
+
prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule)
|
863 |
+
children = reorder_children(myself, children, prod_rule, clique_tree)
|
864 |
+
stack.extend([(myself, each_child) for each_child in children[::-1]])
|
865 |
+
prod_rule_seq.append(prod_rule_id)
|
866 |
+
prod_rule_seq_list.append(prod_rule_seq)
|
867 |
+
self._compute_stats()
|
868 |
+
return prod_rule_seq_list
|
869 |
+
|
870 |
+
|
871 |
+
def reorder_children(myself, children, prod_rule, clique_tree):
|
872 |
+
""" reorder children so that they match the order in `prod_rule`.
|
873 |
+
|
874 |
+
Parameters
|
875 |
+
----------
|
876 |
+
myself : int
|
877 |
+
children : list of int
|
878 |
+
prod_rule : ProductionRule
|
879 |
+
clique_tree : nx.Graph
|
880 |
+
|
881 |
+
Returns
|
882 |
+
-------
|
883 |
+
new_children : list of str
|
884 |
+
reordered children
|
885 |
+
"""
|
886 |
+
perm = {} # key : `nt_idx`, val : child
|
887 |
+
for each_edge in prod_rule.rhs.edges:
|
888 |
+
if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys():
|
889 |
+
for each_child in children:
|
890 |
+
common_node_set = set(
|
891 |
+
common_node_list(clique_tree.nodes[myself]["subhg"],
|
892 |
+
clique_tree.nodes[each_child]["subhg"])[0])
|
893 |
+
if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set:
|
894 |
+
assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm
|
895 |
+
perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child
|
896 |
+
new_children = []
|
897 |
+
assert len(perm) == len(children)
|
898 |
+
for i in range(len(perm)):
|
899 |
+
new_children.append(perm[i])
|
900 |
+
return new_children
|
901 |
+
|
902 |
+
|
903 |
+
def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None):
|
904 |
+
""" extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`.
|
905 |
+
|
906 |
+
Parameters
|
907 |
+
----------
|
908 |
+
parent_hg : Hypergraph
|
909 |
+
myself_hg : Hypergraph
|
910 |
+
children_hg_list : list of Hypergraph
|
911 |
+
|
912 |
+
Returns
|
913 |
+
-------
|
914 |
+
ProductionRule, consisting of
|
915 |
+
lhs : Hypergraph or None
|
916 |
+
rhs : Hypergraph
|
917 |
+
"""
|
918 |
+
def _add_ext_node(hg, ext_nodes):
|
919 |
+
""" mark nodes to be external (ordered ids are assigned)
|
920 |
+
|
921 |
+
Parameters
|
922 |
+
----------
|
923 |
+
hg : UndirectedHypergraph
|
924 |
+
ext_nodes : list of str
|
925 |
+
list of external nodes
|
926 |
+
|
927 |
+
Returns
|
928 |
+
-------
|
929 |
+
hg : Hypergraph
|
930 |
+
nodes in `ext_nodes` are marked to be external
|
931 |
+
"""
|
932 |
+
ext_id = 0
|
933 |
+
ext_id_exists = []
|
934 |
+
for each_node in ext_nodes:
|
935 |
+
ext_id_exists.append('ext_id' in hg.node_attr(each_node))
|
936 |
+
if ext_id_exists and any(ext_id_exists) != all(ext_id_exists):
|
937 |
+
raise ValueError
|
938 |
+
if not all(ext_id_exists):
|
939 |
+
for each_node in ext_nodes:
|
940 |
+
hg.node_attr(each_node)['ext_id'] = ext_id
|
941 |
+
ext_id += 1
|
942 |
+
return hg
|
943 |
+
|
944 |
+
def _check_aromatic(hg, node_list):
|
945 |
+
is_aromatic = False
|
946 |
+
node_aromatic_list = []
|
947 |
+
for each_node in node_list:
|
948 |
+
if hg.node_attr(each_node)['symbol'].is_aromatic:
|
949 |
+
is_aromatic = True
|
950 |
+
node_aromatic_list.append(True)
|
951 |
+
else:
|
952 |
+
node_aromatic_list.append(False)
|
953 |
+
return is_aromatic, node_aromatic_list
|
954 |
+
|
955 |
+
def _check_ring(hg):
|
956 |
+
for each_edge in hg.edges:
|
957 |
+
if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])):
|
958 |
+
return False
|
959 |
+
return True
|
960 |
+
|
961 |
+
if parent_hg is None:
|
962 |
+
lhs = Hypergraph()
|
963 |
+
node_list = []
|
964 |
+
else:
|
965 |
+
lhs = Hypergraph()
|
966 |
+
node_list, edge_exists = common_node_list(parent_hg, myself_hg)
|
967 |
+
for each_node in node_list:
|
968 |
+
lhs.add_node(each_node,
|
969 |
+
deepcopy(myself_hg.node_attr(each_node)))
|
970 |
+
is_aromatic, _ = _check_aromatic(parent_hg, node_list)
|
971 |
+
for_ring = _check_ring(myself_hg)
|
972 |
+
bond_symbol_list = []
|
973 |
+
for each_node in node_list:
|
974 |
+
bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol'])
|
975 |
+
lhs.add_edge(
|
976 |
+
node_list,
|
977 |
+
attr_dict=dict(
|
978 |
+
terminal=False,
|
979 |
+
edge_exists=edge_exists,
|
980 |
+
symbol=NTSymbol(
|
981 |
+
degree=len(node_list),
|
982 |
+
is_aromatic=is_aromatic,
|
983 |
+
bond_symbol_list=bond_symbol_list,
|
984 |
+
for_ring=for_ring)))
|
985 |
+
try:
|
986 |
+
lhs = _add_ext_node(lhs, node_list)
|
987 |
+
except ValueError:
|
988 |
+
import pdb; pdb.set_trace()
|
989 |
+
|
990 |
+
rhs = remove_tmp_edge(deepcopy(myself_hg))
|
991 |
+
#rhs = remove_ext_node(rhs)
|
992 |
+
#rhs = remove_nt_edge(rhs)
|
993 |
+
try:
|
994 |
+
rhs = _add_ext_node(rhs, node_list)
|
995 |
+
except ValueError:
|
996 |
+
import pdb; pdb.set_trace()
|
997 |
+
|
998 |
+
nt_idx = 0
|
999 |
+
if children_hg_list is not None:
|
1000 |
+
for each_child_hg in children_hg_list:
|
1001 |
+
node_list, edge_exists = common_node_list(myself_hg, each_child_hg)
|
1002 |
+
is_aromatic, _ = _check_aromatic(myself_hg, node_list)
|
1003 |
+
for_ring = _check_ring(each_child_hg)
|
1004 |
+
bond_symbol_list = []
|
1005 |
+
for each_node in node_list:
|
1006 |
+
bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol'])
|
1007 |
+
rhs.add_edge(
|
1008 |
+
node_list,
|
1009 |
+
attr_dict=dict(
|
1010 |
+
terminal=False,
|
1011 |
+
nt_idx=nt_idx,
|
1012 |
+
edge_exists=edge_exists,
|
1013 |
+
symbol=NTSymbol(degree=len(node_list),
|
1014 |
+
is_aromatic=is_aromatic,
|
1015 |
+
bond_symbol_list=bond_symbol_list,
|
1016 |
+
for_ring=for_ring)))
|
1017 |
+
nt_idx += 1
|
1018 |
+
prod_rule = ProductionRule(lhs, rhs)
|
1019 |
+
prod_rule.subhg_idx = subhg_idx
|
1020 |
+
if DEBUG:
|
1021 |
+
if sorted(list(prod_rule.ext_node.keys())) \
|
1022 |
+
!= list(np.arange(len(prod_rule.ext_node))):
|
1023 |
+
raise RuntimeError('ext_id is not continuous')
|
1024 |
+
return prod_rule
|
1025 |
+
|
1026 |
+
|
1027 |
+
def _find_root(clique_tree):
|
1028 |
+
max_node = None
|
1029 |
+
num_nodes_max = -np.inf
|
1030 |
+
for each_node in clique_tree.nodes:
|
1031 |
+
if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max:
|
1032 |
+
max_node = each_node
|
1033 |
+
num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes
|
1034 |
+
'''
|
1035 |
+
children = sorted(list(clique_tree[each_node].keys()))
|
1036 |
+
prod_rule = extract_prod_rule(None,
|
1037 |
+
clique_tree.nodes[each_node]["subhg"],
|
1038 |
+
[clique_tree.nodes[each_child]["subhg"]
|
1039 |
+
for each_child in children])
|
1040 |
+
for each_start_rule in start_rule_list:
|
1041 |
+
if prod_rule.is_same(each_start_rule):
|
1042 |
+
return each_node
|
1043 |
+
'''
|
1044 |
+
return max_node
|
1045 |
+
|
1046 |
+
def remove_ext_node(hg):
|
1047 |
+
for each_node in hg.nodes:
|
1048 |
+
hg.node_attr(each_node).pop('ext_id', None)
|
1049 |
+
return hg
|
1050 |
+
|
1051 |
+
def remove_nt_edge(hg):
|
1052 |
+
remove_edge_list = []
|
1053 |
+
for each_edge in hg.edges:
|
1054 |
+
if not hg.edge_attr(each_edge)['terminal']:
|
1055 |
+
remove_edge_list.append(each_edge)
|
1056 |
+
hg.remove_edges(remove_edge_list)
|
1057 |
+
return hg
|
1058 |
+
|
1059 |
+
def remove_tmp_edge(hg):
|
1060 |
+
remove_edge_list = []
|
1061 |
+
for each_edge in hg.edges:
|
1062 |
+
if hg.edge_attr(each_edge).get('tmp', False):
|
1063 |
+
remove_edge_list.append(each_edge)
|
1064 |
+
hg.remove_edges(remove_edge_list)
|
1065 |
+
return hg
|
models/mhg_model/graph_grammar/graph_grammar/symbols.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
|
15 |
+
""" Title """
|
16 |
+
|
17 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
18 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
19 |
+
__version__ = "0.1"
|
20 |
+
__date__ = "Jan 1 2018"
|
21 |
+
|
22 |
+
from typing import List
|
23 |
+
|
24 |
+
class TSymbol(object):
|
25 |
+
|
26 |
+
''' terminal symbol
|
27 |
+
|
28 |
+
Attributes
|
29 |
+
----------
|
30 |
+
degree : int
|
31 |
+
the number of nodes in a hyperedge
|
32 |
+
is_aromatic : bool
|
33 |
+
whether or not the hyperedge is in an aromatic ring
|
34 |
+
symbol : str
|
35 |
+
atomic symbol
|
36 |
+
num_explicit_Hs : int
|
37 |
+
the number of hydrogens associated to this hyperedge
|
38 |
+
formal_charge : int
|
39 |
+
charge
|
40 |
+
chirality : int
|
41 |
+
chirality
|
42 |
+
'''
|
43 |
+
|
44 |
+
def __init__(self, degree, is_aromatic,
|
45 |
+
symbol, num_explicit_Hs, formal_charge, chirality):
|
46 |
+
self.degree = degree
|
47 |
+
self.is_aromatic = is_aromatic
|
48 |
+
self.symbol = symbol
|
49 |
+
self.num_explicit_Hs = num_explicit_Hs
|
50 |
+
self.formal_charge = formal_charge
|
51 |
+
self.chirality = chirality
|
52 |
+
|
53 |
+
@property
|
54 |
+
def terminal(self):
|
55 |
+
return True
|
56 |
+
|
57 |
+
def __eq__(self, other):
|
58 |
+
if not isinstance(other, TSymbol):
|
59 |
+
return False
|
60 |
+
if self.degree != other.degree:
|
61 |
+
return False
|
62 |
+
if self.is_aromatic != other.is_aromatic:
|
63 |
+
return False
|
64 |
+
if self.symbol != other.symbol:
|
65 |
+
return False
|
66 |
+
if self.num_explicit_Hs != other.num_explicit_Hs:
|
67 |
+
return False
|
68 |
+
if self.formal_charge != other.formal_charge:
|
69 |
+
return False
|
70 |
+
if self.chirality != other.chirality:
|
71 |
+
return False
|
72 |
+
return True
|
73 |
+
|
74 |
+
def __hash__(self):
|
75 |
+
return self.__str__().__hash__()
|
76 |
+
|
77 |
+
def __str__(self):
|
78 |
+
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
79 |
+
f'symbol={self.symbol}, '\
|
80 |
+
f'num_explicit_Hs={self.num_explicit_Hs}, '\
|
81 |
+
f'formal_charge={self.formal_charge}, chirality={self.chirality}'
|
82 |
+
|
83 |
+
|
84 |
+
class NTSymbol(object):
|
85 |
+
|
86 |
+
''' non-terminal symbol
|
87 |
+
|
88 |
+
Attributes
|
89 |
+
----------
|
90 |
+
degree : int
|
91 |
+
degree of the hyperedge
|
92 |
+
is_aromatic : bool
|
93 |
+
if True, at least one of the associated bonds must be aromatic.
|
94 |
+
node_aromatic_list : list of bool
|
95 |
+
indicate whether each of the nodes is aromatic or not.
|
96 |
+
bond_type_list : list of int
|
97 |
+
bond type of each node"
|
98 |
+
'''
|
99 |
+
|
100 |
+
def __init__(self, degree: int, is_aromatic: bool,
|
101 |
+
bond_symbol_list: list,
|
102 |
+
for_ring=False):
|
103 |
+
self.degree = degree
|
104 |
+
self.is_aromatic = is_aromatic
|
105 |
+
self.for_ring = for_ring
|
106 |
+
self.bond_symbol_list = bond_symbol_list
|
107 |
+
|
108 |
+
@property
|
109 |
+
def terminal(self) -> bool:
|
110 |
+
return False
|
111 |
+
|
112 |
+
@property
|
113 |
+
def symbol(self):
|
114 |
+
return f'NT{self.degree}'
|
115 |
+
|
116 |
+
def __eq__(self, other) -> bool:
|
117 |
+
if not isinstance(other, NTSymbol):
|
118 |
+
return False
|
119 |
+
|
120 |
+
if self.degree != other.degree:
|
121 |
+
return False
|
122 |
+
if self.is_aromatic != other.is_aromatic:
|
123 |
+
return False
|
124 |
+
if self.for_ring != other.for_ring:
|
125 |
+
return False
|
126 |
+
if len(self.bond_symbol_list) != len(other.bond_symbol_list):
|
127 |
+
return False
|
128 |
+
for each_idx in range(len(self.bond_symbol_list)):
|
129 |
+
if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]:
|
130 |
+
return False
|
131 |
+
return True
|
132 |
+
|
133 |
+
def __hash__(self):
|
134 |
+
return self.__str__().__hash__()
|
135 |
+
|
136 |
+
def __str__(self) -> str:
|
137 |
+
return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\
|
138 |
+
f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\
|
139 |
+
f'for_ring={self.for_ring}'
|
140 |
+
|
141 |
+
|
142 |
+
class BondSymbol(object):
|
143 |
+
|
144 |
+
|
145 |
+
''' Bond symbol
|
146 |
+
|
147 |
+
Attributes
|
148 |
+
----------
|
149 |
+
is_aromatic : bool
|
150 |
+
if True, at least one of the associated bonds must be aromatic.
|
151 |
+
bond_type : int
|
152 |
+
bond type of each node"
|
153 |
+
'''
|
154 |
+
|
155 |
+
def __init__(self, is_aromatic: bool,
|
156 |
+
bond_type: int,
|
157 |
+
stereo: int):
|
158 |
+
self.is_aromatic = is_aromatic
|
159 |
+
self.bond_type = bond_type
|
160 |
+
self.stereo = stereo
|
161 |
+
|
162 |
+
def __eq__(self, other) -> bool:
|
163 |
+
if not isinstance(other, BondSymbol):
|
164 |
+
return False
|
165 |
+
|
166 |
+
if self.is_aromatic != other.is_aromatic:
|
167 |
+
return False
|
168 |
+
if self.bond_type != other.bond_type:
|
169 |
+
return False
|
170 |
+
if self.stereo != other.stereo:
|
171 |
+
return False
|
172 |
+
return True
|
173 |
+
|
174 |
+
def __hash__(self):
|
175 |
+
return self.__str__().__hash__()
|
176 |
+
|
177 |
+
def __str__(self) -> str:
|
178 |
+
return f'is_aromatic={self.is_aromatic}, '\
|
179 |
+
f'bond_type={self.bond_type}, '\
|
180 |
+
f'stereo={self.stereo}, '
|
models/mhg_model/graph_grammar/graph_grammar/utils.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jun 4 2018"
|
20 |
+
|
21 |
+
from ..hypergraph import Hypergraph
|
22 |
+
from copy import deepcopy
|
23 |
+
from typing import List
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
|
27 |
+
def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]:
|
28 |
+
""" return a list of common nodes
|
29 |
+
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
hg1, hg2 : Hypergraph
|
33 |
+
|
34 |
+
Returns
|
35 |
+
-------
|
36 |
+
list of str
|
37 |
+
list of common nodes
|
38 |
+
"""
|
39 |
+
if hg1 is None or hg2 is None:
|
40 |
+
return [], False
|
41 |
+
else:
|
42 |
+
node_set = hg1.nodes.intersection(hg2.nodes)
|
43 |
+
node_dict = {}
|
44 |
+
if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]):
|
45 |
+
for each_node in node_set:
|
46 |
+
node_dict[each_node] = hg1.node_attr(each_node)['order4hrg']
|
47 |
+
else:
|
48 |
+
for each_node in node_set:
|
49 |
+
node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__()
|
50 |
+
node_list = []
|
51 |
+
for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]):
|
52 |
+
node_list.append(each_key)
|
53 |
+
edge_name = hg1.has_edge(node_list, ignore_order=True)
|
54 |
+
if edge_name:
|
55 |
+
if not hg1.edge_attr(edge_name).get('terminal', True):
|
56 |
+
node_list = hg1.nodes_in_edge(edge_name)
|
57 |
+
return node_list, True
|
58 |
+
else:
|
59 |
+
return node_list, False
|
60 |
+
|
61 |
+
|
62 |
+
def _node_match(node1, node2):
|
63 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
64 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
65 |
+
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
66 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
67 |
+
# bond_symbol
|
68 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
69 |
+
else:
|
70 |
+
return False
|
71 |
+
|
72 |
+
def _easy_node_match(node1, node2):
|
73 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
74 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
75 |
+
return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None)
|
76 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
77 |
+
# bond_symbol
|
78 |
+
return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\
|
79 |
+
and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
80 |
+
else:
|
81 |
+
return False
|
82 |
+
|
83 |
+
|
84 |
+
def _node_match_prod_rule(node1, node2, ignore_order=False):
|
85 |
+
# if the nodes are hyperedges, `atom_attr` determines the match
|
86 |
+
if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge':
|
87 |
+
return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol']
|
88 |
+
elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node':
|
89 |
+
# ext_id, order4hrg, bond_symbol
|
90 |
+
if ignore_order:
|
91 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']
|
92 |
+
else:
|
93 |
+
return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\
|
94 |
+
and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)
|
95 |
+
else:
|
96 |
+
return False
|
97 |
+
|
98 |
+
|
99 |
+
def _edge_match(edge1, edge2, ignore_order=False):
|
100 |
+
#return True
|
101 |
+
if ignore_order:
|
102 |
+
return True
|
103 |
+
else:
|
104 |
+
return edge1["order"] == edge2["order"]
|
105 |
+
|
106 |
+
def masked_softmax(logit, mask):
|
107 |
+
''' compute a probability distribution from logit
|
108 |
+
|
109 |
+
Parameters
|
110 |
+
----------
|
111 |
+
logit : array-like, length D
|
112 |
+
each element indicates how each dimension is likely to be chosen
|
113 |
+
(the larger, the more likely)
|
114 |
+
mask : array-like, length D
|
115 |
+
each element is either 0 or 1.
|
116 |
+
if 0, the dimension is ignored
|
117 |
+
when computing the probability distribution.
|
118 |
+
|
119 |
+
Returns
|
120 |
+
-------
|
121 |
+
prob_dist : array, length D
|
122 |
+
probability distribution computed from logit.
|
123 |
+
if `mask[d] = 0`, `prob_dist[d] = 0`.
|
124 |
+
'''
|
125 |
+
if logit.shape != mask.shape:
|
126 |
+
raise ValueError('logit and mask must have the same shape')
|
127 |
+
c = np.max(logit)
|
128 |
+
exp_logit = np.exp(logit - c) * mask
|
129 |
+
sum_exp_logit = exp_logit @ mask
|
130 |
+
return exp_logit / sum_exp_logit
|
models/mhg_model/graph_grammar/hypergraph.py
ADDED
@@ -0,0 +1,544 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 31 2018"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from typing import List, Dict, Tuple
|
23 |
+
import networkx as nx
|
24 |
+
import numpy as np
|
25 |
+
import os
|
26 |
+
|
27 |
+
|
28 |
+
class Hypergraph(object):
|
29 |
+
'''
|
30 |
+
A class of a hypergraph.
|
31 |
+
Each hyperedge can be ordered. For the ordered case,
|
32 |
+
edges adjacent to the hyperedge node are labeled by their orders.
|
33 |
+
|
34 |
+
Attributes
|
35 |
+
----------
|
36 |
+
hg : nx.Graph
|
37 |
+
a bipartite graph representation of a hypergraph
|
38 |
+
edge_idx : int
|
39 |
+
total number of hyperedges that exist so far
|
40 |
+
'''
|
41 |
+
def __init__(self):
|
42 |
+
self.hg = nx.Graph()
|
43 |
+
self.edge_idx = 0
|
44 |
+
self.nodes = set([])
|
45 |
+
self.num_nodes = 0
|
46 |
+
self.edges = set([])
|
47 |
+
self.num_edges = 0
|
48 |
+
self.nodes_in_edge_dict = {}
|
49 |
+
|
50 |
+
def add_node(self, node: str, attr_dict=None):
|
51 |
+
''' add a node to hypergraph
|
52 |
+
|
53 |
+
Parameters
|
54 |
+
----------
|
55 |
+
node : str
|
56 |
+
node name
|
57 |
+
attr_dict : dict
|
58 |
+
dictionary of node attributes
|
59 |
+
'''
|
60 |
+
self.hg.add_node(node, bipartite='node', attr_dict=attr_dict)
|
61 |
+
if node not in self.nodes:
|
62 |
+
self.num_nodes += 1
|
63 |
+
self.nodes.add(node)
|
64 |
+
|
65 |
+
def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None):
|
66 |
+
''' add an edge consisting of nodes `node_list`
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
node_list : list
|
71 |
+
ordered list of nodes that consist the edge
|
72 |
+
attr_dict : dict
|
73 |
+
dictionary of edge attributes
|
74 |
+
'''
|
75 |
+
if edge_name is None:
|
76 |
+
edge = 'e{}'.format(self.edge_idx)
|
77 |
+
else:
|
78 |
+
assert edge_name not in self.edges
|
79 |
+
edge = edge_name
|
80 |
+
self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict)
|
81 |
+
if edge not in self.edges:
|
82 |
+
self.num_edges += 1
|
83 |
+
self.edges.add(edge)
|
84 |
+
self.nodes_in_edge_dict[edge] = node_list
|
85 |
+
if type(node_list) == list:
|
86 |
+
for node_idx, each_node in enumerate(node_list):
|
87 |
+
self.hg.add_edge(edge, each_node, order=node_idx)
|
88 |
+
if each_node not in self.nodes:
|
89 |
+
self.num_nodes += 1
|
90 |
+
self.nodes.add(each_node)
|
91 |
+
|
92 |
+
elif type(node_list) == set:
|
93 |
+
for each_node in node_list:
|
94 |
+
self.hg.add_edge(edge, each_node, order=-1)
|
95 |
+
if each_node not in self.nodes:
|
96 |
+
self.num_nodes += 1
|
97 |
+
self.nodes.add(each_node)
|
98 |
+
else:
|
99 |
+
raise ValueError
|
100 |
+
self.edge_idx += 1
|
101 |
+
return edge
|
102 |
+
|
103 |
+
def remove_node(self, node: str, remove_connected_edges=True):
|
104 |
+
''' remove a node
|
105 |
+
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
node : str
|
109 |
+
node name
|
110 |
+
remove_connected_edges : bool
|
111 |
+
if True, remove edges that are adjacent to the node
|
112 |
+
'''
|
113 |
+
if remove_connected_edges:
|
114 |
+
connected_edges = deepcopy(self.adj_edges(node))
|
115 |
+
for each_edge in connected_edges:
|
116 |
+
self.remove_edge(each_edge)
|
117 |
+
self.hg.remove_node(node)
|
118 |
+
self.num_nodes -= 1
|
119 |
+
self.nodes.remove(node)
|
120 |
+
|
121 |
+
def remove_nodes(self, node_iter, remove_connected_edges=True):
|
122 |
+
''' remove a set of nodes
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
node_iter : iterator of strings
|
127 |
+
nodes to be removed
|
128 |
+
remove_connected_edges : bool
|
129 |
+
if True, remove edges that are adjacent to the node
|
130 |
+
'''
|
131 |
+
for each_node in node_iter:
|
132 |
+
self.remove_node(each_node, remove_connected_edges)
|
133 |
+
|
134 |
+
def remove_edge(self, edge: str):
|
135 |
+
''' remove an edge
|
136 |
+
|
137 |
+
Parameters
|
138 |
+
----------
|
139 |
+
edge : str
|
140 |
+
edge to be removed
|
141 |
+
'''
|
142 |
+
self.hg.remove_node(edge)
|
143 |
+
self.edges.remove(edge)
|
144 |
+
self.num_edges -= 1
|
145 |
+
self.nodes_in_edge_dict.pop(edge)
|
146 |
+
|
147 |
+
def remove_edges(self, edge_iter):
|
148 |
+
''' remove a set of edges
|
149 |
+
|
150 |
+
Parameters
|
151 |
+
----------
|
152 |
+
edge_iter : iterator of strings
|
153 |
+
edges to be removed
|
154 |
+
'''
|
155 |
+
for each_edge in edge_iter:
|
156 |
+
self.remove_edge(each_edge)
|
157 |
+
|
158 |
+
def remove_edges_with_attr(self, edge_attr_dict):
|
159 |
+
remove_edge_list = []
|
160 |
+
for each_edge in self.edges:
|
161 |
+
satisfy = True
|
162 |
+
for each_key, each_val in edge_attr_dict.items():
|
163 |
+
if not satisfy:
|
164 |
+
break
|
165 |
+
try:
|
166 |
+
if self.edge_attr(each_edge)[each_key] != each_val:
|
167 |
+
satisfy = False
|
168 |
+
except KeyError:
|
169 |
+
satisfy = False
|
170 |
+
if satisfy:
|
171 |
+
remove_edge_list.append(each_edge)
|
172 |
+
self.remove_edges(remove_edge_list)
|
173 |
+
|
174 |
+
def remove_subhg(self, subhg):
|
175 |
+
''' remove subhypergraph.
|
176 |
+
all of the hyperedges are removed.
|
177 |
+
each node of subhg is removed if its degree becomes 0 after removing hyperedges.
|
178 |
+
|
179 |
+
Parameters
|
180 |
+
----------
|
181 |
+
subhg : Hypergraph
|
182 |
+
'''
|
183 |
+
for each_edge in subhg.edges:
|
184 |
+
self.remove_edge(each_edge)
|
185 |
+
for each_node in subhg.nodes:
|
186 |
+
if self.degree(each_node) == 0:
|
187 |
+
self.remove_node(each_node)
|
188 |
+
|
189 |
+
def nodes_in_edge(self, edge):
|
190 |
+
''' return an ordered list of nodes in a given edge.
|
191 |
+
|
192 |
+
Parameters
|
193 |
+
----------
|
194 |
+
edge : str
|
195 |
+
edge whose nodes are returned
|
196 |
+
|
197 |
+
Returns
|
198 |
+
-------
|
199 |
+
list or set
|
200 |
+
ordered list or set of nodes that belong to the edge
|
201 |
+
'''
|
202 |
+
if edge.startswith('e'):
|
203 |
+
return self.nodes_in_edge_dict[edge]
|
204 |
+
else:
|
205 |
+
adj_node_list = self.hg.adj[edge]
|
206 |
+
adj_node_order_list = []
|
207 |
+
adj_node_name_list = []
|
208 |
+
for each_node in adj_node_list:
|
209 |
+
adj_node_order_list.append(adj_node_list[each_node]['order'])
|
210 |
+
adj_node_name_list.append(each_node)
|
211 |
+
if adj_node_order_list == [-1] * len(adj_node_order_list):
|
212 |
+
return set(adj_node_name_list)
|
213 |
+
else:
|
214 |
+
return [adj_node_name_list[each_idx] for each_idx
|
215 |
+
in np.argsort(adj_node_order_list)]
|
216 |
+
|
217 |
+
def adj_edges(self, node):
|
218 |
+
''' return a dict of adjacent hyperedges
|
219 |
+
|
220 |
+
Parameters
|
221 |
+
----------
|
222 |
+
node : str
|
223 |
+
|
224 |
+
Returns
|
225 |
+
-------
|
226 |
+
set
|
227 |
+
set of edges that are adjacent to `node`
|
228 |
+
'''
|
229 |
+
return self.hg.adj[node]
|
230 |
+
|
231 |
+
def adj_nodes(self, node):
|
232 |
+
''' return a set of adjacent nodes
|
233 |
+
|
234 |
+
Parameters
|
235 |
+
----------
|
236 |
+
node : str
|
237 |
+
|
238 |
+
Returns
|
239 |
+
-------
|
240 |
+
set
|
241 |
+
set of nodes that are adjacent to `node`
|
242 |
+
'''
|
243 |
+
node_set = set([])
|
244 |
+
for each_adj_edge in self.adj_edges(node):
|
245 |
+
node_set.update(set(self.nodes_in_edge(each_adj_edge)))
|
246 |
+
node_set.discard(node)
|
247 |
+
return node_set
|
248 |
+
|
249 |
+
def has_edge(self, node_list, ignore_order=False):
|
250 |
+
for each_edge in self.edges:
|
251 |
+
if ignore_order:
|
252 |
+
if set(self.nodes_in_edge(each_edge)) == set(node_list):
|
253 |
+
return each_edge
|
254 |
+
else:
|
255 |
+
if self.nodes_in_edge(each_edge) == node_list:
|
256 |
+
return each_edge
|
257 |
+
return False
|
258 |
+
|
259 |
+
def degree(self, node):
|
260 |
+
return len(self.hg.adj[node])
|
261 |
+
|
262 |
+
def degrees(self):
|
263 |
+
return {each_node: self.degree(each_node) for each_node in self.nodes}
|
264 |
+
|
265 |
+
def edge_degree(self, edge):
|
266 |
+
return len(self.nodes_in_edge(edge))
|
267 |
+
|
268 |
+
def edge_degrees(self):
|
269 |
+
return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges}
|
270 |
+
|
271 |
+
def is_adj(self, node1, node2):
|
272 |
+
return node1 in self.adj_nodes(node2)
|
273 |
+
|
274 |
+
def adj_subhg(self, node, ident_node_dict=None):
|
275 |
+
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
276 |
+
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
277 |
+
|
278 |
+
Parameters
|
279 |
+
----------
|
280 |
+
node : str
|
281 |
+
ident_node_dict : dict
|
282 |
+
dict containing identical nodes. see `get_identical_node_dict` for more details
|
283 |
+
|
284 |
+
Returns
|
285 |
+
-------
|
286 |
+
subhg : Hypergraph
|
287 |
+
"""
|
288 |
+
if ident_node_dict is None:
|
289 |
+
ident_node_dict = self.get_identical_node_dict()
|
290 |
+
adj_node_set = set(ident_node_dict[node])
|
291 |
+
adj_edge_set = set([])
|
292 |
+
for each_node in ident_node_dict[node]:
|
293 |
+
adj_edge_set.update(set(self.adj_edges(each_node)))
|
294 |
+
fixed_adj_edge_set = deepcopy(adj_edge_set)
|
295 |
+
for each_edge in fixed_adj_edge_set:
|
296 |
+
other_nodes = self.nodes_in_edge(each_edge)
|
297 |
+
adj_node_set.update(other_nodes)
|
298 |
+
|
299 |
+
# if the adjacent node has self-loop edge, it will be appended to adj_edge_list.
|
300 |
+
for each_node in other_nodes:
|
301 |
+
for other_edge in set(self.adj_edges(each_node)) - set([each_edge]):
|
302 |
+
if len(set(self.nodes_in_edge(other_edge)) \
|
303 |
+
- set(self.nodes_in_edge(each_edge))) == 0:
|
304 |
+
adj_edge_set.update(set([other_edge]))
|
305 |
+
subhg = Hypergraph()
|
306 |
+
for each_node in adj_node_set:
|
307 |
+
subhg.add_node(each_node, attr_dict=self.node_attr(each_node))
|
308 |
+
for each_edge in adj_edge_set:
|
309 |
+
subhg.add_edge(self.nodes_in_edge(each_edge),
|
310 |
+
attr_dict=self.edge_attr(each_edge),
|
311 |
+
edge_name=each_edge)
|
312 |
+
subhg.edge_idx = self.edge_idx
|
313 |
+
return subhg
|
314 |
+
|
315 |
+
def get_subhg(self, node_list, edge_list, ident_node_dict=None):
|
316 |
+
""" return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`.
|
317 |
+
if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph.
|
318 |
+
|
319 |
+
Parameters
|
320 |
+
----------
|
321 |
+
node : str
|
322 |
+
ident_node_dict : dict
|
323 |
+
dict containing identical nodes. see `get_identical_node_dict` for more details
|
324 |
+
|
325 |
+
Returns
|
326 |
+
-------
|
327 |
+
subhg : Hypergraph
|
328 |
+
"""
|
329 |
+
if ident_node_dict is None:
|
330 |
+
ident_node_dict = self.get_identical_node_dict()
|
331 |
+
adj_node_set = set([])
|
332 |
+
for each_node in node_list:
|
333 |
+
adj_node_set.update(set(ident_node_dict[each_node]))
|
334 |
+
adj_edge_set = set(edge_list)
|
335 |
+
|
336 |
+
subhg = Hypergraph()
|
337 |
+
for each_node in adj_node_set:
|
338 |
+
subhg.add_node(each_node,
|
339 |
+
attr_dict=deepcopy(self.node_attr(each_node)))
|
340 |
+
for each_edge in adj_edge_set:
|
341 |
+
subhg.add_edge(self.nodes_in_edge(each_edge),
|
342 |
+
attr_dict=deepcopy(self.edge_attr(each_edge)),
|
343 |
+
edge_name=each_edge)
|
344 |
+
subhg.edge_idx = self.edge_idx
|
345 |
+
return subhg
|
346 |
+
|
347 |
+
def copy(self):
|
348 |
+
''' return a copy of the object
|
349 |
+
|
350 |
+
Returns
|
351 |
+
-------
|
352 |
+
Hypergraph
|
353 |
+
'''
|
354 |
+
return deepcopy(self)
|
355 |
+
|
356 |
+
def node_attr(self, node):
|
357 |
+
return self.hg.nodes[node]['attr_dict']
|
358 |
+
|
359 |
+
def edge_attr(self, edge):
|
360 |
+
return self.hg.nodes[edge]['attr_dict']
|
361 |
+
|
362 |
+
def set_node_attr(self, node, attr_dict):
|
363 |
+
for each_key, each_val in attr_dict.items():
|
364 |
+
self.hg.nodes[node]['attr_dict'][each_key] = each_val
|
365 |
+
|
366 |
+
def set_edge_attr(self, edge, attr_dict):
|
367 |
+
for each_key, each_val in attr_dict.items():
|
368 |
+
self.hg.nodes[edge]['attr_dict'][each_key] = each_val
|
369 |
+
|
370 |
+
def get_identical_node_dict(self):
|
371 |
+
''' get identical nodes
|
372 |
+
nodes are identical if they share the same set of adjacent edges.
|
373 |
+
|
374 |
+
Returns
|
375 |
+
-------
|
376 |
+
ident_node_dict : dict
|
377 |
+
ident_node_dict[node] returns a list of nodes that are identical to `node`.
|
378 |
+
'''
|
379 |
+
ident_node_dict = {}
|
380 |
+
for each_node in self.nodes:
|
381 |
+
ident_node_list = []
|
382 |
+
for each_other_node in self.nodes:
|
383 |
+
if each_other_node == each_node:
|
384 |
+
ident_node_list.append(each_other_node)
|
385 |
+
elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \
|
386 |
+
and len(self.adj_edges(each_node)) != 0:
|
387 |
+
ident_node_list.append(each_other_node)
|
388 |
+
ident_node_dict[each_node] = ident_node_list
|
389 |
+
return ident_node_dict
|
390 |
+
'''
|
391 |
+
ident_node_dict = {}
|
392 |
+
for each_node in self.nodes:
|
393 |
+
ident_node_dict[each_node] = [each_node]
|
394 |
+
return ident_node_dict
|
395 |
+
'''
|
396 |
+
|
397 |
+
def get_leaf_edge(self):
|
398 |
+
''' get an edge that is incident only to one edge
|
399 |
+
|
400 |
+
Returns
|
401 |
+
-------
|
402 |
+
if exists, return a leaf edge. otherwise, return None.
|
403 |
+
'''
|
404 |
+
for each_edge in self.edges:
|
405 |
+
if len(self.adj_nodes(each_edge)) == 1:
|
406 |
+
if 'tmp' not in self.edge_attr(each_edge):
|
407 |
+
return each_edge
|
408 |
+
return None
|
409 |
+
|
410 |
+
def get_nontmp_edge(self):
|
411 |
+
for each_edge in self.edges:
|
412 |
+
if 'tmp' not in self.edge_attr(each_edge):
|
413 |
+
return each_edge
|
414 |
+
return None
|
415 |
+
|
416 |
+
def is_subhg(self, hg):
|
417 |
+
''' return whether this hypergraph is a subhypergraph of `hg`
|
418 |
+
|
419 |
+
Returns
|
420 |
+
-------
|
421 |
+
True if self \in hg,
|
422 |
+
False otherwise.
|
423 |
+
'''
|
424 |
+
for each_node in self.nodes:
|
425 |
+
if each_node not in hg.nodes:
|
426 |
+
return False
|
427 |
+
for each_edge in self.edges:
|
428 |
+
if each_edge not in hg.edges:
|
429 |
+
return False
|
430 |
+
return True
|
431 |
+
|
432 |
+
def in_cycle(self, node, visited=None, parent='', root_node='') -> bool:
|
433 |
+
''' if `node` is in a cycle, then return True. otherwise, False.
|
434 |
+
|
435 |
+
Parameters
|
436 |
+
----------
|
437 |
+
node : str
|
438 |
+
node in a hypergraph
|
439 |
+
visited : list
|
440 |
+
list of visited nodes, used for recursion
|
441 |
+
parent : str
|
442 |
+
parent node, used to eliminate a cycle consisting of two nodes and one edge.
|
443 |
+
|
444 |
+
Returns
|
445 |
+
-------
|
446 |
+
bool
|
447 |
+
'''
|
448 |
+
if visited is None:
|
449 |
+
visited = []
|
450 |
+
if parent == '':
|
451 |
+
visited = []
|
452 |
+
if root_node == '':
|
453 |
+
root_node = node
|
454 |
+
visited.append(node)
|
455 |
+
for each_adj_node in self.adj_nodes(node):
|
456 |
+
if each_adj_node not in visited:
|
457 |
+
if self.in_cycle(each_adj_node, visited, node, root_node):
|
458 |
+
return True
|
459 |
+
elif each_adj_node != parent and each_adj_node == root_node:
|
460 |
+
return True
|
461 |
+
return False
|
462 |
+
|
463 |
+
|
464 |
+
def draw(self, file_path=None, with_node=False, with_edge_name=False):
|
465 |
+
''' draw hypergraph
|
466 |
+
'''
|
467 |
+
import graphviz
|
468 |
+
G = graphviz.Graph(format='png')
|
469 |
+
for each_node in self.nodes:
|
470 |
+
if 'ext_id' in self.node_attr(each_node):
|
471 |
+
G.node(each_node, label='',
|
472 |
+
shape='circle', width='0.1', height='0.1', style='filled',
|
473 |
+
fillcolor='black')
|
474 |
+
else:
|
475 |
+
if with_node:
|
476 |
+
G.node(each_node, label='',
|
477 |
+
shape='circle', width='0.1', height='0.1', style='filled',
|
478 |
+
fillcolor='gray')
|
479 |
+
edge_list = []
|
480 |
+
for each_edge in self.edges:
|
481 |
+
if self.edge_attr(each_edge).get('terminal', False):
|
482 |
+
G.node(each_edge,
|
483 |
+
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
484 |
+
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
485 |
+
fontcolor='black', shape='square')
|
486 |
+
elif self.edge_attr(each_edge).get('tmp', False):
|
487 |
+
G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge,
|
488 |
+
fontcolor='black', shape='square')
|
489 |
+
else:
|
490 |
+
G.node(each_edge,
|
491 |
+
label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \
|
492 |
+
else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge,
|
493 |
+
fontcolor='black', shape='square', style='filled')
|
494 |
+
if with_node:
|
495 |
+
for each_node in self.nodes_in_edge(each_edge):
|
496 |
+
G.edge(each_edge, each_node)
|
497 |
+
else:
|
498 |
+
for each_node in self.nodes_in_edge(each_edge):
|
499 |
+
if 'ext_id' in self.node_attr(each_node)\
|
500 |
+
and set([each_node, each_edge]) not in edge_list:
|
501 |
+
G.edge(each_edge, each_node)
|
502 |
+
edge_list.append(set([each_node, each_edge]))
|
503 |
+
for each_other_edge in self.adj_nodes(each_edge):
|
504 |
+
if set([each_edge, each_other_edge]) not in edge_list:
|
505 |
+
num_bond = 0
|
506 |
+
common_node_set = set(self.nodes_in_edge(each_edge))\
|
507 |
+
.intersection(set(self.nodes_in_edge(each_other_edge)))
|
508 |
+
for each_node in common_node_set:
|
509 |
+
if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]:
|
510 |
+
num_bond += self.node_attr(each_node)['symbol'].bond_type
|
511 |
+
elif self.node_attr(each_node)['symbol'].bond_type in [12]:
|
512 |
+
num_bond += 1
|
513 |
+
else:
|
514 |
+
raise NotImplementedError('unsupported bond type')
|
515 |
+
for _ in range(num_bond):
|
516 |
+
G.edge(each_edge, each_other_edge)
|
517 |
+
edge_list.append(set([each_edge, each_other_edge]))
|
518 |
+
if file_path is not None:
|
519 |
+
G.render(file_path, cleanup=True)
|
520 |
+
#os.remove(file_path)
|
521 |
+
return G
|
522 |
+
|
523 |
+
def is_dividable(self, node):
|
524 |
+
_hg = deepcopy(self.hg)
|
525 |
+
_hg.remove_node(node)
|
526 |
+
return (not nx.is_connected(_hg))
|
527 |
+
|
528 |
+
def divide(self, node):
|
529 |
+
subhg_list = []
|
530 |
+
|
531 |
+
hg_wo_node = deepcopy(self)
|
532 |
+
hg_wo_node.remove_node(node, remove_connected_edges=False)
|
533 |
+
connected_components = nx.connected_components(hg_wo_node.hg)
|
534 |
+
for each_component in connected_components:
|
535 |
+
node_list = [node]
|
536 |
+
edge_list = []
|
537 |
+
node_list.extend([each_node for each_node in each_component
|
538 |
+
if each_node.startswith('bond_')])
|
539 |
+
edge_list.extend([each_edge for each_edge in each_component
|
540 |
+
if each_edge.startswith('e')])
|
541 |
+
subhg_list.append(self.get_subhg(node_list, edge_list))
|
542 |
+
#subhg_list[-1].set_node_attr(node, {'divided': True})
|
543 |
+
return subhg_list
|
544 |
+
|
models/mhg_model/graph_grammar/io/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
models/mhg_model/graph_grammar/io/smi.py
ADDED
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 12 2018"
|
20 |
+
|
21 |
+
from copy import deepcopy
|
22 |
+
from rdkit import Chem
|
23 |
+
from rdkit import RDLogger
|
24 |
+
import networkx as nx
|
25 |
+
import numpy as np
|
26 |
+
from ..hypergraph import Hypergraph
|
27 |
+
from ..graph_grammar.symbols import TSymbol, BondSymbol
|
28 |
+
|
29 |
+
# supress warnings
|
30 |
+
lg = RDLogger.logger()
|
31 |
+
lg.setLevel(RDLogger.CRITICAL)
|
32 |
+
|
33 |
+
|
34 |
+
class HGGen(object):
|
35 |
+
"""
|
36 |
+
load .smi file and yield a hypergraph.
|
37 |
+
|
38 |
+
Attributes
|
39 |
+
----------
|
40 |
+
path_to_file : str
|
41 |
+
path to .smi file
|
42 |
+
kekulize : bool
|
43 |
+
kekulize or not
|
44 |
+
add_Hs : bool
|
45 |
+
add implicit hydrogens to the molecule or not.
|
46 |
+
all_single : bool
|
47 |
+
if True, all multiple bonds are summarized into a single bond with some attributes
|
48 |
+
|
49 |
+
Yields
|
50 |
+
------
|
51 |
+
Hypergraph
|
52 |
+
"""
|
53 |
+
def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True):
|
54 |
+
self.num_line = 1
|
55 |
+
self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False)
|
56 |
+
self.kekulize = kekulize
|
57 |
+
self.add_Hs = add_Hs
|
58 |
+
self.all_single = all_single
|
59 |
+
|
60 |
+
def __iter__(self):
|
61 |
+
return self
|
62 |
+
|
63 |
+
def __next__(self):
|
64 |
+
'''
|
65 |
+
each_mol = None
|
66 |
+
while each_mol is None:
|
67 |
+
each_mol = next(self.mol_gen)
|
68 |
+
'''
|
69 |
+
# not ignoring parse errors
|
70 |
+
each_mol = next(self.mol_gen)
|
71 |
+
if each_mol is None:
|
72 |
+
raise ValueError(f'incorrect smiles in line {self.num_line}')
|
73 |
+
else:
|
74 |
+
self.num_line += 1
|
75 |
+
return mol_to_hg(each_mol, self.kekulize, self.add_Hs)
|
76 |
+
|
77 |
+
|
78 |
+
def mol_to_bipartite(mol, kekulize):
|
79 |
+
"""
|
80 |
+
get a bipartite representation of a molecule.
|
81 |
+
|
82 |
+
Parameters
|
83 |
+
----------
|
84 |
+
mol : rdkit.Chem.rdchem.Mol
|
85 |
+
molecule object
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
nx.Graph
|
90 |
+
a bipartite graph representing which bond is connected to which atoms.
|
91 |
+
"""
|
92 |
+
try:
|
93 |
+
mol = standardize_stereo(mol)
|
94 |
+
except KeyError:
|
95 |
+
print(Chem.MolToSmiles(mol))
|
96 |
+
raise KeyError
|
97 |
+
|
98 |
+
if kekulize:
|
99 |
+
Chem.Kekulize(mol)
|
100 |
+
|
101 |
+
bipartite_g = nx.Graph()
|
102 |
+
for each_atom in mol.GetAtoms():
|
103 |
+
bipartite_g.add_node(f"atom_{each_atom.GetIdx()}",
|
104 |
+
atom_attr=atom_attr(each_atom, kekulize))
|
105 |
+
|
106 |
+
for each_bond in mol.GetBonds():
|
107 |
+
bond_idx = each_bond.GetIdx()
|
108 |
+
bipartite_g.add_node(
|
109 |
+
f"bond_{bond_idx}",
|
110 |
+
bond_attr=bond_attr(each_bond, kekulize))
|
111 |
+
bipartite_g.add_edge(
|
112 |
+
f"atom_{each_bond.GetBeginAtomIdx()}",
|
113 |
+
f"bond_{bond_idx}")
|
114 |
+
bipartite_g.add_edge(
|
115 |
+
f"atom_{each_bond.GetEndAtomIdx()}",
|
116 |
+
f"bond_{bond_idx}")
|
117 |
+
return bipartite_g
|
118 |
+
|
119 |
+
|
120 |
+
def mol_to_hg(mol, kekulize, add_Hs):
|
121 |
+
"""
|
122 |
+
get a bipartite representation of a molecule.
|
123 |
+
|
124 |
+
Parameters
|
125 |
+
----------
|
126 |
+
mol : rdkit.Chem.rdchem.Mol
|
127 |
+
molecule object
|
128 |
+
kekulize : bool
|
129 |
+
kekulize or not
|
130 |
+
add_Hs : bool
|
131 |
+
add implicit hydrogens to the molecule or not.
|
132 |
+
|
133 |
+
Returns
|
134 |
+
-------
|
135 |
+
Hypergraph
|
136 |
+
"""
|
137 |
+
if add_Hs:
|
138 |
+
mol = Chem.AddHs(mol)
|
139 |
+
|
140 |
+
if kekulize:
|
141 |
+
Chem.Kekulize(mol)
|
142 |
+
|
143 |
+
bipartite_g = mol_to_bipartite(mol, kekulize)
|
144 |
+
hg = Hypergraph()
|
145 |
+
for each_atom in [each_node for each_node in bipartite_g.nodes()
|
146 |
+
if each_node.startswith('atom_')]:
|
147 |
+
node_set = set([])
|
148 |
+
for each_bond in bipartite_g.adj[each_atom]:
|
149 |
+
hg.add_node(each_bond,
|
150 |
+
attr_dict=bipartite_g.nodes[each_bond]['bond_attr'])
|
151 |
+
node_set.add(each_bond)
|
152 |
+
hg.add_edge(node_set,
|
153 |
+
attr_dict=bipartite_g.nodes[each_atom]['atom_attr'])
|
154 |
+
return hg
|
155 |
+
|
156 |
+
|
157 |
+
def hg_to_mol(hg, verbose=False):
|
158 |
+
""" convert a hypergraph into Mol object
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
hg : Hypergraph
|
163 |
+
|
164 |
+
Returns
|
165 |
+
-------
|
166 |
+
mol : Chem.RWMol
|
167 |
+
"""
|
168 |
+
mol = Chem.RWMol()
|
169 |
+
atom_dict = {}
|
170 |
+
bond_set = set([])
|
171 |
+
for each_edge in hg.edges:
|
172 |
+
atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol)
|
173 |
+
atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs)
|
174 |
+
atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge)
|
175 |
+
atom.SetChiralTag(
|
176 |
+
Chem.rdchem.ChiralType.values[
|
177 |
+
hg.edge_attr(each_edge)['symbol'].chirality])
|
178 |
+
atom_idx = mol.AddAtom(atom)
|
179 |
+
atom_dict[each_edge] = atom_idx
|
180 |
+
|
181 |
+
for each_node in hg.nodes:
|
182 |
+
edge_1, edge_2 = hg.adj_edges(each_node)
|
183 |
+
if edge_1+edge_2 not in bond_set:
|
184 |
+
if hg.node_attr(each_node)['symbol'].bond_type <= 3:
|
185 |
+
num_bond = hg.node_attr(each_node)['symbol'].bond_type
|
186 |
+
elif hg.node_attr(each_node)['symbol'].bond_type == 12:
|
187 |
+
num_bond = 1
|
188 |
+
else:
|
189 |
+
raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}')
|
190 |
+
_ = mol.AddBond(atom_dict[edge_1],
|
191 |
+
atom_dict[edge_2],
|
192 |
+
order=Chem.rdchem.BondType.values[num_bond])
|
193 |
+
bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx()
|
194 |
+
|
195 |
+
# stereo
|
196 |
+
mol.GetBondWithIdx(bond_idx).SetStereo(
|
197 |
+
Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo])
|
198 |
+
bond_set.update([edge_1+edge_2])
|
199 |
+
bond_set.update([edge_2+edge_1])
|
200 |
+
mol.UpdatePropertyCache()
|
201 |
+
mol = mol.GetMol()
|
202 |
+
not_stereo_mol = deepcopy(mol)
|
203 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None:
|
204 |
+
raise RuntimeError('no valid molecule was obtained.')
|
205 |
+
try:
|
206 |
+
mol = set_stereo(mol)
|
207 |
+
is_stereo = True
|
208 |
+
except:
|
209 |
+
import traceback
|
210 |
+
traceback.print_exc()
|
211 |
+
is_stereo = False
|
212 |
+
mol_tmp = deepcopy(mol)
|
213 |
+
Chem.SetAromaticity(mol_tmp)
|
214 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None:
|
215 |
+
mol = mol_tmp
|
216 |
+
else:
|
217 |
+
if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None:
|
218 |
+
mol = not_stereo_mol
|
219 |
+
mol.UpdatePropertyCache()
|
220 |
+
Chem.GetSymmSSSR(mol)
|
221 |
+
mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
222 |
+
if verbose:
|
223 |
+
return mol, is_stereo
|
224 |
+
else:
|
225 |
+
return mol
|
226 |
+
|
227 |
+
def hgs_to_mols(hg_list, ignore_error=False):
|
228 |
+
if ignore_error:
|
229 |
+
mol_list = []
|
230 |
+
for each_hg in hg_list:
|
231 |
+
try:
|
232 |
+
mol = hg_to_mol(each_hg)
|
233 |
+
except:
|
234 |
+
mol = None
|
235 |
+
mol_list.append(mol)
|
236 |
+
else:
|
237 |
+
mol_list = [hg_to_mol(each_hg) for each_hg in hg_list]
|
238 |
+
return mol_list
|
239 |
+
|
240 |
+
def hgs_to_smiles(hg_list, ignore_error=False):
|
241 |
+
mol_list = hgs_to_mols(hg_list, ignore_error)
|
242 |
+
smiles_list = []
|
243 |
+
for each_mol in mol_list:
|
244 |
+
try:
|
245 |
+
smiles_list.append(
|
246 |
+
Chem.MolToSmiles(
|
247 |
+
Chem.MolFromSmiles(
|
248 |
+
Chem.MolToSmiles(
|
249 |
+
each_mol))))
|
250 |
+
except:
|
251 |
+
smiles_list.append(None)
|
252 |
+
return smiles_list
|
253 |
+
|
254 |
+
def atom_attr(atom, kekulize):
|
255 |
+
"""
|
256 |
+
get atom's attributes
|
257 |
+
|
258 |
+
Parameters
|
259 |
+
----------
|
260 |
+
atom : rdkit.Chem.rdchem.Atom
|
261 |
+
kekulize : bool
|
262 |
+
kekulize or not
|
263 |
+
|
264 |
+
Returns
|
265 |
+
-------
|
266 |
+
atom_attr : dict
|
267 |
+
"is_aromatic" : bool
|
268 |
+
the atom is aromatic or not.
|
269 |
+
"smarts" : str
|
270 |
+
SMARTS representation of the atom.
|
271 |
+
"""
|
272 |
+
if kekulize:
|
273 |
+
return {'terminal': True,
|
274 |
+
'is_in_ring': atom.IsInRing(),
|
275 |
+
'symbol': TSymbol(degree=0,
|
276 |
+
#degree=atom.GetTotalDegree(),
|
277 |
+
is_aromatic=False,
|
278 |
+
symbol=atom.GetSymbol(),
|
279 |
+
num_explicit_Hs=atom.GetNumExplicitHs(),
|
280 |
+
formal_charge=atom.GetFormalCharge(),
|
281 |
+
chirality=atom.GetChiralTag().real
|
282 |
+
)}
|
283 |
+
else:
|
284 |
+
return {'terminal': True,
|
285 |
+
'is_in_ring': atom.IsInRing(),
|
286 |
+
'symbol': TSymbol(degree=0,
|
287 |
+
#degree=atom.GetTotalDegree(),
|
288 |
+
is_aromatic=atom.GetIsAromatic(),
|
289 |
+
symbol=atom.GetSymbol(),
|
290 |
+
num_explicit_Hs=atom.GetNumExplicitHs(),
|
291 |
+
formal_charge=atom.GetFormalCharge(),
|
292 |
+
chirality=atom.GetChiralTag().real
|
293 |
+
)}
|
294 |
+
|
295 |
+
def bond_attr(bond, kekulize):
|
296 |
+
"""
|
297 |
+
get atom's attributes
|
298 |
+
|
299 |
+
Parameters
|
300 |
+
----------
|
301 |
+
bond : rdkit.Chem.rdchem.Bond
|
302 |
+
kekulize : bool
|
303 |
+
kekulize or not
|
304 |
+
|
305 |
+
Returns
|
306 |
+
-------
|
307 |
+
bond_attr : dict
|
308 |
+
"bond_type" : int
|
309 |
+
{0: rdkit.Chem.rdchem.BondType.UNSPECIFIED,
|
310 |
+
1: rdkit.Chem.rdchem.BondType.SINGLE,
|
311 |
+
2: rdkit.Chem.rdchem.BondType.DOUBLE,
|
312 |
+
3: rdkit.Chem.rdchem.BondType.TRIPLE,
|
313 |
+
4: rdkit.Chem.rdchem.BondType.QUADRUPLE,
|
314 |
+
5: rdkit.Chem.rdchem.BondType.QUINTUPLE,
|
315 |
+
6: rdkit.Chem.rdchem.BondType.HEXTUPLE,
|
316 |
+
7: rdkit.Chem.rdchem.BondType.ONEANDAHALF,
|
317 |
+
8: rdkit.Chem.rdchem.BondType.TWOANDAHALF,
|
318 |
+
9: rdkit.Chem.rdchem.BondType.THREEANDAHALF,
|
319 |
+
10: rdkit.Chem.rdchem.BondType.FOURANDAHALF,
|
320 |
+
11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF,
|
321 |
+
12: rdkit.Chem.rdchem.BondType.AROMATIC,
|
322 |
+
13: rdkit.Chem.rdchem.BondType.IONIC,
|
323 |
+
14: rdkit.Chem.rdchem.BondType.HYDROGEN,
|
324 |
+
15: rdkit.Chem.rdchem.BondType.THREECENTER,
|
325 |
+
16: rdkit.Chem.rdchem.BondType.DATIVEONE,
|
326 |
+
17: rdkit.Chem.rdchem.BondType.DATIVE,
|
327 |
+
18: rdkit.Chem.rdchem.BondType.DATIVEL,
|
328 |
+
19: rdkit.Chem.rdchem.BondType.DATIVER,
|
329 |
+
20: rdkit.Chem.rdchem.BondType.OTHER,
|
330 |
+
21: rdkit.Chem.rdchem.BondType.ZERO}
|
331 |
+
"""
|
332 |
+
if kekulize:
|
333 |
+
is_aromatic = False
|
334 |
+
if bond.GetBondType().real == 12:
|
335 |
+
bond_type = 1
|
336 |
+
else:
|
337 |
+
bond_type = bond.GetBondType().real
|
338 |
+
else:
|
339 |
+
is_aromatic = bond.GetIsAromatic()
|
340 |
+
bond_type = bond.GetBondType().real
|
341 |
+
return {'symbol': BondSymbol(is_aromatic=is_aromatic,
|
342 |
+
bond_type=bond_type,
|
343 |
+
stereo=int(bond.GetStereo())),
|
344 |
+
'is_in_ring': bond.IsInRing()}
|
345 |
+
|
346 |
+
|
347 |
+
def standardize_stereo(mol):
|
348 |
+
'''
|
349 |
+
0: rdkit.Chem.rdchem.BondDir.NONE,
|
350 |
+
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
351 |
+
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
352 |
+
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
353 |
+
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
354 |
+
|
355 |
+
'''
|
356 |
+
# mol = Chem.AddHs(mol) # this removes CIPRank !!!
|
357 |
+
for each_bond in mol.GetBonds():
|
358 |
+
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
359 |
+
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
360 |
+
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
361 |
+
atom_idx_1 = each_bond.GetStereoAtoms()[0]
|
362 |
+
atom_idx_2 = each_bond.GetStereoAtoms()[1]
|
363 |
+
if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx):
|
364 |
+
begin_atom_idx = atom_idx_1
|
365 |
+
end_atom_idx = atom_idx_2
|
366 |
+
else:
|
367 |
+
begin_atom_idx = atom_idx_2
|
368 |
+
end_atom_idx = atom_idx_1
|
369 |
+
|
370 |
+
begin_another_atom_idx = None
|
371 |
+
assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3
|
372 |
+
for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors():
|
373 |
+
each_neighbor_idx = each_neighbor.GetIdx()
|
374 |
+
if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]:
|
375 |
+
begin_another_atom_idx = each_neighbor_idx
|
376 |
+
|
377 |
+
end_another_atom_idx = None
|
378 |
+
assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3
|
379 |
+
for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors():
|
380 |
+
each_neighbor_idx = each_neighbor.GetIdx()
|
381 |
+
if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]:
|
382 |
+
end_another_atom_idx = each_neighbor_idx
|
383 |
+
|
384 |
+
'''
|
385 |
+
relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo
|
386 |
+
'''
|
387 |
+
begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank'))
|
388 |
+
end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank'))
|
389 |
+
try:
|
390 |
+
begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank'))
|
391 |
+
except:
|
392 |
+
begin_another_atom_rank = np.inf
|
393 |
+
try:
|
394 |
+
end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank'))
|
395 |
+
except:
|
396 |
+
end_another_atom_rank = np.inf
|
397 |
+
if begin_atom_rank < begin_another_atom_rank\
|
398 |
+
and end_atom_rank < end_another_atom_rank:
|
399 |
+
pass
|
400 |
+
elif begin_atom_rank < begin_another_atom_rank\
|
401 |
+
and end_atom_rank > end_another_atom_rank:
|
402 |
+
# (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms
|
403 |
+
if each_bond.GetStereo() == 2:
|
404 |
+
# set stereo
|
405 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
406 |
+
# set bond dir
|
407 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
408 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
409 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
410 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
411 |
+
elif each_bond.GetStereo() == 3:
|
412 |
+
# set stereo
|
413 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
414 |
+
# set bond dir
|
415 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
416 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0)
|
417 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
418 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
419 |
+
else:
|
420 |
+
raise ValueError
|
421 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx)
|
422 |
+
elif begin_atom_rank > begin_another_atom_rank\
|
423 |
+
and end_atom_rank < end_another_atom_rank:
|
424 |
+
# (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms
|
425 |
+
if each_bond.GetStereo() == 2:
|
426 |
+
# set stereo
|
427 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[3])
|
428 |
+
# set bond dir
|
429 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
430 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
431 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
432 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
433 |
+
elif each_bond.GetStereo() == 3:
|
434 |
+
# set stereo
|
435 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo.values[2])
|
436 |
+
# set bond dir
|
437 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
438 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
439 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
440 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0)
|
441 |
+
else:
|
442 |
+
raise ValueError
|
443 |
+
each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx)
|
444 |
+
elif begin_atom_rank > begin_another_atom_rank\
|
445 |
+
and end_atom_rank > end_another_atom_rank:
|
446 |
+
# begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms
|
447 |
+
if each_bond.GetStereo() == 2:
|
448 |
+
# set bond dir
|
449 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
450 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
451 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
452 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3)
|
453 |
+
elif each_bond.GetStereo() == 3:
|
454 |
+
# set bond dir
|
455 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0)
|
456 |
+
mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4)
|
457 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0)
|
458 |
+
mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4)
|
459 |
+
else:
|
460 |
+
raise ValueError
|
461 |
+
each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx)
|
462 |
+
else:
|
463 |
+
raise RuntimeError
|
464 |
+
return mol
|
465 |
+
|
466 |
+
|
467 |
+
def set_stereo(mol):
|
468 |
+
'''
|
469 |
+
0: rdkit.Chem.rdchem.BondDir.NONE,
|
470 |
+
1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE,
|
471 |
+
2: rdkit.Chem.rdchem.BondDir.BEGINDASH,
|
472 |
+
3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
473 |
+
4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT,
|
474 |
+
'''
|
475 |
+
_mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
|
476 |
+
Chem.Kekulize(_mol, True)
|
477 |
+
substruct_match = mol.GetSubstructMatch(_mol)
|
478 |
+
if not substruct_match:
|
479 |
+
''' mol and _mol are kekulized.
|
480 |
+
sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched.
|
481 |
+
'''
|
482 |
+
Chem.SetAromaticity(mol)
|
483 |
+
Chem.SetAromaticity(_mol)
|
484 |
+
substruct_match = mol.GetSubstructMatch(_mol)
|
485 |
+
try:
|
486 |
+
atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol
|
487 |
+
except:
|
488 |
+
raise ValueError('two molecules obtained from the same data do not match.')
|
489 |
+
|
490 |
+
for each_bond in mol.GetBonds():
|
491 |
+
begin_atom_idx = each_bond.GetBeginAtomIdx()
|
492 |
+
end_atom_idx = each_bond.GetEndAtomIdx()
|
493 |
+
_bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx])
|
494 |
+
_bond.SetStereo(each_bond.GetStereo())
|
495 |
+
|
496 |
+
mol = _mol
|
497 |
+
for each_bond in mol.GetBonds():
|
498 |
+
if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E
|
499 |
+
begin_stereo_atom_idx = each_bond.GetBeginAtomIdx()
|
500 |
+
end_stereo_atom_idx = each_bond.GetEndAtomIdx()
|
501 |
+
begin_atom_idx_set = set([each_neighbor.GetIdx()
|
502 |
+
for each_neighbor
|
503 |
+
in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()
|
504 |
+
if each_neighbor.GetIdx() != end_stereo_atom_idx])
|
505 |
+
end_atom_idx_set = set([each_neighbor.GetIdx()
|
506 |
+
for each_neighbor
|
507 |
+
in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()
|
508 |
+
if each_neighbor.GetIdx() != begin_stereo_atom_idx])
|
509 |
+
if not begin_atom_idx_set:
|
510 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
511 |
+
continue
|
512 |
+
if not end_atom_idx_set:
|
513 |
+
each_bond.SetStereo(Chem.rdchem.BondStereo(0))
|
514 |
+
continue
|
515 |
+
if len(begin_atom_idx_set) == 1:
|
516 |
+
begin_atom_idx = begin_atom_idx_set.pop()
|
517 |
+
begin_another_atom_idx = None
|
518 |
+
if len(end_atom_idx_set) == 1:
|
519 |
+
end_atom_idx = end_atom_idx_set.pop()
|
520 |
+
end_another_atom_idx = None
|
521 |
+
if len(begin_atom_idx_set) == 2:
|
522 |
+
atom_idx_1 = begin_atom_idx_set.pop()
|
523 |
+
atom_idx_2 = begin_atom_idx_set.pop()
|
524 |
+
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
525 |
+
begin_atom_idx = atom_idx_1
|
526 |
+
begin_another_atom_idx = atom_idx_2
|
527 |
+
else:
|
528 |
+
begin_atom_idx = atom_idx_2
|
529 |
+
begin_another_atom_idx = atom_idx_1
|
530 |
+
if len(end_atom_idx_set) == 2:
|
531 |
+
atom_idx_1 = end_atom_idx_set.pop()
|
532 |
+
atom_idx_2 = end_atom_idx_set.pop()
|
533 |
+
if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')):
|
534 |
+
end_atom_idx = atom_idx_1
|
535 |
+
end_another_atom_idx = atom_idx_2
|
536 |
+
else:
|
537 |
+
end_atom_idx = atom_idx_2
|
538 |
+
end_another_atom_idx = atom_idx_1
|
539 |
+
|
540 |
+
if each_bond.GetStereo() == 2: # same side
|
541 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
542 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4)
|
543 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
544 |
+
elif each_bond.GetStereo() == 3: # opposite side
|
545 |
+
mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3)
|
546 |
+
mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3)
|
547 |
+
each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx)
|
548 |
+
else:
|
549 |
+
raise ValueError
|
550 |
+
return mol
|
551 |
+
|
552 |
+
|
553 |
+
def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val):
|
554 |
+
if atom_idx_1 is None or atom_idx_2 is None:
|
555 |
+
return mol
|
556 |
+
else:
|
557 |
+
mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val])
|
558 |
+
return mol
|
559 |
+
|
models/mhg_model/graph_grammar/nn/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
6 |
+
|
7 |
+
"""
|
8 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
9 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
10 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
11 |
+
"""
|
models/mhg_model/graph_grammar/nn/dataset.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Apr 18 2018"
|
20 |
+
|
21 |
+
from torch.utils.data import Dataset, DataLoader
|
22 |
+
import torch
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
|
26 |
+
def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False):
|
27 |
+
''' pad left
|
28 |
+
|
29 |
+
Parameters
|
30 |
+
----------
|
31 |
+
sentence_list : list of sequences of integers
|
32 |
+
max_len : int
|
33 |
+
maximum length of sentences.
|
34 |
+
if a sentence is shorter than `max_len`, its left part is padded.
|
35 |
+
pad_idx : int
|
36 |
+
integer for padding
|
37 |
+
inverse : bool
|
38 |
+
if True, the sequence is inversed.
|
39 |
+
|
40 |
+
Returns
|
41 |
+
-------
|
42 |
+
List of torch.LongTensor
|
43 |
+
each sentence is left-padded.
|
44 |
+
'''
|
45 |
+
max_in_list = max([len(each_sen) for each_sen in sentence_list])
|
46 |
+
|
47 |
+
if max_in_list > max_len:
|
48 |
+
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
|
49 |
+
|
50 |
+
if inverse:
|
51 |
+
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list]
|
52 |
+
else:
|
53 |
+
return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list]
|
54 |
+
|
55 |
+
|
56 |
+
def right_padding(sentence_list, max_len, pad_idx=-1):
|
57 |
+
''' pad right
|
58 |
+
|
59 |
+
Parameters
|
60 |
+
----------
|
61 |
+
sentence_list : list of sequences of integers
|
62 |
+
max_len : int
|
63 |
+
maximum length of sentences.
|
64 |
+
if a sentence is shorter than `max_len`, its right part is padded.
|
65 |
+
pad_idx : int
|
66 |
+
integer for padding
|
67 |
+
|
68 |
+
Returns
|
69 |
+
-------
|
70 |
+
List of torch.LongTensor
|
71 |
+
each sentence is right-padded.
|
72 |
+
'''
|
73 |
+
max_in_list = max([len(each_sen) for each_sen in sentence_list])
|
74 |
+
if max_in_list > max_len:
|
75 |
+
raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list))
|
76 |
+
|
77 |
+
return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list]
|
78 |
+
|
79 |
+
|
80 |
+
class HRGDataset(Dataset):
|
81 |
+
|
82 |
+
'''
|
83 |
+
A class of HRG data
|
84 |
+
'''
|
85 |
+
|
86 |
+
def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False):
|
87 |
+
self.hrg = hrg
|
88 |
+
self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list,
|
89 |
+
max_len,
|
90 |
+
inverse=inversed_input)
|
91 |
+
|
92 |
+
self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len)
|
93 |
+
self.inserved_input = inversed_input
|
94 |
+
self.target_val_list = target_val_list
|
95 |
+
if target_val_list is not None:
|
96 |
+
if len(prod_rule_seq_list) != len(target_val_list):
|
97 |
+
raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}')
|
98 |
+
|
99 |
+
def __len__(self):
|
100 |
+
return len(self.left_prod_rule_seq_list)
|
101 |
+
|
102 |
+
def __getitem__(self, idx):
|
103 |
+
if self.target_val_list is not None:
|
104 |
+
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx])
|
105 |
+
else:
|
106 |
+
return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx]
|
107 |
+
|
108 |
+
@property
|
109 |
+
def vocab_size(self):
|
110 |
+
return self.hrg.num_prod_rule
|
111 |
+
|
112 |
+
def batch_padding(each_batch, batch_size, padding_idx):
|
113 |
+
num_pad = batch_size - len(each_batch[0])
|
114 |
+
if num_pad:
|
115 |
+
each_batch[0] = torch.cat([each_batch[0],
|
116 |
+
padding_idx * torch.ones((batch_size - len(each_batch[0]),
|
117 |
+
len(each_batch[0][0])), dtype=torch.int64)], dim=0)
|
118 |
+
each_batch[1] = torch.cat([each_batch[1],
|
119 |
+
padding_idx * torch.ones((batch_size - len(each_batch[1]),
|
120 |
+
len(each_batch[1][0])), dtype=torch.int64)], dim=0)
|
121 |
+
return each_batch, num_pad
|
models/mhg_model/graph_grammar/nn/decoder.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Aug 9 2018"
|
20 |
+
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from torch import nn
|
26 |
+
|
27 |
+
|
28 |
+
class DecoderBase(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self):
|
31 |
+
super().__init__()
|
32 |
+
self.hidden_dict = {}
|
33 |
+
|
34 |
+
@abc.abstractmethod
|
35 |
+
def forward_one_step(self, tgt_emb_in):
|
36 |
+
''' one-step forward model
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
41 |
+
|
42 |
+
Returns
|
43 |
+
-------
|
44 |
+
Tensor, shape (batch_size, hidden_dim)
|
45 |
+
'''
|
46 |
+
tgt_emb_out = None
|
47 |
+
return tgt_emb_out
|
48 |
+
|
49 |
+
@abc.abstractmethod
|
50 |
+
def init_hidden(self):
|
51 |
+
''' initialize the hidden states
|
52 |
+
'''
|
53 |
+
pass
|
54 |
+
|
55 |
+
@abc.abstractmethod
|
56 |
+
def feed_hidden(self, hidden_dict_0):
|
57 |
+
for each_hidden in self.hidden_dict.keys():
|
58 |
+
self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden]
|
59 |
+
|
60 |
+
|
61 |
+
class GRUDecoder(DecoderBase):
|
62 |
+
|
63 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
64 |
+
dropout: float, batch_size: int, use_gpu: bool,
|
65 |
+
no_dropout=False):
|
66 |
+
super().__init__()
|
67 |
+
self.input_dim = input_dim
|
68 |
+
self.hidden_dim = hidden_dim
|
69 |
+
self.num_layers = num_layers
|
70 |
+
self.dropout = dropout
|
71 |
+
self.batch_size = batch_size
|
72 |
+
self.use_gpu = use_gpu
|
73 |
+
self.model = nn.GRU(input_size=self.input_dim,
|
74 |
+
hidden_size=self.hidden_dim,
|
75 |
+
num_layers=self.num_layers,
|
76 |
+
batch_first=True,
|
77 |
+
bidirectional=False,
|
78 |
+
dropout=self.dropout if not no_dropout else 0
|
79 |
+
)
|
80 |
+
if self.use_gpu:
|
81 |
+
self.model.cuda()
|
82 |
+
self.init_hidden()
|
83 |
+
|
84 |
+
def init_hidden(self):
|
85 |
+
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
86 |
+
self.batch_size,
|
87 |
+
self.hidden_dim),
|
88 |
+
requires_grad=False)
|
89 |
+
if self.use_gpu:
|
90 |
+
self.hidden_dict['h'] = self.hidden_dict['h'].cuda()
|
91 |
+
|
92 |
+
def forward_one_step(self, tgt_emb_in):
|
93 |
+
''' one-step forward model
|
94 |
+
|
95 |
+
Parameters
|
96 |
+
----------
|
97 |
+
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
98 |
+
|
99 |
+
Returns
|
100 |
+
-------
|
101 |
+
Tensor, shape (batch_size, hidden_dim)
|
102 |
+
'''
|
103 |
+
tgt_emb_out, self.hidden_dict['h'] \
|
104 |
+
= self.model(tgt_emb_in.view(self.batch_size, 1, -1),
|
105 |
+
self.hidden_dict['h'])
|
106 |
+
return tgt_emb_out
|
107 |
+
|
108 |
+
|
109 |
+
class LSTMDecoder(DecoderBase):
|
110 |
+
|
111 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
112 |
+
dropout: float, batch_size: int, use_gpu: bool,
|
113 |
+
no_dropout=False):
|
114 |
+
super().__init__()
|
115 |
+
self.input_dim = input_dim
|
116 |
+
self.hidden_dim = hidden_dim
|
117 |
+
self.num_layers = num_layers
|
118 |
+
self.dropout = dropout
|
119 |
+
self.batch_size = batch_size
|
120 |
+
self.use_gpu = use_gpu
|
121 |
+
self.model = nn.LSTM(input_size=self.input_dim,
|
122 |
+
hidden_size=self.hidden_dim,
|
123 |
+
num_layers=self.num_layers,
|
124 |
+
batch_first=True,
|
125 |
+
bidirectional=False,
|
126 |
+
dropout=self.dropout if not no_dropout else 0)
|
127 |
+
if self.use_gpu:
|
128 |
+
self.model.cuda()
|
129 |
+
self.init_hidden()
|
130 |
+
|
131 |
+
def init_hidden(self):
|
132 |
+
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
133 |
+
self.batch_size,
|
134 |
+
self.hidden_dim),
|
135 |
+
requires_grad=False)
|
136 |
+
self.hidden_dict['c'] = torch.zeros((self.num_layers,
|
137 |
+
self.batch_size,
|
138 |
+
self.hidden_dim),
|
139 |
+
requires_grad=False)
|
140 |
+
if self.use_gpu:
|
141 |
+
for each_hidden in self.hidden_dict.keys():
|
142 |
+
self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda()
|
143 |
+
|
144 |
+
def forward_one_step(self, tgt_emb_in):
|
145 |
+
''' one-step forward model
|
146 |
+
|
147 |
+
Parameters
|
148 |
+
----------
|
149 |
+
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
150 |
+
|
151 |
+
Returns
|
152 |
+
-------
|
153 |
+
Tensor, shape (batch_size, hidden_dim)
|
154 |
+
'''
|
155 |
+
tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \
|
156 |
+
= self.model(tgt_emb_in.view(self.batch_size, 1, -1),
|
157 |
+
self.hidden_dict['h'], self.hidden_dict['c'])
|
158 |
+
return tgt_hidden_out
|
models/mhg_model/graph_grammar/nn/encoder.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Aug 9 2018"
|
20 |
+
|
21 |
+
|
22 |
+
import abc
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
+
from torch import nn
|
27 |
+
from typing import List
|
28 |
+
|
29 |
+
|
30 |
+
class EncoderBase(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
@abc.abstractmethod
|
36 |
+
def forward(self, in_seq):
|
37 |
+
''' forward model
|
38 |
+
|
39 |
+
Parameters
|
40 |
+
----------
|
41 |
+
in_seq_emb : Variable, shape (batch_size, max_len, input_dim)
|
42 |
+
|
43 |
+
Returns
|
44 |
+
-------
|
45 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
46 |
+
'''
|
47 |
+
pass
|
48 |
+
|
49 |
+
@abc.abstractmethod
|
50 |
+
def init_hidden(self):
|
51 |
+
''' initialize the hidden states
|
52 |
+
'''
|
53 |
+
pass
|
54 |
+
|
55 |
+
|
56 |
+
class GRUEncoder(EncoderBase):
|
57 |
+
|
58 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
59 |
+
bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
|
60 |
+
no_dropout=False):
|
61 |
+
super().__init__()
|
62 |
+
self.input_dim = input_dim
|
63 |
+
self.hidden_dim = hidden_dim
|
64 |
+
self.num_layers = num_layers
|
65 |
+
self.bidirectional = bidirectional
|
66 |
+
self.dropout = dropout
|
67 |
+
self.batch_size = batch_size
|
68 |
+
self.use_gpu = use_gpu
|
69 |
+
self.model = nn.GRU(input_size=self.input_dim,
|
70 |
+
hidden_size=self.hidden_dim,
|
71 |
+
num_layers=self.num_layers,
|
72 |
+
batch_first=True,
|
73 |
+
bidirectional=self.bidirectional,
|
74 |
+
dropout=self.dropout if not no_dropout else 0)
|
75 |
+
if self.use_gpu:
|
76 |
+
self.model.cuda()
|
77 |
+
self.init_hidden()
|
78 |
+
|
79 |
+
|
80 |
+
def init_hidden(self):
|
81 |
+
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
82 |
+
self.batch_size,
|
83 |
+
self.hidden_dim),
|
84 |
+
requires_grad=False)
|
85 |
+
if self.use_gpu:
|
86 |
+
self.h0 = self.h0.cuda()
|
87 |
+
|
88 |
+
def forward(self, in_seq_emb):
|
89 |
+
''' forward model
|
90 |
+
|
91 |
+
Parameters
|
92 |
+
----------
|
93 |
+
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
94 |
+
|
95 |
+
Returns
|
96 |
+
-------
|
97 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
98 |
+
'''
|
99 |
+
max_len = in_seq_emb.size(1)
|
100 |
+
hidden_seq_emb, self.h0 = self.model(
|
101 |
+
in_seq_emb, self.h0)
|
102 |
+
hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
|
103 |
+
max_len,
|
104 |
+
1 + self.bidirectional,
|
105 |
+
self.hidden_dim)
|
106 |
+
return hidden_seq_emb
|
107 |
+
|
108 |
+
|
109 |
+
class LSTMEncoder(EncoderBase):
|
110 |
+
|
111 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
112 |
+
bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool,
|
113 |
+
no_dropout=False):
|
114 |
+
super().__init__()
|
115 |
+
self.input_dim = input_dim
|
116 |
+
self.hidden_dim = hidden_dim
|
117 |
+
self.num_layers = num_layers
|
118 |
+
self.bidirectional = bidirectional
|
119 |
+
self.dropout = dropout
|
120 |
+
self.batch_size = batch_size
|
121 |
+
self.use_gpu = use_gpu
|
122 |
+
self.model = nn.LSTM(input_size=self.input_dim,
|
123 |
+
hidden_size=self.hidden_dim,
|
124 |
+
num_layers=self.num_layers,
|
125 |
+
batch_first=True,
|
126 |
+
bidirectional=self.bidirectional,
|
127 |
+
dropout=self.dropout if not no_dropout else 0)
|
128 |
+
if self.use_gpu:
|
129 |
+
self.model.cuda()
|
130 |
+
self.init_hidden()
|
131 |
+
|
132 |
+
def init_hidden(self):
|
133 |
+
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
134 |
+
self.batch_size,
|
135 |
+
self.hidden_dim),
|
136 |
+
requires_grad=False)
|
137 |
+
self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
138 |
+
self.batch_size,
|
139 |
+
self.hidden_dim),
|
140 |
+
requires_grad=False)
|
141 |
+
if self.use_gpu:
|
142 |
+
self.h0 = self.h0.cuda()
|
143 |
+
self.c0 = self.c0.cuda()
|
144 |
+
|
145 |
+
def forward(self, in_seq_emb):
|
146 |
+
''' forward model
|
147 |
+
|
148 |
+
Parameters
|
149 |
+
----------
|
150 |
+
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
151 |
+
|
152 |
+
Returns
|
153 |
+
-------
|
154 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
155 |
+
'''
|
156 |
+
max_len = in_seq_emb.size(1)
|
157 |
+
hidden_seq_emb, (self.h0, self.c0) = self.model(
|
158 |
+
in_seq_emb, (self.h0, self.c0))
|
159 |
+
hidden_seq_emb = hidden_seq_emb.view(self.batch_size,
|
160 |
+
max_len,
|
161 |
+
1 + self.bidirectional,
|
162 |
+
self.hidden_dim)
|
163 |
+
return hidden_seq_emb
|
164 |
+
|
165 |
+
|
166 |
+
class FullConnectedEncoder(EncoderBase):
|
167 |
+
|
168 |
+
def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int],
|
169 |
+
batch_size: int, use_gpu: bool):
|
170 |
+
super().__init__()
|
171 |
+
self.input_dim = input_dim
|
172 |
+
self.hidden_dim = hidden_dim
|
173 |
+
self.max_len = max_len
|
174 |
+
self.hidden_dim_list = hidden_dim_list
|
175 |
+
self.use_gpu = use_gpu
|
176 |
+
in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim]
|
177 |
+
self.linear_list = nn.ModuleList(
|
178 |
+
[nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\
|
179 |
+
for each_idx in range(len(in_out_dim_list) - 1)])
|
180 |
+
|
181 |
+
def forward(self, in_seq_emb):
|
182 |
+
''' forward model
|
183 |
+
|
184 |
+
Parameters
|
185 |
+
----------
|
186 |
+
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
187 |
+
|
188 |
+
Returns
|
189 |
+
-------
|
190 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
191 |
+
'''
|
192 |
+
batch_size = in_seq_emb.size(0)
|
193 |
+
x = in_seq_emb.view(batch_size, -1)
|
194 |
+
for each_linear in self.linear_list:
|
195 |
+
x = F.relu(each_linear(x))
|
196 |
+
return x.view(batch_size, 1, -1)
|
197 |
+
|
198 |
+
def init_hidden(self):
|
199 |
+
pass
|
models/mhg_model/graph_grammar/nn/graph.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# Rhizome
|
4 |
+
# Version beta 0.0, August 2023
|
5 |
+
# Property of IBM Research, Accelerated Discovery
|
6 |
+
#
|
7 |
+
|
8 |
+
"""
|
9 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS)
|
10 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
""" Title """
|
15 |
+
|
16 |
+
__author__ = "Hiroshi Kajino <[email protected]>"
|
17 |
+
__copyright__ = "(c) Copyright IBM Corp. 2018"
|
18 |
+
__version__ = "0.1"
|
19 |
+
__date__ = "Jan 1 2018"
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus
|
25 |
+
from torch import nn
|
26 |
+
from torch.autograd import Variable
|
27 |
+
|
28 |
+
class MolecularProdRuleEmbedding(nn.Module):
|
29 |
+
|
30 |
+
''' molecular fingerprint layer
|
31 |
+
'''
|
32 |
+
|
33 |
+
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
34 |
+
out_dim=32, element_embed_dim=32,
|
35 |
+
num_layers=3, padding_idx=None, use_gpu=False):
|
36 |
+
super().__init__()
|
37 |
+
if padding_idx is not None:
|
38 |
+
assert padding_idx == -1, 'padding_idx must be -1.'
|
39 |
+
self.prod_rule_corpus = prod_rule_corpus
|
40 |
+
self.layer2layer_activation = layer2layer_activation
|
41 |
+
self.layer2out_activation = layer2out_activation
|
42 |
+
self.out_dim = out_dim
|
43 |
+
self.element_embed_dim = element_embed_dim
|
44 |
+
self.num_layers = num_layers
|
45 |
+
self.padding_idx = padding_idx
|
46 |
+
self.use_gpu = use_gpu
|
47 |
+
|
48 |
+
self.layer2layer_list = []
|
49 |
+
self.layer2out_list = []
|
50 |
+
|
51 |
+
if self.use_gpu:
|
52 |
+
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
|
53 |
+
self.element_embed_dim, requires_grad=True).cuda()
|
54 |
+
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
|
55 |
+
self.element_embed_dim, requires_grad=True).cuda()
|
56 |
+
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
|
57 |
+
self.element_embed_dim, requires_grad=True).cuda()
|
58 |
+
for _ in range(num_layers):
|
59 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
|
60 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
|
61 |
+
else:
|
62 |
+
self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol,
|
63 |
+
self.element_embed_dim, requires_grad=True)
|
64 |
+
self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol,
|
65 |
+
self.element_embed_dim, requires_grad=True)
|
66 |
+
self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id,
|
67 |
+
self.element_embed_dim, requires_grad=True)
|
68 |
+
for _ in range(num_layers):
|
69 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
|
70 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, prod_rule_idx_seq):
|
74 |
+
''' forward model for mini-batch
|
75 |
+
|
76 |
+
Parameters
|
77 |
+
----------
|
78 |
+
prod_rule_idx_seq : (batch_size, length)
|
79 |
+
|
80 |
+
Returns
|
81 |
+
-------
|
82 |
+
Variable, shape (batch_size, length, out_dim)
|
83 |
+
'''
|
84 |
+
batch_size, length = prod_rule_idx_seq.shape
|
85 |
+
if self.use_gpu:
|
86 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
87 |
+
else:
|
88 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
89 |
+
for each_batch_idx in range(batch_size):
|
90 |
+
for each_idx in range(length):
|
91 |
+
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
92 |
+
continue
|
93 |
+
else:
|
94 |
+
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
95 |
+
layer_wise_embed_dict = {each_edge: self.atom_embed[
|
96 |
+
each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
97 |
+
for each_edge in each_prod_rule.rhs.edges}
|
98 |
+
layer_wise_embed_dict.update({each_node: self.bond_embed[
|
99 |
+
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]
|
100 |
+
for each_node in each_prod_rule.rhs.nodes})
|
101 |
+
for each_node in each_prod_rule.rhs.nodes:
|
102 |
+
if 'ext_id' in each_prod_rule.rhs.node_attr(each_node):
|
103 |
+
layer_wise_embed_dict[each_node] \
|
104 |
+
= layer_wise_embed_dict[each_node] \
|
105 |
+
+ self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']]
|
106 |
+
|
107 |
+
for each_layer in range(self.num_layers):
|
108 |
+
next_layer_embed_dict = {}
|
109 |
+
for each_edge in each_prod_rule.rhs.edges:
|
110 |
+
v = layer_wise_embed_dict[each_edge]
|
111 |
+
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
|
112 |
+
v = v + layer_wise_embed_dict[each_node]
|
113 |
+
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
114 |
+
out[each_batch_idx, each_idx, :] \
|
115 |
+
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
|
116 |
+
for each_node in each_prod_rule.rhs.nodes:
|
117 |
+
v = layer_wise_embed_dict[each_node]
|
118 |
+
for each_edge in each_prod_rule.rhs.adj_edges(each_node):
|
119 |
+
v = v + layer_wise_embed_dict[each_edge]
|
120 |
+
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
121 |
+
out[each_batch_idx, each_idx, :]\
|
122 |
+
= out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v))
|
123 |
+
layer_wise_embed_dict = next_layer_embed_dict
|
124 |
+
|
125 |
+
return out
|
126 |
+
|
127 |
+
|
128 |
+
class MolecularProdRuleEmbeddingLastLayer(nn.Module):
|
129 |
+
|
130 |
+
''' molecular fingerprint layer
|
131 |
+
'''
|
132 |
+
|
133 |
+
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
134 |
+
out_dim=32, element_embed_dim=32,
|
135 |
+
num_layers=3, padding_idx=None, use_gpu=False):
|
136 |
+
super().__init__()
|
137 |
+
if padding_idx is not None:
|
138 |
+
assert padding_idx == -1, 'padding_idx must be -1.'
|
139 |
+
self.prod_rule_corpus = prod_rule_corpus
|
140 |
+
self.layer2layer_activation = layer2layer_activation
|
141 |
+
self.layer2out_activation = layer2out_activation
|
142 |
+
self.out_dim = out_dim
|
143 |
+
self.element_embed_dim = element_embed_dim
|
144 |
+
self.num_layers = num_layers
|
145 |
+
self.padding_idx = padding_idx
|
146 |
+
self.use_gpu = use_gpu
|
147 |
+
|
148 |
+
self.layer2layer_list = []
|
149 |
+
self.layer2out_list = []
|
150 |
+
|
151 |
+
if self.use_gpu:
|
152 |
+
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda()
|
153 |
+
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda()
|
154 |
+
for _ in range(num_layers+1):
|
155 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda())
|
156 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda())
|
157 |
+
else:
|
158 |
+
self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim)
|
159 |
+
self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim)
|
160 |
+
for _ in range(num_layers+1):
|
161 |
+
self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim))
|
162 |
+
self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim))
|
163 |
+
|
164 |
+
|
165 |
+
def forward(self, prod_rule_idx_seq):
|
166 |
+
''' forward model for mini-batch
|
167 |
+
|
168 |
+
Parameters
|
169 |
+
----------
|
170 |
+
prod_rule_idx_seq : (batch_size, length)
|
171 |
+
|
172 |
+
Returns
|
173 |
+
-------
|
174 |
+
Variable, shape (batch_size, length, out_dim)
|
175 |
+
'''
|
176 |
+
batch_size, length = prod_rule_idx_seq.shape
|
177 |
+
if self.use_gpu:
|
178 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
179 |
+
else:
|
180 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
181 |
+
for each_batch_idx in range(batch_size):
|
182 |
+
for each_idx in range(length):
|
183 |
+
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
184 |
+
continue
|
185 |
+
else:
|
186 |
+
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
187 |
+
|
188 |
+
if self.use_gpu:
|
189 |
+
layer_wise_embed_dict = {each_edge: self.atom_embed(
|
190 |
+
Variable(torch.LongTensor(
|
191 |
+
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
192 |
+
), requires_grad=False).cuda())
|
193 |
+
for each_edge in each_prod_rule.rhs.edges}
|
194 |
+
layer_wise_embed_dict.update({each_node: self.bond_embed(
|
195 |
+
Variable(
|
196 |
+
torch.LongTensor([
|
197 |
+
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
|
198 |
+
requires_grad=False).cuda()
|
199 |
+
) for each_node in each_prod_rule.rhs.nodes})
|
200 |
+
else:
|
201 |
+
layer_wise_embed_dict = {each_edge: self.atom_embed(
|
202 |
+
Variable(torch.LongTensor(
|
203 |
+
[each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']]
|
204 |
+
), requires_grad=False))
|
205 |
+
for each_edge in each_prod_rule.rhs.edges}
|
206 |
+
layer_wise_embed_dict.update({each_node: self.bond_embed(
|
207 |
+
Variable(
|
208 |
+
torch.LongTensor([
|
209 |
+
each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]),
|
210 |
+
requires_grad=False)
|
211 |
+
) for each_node in each_prod_rule.rhs.nodes})
|
212 |
+
|
213 |
+
for each_layer in range(self.num_layers):
|
214 |
+
next_layer_embed_dict = {}
|
215 |
+
for each_edge in each_prod_rule.rhs.edges:
|
216 |
+
v = layer_wise_embed_dict[each_edge]
|
217 |
+
for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge):
|
218 |
+
v += layer_wise_embed_dict[each_node]
|
219 |
+
next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
220 |
+
for each_node in each_prod_rule.rhs.nodes:
|
221 |
+
v = layer_wise_embed_dict[each_node]
|
222 |
+
for each_edge in each_prod_rule.rhs.adj_edges(each_node):
|
223 |
+
v += layer_wise_embed_dict[each_edge]
|
224 |
+
next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v))
|
225 |
+
layer_wise_embed_dict = next_layer_embed_dict
|
226 |
+
for each_edge in each_prod_rule.rhs.edges:
|
227 |
+
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
|
228 |
+
for each_edge in each_prod_rule.rhs.edges:
|
229 |
+
out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v))
|
230 |
+
|
231 |
+
return out
|
232 |
+
|
233 |
+
|
234 |
+
class MolecularProdRuleEmbeddingUsingFeatures(nn.Module):
|
235 |
+
|
236 |
+
''' molecular fingerprint layer
|
237 |
+
'''
|
238 |
+
|
239 |
+
def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation,
|
240 |
+
out_dim=32, num_layers=3, padding_idx=None, use_gpu=False):
|
241 |
+
super().__init__()
|
242 |
+
if padding_idx is not None:
|
243 |
+
assert padding_idx == -1, 'padding_idx must be -1.'
|
244 |
+
self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors()
|
245 |
+
self.prod_rule_corpus = prod_rule_corpus
|
246 |
+
self.layer2layer_activation = layer2layer_activation
|
247 |
+
self.layer2out_activation = layer2out_activation
|
248 |
+
self.out_dim = out_dim
|
249 |
+
self.num_layers = num_layers
|
250 |
+
self.padding_idx = padding_idx
|
251 |
+
self.use_gpu = use_gpu
|
252 |
+
|
253 |
+
self.layer2layer_list = []
|
254 |
+
self.layer2out_list = []
|
255 |
+
|
256 |
+
if self.use_gpu:
|
257 |
+
for each_key in self.feature_dict:
|
258 |
+
self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda()
|
259 |
+
for _ in range(num_layers):
|
260 |
+
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda())
|
261 |
+
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda())
|
262 |
+
else:
|
263 |
+
for _ in range(num_layers):
|
264 |
+
self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim))
|
265 |
+
self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim))
|
266 |
+
|
267 |
+
|
268 |
+
def forward(self, prod_rule_idx_seq):
|
269 |
+
''' forward model for mini-batch
|
270 |
+
|
271 |
+
Parameters
|
272 |
+
----------
|
273 |
+
prod_rule_idx_seq : (batch_size, length)
|
274 |
+
|
275 |
+
Returns
|
276 |
+
-------
|
277 |
+
Variable, shape (batch_size, length, out_dim)
|
278 |
+
'''
|
279 |
+
batch_size, length = prod_rule_idx_seq.shape
|
280 |
+
if self.use_gpu:
|
281 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda()
|
282 |
+
else:
|
283 |
+
out = Variable(torch.zeros((batch_size, length, self.out_dim)))
|
284 |
+
for each_batch_idx in range(batch_size):
|
285 |
+
for each_idx in range(length):
|
286 |
+
if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list):
|
287 |
+
continue
|
288 |
+
else:
|
289 |
+
each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])]
|
290 |
+
edge_list = sorted(list(each_prod_rule.rhs.edges))
|
291 |
+
node_list = sorted(list(each_prod_rule.rhs.nodes))
|
292 |
+
adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list)))
|
293 |
+
if self.use_gpu:
|
294 |
+
adj_mat = adj_mat.cuda()
|
295 |
+
layer_wise_embed = [
|
296 |
+
self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']]
|
297 |
+
for each_edge in edge_list]\
|
298 |
+
+ [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']]
|
299 |
+
for each_node in node_list]
|
300 |
+
for each_node in each_prod_rule.ext_node.values():
|
301 |
+
layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
|
302 |
+
= layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \
|
303 |
+
+ self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])]
|
304 |
+
layer_wise_embed = torch.stack(layer_wise_embed)
|
305 |
+
|
306 |
+
for each_layer in range(self.num_layers):
|
307 |
+
message = adj_mat @ layer_wise_embed
|
308 |
+
next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message))
|
309 |
+
out[each_batch_idx, each_idx, :] \
|
310 |
+
= out[each_batch_idx, each_idx, :] \
|
311 |
+
+ self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0)
|
312 |
+
layer_wise_embed = next_layer_embed
|
313 |
+
return out
|
models/mhg_model/load.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
6 |
+
|
7 |
+
import os
|
8 |
+
import pickle
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from rdkit import Chem
|
12 |
+
import torch
|
13 |
+
from torch_geometric.utils.smiles import from_smiles
|
14 |
+
|
15 |
+
from typing import Any, Dict, List, Optional, Union
|
16 |
+
from typing_extensions import Self
|
17 |
+
|
18 |
+
from .graph_grammar.io.smi import hg_to_mol
|
19 |
+
from .models.mhgvae import GrammarGINVAE
|
20 |
+
|
21 |
+
from huggingface_hub import hf_hub_download
|
22 |
+
|
23 |
+
|
24 |
+
class PretrainedModelWrapper:
|
25 |
+
model: GrammarGINVAE
|
26 |
+
|
27 |
+
def __init__(self, model_dict: Dict[str, Any]) -> None:
|
28 |
+
json_params = model_dict['gnn_params']
|
29 |
+
encoder_params = json_params['encoder_params']
|
30 |
+
encoder_params['node_feature_size'] = model_dict['num_features']
|
31 |
+
encoder_params['edge_feature_size'] = model_dict['num_edge_features']
|
32 |
+
self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params,
|
33 |
+
decoder_params=json_params['decoder_params'],
|
34 |
+
prod_rule_embed_params=json_params["prod_rule_embed_params"],
|
35 |
+
batch_size=512, max_len=model_dict['max_length'])
|
36 |
+
self.model.load_state_dict(model_dict['model_state_dict'])
|
37 |
+
|
38 |
+
self.model.eval()
|
39 |
+
|
40 |
+
def to(self, device: Union[str, int, torch.device]) -> Self:
|
41 |
+
dev_type = type(device)
|
42 |
+
if dev_type != torch.device:
|
43 |
+
if dev_type == str or torch.cuda.is_available():
|
44 |
+
device = torch.device(device)
|
45 |
+
else:
|
46 |
+
device = torch.device("mps", device)
|
47 |
+
|
48 |
+
self.model = self.model.to(device)
|
49 |
+
return self
|
50 |
+
|
51 |
+
def encode(self, data: List[str]) -> List[torch.tensor]:
|
52 |
+
# Need to encode them into a graph nn
|
53 |
+
output = []
|
54 |
+
for d in data:
|
55 |
+
params = next(self.model.parameters())
|
56 |
+
g = from_smiles(d)
|
57 |
+
if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'):
|
58 |
+
g.to(params.device)
|
59 |
+
ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch)
|
60 |
+
output.append(ltvec[0])
|
61 |
+
return output
|
62 |
+
|
63 |
+
def decode(self, data: List[torch.tensor]) -> List[str]:
|
64 |
+
output = []
|
65 |
+
for d in data:
|
66 |
+
mu, logvar = self.model.get_mean_var(d.unsqueeze(0))
|
67 |
+
z = self.model.reparameterize(mu, logvar)
|
68 |
+
flags, _, hgs = self.model.decode(z)
|
69 |
+
if flags[0]:
|
70 |
+
reconstructed_mol, _ = hg_to_mol(hgs[0], True)
|
71 |
+
output.append(Chem.MolToSmiles(reconstructed_mol))
|
72 |
+
else:
|
73 |
+
output.append(None)
|
74 |
+
return output
|
75 |
+
|
76 |
+
|
77 |
+
def load(model_name: str = "mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[
|
78 |
+
PretrainedModelWrapper]:
|
79 |
+
|
80 |
+
repo_id = "ibm/materials.mhg-ged"
|
81 |
+
filename = "pytorch_model.bin" #"mhggnn_pretrained_model_0724_2023.pickle"
|
82 |
+
file_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
83 |
+
with open(file_path, "rb") as f:
|
84 |
+
model_dict = torch.load(f, weights_only=False)
|
85 |
+
return PretrainedModelWrapper(model_dict)
|
86 |
+
|
87 |
+
|
88 |
+
"""try:
|
89 |
+
if os.path.isfile(model_name):
|
90 |
+
with open(model_name, "rb") as f:
|
91 |
+
model_dict = pickle.load(f)
|
92 |
+
print("MHG Model Loaded")
|
93 |
+
return PretrainedModelWrapper(model_dict)
|
94 |
+
|
95 |
+
except:
|
96 |
+
|
97 |
+
for p in sys.path:
|
98 |
+
file = p + "/" + model_name
|
99 |
+
if os.path.isfile(file):
|
100 |
+
with open(file, "rb") as f:
|
101 |
+
model_dict = pickle.load(f)
|
102 |
+
return PretrainedModelWrapper(model_dict)"""
|
103 |
+
return None
|
models/mhg_model/mhg_gnn.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: mhg-gnn
|
3 |
+
Version: 0.0
|
4 |
+
Summary: Package for mhg-gnn
|
5 |
+
Author: team
|
6 |
+
License: TBD
|
7 |
+
Classifier: Programming Language :: Python :: 3
|
8 |
+
Classifier: Programming Language :: Python :: 3.9
|
9 |
+
Description-Content-Type: text/markdown
|
10 |
+
Requires-Dist: networkx>=2.8
|
11 |
+
Requires-Dist: numpy<2.0.0,>=1.23.5
|
12 |
+
Requires-Dist: pandas>=1.5.3
|
13 |
+
Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4
|
14 |
+
Requires-Dist: torch>=2.0.0
|
15 |
+
Requires-Dist: torchinfo>=1.8.0
|
16 |
+
Requires-Dist: torch-geometric>=2.3.1
|
17 |
+
|
18 |
+
# mhg-gnn
|
19 |
+
|
20 |
+
This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network"
|
21 |
+
|
22 |
+
**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374)
|
23 |
+
|
24 |
+
For more information contact: [email protected]
|
25 |
+
|
26 |
+

|
27 |
+
|
28 |
+
## Introduction
|
29 |
+
|
30 |
+
We present MHG-GNN, an autoencoder architecture
|
31 |
+
that has an encoder based on GNN and a decoder based on a sequential model with MHG.
|
32 |
+
Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and
|
33 |
+
demonstrate high predictive performance on molecular graph data.
|
34 |
+
In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output.
|
35 |
+
|
36 |
+
## Table of Contents
|
37 |
+
|
38 |
+
1. [Getting Started](#getting-started)
|
39 |
+
1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs)
|
40 |
+
2. [Replicating Conda Environment](#replicating-conda-environment)
|
41 |
+
2. [Feature Extraction](#feature-extraction)
|
42 |
+
|
43 |
+
## Getting Started
|
44 |
+
|
45 |
+
**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.**
|
46 |
+
|
47 |
+
### Pretrained Models and Training Logs
|
48 |
+
|
49 |
+
We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]()
|
50 |
+
|
51 |
+
Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs.
|
52 |
+
|
53 |
+
### Replacicating Conda Environment
|
54 |
+
|
55 |
+
Follow these steps to replicate our Conda environment and install the necessary libraries:
|
56 |
+
|
57 |
+
```
|
58 |
+
conda create --name mhg-gnn-env python=3.8.18
|
59 |
+
conda activate mhg-gnn-env
|
60 |
+
```
|
61 |
+
|
62 |
+
#### Install Packages with Conda
|
63 |
+
|
64 |
+
```
|
65 |
+
conda install -c conda-forge networkx=2.8
|
66 |
+
conda install numpy=1.23.5
|
67 |
+
# conda install -c conda-forge rdkit=2022.9.4
|
68 |
+
conda install pytorch=2.0.0 torchvision torchaudio -c pytorch
|
69 |
+
conda install -c conda-forge torchinfo=1.8.0
|
70 |
+
conda install pyg -c pyg
|
71 |
+
```
|
72 |
+
|
73 |
+
#### Install Packages with pip
|
74 |
+
```
|
75 |
+
pip install rdkit torch-nl==0.3 torch-scatter torch-sparse
|
76 |
+
```
|
77 |
+
|
78 |
+
## Feature Extraction
|
79 |
+
|
80 |
+
The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks.
|
81 |
+
|
82 |
+
To load mhg-gnn, you can simply use:
|
83 |
+
|
84 |
+
```python
|
85 |
+
import torch
|
86 |
+
import load
|
87 |
+
|
88 |
+
model = load.load()
|
89 |
+
```
|
90 |
+
|
91 |
+
To encode SMILES into embeddings, you can use:
|
92 |
+
|
93 |
+
```python
|
94 |
+
with torch.no_grad():
|
95 |
+
repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"])
|
96 |
+
```
|
97 |
+
|
98 |
+
For decoder, you can use the function, so you can return from embeddings to SMILES strings:
|
99 |
+
|
100 |
+
```python
|
101 |
+
orig = model.decode(repr)
|
102 |
+
```
|
models/mhg_model/mhg_gnn.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.cfg
|
3 |
+
setup.py
|
4 |
+
./graph_grammar/__init__.py
|
5 |
+
./graph_grammar/hypergraph.py
|
6 |
+
./graph_grammar/algo/__init__.py
|
7 |
+
./graph_grammar/algo/tree_decomposition.py
|
8 |
+
./graph_grammar/graph_grammar/__init__.py
|
9 |
+
./graph_grammar/graph_grammar/base.py
|
10 |
+
./graph_grammar/graph_grammar/corpus.py
|
11 |
+
./graph_grammar/graph_grammar/hrg.py
|
12 |
+
./graph_grammar/graph_grammar/symbols.py
|
13 |
+
./graph_grammar/graph_grammar/utils.py
|
14 |
+
./graph_grammar/io/__init__.py
|
15 |
+
./graph_grammar/io/smi.py
|
16 |
+
./graph_grammar/nn/__init__.py
|
17 |
+
./graph_grammar/nn/dataset.py
|
18 |
+
./graph_grammar/nn/decoder.py
|
19 |
+
./graph_grammar/nn/encoder.py
|
20 |
+
./graph_grammar/nn/graph.py
|
21 |
+
./models/__init__.py
|
22 |
+
./models/mhgvae.py
|
23 |
+
graph_grammar/__init__.py
|
24 |
+
graph_grammar/hypergraph.py
|
25 |
+
graph_grammar/algo/__init__.py
|
26 |
+
graph_grammar/algo/tree_decomposition.py
|
27 |
+
graph_grammar/graph_grammar/__init__.py
|
28 |
+
graph_grammar/graph_grammar/base.py
|
29 |
+
graph_grammar/graph_grammar/corpus.py
|
30 |
+
graph_grammar/graph_grammar/hrg.py
|
31 |
+
graph_grammar/graph_grammar/symbols.py
|
32 |
+
graph_grammar/graph_grammar/utils.py
|
33 |
+
graph_grammar/io/__init__.py
|
34 |
+
graph_grammar/io/smi.py
|
35 |
+
graph_grammar/nn/__init__.py
|
36 |
+
graph_grammar/nn/dataset.py
|
37 |
+
graph_grammar/nn/decoder.py
|
38 |
+
graph_grammar/nn/encoder.py
|
39 |
+
graph_grammar/nn/graph.py
|
40 |
+
mhg_gnn.egg-info/PKG-INFO
|
41 |
+
mhg_gnn.egg-info/SOURCES.txt
|
42 |
+
mhg_gnn.egg-info/dependency_links.txt
|
43 |
+
mhg_gnn.egg-info/requires.txt
|
44 |
+
mhg_gnn.egg-info/top_level.txt
|
45 |
+
models/__init__.py
|
46 |
+
models/mhgvae.py
|
models/mhg_model/mhg_gnn.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
models/mhg_model/mhg_gnn.egg-info/requires.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
networkx>=2.8
|
2 |
+
numpy<2.0.0,>=1.23.5
|
3 |
+
pandas>=1.5.3
|
4 |
+
rdkit-pypi<2023.9.6,>=2022.9.4
|
5 |
+
torch>=2.0.0
|
6 |
+
torchinfo>=1.8.0
|
7 |
+
torch-geometric>=2.3.1
|
models/mhg_model/mhg_gnn.egg-info/top_level.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
graph_grammar
|
2 |
+
models
|
models/mhg_model/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
models/mhg_model/models/mhgvae.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding:utf-8 -*-
|
2 |
+
# Rhizome
|
3 |
+
# Version beta 0.0, August 2023
|
4 |
+
# Property of IBM Research, Accelerated Discovery
|
5 |
+
#
|
6 |
+
|
7 |
+
"""
|
8 |
+
PLEASE NOTE THIS IMPLEMENTATION INCLUDES ADAPTED SOURCE CODE
|
9 |
+
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE,
|
10 |
+
E.G., GRUEncoder/GRUDecoder, GrammarSeq2SeqVAE AND EVEN SOME METHODS OF GrammarGINVAE.
|
11 |
+
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import logging
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch.autograd import Variable
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.nn.modules.loss import _Loss
|
22 |
+
|
23 |
+
from torch_geometric.nn import MessagePassing
|
24 |
+
from torch_geometric.nn import global_add_pool
|
25 |
+
|
26 |
+
|
27 |
+
from ..graph_grammar.graph_grammar.symbols import NTSymbol
|
28 |
+
from ..graph_grammar.nn.encoder import EncoderBase
|
29 |
+
from ..graph_grammar.nn.decoder import DecoderBase
|
30 |
+
|
31 |
+
def get_atom_edge_feature_dims():
|
32 |
+
from torch_geometric.utils.smiles import x_map, e_map
|
33 |
+
func = lambda x: len(x[1])
|
34 |
+
return list(map(func, x_map.items())), list(map(func, e_map.items()))
|
35 |
+
|
36 |
+
|
37 |
+
class FeatureEmbedding(nn.Module):
|
38 |
+
def __init__(self, input_dims, embedded_dim):
|
39 |
+
super().__init__()
|
40 |
+
self.embedding_list = nn.ModuleList()
|
41 |
+
for dim in input_dims:
|
42 |
+
embedding = nn.Embedding(dim, embedded_dim)
|
43 |
+
self.embedding_list.append(embedding)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
output = 0
|
47 |
+
for i in range(x.shape[1]):
|
48 |
+
input = x[:, i].to(torch.int)
|
49 |
+
device = next(self.parameters()).device
|
50 |
+
if device != input.device:
|
51 |
+
input = input.to(device)
|
52 |
+
emb = self.embedding_list[i](input)
|
53 |
+
output += emb
|
54 |
+
return output
|
55 |
+
|
56 |
+
|
57 |
+
class GRUEncoder(EncoderBase):
|
58 |
+
|
59 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
60 |
+
bidirectional: bool, dropout: float, batch_size: int, rank: int=-1,
|
61 |
+
no_dropout: bool=False):
|
62 |
+
super().__init__()
|
63 |
+
self.input_dim = input_dim
|
64 |
+
self.hidden_dim = hidden_dim
|
65 |
+
self.num_layers = num_layers
|
66 |
+
self.bidirectional = bidirectional
|
67 |
+
self.dropout = dropout
|
68 |
+
self.batch_size = batch_size
|
69 |
+
self.rank = rank
|
70 |
+
self.model = nn.GRU(input_size=self.input_dim,
|
71 |
+
hidden_size=self.hidden_dim,
|
72 |
+
num_layers=self.num_layers,
|
73 |
+
batch_first=True,
|
74 |
+
bidirectional=self.bidirectional,
|
75 |
+
dropout=self.dropout if not no_dropout else 0)
|
76 |
+
if self.rank >= 0:
|
77 |
+
if torch.cuda.is_available():
|
78 |
+
self.model = self.model.to(rank)
|
79 |
+
else:
|
80 |
+
# support mac mps
|
81 |
+
self.model = self.model.to(torch.device("mps", rank))
|
82 |
+
self.init_hidden(self.batch_size)
|
83 |
+
|
84 |
+
def init_hidden(self, bsize):
|
85 |
+
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers,
|
86 |
+
min(self.batch_size, bsize),
|
87 |
+
self.hidden_dim),
|
88 |
+
requires_grad=False)
|
89 |
+
if self.rank >= 0:
|
90 |
+
if torch.cuda.is_available():
|
91 |
+
self.h0 = self.h0.to(self.rank)
|
92 |
+
else:
|
93 |
+
# support mac mps
|
94 |
+
self.h0 = self.h0.to(torch.device("mps", self.rank))
|
95 |
+
|
96 |
+
def to(self, device):
|
97 |
+
newself = super().to(device)
|
98 |
+
newself.model = newself.model.to(device)
|
99 |
+
newself.h0 = newself.h0.to(device)
|
100 |
+
newself.rank = next(newself.parameters()).get_device()
|
101 |
+
return newself
|
102 |
+
|
103 |
+
def forward(self, in_seq_emb):
|
104 |
+
''' forward model
|
105 |
+
|
106 |
+
Parameters
|
107 |
+
----------
|
108 |
+
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim)
|
109 |
+
|
110 |
+
Returns
|
111 |
+
-------
|
112 |
+
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim)
|
113 |
+
'''
|
114 |
+
# Kishi: I think original MHG had this init_hidden()
|
115 |
+
self.init_hidden(in_seq_emb.size(0))
|
116 |
+
max_len = in_seq_emb.size(1)
|
117 |
+
hidden_seq_emb, self.h0 = self.model(
|
118 |
+
in_seq_emb, self.h0)
|
119 |
+
# As shown as returns, convert hidden_seq_emb: (batch_size, seq_len, (1 or 2) * hidden_size) -->
|
120 |
+
# (batch_size, seq_len, 1 or 2, hidden_size)
|
121 |
+
# In the original input the original GRU/LSTM with bidirectional encoding
|
122 |
+
# has contactinated tensors
|
123 |
+
# (first half for forward RNN, latter half for backward RNN)
|
124 |
+
# so convert them in a more friendly format packed for each RNN
|
125 |
+
hidden_seq_emb = hidden_seq_emb.view(-1,
|
126 |
+
max_len,
|
127 |
+
1 + self.bidirectional,
|
128 |
+
self.hidden_dim)
|
129 |
+
return hidden_seq_emb
|
130 |
+
|
131 |
+
|
132 |
+
class GRUDecoder(DecoderBase):
|
133 |
+
|
134 |
+
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int,
|
135 |
+
dropout: float, batch_size: int, rank: int=-1,
|
136 |
+
no_dropout: bool=False):
|
137 |
+
super().__init__()
|
138 |
+
self.input_dim = input_dim
|
139 |
+
self.hidden_dim = hidden_dim
|
140 |
+
self.num_layers = num_layers
|
141 |
+
self.dropout = dropout
|
142 |
+
self.batch_size = batch_size
|
143 |
+
self.rank = rank
|
144 |
+
self.model = nn.GRU(input_size=self.input_dim,
|
145 |
+
hidden_size=self.hidden_dim,
|
146 |
+
num_layers=self.num_layers,
|
147 |
+
batch_first=True,
|
148 |
+
bidirectional=False,
|
149 |
+
dropout=self.dropout if not no_dropout else 0
|
150 |
+
)
|
151 |
+
if self.rank >= 0:
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
self.model = self.model.to(self.rank)
|
154 |
+
else:
|
155 |
+
# support mac mps
|
156 |
+
self.model = self.model.to(torch.device("mps", self.rank))
|
157 |
+
self.init_hidden(self.batch_size)
|
158 |
+
|
159 |
+
def init_hidden(self, bsize):
|
160 |
+
self.hidden_dict['h'] = torch.zeros((self.num_layers,
|
161 |
+
min(self.batch_size, bsize),
|
162 |
+
self.hidden_dim),
|
163 |
+
requires_grad=False)
|
164 |
+
if self.rank >= 0:
|
165 |
+
if torch.cuda.is_available():
|
166 |
+
self.hidden_dict['h'] = self.hidden_dict['h'].to(self.rank)
|
167 |
+
else:
|
168 |
+
self.hidden_dict['h'] = self.hidden_dict['h'].to(torch.device("mps", self.rank))
|
169 |
+
|
170 |
+
def to(self, device):
|
171 |
+
newself = super().to(device)
|
172 |
+
newself.model = newself.model.to(device)
|
173 |
+
for k in self.hidden_dict.keys():
|
174 |
+
newself.hidden_dict[k] = newself.hidden_dict[k].to(device)
|
175 |
+
newself.rank = next(newself.parameters()).get_device()
|
176 |
+
return newself
|
177 |
+
|
178 |
+
def forward_one_step(self, tgt_emb_in):
|
179 |
+
''' one-step forward model
|
180 |
+
|
181 |
+
Parameters
|
182 |
+
----------
|
183 |
+
tgt_emb_in : Tensor, shape (batch_size, input_dim)
|
184 |
+
|
185 |
+
Returns
|
186 |
+
-------
|
187 |
+
Tensor, shape (batch_size, hidden_dim)
|
188 |
+
'''
|
189 |
+
bsize = tgt_emb_in.size(0)
|
190 |
+
tgt_emb_out, self.hidden_dict['h'] \
|
191 |
+
= self.model(tgt_emb_in.view(bsize, 1, -1),
|
192 |
+
self.hidden_dict['h'])
|
193 |
+
return tgt_emb_out
|
194 |
+
|
195 |
+
|
196 |
+
class NodeMLP(nn.Module):
|
197 |
+
def __init__(self, input_size, output_size, hidden_size):
|
198 |
+
super().__init__()
|
199 |
+
self.lin1 = nn.Linear(input_size, hidden_size)
|
200 |
+
self.nbat = nn.BatchNorm1d(hidden_size)
|
201 |
+
self.lin2 = nn.Linear(hidden_size, output_size)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
x = self.lin1(x)
|
205 |
+
x = self.nbat(x)
|
206 |
+
x = x.relu()
|
207 |
+
x = self.lin2(x)
|
208 |
+
return x
|
209 |
+
|
210 |
+
|
211 |
+
class GINLayer(MessagePassing):
|
212 |
+
def __init__(self, node_input_size, node_output_size, node_hidden_size, edge_input_size):
|
213 |
+
super().__init__()
|
214 |
+
self.node_mlp = NodeMLP(node_input_size, node_output_size, node_hidden_size)
|
215 |
+
self.edge_mlp = FeatureEmbedding(edge_input_size, node_output_size)
|
216 |
+
self.eps = nn.Parameter(torch.tensor([0.0]))
|
217 |
+
|
218 |
+
def forward(self, x, edge_index, edge_attr):
|
219 |
+
msg = self.propagate(edge_index, x=x ,edge_attr=edge_attr)
|
220 |
+
x = (1.0 + self.eps) * x + msg
|
221 |
+
x = x.relu()
|
222 |
+
x = self.node_mlp(x)
|
223 |
+
return x
|
224 |
+
|
225 |
+
def message(self, x_j, edge_attr):
|
226 |
+
edge_attr = self.edge_mlp(edge_attr)
|
227 |
+
x_j = x_j + edge_attr
|
228 |
+
x_j = x_j.relu()
|
229 |
+
return x_j
|
230 |
+
|
231 |
+
def update(self, aggr_out):
|
232 |
+
return aggr_out
|
233 |
+
|
234 |
+
#TODO implement the case where features of atoms and edges are considered
|
235 |
+
# Check GraphMVP and ogb (open graph benchmark) to realize this
|
236 |
+
class GIN(torch.nn.Module):
|
237 |
+
def __init__(self, node_feature_size, edge_feature_size, hidden_channels=64,
|
238 |
+
proximity_size=3, dropout=0.1):
|
239 |
+
super().__init__()
|
240 |
+
#print("(num node features, num edge features)=", (node_feature_size, edge_feature_size))
|
241 |
+
hsize = hidden_channels * 2
|
242 |
+
atom_dim, edge_dim = get_atom_edge_feature_dims()
|
243 |
+
self.trans = FeatureEmbedding(atom_dim, hidden_channels)
|
244 |
+
ml = []
|
245 |
+
for _ in range(proximity_size):
|
246 |
+
ml.append(GINLayer(hidden_channels, hidden_channels, hsize, edge_dim))
|
247 |
+
self.mlist = nn.ModuleList(ml)
|
248 |
+
#It is possible to calculate relu with x.relu() where x is an output
|
249 |
+
#self.activations = nn.ModuleList(actl)
|
250 |
+
self.dropout = dropout
|
251 |
+
self.proximity_size = proximity_size
|
252 |
+
|
253 |
+
def forward(self, x, edge_index, edge_attr, batch_size):
|
254 |
+
x = x.to(torch.float)
|
255 |
+
#print("before: edge_weight.shape=", edge_attr.shape)
|
256 |
+
edge_attr = edge_attr.to(torch.float)
|
257 |
+
#print("after: edge_weight.shape=", edge_attr.shape)
|
258 |
+
x = self.trans(x)
|
259 |
+
# TODO Check if this x is consistent with global_add_pool
|
260 |
+
hlist = [global_add_pool(x, batch_size)]
|
261 |
+
for id, m in enumerate(self.mlist):
|
262 |
+
x = m(x, edge_index=edge_index, edge_attr=edge_attr)
|
263 |
+
#print("Done with one layer")
|
264 |
+
###if id != self.proximity_size - 1:
|
265 |
+
x = x.relu()
|
266 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
267 |
+
#h = global_mean_pool(x, batch_size)
|
268 |
+
h = global_add_pool(x, batch_size)
|
269 |
+
hlist.append(h)
|
270 |
+
#print("Done with one relu call: x.shape=", x.shape)
|
271 |
+
#print("calling golbal mean pool")
|
272 |
+
#print("calling dropout x.shape=", x.shape)
|
273 |
+
#print("x=", x)
|
274 |
+
#print("hlist[0].shape=", hlist[0].shape)
|
275 |
+
x = torch.cat(hlist, dim=1)
|
276 |
+
#print("x.shape=", x.shape)
|
277 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
278 |
+
|
279 |
+
return x
|
280 |
+
|
281 |
+
|
282 |
+
# TODO copied from MHG implementation and adapted here.
|
283 |
+
class GrammarSeq2SeqVAE(nn.Module):
|
284 |
+
|
285 |
+
'''
|
286 |
+
Variational seq2seq with grammar.
|
287 |
+
TODO: rewrite this class using mixin
|
288 |
+
'''
|
289 |
+
|
290 |
+
def __init__(self, hrg, rank=-1, latent_dim=64, max_len=80,
|
291 |
+
batch_size=64, padding_idx=-1,
|
292 |
+
encoder_params={'hidden_dim': 384, 'num_layers': 3, 'bidirectional': True,
|
293 |
+
'dropout': 0.1},
|
294 |
+
decoder_params={'hidden_dim': 384, #'num_layers': 2,
|
295 |
+
'num_layers': 3,
|
296 |
+
'dropout': 0.1},
|
297 |
+
prod_rule_embed_params={'out_dim': 128},
|
298 |
+
no_dropout=False):
|
299 |
+
|
300 |
+
super().__init__()
|
301 |
+
# TODO USE GRU FOR ENCODING AND DECODING
|
302 |
+
self.hrg = hrg
|
303 |
+
self.rank = rank
|
304 |
+
self.prod_rule_corpus = hrg.prod_rule_corpus
|
305 |
+
self.prod_rule_embed_params = prod_rule_embed_params
|
306 |
+
|
307 |
+
self.vocab_size = hrg.num_prod_rule + 1
|
308 |
+
self.batch_size = batch_size
|
309 |
+
self.padding_idx = np.mod(padding_idx, self.vocab_size)
|
310 |
+
self.no_dropout = no_dropout
|
311 |
+
|
312 |
+
self.latent_dim = latent_dim
|
313 |
+
self.max_len = max_len
|
314 |
+
self.encoder_params = encoder_params
|
315 |
+
self.decoder_params = decoder_params
|
316 |
+
|
317 |
+
# TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
|
318 |
+
embed_out_dim = self.prod_rule_embed_params['out_dim']
|
319 |
+
#use MolecularProdRuleEmbedding later on
|
320 |
+
self.src_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
|
321 |
+
padding_idx=self.padding_idx)
|
322 |
+
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
|
323 |
+
padding_idx=self.padding_idx)
|
324 |
+
|
325 |
+
# USE a GRU-based encoder in MHG
|
326 |
+
self.encoder = GRUEncoder(input_dim=embed_out_dim, batch_size=self.batch_size,
|
327 |
+
rank=self.rank, no_dropout=self.no_dropout,
|
328 |
+
**self.encoder_params)
|
329 |
+
|
330 |
+
lin_dim = (self.encoder_params.get('bidirectional', False) + 1) * self.encoder_params['hidden_dim']
|
331 |
+
lin_out_dim = self.latent_dim
|
332 |
+
self.hidden2mean = nn.Linear(lin_dim, lin_out_dim, bias=False)
|
333 |
+
self.hidden2logvar = nn.Linear(lin_dim, lin_out_dim)
|
334 |
+
|
335 |
+
# USE a GRU-based decoder in MHG
|
336 |
+
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
|
337 |
+
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
|
338 |
+
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
|
339 |
+
self.latent2hidden_dict = nn.ModuleDict()
|
340 |
+
dec_lin_out_dim = self.decoder_params['hidden_dim']
|
341 |
+
for each_hidden in self.decoder.hidden_dict.keys():
|
342 |
+
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, dec_lin_out_dim)
|
343 |
+
if self.rank >= 0:
|
344 |
+
if torch.cuda.is_available():
|
345 |
+
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
|
346 |
+
else:
|
347 |
+
# support mac mps
|
348 |
+
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
|
349 |
+
|
350 |
+
self.dec2vocab = nn.Linear(dec_lin_out_dim, self.vocab_size)
|
351 |
+
self.encoder.init_hidden(self.batch_size)
|
352 |
+
self.decoder.init_hidden(self.batch_size)
|
353 |
+
|
354 |
+
# TODO Do we need this?
|
355 |
+
if hasattr(self.src_embedding, 'weight'):
|
356 |
+
self.src_embedding.weight.data.uniform_(-0.1, 0.1)
|
357 |
+
if hasattr(self.tgt_embedding, 'weight'):
|
358 |
+
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
|
359 |
+
|
360 |
+
self.encoder.init_hidden(self.batch_size)
|
361 |
+
self.decoder.init_hidden(self.batch_size)
|
362 |
+
|
363 |
+
def to(self, device):
|
364 |
+
newself = super().to(device)
|
365 |
+
newself.src_embedding = newself.src_embedding.to(device)
|
366 |
+
newself.tgt_embedding = newself.tgt_embedding.to(device)
|
367 |
+
newself.encoder = newself.encoder.to(device)
|
368 |
+
newself.decoder = newself.decoder.to(device)
|
369 |
+
newself.dec2vocab = newself.dec2vocab.to(device)
|
370 |
+
newself.hidden2mean = newself.hidden2mean.to(device)
|
371 |
+
newself.hidden2logvar = newself.hidden2logvar.to(device)
|
372 |
+
newself.latent2tgt_emb = newself.latent2tgt_emb.to(device)
|
373 |
+
newself.latent2hidden_dict = newself.latent2hidden_dict.to(device)
|
374 |
+
return newself
|
375 |
+
|
376 |
+
def forward(self, in_seq, out_seq):
|
377 |
+
''' forward model
|
378 |
+
|
379 |
+
Parameters
|
380 |
+
----------
|
381 |
+
in_seq : Variable, shape (batch_size, length)
|
382 |
+
each element corresponds to word index.
|
383 |
+
where the index should be less than `vocab_size`
|
384 |
+
|
385 |
+
Returns
|
386 |
+
-------
|
387 |
+
Variable, shape (batch_size, length, vocab_size)
|
388 |
+
logit of each word (applying softmax yields the probability)
|
389 |
+
'''
|
390 |
+
mu, logvar = self.encode(in_seq)
|
391 |
+
z = self.reparameterize(mu, logvar)
|
392 |
+
return self.decode(z, out_seq), mu, logvar
|
393 |
+
|
394 |
+
def encode(self, in_seq):
|
395 |
+
src_emb = self.src_embedding(in_seq)
|
396 |
+
src_h = self.encoder.forward(src_emb)
|
397 |
+
if self.encoder_params.get('bidirectional', False):
|
398 |
+
concat_src_h = torch.cat((src_h[:, -1, 0, :], src_h[:, 0, 1, :]), dim=1)
|
399 |
+
return self.hidden2mean(concat_src_h), self.hidden2logvar(concat_src_h)
|
400 |
+
else:
|
401 |
+
return self.hidden2mean(src_h[:, -1, :]), self.hidden2logvar(src_h[:, -1, :])
|
402 |
+
|
403 |
+
def reparameterize(self, mu, logvar, training=True):
|
404 |
+
if training:
|
405 |
+
std = logvar.mul(0.5).exp_()
|
406 |
+
device = next(self.parameters()).device
|
407 |
+
eps = Variable(std.data.new(std.size()).normal_())
|
408 |
+
if device != eps.get_device():
|
409 |
+
eps.to(device)
|
410 |
+
return eps.mul(std).add_(mu)
|
411 |
+
else:
|
412 |
+
return mu
|
413 |
+
|
414 |
+
#TODO Not tested. Need to implement this in case of molecular structure generation
|
415 |
+
def sample(self, sample_size=-1, deterministic=True, return_z=False):
|
416 |
+
self.eval()
|
417 |
+
self.init_hidden()
|
418 |
+
if sample_size == -1:
|
419 |
+
sample_size = self.batch_size
|
420 |
+
|
421 |
+
num_iter = int(np.ceil(sample_size / self.batch_size))
|
422 |
+
hg_list = []
|
423 |
+
z_list = []
|
424 |
+
for _ in range(num_iter):
|
425 |
+
z = Variable(torch.normal(
|
426 |
+
torch.zeros(self.batch_size, self.latent_dim),
|
427 |
+
torch.ones(self.batch_size * self.latent_dim))).cuda()
|
428 |
+
_, each_hg_list = self.decode(z, deterministic=deterministic)
|
429 |
+
z_list.append(z)
|
430 |
+
hg_list += each_hg_list
|
431 |
+
z = torch.cat(z_list)[:sample_size]
|
432 |
+
hg_list = hg_list[:sample_size]
|
433 |
+
if return_z:
|
434 |
+
return hg_list, z.cpu().detach().numpy()
|
435 |
+
else:
|
436 |
+
return hg_list
|
437 |
+
|
438 |
+
def decode(self, z=None, out_seq=None, deterministic=True):
|
439 |
+
if z is None:
|
440 |
+
z = Variable(torch.normal(
|
441 |
+
torch.zeros(self.batch_size, self.latent_dim),
|
442 |
+
torch.ones(self.batch_size * self.latent_dim)))
|
443 |
+
if self.rank >= 0:
|
444 |
+
z = z.to(next(self.parameters()).device)
|
445 |
+
|
446 |
+
hidden_dict_0 = {}
|
447 |
+
for each_hidden in self.latent2hidden_dict.keys():
|
448 |
+
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
|
449 |
+
bsize = z.size(0)
|
450 |
+
self.decoder.init_hidden(bsize)
|
451 |
+
self.decoder.feed_hidden(hidden_dict_0)
|
452 |
+
|
453 |
+
if out_seq is not None:
|
454 |
+
tgt_emb0 = self.latent2tgt_emb(z)
|
455 |
+
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
|
456 |
+
out_seq_emb = self.tgt_embedding(out_seq)
|
457 |
+
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
|
458 |
+
tgt_emb_pred_list = []
|
459 |
+
for each_idx in range(self.max_len):
|
460 |
+
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb[:, each_idx, :].view(bsize, 1, -1))
|
461 |
+
tgt_emb_pred_list.append(tgt_emb_pred)
|
462 |
+
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
463 |
+
return vocab_logit
|
464 |
+
else:
|
465 |
+
with torch.no_grad():
|
466 |
+
tgt_emb = self.latent2tgt_emb(z)
|
467 |
+
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
|
468 |
+
tgt_emb_pred_list = []
|
469 |
+
stack_list = []
|
470 |
+
hg_list = []
|
471 |
+
nt_symbol_list = []
|
472 |
+
nt_edge_list = []
|
473 |
+
gen_finish_list = []
|
474 |
+
for _ in range(bsize):
|
475 |
+
stack_list.append([])
|
476 |
+
hg_list.append(None)
|
477 |
+
nt_symbol_list.append(NTSymbol(degree=0,
|
478 |
+
is_aromatic=False,
|
479 |
+
bond_symbol_list=[]))
|
480 |
+
nt_edge_list.append(None)
|
481 |
+
gen_finish_list.append(False)
|
482 |
+
|
483 |
+
for idx in range(self.max_len):
|
484 |
+
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
485 |
+
tgt_emb_pred_list.append(tgt_emb_pred)
|
486 |
+
vocab_logit = self.dec2vocab(tgt_emb_pred)
|
487 |
+
for each_batch_idx in range(bsize):
|
488 |
+
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
489 |
+
# get production rule greedily
|
490 |
+
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
|
491 |
+
nt_symbol_list[each_batch_idx],
|
492 |
+
deterministic=deterministic)
|
493 |
+
# convert production rule into an index
|
494 |
+
tgt_id = self.hrg.prod_rule_list.index(prod_rule)
|
495 |
+
# apply the production rule
|
496 |
+
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
|
497 |
+
# add non-terminals to the stack
|
498 |
+
stack_list[each_batch_idx].extend(nt_edges[::-1])
|
499 |
+
# if the stack size is 0, generation has finished!
|
500 |
+
if len(stack_list[each_batch_idx]) == 0:
|
501 |
+
gen_finish_list[each_batch_idx] = True
|
502 |
+
else:
|
503 |
+
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
|
504 |
+
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
|
505 |
+
else:
|
506 |
+
tgt_id = np.mod(self.padding_idx, self.vocab_size)
|
507 |
+
indice_tensor = torch.LongTensor([tgt_id])
|
508 |
+
device = next(self.parameters()).device
|
509 |
+
if indice_tensor.device != device:
|
510 |
+
indice_tensor = indice_tensor.to(device)
|
511 |
+
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
|
512 |
+
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
513 |
+
#for id, v in enumerate(gen_finish_list):
|
514 |
+
#if not v:
|
515 |
+
# print("bacth id={} not finished generating a sequence: ".format(id))
|
516 |
+
return gen_finish_list, vocab_logit, hg_list
|
517 |
+
|
518 |
+
|
519 |
+
# TODO A lot of duplicates with GrammarVAE. Clean up it if necessary
|
520 |
+
class GrammarGINVAE(nn.Module):
|
521 |
+
|
522 |
+
'''
|
523 |
+
Variational autoencoder based on GIN and grammar
|
524 |
+
'''
|
525 |
+
|
526 |
+
def __init__(self, hrg, rank=-1, max_len=80,
|
527 |
+
batch_size=64, padding_idx=-1,
|
528 |
+
encoder_params={'node_feature_size': 4, 'edge_feature_size': 3,
|
529 |
+
'hidden_channels': 64, 'proximity_size': 3,
|
530 |
+
'dropout': 0.1},
|
531 |
+
decoder_params={'hidden_dim': 384, 'num_layers': 3,
|
532 |
+
'dropout': 0.1},
|
533 |
+
prod_rule_embed_params={'out_dim': 128},
|
534 |
+
no_dropout=False):
|
535 |
+
|
536 |
+
super().__init__()
|
537 |
+
# TODO USE GRU FOR ENCODING AND DECODING
|
538 |
+
self.hrg = hrg
|
539 |
+
self.rank = rank
|
540 |
+
self.prod_rule_corpus = hrg.prod_rule_corpus
|
541 |
+
self.prod_rule_embed_params = prod_rule_embed_params
|
542 |
+
|
543 |
+
self.vocab_size = hrg.num_prod_rule + 1
|
544 |
+
self.batch_size = batch_size
|
545 |
+
self.padding_idx = np.mod(padding_idx, self.vocab_size)
|
546 |
+
self.no_dropout = no_dropout
|
547 |
+
self.max_len = max_len
|
548 |
+
self.encoder_params = encoder_params
|
549 |
+
self.decoder_params = decoder_params
|
550 |
+
|
551 |
+
# TODO Simple embedding is used. Check if a domain-dependent embedding works or not.
|
552 |
+
embed_out_dim = self.prod_rule_embed_params['out_dim']
|
553 |
+
#use MolecularProdRuleEmbedding later on
|
554 |
+
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim,
|
555 |
+
padding_idx=self.padding_idx)
|
556 |
+
|
557 |
+
self.encoder = GIN(**self.encoder_params)
|
558 |
+
self.latent_dim = self.encoder_params['hidden_channels']
|
559 |
+
self.proximity_size = self.encoder_params['proximity_size']
|
560 |
+
hidden_dim = self.decoder_params['hidden_dim']
|
561 |
+
self.hidden2mean = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim, bias=False)
|
562 |
+
self.hidden2logvar = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim)
|
563 |
+
|
564 |
+
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size,
|
565 |
+
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params)
|
566 |
+
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim)
|
567 |
+
self.latent2hidden_dict = nn.ModuleDict()
|
568 |
+
for each_hidden in self.decoder.hidden_dict.keys():
|
569 |
+
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, hidden_dim)
|
570 |
+
if self.rank >= 0:
|
571 |
+
if torch.cuda.is_available():
|
572 |
+
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank)
|
573 |
+
else:
|
574 |
+
# support mac mps
|
575 |
+
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank))
|
576 |
+
|
577 |
+
self.dec2vocab = nn.Linear(hidden_dim, self.vocab_size)
|
578 |
+
self.decoder.init_hidden(self.batch_size)
|
579 |
+
|
580 |
+
# TODO Do we need this?
|
581 |
+
if hasattr(self.tgt_embedding, 'weight'):
|
582 |
+
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1)
|
583 |
+
self.decoder.init_hidden(self.batch_size)
|
584 |
+
|
585 |
+
def to(self, device):
|
586 |
+
newself = super().to(device)
|
587 |
+
newself.encoder = newself.encoder.to(device)
|
588 |
+
newself.decoder = newself.decoder.to(device)
|
589 |
+
newself.rank = next(newself.encoder.parameters()).get_device()
|
590 |
+
return newself
|
591 |
+
|
592 |
+
def forward(self, x, edge_index, edge_attr, batch_size, out_seq=None, sched_prob = None):
|
593 |
+
mu, logvar = self.encode(x, edge_index, edge_attr, batch_size)
|
594 |
+
z = self.reparameterize(mu, logvar)
|
595 |
+
return self.decode(z, out_seq, sched_prob=sched_prob), mu, logvar
|
596 |
+
|
597 |
+
#TODO Not tested. Need to implement this in case of molecular structure generation
|
598 |
+
def sample(self, sample_size=-1, deterministic=True, return_z=False):
|
599 |
+
self.eval()
|
600 |
+
self.init_hidden()
|
601 |
+
if sample_size == -1:
|
602 |
+
sample_size = self.batch_size
|
603 |
+
|
604 |
+
num_iter = int(np.ceil(sample_size / self.batch_size))
|
605 |
+
hg_list = []
|
606 |
+
z_list = []
|
607 |
+
for _ in range(num_iter):
|
608 |
+
z = Variable(torch.normal(
|
609 |
+
torch.zeros(self.batch_size, self.latent_dim),
|
610 |
+
torch.ones(self.batch_size * self.latent_dim))).cuda()
|
611 |
+
_, each_hg_list = self.decode(z, deterministic=deterministic)
|
612 |
+
z_list.append(z)
|
613 |
+
hg_list += each_hg_list
|
614 |
+
z = torch.cat(z_list)[:sample_size]
|
615 |
+
hg_list = hg_list[:sample_size]
|
616 |
+
if return_z:
|
617 |
+
return hg_list, z.cpu().detach().numpy()
|
618 |
+
else:
|
619 |
+
return hg_list
|
620 |
+
|
621 |
+
def decode(self, z=None, out_seq=None, deterministic=True, sched_prob=None):
|
622 |
+
if z is None:
|
623 |
+
z = Variable(torch.normal(
|
624 |
+
torch.zeros(self.batch_size, self.latent_dim),
|
625 |
+
torch.ones(self.batch_size * self.latent_dim)))
|
626 |
+
if self.rank >= 0:
|
627 |
+
z = z.to(next(self.parameters()).device)
|
628 |
+
|
629 |
+
hidden_dict_0 = {}
|
630 |
+
for each_hidden in self.latent2hidden_dict.keys():
|
631 |
+
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
|
632 |
+
bsize = z.size(0)
|
633 |
+
self.decoder.init_hidden(bsize)
|
634 |
+
self.decoder.feed_hidden(hidden_dict_0)
|
635 |
+
|
636 |
+
if out_seq is not None:
|
637 |
+
tgt_emb0 = self.latent2tgt_emb(z)
|
638 |
+
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1])
|
639 |
+
out_seq_emb = self.tgt_embedding(out_seq)
|
640 |
+
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :]
|
641 |
+
tgt_emb_pred_list = []
|
642 |
+
tgt_emb_pred = None
|
643 |
+
for each_idx in range(self.max_len):
|
644 |
+
if tgt_emb_pred is None or sched_prob is None or torch.rand(1)[0] <= sched_prob:
|
645 |
+
inp = tgt_emb[:, each_idx, :].view(bsize, 1, -1)
|
646 |
+
else:
|
647 |
+
cur_logit = self.dec2vocab(tgt_emb_pred)
|
648 |
+
yi = torch.argmax(cur_logit, dim=2)
|
649 |
+
inp = self.tgt_embedding(yi)
|
650 |
+
tgt_emb_pred = self.decoder.forward_one_step(inp)
|
651 |
+
tgt_emb_pred_list.append(tgt_emb_pred)
|
652 |
+
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
653 |
+
return vocab_logit
|
654 |
+
else:
|
655 |
+
with torch.no_grad():
|
656 |
+
tgt_emb = self.latent2tgt_emb(z)
|
657 |
+
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
|
658 |
+
tgt_emb_pred_list = []
|
659 |
+
stack_list = []
|
660 |
+
hg_list = []
|
661 |
+
nt_symbol_list = []
|
662 |
+
nt_edge_list = []
|
663 |
+
gen_finish_list = []
|
664 |
+
for _ in range(bsize):
|
665 |
+
stack_list.append([])
|
666 |
+
hg_list.append(None)
|
667 |
+
nt_symbol_list.append(NTSymbol(degree=0,
|
668 |
+
is_aromatic=False,
|
669 |
+
bond_symbol_list=[]))
|
670 |
+
nt_edge_list.append(None)
|
671 |
+
gen_finish_list.append(False)
|
672 |
+
|
673 |
+
for _ in range(self.max_len):
|
674 |
+
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
675 |
+
tgt_emb_pred_list.append(tgt_emb_pred)
|
676 |
+
vocab_logit = self.dec2vocab(tgt_emb_pred)
|
677 |
+
for each_batch_idx in range(bsize):
|
678 |
+
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
679 |
+
# get production rule greedily
|
680 |
+
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
|
681 |
+
nt_symbol_list[each_batch_idx],
|
682 |
+
deterministic=deterministic)
|
683 |
+
# convert production rule into an index
|
684 |
+
tgt_id = self.hrg.prod_rule_list.index(prod_rule)
|
685 |
+
# apply the production rule
|
686 |
+
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
|
687 |
+
# add non-terminals to the stack
|
688 |
+
stack_list[each_batch_idx].extend(nt_edges[::-1])
|
689 |
+
# if the stack size is 0, generation has finished!
|
690 |
+
if len(stack_list[each_batch_idx]) == 0:
|
691 |
+
gen_finish_list[each_batch_idx] = True
|
692 |
+
else:
|
693 |
+
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
|
694 |
+
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
|
695 |
+
else:
|
696 |
+
tgt_id = np.mod(self.padding_idx, self.vocab_size)
|
697 |
+
indice_tensor = torch.LongTensor([tgt_id])
|
698 |
+
if self.rank >= 0:
|
699 |
+
indice_tensor = indice_tensor.to(next(self.parameters()).device)
|
700 |
+
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
|
701 |
+
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1))
|
702 |
+
return gen_finish_list, vocab_logit, hg_list
|
703 |
+
|
704 |
+
#TODO Not tested. Need to implement this in case of molecular structure generation
|
705 |
+
def conditional_distribution(self, z, tgt_id_list):
|
706 |
+
self.eval()
|
707 |
+
self.init_hidden()
|
708 |
+
z = z.cuda()
|
709 |
+
|
710 |
+
hidden_dict_0 = {}
|
711 |
+
for each_hidden in self.latent2hidden_dict.keys():
|
712 |
+
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z)
|
713 |
+
self.decoder.feed_hidden(hidden_dict_0)
|
714 |
+
|
715 |
+
with torch.no_grad():
|
716 |
+
tgt_emb = self.latent2tgt_emb(z)
|
717 |
+
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1])
|
718 |
+
nt_symbol_list = []
|
719 |
+
stack_list = []
|
720 |
+
hg_list = []
|
721 |
+
nt_edge_list = []
|
722 |
+
gen_finish_list = []
|
723 |
+
for _ in range(self.batch_size):
|
724 |
+
nt_symbol_list.append(NTSymbol(degree=0,
|
725 |
+
is_aromatic=False,
|
726 |
+
bond_symbol_list=[]))
|
727 |
+
stack_list.append([])
|
728 |
+
hg_list.append(None)
|
729 |
+
nt_edge_list.append(None)
|
730 |
+
gen_finish_list.append(False)
|
731 |
+
|
732 |
+
for each_position in range(len(tgt_id_list[0])):
|
733 |
+
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
734 |
+
for each_batch_idx in range(self.batch_size):
|
735 |
+
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
736 |
+
# use the prespecified target ids
|
737 |
+
tgt_id = tgt_id_list[each_batch_idx][each_position]
|
738 |
+
prod_rule = self.hrg.prod_rule_list[tgt_id]
|
739 |
+
# apply the production rule
|
740 |
+
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx])
|
741 |
+
# add non-terminals to the stack
|
742 |
+
stack_list[each_batch_idx].extend(nt_edges[::-1])
|
743 |
+
# if the stack size is 0, generation has finished!
|
744 |
+
if len(stack_list[each_batch_idx]) == 0:
|
745 |
+
gen_finish_list[each_batch_idx] = True
|
746 |
+
else:
|
747 |
+
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop()
|
748 |
+
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol']
|
749 |
+
else:
|
750 |
+
tgt_id = np.mod(self.padding_idx, self.vocab_size)
|
751 |
+
indice_tensor = torch.LongTensor([tgt_id])
|
752 |
+
indice_tensor = indice_tensor.cuda()
|
753 |
+
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor)
|
754 |
+
|
755 |
+
# last one step
|
756 |
+
conditional_logprob_list = []
|
757 |
+
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb)
|
758 |
+
vocab_logit = self.dec2vocab(tgt_emb_pred)
|
759 |
+
for each_batch_idx in range(self.batch_size):
|
760 |
+
if not gen_finish_list[each_batch_idx]: # if generation has not finished
|
761 |
+
# get production rule greedily
|
762 |
+
masked_logprob = self.hrg.prod_rule_corpus.masked_logprob(
|
763 |
+
vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(),
|
764 |
+
nt_symbol_list[each_batch_idx])
|
765 |
+
conditional_logprob_list.append(masked_logprob)
|
766 |
+
else:
|
767 |
+
conditional_logprob_list.append(None)
|
768 |
+
return conditional_logprob_list
|
769 |
+
|
770 |
+
#TODO Not tested. Need to implement this in case of molecular structure generation
|
771 |
+
def decode_with_beam_search(self, z, beam_width=1):
|
772 |
+
''' Decode a latent vector using beam search.
|
773 |
+
|
774 |
+
Parameters
|
775 |
+
----------
|
776 |
+
z
|
777 |
+
latent vector
|
778 |
+
beam_width : int
|
779 |
+
parameter for beam search
|
780 |
+
|
781 |
+
Returns
|
782 |
+
-------
|
783 |
+
List of Hypergraphs
|
784 |
+
'''
|
785 |
+
if self.batch_size != 1:
|
786 |
+
raise ValueError('this method works only under batch_size=1')
|
787 |
+
if self.padding_idx != -1:
|
788 |
+
raise ValueError('this method works only under padding_idx=-1')
|
789 |
+
top_k_tgt_id_list = [[]] * beam_width
|
790 |
+
logprob_list = [0.] * beam_width
|
791 |
+
|
792 |
+
for each_len in range(self.max_len):
|
793 |
+
expanded_logprob_list = np.repeat(logprob_list, self.vocab_size) # including padding_idx
|
794 |
+
expanded_length_list = np.array([0] * (beam_width * self.vocab_size))
|
795 |
+
for each_beam_idx, each_candidate in enumerate(top_k_tgt_id_list):
|
796 |
+
conditional_logprob = self.conditional_distribution(z, [each_candidate])[0]
|
797 |
+
if conditional_logprob is None:
|
798 |
+
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
|
799 |
+
= logprob_list[each_beam_idx]
|
800 |
+
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
|
801 |
+
= -np.inf
|
802 |
+
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
|
803 |
+
= len(each_candidate)
|
804 |
+
else:
|
805 |
+
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\
|
806 |
+
= logprob_list[each_beam_idx] + conditional_logprob
|
807 |
+
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\
|
808 |
+
= -np.inf
|
809 |
+
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\
|
810 |
+
= len(each_candidate) + 1
|
811 |
+
score_list = np.array(expanded_logprob_list) / np.array(expanded_length_list)
|
812 |
+
if each_len == 0:
|
813 |
+
top_k_list = np.argsort(score_list[:self.vocab_size])[::-1][:beam_width]
|
814 |
+
else:
|
815 |
+
top_k_list = np.argsort(score_list)[::-1][:beam_width]
|
816 |
+
next_top_k_tgt_id_list = []
|
817 |
+
next_logprob_list = []
|
818 |
+
for each_top_k in top_k_list:
|
819 |
+
beam_idx = each_top_k // self.vocab_size
|
820 |
+
vocab_idx = each_top_k % self.vocab_size
|
821 |
+
if vocab_idx == self.vocab_size - 1:
|
822 |
+
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx])
|
823 |
+
next_logprob_list.append(expanded_logprob_list[each_top_k])
|
824 |
+
else:
|
825 |
+
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx] + [vocab_idx])
|
826 |
+
next_logprob_list.append(expanded_logprob_list[each_top_k])
|
827 |
+
top_k_tgt_id_list = next_top_k_tgt_id_list
|
828 |
+
logprob_list = next_logprob_list
|
829 |
+
|
830 |
+
# construct hypergraphs
|
831 |
+
hg_list = []
|
832 |
+
for each_tgt_id_list in top_k_tgt_id_list:
|
833 |
+
hg = None
|
834 |
+
stack = []
|
835 |
+
nt_edge = None
|
836 |
+
for each_idx, each_prod_rule_id in enumerate(each_tgt_id_list):
|
837 |
+
prod_rule = self.hrg.prod_rule_list[each_prod_rule_id]
|
838 |
+
hg, nt_edges = prod_rule.applied_to(hg, nt_edge)
|
839 |
+
stack.extend(nt_edges[::-1])
|
840 |
+
try:
|
841 |
+
nt_edge = stack.pop()
|
842 |
+
except IndexError:
|
843 |
+
if each_idx == len(each_tgt_id_list) - 1:
|
844 |
+
break
|
845 |
+
else:
|
846 |
+
raise ValueError('some bugs')
|
847 |
+
hg_list.append(hg)
|
848 |
+
return hg_list
|
849 |
+
|
850 |
+
def graph_embed(self, x, edge_index, edge_attr, batch_size):
|
851 |
+
src_h = self.encoder.forward(x, edge_index, edge_attr, batch_size)
|
852 |
+
return src_h
|
853 |
+
|
854 |
+
def encode(self, x, edge_index, edge_attr, batch_size):
|
855 |
+
#print("device for src_emb=", src_emb.get_device())
|
856 |
+
#print("device for self.encoder=", next(self.encoder.parameters()).get_device())
|
857 |
+
src_h = self.graph_embed(x, edge_index, edge_attr, batch_size)
|
858 |
+
mu, lv = self.get_mean_var(src_h)
|
859 |
+
return mu, lv
|
860 |
+
|
861 |
+
def get_mean_var(self, src_h):
|
862 |
+
#src_h = torch.tanh(src_h)
|
863 |
+
mu = self.hidden2mean(src_h)
|
864 |
+
lv = self.hidden2logvar(src_h)
|
865 |
+
mu = torch.tanh(mu)
|
866 |
+
lv = torch.tanh(lv)
|
867 |
+
return mu, lv
|
868 |
+
|
869 |
+
def reparameterize(self, mu, logvar, training=True):
|
870 |
+
if training:
|
871 |
+
std = logvar.mul(0.5).exp_()
|
872 |
+
eps = Variable(std.data.new(std.size()).normal_())
|
873 |
+
if self.rank >= 0:
|
874 |
+
eps = eps.to(next(self.parameters()).device)
|
875 |
+
return eps.mul(std).add_(mu)
|
876 |
+
else:
|
877 |
+
return mu
|
878 |
+
|
879 |
+
# Copied from the MHG implementation and adapted
|
880 |
+
class GrammarVAELoss(_Loss):
|
881 |
+
|
882 |
+
'''
|
883 |
+
a loss function for Grammar VAE
|
884 |
+
|
885 |
+
Attributes
|
886 |
+
----------
|
887 |
+
hrg : HyperedgeReplacementGrammar
|
888 |
+
beta : float
|
889 |
+
coefficient of KL divergence
|
890 |
+
'''
|
891 |
+
|
892 |
+
def __init__(self, rank, hrg, beta=1.0, **kwargs):
|
893 |
+
super().__init__(**kwargs)
|
894 |
+
self.hrg = hrg
|
895 |
+
self.beta = beta
|
896 |
+
self.rank = rank
|
897 |
+
|
898 |
+
def forward(self, mu, logvar, in_seq_pred, in_seq):
|
899 |
+
''' compute VAE loss
|
900 |
+
|
901 |
+
Parameters
|
902 |
+
----------
|
903 |
+
in_seq_pred : torch.Tensor, shape (batch_size, max_len, vocab_size)
|
904 |
+
logit
|
905 |
+
in_seq : torch.Tensor, shape (batch_size, max_len)
|
906 |
+
each element corresponds to a word id in vocabulary.
|
907 |
+
mu : torch.Tensor, shape (batch_size, hidden_dim)
|
908 |
+
logvar : torch.Tensor, shape (batch_size, hidden_dim)
|
909 |
+
mean and log variance of the normal distribution
|
910 |
+
'''
|
911 |
+
batch_size = in_seq_pred.shape[0]
|
912 |
+
max_len = in_seq_pred.shape[1]
|
913 |
+
vocab_size = in_seq_pred.shape[2]
|
914 |
+
mask = torch.zeros(in_seq_pred.shape)
|
915 |
+
|
916 |
+
for each_batch in range(batch_size):
|
917 |
+
flag = True
|
918 |
+
for each_idx in range(max_len):
|
919 |
+
prod_rule_idx = in_seq[each_batch, each_idx]
|
920 |
+
if prod_rule_idx == vocab_size - 1:
|
921 |
+
#### DETERMINE WHETHER THIS SHOULD BE SKIPPED OR NOT
|
922 |
+
mask[each_batch, each_idx, prod_rule_idx] = 1
|
923 |
+
#break
|
924 |
+
continue
|
925 |
+
lhs = self.hrg.prod_rule_corpus.prod_rule_list[prod_rule_idx].lhs_nt_symbol
|
926 |
+
lhs_idx = self.hrg.prod_rule_corpus.nt_symbol_list.index(lhs)
|
927 |
+
mask[each_batch, each_idx, :-1] = torch.FloatTensor(self.hrg.prod_rule_corpus.lhs_in_prod_rule[lhs_idx])
|
928 |
+
if self.rank >= 0:
|
929 |
+
mask = mask.to(next(self.parameters()).device)
|
930 |
+
in_seq_pred = mask * in_seq_pred
|
931 |
+
|
932 |
+
cross_entropy = F.cross_entropy(
|
933 |
+
in_seq_pred.view(-1, vocab_size),
|
934 |
+
in_seq.view(-1),
|
935 |
+
reduction='sum',
|
936 |
+
#ignore_index=self.ignore_index if self.ignore_index is not None else -100
|
937 |
+
)
|
938 |
+
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
939 |
+
return cross_entropy + self.beta * kl_div
|
940 |
+
|
941 |
+
|
942 |
+
class VAELoss(_Loss):
|
943 |
+
def __init__(self, beta=0.01):
|
944 |
+
super().__init__()
|
945 |
+
self.beta = beta
|
946 |
+
|
947 |
+
def forward(self, mean, log_var, dec_outputs, targets):
|
948 |
+
|
949 |
+
device = mean.get_device()
|
950 |
+
if device >= 0:
|
951 |
+
targets = targets.to(mean.get_device())
|
952 |
+
reconstruction = F.cross_entropy(dec_outputs.view(-1, dec_outputs.size(2)), targets.view(-1), reduction='sum')
|
953 |
+
|
954 |
+
KL = 0.5 * torch.sum(1 + log_var - mean ** 2 - torch.exp(log_var))
|
955 |
+
loss = - self.beta * KL + reconstruction
|
956 |
+
return loss
|
models/mhg_model/notebooks/mhg-gnn_encoder_decoder_example.ipynb
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "829ddc03",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import sys\n",
|
11 |
+
"sys.path.append('..')"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"id": "ea820e23",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"import torch\n",
|
22 |
+
"import load"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "markdown",
|
27 |
+
"id": "b9a51fa8",
|
28 |
+
"metadata": {},
|
29 |
+
"source": [
|
30 |
+
"# Load MHG-GNN"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"id": "c6ea1fc8",
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"model_ckp = \"models/model_checkpoints/mhg_model/pickles/mhggnn_pretrained_model_radius7_1116_2023.pickle\"\n",
|
41 |
+
"\n",
|
42 |
+
"model = load.load(model_name = model_ckp)\n",
|
43 |
+
"if model is None:\n",
|
44 |
+
" print(\"Model not loaded, please check you have MHG pickle file\")\n",
|
45 |
+
"else:\n",
|
46 |
+
" print(\"MHG model loaded\")"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "markdown",
|
51 |
+
"id": "b4a0b557",
|
52 |
+
"metadata": {},
|
53 |
+
"source": [
|
54 |
+
"# Embeddings\n",
|
55 |
+
"\n",
|
56 |
+
"※ replace the smiles exaple list with your dataset"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": null,
|
62 |
+
"id": "c63a6be6",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"with torch.no_grad():\n",
|
67 |
+
" repr = model.encode([\"CCO\", \"O=C=O\", \"OC(=O)c1ccccc1C(=O)O\"])\n",
|
68 |
+
" \n",
|
69 |
+
"# Print the latent vectors\n",
|
70 |
+
"print(repr)"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"cell_type": "markdown",
|
75 |
+
"id": "a59f9442",
|
76 |
+
"metadata": {},
|
77 |
+
"source": [
|
78 |
+
"# Decoding"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "code",
|
83 |
+
"execution_count": null,
|
84 |
+
"id": "6a0d8a41",
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [],
|
87 |
+
"source": [
|
88 |
+
"orig = model.decode(repr)\n",
|
89 |
+
"print(orig)"
|
90 |
+
]
|
91 |
+
}
|
92 |
+
],
|
93 |
+
"metadata": {
|
94 |
+
"kernelspec": {
|
95 |
+
"display_name": "Python 3 (ipykernel)",
|
96 |
+
"language": "python",
|
97 |
+
"name": "python3"
|
98 |
+
},
|
99 |
+
"language_info": {
|
100 |
+
"codemirror_mode": {
|
101 |
+
"name": "ipython",
|
102 |
+
"version": 3
|
103 |
+
},
|
104 |
+
"file_extension": ".py",
|
105 |
+
"mimetype": "text/x-python",
|
106 |
+
"name": "python",
|
107 |
+
"nbconvert_exporter": "python",
|
108 |
+
"pygments_lexer": "ipython3",
|
109 |
+
"version": "3.7.10"
|
110 |
+
}
|
111 |
+
},
|
112 |
+
"nbformat": 4,
|
113 |
+
"nbformat_minor": 5
|
114 |
+
}
|
models/mhg_model/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fa61ca0975afd93bf1f3c43f3b720594018f57ac9df3cb740def2eb28af6529d
|
3 |
+
size 342570
|
models/mhg_model/setup.cfg
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[metadata]
|
2 |
+
name = mhg-gnn
|
3 |
+
version = attr: .__version__
|
4 |
+
description = Package for mhg-gnn
|
5 |
+
author= team
|
6 |
+
long_description_content_type=text/markdown
|
7 |
+
long_description = file: README.md
|
8 |
+
python_requires = >= 3.9.7
|
9 |
+
license = TBD
|
10 |
+
|
11 |
+
classifiers =
|
12 |
+
Programming Language :: Python :: 3
|
13 |
+
Programming Language :: Python :: 3.9
|
14 |
+
|
15 |
+
[options]
|
16 |
+
install_requires =
|
17 |
+
networkx>=2.8
|
18 |
+
numpy>=1.23.5, <2.0.0
|
19 |
+
pandas>=1.5.3
|
20 |
+
rdkit-pypi>=2022.9.4, <2023.9.6
|
21 |
+
torch>=2.0.0
|
22 |
+
torchinfo>=1.8.0
|
23 |
+
torch-geometric>=2.3.1
|
24 |
+
requests>=2.32.2
|
25 |
+
scikit-learn>=1.5.0
|
26 |
+
urllib3>=2.2.2
|
27 |
+
|
28 |
+
|
29 |
+
setup_requires =
|
30 |
+
setuptools
|
31 |
+
package_dir =
|
32 |
+
= .
|
33 |
+
packages=find:
|
34 |
+
include_package_data = True
|
35 |
+
|
36 |
+
[options.packages.find]
|
37 |
+
where = .
|
models/mhg_model/setup.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import setuptools
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
setuptools.setup()
|
models/selfies_model/README.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
library_name: transformers
|
4 |
+
pipeline_tag: feature-extraction
|
5 |
+
tags:
|
6 |
+
- chemistry
|
7 |
+
---
|
8 |
+
|
9 |
+
# selfies-ted
|
10 |
+
|
11 |
+
selfies-ted is a project for encoding SMILES (Simplified Molecular Input Line Entry System) into SELFIES (SELF-referencing Embedded Strings) and generating embeddings for molecular representations.
|
12 |
+
|
13 |
+

|
14 |
+
## Model Architecture
|
15 |
+
|
16 |
+
Configuration details
|
17 |
+
|
18 |
+
Encoder and Decoder FFN dimensions: 256
|
19 |
+
Number of attention heads: 4
|
20 |
+
Number of encoder and decoder layers: 2
|
21 |
+
Total number of hidden layers: 6
|
22 |
+
Maximum position embeddings: 128
|
23 |
+
Model dimension (d_model): 256
|
24 |
+
|
25 |
+
## Pretrained Models and Training Logs
|
26 |
+
We provide checkpoints of the selfies-ted model pre-trained on a dataset of molecules curated from PubChem. The pre-trained model shows competitive performance on molecular representation tasks. For model weights: "HuggingFace link".
|
27 |
+
|
28 |
+
To install and use the pre-trained model:
|
29 |
+
|
30 |
+
Download the selfies_ted_model.pkl file from the "HuggingFace link".
|
31 |
+
Add the selfies-ted selfies_ted_model.pkl to the models/ directory. The directory structure should look like the following:
|
32 |
+
|
33 |
+
```
|
34 |
+
models/
|
35 |
+
└── selfies_ted_model.pkl
|
36 |
+
```
|
37 |
+
|
38 |
+
## Installation
|
39 |
+
|
40 |
+
To use this project, you'll need to install the required dependencies. We recommend using a virtual environment:
|
41 |
+
|
42 |
+
```bash
|
43 |
+
python -m venv venv
|
44 |
+
source venv/bin/activate # On Windows use `venv\Scripts\activate`
|
45 |
+
```
|
46 |
+
|
47 |
+
Install the required dependencies
|
48 |
+
|
49 |
+
```
|
50 |
+
pip install -r requirements.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
|
54 |
+
## Usage
|
55 |
+
|
56 |
+
### Import
|
57 |
+
|
58 |
+
```
|
59 |
+
import load
|
60 |
+
```
|
61 |
+
### Training the Model
|
62 |
+
|
63 |
+
To train the model, use the train.py script:
|
64 |
+
|
65 |
+
```
|
66 |
+
python train.py -f <path_to_your_data_file>
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
Note: The actual usage may depend on the specific implementation in load.py. Please refer to the source code for detailed functionality.
|
71 |
+
|
72 |
+
### Load the model and tokenizer
|
73 |
+
```
|
74 |
+
load.load("path/to/checkpoint.pkl")
|
75 |
+
```
|
76 |
+
### Encode SMILES strings
|
77 |
+
```
|
78 |
+
smiles_list = ["COC", "CCO"]
|
79 |
+
```
|
80 |
+
```
|
81 |
+
embeddings = load.encode(smiles_list)
|
82 |
+
```
|
83 |
+
|
84 |
+
|
85 |
+
## Example Notebook
|
86 |
+
|
87 |
+
Example notebook of this project is `selfies-ted-example.ipynb`.
|
models/selfies_model/load.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import selfies as sf
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
from rdkit import Chem
|
6 |
+
from transformers import AutoTokenizer, AutoModel
|
7 |
+
import gc
|
8 |
+
from torch.utils.data import DataLoader, Dataset
|
9 |
+
from multiprocessing import Pool, cpu_count
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
import os
|
13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
14 |
+
|
15 |
+
|
16 |
+
class SELFIESDataset(Dataset):
|
17 |
+
def __init__(self, selfies_list):
|
18 |
+
self.selfies = selfies_list
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.selfies)
|
22 |
+
|
23 |
+
def __getitem__(self, idx):
|
24 |
+
return self.selfies[idx]
|
25 |
+
|
26 |
+
class SELFIES(torch.nn.Module):
|
27 |
+
def __init__(self):
|
28 |
+
super().__init__()
|
29 |
+
self.model = None
|
30 |
+
self.tokenizer = None
|
31 |
+
self.invalid = []
|
32 |
+
|
33 |
+
def smiles_to_selfies(self, smiles):
|
34 |
+
try:
|
35 |
+
return sf.encoder(smiles.strip()).replace('][', '] [')
|
36 |
+
except:
|
37 |
+
try:
|
38 |
+
smiles = Chem.MolToSmiles(Chem.MolFromSmiles(smiles.strip()))
|
39 |
+
return sf.encoder(smiles).replace('][', '] [')
|
40 |
+
except:
|
41 |
+
return None
|
42 |
+
|
43 |
+
def get_selfies(self, smiles_list):
|
44 |
+
with Pool(cpu_count()) as pool:
|
45 |
+
selfies = list(pool.map(self.smiles_to_selfies, smiles_list))
|
46 |
+
|
47 |
+
self.invalid = [i for i, s in enumerate(selfies) if s is None]
|
48 |
+
selfies = [s if s is not None else '[nop]' for s in selfies]
|
49 |
+
return selfies
|
50 |
+
|
51 |
+
@torch.no_grad()
|
52 |
+
def get_embedding_batch(self, selfies_batch):
|
53 |
+
encodings = self.tokenizer(
|
54 |
+
selfies_batch,
|
55 |
+
return_tensors='pt',
|
56 |
+
max_length=128,
|
57 |
+
truncation=True,
|
58 |
+
padding='max_length'
|
59 |
+
)
|
60 |
+
encodings = {k: v.to(self.model.device) for k, v in encodings.items()}
|
61 |
+
|
62 |
+
outputs = self.model.encoder(
|
63 |
+
input_ids=encodings['input_ids'],
|
64 |
+
attention_mask=encodings['attention_mask']
|
65 |
+
)
|
66 |
+
|
67 |
+
model_output = outputs.last_hidden_state
|
68 |
+
input_mask_expanded = encodings['attention_mask'].unsqueeze(-1).expand(model_output.size()).float()
|
69 |
+
sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
|
70 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
71 |
+
pooled_output = sum_embeddings / sum_mask
|
72 |
+
|
73 |
+
return pooled_output.cpu().numpy()
|
74 |
+
|
75 |
+
def load(self, checkpoint=None):
|
76 |
+
self.tokenizer = AutoTokenizer.from_pretrained("ibm/materials.selfies-ted")
|
77 |
+
self.model = AutoModel.from_pretrained("ibm/materials.selfies-ted")
|
78 |
+
self.model.eval()
|
79 |
+
|
80 |
+
def encode(self, smiles_list=[], use_gpu=False, return_tensor=False, batch_size=128, num_workers=4):
|
81 |
+
selfies = self.get_selfies(smiles_list)
|
82 |
+
dataset = SELFIESDataset(selfies)
|
83 |
+
|
84 |
+
device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
|
85 |
+
self.model.to(device)
|
86 |
+
|
87 |
+
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
|
88 |
+
|
89 |
+
embeddings = []
|
90 |
+
for batch in tqdm(loader, desc="Encoding"):
|
91 |
+
emb = self.get_embedding_batch(batch)
|
92 |
+
embeddings.append(emb)
|
93 |
+
del emb
|
94 |
+
gc.collect()
|
95 |
+
|
96 |
+
emb = np.vstack(embeddings)
|
97 |
+
|
98 |
+
for idx in self.invalid:
|
99 |
+
emb[idx] = np.nan
|
100 |
+
print(f"Cannot encode {smiles_list[idx]} to selfies. Embedding replaced by NaN.")
|
101 |
+
|
102 |
+
return torch.tensor(emb) if return_tensor else pd.DataFrame(emb)
|
models/selfies_model/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.1.0
|
2 |
+
transformers>=4.38
|
3 |
+
numpy>=1.26.1
|
4 |
+
datasets>=2.13.1
|
5 |
+
evaluate>=0.4.0
|
6 |
+
selfies>=2.1.0
|
7 |
+
scikit-learn>=1.2.1
|
8 |
+
pyarrow>=14.0.1
|
9 |
+
requests>=2.31.0
|
10 |
+
urllib3>=2.0.7
|
11 |
+
aiohttp>=3.9.0
|
12 |
+
zipp>=3.17.0
|