File size: 7,673 Bytes
95f97c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from data_provider.context_gen import *

def parse_args():
    parser = argparse.ArgumentParser(description="A simple argument parser")

	# Script arguments
    parser.add_argument('--name', default='none', type=str)
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--epochs', default=100, type=int)
    parser.add_argument('--chunk_size', default=100, type=int)
    parser.add_argument('--rxn_num', default=50000, type=int)
    parser.add_argument('--k', default=4, type=int)
    parser.add_argument('--root', default='data/pretrain_data', type=str)

    args = parser.parse_args()
    return args

def pad_shorter_array(arr1, arr2):
    len1 = arr1.shape[0]
    len2 = arr2.shape[0]
    if len1 > len2:
        arr2 = np.pad(arr2, (0, len1 - len2), 'constant')
    elif len2 > len1:
        arr1 = np.pad(arr1, (0, len2 - len1), 'constant')
    return arr1, arr2

def plot_distribution(values, target_path, x_lim=None, y_lim=None, chunk_size=100, color='blue'):
    num_full_chunks = len(values) // chunk_size
    values = np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1)
    values = np.sort(values)[::-1]
    plt.figure(figsize=(10, 4), dpi=100)
    x = np.arange(len(values))
    plt.bar(x, values, color=color)
    current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
    plt.xticks((current_values/chunk_size).astype(int), current_values)
    plt.ylabel('Molecule Frequency', fontsize=20)
    if x_lim:
        plt.xlim(*x_lim)
    if y_lim:
        plt.ylim(*y_lim)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout(pad=0.5)
    plt.savefig(target_path)
    print(f'Figure saved to {target_path}')
    plt.clf()

def plot_compare_distribution(list1, list2, target_path, x_lim=None, y_lim=None, labels=['Random', 'Ours'], colors=['blue', 'orange'], chunk_size=100):
    num_full_chunks = len(list1) // chunk_size
    list1, list2 = pad_shorter_array(list1, list2)
    values1, values2 = [
        np.sort(np.mean(values[:num_full_chunks*chunk_size].reshape(-1, chunk_size), axis=1))[::-1]
        for values in (list1, list2)]

    plt.figure(figsize=(10, 6), dpi=100)
    x = np.arange(len(values1))
    plt.bar(x, values1, color=colors[0], label=labels[0], alpha=0.6)
    plt.bar(x, values2, color=colors[1], label=labels[1], alpha=0.5)
    current_values = np.array([0, 200000, 400000, 600000, 800000, 1000000], dtype=int)
    plt.xticks((current_values/chunk_size).astype(int), current_values)
    plt.ylabel('Molecule Frequency', fontsize=20)
    if x_lim:
        plt.xlim(*x_lim)
    if y_lim:
        plt.ylim(*y_lim)
    plt.tick_params(axis='both', which='major', labelsize=18)
    plt.tight_layout(pad=0.5)
    plt.legend(fontsize=24, loc='upper right')
    plt.savefig(target_path)
    print(f'Figure saved to {target_path}')
    plt.clf()

def statistics(args):
    if args.seed:
        set_random_seed(args.seed)
    # 1141864 rxns from ord
    # 1120773 rxns from uspto
    cluster = Reaction_Cluster(args.root)

    rxn_num = len(cluster.reaction_data)
    abstract_num = 0
    property_num = 0
    calculated_property_num = 0
    experimental_property_num = 0
    avg_calculated_property_len = 0
    avg_experimental_property_len = 0
    mol_set = set()
    for rxn_dict in cluster.reaction_data:
        for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
            for mol in rxn_dict[key]:
                mol_set.add(mol)
    mol_num = len(mol_set)

    for mol_dict in cluster.property_data:
        if 'abstract' in mol_dict:
            abstract_num += 1
        if 'property' in mol_dict:
            property_num += 1
            if 'Experimental Properties' in mol_dict['property']:
                experimental_property_num += 1
                avg_experimental_property_len += len(mol_dict['property']['Experimental Properties'])
            if 'Computed Properties' in mol_dict['property']:
                calculated_property_num += 1
                avg_calculated_property_len += len(mol_dict['property']['Computed Properties'])
            
    print(f'Reaction Number: {rxn_num}')
    print(f'Molecule Number: {mol_num}')
    print(f'Abstract Number: {abstract_num}/{mol_num}({abstract_num/mol_num*100:.2f}%)')
    print(f'Property Number: {property_num}/{mol_num}({property_num/mol_num*100:.2f}%)')
    print(f'- Experimental Properties Number: {experimental_property_num}/{property_num}({experimental_property_num/property_num*100:.2f}%), {avg_experimental_property_len/mol_num:.2f} items per molecule')
    print(f'- Computed Properties: {calculated_property_num}/{property_num}({calculated_property_num/property_num*100:.2f}%), {avg_calculated_property_len/mol_num:.2f} items per molecule')

def visualize(args):
    if args.seed:
        set_random_seed(args.seed)
    cluster = Reaction_Cluster(args.root)
    prob_values, rxn_weights = cluster.visualize_mol_distribution()
    rand_prob_values, rand_rxn_weights = cluster._randomly(
        cluster.visualize_mol_distribution
    )
    fig_root = f'results/{args.name}/'

    plot_distribution(prob_values, fig_root+'mol_distribution.pdf')
    plot_distribution(rxn_weights, fig_root+'rxns_distribution.pdf')
    plot_distribution(rand_prob_values, fig_root+'mol_distribution_random.pdf')
    plot_distribution(rand_rxn_weights, fig_root+'rxns_distribution_random.pdf')
    
    plot_compare_distribution(prob_values, rand_prob_values, fig_root+'Compare_mol.pdf', y_lim=(-0.5,15.5))
    plot_compare_distribution(rxn_weights, rand_rxn_weights, fig_root+'Compare_rxns.pdf')


def visualize_frequency(args):
    if args.seed:
        set_random_seed(args.seed)
    fig_root = f'results/{args.name}/'
    name_suffix = f'E{args.epochs}_Rxn{args.rxn_num}_K{args.k}'
    cache_path = f'{fig_root}/freq_{name_suffix}.npy'
    if os.path.exists(cache_path):
        mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq = np.load(cache_path, allow_pickle=True)
    else:
        cluster = Reaction_Cluster(args.root)
        mol_freq, rxn_freq = cluster.visualize_mol_frequency(rxn_num=args.rxn_num, k=args.k, epochs=args.epochs)
        rand_mol_freq, rand_rxn_freq = cluster._randomly(
            cluster.visualize_mol_frequency,
            rxn_num=args.rxn_num, k=args.k, epochs=args.epochs
        )
        np.save(cache_path, np.array([mol_freq, rxn_freq, rand_mol_freq, rand_rxn_freq], dtype=object), allow_pickle=True)

    color1 = '#FA7F6F'
    color2 = '#80AFBF'
    color3 = '#FFBE7A'
    plot_distribution(mol_freq, fig_root+f'mol_frequency_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
    # plot_distribution(rxn_freq, fig_root+f'rxns_frequency_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)
    plot_distribution(rand_mol_freq, fig_root+f'mol_frequency_random_{name_suffix}.pdf', x_lim=(-50000//args.chunk_size, 1200000//args.chunk_size), y_lim=(-2, 62), chunk_size=args.chunk_size, color=color2)
    # plot_distribution(rand_rxn_freq, fig_root+f'rxns_frequency_random_{name_suffix}.pdf', chunk_size=args.chunk_size, color=color1)

    plot_compare_distribution(rand_mol_freq, mol_freq, fig_root+f'Compare_mol_{name_suffix}.pdf', y_lim=(-2, 62), labels=['Before Adjustment', 'After Adjustment'], colors=[color1, color2], chunk_size=args.chunk_size)
    # plot_compare_distribution(rxn_freq, rand_rxn_freq, fig_root+f'Compare_rxns_{name_suffix}.pdf', chunk_size=args.chunk_size)

if __name__=='__main__':
    args = parse_args()
    print(args, flush=True)
    # statistics(args)
    # visualize(args)
    visualize_frequency(args)