File size: 4,615 Bytes
eca78a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import yaml
from argparse import Namespace
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import Stage1_source.preprocess as prep
import Stage1_source.model as mod
import Stage1_source.PL_wrapper as PL_wrap

# Step 1: Load JSON Configuration
def load_json_config(json_path):
    with open(json_path, "r") as f:
        config = json.load(f)
    return config

# Step 2: Convert JSON dictionary to Namespace
def convert_to_namespace(config_dict):
    for key, value in config_dict.items():
        if isinstance(value, dict):
            config_dict[key] = convert_to_namespace(value)
    return Namespace(**config_dict)

# Step 3: Load Pre-trained Model
def prepare_model(config_args, model_path) -> nn.Module:
    model = mod.pfam_PEN_CL(args=config_args)
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()
    print("Model loaded successfully with weights!")
    return model

# Step 4: Prepare Test Dataset
def load_test_dataset(config_args):
    test_dict = {
        'primary_Accession': ['A0A009IHW8', 'A0A023I7E1'],
        'protein_sequence': [
            "MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKL...",
            "MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIFPEIKHP..."
        ],
        '[final]text_caption': [
            "PROTEIN NAME: 2' cyclic ADP-D-ribose synthase AbTIR...",
            "PROTEIN NAME: Glucan endo-1,3-beta-D-glucosidase 1..."
        ],
        'pfam_label': ["['PF13676']", "['PF17652','PF03639']"]
    }
    test_df = pd.DataFrame(test_dict)
    test_dataset = prep.TextSeqPairing_Dataset(args=config_args, df=test_df)
    return test_dataset

# Step 5: Argument Parser Function
def parse_arguments():
    parser = argparse.ArgumentParser(description="BioM3 Inference Script (Stage 1)")
    parser.add_argument('--json_path', type=str, required=True,
                        help="Path to the JSON configuration file (stage1_config.json)")
    parser.add_argument('--model_path', type=str, required=True,
                        help="Path to the pre-trained model weights (pytorch_model.bin)")
    return parser.parse_args()

# Main Execution
if __name__ == '__main__':
    # Parse arguments
    config_args_parser = parse_arguments()

    # Load configuration
    config_dict = load_json_config(config_args_parser.json_path)
    config_args = convert_to_namespace(config_dict)

    # Load model
    model = prepare_model(config_args=config_args, model_path=config_args_parser.model_path)

    # Load test dataset
    test_dataset = load_test_dataset(config_args)

    # Run inference and store z_t, z_p
    z_t_list = []
    z_p_list = []

    with torch.no_grad():
        for idx in range(len(test_dataset)):
            batch = test_dataset[idx]
            x_t, x_p = batch
            outputs = model(x_t, x_p, compute_masked_logits=False) # Infer Joint-Embeddings 
            z_t = outputs['text_joint_latent']  # Text latent
            z_p = outputs['seq_joint_latent']   # Protein latent
            z_t_list.append(z_t)
            z_p_list.append(z_p)

    # Stack all latent vectors
    z_t_tensor = torch.vstack(z_t_list)  # Shape: (num_samples, latent_dim)
    z_p_tensor = torch.vstack(z_p_list)  # Shape: (num_samples, latent_dim)

    # Compute Dot Product scores
    dot_product_scores = torch.matmul(z_p_tensor, z_t_tensor.T)  # Dot product

    # Normalize scores into probabilities
    protein_given_text_probs = F.softmax(dot_product_scores, dim=0)  # Normalize across rows (proteins), for each text
    text_given_protein_probs = F.softmax(dot_product_scores, dim=1)  # Normalize across columns (texts), for each protein

    # Compute magnitudes (L2 norms) for z_t and z_p
    z_p_magnitude = torch.norm(z_p_tensor, dim=1)  # L2 norm for each protein latent vector
    z_t_magnitude = torch.norm(z_t_tensor, dim=1)  # L2 norm for each text latent vector

    # Print results
    print("\n=== Inference Results ===")
    print(f"Shape of z_p (protein latent): {z_p_tensor.shape}")
    print(f"Shape of z_t (text latent): {z_t_tensor.shape}")
    print(f"\nMagnitudes of z_p vectors: {z_p_magnitude}")
    print(f"Magnitudes of z_t vectors: {z_t_magnitude}")

    print("\n=== Dot Product Scores Matrix ===")
    print(dot_product_scores)

    print("\n=== Normalized Probabilities ===")
    print("Protein-Normalized Probabilities (Softmax across Proteins for each Text):")
    print(protein_given_text_probs)

    print("\nText-Normalized Probabilities (Softmax across Texts for each Protein):")
    print(text_given_protein_probs)