File size: 4,658 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import networkx as nx
import numpy as np
import torch
from rdkit import Chem
from torch_geometric.utils import from_smiles
from torch_geometric.data import Data

from deepscreen.data.featurizers.categorical import one_of_k_encoding_unk, one_of_k_encoding
from deepscreen.utils import get_logger

log = get_logger(__name__)


def atom_features(atom, explicit_H=False, use_chirality=True):
    """
    Adapted from TransformerCPI 2.0
    """
    symbol = ['C', 'N', 'O', 'F', 'P', 'S', 'Cl', 'Br', 'I', 'other']  # 10-dim
    degree = [0, 1, 2, 3, 4, 5, 6]  # 7-dim
    hybridization_type = [Chem.rdchem.HybridizationType.SP,
                          Chem.rdchem.HybridizationType.SP2,
                          Chem.rdchem.HybridizationType.SP3,
                          Chem.rdchem.HybridizationType.SP3D,
                          Chem.rdchem.HybridizationType.SP3D2,
                          'other']  # 6-dim

    # 10+7+2+6+1=26
    results = one_of_k_encoding_unk(atom.GetSymbol(), symbol) + \
              one_of_k_encoding(atom.GetDegree(), degree) + \
              [atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
              one_of_k_encoding_unk(atom.GetHybridization(), hybridization_type) + [atom.GetIsAromatic()]

    # In case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    # 26+5=31
    if not explicit_H:
        results = results + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                  [0, 1, 2, 3, 4])
    # 31+3=34
    if use_chirality:
        try:
            results = results + one_of_k_encoding_unk(
                atom.GetProp('_CIPCode'),
                ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
        except:
            results = results + [False, False] + [atom.HasProp('_ChiralityPossible')]

    return np.array(results)


def bond_features(bond):
    bt = bond.GetBondType()
    return np.array(
        [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE,
         bt == Chem.rdchem.BondType.AROMATIC, bond.GetIsConjugated(), bond.IsInRing()])


def smiles_to_graph_pyg(smiles):
    """
    Convert SMILES to graph with the default method defined by PyTorch Geometric
    """
    try:
        return from_smiles(smiles)
    except Exception as e:
        log.warning(f"Failed to featurize the following SMILES to graph: {smiles} due to {str(e)}")
        return None


def smiles_to_graph(smiles, atom_features: callable = atom_features):
    """
    Convert SMILES to graph with custom atom_features
    """
    try:
        mol = Chem.MolFromSmiles(smiles)

        features = []
        for atom in mol.GetAtoms():
            feature = atom_features(atom)
            features.append(feature / sum(feature))
        features = np.array(features)

        edges = []
        for bond in mol.GetBonds():
            edges.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
        g = nx.Graph(edges).to_directed()

        if len(edges) == 0:
            edge_index = [[0, 0]]
        else:
            edge_index = []
            for e1, e2 in g.edges:
                edge_index.append([e1, e2])

        return Data(x=torch.Tensor(features),
                    edge_index=torch.LongTensor(edge_index).transpose(0, 1))

    except Exception as e:
        log.warning(f"Failed to convert SMILES ({smiles}) to graph due to {str(e)}")
        return None
    # features = []
    # for atom in mol.GetAtoms():
    #     feature = atom_features(atom)
    #     features.append(feature / sum(feature))
    #
    # edge_indices = []
    # for bond in mol.GetBonds():
    #     i = bond.GetBeginAtomIdx()
    #     j = bond.GetEndAtomIdx()
    #     edge_indices += [[i, j], [j, i]]
    #
    # edge_index = torch.tensor(edge_indices)
    # edge_index = edge_index.t().to(torch.long).view(2, -1)
    #
    # if edge_index.numel() > 0:  # Sort indices.
    #     perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
    #     edge_index = edge_index[:, perm]
    #


def smiles_to_mol_features(smiles, num_atom_feat: callable):
    try:
        mol = Chem.MolFromSmiles(smiles)
        num_atom_feat = len(atom_features(mol.GetAtoms()[0]))
        atom_feat = np.zeros((mol.GetNumAtoms(), num_atom_feat))
        for atom in mol.GetAtoms():
            atom_feat[atom.GetIdx(), :] = atom_features(atom)
        adj = Chem.GetAdjacencyMatrix(mol)
        adj_mat = np.array(adj)

        return atom_feat, adj_mat

    except Exception as e:
        log.warning(f"Failed to featurize the following SMILES to molecular features: {smiles} due to {str(e)}")
        return None