Spaces:
Runtime error
Runtime error
File size: 5,544 Bytes
95f97c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from utils import *
import torch
from rxnfp.transformer_fingerprints import (
RXNBERTFingerprintGenerator, get_default_model_and_tokenizer, generate_fingerprints
)
class Reaction_model:
def __init__(self, train_list, test_list):
self.train_list = train_list
self.test_list = test_list
model, tokenizer = get_default_model_and_tokenizer()
self.rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)
@time_it
def generate_random(self):
pred = random.sample(self.train_list, k=len(self.test_list))
pred = [i['actions'] for i in pred]
return pred
@time_it
def generate_random_compatible_old(self):
pred_list = []
len_id_map = defaultdict(list)
for train_rxn in self.train_list:
len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index'])
keys = sorted(k for k in len_id_map.keys())
accumulated_counts = {}
count = 0
for key in keys:
count += len(len_id_map[key])
accumulated_counts[key] = count
for rxn in self.test_list:
test_token_num = len(rxn['extracted_molecules'])-1
idx = random.randint(0, accumulated_counts[test_token_num] - 1)
for key in keys:
if idx < len(len_id_map[key]):
pred_list.append(self.train_list[len_id_map[key][idx]]['actions'])
break
else:
idx -= len(len_id_map[key])
return pred_list
@time_it
def generate_random_compatible(self):
pred_list = []
len_id_map = defaultdict(list)
for train_rxn in self.train_list:
len_id_map[len(train_rxn['extracted_molecules'])-1].append(train_rxn['index'])
for rxn in self.test_list:
mole_num = len(rxn['extracted_molecules'])-1
pred_list.append(self.train_list[random.choice(len_id_map[mole_num])]['actions'])
return pred_list
@time_it
def generate_nn(self, batch_size=2048):
train_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.train_list]
test_rxns = [f"{'.'.join(rxn['REACTANT'])}>>{rxn['PRODUCT'][0]}" for rxn in self.test_list]
train_rxns_batches = [train_rxns[i:i+batch_size] for i in range(0, len(train_rxns), batch_size)]
test_rxns_batches = [test_rxns[i:i+batch_size] for i in range(0, len(test_rxns), batch_size)]
device = torch.device("cuda")
train_fps = []
for batch in tqdm(train_rxns_batches, desc='Generating fingerprints for training reactions'):
batch_fps = self.rxnfp_generator.convert_batch(batch)
train_fps.extend(batch_fps)
train_fps = torch.tensor(train_fps, device=device) # N x 256
most_similar_indices = []
for batch in tqdm(test_rxns_batches, desc='Generating fingerprints for test reactions'):
batch_fps = self.rxnfp_generator.convert_batch(batch)
batch_fps = torch.tensor(batch_fps, device=device) # BS x 256
batch_fps = batch_fps / torch.norm(batch_fps, dim=1, keepdim=True)
similarity_matrix = torch.matmul(train_fps, batch_fps.T) # N x BS
most_similar_indices.extend(torch.argmax(similarity_matrix, dim=0).tolist())
return [self.train_list[i]['actions'] for i in most_similar_indices]
def save_results(self, gt_list, pred_list, target_file):
text_dict_list = [{
"targets": gt,
"indices": i,
"predictions": pred,
} for i, (gt, pred) in enumerate(zip(gt_list, pred_list))]
with open(target_file, 'w') as f:
json.dump(text_dict_list, f, indent=4)
def parse_args():
parser = argparse.ArgumentParser(description="A simple argument parser")
parser.add_argument('--name', default='none', type=str)
parser.add_argument('--train_file', default=None, type=str)
parser.add_argument('--test_file', default=None, type=str)
parser.add_argument('--use_tok', default=False, action='store_true')
args = parser.parse_args()
return args
def read_dataset(args):
print(f'Reading {args.train_file}...')
with open(args.train_file, 'r', encoding='utf-8') as f:
train_ds = json.load(f)
print(f'{len(train_ds)} samples read.')
print(f'Reading {args.test_file}...')
with open(args.test_file, 'r', encoding='utf-8') as f:
test_ds = json.load(f)
print(f'{len(test_ds)} samples read.')
return train_ds, test_ds
def run_baselines(args):
set_random_seed(0)
train_ds, test_ds = read_dataset(args)
model = Reaction_model(train_ds, test_ds)
calculator = Metric_calculator()
gt_list = [i['actions'] for i in test_ds]
print('Random:')
pred_list = model.generate_random()
calculator(gt_list, pred_list, args.use_tok)
model.save_results(gt_list, pred_list, f'results/{args.name}/random.json')
print('Random (compatible pattern):')
pred_list = model.generate_random_compatible()
calculator(gt_list, pred_list, args.use_tok)
model.save_results(gt_list, pred_list, f'results/{args.name}/random_compatible.json')
print('Nearest neighbor:')
pred_list = model.generate_nn()
calculator(gt_list, pred_list, args.use_tok)
model.save_results(gt_list, pred_list, f'results/{args.name}/nn.json')
# assert 0
if __name__ == "__main__":
args=parse_args()
run_baselines(args) |