Niksa Praljak commited on
Commit
d5de529
·
1 Parent(s): c11fc4d

add Facilitator script

Browse files
Files changed (1) hide show
  1. run_Facilitator_sample.py +139 -0
run_Facilitator_sample.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from argparse import Namespace
3
+ import json
4
+ import pandas as pd
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ import pytorch_lightning as pl
11
+ import Stage1_source.preprocess as prep
12
+ import Stage1_source.model as mod
13
+ import Stage1_source.PL_wrapper as PL_wrap
14
+
15
+
16
+ # Step 1: Load JSON configuration
17
+ def load_json_config(json_path):
18
+ """
19
+ Load JSON configuration file.
20
+ """
21
+ with open(json_path, "r") as f:
22
+ config = json.load(f)
23
+ # print("Loaded JSON config:", config)
24
+ return config
25
+
26
+ # Step 2: Convert JSON dictionary to Namespace
27
+ def convert_to_namespace(config_dict):
28
+ """
29
+ Recursively convert a dictionary to an argparse Namespace.
30
+ """
31
+ for key, value in config_dict.items():
32
+ if isinstance(value, dict): # Recursively handle nested dictionaries
33
+ config_dict[key] = convert_to_namespace(value)
34
+ return Namespace(**config_dict)
35
+
36
+ def prepare_model(args) ->nn.Module:
37
+ """
38
+ Prepare the model and PyTorch Lightning Trainer using a flat args object.
39
+ """
40
+ model = mod.Facilitator(
41
+ in_dim=args.emb_dim,
42
+ hid_dim=args.hid_dim,
43
+ out_dim=args.emb_dim,
44
+ dropout=args.dropout
45
+ )
46
+ weights_path = f"{save_dir}/BioM3_Facilitator_epoch20.bin"# BioM3_PenCL_epoch20.bin"
47
+ model.load_state_dict(torch.load(weights_path, map_location="cpu"))
48
+ model.eval()
49
+ print("Model loaded successfully with weights!")
50
+ return model
51
+
52
+ def compute_mmd_loss(x, y, kernel="rbf", sigma=1.0):
53
+ """
54
+ Compute the MMD loss between two sets of embeddings.
55
+ Args:
56
+ x: Tensor of shape [N, D]
57
+ y: Tensor of shape [N, D]
58
+ kernel: Kernel function, default is 'rbf' (Gaussian kernel)
59
+ sigma: Bandwidth for the Gaussian kernel
60
+ """
61
+ def rbf_kernel(a, b, sigma):
62
+ """
63
+ Compute the RBF kernel between two tensors.
64
+ """
65
+ pairwise_distances = torch.cdist(a, b, p=2) ** 2
66
+ return torch.exp(-pairwise_distances / (2 * sigma ** 2))
67
+
68
+ # Compute RBF kernel matrices
69
+ K_xx = rbf_kernel(x, x, sigma) # Kernel within x
70
+ K_yy = rbf_kernel(y, y, sigma) # Kernel within y
71
+ K_xy = rbf_kernel(x, y, sigma) # Kernel between x and y
72
+
73
+ # Compute MMD loss
74
+ mmd_loss = K_xx.mean() - 2 * K_xy.mean() + K_yy.mean()
75
+ return mmd_loss
76
+
77
+
78
+ if __name__ == '__main__':
79
+
80
+ json_path = f"{save_dir}/stage2_config.json"
81
+ # Load and convert JSON config
82
+ json_path = f"{save_dir}/stage2_config.json"
83
+ config_dict = load_json_config(json_path)
84
+ args = convert_to_namespace(config_dict)
85
+
86
+ # load model
87
+ model = prepare_model(args=args)
88
+
89
+ # load test dataset
90
+ embedding_dataset = torch.load('./PenCL_test_outputs.pt')
91
+
92
+ # Run inference and store z_t, z_p
93
+
94
+ with torch.no_grad():
95
+ z_t = embedding_dataset['z_t']
96
+ z_p = embedding_dataset['z_p']
97
+ z_c = model(z_t)
98
+ embedding_dataset['z_c'] = z_c
99
+
100
+ # Compute MSE between embeddings
101
+ mse_zc_zp = F.mse_loss(z_c, z_p) # MSE between facilitated embeddings and protein embeddings
102
+ mse_zt_zp = F.mse_loss(z_t, z_p) # MSE between text embeddings and protein embeddings
103
+
104
+ # Compute Norms (L2 magnitudes) for a given batch (e.g., first 5 embeddings)
105
+ batch_idx = 0
106
+ norm_z_t = torch.norm(z_t[batch_idx], p=2).item()
107
+ norm_z_p = torch.norm(z_p[batch_idx], p=2).item()
108
+ norm_z_c = torch.norm(z_c[batch_idx], p=2).item()
109
+
110
+ # Compute MMD between embeddings
111
+ MMD_zc_zp = model.compute_mmd(z_c, z_p)
112
+ MMD_zp_zt = model.compute_mmd(z_p, z_t)
113
+
114
+ # Print Results
115
+ print("\n=== Facilitator Model Output ===")
116
+ print(f"Shape of z_t (Text Embeddings): {z_t.shape}")
117
+ print(f"Shape of z_p (Protein Embeddings): {z_p.shape}")
118
+ print(f"Shape of z_c (Facilitated Embeddings): {z_c.shape}\n")
119
+
120
+ print("=== Norm (L2 Magnitude) Results for Batch Index 0 ===")
121
+ print(f"Norm of z_t (Text Embedding): {norm_z_t:.6f}")
122
+ print(f"Norm of z_p (Protein Embedding): {norm_z_p:.6f}")
123
+ print(f"Norm of z_c (Facilitated Embedding): {norm_z_c:.6f}")
124
+
125
+ print("\n=== Mean Squared Error (MSE) Results ===")
126
+ print(f"MSE between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): {mse_zc_zp:.6f}")
127
+ print(f"MSE between Text Embeddings (z_t) and Protein Embeddings (z_p): {mse_zt_zp:.6f}")
128
+
129
+ print("\n=== Max Mean Discrepancy (MMD) Results ===")
130
+ print(f"MMD between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): {MMD_zc_zp:.6f}")
131
+ print(f"MMD between Text Embeddings (z_t) and Protein Embeddings (z_p): {MMD_zp_zt:.6f}")
132
+
133
+ print("\nFacilitator Model successfully computed facilitated embeddings!")
134
+
135
+ # save output embeddings
136
+
137
+ torch.save(embedding_dataset, 'Facilitator_test_outputs.pt')
138
+
139
+