File size: 5,530 Bytes
4c9e6d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import pickle

import pandas as pd
from tqdm import tqdm

import torch
from torch_geometric.data import Data, InMemoryDataset
import torch_geometric.utils as geoutils

from rdkit import Chem, RDLogger



def label2onehot(labels, dim, device=None):
    """Convert label indices to one-hot vectors."""
    out = torch.zeros(list(labels.size())+[dim])
    if device:
        out = out.to(device)

    out.scatter_(len(out.size())-1,labels.unsqueeze(-1),1.)

    return out.float()


def get_encoders_decoders(raw_file1, raw_file2, max_atom):
    """
    Given two raw SMILES files, either load the atom and bond encoders/decoders
    if they exist (naming them based on the file names) or create and save them.

    Parameters:
        raw_file1 (str): Path to the first SMILES file.
        raw_file2 (str): Path to the second SMILES file.
        max_atom (int): Maximum allowed number of atoms in a molecule.

    Returns:
        atom_encoder (dict): Mapping from atomic numbers to indices.
        atom_decoder (dict): Mapping from indices to atomic numbers.
        bond_encoder (dict): Mapping from bond types to indices.
        bond_decoder (dict): Mapping from indices to bond types.
    """
    # Determine unique suffix based on the two file names (alphabetically sorted for consistency)
    name1 = os.path.splitext(os.path.basename(raw_file1))[0]
    name2 = os.path.splitext(os.path.basename(raw_file2))[0]
    sorted_names = sorted([name1, name2])
    suffix = f"{sorted_names[0]}_{sorted_names[1]}"

    # Define encoder/decoder directories and file paths
    enc_dir = os.path.join("data", "encoders")
    dec_dir = os.path.join("data", "decoders")
    atom_encoder_path = os.path.join(enc_dir, f"atom_{suffix}.pkl")
    atom_decoder_path = os.path.join(dec_dir, f"atom_{suffix}.pkl")
    bond_encoder_path = os.path.join(enc_dir, f"bond_{suffix}.pkl")
    bond_decoder_path = os.path.join(dec_dir, f"bond_{suffix}.pkl")

    # If all files exist, load and return them
    if (os.path.exists(atom_encoder_path) and os.path.exists(atom_decoder_path) and 
        os.path.exists(bond_encoder_path) and os.path.exists(bond_decoder_path)):
        with open(atom_encoder_path, "rb") as f:
            atom_encoder = pickle.load(f)
        with open(atom_decoder_path, "rb") as f:
            atom_decoder = pickle.load(f)
        with open(bond_encoder_path, "rb") as f:
            bond_encoder = pickle.load(f)
        with open(bond_decoder_path, "rb") as f:
            bond_decoder = pickle.load(f)
        print("Loaded existing encoders/decoders!")
        return atom_encoder, atom_decoder, bond_encoder, bond_decoder

    # Otherwise, create the encoders/decoders
    print("Creating new encoders/decoders...")
    # Read SMILES from both files (assuming one SMILES per row, no header)
    smiles1 = pd.read_csv(raw_file1, header=None)[0].tolist()
    smiles2 = pd.read_csv(raw_file2, header=None)[0].tolist()
    smiles_combined = smiles1 + smiles2

    atom_labels = set()
    bond_labels = set()
    max_length = 0
    filtered_smiles = []
    
    # Process each SMILES: keep only valid molecules with <= max_atom atoms
    for smiles in tqdm(smiles_combined, desc="Processing SMILES"):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
        molecule_size = mol.GetNumAtoms()
        if molecule_size > max_atom:
            continue
        filtered_smiles.append(smiles)
        # Collect atomic numbers
        atom_labels.update([atom.GetAtomicNum() for atom in mol.GetAtoms()])
        max_length = max(max_length, molecule_size)
        # Collect bond types
        bond_labels.update([bond.GetBondType() for bond in mol.GetBonds()])
    
    # Add a PAD symbol (here using 0 for atoms)
    atom_labels.add(0)
    atom_labels = sorted(atom_labels)
    
    # For bonds, prepend the PAD bond type (using rdkit's BondType.ZERO)
    bond_labels = sorted(bond_labels)
    bond_labels = [Chem.rdchem.BondType.ZERO] + bond_labels

    # Create encoder and decoder dictionaries
    atom_encoder = {l: i for i, l in enumerate(atom_labels)}
    atom_decoder = {i: l for i, l in enumerate(atom_labels)}
    bond_encoder = {l: i for i, l in enumerate(bond_labels)}
    bond_decoder = {i: l for i, l in enumerate(bond_labels)}

    # Ensure directories exist
    os.makedirs(enc_dir, exist_ok=True)
    os.makedirs(dec_dir, exist_ok=True)

    # Save the encoders/decoders to disk
    with open(atom_encoder_path, "wb") as f:
        pickle.dump(atom_encoder, f)
    with open(atom_decoder_path, "wb") as f:
        pickle.dump(atom_decoder, f)
    with open(bond_encoder_path, "wb") as f:
        pickle.dump(bond_encoder, f)
    with open(bond_decoder_path, "wb") as f:
        pickle.dump(bond_decoder, f)

    print("Encoders/decoders created and saved.")
    return atom_encoder, atom_decoder, bond_encoder, bond_decoder

def load_molecules(data=None, b_dim=32, m_dim=32, device=None, batch_size=32):
    data = data.to(device)
    a = geoutils.to_dense_adj(
        edge_index = data.edge_index,
        batch=data.batch,
        edge_attr=data.edge_attr,
        max_num_nodes=int(data.batch.shape[0]/batch_size)
    )
    x_tensor = data.x.view(batch_size,int(data.batch.shape[0]/batch_size),-1)
    a_tensor = label2onehot(a, b_dim, device)

    a_tensor_vec = a_tensor.reshape(batch_size,-1)
    x_tensor_vec = x_tensor.reshape(batch_size,-1)
    real_graphs = torch.concat((x_tensor_vec,a_tensor_vec),dim=-1)

    return real_graphs, a_tensor, x_tensor