File size: 7,886 Bytes
95f97c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch
from torch_geometric.data import Dataset
import os
import random
import json
from .data_utils import smiles2data, escape_custom_split_sequence, reformat_smiles, generate_rsmiles

class SynthesisDataset(Dataset):
    def __init__(self,
            root,
            mode,
            smi_max_len=128,
            use_graph=True,
            disable_graph_cache=False,
            smiles_type='default',
            roundrobin_train=False,
            test_subset=-1
        ):
        super(SynthesisDataset, self).__init__(root)
        self.root = root
        if 'PtoR' in root:
            self.task = 'retro'
        elif 'pretrain' in root:
            self.task = 'pretrain'
        elif 'RtoP' in root:
            self.task = 'forward'
        else:
            raise NotImplementedError(f'Invalid task: {root}')
        if mode=='valid':
            mode='val'
        self.mode = mode
        self.smi_max_len = smi_max_len
        self.tokenizer = None
        self.use_graph = use_graph
        self.disable_graph_cache = disable_graph_cache
        self.smiles_type = smiles_type
        self.roundrobin_train = roundrobin_train
        with open(os.path.join(root, 'mol_graphid_map.json')) as f:
            self.mol_idx_map = json.load(f)
        if self.use_graph:
            self.idx_graph_map = torch.load(os.path.join(root, 'idx_graph_map.pt'))

        if self.roundrobin_train and mode=='train':
            self.reload_counter=-2
            self.reload_data()
        else:
            with open(os.path.join(root, mode, f'src-{mode}.txt')) as f:
                self.input_list = f.readlines()
            with open(os.path.join(root, mode, f'tgt-{mode}.txt')) as f:
                self.output_list = f.readlines()
            assert len(self.input_list) == len(self.output_list)
            self.renew_r_smiles()
            self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
            self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
        if test_subset>0 and mode=='test':
            assert test_subset<=len(self.input_list)
            self.input_list = self.input_list[:test_subset]
            self.input_list = self.input_list[:test_subset]

    def reload_data(self):
        if not self.roundrobin_train:
            return
        self.reload_counter = (self.reload_counter+1)%10
        if hasattr(self, 'input_list'):
            del self.input_list
        if hasattr(self, 'output_list'):
            del self.output_list
        with open(os.path.join(self.root, f'train/src-train_{self.reload_counter}.txt')) as f:
            self.input_list = f.readlines()
        with open(os.path.join(self.root, f'train/tgt-train_{self.reload_counter}.txt')) as f:
            self.output_list = f.readlines()
        assert len(self.input_list) == len(self.output_list)
        self.renew_r_smiles()
        self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
        self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]
        input_list, output_list = [], []
        for input_smiles, output_smiles in zip(self.input_list, self.output_list):
            if input_smiles.count('.') != output_smiles.count('.'):
                continue
            input_list.append(input_smiles)
            output_list.append(output_smiles)
        print(f'Reloaded data from {self.root}/train/src-train_{self.reload_counter}.txt, filtered len={len(self.input_list)}', flush=True)
        self.input_list = input_list
        self.output_list = output_list
    
    def renew_r_smiles(self):
        if self.smiles_type == 'r_smiles' and self.mode == 'train':
            # only renew r_smiles for training set
            if not hasattr(self, 'input_list_mapped'):
                # here we back up the original input_list and output_list
                self.input_list_mapped = self.input_list
                self.output_list_mapped = self.output_list
            self.output_list, self.input_list = generate_rsmiles(self.output_list_mapped, self.input_list_mapped)
            self.input_list = [smi.strip().replace(' ','') for smi in self.input_list]
            self.output_list = [smi.strip().replace(' ','') for smi in self.output_list]

    def get(self, index):
        return self.__getitem__(index)

    def len(self):
        return len(self)

    def __len__(self):
        return len(self.input_list)

    def make_prompt(self, input_smiles, output_smiles, smi_max_len=512):
        FORWARD_PROMPT = 'Question: Given the following reactant molecules: {}, what are the expected products? Answer: The product molecules are '
        FORWARD_CATALYST_PROMPT = '{}, and the following catalyst molecules: {}'
        RETRO_PROMPT = 'Question: Given the following product molecules: {}, what are the reactants that produce them? Answer: The reactant molecules are '
        # RETRO_PROMPT = 'Predict the reaction that produces the following product: {} '
        PRETRAIN_PROMPT = 'Reconstruct the masked molecule: {}. Answer: '
        smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
        if self.task=='retro':
            assert '<separated>' not in input_smiles
            smiles_list = input_smiles.split('.')
            in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list])
            input_prompt = RETRO_PROMPT.format(in_prompt)
        elif self.task=='forward':
            if '<separated>' in input_smiles:
                reactant_smiles, reagent_smiles = input_smiles.split('<separated>')
                reactant_smiles = reactant_smiles.split('.')
                reagent_smiles = reagent_smiles.split('.')
                reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reactant_smiles])
                reagent_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reagent_smiles])
                smiles_list = reactant_smiles+reagent_smiles
                input_prompt = FORWARD_CATALYST_PROMPT.format(reactant_prompt, reagent_prompt)
            else:
                smiles_list = input_smiles.split('.')
                reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list])
                input_prompt = reactant_prompt
            input_prompt = FORWARD_PROMPT.format(input_prompt)
        elif self.task=='pretrain':
            in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in input_smiles.split('.')])
            input_prompt = PRETRAIN_PROMPT.format(in_prompt)
            smiles_list = output_smiles.split('.')
        # output_smiles = ' '.join([f'[START_SMILES]{smi[:smi_max_len]}[END_SMILES]' for smi in output_smiles.split('.')])
        output_smiles = f'[START_SMILES]{output_smiles}[END_SMILES]'
        output_smiles = escape_custom_split_sequence(output_smiles)

        return input_prompt, smiles_list, output_smiles

    def __getitem__(self, index):
        input_smiles = self.input_list[index]
        output_smiles = self.output_list[index]
        input_text, smiles_list, output_text = self.make_prompt(input_smiles, output_smiles, smi_max_len=self.smi_max_len)
        output_text = output_text.strip()+'\n'

        graph_list = []
        if self.use_graph:
            for smiles in smiles_list:
                if self.disable_graph_cache:
                    graph_item = smiles2data(smiles)
                else:
                    assert smiles in self.mol_idx_map
                    idx = self.mol_idx_map[smiles]
                    assert idx in self.idx_graph_map
                    graph_item = self.idx_graph_map[idx]
                graph_list.append(graph_item)

        return index, graph_list, output_text, input_text