abhik1368 commited on
Commit
4cf35ad
·
unverified ·
1 Parent(s): f407b8d

New tools and filters for cheminfo

Browse files

* Update Cheminformatrics use Cases

Add _cheminfo_tools.py with lipinksi filter , View Mol Image , View mol filter with smarts and smiles and highlights are done .

* update new filters and chembl webapi

update new filters and chembl webapi

veber, pains, muegge, brenk_aggregator_filter, egan , ghose , new qsar2.py code with matplotlib plots.

* update tools

update on chembl uniprot based search

* update the code

Delete the old files and folder

Put in example \ Cheminformatics folders

Chembl web service client with example
Plots with plot qsar and plot qsar2 with confidence intervals

* Update new code with new workspace

New workspace created deleted ex1 and ex2 .
Deleted the ecfp and maccs model .pkl file

examples/.crdt/Image table.lynxkite.json.crdt ADDED
Binary file (31.8 kB). View file
 
examples/.crdt/requirements.txt.crdt ADDED
Binary file (251 Bytes). View file
 
examples/Cheminformatics/chem_utils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import sys
4
+ from io import StringIO
5
+ from operator import itemgetter
6
+ from typing import List
7
+ from typing import Tuple
8
+ import itertools
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import seaborn as sns
12
+ from rdkit import Chem, DataStructs, RDLogger
13
+ from rdkit.Chem.Draw import rdMolDraw2D
14
+ from rdkit.Chem.rdchem import Mol
15
+ from rdkit.ML.Cluster import Butina
16
+ from rdkit.rdBase import BlockLogs
17
+
18
+ import pandas as pd
19
+ from rdkit.Chem.rdMMPA import FragmentMol
20
+ from rdkit.Chem.rdRGroupDecomposition import RGroupDecompose
21
+
22
+
23
+ def smi2mol_with_errors(smi: str) -> Tuple[Mol, str]:
24
+ """Parse SMILES and return any associated errors or warnings
25
+
26
+ :param smi: input SMILES
27
+ :return: tuple of RDKit molecule, warning or error
28
+ """
29
+ sio = sys.stderr = StringIO()
30
+ mol = Chem.MolFromSmiles(smi)
31
+ err = sio.getvalue()
32
+ sio = sys.stderr = StringIO()
33
+ sys.stderr = sys.__stderr__
34
+ return mol, err
35
+
36
+
37
+ def count_fragments(mol: Mol) -> int:
38
+ """Count the number of fragments in a molecule
39
+
40
+ :param mol: RDKit molecule
41
+ :return: number of fragments
42
+ """
43
+ return len(Chem.GetMolFrags(mol, asMols=True))
44
+
45
+
46
+ def get_largest_fragment(mol: Mol) -> Mol:
47
+ """Return the fragment with the largest number of atoms
48
+
49
+ :param mol: RDKit molecule
50
+ :return: RDKit molecule with the largest number of atoms
51
+ """
52
+ frag_list = list(Chem.GetMolFrags(mol, asMols=True))
53
+ frag_mw_list = [(x.GetNumAtoms(), x) for x in frag_list]
54
+ frag_mw_list.sort(key=itemgetter(0), reverse=True)
55
+ return frag_mw_list[0][1]
56
+
57
+
58
+ # ----------- Clustering
59
+ # https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GroupShuffleSplit.html
60
+ def taylor_butina_clustering(
61
+ fp_list: List[DataStructs.ExplicitBitVect], cutoff: float = 0.65
62
+ ) -> List[int]:
63
+ """Cluster a set of fingerprints using the RDKit Taylor-Butina implementation
64
+
65
+ :param fp_list: a list of fingerprints
66
+ :param cutoff: distance cutoff (1 - Tanimoto similarity)
67
+ :return: a list of cluster ids
68
+ """
69
+ dists = []
70
+ nfps = len(fp_list)
71
+ for i in range(1, nfps):
72
+ sims = DataStructs.BulkTanimotoSimilarity(fp_list[i], fp_list[:i])
73
+ dists.extend([1 - x for x in sims])
74
+ cluster_res = Butina.ClusterData(dists, nfps, cutoff, isDistData=True)
75
+ cluster_id_list = np.zeros(nfps, dtype=int)
76
+ for cluster_num, cluster in enumerate(cluster_res):
77
+ for member in cluster:
78
+ cluster_id_list[member] = cluster_num
79
+ return cluster_id_list.tolist()
80
+
81
+
82
+ # ----------- Atom tagging
83
+ def label_atoms(mol: Mol, labels: List[str]) -> Mol:
84
+ """Label atoms when depicting a molecule
85
+
86
+ :param mol: input molecule
87
+ :param labels: labels, one for each atom
88
+ :return: molecule with labels
89
+ """
90
+ [atm.SetProp("atomNote", "") for atm in mol.GetAtoms()]
91
+ for atm in mol.GetAtoms():
92
+ idx = atm.GetIdx()
93
+ mol.GetAtomWithIdx(idx).SetProp("atomNote", f"{labels[idx]}")
94
+ return mol
95
+
96
+
97
+ def tag_atoms(mol: Mol, atoms_to_tag: List[int], tag: str = "x") -> Mol:
98
+ """Tag atoms with a specified string
99
+
100
+ :param mol: input molecule
101
+ :param atoms_to_tag: indices of atoms to tag
102
+ :param tag: string to use for the tags
103
+ :return: molecule with atoms tagged
104
+ """
105
+ [atm.SetProp("atomNote", "") for atm in mol.GetAtoms()]
106
+ [mol.GetAtomWithIdx(idx).SetProp("atomNote", tag) for idx in atoms_to_tag]
107
+ return mol
108
+
109
+
110
+ # ----------- Logging
111
+ def rd_shut_the_hell_up() -> None:
112
+ """Make the RDKit be a bit more quiet
113
+
114
+ :return: None
115
+ """
116
+ lg = RDLogger.logger()
117
+ lg.setLevel(RDLogger.CRITICAL)
118
+
119
+
120
+ def demo_block_logs() -> None:
121
+ """An example of another way to turn off RDKit logging
122
+
123
+ :return: None
124
+ """
125
+ block = BlockLogs()
126
+ # do stuff
127
+ del block
128
+
129
+
130
+ # ----------- Image generation
131
+ def boxplot_base64_image(dist: np.ndarray, x_lim: list[int] = [0, 10]) -> str:
132
+ """
133
+ Plot a distribution as a seaborn boxplot and save the resulting image as a base64 image.
134
+
135
+ Parameters:
136
+ dist (np.ndarray): The distribution data to plot.
137
+ x_lim (list[int]): The x-axis limits for the boxplot.
138
+
139
+ Returns:
140
+ str: The base64 encoded image string.
141
+ """
142
+ sns.set(rc={"figure.figsize": (3, 1)})
143
+ sns.set_style("whitegrid")
144
+ ax = sns.boxplot(x=dist)
145
+ ax.set_xlim(x_lim[0], x_lim[1])
146
+ s = io.BytesIO()
147
+ plt.savefig(s, format="png", bbox_inches="tight")
148
+ plt.close()
149
+ s = base64.b64encode(s.getvalue()).decode("utf-8").replace("\n", "")
150
+ return '<img align="left" src="data:image/png;base64,%s">' % s
151
+
152
+
153
+ def mol_to_base64_image(mol: Chem.Mol) -> str:
154
+ """
155
+ Convert an RDKit molecule to a base64 encoded image string.
156
+
157
+ Parameters:
158
+ mol (Chem.Mol): The RDKit molecule to convert.
159
+
160
+ Returns:
161
+ str: The base64 encoded image string.
162
+ """
163
+ drawer = rdMolDraw2D.MolDraw2DCairo(300, 150)
164
+ drawer.DrawMolecule(mol)
165
+ drawer.FinishDrawing()
166
+ text = drawer.GetDrawingText()
167
+ im_text64 = base64.b64encode(text).decode("utf8")
168
+ img_str = f"<img src='data:image/png;base64, {im_text64}'/>"
169
+ return img_str
170
+
171
+
172
+ def cleanup_fragment(mol: Mol) -> Tuple[Mol, int]:
173
+ """
174
+ Replace atom map numbers with Hydrogens
175
+ :param mol: input molecule
176
+ :return: modified molecule, number of R-groups
177
+ """
178
+ rgroup_count = 0
179
+ for atm in mol.GetAtoms():
180
+ atm.SetAtomMapNum(0)
181
+ if atm.GetAtomicNum() == 0:
182
+ rgroup_count += 1
183
+ atm.SetAtomicNum(1)
184
+ mol = Chem.RemoveAllHs(mol)
185
+ return mol, rgroup_count
186
+
187
+
188
+ def generate_fragments(mol: Mol) -> pd.DataFrame:
189
+ """
190
+ Generate fragments using the RDKit
191
+ :param mol: RDKit molecule
192
+ :return: a Pandas dataframe with Scaffold SMILES, Number of Atoms, Number of R-Groups
193
+ """
194
+ # Generate molecule fragments
195
+ frag_list = FragmentMol(mol)
196
+ # Flatten the output into a single list
197
+ flat_frag_list = [x for x in itertools.chain(*frag_list) if x]
198
+ # The output of Fragment mol is contained in single molecules. Extract the largest fragment from each molecule
199
+ flat_frag_list = [get_largest_fragment(x) for x in flat_frag_list]
200
+ # Keep fragments where the number of atoms in the fragment is at least 2/3 of the number fragments in
201
+ # input molecule
202
+ num_mol_atoms = mol.GetNumAtoms()
203
+ flat_frag_list = [x for x in flat_frag_list if x.GetNumAtoms() / num_mol_atoms > 0.67]
204
+ # remove atom map numbers from the fragments
205
+ flat_frag_list = [cleanup_fragment(x) for x in flat_frag_list]
206
+ # Convert fragments to SMILES
207
+ frag_smiles_list = [[Chem.MolToSmiles(x), x.GetNumAtoms(), y] for (x, y) in flat_frag_list]
208
+ # Add the input molecule to the fragment list
209
+ frag_smiles_list.append([Chem.MolToSmiles(mol), mol.GetNumAtoms(), 1])
210
+ # Put the results into a Pandas dataframe
211
+ frag_df = pd.DataFrame(frag_smiles_list, columns=["Scaffold", "NumAtoms", "NumRgroupgs"])
212
+ # Remove duplicate fragments
213
+ frag_df = frag_df.drop_duplicates("Scaffold")
214
+ return frag_df
215
+
216
+
217
+ def find_scaffolds(df_in: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
218
+ """
219
+ Generate scaffolds for a set of molecules
220
+ :param df_in: Pandas dataframe with [SMILES, Name, RDKit molecule] columns
221
+ :return: dataframe with molecules and scaffolds, dataframe with unique scaffolds
222
+ """
223
+ # Loop over molecules and generate fragments, fragments for each molecule are returned as a Pandas dataframe
224
+ df_list = []
225
+ for smiles, name, mol in df_in[["SMILES", "Name", "mol"]].values:
226
+ tmp_df = generate_fragments(mol).copy()
227
+ tmp_df["Name"] = name
228
+ tmp_df["SMILES"] = smiles
229
+ df_list.append(tmp_df)
230
+ # Combine the list of dataframes into a single dataframe
231
+ mol_df = pd.concat(df_list)
232
+ # Collect scaffolds
233
+ scaffold_list = []
234
+ for k, v in mol_df.groupby("Scaffold"):
235
+ scaffold_list.append([k, len(v.Name.unique()), v.NumAtoms.values[0]])
236
+ scaffold_df = pd.DataFrame(scaffold_list, columns=["Scaffold", "Count", "NumAtoms"])
237
+ # Any fragment that occurs more times than the number of fragments can't be a scaffold
238
+ num_df_rows = len(df_in) # noqa: F841
239
+ scaffold_df = scaffold_df.query(f"Count <= {num_df_rows}")
240
+ # Sort scaffolds by frequency
241
+ scaffold_df = scaffold_df.sort_values(["Count", "NumAtoms"], ascending=[False, False])
242
+ return mol_df, scaffold_df
243
+
244
+
245
+ def get_molecules_with_scaffold(
246
+ scaffold: str, mol_df: pd.DataFrame, activity_df: pd.DataFrame
247
+ ) -> Tuple[List[str], pd.DataFrame]:
248
+ """
249
+ Associate molecules with scaffolds
250
+ :param scaffold: scaffold SMILES
251
+ :param mol_df: dataframe with molecules and scaffolds, returned by find_scaffolds()
252
+ :param activity_df: dataframe with [SMILES, Name, pIC50] columns
253
+ :return: list of core(s) with R-groups labeled, dataframe with [SMILES, Name, pIC50]
254
+ """
255
+ match_df = mol_df.query("Scaffold == @scaffold")
256
+ merge_df = match_df.merge(activity_df, on=["SMILES", "Name"])
257
+ scaffold_mol = Chem.MolFromSmiles(scaffold)
258
+ rgroup_match, rgroup_miss = RGroupDecompose(scaffold_mol, merge_df.mol, asSmiles=True)
259
+ if len(rgroup_match):
260
+ rgroup_df = pd.DataFrame(rgroup_match)
261
+ return rgroup_df.Core.unique(), merge_df[["SMILES", "Name", "pIC50"]]
262
+ else:
263
+ return [], merge_df[["SMILES", "Name", "pIC50"]]
examples/Cheminformatics/chembl_api_uses.lynxkite.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/Cheminformatics/chembl_tools.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lynxkite.core.ops import op
2
+ import pandas as pd
3
+ from chembl_webresource_client.new_client import new_client
4
+ from rdkit import Chem
5
+
6
+
7
+ @op("LynxKite Graph Analytics", "chembl sim search")
8
+ def similarity_to_dataframe(*, smiles: str, cutoff: int = 70) -> pd.DataFrame:
9
+ """
10
+ Run a ChEMBL similarity search and return the hits as a pandas DataFrame.
11
+ If the SMILES is invalid or an error occurs, prints a message and returns
12
+ an empty DataFrame with the expected columns.
13
+
14
+ Parameters
15
+ ----------
16
+ smiles : str
17
+ The SMILES string to search on.
18
+ cutoff : int
19
+ The minimum Tanimoto similarity (0–100).
20
+
21
+ Returns
22
+ -------
23
+ pd.DataFrame
24
+ Columns: 'molecule_chembl_id', 'similarity'
25
+ """
26
+ # Prepare empty frame to return on error
27
+ cols = ["molecule_chembl_id", "similarity"]
28
+ empty_df = pd.DataFrame(columns=cols)
29
+
30
+ # 1) Quick SMILES validation
31
+ if Chem.MolFromSmiles(smiles) is None:
32
+ print("Please input a correct SMILES string.")
33
+ return empty_df
34
+
35
+ try:
36
+ # 2) Do the ChEMBL API call
37
+ similarity = new_client.similarity
38
+ results = similarity.filter(smiles=smiles, similarity=cutoff).only(cols)
39
+
40
+ # 3) Build DataFrame
41
+ data = list(results)
42
+ df = pd.DataFrame.from_records(data, columns=cols)
43
+
44
+ # 4) Inform if no hits
45
+ if df.empty:
46
+ print("No hits found for that SMILES at the given cutoff.")
47
+ return df
48
+
49
+ except Exception as e:
50
+ # Catch network errors, unexpected API replies, etc.
51
+ print("An error occurred during the similarity search.")
52
+ print(" Details:", str(e))
53
+ return empty_df
54
+
55
+
56
+ @op("LynxKite Graph Analytics", "chembl structure")
57
+ def _chembl_structures(
58
+ df: pd.DataFrame, *, id_col: str = "molecule_chembl_id", timeout: int = 5
59
+ ) -> pd.DataFrame:
60
+ """
61
+ Given a DataFrame with a column of ChEMBL molecule IDs, append
62
+ canonical SMILES, standard InChI, and standard InChIKey.
63
+
64
+ Parameters
65
+ ----------
66
+ df : pd.DataFrame
67
+ Input DataFrame; must contain `id_col`.
68
+ id_col : str
69
+ Name of the column in `df` that holds ChEMBL IDs (e.g. 'CHEMBL1234').
70
+ timeout : int
71
+ How many seconds to wait for the API (not currently used by chembl client,
72
+ but reserved for future enhancements or custom wrappers).
73
+
74
+ Returns
75
+ -------
76
+ pd.DataFrame
77
+ A new DataFrame with three additional columns:
78
+ - smiles
79
+ - standard_inchi
80
+ - standard_inchi_key
81
+ """
82
+ # make a copy so we don’t modify in-place
83
+ out = df.copy()
84
+ # prepare new columns
85
+ out["smiles"] = None
86
+ out["standard_inchi"] = None
87
+ out["standard_inchi_key"] = None
88
+
89
+ mol_client = new_client.molecule
90
+
91
+ for idx, chembl_id in out[id_col].items():
92
+ try:
93
+ # query ChEMBL for this molecule
94
+ res = mol_client.filter(chembl_id=chembl_id).only(
95
+ ["molecule_chembl_id", "molecule_structures"]
96
+ )
97
+ # filter() returns an iterable; grab first record if exists
98
+ rec = next(iter(res), None)
99
+ if rec and rec.get("molecule_structures"):
100
+ struct = rec["molecule_structures"]
101
+ out.at[idx, "smiles"] = struct.get("canonical_smiles")
102
+ out.at[idx, "standard_inchi"] = struct.get("standard_inchi")
103
+ out.at[idx, "standard_inchi_key"] = struct.get("standard_inchi_key")
104
+ else:
105
+ print(f"[Warning] No structure found for {chembl_id}")
106
+ except Exception as e:
107
+ print(f"[Error] Lookup failed for {chembl_id}: {e!s}")
108
+
109
+ return out
110
+
111
+
112
+ @op("LynxKite Graph Analytics", "get chembl drugs")
113
+ def fetch_chembl_drugs(
114
+ *, first_approval: int = 2000, development_phase: int = None
115
+ ) -> pd.DataFrame:
116
+ """
117
+ Fetch drugs from ChEMBL matching the given USAN stem, approval year,
118
+ and development phase, returning key fields as a DataFrame.
119
+
120
+ Parameters
121
+ ----------
122
+ first_approval : int, optional
123
+ Only include drugs first approved in or after this year (default=1980).
124
+ development_phase : int, optional
125
+ Only include drugs in this development phase (e.g. 2, 3, 4).
126
+ If None, do not filter by phase.
127
+ usan_stem : str, optional
128
+ USAN stem to filter on (default="-azosin").
129
+
130
+ Returns
131
+ -------
132
+ pd.DataFrame
133
+ Columns:
134
+ - development_phase
135
+ - first_approval
136
+ - molecule_chembl_id
137
+ - synonyms
138
+ - usan_stem
139
+ - usan_stem_definition
140
+ - usan_year
141
+
142
+ If no results (or on error), returns an empty DataFrame with these columns.
143
+ """
144
+ cols = [
145
+ "development_phase",
146
+ "first_approval",
147
+ "molecule_chembl_id",
148
+ "synonyms",
149
+ "usan_stem",
150
+ "usan_stem_definition",
151
+ "usan_year",
152
+ ]
153
+ empty_df = pd.DataFrame(columns=cols)
154
+
155
+ # Validate inputs
156
+ if first_approval is not None and not isinstance(first_approval, int):
157
+ print("Error: first_approval must be an integer year.")
158
+ return empty_df
159
+ if development_phase is not None and not isinstance(development_phase, int):
160
+ print("Error: development_phase must be an integer.")
161
+ return empty_df
162
+ # if not isinstance(usan_stem, str):
163
+ # print("Error: usan_stem must be a string.")
164
+ # return empty_df
165
+
166
+ try:
167
+ drug = new_client.drug
168
+
169
+ # apply approval-year filter
170
+ if first_approval is not None:
171
+ drug = drug.filter(first_approval__gte=first_approval)
172
+ # apply development-phase filter
173
+ if development_phase is not None:
174
+ drug = drug.filter(development_phase=development_phase)
175
+ # apply USAN stem filter
176
+ # drug = drug.filter(usan_stem=usan_stem)
177
+
178
+ res = drug.only(cols)
179
+ df = pd.DataFrame(res, columns=cols)
180
+
181
+ if df.empty:
182
+ print("No drugs found for those filters.")
183
+ return df
184
+
185
+ except Exception as e:
186
+ print("An error occurred during the ChEMBL query:")
187
+ print(" ", str(e))
188
+ return empty_df
189
+
190
+
191
+ @op("LynxKite Graph Analytics", "get bioactivity from uniprot")
192
+ def fetch_chembl_bioactivity(*, uniprot_id: str = "Q9NZQ7"):
193
+ """
194
+ Fetch bioactivity data from ChEMBL for a given UniProt ID.
195
+ """
196
+ target = new_client.target.filter(target_components__accession=uniprot_id)
197
+ targets = list(target)
198
+ if not targets:
199
+ return []
200
+
201
+ target_chembl_id = targets[0]["target_chembl_id"]
202
+ activities = new_client.activity.filter(
203
+ target_chembl_id=target_chembl_id, standard_type__in=["IC50", "Ki", "Kd"]
204
+ )
205
+ df = pd.DataFrame(activities)
206
+ return df
examples/Cheminformatics/cheminfo_tools.py CHANGED
@@ -16,6 +16,7 @@ from sklearn.ensemble import RandomForestRegressor
16
  from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
17
  from sklearn.model_selection import train_test_split
18
  import numpy as np
 
19
 
20
 
21
  @op("LynxKite Graph Analytics", "View mol filter", view="matplotlib", slow=True)
@@ -303,3 +304,612 @@ def build_qsar_model(
303
 
304
  print(f"Trained & saved QSAR model for '{fp_type}' → {model_file}")
305
  return metrics_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error
17
  from sklearn.model_selection import train_test_split
18
  import numpy as np
19
+ from rdkit.Chem import MACCSkeys
20
 
21
 
22
  @op("LynxKite Graph Analytics", "View mol filter", view="matplotlib", slow=True)
 
304
 
305
  print(f"Trained & saved QSAR model for '{fp_type}' → {model_file}")
306
  return metrics_df
307
+
308
+
309
+ def predict_with_ci(model, X, confidence=0.95):
310
+ """
311
+ Calculates predictions and confidence intervals for a RandomForestRegressor.
312
+ (Implementation is the same as in the previous answer)
313
+ """
314
+ # Get predictions from each individual tree
315
+ tree_preds = np.array([tree.predict(X) for tree in model.estimators_])
316
+ # Calculate mean prediction
317
+ y_pred_mean = np.mean(tree_preds, axis=0)
318
+ # Calculate percentiles for confidence interval
319
+ alpha = (1.0 - confidence) / 2.0
320
+ lower_percentile = alpha * 100
321
+ upper_percentile = (1.0 - alpha) * 100
322
+ y_pred_lower = np.percentile(tree_preds, lower_percentile, axis=0)
323
+ y_pred_upper = np.percentile(tree_preds, upper_percentile, axis=0)
324
+ return y_pred_mean, y_pred_lower, y_pred_upper
325
+
326
+
327
+ # --- End of predict_with_ci definition ---
328
+
329
+
330
+ @op("LynxKite Graph Analytics", "Train QSAR2")
331
+ def build_qsar_model2(
332
+ df: pd.DataFrame,
333
+ *,
334
+ smiles_col: str,
335
+ target_col: str,
336
+ fp_type: str,
337
+ radius: int = 2,
338
+ n_bits: int = 2048,
339
+ test_size: float = 0.2,
340
+ random_state: int = 42,
341
+ out_dir: str = "Models",
342
+ confidence: float = 0.95,
343
+ ):
344
+ """
345
+ Train/save RandomForest QSAR model, returning the model and a results DataFrame.
346
+
347
+ The results DataFrame contains per-point data ('actual', 'predicted',
348
+ 'lower_ci', 'upper_ci', 'split') AND repeated summary metrics for each
349
+ split ('split_R2', 'split_MAE', 'split_RMSE').
350
+
351
+ Parameters
352
+ ----------
353
+ (Parameters are the same as before)
354
+ bundle : any
355
+ table_name : str
356
+ smiles_col : str
357
+ target_col : str
358
+ fp_type : str
359
+ radius : int
360
+ n_bits : int
361
+ test_size : float
362
+ random_state : int
363
+ out_dir : str
364
+ confidence : float, optional
365
+
366
+ Returns
367
+ -------
368
+ model : RandomForestRegressor
369
+ The trained QSAR model.
370
+ results_df : pandas.DataFrame
371
+ DataFrame containing columns: 'actual', 'predicted', 'lower_ci',
372
+ 'upper_ci', 'split', 'split_R2', 'split_MAE', 'split_RMSE'.
373
+ The metric columns repeat the overall metric for the corresponding split.
374
+ """
375
+ # Steps 1-5: Load data, split, featurize, split features, train model
376
+ # (Code is identical to previous versions up to model training)
377
+ # ... (load data, sanitize, split indices) ...
378
+ # df = bundle.dfs.get(table_name)
379
+ df = df.copy()
380
+ if df is None:
381
+ raise KeyError("Table not found")
382
+ df[target_col] = pd.to_numeric(df[target_col], errors="coerce")
383
+ df.dropna(subset=[target_col, smiles_col], inplace=True)
384
+ df["mol"] = df[smiles_col].apply(Chem.MolFromSmiles)
385
+ df = df[df["mol"].notnull()].reset_index(drop=True)
386
+ if df.empty:
387
+ raise ValueError("No valid molecules or targets")
388
+
389
+ indices = np.arange(len(df))
390
+ train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=random_state)
391
+
392
+ print(f"Featurizing using {fp_type}...")
393
+ fps = []
394
+ valid_indices = []
395
+ for i, mol in enumerate(df["mol"]):
396
+ try:
397
+ # ... (fp generation logic as before) ...
398
+ if fp_type == "ecfp":
399
+ bv = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
400
+ current_n_bits = n_bits
401
+ elif fp_type == "rdkit":
402
+ bv = Chem.RDKFingerprint(mol, fpSize=n_bits)
403
+ current_n_bits = n_bits
404
+ elif fp_type == "torsion":
405
+ bv = AllChem.GetHashedTopologicalTorsionFingerprintAsBitVect(mol, nBits=n_bits)
406
+ current_n_bits = n_bits
407
+ elif fp_type == "atompair":
408
+ bv = AllChem.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=n_bits)
409
+ current_n_bits = n_bits
410
+ elif fp_type == "maccs":
411
+ bv = MACCSkeys.GenMACCSKeys(mol) # 167 bits
412
+ current_n_bits = 167
413
+ else:
414
+ raise ValueError(f"Unsupported fp type: '{fp_type}'")
415
+
416
+ arr = np.zeros((current_n_bits,), dtype=np.int8)
417
+ DataStructs.ConvertToNumpyArray(bv, arr)
418
+ fps.append(arr)
419
+ valid_indices.append(i)
420
+ except Exception as e:
421
+ print(f"Warning: Featurization failed index {i}. Skipping. Error: {e}")
422
+ continue
423
+ if not fps:
424
+ raise ValueError("No molecules featurized.")
425
+ X = np.vstack(fps)
426
+ df_filtered = df.iloc[valid_indices].reset_index(drop=True)
427
+ y = df_filtered[target_col].values
428
+
429
+ # original_indices_set = set(valid_indices)
430
+
431
+ train_idx_filtered = [
432
+ i for i, original_idx in enumerate(valid_indices) if original_idx in train_idx
433
+ ]
434
+ test_idx_filtered = [
435
+ i for i, original_idx in enumerate(valid_indices) if original_idx in test_idx
436
+ ]
437
+
438
+ X_train, y_train = X[train_idx_filtered], y[train_idx_filtered]
439
+ X_test, y_test = X[test_idx_filtered], y[test_idx_filtered]
440
+
441
+ if X_train.shape[0] == 0 or X_test.shape[0] == 0:
442
+ raise ValueError("Train or test split empty after filtering.")
443
+
444
+ print("Training RandomForestRegressor...")
445
+ model = RandomForestRegressor(random_state=random_state, n_jobs=-1)
446
+ model.fit(X_train, y_train)
447
+
448
+ # 6) Compute predictions and *summary* performance metrics
449
+ print("Calculating predictions and metrics...")
450
+ y_pred_train, lower_ci_train, upper_ci_train = predict_with_ci(model, X_train, confidence)
451
+ y_pred_test, lower_ci_test, upper_ci_test = predict_with_ci(model, X_test, confidence)
452
+
453
+ def _metrics(y_true, y_pred_mean):
454
+ # (Same helper function as before)
455
+ y_true = np.ravel(y_true)
456
+ y_pred_mean = np.ravel(y_pred_mean)
457
+ if len(y_true) == 0:
458
+ return {"R2": np.nan, "MAE": np.nan, "RMSE": np.nan}
459
+ mse = mean_squared_error(y_true, y_pred_mean)
460
+ return {
461
+ "R2": r2_score(y_true, y_pred_mean),
462
+ "MAE": mean_absolute_error(y_true, y_pred_mean),
463
+ "RMSE": np.sqrt(mse),
464
+ }
465
+
466
+ train_metrics_dict = _metrics(y_train, y_pred_train)
467
+ test_metrics_dict = _metrics(y_test, y_pred_test)
468
+
469
+ # 7) Create results DataFrames and ADD metrics columns
470
+ train_results = pd.DataFrame(
471
+ {
472
+ "actual": y_train,
473
+ "predicted": y_pred_train,
474
+ "lower_ci": lower_ci_train,
475
+ "upper_ci": upper_ci_train,
476
+ "split": "train",
477
+ }
478
+ )
479
+ # Add repeated metrics
480
+ for metric, value in train_metrics_dict.items():
481
+ train_results[f"split_{metric}"] = value
482
+
483
+ test_results = pd.DataFrame(
484
+ {
485
+ "actual": y_test,
486
+ "predicted": y_pred_test,
487
+ "lower_ci": lower_ci_test,
488
+ "upper_ci": upper_ci_test,
489
+ "split": "test",
490
+ }
491
+ )
492
+ # Add repeated metrics
493
+ for metric, value in test_metrics_dict.items():
494
+ test_results[f"split_{metric}"] = value
495
+
496
+ # Concatenate into the final DataFrame
497
+ results_df = pd.concat([train_results, test_results], ignore_index=True)
498
+
499
+ # 8) Save the model (same as before)
500
+ os.makedirs(out_dir, exist_ok=True)
501
+ model_file = os.path.join(out_dir, f"qsar_model_{fp_type}.pkl")
502
+ try:
503
+ with open(model_file, "wb") as fout:
504
+ pickle.dump(model, fout)
505
+ print(f"Trained & saved QSAR model for '{fp_type}' -> {model_file}")
506
+ except Exception as e:
507
+ print(f"Error saving model to {model_file}: {e}")
508
+
509
+ return results_df
510
+
511
+
512
+ @op("LynxKite Graph Analytics", "plot qsar", view="matplotlib")
513
+ def plot_qsar(results_df: pd.DataFrame):
514
+ """
515
+ Plots actual vs. predicted values from a QSAR results DataFrame.
516
+
517
+ Requires a single positional argument: the results DataFrame. All other
518
+ parameters are optional keyword arguments. It extracts summary metrics
519
+ directly from columns ('split_R2', 'split_MAE', 'split_RMSE')
520
+ expected within the results_df.
521
+ """
522
+ title = "QSAR Model Performance: Actual vs. Predicted"
523
+ xlabel = "Actual Values"
524
+ ylabel = "Predicted Values"
525
+ show_metrics = True
526
+
527
+ if not isinstance(results_df, pd.DataFrame):
528
+ raise TypeError(
529
+ "plot_qsar() missing 1 required positional argument: 'results_df' or the provided argument is not a pandas DataFrame."
530
+ )
531
+
532
+ required_cols = ["actual", "predicted", "lower_ci", "upper_ci", "split"]
533
+ if not all(col in results_df.columns for col in required_cols):
534
+ raise ValueError(f"Invalid 'results_df'. Must contain columns: {required_cols}")
535
+
536
+ metric_cols = ["split_R2", "split_MAE", "split_RMSE"]
537
+ metrics_available = all(col in results_df.columns for col in metric_cols)
538
+ if show_metrics and not metrics_available:
539
+ print(
540
+ f"Warning: Metrics display requested, but one or more metric columns ({metric_cols}) are missing in results_df."
541
+ )
542
+
543
+ # --- Prepare Data ---
544
+ train_data = results_df[results_df["split"] == "train"]
545
+ test_data = results_df[results_df["split"] == "test"]
546
+ can_plot_train = not train_data.empty
547
+ can_plot_test = not test_data.empty
548
+
549
+ if not can_plot_train and not can_plot_test:
550
+ print("Warning: Both training and test data subsets are empty. Cannot generate plot.")
551
+ return # Exit function early if no data
552
+
553
+ # --- Create Plot (Internal Figure/Axes) ---
554
+ fig, ax = plt.subplots(figsize=(8, 8))
555
+
556
+ # --- Plotting Logic ---
557
+ # (Draws scatter, error bars, line, grid, labels, title, legend on 'ax')
558
+ if can_plot_train:
559
+ train_error = [
560
+ train_data["predicted"] - train_data["lower_ci"],
561
+ train_data["upper_ci"] - train_data["predicted"],
562
+ ]
563
+ ax.scatter(
564
+ train_data["actual"],
565
+ train_data["predicted"],
566
+ label="Train",
567
+ alpha=0.6,
568
+ s=30,
569
+ edgecolors="w",
570
+ linewidth=0.5,
571
+ )
572
+ ax.errorbar(
573
+ train_data["actual"],
574
+ train_data["predicted"],
575
+ yerr=train_error,
576
+ fmt="none",
577
+ ecolor="tab:blue",
578
+ label="_nolegend_",
579
+ capsize=0,
580
+ elinewidth=1,
581
+ )
582
+
583
+ if can_plot_test:
584
+ test_error = [
585
+ test_data["predicted"] - test_data["lower_ci"],
586
+ test_data["upper_ci"] - test_data["predicted"],
587
+ ]
588
+ ax.scatter(
589
+ test_data["actual"],
590
+ test_data["predicted"],
591
+ label="Test",
592
+ alpha=0.8,
593
+ s=40,
594
+ edgecolors="w",
595
+ linewidth=0.5,
596
+ )
597
+ ax.errorbar(
598
+ test_data["actual"],
599
+ test_data["predicted"],
600
+ yerr=test_error,
601
+ fmt="none",
602
+ ecolor="tab:orange",
603
+ label="_nolegend_",
604
+ capsize=0,
605
+ elinewidth=1,
606
+ )
607
+
608
+ all_actual = results_df["actual"].dropna()
609
+ all_pred_ci = pd.concat(
610
+ [results_df["predicted"], results_df["lower_ci"], results_df["upper_ci"]]
611
+ ).dropna()
612
+ all_values = pd.concat([all_actual, all_pred_ci]).dropna()
613
+ if all_values.empty:
614
+ min_val, max_val = 0, 1
615
+ else:
616
+ min_val, max_val = all_values.min(), all_values.max()
617
+ if min_val == max_val:
618
+ min_val -= 0.5
619
+ max_val += 0.5
620
+ padding = (max_val - min_val) * 0.05
621
+ min_val -= padding
622
+ max_val += padding
623
+ ax.plot([min_val, max_val], [min_val, max_val], "k--", alpha=0.7, lw=1, label="y=x")
624
+ ax.set_xlim(min_val, max_val)
625
+ ax.set_ylim(min_val, max_val)
626
+ ax.set_aspect("equal", adjustable="box")
627
+ ax.grid(True, linestyle=":", alpha=0.6)
628
+ ax.set_xlabel(xlabel)
629
+ ax.set_ylabel(ylabel)
630
+ ax.set_title(title)
631
+ ax.legend(loc="lower right")
632
+
633
+ # --- Display Metrics Text ---
634
+ if show_metrics and metrics_available:
635
+ # (Logic for extracting and formatting metrics text remains the same)
636
+ metrics_text = ""
637
+ try:
638
+ if can_plot_train:
639
+ train_metrics = train_data[metric_cols].iloc[0]
640
+ r2_tr = (
641
+ f"{train_metrics['split_R2']:.3f}"
642
+ if pd.notna(train_metrics["split_R2"])
643
+ else "N/A"
644
+ )
645
+ mae_tr = (
646
+ f"{train_metrics['split_MAE']:.3f}"
647
+ if pd.notna(train_metrics["split_MAE"])
648
+ else "N/A"
649
+ )
650
+ rmse_tr = (
651
+ f"{train_metrics['split_RMSE']:.3f}"
652
+ if pd.notna(train_metrics["split_RMSE"])
653
+ else "N/A"
654
+ )
655
+ metrics_text += f"Train: $R^2$={r2_tr}, MAE={mae_tr}, RMSE={rmse_tr}\n"
656
+ else:
657
+ metrics_text += "Train: N/A (No Data)\n"
658
+ if can_plot_test:
659
+ test_metrics = test_data[metric_cols].iloc[0]
660
+ r2_te = (
661
+ f"{test_metrics['split_R2']:.3f}"
662
+ if pd.notna(test_metrics["split_R2"])
663
+ else "N/A"
664
+ )
665
+ mae_te = (
666
+ f"{test_metrics['split_MAE']:.3f}"
667
+ if pd.notna(test_metrics["split_MAE"])
668
+ else "N/A"
669
+ )
670
+ rmse_te = (
671
+ f"{test_metrics['split_RMSE']:.3f}"
672
+ if pd.notna(test_metrics["split_RMSE"])
673
+ else "N/A"
674
+ )
675
+ metrics_text += f"Test: $R^2$={r2_te}, MAE={mae_te}, RMSE={rmse_te}"
676
+ else:
677
+ metrics_text += "Test: N/A (No Data)"
678
+ if metrics_text:
679
+ ax.text(
680
+ 0.05,
681
+ 0.95,
682
+ metrics_text.strip(),
683
+ transform=ax.transAxes,
684
+ fontsize=9,
685
+ verticalalignment="top",
686
+ bbox=dict(boxstyle="round,pad=0.5", fc="white", alpha=0.8),
687
+ )
688
+ except Exception as e:
689
+ print(f"An error occurred during metrics display: {e}")
690
+ ax.text(
691
+ 0.05,
692
+ 0.95,
693
+ "Error displaying metrics",
694
+ transform=ax.transAxes,
695
+ fontsize=9,
696
+ color="red",
697
+ verticalalignment="top",
698
+ bbox=dict(boxstyle="round,pad=0.5", fc="white", alpha=0.8),
699
+ )
700
+
701
+
702
+ @op("LynxKite Graph Analytics", "plot qsar2", view="matplotlib")
703
+ def plot_qsar2(results_df: pd.DataFrame):
704
+ """
705
+ Plots actual vs. predicted values resembling the example image.
706
+
707
+ Includes separate markers for train/test, y=x line, and parallel dashed
708
+ error bands based on test set RMSE (optional). Does NOT use per-point CIs.
709
+
710
+ Handles displaying the plot via plt.show() or saving it to a file
711
+ based on the `save_path` parameter. THIS FUNCTION DOES NOT RETURN ANY VALUE.
712
+
713
+ Parameters
714
+ ----------
715
+ results_df : pd.DataFrame
716
+ Mandatory input DataFrame. Must contain: 'actual', 'predicted', 'split'.
717
+ Should also contain 'split_RMSE' column for error bands and metrics display.
718
+ title : str, optional
719
+ xlabel : str, optional
720
+ ylabel : str, optional
721
+ rmse_multiplier_for_bands : float or None, optional
722
+ Determines the width of the dashed error bands (multiplier * test_RMSE).
723
+ Set to None to disable bands. Default is 1.0.
724
+ show_metrics : bool, optional
725
+ Whether to display R2/MAE/RMSE text (requires metric columns). Default is True.
726
+ save_path : str, optional
727
+ If provided, saves plot to this path. If None (default), displays plot.
728
+
729
+ Raises
730
+ ------
731
+ ValueError / TypeError : For invalid inputs.
732
+ """
733
+ COLOR_TRAIN = "royalblue"
734
+ COLOR_TEST = "darkorange" # Changed from red for potentially better contrast/appeal
735
+ COLOR_PERFECT = "black"
736
+ COLOR_BANDS = "dimgrey" # Less prominent than the perfect line
737
+ COLOR_GRID = "lightgrey"
738
+ title = "QSAR Model Performance: Actual vs. Predicted"
739
+ xlabel = "Actual Values"
740
+ ylabel = "Predicted Values"
741
+ # ci_alpha = 0.2
742
+ show_metrics = True
743
+ rmse_multiplier_for_bands = 1.0
744
+ # --- Input Validation ---
745
+ if not isinstance(results_df, pd.DataFrame):
746
+ raise TypeError("Input must be a pandas DataFrame.")
747
+
748
+ required_cols = ["actual", "predicted", "split"]
749
+ if not all(col in results_df.columns for col in required_cols):
750
+ raise ValueError(f"DataFrame must contain columns: {required_cols}")
751
+
752
+ metric_cols = ["split_R2", "split_MAE", "split_RMSE"]
753
+ metrics_available = all(col in results_df.columns for col in metric_cols)
754
+ bands_possible = rmse_multiplier_for_bands is not None and "split_RMSE" in results_df.columns
755
+
756
+ if show_metrics and not metrics_available:
757
+ print(
758
+ f"Warning: Metrics display requested, but one or more metric columns ({metric_cols}) are missing."
759
+ )
760
+ if rmse_multiplier_for_bands is not None and "split_RMSE" not in results_df.columns:
761
+ print("Warning: Error bands requested, but 'split_RMSE' column is missing.")
762
+ bands_possible = False
763
+
764
+ # --- Prepare Data ---
765
+ train_data = results_df[results_df["split"] == "train"].copy()
766
+ test_data = results_df[results_df["split"] == "test"].copy()
767
+ can_plot_train = not train_data.empty
768
+ can_plot_test = not test_data.empty
769
+
770
+ if not can_plot_train and not can_plot_test:
771
+ print("Warning: Both training and test data subsets are empty. Cannot generate plot.")
772
+ return
773
+
774
+ # --- Create Plot with Style ---
775
+ plt.style.use("seaborn-v0_8-whitegrid") # Use a cleaner base style
776
+ fig, ax = plt.subplots(figsize=(8, 8)) # Slightly larger figure
777
+
778
+ # --- Plotting Logic ---
779
+ # Scatter plots with enhanced style
780
+ common_scatter_kws = {"s": 45, "alpha": 0.75, "edgecolor": "black", "linewidth": 0.5}
781
+ if can_plot_train:
782
+ ax.scatter(
783
+ train_data["actual"],
784
+ train_data["predicted"],
785
+ label="Training set",
786
+ marker="o",
787
+ color=COLOR_TRAIN,
788
+ **common_scatter_kws,
789
+ ) # Blue circles
790
+
791
+ if can_plot_test:
792
+ ax.scatter(
793
+ test_data["actual"],
794
+ test_data["predicted"],
795
+ label="Test set",
796
+ marker="o",
797
+ color=COLOR_TEST,
798
+ **common_scatter_kws,
799
+ ) # Orange circles
800
+
801
+ # Determine plot limits
802
+ # (Using the same logic as before to calculate min_val, max_val)
803
+ all_actual = results_df["actual"].dropna()
804
+ all_pred = results_df["predicted"].dropna()
805
+ all_values = pd.concat([all_actual, all_pred]).dropna()
806
+ if all_values.empty:
807
+ min_val, max_val = 0, 1
808
+ else:
809
+ min_val, max_val = all_values.min(), all_values.max()
810
+ if min_val == max_val:
811
+ min_val -= 0.5
812
+ max_val += 0.5
813
+ data_range = max_val - min_val
814
+ if data_range == 0:
815
+ data_range = 1.0
816
+ padding = data_range * 0.10
817
+ min_val -= padding
818
+ max_val += padding
819
+
820
+ # Plot y=x line (Solid Black, slightly thicker)
821
+ ax.plot(
822
+ [min_val, max_val],
823
+ [min_val, max_val],
824
+ color=COLOR_PERFECT,
825
+ linestyle="-",
826
+ linewidth=1.5,
827
+ alpha=0.9,
828
+ label="_nolegend_",
829
+ )
830
+
831
+ # Plot Error Bands based on Test RMSE (subtler style)
832
+ rmse_test = np.nan
833
+ if bands_possible and can_plot_test:
834
+ try:
835
+ rmse_test = test_data["split_RMSE"].dropna().iloc[0]
836
+ if pd.notna(rmse_test) and rmse_test >= 0:
837
+ margin = rmse_multiplier_for_bands * rmse_test
838
+ band_label = (
839
+ f"$\pm {rmse_multiplier_for_bands}\,$RMSE"
840
+ if rmse_multiplier_for_bands == 1
841
+ else f"$\pm {rmse_multiplier_for_bands}\,$RMSE"
842
+ )
843
+ ax.plot(
844
+ [min_val, max_val],
845
+ [min_val + margin, max_val + margin],
846
+ color=COLOR_BANDS,
847
+ linestyle="--",
848
+ linewidth=1.0,
849
+ alpha=0.7,
850
+ label=band_label,
851
+ ) # Grey dashed
852
+ ax.plot(
853
+ [min_val, max_val],
854
+ [min_val - margin, max_val - margin],
855
+ color=COLOR_BANDS,
856
+ linestyle="--",
857
+ linewidth=1.0,
858
+ alpha=0.7,
859
+ label="_nolegend_",
860
+ ) # Grey dashed
861
+ # else: print("Warning: Could not plot error bands (Invalid Test RMSE).") # Optionally silent
862
+ except Exception as e:
863
+ print(f"Warning: Could not plot error bands: {e}")
864
+
865
+ # Set limits and aspect ratio
866
+ ax.set_xlim(min_val, max_val)
867
+ ax.set_ylim(min_val, max_val)
868
+ ax.set_aspect("equal", adjustable="box")
869
+
870
+ # ADD BACK Grid (Subtle Style)
871
+ ax.grid(True, which="both", linestyle=":", linewidth=0.7, color=COLOR_GRID, alpha=0.7)
872
+ # Ensure grid is behind data points
873
+ ax.set_axisbelow(True)
874
+
875
+ # Set Labels and Title (using specified arguments)
876
+ ax.set_xlabel(xlabel, fontsize=12)
877
+ ax.set_ylabel(ylabel, fontsize=12)
878
+ ax.set_title(title, fontsize=15, pad=15, weight="semibold") # Slightly larger title
879
+
880
+ # Enhance Legend
881
+ ax.legend(loc="best", frameon=True, framealpha=0.85, fontsize=10, shadow=False)
882
+
883
+ # --- Display Metrics Text (Optional) ---
884
+ if show_metrics and metrics_available:
885
+ # (Logic for extracting and formatting metrics text remains the same)
886
+ metrics_text = ""
887
+ try:
888
+ if can_plot_train:
889
+ train_metrics = train_data[metric_cols].dropna().iloc[0] # Ensure using valid row
890
+ r2_tr = f"{train_metrics['split_R2']:.3f}"
891
+ mae_tr = f"{train_metrics['split_MAE']:.3f}"
892
+ rmse_tr = f"{train_metrics['split_RMSE']:.3f}"
893
+ metrics_text += f"Train: $R^2$={r2_tr}, MAE={mae_tr}, RMSE={rmse_tr}\n"
894
+ else:
895
+ metrics_text += "Train: N/A\n"
896
+ if can_plot_test:
897
+ test_metrics = test_data[metric_cols].dropna().iloc[0] # Ensure using valid row
898
+ r2_te = f"{test_metrics['split_R2']:.3f}"
899
+ mae_te = f"{test_metrics['split_MAE']:.3f}"
900
+ rmse_te = f"{test_metrics['split_RMSE']:.3f}"
901
+ metrics_text += f"Test: $R^2$={r2_te}, MAE={mae_te}, RMSE={rmse_te}"
902
+ else:
903
+ metrics_text += "Test: N/A"
904
+ if metrics_text:
905
+ ax.text(
906
+ 0.05,
907
+ 0.95,
908
+ metrics_text.strip(),
909
+ transform=ax.transAxes,
910
+ fontsize=9,
911
+ verticalalignment="top",
912
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7),
913
+ ) # Adjusted box slightly
914
+ except Exception as e:
915
+ print(f"An error occurred during metrics display: {e}")
examples/Cheminformatics/qsar_example.lynxkite.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/draw_molecules.py DELETED
@@ -1,29 +0,0 @@
1
- from lynxkite.core.ops import op
2
- import pandas as pd
3
- import base64
4
- import io
5
-
6
-
7
- def pil_to_data(image):
8
- buffer = io.BytesIO()
9
- image.save(buffer, format="png")
10
- b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
11
- return "data:image/png;base64," + b64
12
-
13
-
14
- def smiles_to_data(smiles):
15
- import rdkit
16
-
17
- m = rdkit.Chem.MolFromSmiles(smiles)
18
- if m is None:
19
- return None
20
- img = rdkit.Chem.Draw.MolToImage(m)
21
- data = pil_to_data(img)
22
- return data
23
-
24
-
25
- @op("LynxKite Graph Analytics", "Draw molecules")
26
- def draw_molecules(df: pd.DataFrame, *, smiles_column: str, image_column: str = "image"):
27
- df = df.copy()
28
- df[image_column] = df[smiles_column].apply(smiles_to_data)
29
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  # Example of a requirements.txt file. LynxKite will automatically install anything you put here.
2
  faker
3
  matplotlib
 
 
 
 
1
  # Example of a requirements.txt file. LynxKite will automatically install anything you put here.
2
  faker
3
  matplotlib
4
+ chembl_webresource_client
5
+ rcsb-api
6
+ itertools