Eachan Johnson commited on
Commit
6335878
·
1 Parent(s): 2e4ba77

Add application file

Browse files
Files changed (1) hide show
  1. app.py +254 -0
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio demo for schemist."""
2
+
3
+ from typing import Iterable, List, Union
4
+ from io import TextIOWrapper
5
+ import os
6
+ os.environ["COMMANDLINE_ARGS"] = "--no-gradio-queue"
7
+
8
+ from carabiner import cast, print_err
9
+ from carabiner.pd import read_table
10
+ import gradio as gr
11
+ import nemony as nm
12
+ import numpy as np
13
+ import pandas as pd
14
+ from rdkit.Chem import Draw, Mol
15
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedModel
16
+ import schemist as sch
17
+ from schemist.converting import (
18
+ _TO_FUNCTIONS,
19
+ _FROM_FUNCTIONS,
20
+ convert_string_representation,
21
+ _x2mol,
22
+ )
23
+ from schemist.tables import converter
24
+
25
+ MODELS = (
26
+ "scbirlab/lchemme-base-zinc22-lteq300",
27
+ "scbirlab/lchemme-base-dosedo-lteq300",
28
+ "facebook/bart-base",
29
+ )
30
+
31
+ models = {model_name: (
32
+ AutoTokenizer.from_pretrained(model_name, cache_dir="model-cache"),
33
+ AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="model-cache"),
34
+ ) for model_name in MODELS}
35
+
36
+ def load_input_data(file: TextIOWrapper) -> pd.DataFrame:
37
+ df = read_table(file.name)
38
+ string_cols = list(df.select_dtypes(exclude=[np.number]))
39
+ df = gr.Dataframe(value=df, visible=True)
40
+ return df, gr.Dropdown(choices=string_cols, interactive=True)
41
+
42
+
43
+ def _clean_split_input(strings: str) -> List[str]:
44
+ return [s2.strip() for s in strings.split("\n") for s2 in s.split(",")]
45
+
46
+
47
+ def _convert_input(
48
+ strings: str,
49
+ input_representation: str = 'smiles',
50
+ output_representation: Union[Iterable[str], str] = 'smiles'
51
+ ) -> List[str]:
52
+ strings = _clean_split_input(strings)
53
+ converted = convert_string_representation(
54
+ strings=strings,
55
+ input_representation=input_representation,
56
+ output_representation=output_representation,
57
+ )
58
+ return {
59
+ key: list(map(str, cast(val, to=list)))
60
+ for key, val in converted.items()
61
+ }
62
+
63
+
64
+ def model_convert(
65
+ df: pd.DataFrame,
66
+ name: str,
67
+ tokenizer,
68
+ model: PreTrainedModel
69
+ ) -> pd.DataFrame:
70
+
71
+ model_basename = name.split("/")[-1]
72
+ inputs = tokenizer(df["inputs"].tolist(), return_tensors="pt")
73
+ model.eval()
74
+ model_args = {key: inputs[key] for key in ['input_ids', 'attention_mask']}
75
+ outputs = model(
76
+ **model_args,
77
+ # decoder_input_ids=model_args['input_ids'],
78
+ )
79
+ output_smiles = tokenizer.batch_decode(
80
+ outputs.logits.argmax(dim=-1),
81
+ skip_special_tokens=True,
82
+ clean_up_tokenization_spaces=True,
83
+ )
84
+ output_inchikey = convert_string_representation(
85
+ strings=output_smiles,
86
+ output_representation="inchikey",
87
+ )
88
+ return pd.DataFrame({
89
+ f"{model_basename}_smiles": output_smiles,
90
+ f"{model_basename}_inchikey": output_inchikey,
91
+ })
92
+
93
+
94
+ def convert_one(
95
+ strings: str,
96
+ output_representation: Union[Iterable[str], str] = MODELS[0]
97
+ ):
98
+ input_representation: str = 'smiles'
99
+ df = pd.DataFrame({
100
+ "inputs": _clean_split_input(strings),
101
+ })
102
+
103
+ true_canonical_df = convert_file(
104
+ df=df,
105
+ column="inputs",
106
+ input_representation=input_representation,
107
+ output_representation=["smiles", "inchikey"]
108
+ )
109
+
110
+ output_representation = cast(output_representation, to=list)
111
+ model_canonical_dfs = {
112
+ model_name: model_convert(df, model_name, *models[model_name])
113
+ for model_name in output_representation
114
+ }
115
+
116
+ return gr.DataFrame(
117
+ pd.concat([true_canonical_df] + list(model_canonical_dfs.values()), axis=1),
118
+ visible=True
119
+ )
120
+
121
+
122
+ def convert_file(
123
+ df: pd.DataFrame,
124
+ column: str = 'smiles',
125
+ input_representation: str = 'smiles',
126
+ output_representation: Union[str, Iterable[str]] = 'smiles'
127
+ ):
128
+ message = f"Converting from {input_representation} to {output_representation}..."
129
+ print_err(message)
130
+ gr.Info(message, duration=3)
131
+ errors, df = converter(
132
+ df=df,
133
+ column=column,
134
+ input_representation=input_representation,
135
+ output_representation=output_representation,
136
+ )
137
+ df = df[
138
+ cast(output_representation, to=list) +
139
+ [col for col in df if col not in output_representation]
140
+ ]
141
+ all_err = sum(err for key, err in errors.items())
142
+ message = (
143
+ f"Converted {df.shape[0]} molecules from "
144
+ f"{input_representation} to {output_representation} "
145
+ f"with {all_err} errors!"
146
+ )
147
+ print_err(message)
148
+ gr.Info(message, duration=5)
149
+ return df
150
+
151
+
152
+ def draw_one(
153
+ strings: Union[Iterable[str], str]
154
+ ):
155
+ input_representation: str = 'smiles'
156
+ _ids = _convert_input(
157
+ strings,
158
+ input_representation,
159
+ ["inchikey", "id"],
160
+ )
161
+ mols = cast(_x2mol(_clean_split_input(strings), input_representation), to=list)
162
+ if isinstance(mols, Mol):
163
+ mols = [mols]
164
+ return Draw.MolsToGridImage(
165
+ mols,
166
+ molsPerRow=min(3, len(mols)),
167
+ subImgSize=(300, 300),
168
+ legends=["\n".join(items) for items in zip(*_ids.values())],
169
+ )
170
+
171
+
172
+ def download_table(
173
+ df: pd.DataFrame
174
+ ) -> str:
175
+ df_hash = nm.hash(pd.util.hash_pandas_object(df).values)
176
+ filename = f"converted-{df_hash}.csv"
177
+ df.to_csv(filename, index=False)
178
+ return gr.DownloadButton(value=filename, visible=True)
179
+
180
+
181
+ with gr.Blocks() as demo:
182
+
183
+ gr.Markdown(
184
+ """
185
+ # SMILES canonicalization with LChemME
186
+
187
+ Interface to demonstrate SMILES canonicalization using Large Chemical Models pre-trained using
188
+ [LChemME](https://github.com/scbirlab/lchemme).
189
+
190
+ """
191
+ )
192
+
193
+ input_line = gr.Textbox(
194
+ label="Input",
195
+ placeholder="Paste your molecule(s) here, one per line",
196
+ lines=2,
197
+ interactive=True,
198
+ submit_btn=True,
199
+ )
200
+ output_format_single = gr.CheckboxGroup(
201
+ label="Use model(s):",
202
+ choices=list(MODELS),
203
+ value=MODELS[:1],
204
+ interactive=True,
205
+ )
206
+ examples = gr.Examples(
207
+ examples=[
208
+ ["CC(Oc1c(cccc1)C(=O)N)=O", MODELS[0]],
209
+ ["O=S1(N([C@H](C)COC(NC[3H])=O)C[C@H]([C@@H](Oc2cc(-c3cnc(c(c3)C)OC)ccc21)CN(C)C(c1c(C)c(sc1Cl)C)=O)C)=O", MODELS[1]],
210
+ ["CC(Oc1ccccc1C(O)=O)=O", MODELS[0]],
211
+ ["CC(Oc1ccccc1C(O)=O)=O", MODELS[2]],
212
+ ],
213
+ inputs=[input_line, output_format_single],
214
+ )
215
+ download_single = gr.DownloadButton(
216
+ label="Download converted data",
217
+ visible=False,
218
+ )
219
+
220
+ output_line = gr.DataFrame(
221
+ label="Converted",
222
+ interactive=False,
223
+ visible=False,
224
+ )
225
+ drawing = gr.Image(label="Chemical structures")
226
+
227
+ gr.on(
228
+ [
229
+ input_line.submit,
230
+ ],
231
+ fn=convert_one,
232
+ inputs=[
233
+ input_line,
234
+ output_format_single,
235
+ ],
236
+ outputs={
237
+ output_line,
238
+ }
239
+ ).then(
240
+ draw_one,
241
+ inputs=[
242
+ input_line,
243
+ ],
244
+ outputs=drawing,
245
+ ).then(
246
+ download_table,
247
+ inputs=output_line,
248
+ outputs=download_single
249
+ )
250
+
251
+ if __name__ == "__main__":
252
+ demo.queue()
253
+ demo.launch(share=True)
254
+