File size: 2,039 Bytes
5984d9a
 
e83e5dc
 
 
4b5d582
5984d9a
 
 
 
 
 
 
 
e83e5dc
4b5d582
 
 
 
 
 
 
e83e5dc
 
 
 
4b5d582
e83e5dc
 
 
5984d9a
 
e83e5dc
 
5984d9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e83e5dc
 
 
5984d9a
 
e83e5dc
5984d9a
e83e5dc
 
 
 
 
 
 
 
 
5984d9a
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import logging
from collections import defaultdict
from typing import List, Callable
from gt4sd.properties import PropertyPredictorRegistry
from gt4sd.algorithms.prediction.paccmann.core import PaccMann, AffinityPredictor
import torch

import mols2grid
import pandas as pd

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


def get_affinity_function(target: str) -> Callable:
    return lambda mols: torch.stack(
        list(
            PaccMann(
                AffinityPredictor(protein_targets=[target] * len(mols), ligands=mols)
            ).sample(len(mols))
        )
    ).tolist()


EVAL_DICT = {
    "qed": PropertyPredictorRegistry.get_property_predictor("qed"),
    "sa": PropertyPredictorRegistry.get_property_predictor("sas"),
}


def draw_grid_generate(
    samples: List[str],
    properties: List[str],
    protein_target: str,
    n_cols: int = 3,
    size=(140, 200),
) -> str:
    """
    Uses mols2grid to draw a HTML grid for the generated molecules

    Args:
        samples: The generated samples.
        n_cols: Number of columns in grid. Defaults to 5.
        size: Size of molecule in grid. Defaults to (140, 200).

    Returns:
        HTML to display
    """

    if protein_target != "":
        EVAL_DICT.update({"affinity": get_affinity_function(protein_target)})

    result = defaultdict(list)
    result.update(
        {"SMILES": samples, "Name": [f"Generated_{i}" for i in range(len(samples))]},
    )
    if "affinity" in properties:
        properties.remove("affinity")
        vals = EVAL_DICT["affinity"](samples)
        result["affinity"] = vals
    # Fill properties
    for sample in samples:
        for prop in properties:
            value = EVAL_DICT[prop](sample)
            result[prop].append(f"{prop} = {value}")

    result_df = pd.DataFrame(result)
    obj = mols2grid.display(
        result_df,
        tooltip=list(result.keys()),
        height=1100,
        n_cols=n_cols,
        name="Results",
        size=size,
    )
    return obj.data