Spaces:
Runtime error
Runtime error
File size: 7,673 Bytes
95f97c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from data_provider.context_gen import *
def parse_args():
parser = argparse.ArgumentParser(description="A simple argument parser")
# Script arguments
parser.add_argument('--name', default='none', type=str)
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--chunk_size', default=100, type=int)
parser.add_argument('--rxn_num', default=50000, type=int)
parser.add_argument('--k', default=4, type=int)
parser.add_argument('--root', default='data/pretrain_data', type=str)
args = parser.parse_args()
return args
def pad_shorter_array(arr1, arr2):
len1 = arr1.shape[0]
len2 = arr2.shape[0]
if len1 > len2:
arr2 = np.pad(arr2, (0, len1 - len2), 'constant')
elif len2 > len1:
arr1 = np.pad(arr1, (0, len2 - len1), 'constant')
return arr1, arr2
def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'):
num_full_chunks = len(values) // chunk_size
values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1)
values = np.sort(values)[::-1]
plt.figure(figsize=(10, 4), dpi=100)
x = np.arange(len(values))
plt.bar(x, values, color=color)
current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
plt.xticks((current_values/chunk_size).astype(int), current_values)
plt.ylabel('Molecule Frequency', fontsize=20)
if x_lim:
plt.xlim(*x_lim)
if y_lim:
plt.ylim(*y_lim)
plt.tick_params(axis='both', which='major', labelsize=12)
plt.tight_layout(pad=0.5)
plt.savefig(target_path)
print(f'Figure saved to {target_path}')
plt.clf()
def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100):
num_full_chunks = len(list1) // chunk_size
list1, list2 = pad_shorter_array(list1, list2)
values1, values2 = [
np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1]
for values in (list1, list2)]
plt.figure(figsize=(10, 6), dpi=100)
x = np.arange(len(values1))
plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6)
plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5)
current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
plt.xticks((current_values/chunk_size).astype(int), current_values)
plt.ylabel('Molecule Frequency', fontsize=20)
if x_lim:
plt.xlim(*x_lim)
if y_lim:
plt.ylim(*y_lim)
plt.tick_params(axis='both', which='major', labelsize=18)
plt.tight_layout(pad=0.5)
plt.legend(fontsize=24, loc='upper right')
plt.savefig(target_path)
print(f'Figure saved to {target_path}')
plt.clf()
def statistics(args):
if args.seed:
set_random_seed(args.seed)
# 1141864 rxns from ord
# 1120773 rxns from uspto
cluster = Reaction_Cluster(args.root)
rxn_num = len(cluster.reaction_data)
abstract_num = 0
property_num = 0
calculated_property_num = 0
experimental_property_num = 0
avg_calculated_property_len = 0
avg_experimental_property_len = 0
mol_set = set()
for rxn_dict in cluster.reaction_data:
for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
for mol in rxn_dict[key]:
mol_set.add(mol)
mol_num = len(mol_set)
for mol_dict in cluster.property_data:
if 'abstract' in mol_dict:
abstract_num += 1
if 'property' in mol_dict:
property_num += 1
if 'Experimental Properties' in mol_dict['property']:
experimental_property_num += 1
avg_experimental_property_len += len(mol_dict['property']['Experimental Properties'])
if 'Computed Properties' in mol_dict['property']:
calculated_property_num += 1
avg_calculated_property_len += len(mol_dict['property']['Computed Properties'])
print(f'Reaction Number: {rxn_num}')
print(f'Molecule Number: {mol_num}')
print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)')
print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)')
print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule')
print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule')
def visualize(args):
if args.seed:
set_random_seed(args.seed)
cluster = Reaction_Cluster(args.root)
prob_values, rxn_weights = cluster.visualize_mol_distribution()
rand_prob_values, rand_rxn_weights = cluster._randomly(
cluster.visualize_mol_distribution
)
fig_root = f'results/{args.name}/'
plot_distribution(prob_values, fig_root+'mol_distribution.pdf')
plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf')
plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf')
plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf')
plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5))
plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf')
def visualize_frequency(args):
if args.seed:
set_random_seed(args.seed)
fig_root = f'results/{args.name}/'
name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}'
cache_path = f'{fig_root}/freq_{name_suffix}.npy'
if os.path.exists(cache_path):
mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True)
else:
cluster = Reaction_Cluster(args.root)
mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs)
rand_mol_freq, rand_rxn_freq = cluster._randomly(
cluster.visualize_mol_frequency,
rxn_num=args.rxn_num, k=args.k, epochs=args.epochs
)
np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True)
color1 = '#FA7F6F'
color2 = '#80AFBF'
color3 = '#FFBE7A'
plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
# plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
# plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size)
# plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size)
if __name__=='__main__':
args = parse_args()
print(args, flush=True)
# statistics(args)
# visualize(args)
visualize_frequency(args) |