paccmann_gp / utils.py
jannisborn's picture
update
4b5d582 unverified
raw
history blame
2.04 kB
import logging
from collections import defaultdict
from typing import List, Callable
from gt4sd.properties import PropertyPredictorRegistry
from gt4sd.algorithms.prediction.paccmann.core import PaccMann, AffinityPredictor
import torch
import mols2grid
import pandas as pd
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
def get_affinity_function(target: str) -> Callable:
return lambda mols: torch.stack(
list(
PaccMann(
AffinityPredictor(protein_targets=[target] * len(mols), ligands=mols)
).sample(len(mols))
)
).tolist()
EVAL_DICT = {
"qed": PropertyPredictorRegistry.get_property_predictor("qed"),
"sa": PropertyPredictorRegistry.get_property_predictor("sas"),
}
def draw_grid_generate(
samples: List[str],
properties: List[str],
protein_target: str,
n_cols: int = 3,
size=(140, 200),
) -> str:
"""
Uses mols2grid to draw a HTML grid for the generated molecules
Args:
samples: The generated samples.
n_cols: Number of columns in grid. Defaults to 5.
size: Size of molecule in grid. Defaults to (140, 200).
Returns:
HTML to display
"""
if protein_target != "":
EVAL_DICT.update({"affinity": get_affinity_function(protein_target)})
result = defaultdict(list)
result.update(
{"SMILES": samples, "Name": [f"Generated_{i}" for i in range(len(samples))]},
)
if "affinity" in properties:
properties.remove("affinity")
vals = EVAL_DICT["affinity"](samples)
result["affinity"] = vals
# Fill properties
for sample in samples:
for prop in properties:
value = EVAL_DICT[prop](sample)
result[prop].append(f"{prop} = {value}")
result_df = pd.DataFrame(result)
obj = mols2grid.display(
result_df,
tooltip=list(result.keys()),
height=1100,
n_cols=n_cols,
name="Results",
size=size,
)
return obj.data