File size: 3,095 Bytes
bca3a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import numpy as np
import torch
from torch import nn
from rdkit import Chem

from dockformer.data.utils import FeatureTensorDict
from dockformer.utils.consts import POSSIBLE_BOND_TYPES, POSSIBLE_ATOM_TYPES, POSSIBLE_CHARGES, POSSIBLE_CHIRALITIES


def get_atom_features(atom: Chem.Atom):
    # TODO: this is temporary, we need to add more features, for example for Zn
    if atom.GetSymbol() not in POSSIBLE_ATOM_TYPES:
        print(f"********Unknown atom type {atom.GetSymbol()}")
        atom_type = POSSIBLE_ATOM_TYPES.index("Ni")
    else:
        atom_type = POSSIBLE_ATOM_TYPES.index(atom.GetSymbol())
    atom_charge = POSSIBLE_CHARGES.index(max(min(atom.GetFormalCharge(), 1), -1))
    atom_chirality = POSSIBLE_CHIRALITIES.index(atom.GetChiralTag())

    return {"atom_type": atom_type, "atom_charge": atom_charge, "atom_chirality": atom_chirality}


def get_bond_features(bond: Chem.Bond):
    bond_type = POSSIBLE_BOND_TYPES.index(bond.GetBondType())
    return {"bond_type": bond_type}


def make_ligand_features(ligand: Chem.Mol) -> FeatureTensorDict:
    atoms_features = []
    atom_idx_to_atom_pos_idx = {}
    for atom in ligand.GetAtoms():
        atom_idx_to_atom_pos_idx[atom.GetIdx()] = len(atoms_features)
        atoms_features.append(get_atom_features(atom))

    atom_types = torch.tensor(np.array([atom["atom_type"] for atom in atoms_features], dtype=np.int64))
    atom_types_one_hot = nn.functional.one_hot(atom_types, num_classes=len(POSSIBLE_ATOM_TYPES), )
    atom_charges = torch.tensor(np.array([atom["atom_charge"] for atom in atoms_features], dtype=np.int64))
    atom_charges_one_hot = nn.functional.one_hot(atom_charges, num_classes=len(POSSIBLE_CHARGES))
    atom_chiralities = torch.tensor(np.array([atom["atom_chirality"] for atom in atoms_features], dtype=np.int64))
    atom_chiralities_one_hot = nn.functional.one_hot(atom_chiralities, num_classes=len(POSSIBLE_CHIRALITIES))

    ligand_target_feat = torch.cat([atom_types_one_hot.float(), atom_charges_one_hot.float(),
                                    atom_chiralities_one_hot.float()], dim=1)

    # create one-hot matrix encoding for bonds
    ligand_bonds_feat = torch.zeros((len(atoms_features), len(atoms_features), len(POSSIBLE_BOND_TYPES)))
    ligand_bonds = []
    for bond in ligand.GetBonds():
        atom1_idx = atom_idx_to_atom_pos_idx[bond.GetBeginAtomIdx()]
        atom2_idx = atom_idx_to_atom_pos_idx[bond.GetEndAtomIdx()]
        bond_features = get_bond_features(bond)
        ligand_bonds.append((atom1_idx, atom2_idx, bond_features["bond_type"]))
        ligand_bonds_feat[atom1_idx, atom2_idx, bond_features["bond_type"]] = 1

    return {
        # These are used for reconstruction at the end of the pipeline
        "ligand_atype": atom_types,
        "ligand_charge": atom_charges,
        "ligand_chirality": atom_chiralities,
        "ligand_bonds": torch.tensor(ligand_bonds, dtype=torch.int64),
        # these are the actual features
        "ligand_target_feat": ligand_target_feat.float(),
        "ligand_bonds_feat": ligand_bonds_feat.float(),
    }