Spaces:
Runtime error
Runtime error
import torch | |
from torch_geometric.data import Dataset | |
import os | |
import random | |
import json | |
from .data_utils import smiles2data, escape_custom_split_sequence, reformat_smiles, generate_rsmiles | |
class SynthesisDataset(Dataset): | |
def __init__(self, | |
root, | |
mode, | |
smi_max_len=128, | |
use_graph=True, | |
disable_graph_cache=False, | |
smiles_type='default', | |
roundrobin_train=False, | |
test_subset=-1 | |
): | |
super(SynthesisDataset, self).__init__(root) | |
self.root = root | |
if 'PtoR' in root: | |
self.task = 'retro' | |
elif 'pretrain' in root: | |
self.task = 'pretrain' | |
elif 'RtoP' in root: | |
self.task = 'forward' | |
else: | |
raise NotImplementedError(f'Invalid task: {root}') | |
if mode=='valid': | |
mode='val' | |
self.mode = mode | |
self.smi_max_len = smi_max_len | |
self.tokenizer = None | |
self.use_graph = use_graph | |
self.disable_graph_cache = disable_graph_cache | |
self.smiles_type = smiles_type | |
self.roundrobin_train = roundrobin_train | |
with open(os.path.join(root, 'mol_graphid_map.json')) as f: | |
self.mol_idx_map = json.load(f) | |
if self.use_graph: | |
self.idx_graph_map = torch.load(os.path.join(root, 'idx_graph_map.pt')) | |
if self.roundrobin_train and mode=='train': | |
self.reload_counter=-2 | |
self.reload_data() | |
else: | |
with open(os.path.join(root, mode, f'src-{mode}.txt')) as f: | |
self.input_list = f.readlines() | |
with open(os.path.join(root, mode, f'tgt-{mode}.txt')) as f: | |
self.output_list = f.readlines() | |
assert len(self.input_list) == len(self.output_list) | |
self.renew_r_smiles() | |
self.input_list = [smi.strip().replace(' ','') for smi in self.input_list] | |
self.output_list = [smi.strip().replace(' ','') for smi in self.output_list] | |
if test_subset>0 and mode=='test': | |
assert test_subset<=len(self.input_list) | |
self.input_list = self.input_list[:test_subset] | |
self.input_list = self.input_list[:test_subset] | |
def reload_data(self): | |
if not self.roundrobin_train: | |
return | |
self.reload_counter = (self.reload_counter+1)%10 | |
if hasattr(self, 'input_list'): | |
del self.input_list | |
if hasattr(self, 'output_list'): | |
del self.output_list | |
with open(os.path.join(self.root, f'train/src-train_{self.reload_counter}.txt')) as f: | |
self.input_list = f.readlines() | |
with open(os.path.join(self.root, f'train/tgt-train_{self.reload_counter}.txt')) as f: | |
self.output_list = f.readlines() | |
assert len(self.input_list) == len(self.output_list) | |
self.renew_r_smiles() | |
self.input_list = [smi.strip().replace(' ','') for smi in self.input_list] | |
self.output_list = [smi.strip().replace(' ','') for smi in self.output_list] | |
input_list, output_list = [], [] | |
for input_smiles, output_smiles in zip(self.input_list, self.output_list): | |
if input_smiles.count('.') != output_smiles.count('.'): | |
continue | |
input_list.append(input_smiles) | |
output_list.append(output_smiles) | |
print(f'Reloaded data from {self.root}/train/src-train_{self.reload_counter}.txt, filtered len={len(self.input_list)}', flush=True) | |
self.input_list = input_list | |
self.output_list = output_list | |
def renew_r_smiles(self): | |
if self.smiles_type == 'r_smiles' and self.mode == 'train': | |
# only renew r_smiles for training set | |
if not hasattr(self, 'input_list_mapped'): | |
# here we back up the original input_list and output_list | |
self.input_list_mapped = self.input_list | |
self.output_list_mapped = self.output_list | |
self.output_list, self.input_list = generate_rsmiles(self.output_list_mapped, self.input_list_mapped) | |
self.input_list = [smi.strip().replace(' ','') for smi in self.input_list] | |
self.output_list = [smi.strip().replace(' ','') for smi in self.output_list] | |
def get(self, index): | |
return self.__getitem__(index) | |
def len(self): | |
return len(self) | |
def __len__(self): | |
return len(self.input_list) | |
def make_prompt(self, input_smiles, output_smiles, smi_max_len=512): | |
FORWARD_PROMPT = 'Question: Given the following reactant molecules: {}, what are the expected products? Answer: The product molecules are ' | |
FORWARD_CATALYST_PROMPT = '{}, and the following catalyst molecules: {}' | |
RETRO_PROMPT = 'Question: Given the following product molecules: {}, what are the reactants that produce them? Answer: The reactant molecules are ' | |
# RETRO_PROMPT = 'Predict the reaction that produces the following product: {} ' | |
PRETRAIN_PROMPT = 'Reconstruct the masked molecule: {}. Answer: ' | |
smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len] | |
if self.task=='retro': | |
assert '<separated>' not in input_smiles | |
smiles_list = input_smiles.split('.') | |
in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list]) | |
input_prompt = RETRO_PROMPT.format(in_prompt) | |
elif self.task=='forward': | |
if '<separated>' in input_smiles: | |
reactant_smiles, reagent_smiles = input_smiles.split('<separated>') | |
reactant_smiles = reactant_smiles.split('.') | |
reagent_smiles = reagent_smiles.split('.') | |
reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reactant_smiles]) | |
reagent_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in reagent_smiles]) | |
smiles_list = reactant_smiles+reagent_smiles | |
input_prompt = FORWARD_CATALYST_PROMPT.format(reactant_prompt, reagent_prompt) | |
else: | |
smiles_list = input_smiles.split('.') | |
reactant_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in smiles_list]) | |
input_prompt = reactant_prompt | |
input_prompt = FORWARD_PROMPT.format(input_prompt) | |
elif self.task=='pretrain': | |
in_prompt = '; '.join([f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES]' for smi in input_smiles.split('.')]) | |
input_prompt = PRETRAIN_PROMPT.format(in_prompt) | |
smiles_list = output_smiles.split('.') | |
# output_smiles = ' '.join([f'[START_SMILES]{smi[:smi_max_len]}[END_SMILES]' for smi in output_smiles.split('.')]) | |
output_smiles = f'[START_SMILES]{output_smiles}[END_SMILES]' | |
output_smiles = escape_custom_split_sequence(output_smiles) | |
return input_prompt, smiles_list, output_smiles | |
def __getitem__(self, index): | |
input_smiles = self.input_list[index] | |
output_smiles = self.output_list[index] | |
input_text, smiles_list, output_text = self.make_prompt(input_smiles, output_smiles, smi_max_len=self.smi_max_len) | |
output_text = output_text.strip()+'\n' | |
graph_list = [] | |
if self.use_graph: | |
for smiles in smiles_list: | |
if self.disable_graph_cache: | |
graph_item = smiles2data(smiles) | |
else: | |
assert smiles in self.mol_idx_map | |
idx = self.mol_idx_map[smiles] | |
assert idx in self.idx_graph_map | |
graph_item = self.idx_graph_map[idx] | |
graph_list.append(graph_item) | |
return index, graph_list, output_text, input_text |