File size: 4,615 Bytes
eca78a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import argparse
import yaml
from argparse import Namespace
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import Stage1_source.preprocess as prep
import Stage1_source.model as mod
import Stage1_source.PL_wrapper as PL_wrap
# Step 1: Load JSON Configuration
def load_json_config(json_path):
with open(json_path, "r") as f:
config = json.load(f)
return config
# Step 2: Convert JSON dictionary to Namespace
def convert_to_namespace(config_dict):
for key, value in config_dict.items():
if isinstance(value, dict):
config_dict[key] = convert_to_namespace(value)
return Namespace(**config_dict)
# Step 3: Load Pre-trained Model
def prepare_model(config_args, model_path) -> nn.Module:
model = mod.pfam_PEN_CL(args=config_args)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
print("Model loaded successfully with weights!")
return model
# Step 4: Prepare Test Dataset
def load_test_dataset(config_args):
test_dict = {
'primary_Accession': ['A0A009IHW8', 'A0A023I7E1'],
'protein_sequence': [
"MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENARIQSKL...",
"MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIFPEIKHP..."
],
'[final]text_caption': [
"PROTEIN NAME: 2' cyclic ADP-D-ribose synthase AbTIR...",
"PROTEIN NAME: Glucan endo-1,3-beta-D-glucosidase 1..."
],
'pfam_label': ["['PF13676']", "['PF17652','PF03639']"]
}
test_df = pd.DataFrame(test_dict)
test_dataset = prep.TextSeqPairing_Dataset(args=config_args, df=test_df)
return test_dataset
# Step 5: Argument Parser Function
def parse_arguments():
parser = argparse.ArgumentParser(description="BioM3 Inference Script (Stage 1)")
parser.add_argument('--json_path', type=str, required=True,
help="Path to the JSON configuration file (stage1_config.json)")
parser.add_argument('--model_path', type=str, required=True,
help="Path to the pre-trained model weights (pytorch_model.bin)")
return parser.parse_args()
# Main Execution
if __name__ == '__main__':
# Parse arguments
config_args_parser = parse_arguments()
# Load configuration
config_dict = load_json_config(config_args_parser.json_path)
config_args = convert_to_namespace(config_dict)
# Load model
model = prepare_model(config_args=config_args, model_path=config_args_parser.model_path)
# Load test dataset
test_dataset = load_test_dataset(config_args)
# Run inference and store z_t, z_p
z_t_list = []
z_p_list = []
with torch.no_grad():
for idx in range(len(test_dataset)):
batch = test_dataset[idx]
x_t, x_p = batch
outputs = model(x_t, x_p, compute_masked_logits=False) # Infer Joint-Embeddings
z_t = outputs['text_joint_latent'] # Text latent
z_p = outputs['seq_joint_latent'] # Protein latent
z_t_list.append(z_t)
z_p_list.append(z_p)
# Stack all latent vectors
z_t_tensor = torch.vstack(z_t_list) # Shape: (num_samples, latent_dim)
z_p_tensor = torch.vstack(z_p_list) # Shape: (num_samples, latent_dim)
# Compute Dot Product scores
dot_product_scores = torch.matmul(z_p_tensor, z_t_tensor.T) # Dot product
# Normalize scores into probabilities
protein_given_text_probs = F.softmax(dot_product_scores, dim=0) # Normalize across rows (proteins), for each text
text_given_protein_probs = F.softmax(dot_product_scores, dim=1) # Normalize across columns (texts), for each protein
# Compute magnitudes (L2 norms) for z_t and z_p
z_p_magnitude = torch.norm(z_p_tensor, dim=1) # L2 norm for each protein latent vector
z_t_magnitude = torch.norm(z_t_tensor, dim=1) # L2 norm for each text latent vector
# Print results
print("\n=== Inference Results ===")
print(f"Shape of z_p (protein latent): {z_p_tensor.shape}")
print(f"Shape of z_t (text latent): {z_t_tensor.shape}")
print(f"\nMagnitudes of z_p vectors: {z_p_magnitude}")
print(f"Magnitudes of z_t vectors: {z_t_magnitude}")
print("\n=== Dot Product Scores Matrix ===")
print(dot_product_scores)
print("\n=== Normalized Probabilities ===")
print("Protein-Normalized Probabilities (Softmax across Proteins for each Text):")
print(protein_given_text_probs)
print("\nText-Normalized Probabilities (Softmax across Texts for each Protein):")
print(text_given_protein_probs)
|