File size: 10,748 Bytes
076e7a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import os
import pickle
from lynxkite.core.ops import op
from matplotlib import pyplot as plt
import pandas as pd
from rdkit.Chem.Draw import rdMolDraw2D
from PIL import Image
from rdkit import Chem
from rdkit.Chem import Descriptors
from rdkit.Chem import Crippen, Lipinski
from rdkit import DataStructs
import math
import io
from rdkit.Chem import AllChem
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
from sklearn.model_selection import train_test_split
import numpy as np


@op("LynxKite Graph Analytics", "View mol filter", view="matplotlib", slow=True)
def mol_filter(
    bundle,
    *,
    table_name: str,
    SMILES_Column: str,
    mols_per_row: int,
    filter_smarts: str = None,
    filter_smiles: str = None,
    highlight: bool = True,
):
    """
    Draws a grid of molecules in square boxes, with optional filtering and substructure highlighting.

    Parameters:
    - bundle: data bundle containing a DataFrame in bundle.dfs[table_name]
    - table_name: name of the table in bundle.dfs
    - column_name: column containing SMILES strings
    - mols_per_row: number of molecules per row in the grid
    - filter_smarts: SMARTS pattern to filter and highlight
    - filter_smiles: SMILES substructure to filter and highlight (if filter_smarts is None)
    - highlight: whether to highlight matching substructures
    """
    # get DataFrame
    df = bundle.dfs[table_name].copy()
    df["mol"] = df[SMILES_Column].apply(Chem.MolFromSmiles)
    df = df[df["mol"].notnull()].reset_index(drop=True)

    # compile substructure query if provided
    query = None
    if filter_smarts:
        query = Chem.MolFromSmarts(filter_smarts)
    elif filter_smiles:
        query = Chem.MolFromSmiles(filter_smiles)

    # compute properties and legends
    df["MW"] = df["mol"].apply(Descriptors.MolWt)
    df["logP"] = df["mol"].apply(Crippen.MolLogP)
    df["HBD"] = df["mol"].apply(Lipinski.NumHDonors)
    df["HBA"] = df["mol"].apply(Lipinski.NumHAcceptors)

    legends = []
    for _, row in df.iterrows():
        mol = row["mol"]
        # filter by substructure
        if query and not mol.HasSubstructMatch(query):
            continue

        # find atom and bond matches
        atom_ids, bond_ids = [], []
        if highlight and query:
            atom_ids = list(mol.GetSubstructMatch(query))
            # find bonds where both ends are in atom_ids
            for bond in mol.GetBonds():
                a1 = bond.GetBeginAtomIdx()
                a2 = bond.GetEndAtomIdx()
                if a1 in atom_ids and a2 in atom_ids:
                    bond_ids.append(bond.GetIdx())

        legend = (
            f"{row['Name']}  pIC50={row['pIC50']:.2f}\n"
            f"MW={row['MW']:.1f}, logP={row['logP']:.2f}\n"
            f"HBD={row['HBD']}, HBA={row['HBA']}"
        )
        legends.append((mol, legend, atom_ids, bond_ids))

    if not legends:
        raise ValueError("No molecules passed the filter.")

    # draw each filtered molecule
    images = []
    for mol, legend, atom_ids, bond_ids in legends:
        drawer = rdMolDraw2D.MolDraw2DCairo(400, 350)
        opts = drawer.drawOptions()
        opts.legendFontSize = 200
        drawer.DrawMolecule(mol, legend=legend, highlightAtoms=atom_ids, highlightBonds=bond_ids)
        drawer.FinishDrawing()

        sub_png = drawer.GetDrawingText()
        sub_img = Image.open(io.BytesIO(sub_png))
        images.append(sub_img)

    plot_gallery(images, num_cols=mols_per_row)


@op("LynxKite Graph Analytics", "Lipinski filter")
def lipinski_filter(bundle, *, table_name: str, column_name: str, strict_lipinski: bool = True):
    # copy bundle and get DataFrame
    bundle = bundle.copy()
    df = bundle.dfs[table_name].copy()
    df["mol"] = df[column_name].apply(Chem.MolFromSmiles)
    df = df[df["mol"].notnull()].reset_index(drop=True)

    # compute properties
    df["MW"] = df["mol"].apply(Descriptors.MolWt)
    df["logP"] = df["mol"].apply(Crippen.MolLogP)
    df["HBD"] = df["mol"].apply(Lipinski.NumHDonors)
    df["HBA"] = df["mol"].apply(Lipinski.NumHAcceptors)

    # compute a boolean pass/fail for Lipinski
    df["pass_lipinski"] = (
        (df["MW"] <= 500) & (df["logP"] <= 5) & (df["HBD"] <= 5) & (df["HBA"] <= 10)
    )
    df = df.drop("mol", axis=1)

    # if strict_lipinski, drop those that fail
    if strict_lipinski:
        failed = df.loc[~df["pass_lipinski"], column_name].tolist()
        df = df[df["pass_lipinski"]].reset_index(drop=True)
        if failed:
            print(f"Dropped {len(failed)} molecules that failed Lipinski: {failed}")

    return df


@op("LynxKite Graph Analytics", "View mol image", view="matplotlib", slow=True)
def mol_image(bundle, *, table_name: str, smiles_column: str, mols_per_row: int):
    df = bundle.dfs[table_name].copy()
    df["mol"] = df[smiles_column].apply(Chem.MolFromSmiles)
    df = df[df["mol"].notnull()].reset_index(drop=True)
    df["MW"] = df["mol"].apply(Descriptors.MolWt)
    df["logP"] = df["mol"].apply(Crippen.MolLogP)
    df["HBD"] = df["mol"].apply(Lipinski.NumHDonors)
    df["HBA"] = df["mol"].apply(Lipinski.NumHAcceptors)

    legends = []
    for _, row in df.iterrows():
        legends.append(
            f"{row['Name']}  pIC50={row['pIC50']:.2f}\n"
            f"MW={row['MW']:.1f}, logP={row['logP']:.2f}\n"
            f"HBD={row['HBD']}, HBA={row['HBA']}"
        )

    mols = df["mol"].tolist()
    if not mols:
        raise ValueError("No valid molecules to draw.")

    # --- draw each molecule into its own sub‐image and paste ---
    images = []
    for mol, legend in zip(mols, legends):
        # draw one molecule
        drawer = rdMolDraw2D.MolDraw2DCairo(400, 350)
        opts = drawer.drawOptions()
        opts.legendFontSize = 200
        drawer.DrawMolecule(mol, legend=legend)
        drawer.FinishDrawing()
        sub_png = drawer.GetDrawingText()
        sub_img = Image.open(io.BytesIO(sub_png))
        images.append(sub_img)

    plot_gallery(images, num_cols=mols_per_row)


def plot_gallery(images, num_cols):
    num_rows = math.ceil(len(images) / num_cols)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 4, num_rows * 3.5))
    axes = axes.flatten()
    for i, ax in enumerate(axes):
        if i < len(images):
            ax.imshow(images[i])
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()


@op("LynxKite Graph Analytics", "Train QSAR model")
def build_qsar_model(
    bundle,
    *,
    table_name: str,
    smiles_col: str,
    target_col: str,
    fp_type: str,
    radius: int = 2,
    n_bits: int = 2048,
    test_size: float = 0.2,
    random_state: int = 42,
    out_dir: str = "Models",
):
    """
    Train and save a RandomForest QSAR model using one fingerprint type.

    Parameters
    ----------
    bundle : any
        An object with a dict‐like attribute `.dfs` mapping table names to DataFrames.
    table_name : str
        Key into bundle.dfs to get the DataFrame.
    smiles_col : str
        Name of the column containing SMILES strings.
    target_col : str
        Name of the column containing the numeric response.
    fp_type : str
        Fingerprint to compute: "ecfp", "rdkit", "torsion", "atompair", or "maccs".
    radius : int
        Radius for the Morgan (ECFP) fingerprint.
    n_bits : int
        Bit‐vector length for all fp types except MACCS (167).
    test_size : float
        Fraction of data held out for testing.
    random_state : int
        Random seed for reproducibility.
    out_dir : str
        Directory in which to save `qsar_model_<fp_type>.pkl`.

    Returns
    -------
    model : RandomForestRegressor
        The trained QSAR model.
    metrics_df : pandas.DataFrame
        R², MAE and RMSE on train and test splits.
    """
    # 1) load and sanitize data
    df = bundle.dfs.get(table_name)
    if df is None:
        raise KeyError(f"Table '{table_name}' not found in bundle.dfs")
    df = df.copy()
    df["mol"] = df[smiles_col].apply(Chem.MolFromSmiles)
    df = df[df["mol"].notnull()].reset_index(drop=True)
    if df.empty:
        raise ValueError(f"No valid molecules in '{smiles_col}'")

    # 2) create a fixed train/test split
    indices = np.arange(len(df))
    train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)

    # 3) featurize
    fps = []
    for mol in df["mol"]:
        if fp_type == "ecfp":
            bv = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
            arr = np.zeros((n_bits,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(bv, arr)
        elif fp_type == "rdkit":
            bv = Chem.RDKFingerprint(mol, fpSize=n_bits)
            arr = np.zeros((n_bits,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(bv, arr)
        elif fp_type == "torsion":
            bv = AllChem.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=n_bits)
            arr = np.zeros((n_bits,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(bv, arr)
        elif fp_type == "atompair":
            bv = AllChem.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=n_bits)
            arr = np.zeros((n_bits,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(bv, arr)
        elif fp_type == "maccs":
            bv = Chem.MACCSkeys.GenMACCSKeys(mol)  # 167 bits
            arr = np.zeros((167,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(bv, arr)
        else:
            raise ValueError(f"Unsupported fingerprint type: '{fp_type}'")
        fps.append(arr)

    X = np.vstack(fps)
    y = df[target_col].values

    # 4) split features/labels
    X_train, y_train = X[train_idx], y[train_idx]
    X_test, y_test = X[test_idx], y[test_idx]

    # 5) train RandomForest
    model = RandomForestRegressor(random_state=random_state)
    model.fit(X_train, y_train)

    # 6) compute performance metrics
    def _metrics(y_true, y_pred):
        mse = mean_squared_error(y_true, y_pred)
        return {
            "R2": r2_score(y_true, y_pred),
            "MAE": mean_absolute_error(y_true, y_pred),
            "RMSE": np.sqrt(mse),
        }

    train_m = _metrics(y_train, model.predict(X_train))
    test_m = _metrics(y_test, model.predict(X_test))
    metrics_df = pd.DataFrame([{"split": "train", **train_m}, {"split": "test", **test_m}])

    # 7) save the model
    os.makedirs(out_dir, exist_ok=True)
    model_file = os.path.join(out_dir, f"qsar_model_{fp_type}.pkl")
    with open(model_file, "wb") as fout:
        pickle.dump(model, fout)

    print(f"Trained & saved QSAR model for '{fp_type}' → {model_file}")
    return metrics_df