Spaces:
Running
Running
jannisborn
commited on
Commit
•
8b150bd
1
Parent(s):
3a8c428
feat: Initial RT app
Browse files- LICENSE +21 -0
- README.md +13 -11
- app.py +158 -0
- model_cards/.DS_Store +0 -0
- model_cards/regression_transformer.png +0 -0
- model_cards/regression_transformer_article.md +59 -0
- model_cards/regression_transformer_description.md +8 -0
- model_cards/regression_transformer_examples.csv +7 -0
- requirements.txt +5 -0
- utils.py +172 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Generative Toolkit 4 Scientific Discovery
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 😻
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: red
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.12.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# gt4sd-apps
|
2 |
+
Web apps of GT4SD models powered via gradio.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
+
## Installation
|
5 |
+
1. Install `gt4sd` from [https://github.com/GT4SD/gt4sd-core](`gt4sd-core`).
|
6 |
+
2. Install requirements in env:
|
7 |
+
```sh
|
8 |
+
conda activate gt4sd
|
9 |
+
pip install -r requirements.txt
|
10 |
+
```
|
11 |
+
3. Run a demo on a localhost:
|
12 |
+
```sh
|
13 |
+
python apps/algorithms/conditional_generation/regression_transformer/app.py
|
14 |
+
```
|
app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import pathlib
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import pandas as pd
|
6 |
+
from gt4sd.algorithms.conditional_generation.regression_transformer import (
|
7 |
+
RegressionTransformer,
|
8 |
+
)
|
9 |
+
from gt4sd.algorithms.registry import ApplicationsRegistry
|
10 |
+
from utils import (
|
11 |
+
draw_grid_generate,
|
12 |
+
draw_grid_predict,
|
13 |
+
get_application,
|
14 |
+
get_inference_dict,
|
15 |
+
get_rt_name,
|
16 |
+
)
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
logger.addHandler(logging.NullHandler())
|
20 |
+
|
21 |
+
|
22 |
+
def regression_transformer(
|
23 |
+
algorithm: str,
|
24 |
+
task: str,
|
25 |
+
target: str,
|
26 |
+
number_of_samples: int,
|
27 |
+
search: str,
|
28 |
+
temperature: float,
|
29 |
+
tolerance: int,
|
30 |
+
wrapper: bool,
|
31 |
+
fraction_to_mask: float,
|
32 |
+
property_goal: str,
|
33 |
+
tokens_to_mask: str,
|
34 |
+
substructures_to_mask: str,
|
35 |
+
substructures_to_keep: str,
|
36 |
+
):
|
37 |
+
|
38 |
+
if task == "Predict" and wrapper:
|
39 |
+
logger.warning(
|
40 |
+
f"For prediction, no sampling_wrapper will be used, ignoring: fraction_to_mask: {fraction_to_mask}, "
|
41 |
+
f"tokens_to_mask: {tokens_to_mask}, substructures_to_mask={substructures_to_mask}, "
|
42 |
+
f"substructures_to_keep: {substructures_to_keep}."
|
43 |
+
)
|
44 |
+
sampling_wrapper = {}
|
45 |
+
elif not wrapper:
|
46 |
+
sampling_wrapper = {}
|
47 |
+
else:
|
48 |
+
substructures_to_mask = (
|
49 |
+
[]
|
50 |
+
if substructures_to_mask == ""
|
51 |
+
else substructures_to_mask.replace(" ", "").split(",")
|
52 |
+
)
|
53 |
+
substructures_to_keep = (
|
54 |
+
[]
|
55 |
+
if substructures_to_keep == ""
|
56 |
+
else substructures_to_keep.replace(" ", "").split(",")
|
57 |
+
)
|
58 |
+
tokens_to_mask = [] if tokens_to_mask == "" else tokens_to_mask.split(",")
|
59 |
+
|
60 |
+
property_goals = {}
|
61 |
+
if property_goal == "":
|
62 |
+
raise ValueError(
|
63 |
+
"For conditional generation you have to specify `property_goal`."
|
64 |
+
)
|
65 |
+
for line in property_goal.split(","):
|
66 |
+
property_goals[line.split(":")[0].strip()] = float(line.split(":")[1])
|
67 |
+
|
68 |
+
sampling_wrapper = {
|
69 |
+
"substructures_to_keep": substructures_to_keep,
|
70 |
+
"substructures_to_mask": substructures_to_mask,
|
71 |
+
"text_filtering": False,
|
72 |
+
"fraction_to_mask": fraction_to_mask,
|
73 |
+
"property_goal": property_goals,
|
74 |
+
}
|
75 |
+
algorithm_application = get_application(algorithm.split(":")[0])
|
76 |
+
algorithm_version = algorithm.split(" ")[-1].lower()
|
77 |
+
config = algorithm_application(
|
78 |
+
algorithm_version=algorithm_version,
|
79 |
+
search=search.lower(),
|
80 |
+
temperature=temperature,
|
81 |
+
tolerance=tolerance,
|
82 |
+
sampling_wrapper=sampling_wrapper,
|
83 |
+
)
|
84 |
+
model = RegressionTransformer(configuration=config, target=target)
|
85 |
+
samples = list(model.sample(number_of_samples))
|
86 |
+
|
87 |
+
if task == "Predict":
|
88 |
+
return draw_grid_predict(samples[0], target, domain=algorithm.split(":")[0])
|
89 |
+
else:
|
90 |
+
return draw_grid_generate(samples, domain=algorithm.split(":")[0])
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
|
95 |
+
# Preparation (retrieve all available algorithms)
|
96 |
+
all_algos = ApplicationsRegistry.list_available()
|
97 |
+
rt_algos = list(
|
98 |
+
filter(lambda x: "RegressionTransformer" in x["algorithm_name"], all_algos)
|
99 |
+
)
|
100 |
+
rt_names = list(map(get_rt_name, rt_algos))
|
101 |
+
|
102 |
+
properties = {}
|
103 |
+
for algo in rt_algos:
|
104 |
+
application = get_application(
|
105 |
+
algo["algorithm_application"].split("Transformer")[-1]
|
106 |
+
)
|
107 |
+
data = get_inference_dict(
|
108 |
+
application=application, algorithm_version=algo["algorithm_version"]
|
109 |
+
)
|
110 |
+
properties[get_rt_name(algo)] = data
|
111 |
+
properties
|
112 |
+
|
113 |
+
# Load metadata
|
114 |
+
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
|
115 |
+
|
116 |
+
examples = pd.read_csv(
|
117 |
+
metadata_root.joinpath("regression_transformer_examples.csv"), header=None
|
118 |
+
).fillna("")
|
119 |
+
|
120 |
+
with open(metadata_root.joinpath("regression_transformer_article.md"), "r") as f:
|
121 |
+
article = f.read()
|
122 |
+
with open(
|
123 |
+
metadata_root.joinpath("regression_transformer_description.md"), "r"
|
124 |
+
) as f:
|
125 |
+
description = f.read()
|
126 |
+
|
127 |
+
demo = gr.Interface(
|
128 |
+
fn=regression_transformer,
|
129 |
+
title="Regression Transformer",
|
130 |
+
inputs=[
|
131 |
+
gr.Dropdown(rt_names, label="Algorithm version", value="Molecules: Qed"),
|
132 |
+
gr.Radio(choices=["Predict", "Generate"], label="Task", value="Generate"),
|
133 |
+
gr.Textbox(
|
134 |
+
label="Input", placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1", lines=1
|
135 |
+
),
|
136 |
+
gr.Slider(
|
137 |
+
minimum=1, maximum=50, value=10, label="Number of samples", step=1
|
138 |
+
),
|
139 |
+
gr.Radio(choices=["Sample", "Greedy"], label="Search", value="Sample"),
|
140 |
+
gr.Slider(minimum=0.5, maximum=2, value=1, label="Decoding temperature"),
|
141 |
+
gr.Slider(minimum=5, maximum=100, value=30, label="Tolerance", step=1),
|
142 |
+
gr.Radio(choices=[True, False], label="Sampling Wrapper", value=True),
|
143 |
+
gr.Slider(minimum=0, maximum=1, value=0.5, label="Fraction to mask"),
|
144 |
+
gr.Textbox(label="Property goal", placeholder="<qed>:0.75", lines=1),
|
145 |
+
gr.Textbox(label="Tokens to mask", placeholder="N, C", lines=1),
|
146 |
+
gr.Textbox(
|
147 |
+
label="Substructures to mask", placeholder="C(=O), C#C", lines=1
|
148 |
+
),
|
149 |
+
gr.Textbox(
|
150 |
+
label="Substructures to keep", placeholder="C1=CC=C(Cl)C=C1", lines=1
|
151 |
+
),
|
152 |
+
],
|
153 |
+
outputs=gr.HTML(label="Output"),
|
154 |
+
article=article,
|
155 |
+
description=description,
|
156 |
+
examples=examples.values.tolist(),
|
157 |
+
)
|
158 |
+
demo.launch(debug=True, show_error=True)
|
model_cards/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model_cards/regression_transformer.png
ADDED
model_cards/regression_transformer_article.md
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model card -- Regression Transformer
|
2 |
+
|
3 |
+
## Parameters
|
4 |
+
|
5 |
+
### Algorithm Version:
|
6 |
+
Which model checkpoint to use (trained on different datasets).
|
7 |
+
|
8 |
+
### Task
|
9 |
+
Whether the multitask model should be used for property prediction or conditional generation (default).
|
10 |
+
|
11 |
+
### Input
|
12 |
+
The input sequence. In the default setting (where `Task` is *Generate* and `Sampling Wrapper` is *True*) this can be a seed SMILES (for the molecule models) or amino-acid sequence (for the protein models). The model will locally adapt the seed sequence by masking `Fraction to mask` of the tokens.
|
13 |
+
If the `Task` is *Predict*, the sequences are given as SELFIES for the molecule models. Moreover, the tokens that should be predicted (`[MASK]` in the input) have to be given explicitly. Populate the examples to understand better.
|
14 |
+
NOTE: When setting `Task` to *Generate*, and `Sampling Wrapper` to *False*, the user has maximal control about the generative process and can explicitly decide which tokens should be masked.
|
15 |
+
|
16 |
+
### Number of samples
|
17 |
+
How many samples should be generated (between 1 and 50). If `Task` is *Predict*, this has to be set to 1.
|
18 |
+
|
19 |
+
### Search
|
20 |
+
Decoding search method. Use *Sample* if `Task` is *Generate*. If `Task` is *Predict*, use *Greedy*.
|
21 |
+
|
22 |
+
### Tolerance
|
23 |
+
Precision tolerance; only used if `Task` is *Generate*. This is a single float between 0 and 100 for the the tolerated deviation between desired/primed property and predicted property of the generated molecule. Given in percentage with respect to the property range encountered during training.
|
24 |
+
NOTE: The tolerance is *only* used for post-hoc filtering of the generated samples.
|
25 |
+
|
26 |
+
### Sampling Wrapper
|
27 |
+
Only used if `Task` is *Generate*. If set to *False*, the user has to provide a full RT-sequence as `Input` and has to **explicitly** decide which tokens are masked (see example below). This gives full control but is tedious. Instead, if `Sampling Wrapper` is set to *True*, the RT stochastically determines which parts of the sequence are masked.
|
28 |
+
**NOTE**: All below arguments only apply if `Sampling Wrapper` is *True*.
|
29 |
+
|
30 |
+
#### Fraction to mask
|
31 |
+
Specifies the ratio of tokens that can be changed by the model. Argument only applies if `Task` is *Generate* and `Sampling Wrapper` is *True*.
|
32 |
+
|
33 |
+
#### Property goal
|
34 |
+
Specifies the desired target properties for the generation. Need to be given in the format `<prop>:value`. If the model supports multiple properties, give them separated by a comma `,`. Argument only applies if `Task` is *Generate* and `Sampling Wrapper` is *True*.
|
35 |
+
|
36 |
+
#### Tokens to mask
|
37 |
+
Optionally specifies which tokens (atoms, bonds etc) can be masked. Please separate multiple tokens by comma (`,`). If not specified, all tokens can be masked. Argument only applies if `Task` is *Generate* and `Sampling Wrapper` is *True*.
|
38 |
+
|
39 |
+
#### Substructures to mask
|
40 |
+
Optionally specifies a list of substructures that should *definitely* be masked (excluded from stochastic masking). Given in SMILES format. If multiple are provided, separate by comma (`,`). Argument only applies if `Task` is *Generate* and `Sampling Wrapper` is *True*.
|
41 |
+
*NOTE*: Most models operate on SELFIES and the matching of the substructures occurs in SELFIES simply on a string level.
|
42 |
+
|
43 |
+
#### Substructures to keep
|
44 |
+
Optionally specifies a list of substructures that should definitely be present in the target sample (i.e., excluded from stochastic masking). Given in SMILES format. Argument only applies if `Task` is *Generate* and `Sampling Wrapper` is *True*.
|
45 |
+
*NOTE*: This keeps tokens even if they are included in `tokens_to_mask`.
|
46 |
+
*NOTE*: Most models operate on SELFIES and the matching of the substructures occurs in SELFIES simply on a string level.
|
47 |
+
|
48 |
+
## Citation
|
49 |
+
|
50 |
+
```bib
|
51 |
+
@article{born2022regression,
|
52 |
+
title={Regression Transformer: Concurrent Conditional Generation and Regression by Blending Numerical and Textual Tokens},
|
53 |
+
author={Born, Jannis and Manica, Matteo},
|
54 |
+
journal={arXiv preprint arXiv:2202.01338},
|
55 |
+
note={Spotlight talk at ICLR workshop on Machine Learning for Drug Discovery},
|
56 |
+
year={2022}
|
57 |
+
}
|
58 |
+
```
|
59 |
+
|
model_cards/regression_transformer_description.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### Concurrent sequence regression and generation for molecular language modeling**
|
3 |
+
|
4 |
+
The RT is a multitask Transformer that reformulates regression as a conditional sequence modeling task.
|
5 |
+
This yields a dichotomous language model that seamlessly integrates regression with property-driven conditional generation task.
|
6 |
+
**Further reading:** [arXiv preprint](https://arxiv.org/abs/2202.01338) and [GitHub development code](https://github.com/IBM/regression-transformer).
|
7 |
+
|
8 |
+
Each `algorithm_version` refers to one trained model. Each model can be used for **two tasks**, either to *predict* one (or multiple) properties of a molecule or to *generate* a molecule (given a seed molecule and a property constraint).
|
model_cards/regression_transformer_examples.csv
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Molecules: Logp_and_synthesizability,Generate,CCOC1=NC=NC(=C1C)NCCOC(C)C,3,Sample,1.2,20,True,0.3,"<logp>:0.390, <scs>:2.628",N,(C)C,CCO
|
2 |
+
Molecules: Qed,Generate,CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1,10,Sample,1.0,30,True,0.5,<qed>:0.75,"N, C","C(=O), CC",C1=CC=C(Cl)C=C1
|
3 |
+
Molecules: Logp_and_synthesizability,Predict,<logp>[MASK][MASK][MASK][MASK][MASK]|<scs>[MASK][MASK][MASK][MASK][MASK]|[C][C][O][C][=N][C][=N][C][Branch1_2][Branch1_1][=C][Ring1][Branch1_2][C][N][C][C][O][C][Branch1_1][C][C][C],1,Greedy,1.0,30,False,0.0,,,,
|
4 |
+
Proteins: Stability,Predict,<stab>[MASK][MASK][MASK][MASK][MASK]|GSQEVNSGTQTYKNASPEEAERIARKAGATTWTEKGNKWEIRI,1,Greedy,1.0,1,False,0.0,,,,
|
5 |
+
Proteins: Stability,Generate,GSQEVNSGTQTYKNASPEEAERIARKAGATTWTEKGNKWEIRI,10,Sample,1.2,30,True,0.3,<stab>:0.393,,SQEVNSGTQTYKN,WTEK
|
6 |
+
Molecules: Qed,Generate,<qed>0.717|[MASK][MASK][MASK][MASK][MASK][C][Branch2_1][Ring1][Ring1][MASK][MASK][=C][C][Branch1_1][C][C][=N][C][MASK][MASK][=C][C][=C][Ring1][O][Ring1][Branch1_2][=C][Ring2][MASK][MASK],10,Sample,1.2,30,False,0.0,,,,
|
7 |
+
Molecules: Solubility,Generate,ClC(Cl)C(Cl)Cl,5,Sample,1.3,40,True,0.4,<esol>:0.754,,,
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gt4sd>=1.0.0
|
2 |
+
gradio>=3.9
|
3 |
+
markdown-it-py>=2.1.0
|
4 |
+
mols2grid>=0.2.0
|
5 |
+
pandas>=1.0.0
|
utils.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from collections import defaultdict
|
5 |
+
from typing import Dict, List, Tuple
|
6 |
+
|
7 |
+
import mols2grid
|
8 |
+
import pandas as pd
|
9 |
+
from gt4sd.algorithms import (
|
10 |
+
RegressionTransformerMolecules,
|
11 |
+
RegressionTransformerProteins,
|
12 |
+
)
|
13 |
+
from gt4sd.algorithms.core import AlgorithmConfiguration
|
14 |
+
from rdkit import Chem
|
15 |
+
from terminator.selfies import decoder
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
logger.addHandler(logging.NullHandler())
|
19 |
+
|
20 |
+
|
21 |
+
def get_application(application: str) -> AlgorithmConfiguration:
|
22 |
+
"""
|
23 |
+
Convert application name to AlgorithmConfiguration.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
application: Molecules or Proteins
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
The corresponding AlgorithmConfiguration
|
30 |
+
"""
|
31 |
+
if application == "Molecules":
|
32 |
+
application = RegressionTransformerMolecules
|
33 |
+
elif application == "Proteins":
|
34 |
+
application = RegressionTransformerProteins
|
35 |
+
else:
|
36 |
+
raise ValueError(
|
37 |
+
"Currently only models for molecules and proteins are supported"
|
38 |
+
)
|
39 |
+
return application
|
40 |
+
|
41 |
+
|
42 |
+
def get_inference_dict(
|
43 |
+
application: AlgorithmConfiguration, algorithm_version: str
|
44 |
+
) -> Dict:
|
45 |
+
"""
|
46 |
+
Get inference dictionary for a given application and algorithm version.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
application: algorithm application (Molecules or Proteins)
|
50 |
+
algorithm_version: algorithm version (e.g. qed)
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
A dictionary with the inference parameters.
|
54 |
+
"""
|
55 |
+
config = application(algorithm_version=algorithm_version)
|
56 |
+
with open(os.path.join(config.ensure_artifacts(), "inference.json"), "r") as f:
|
57 |
+
data = json.load(f)
|
58 |
+
return data
|
59 |
+
|
60 |
+
|
61 |
+
def get_rt_name(x: Dict) -> str:
|
62 |
+
"""
|
63 |
+
Get the UI display name of the regression transformer.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
x: dictionary with the inference parameters
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
The display name
|
70 |
+
"""
|
71 |
+
return (
|
72 |
+
x["algorithm_application"].split("Transformer")[-1]
|
73 |
+
+ ": "
|
74 |
+
+ x["algorithm_version"].capitalize()
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def draw_grid_predict(prediction: str, target: str, domain: str) -> str:
|
79 |
+
"""
|
80 |
+
Uses mols2grid to draw a HTML grid for the prediction
|
81 |
+
|
82 |
+
Args:
|
83 |
+
prediction: Predicted sequence.
|
84 |
+
target: Target molecule
|
85 |
+
domain: Domain of the prediction (molecules or proteins)
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
HTML to display
|
89 |
+
"""
|
90 |
+
|
91 |
+
if domain not in ["Molecules", "Proteins"]:
|
92 |
+
raise ValueError(f"Unsupported domain {domain}")
|
93 |
+
|
94 |
+
seq = target.split("|")[-1]
|
95 |
+
converter = (
|
96 |
+
decoder
|
97 |
+
if domain == "Molecules"
|
98 |
+
else lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x))
|
99 |
+
)
|
100 |
+
try:
|
101 |
+
seq = converter(seq)
|
102 |
+
except Exception:
|
103 |
+
logger.warning(f"Could not draw sequence {seq}")
|
104 |
+
|
105 |
+
result = {"SMILES": [seq], "Name": ["Target"]}
|
106 |
+
# Add properties
|
107 |
+
for prop in prediction.split("<")[1:]:
|
108 |
+
result[
|
109 |
+
prop.split(">")[0]
|
110 |
+
] = f"{prop.split('>')[0].capitalize()} = {prop.split('>')[1]}"
|
111 |
+
result_df = pd.DataFrame(result)
|
112 |
+
obj = mols2grid.display(
|
113 |
+
result_df,
|
114 |
+
tooltip=list(result.keys()),
|
115 |
+
height=900,
|
116 |
+
n_cols=1,
|
117 |
+
name="Results",
|
118 |
+
size=(600, 700),
|
119 |
+
)
|
120 |
+
return obj.data
|
121 |
+
|
122 |
+
|
123 |
+
def draw_grid_generate(
|
124 |
+
samples: List[Tuple[str]], domain: str, n_cols: int = 5, size=(140, 200)
|
125 |
+
) -> str:
|
126 |
+
"""
|
127 |
+
Uses mols2grid to draw a HTML grid for the generated molecules
|
128 |
+
|
129 |
+
Args:
|
130 |
+
samples: The generated samples (with properties)
|
131 |
+
domain: Domain of the prediction (molecules or proteins)
|
132 |
+
n_cols: Number of columns in grid. Defaults to 5.
|
133 |
+
size: Size of molecule in grid. Defaults to (140, 200).
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
HTML to display
|
137 |
+
"""
|
138 |
+
|
139 |
+
if domain not in ["Molecules", "Proteins"]:
|
140 |
+
raise ValueError(f"Unsupported domain {domain}")
|
141 |
+
|
142 |
+
if domain == "Proteins":
|
143 |
+
try:
|
144 |
+
smis = list(
|
145 |
+
map(lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x[0])), samples)
|
146 |
+
)
|
147 |
+
except Exception:
|
148 |
+
logger.warning(f"Could not convert some sequences {samples}")
|
149 |
+
else:
|
150 |
+
smis = [s[0] for s in samples]
|
151 |
+
|
152 |
+
result = defaultdict(list)
|
153 |
+
result.update({"SMILES": smis, "Name": [f"sample_{i}" for i in range(len(smis))]})
|
154 |
+
|
155 |
+
# Create properties
|
156 |
+
properties = [s.split("<")[1] for s in samples[0][1].split(">")[:-1]]
|
157 |
+
# Fill properties
|
158 |
+
for sample in samples:
|
159 |
+
for prop in properties:
|
160 |
+
value = float(sample[1].split(prop)[-1][1:].split("<")[0])
|
161 |
+
result[prop].append(f"{prop} = {value}")
|
162 |
+
|
163 |
+
result_df = pd.DataFrame(result)
|
164 |
+
obj = mols2grid.display(
|
165 |
+
result_df,
|
166 |
+
tooltip=list(result.keys()),
|
167 |
+
height=1100,
|
168 |
+
n_cols=n_cols,
|
169 |
+
name="Results",
|
170 |
+
size=size,
|
171 |
+
)
|
172 |
+
return obj.data
|