File size: 4,587 Bytes
66d2e5f
d5de529
 
 
 
 
 
 
 
 
66d2e5f
d5de529
 
 
 
 
 
 
 
66d2e5f
d5de529
 
 
66d2e5f
 
d5de529
66d2e5f
 
 
 
d5de529
66d2e5f
d5de529
 
 
 
66d2e5f
d5de529
 
 
 
 
66d2e5f
 
 
d5de529
 
 
 
66d2e5f
 
 
 
 
 
 
 
 
 
 
 
 
 
d5de529
66d2e5f
 
 
 
 
 
d5de529
66d2e5f
 
d5de529
66d2e5f
 
d5de529
66d2e5f
d5de529
 
66d2e5f
d5de529
 
 
66d2e5f
 
 
 
 
 
d5de529
 
 
 
 
66d2e5f
 
 
 
 
d5de529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66d2e5f
 
d5de529
66d2e5f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import yaml
from argparse import Namespace
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import Stage1_source.model as mod

# Step 1: Load JSON Configuration
def load_json_config(json_path):
    with open(json_path, "r") as f:
        config = json.load(f)
    return config

# Step 2: Convert JSON dictionary to Namespace
def convert_to_namespace(config_dict):
    for key, value in config_dict.items():
        if isinstance(value, dict):
            config_dict[key] = convert_to_namespace(value)
    return Namespace(**config_dict)

# Step 3: Load Pre-trained Model
def prepare_model(config_args, model_path) -> nn.Module:
    model = mod.Facilitator(
        in_dim=config_args.emb_dim,
        hid_dim=config_args.hid_dim,
        out_dim=config_args.emb_dim,
        dropout=config_args.dropout
    )
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()
    print("Model loaded successfully with weights!")
    return model

# Step 4: Compute MMD Loss
def compute_mmd_loss(x, y, kernel="rbf", sigma=1.0):
    def rbf_kernel(a, b, sigma):
        pairwise_distances = torch.cdist(a, b, p=2) ** 2
        return torch.exp(-pairwise_distances / (2 * sigma ** 2))

    K_xx = rbf_kernel(x, x, sigma)
    K_yy = rbf_kernel(y, y, sigma)
    K_xy = rbf_kernel(x, y, sigma)

    mmd_loss = K_xx.mean() - 2 * K_xy.mean() + K_yy.mean()
    return mmd_loss

# Step 5: Argument Parser Function
def parse_arguments():
    parser = argparse.ArgumentParser(description="BioM3 Facilitator Model (Stage 2)")
    parser.add_argument('--input_data_path', type=str, required=True,
                        help="Path to the input embeddings (e.g., PenCL_test_outputs.pt)")
    parser.add_argument('--output_data_path', type=str, required=True,
                        help="Path to save the output embeddings (e.g., Facilitator_test_outputs.pt)")
    parser.add_argument('--model_path', type=str, required=True,
                        help="Path to the Facilitator model weights (e.g., BioM3_Facilitator_epoch20.bin)")
    parser.add_argument('--json_path', type=str, required=True,
                        help="Path to the JSON configuration file (stage2_config.json)")
    return parser.parse_args()

# Main Execution
if __name__ == '__main__':
    # Parse arguments
    args = parse_arguments()

    # Load configuration
    config_dict = load_json_config(args.json_path)
    config_args = convert_to_namespace(config_dict)

    # Load model
    model = prepare_model(config_args=config_args, model_path=args.model_path)

    # Load input embeddings
    embedding_dataset = torch.load(args.input_data_path)

    # Run inference to get facilitated embeddings
    with torch.no_grad():
        z_t = embedding_dataset['z_t']
        z_p = embedding_dataset['z_p']
        z_c = model(z_t)
        embedding_dataset['z_c'] = z_c

    # Compute evaluation metrics
    # 1. MSE between embeddings
    mse_zc_zp = F.mse_loss(z_c, z_p)
    mse_zt_zp = F.mse_loss(z_t, z_p)

    # 2. Compute L2 norms for first batch
    batch_idx = 0
    norm_z_t = torch.norm(z_t[batch_idx], p=2).item()
    norm_z_p = torch.norm(z_p[batch_idx], p=2).item()
    norm_z_c = torch.norm(z_c[batch_idx], p=2).item()

    # 3. Compute MMD between embeddings
    mmd_zc_zp = model.compute_mmd(z_c, z_p)
    mmd_zp_zt = model.compute_mmd(z_p, z_t)

    # Print results
    print("\n=== Facilitator Model Output ===")
    print(f"Shape of z_t (Text Embeddings): {z_t.shape}")
    print(f"Shape of z_p (Protein Embeddings): {z_p.shape}")
    print(f"Shape of z_c (Facilitated Embeddings): {z_c.shape}\n")

    print("=== Norm (L2 Magnitude) Results for Batch Index 0 ===")
    print(f"Norm of z_t (Text Embedding): {norm_z_t:.6f}")
    print(f"Norm of z_p (Protein Embedding): {norm_z_p:.6f}")
    print(f"Norm of z_c (Facilitated Embedding): {norm_z_c:.6f}")

    print("\n=== Mean Squared Error (MSE) Results ===")
    print(f"MSE between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): {mse_zc_zp:.6f}")
    print(f"MSE between Text Embeddings (z_t) and Protein Embeddings (z_p): {mse_zt_zp:.6f}")

    print("\n=== Max Mean Discrepancy (MMD) Results ===")
    print(f"MMD between Facilitated Embeddings (z_c) and Protein Embeddings (z_p): {mmd_zc_zp:.6f}")
    print(f"MMD between Text Embeddings (z_t) and Protein Embeddings (z_p): {mmd_zp_zt:.6f}")

    # Save output embeddings
    torch.save(embedding_dataset, args.output_data_path)
    print(f"\nFacilitator embeddings saved to {args.output_data_path}")