Spaces:
Runtime error
Runtime error
File size: 9,587 Bytes
95f97c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import torch
from torch_geometric.data import Dataset
import os
from .context_gen import Reaction_Cluster
import json
from .data_utils import smiles2data, reformat_smiles
from collections import defaultdict
import random
from data_provider.caption_dataset import PretrainCaptionDataset
from data_provider.synthesis_dataset import SynthesisDataset
def format_float_from_string(s):
try:
float_value = float(s)
return f'{float_value:.2f}'
except ValueError:
return s
class MoleculeAbstract(Dataset):
def __init__(self,
root,
rxn_num=1000,
rxn_batch_size=4,
smi_max_len=128,
prompt=None,
disable_graph_cache=False,
disable_graphs=False,
context_style='weighted_rxn',
use_caption_dataset=False,
caption_batch_num=10000,
synthesis_datasetpath=None,
synthesis_batch_num=10000,
reverse_ratio=0.5,
enable_abstract=True,
enable_property=True,
smiles_type='default',
mode='train'
):
super(MoleculeAbstract, self).__init__(root)
self.root = root
self.rxn_num = rxn_num
self.rxn_batch_size = rxn_batch_size
self.smi_max_len = smi_max_len
self.context_style = context_style
self.tokenizer = None
self.disable_graph_cache = disable_graph_cache
self.disable_graphs = disable_graphs
self.use_caption_dataset = use_caption_dataset
self.smiles_type = smiles_type
if use_caption_dataset:
self.caption_dataset = PretrainCaptionDataset(
os.path.join(root, '../caption_data'),
smi_max_len=smi_max_len,
use_graph=not self.disable_graphs,
disable_graph_cache=disable_graph_cache,
smiles_type=smiles_type,
)
self.caption_batch_num = caption_batch_num
self.use_synthesis_dataset = bool(synthesis_datasetpath)
if self.use_synthesis_dataset:
self.synthesis_dataset = SynthesisDataset(
synthesis_datasetpath,
'train',
smi_max_len,
roundrobin_train=True,
use_graph=not disable_graphs,
disable_graph_cache=disable_graph_cache,
smiles_type='default',
)
self.synthesis_batch_num = synthesis_batch_num
if not self.disable_graphs:
self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
reaction_filename = 'reactions/reactions_test.json' if (mode=='test') else 'reactions/reactions.json'
if smiles_type=='r_smiles':
reaction_filename = 'reactions/reactions_wRSMILES.json'
self.cluster = Reaction_Cluster(self.root, reaction_filename=reaction_filename, reverse_ratio=reverse_ratio)
self.reload_data_list()
self.abstract_max_len = 10240
self.property_max_len = 10240
self.enable_abstract = enable_abstract
self.enable_property = enable_property
def get(self, index):
return self.__getitem__(index)
def len(self):
return len(self)
def __len__(self):
data_len = len(self.data_list)
if self.use_caption_dataset:
data_len += len(self.caption_index_list)
if self.use_synthesis_dataset:
data_len += len(self.synthesis_index_list)
return data_len
def reload_data_list(self):
k = self.rxn_batch_size
if self.context_style == 'weighted_rxn':
self.data_list = self.cluster(self.rxn_num, k=k)
elif self.context_style == 'uniform_rxn':
self.data_list = self.cluster.generate_batch_uniform_rxn(self.rxn_num, k=k)
elif self.context_style == 'uniform_mol':
self.data_list = self.cluster.generate_batch_uniform_mol(self.rxn_num, k=k)
elif self.context_style == 'single_mol':
self.data_list = self.cluster.generate_batch_single(self.rxn_num)
elif self.context_style == 'hybrid':
self.data_list = self.cluster(self.rxn_num//2, k=k)
self.data_list += self.cluster.generate_batch_uniform_mol(self.rxn_num//2, k=k)
else:
raise NotImplementedError
if self.use_caption_dataset:
assert self.caption_batch_num*k <= len(self.caption_dataset)
caption_index_list = random.sample(range(len(self.caption_dataset)), self.caption_batch_num*k)
self.caption_index_list = [caption_index_list[i*k:(i+1)*k] for i in range(self.caption_batch_num)]
else:
self.caption_index_list = []
if self.use_synthesis_dataset:
if self.synthesis_dataset.roundrobin_train:
self.synthesis_dataset.reload_data()
assert self.synthesis_batch_num <= len(self.synthesis_dataset)
self.synthesis_index_list = random.sample(range(len(self.synthesis_dataset)), self.synthesis_batch_num)
else:
self.synthesis_index_list = []
def make_prompt(self, mol_batch, smi_max_len=128):
mol_prompt_list, text_prompt_list = [], []
last_role = None
for mol_dict in mol_batch:
smiles = mol_dict['canon_smiles']
if self.smiles_type=='r_smiles':
if 'r_smiles' in mol_dict:
smiles = mol_dict['r_smiles']
# else:
# smiles = reformat_smiles(smiles, smiles_type='restricted')
else:
smiles = reformat_smiles(smiles, smiles_type=self.smiles_type)
mol_prompt = f'[START_SMILES]{smiles[:smi_max_len]}[END_SMILES]. '
if 'role' in mol_dict:
role = {
'REACTANT': 'Reactant',
'CATALYST': 'Catalyst',
'SOLVENT': 'Solvent',
'PRODUCT': 'Product',
}[mol_dict['role']]
if last_role != role:
mol_prompt = f'{role}: {mol_prompt}'
last_role = role
text_prompt = self.make_abstract(mol_dict)
mol_prompt_list.append(mol_prompt)
text_prompt_list.append(text_prompt)
return mol_prompt_list, text_prompt_list
def make_abstract(self, mol_dict):
prompt = ''
if self.enable_abstract and 'abstract' in mol_dict:
abstract_string = mol_dict['abstract'][:self.abstract_max_len]
prompt += f'[Abstract] {abstract_string} '
if self.enable_property:
property_string = ''
property_dict = mol_dict['property'] if 'property' in mol_dict else {}
for property_key in ['Experimental Properties', 'Computed Properties']:
if not property_key in property_dict:
continue
for key, value in property_dict[property_key].items():
if isinstance(value, float):
key_value_string = f'{key}: {value:.2f}; '
elif isinstance(value, str):
float_value = format_float_from_string(value)
key_value_string = f'{key}: {float_value}; '
else:
key_value_string = f'{key}: {value}; '
if len(property_string+key_value_string) > self.property_max_len:
break
property_string += key_value_string
if property_string:
property_string = property_string[:self.property_max_len]
prompt += f'[Properties] {property_string}. '
return prompt
def get_caption_data(self, index):
caption_index = self.caption_index_list[index]
graph_list, mol_prompt_list, text_prompt_list = [], [], []
for idx in caption_index:
graph_item, text, smiles_prompt = self.caption_dataset[idx]
graph_list.append(graph_item)
mol_prompt_list.append(smiles_prompt)
text_prompt_list.append(text)
return graph_list, mol_prompt_list, text_prompt_list
def get_synthesis_data(self, index):
synthesis_index = self.synthesis_index_list[index]
_, graph_list, output_text, input_text = self.synthesis_dataset[synthesis_index]
return graph_list, [input_text], [output_text]
def __getitem__(self, index):
if index < len(self.data_list):
mol_batch = self.data_list[index]
elif index < len(self.data_list)+len(self.caption_index_list):
assert self.use_caption_dataset
return self.get_caption_data(index-len(self.data_list))
else:
assert self.use_synthesis_dataset
return self.get_synthesis_data(index-(len(self.data_list)+len(self.caption_index_list)))
graph_list = []
for mol_dict in mol_batch:
smiles = mol_dict['canon_smiles']
if self.disable_graphs:
graph_item = None
else:
if self.disable_graph_cache:
graph_item = smiles2data(smiles)
else:
assert smiles in self.mol_graph_map
graph_item = self.mol_graph_map[smiles]
graph_list.append(graph_item)
mol_prompt_list, text_prompt_list = self.make_prompt(mol_batch, smi_max_len=self.smi_max_len)
return graph_list, mol_prompt_list, text_prompt_list
|